Skip to content

Commit fdc88b1

Browse files
committed
Update local demo
1 parent 66e4dba commit fdc88b1

File tree

6 files changed

+118
-5
lines changed

6 files changed

+118
-5
lines changed

assets/meta/mean.npy

2.18 KB
Binary file not shown.

assets/meta/mean_eval.npy

2.18 KB
Binary file not shown.

assets/meta/std.npy

2.18 KB
Binary file not shown.

assets/meta/std_eval.npy

2.18 KB
Binary file not shown.

configs/webui.yaml

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ DEVICE: [0] # Index of gpus eg. [0] or [0,1,2,3]
77
TRAIN:
88
#---------------------------------
99
STAGE: lm_instruct
10-
DATASETS: ['humanml3d'] # Training datasets
1110
NUM_WORKERS: 32 # Number of workers
1211
BATCH_SIZE: 16 # Size of batches
1312
START_EPOCH: 0 # Start epochMMOTIONENCODER
@@ -23,14 +22,12 @@ TRAIN:
2322

2423
# Evaluating Configuration
2524
EVAL:
26-
DATASETS: ['humanml3d'] # Evaluating datasets
2725
BATCH_SIZE: 32 # Evaluating Batch size
2826
SPLIT: test
2927

3028
# Test Configuration
3129
TEST:
3230
CHECKPOINTS: checkpoints/MotionGPT-base/motiongpt_s3_h3d.tar
33-
DATASETS: ['humanml3d'] # training datasets
3431
SPLIT: test
3532
BATCH_SIZE: 32 # training Batch size
3633
MEAN: False
@@ -39,8 +36,8 @@ TEST:
3936

4037
# Datasets Configuration
4138
DATASET:
42-
JOINT_TYPE: 'humanml3d' # join type
43-
CODE_PATH: 'VQBEST'
39+
target: mGPT.data.webui.HumanML3DDataModule
40+
4441
METRIC:
4542
TYPE: ['TM2TMetrics']
4643
# Losses Configuration

mGPT/data/webui.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import numpy as np
2+
import torch
3+
from os.path import join as pjoin
4+
from .humanml.scripts.motion_process import (process_file, recover_from_ric)
5+
from . import BASEDataModule
6+
from .humanml import Text2MotionDatasetEval, Text2MotionDataset, Text2MotionDatasetCB, MotionDataset, MotionDatasetVQ, Text2MotionDatasetToken, Text2MotionDatasetM2T
7+
from .utils import humanml3d_collate
8+
9+
10+
class HumanML3DDataModule(BASEDataModule):
11+
def __init__(self, cfg, **kwargs):
12+
13+
super().__init__(collate_fn=humanml3d_collate)
14+
self.cfg = cfg
15+
self.save_hyperparameters(logger=False)
16+
17+
# Basic info of the dataset
18+
cfg.DATASET.JOINT_TYPE = 'humanml3d'
19+
self.name = "humanml3d"
20+
self.njoints = 22
21+
22+
# Path to the dataset
23+
data_root = cfg.DATASET.HUMANML3D.ROOT
24+
self.hparams.data_root = data_root
25+
self.hparams.text_dir = pjoin(data_root, "texts")
26+
self.hparams.motion_dir = pjoin(data_root, 'new_joint_vecs')
27+
28+
# Mean and std of the dataset
29+
self.hparams.mean = np.load(pjoin('assets/meta', "mean.npy"))
30+
self.hparams.std = np.load(pjoin('assets/meta', "std.npy"))
31+
32+
# Mean and std for fair evaluation
33+
self.hparams.mean_eval = np.load(pjoin('assets/meta', "mean_eval.npy"))
34+
self.hparams.std_eval = np.load(pjoin('assets/meta', "std_eval.npy"))
35+
36+
# Length of the dataset
37+
self.hparams.max_motion_length = cfg.DATASET.HUMANML3D.MAX_MOTION_LEN
38+
self.hparams.min_motion_length = cfg.DATASET.HUMANML3D.MIN_MOTION_LEN
39+
self.hparams.max_text_len = cfg.DATASET.HUMANML3D.MAX_TEXT_LEN
40+
self.hparams.unit_length = cfg.DATASET.HUMANML3D.UNIT_LEN
41+
42+
# Additional parameters
43+
self.hparams.debug = cfg.DEBUG
44+
self.hparams.stage = cfg.TRAIN.STAGE
45+
46+
# Dataset switch
47+
self.DatasetEval = Text2MotionDatasetEval
48+
49+
if cfg.TRAIN.STAGE == "vae":
50+
if cfg.model.params.motion_vae.target.split('.')[-1].lower() == "vqvae":
51+
self.hparams.win_size = 64
52+
self.Dataset = MotionDatasetVQ
53+
else:
54+
self.Dataset = MotionDataset
55+
elif 'lm' in cfg.TRAIN.STAGE:
56+
self.hparams.code_path = cfg.DATASET.CODE_PATH
57+
self.hparams.task_path = cfg.DATASET.TASK_PATH
58+
self.hparams.std_text = cfg.DATASET.HUMANML3D.STD_TEXT
59+
self.Dataset = Text2MotionDatasetCB
60+
elif cfg.TRAIN.STAGE == "token":
61+
self.Dataset = Text2MotionDatasetToken
62+
self.DatasetEval = Text2MotionDatasetToken
63+
elif cfg.TRAIN.STAGE == "m2t":
64+
self.Dataset = Text2MotionDatasetM2T
65+
self.DatasetEval = Text2MotionDatasetM2T
66+
else:
67+
self.Dataset = Text2MotionDataset
68+
69+
# Get additional info of the dataset
70+
self.nfeats = 263
71+
cfg.DATASET.NFEATS = self.nfeats
72+
73+
74+
def feats2joints(self, features):
75+
mean = torch.tensor(self.hparams.mean).to(features)
76+
std = torch.tensor(self.hparams.std).to(features)
77+
features = features * std + mean
78+
return recover_from_ric(features, self.njoints)
79+
80+
def joints2feats(self, features):
81+
features = process_file(features, self.njoints)[0]
82+
return features
83+
84+
def normalize(self, features):
85+
mean = torch.tensor(self.hparams.mean).to(features)
86+
std = torch.tensor(self.hparams.std).to(features)
87+
features = (features - mean) / std
88+
return features
89+
90+
def denormalize(self, features):
91+
mean = torch.tensor(self.hparams.mean).to(features)
92+
std = torch.tensor(self.hparams.std).to(features)
93+
features = features * std + mean
94+
return features
95+
96+
def renorm4t2m(self, features):
97+
# renorm to t2m norms for using t2m evaluators
98+
ori_mean = torch.tensor(self.hparams.mean).to(features)
99+
ori_std = torch.tensor(self.hparams.std).to(features)
100+
eval_mean = torch.tensor(self.hparams.mean_eval).to(features)
101+
eval_std = torch.tensor(self.hparams.std_eval).to(features)
102+
features = features * ori_std + ori_mean
103+
features = (features - eval_mean) / eval_std
104+
return features
105+
106+
def mm_mode(self, mm_on=True):
107+
if mm_on:
108+
self.is_mm = True
109+
self.name_list = self.test_dataset.name_list
110+
self.mm_list = np.random.choice(self.name_list,
111+
self.cfg.METRIC.MM_NUM_SAMPLES,
112+
replace=False)
113+
self.test_dataset.name_list = self.mm_list
114+
else:
115+
self.is_mm = False
116+
self.test_dataset.name_list = self.name_list

0 commit comments

Comments
 (0)