Skip to content
This repository was archived by the owner on Dec 14, 2023. It is now read-only.

Commit 3de0ee9

Browse files
committed
CLIP text encoder saving
1 parent 8ed13d7 commit 3de0ee9

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

utils/convert_diffusers_to_original_ms_text_to_video.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,7 @@ def convert_text_enc_state_dict(text_enc_dict):
310310

311311
parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
312312
parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
313+
parser.add_argument("--clip_checkpoint_path", default=None, type=str, required=True, help="Path to the output CLIP model.")
313314
parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
314315
parser.add_argument(
315316
"--use_safetensors", action="store_true", help="Save weights use safetensors, default is ckpt."
@@ -321,6 +322,8 @@ def convert_text_enc_state_dict(text_enc_dict):
321322

322323
assert args.checkpoint_path is not None, "Must provide a checkpoint path!"
323324

325+
assert args.clip_checkpoint_path is not None, "Must provide a CLIP checkpoint path!"
326+
324327
# Path for safetensors
325328
unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.safetensors")
326329
vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.safetensors")
@@ -369,8 +372,7 @@ def convert_text_enc_state_dict(text_enc_dict):
369372
text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
370373

371374
# DON'T PUT TOGETHER FOR THE NEW CHECKPOINT AS MODELSCOPE USES THEM IN THE SPLITTED FORM --kabachuha
372-
373-
375+
# Save CLIP and the Diffusion model to their own files
374376

375377
#state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
376378
print ('Saving UNET')
@@ -383,4 +385,16 @@ def convert_text_enc_state_dict(text_enc_dict):
383385
save_file(state_dict, args.checkpoint_path)
384386
else:
385387
state_dict = {"state_dict": state_dict}
386-
torch.save(state_dict, args.checkpoint_path)
388+
torch.save(state_dict, args.checkpoint_path)
389+
390+
print ('Saving CLIP')
391+
state_dict = {**text_enc_dict}
392+
393+
if args.half:
394+
state_dict = {k: v.half() for k, v in state_dict.items()}
395+
396+
if args.use_safetensors:
397+
save_file(state_dict, args.checkpoint_path)
398+
else:
399+
state_dict = {"state_dict": state_dict}
400+
torch.save(state_dict, args.clip_checkpoint_path)

0 commit comments

Comments
 (0)