4747
4848unet_conversion_map_resnet = [
4949 # (ModelScope, HF Diffusers)
50+
51+ # SD
5052 ("in_layers.0" , "norm1" ),
5153 ("in_layers.2" , "conv1" ),
5254 ("out_layers.0" , "norm2" ),
5355 ("out_layers.3" , "conv2" ),
5456 ("emb_layers.1" , "time_emb_proj" ),
5557 ("skip_connection" , "conv_shortcut" ),
58+
59+ # MS
60+ ("temopral_conv" , "temp_conv" ), # ROFL, they have a typo here --kabachuha
5661]
5762
5863unet_conversion_map_layer = []
8590
8691 for j in range (2 ):
8792 # loop over resnets/attentions for downblocks
93+
94+ # Spacial SD stuff
8895 hf_down_res_prefix = f"down_blocks.{ i } .resnets.{ j } ."
8996 sd_down_res_prefix = f"input_blocks.{ 3 * i + j + 1 } .0."
9097 unet_conversion_map_layer .append ((sd_down_res_prefix , hf_down_res_prefix ))
94101 hf_down_atn_prefix = f"down_blocks.{ i } .attentions.{ j } ."
95102 sd_down_atn_prefix = f"input_blocks.{ 3 * i + j + 1 } .1."
96103 unet_conversion_map_layer .append ((sd_down_atn_prefix , hf_down_atn_prefix ))
104+
105+ # Temporal MS stuff
106+ hf_down_res_prefix = f"down_blocks.{ i } .temp_convs.{ j } ."
107+ sd_down_res_prefix = f"input_blocks.{ 3 * i + j + 1 } .0."
108+ unet_conversion_map_layer .append ((sd_down_res_prefix , hf_down_res_prefix ))
109+
110+ if i < 3 :
111+ # no attention layers in down_blocks.3
112+ hf_down_atn_prefix = f"down_blocks.{ i } .temp_attentions.{ j } ."
113+ sd_down_atn_prefix = f"input_blocks.{ 3 * i + j + 1 } .1."
114+ unet_conversion_map_layer .append ((sd_down_atn_prefix , hf_down_atn_prefix ))
97115
98116 for j in range (3 ):
99117 # loop over resnets/attentions for upblocks
118+
119+ # Spacial SD stuff
100120 hf_up_res_prefix = f"up_blocks.{ i } .resnets.{ j } ."
101121 sd_up_res_prefix = f"output_blocks.{ 3 * i + j } .0."
102122 unet_conversion_map_layer .append ((sd_up_res_prefix , hf_up_res_prefix ))
106126 hf_up_atn_prefix = f"up_blocks.{ i } .attentions.{ j } ."
107127 sd_up_atn_prefix = f"output_blocks.{ 3 * i + j } .1."
108128 unet_conversion_map_layer .append ((sd_up_atn_prefix , hf_up_atn_prefix ))
129+
130+ # loop over resnets/attentions for upblocks
131+ hf_up_res_prefix = f"up_blocks.{ i } .temp_convs.{ j } ."
132+ sd_up_res_prefix = f"output_blocks.{ 3 * i + j } .0."
133+ unet_conversion_map_layer .append ((sd_up_res_prefix , hf_up_res_prefix ))
134+
135+ if i > 0 :
136+ # no attention layers in up_blocks.0
137+ hf_up_atn_prefix = f"up_blocks.{ i } .temp_attentions.{ j } ."
138+ sd_up_atn_prefix = f"output_blocks.{ 3 * i + j } .1."
139+ unet_conversion_map_layer .append ((sd_up_atn_prefix , hf_up_atn_prefix ))
109140
141+ # Up/Downsamplers are 2D, so don't need to touch them
110142 if i < 3 :
111143 # no downsample in down_blocks.3
112144 hf_downsample_prefix = f"down_blocks.{ i } .downsamplers.0.conv."
121153
122154# Handle the middle block
123155
156+ # Spacial
124157hf_mid_atn_prefix = "mid_block.attentions.0."
125158sd_mid_atn_prefix = "middle_block.1."
126159unet_conversion_map_layer .append ((sd_mid_atn_prefix , hf_mid_atn_prefix ))
130163 sd_mid_res_prefix = f"middle_block.{ 2 * j } ."
131164 unet_conversion_map_layer .append ((sd_mid_res_prefix , hf_mid_res_prefix ))
132165
166+ # Temporal
167+ hf_mid_atn_prefix = "mid_block.temp_attentions.0."
168+ sd_mid_atn_prefix = "middle_block.1."
169+ unet_conversion_map_layer .append ((sd_mid_atn_prefix , hf_mid_atn_prefix ))
133170
171+ for j in range (2 ):
172+ hf_mid_res_prefix = f"mid_block.temp_convs.{ j } ."
173+ sd_mid_res_prefix = f"middle_block.{ 2 * j } ."
174+ unet_conversion_map_layer .append ((sd_mid_res_prefix , hf_mid_res_prefix ))
134175
176+ # The pipeline
135177def convert_unet_state_dict (unet_state_dict ):
136178 print ('Converting the UNET' )
137179 # buyer beware: this is a *brittle* function,
@@ -146,6 +188,10 @@ def convert_unet_state_dict(unet_state_dict):
146188 for sd_part , hf_part in unet_conversion_map_resnet :
147189 v = v .replace (hf_part , sd_part )
148190 mapping [k ] = v
191+ elif "temp_convs" in k :
192+ for sd_part , hf_part in unet_conversion_map_resnet :
193+ v = v .replace (hf_part , sd_part )
194+ mapping [k ] = v
149195 for k , v in mapping .items ():
150196 for sd_part , hf_part in unet_conversion_map_layer :
151197 v = v .replace (hf_part , sd_part )
@@ -326,7 +372,7 @@ def convert_text_enc_state_dict(text_enc_dict):
326372
327373 # Path for safetensors
328374 unet_path = osp .join (args .model_path , "unet" , "diffusion_pytorch_model.safetensors" )
329- vae_path = osp .join (args .model_path , "vae" , "diffusion_pytorch_model.safetensors" )
375+ # vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.safetensors")
330376 text_enc_path = osp .join (args .model_path , "text_encoder" , "model.safetensors" )
331377
332378 # Load models from safetensors if it exists, if it doesn't pytorch
@@ -336,11 +382,11 @@ def convert_text_enc_state_dict(text_enc_dict):
336382 unet_path = osp .join (args .model_path , "unet" , "diffusion_pytorch_model.bin" )
337383 unet_state_dict = torch .load (unet_path , map_location = "cpu" )
338384
339- if osp .exists (vae_path ):
340- vae_state_dict = load_file (vae_path , device = "cpu" )
341- else :
342- vae_path = osp .join (args .model_path , "vae" , "diffusion_pytorch_model.bin" )
343- vae_state_dict = torch .load (vae_path , map_location = "cpu" )
385+ # if osp.exists(vae_path):
386+ # vae_state_dict = load_file(vae_path, device="cpu")
387+ # else:
388+ # vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")
389+ # vae_state_dict = torch.load(vae_path, map_location="cpu")
344390
345391 if osp .exists (text_enc_path ):
346392 text_enc_dict = load_file (text_enc_path , device = "cpu" )
0 commit comments