55from pathlib import Path
66from tqdm import tqdm
77from mGPT .config import parse_args
8- from mGPT .data .build_data import get_datasets
8+ from mGPT .data .build_data import build_data
99from mGPT .models .build_model import build_model
10-
10+ from mGPT . utils . load_checkpoint import load_pretrained , load_pretrained_vae
1111
1212def main ():
1313
@@ -28,48 +28,28 @@ def main():
2828 os .environ ["TOKENIZERS_PARALLELISM" ] = "false"
2929
3030 # create dataset
31- datasets = get_datasets (cfg , phase = "test" )[ 0 ]
31+ datamodule = build_data (cfg , phase = "test" )
3232 print ("datasets module {} initialized" .format ("" .join (cfg .TRAIN .DATASETS )))
3333
3434 os .makedirs (output_dir , exist_ok = True )
3535
3636 # create model
37- model = build_model (cfg , datasets )
38- if hasattr (model , "motion_vae" ):
39- model .vae = model .motion_vae
40- print ("model {} loaded" .format (cfg .model .model_type ))
37+ model = build_model (cfg , datamodule )
38+ print ("model {} loaded" .format (cfg .model .target ))
4139
4240 # Strict load vae model
4341 if cfg .TRAIN .PRETRAINED_VAE :
44- state_dict = torch .load (cfg .TRAIN .PRETRAINED_VAE ,
45- map_location = "cpu" )['state_dict' ]
46- print (f"Loading pretrain vae from { cfg .TRAIN .PRETRAINED_VAE } " )
47-
48- from collections import OrderedDict
49- vae_dict = OrderedDict ()
50- for k , v in state_dict .items ():
51- if "motion_vae" in k :
52- name = k .replace ("motion_vae." , "" )
53- vae_dict [name ] = v
54- elif "vae" in k :
55- name = k .replace ("vae." , "" )
56- vae_dict [name ] = v
57- if hasattr (model , 'vae' ):
58- model .vae .load_state_dict (vae_dict , strict = True )
59- else :
60- model .motion_vae .load_state_dict (vae_dict , strict = True )
61-
62- # Strict load pretrianed model
63- if cfg .TRAIN .PRETRAINED :
64- state_dict = torch .load (cfg .TRAIN .PRETRAINED ,
65- map_location = "cpu" )["state_dict" ]
66- model .load_state_dict (state_dict , strict = True )
42+ load_pretrained_vae (cfg , model )
6743
44+ # loading state dict
45+ if cfg .TEST .CHECKPOINTS :
46+ load_pretrained (cfg , model , phase = "test" )
47+
6848 if cfg .ACCELERATOR == "gpu" :
6949 model = model .cuda ()
7050
7151 model .eval ()
72- codes = cfg .model .codebook_size
52+ codes = cfg .model .params . codebook_size
7353 with torch .no_grad ():
7454 for i in tqdm (range (codes )):
7555
0 commit comments