Skip to content

Commit 79fabdc

Browse files
committed
Fix code visualization script
1 parent 30527a4 commit 79fabdc

File tree

2 files changed

+15
-33
lines changed

2 files changed

+15
-33
lines changed

mGPT/utils/load_checkpoint.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import torch
22

3-
def load_pretrained(cfg, model, logger, phase="train"):
4-
logger.info(f"Loading pretrain model from {cfg.TRAIN.PRETRAINED}")
3+
def load_pretrained(cfg, model, logger=None, phase="train"):
4+
if logger is not None:
5+
logger.info(f"Loading pretrain model from {cfg.TRAIN.PRETRAINED}")
6+
57
if phase == "train":
68
ckpt_path = cfg.TRAIN.PRETRAINED
79
elif phase == "test":

scripts/get_code_visual.py

Lines changed: 11 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
from pathlib import Path
66
from tqdm import tqdm
77
from mGPT.config import parse_args
8-
from mGPT.data.build_data import get_datasets
8+
from mGPT.data.build_data import build_data
99
from mGPT.models.build_model import build_model
10-
10+
from mGPT.utils.load_checkpoint import load_pretrained, load_pretrained_vae
1111

1212
def main():
1313

@@ -28,48 +28,28 @@ def main():
2828
os.environ["TOKENIZERS_PARALLELISM"] = "false"
2929

3030
# create dataset
31-
datasets = get_datasets(cfg, phase="test")[0]
31+
datamodule = build_data(cfg, phase="test")
3232
print("datasets module {} initialized".format("".join(cfg.TRAIN.DATASETS)))
3333

3434
os.makedirs(output_dir, exist_ok=True)
3535

3636
# create model
37-
model = build_model(cfg, datasets)
38-
if hasattr(model, "motion_vae"):
39-
model.vae = model.motion_vae
40-
print("model {} loaded".format(cfg.model.model_type))
37+
model = build_model(cfg, datamodule)
38+
print("model {} loaded".format(cfg.model.target))
4139

4240
# Strict load vae model
4341
if cfg.TRAIN.PRETRAINED_VAE:
44-
state_dict = torch.load(cfg.TRAIN.PRETRAINED_VAE,
45-
map_location="cpu")['state_dict']
46-
print(f"Loading pretrain vae from {cfg.TRAIN.PRETRAINED_VAE}")
47-
48-
from collections import OrderedDict
49-
vae_dict = OrderedDict()
50-
for k, v in state_dict.items():
51-
if "motion_vae" in k:
52-
name = k.replace("motion_vae.", "")
53-
vae_dict[name] = v
54-
elif "vae" in k:
55-
name = k.replace("vae.", "")
56-
vae_dict[name] = v
57-
if hasattr(model, 'vae'):
58-
model.vae.load_state_dict(vae_dict, strict=True)
59-
else:
60-
model.motion_vae.load_state_dict(vae_dict, strict=True)
61-
62-
# Strict load pretrianed model
63-
if cfg.TRAIN.PRETRAINED:
64-
state_dict = torch.load(cfg.TRAIN.PRETRAINED,
65-
map_location="cpu")["state_dict"]
66-
model.load_state_dict(state_dict, strict=True)
42+
load_pretrained_vae(cfg, model)
6743

44+
# loading state dict
45+
if cfg.TEST.CHECKPOINTS:
46+
load_pretrained(cfg, model, phase="test")
47+
6848
if cfg.ACCELERATOR == "gpu":
6949
model = model.cuda()
7050

7151
model.eval()
72-
codes = cfg.model.codebook_size
52+
codes = cfg.model.params.codebook_size
7353
with torch.no_grad():
7454
for i in tqdm(range(codes)):
7555

0 commit comments

Comments
 (0)