@@ -310,6 +310,7 @@ def convert_text_enc_state_dict(text_enc_dict):
310310
311311 parser .add_argument ("--model_path" , default = None , type = str , required = True , help = "Path to the model to convert." )
312312 parser .add_argument ("--checkpoint_path" , default = None , type = str , required = True , help = "Path to the output model." )
313+ parser .add_argument ("--clip_checkpoint_path" , default = None , type = str , required = True , help = "Path to the output CLIP model." )
313314 parser .add_argument ("--half" , action = "store_true" , help = "Save weights in half precision." )
314315 parser .add_argument (
315316 "--use_safetensors" , action = "store_true" , help = "Save weights use safetensors, default is ckpt."
@@ -321,6 +322,8 @@ def convert_text_enc_state_dict(text_enc_dict):
321322
322323 assert args .checkpoint_path is not None , "Must provide a checkpoint path!"
323324
325+ assert args .clip_checkpoint_path is not None , "Must provide a CLIP checkpoint path!"
326+
324327 # Path for safetensors
325328 unet_path = osp .join (args .model_path , "unet" , "diffusion_pytorch_model.safetensors" )
326329 vae_path = osp .join (args .model_path , "vae" , "diffusion_pytorch_model.safetensors" )
@@ -369,8 +372,7 @@ def convert_text_enc_state_dict(text_enc_dict):
369372 text_enc_dict = {"cond_stage_model.transformer." + k : v for k , v in text_enc_dict .items ()}
370373
371374 # DON'T PUT TOGETHER FOR THE NEW CHECKPOINT AS MODELSCOPE USES THEM IN THE SPLITTED FORM --kabachuha
372-
373-
375+ # Save CLIP and the Diffusion model to their own files
374376
375377 #state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
376378 print ('Saving UNET' )
@@ -383,4 +385,16 @@ def convert_text_enc_state_dict(text_enc_dict):
383385 save_file (state_dict , args .checkpoint_path )
384386 else :
385387 state_dict = {"state_dict" : state_dict }
386- torch .save (state_dict , args .checkpoint_path )
388+ torch .save (state_dict , args .checkpoint_path )
389+
390+ print ('Saving CLIP' )
391+ state_dict = {** text_enc_dict }
392+
393+ if args .half :
394+ state_dict = {k : v .half () for k , v in state_dict .items ()}
395+
396+ if args .use_safetensors :
397+ save_file (state_dict , args .checkpoint_path )
398+ else :
399+ state_dict = {"state_dict" : state_dict }
400+ torch .save (state_dict , args .clip_checkpoint_path )
0 commit comments