Skip to content
This repository was archived by the owner on Dec 14, 2023. It is now read-only.

Commit 5c428c6

Browse files
committed
add support for temporal Unet parts
1 parent 3de0ee9 commit 5c428c6

File tree

1 file changed

+52
-6
lines changed

1 file changed

+52
-6
lines changed

utils/convert_diffusers_to_original_ms_text_to_video.py

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,17 @@
4747

4848
unet_conversion_map_resnet = [
4949
# (ModelScope, HF Diffusers)
50+
51+
# SD
5052
("in_layers.0", "norm1"),
5153
("in_layers.2", "conv1"),
5254
("out_layers.0", "norm2"),
5355
("out_layers.3", "conv2"),
5456
("emb_layers.1", "time_emb_proj"),
5557
("skip_connection", "conv_shortcut"),
58+
59+
# MS
60+
("temopral_conv", "temp_conv"), # ROFL, they have a typo here --kabachuha
5661
]
5762

5863
unet_conversion_map_layer = []
@@ -85,6 +90,8 @@
8590

8691
for j in range(2):
8792
# loop over resnets/attentions for downblocks
93+
94+
# Spacial SD stuff
8895
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
8996
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
9097
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
@@ -94,9 +101,22 @@
94101
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
95102
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
96103
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
104+
105+
# Temporal MS stuff
106+
hf_down_res_prefix = f"down_blocks.{i}.temp_convs.{j}."
107+
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
108+
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
109+
110+
if i < 3:
111+
# no attention layers in down_blocks.3
112+
hf_down_atn_prefix = f"down_blocks.{i}.temp_attentions.{j}."
113+
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
114+
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
97115

98116
for j in range(3):
99117
# loop over resnets/attentions for upblocks
118+
119+
# Spacial SD stuff
100120
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
101121
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
102122
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
@@ -106,7 +126,19 @@
106126
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
107127
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
108128
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
129+
130+
# loop over resnets/attentions for upblocks
131+
hf_up_res_prefix = f"up_blocks.{i}.temp_convs.{j}."
132+
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
133+
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
134+
135+
if i > 0:
136+
# no attention layers in up_blocks.0
137+
hf_up_atn_prefix = f"up_blocks.{i}.temp_attentions.{j}."
138+
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
139+
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
109140

141+
# Up/Downsamplers are 2D, so don't need to touch them
110142
if i < 3:
111143
# no downsample in down_blocks.3
112144
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
@@ -121,6 +153,7 @@
121153

122154
# Handle the middle block
123155

156+
# Spacial
124157
hf_mid_atn_prefix = "mid_block.attentions.0."
125158
sd_mid_atn_prefix = "middle_block.1."
126159
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
@@ -130,8 +163,17 @@
130163
sd_mid_res_prefix = f"middle_block.{2*j}."
131164
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
132165

166+
# Temporal
167+
hf_mid_atn_prefix = "mid_block.temp_attentions.0."
168+
sd_mid_atn_prefix = "middle_block.1."
169+
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
133170

171+
for j in range(2):
172+
hf_mid_res_prefix = f"mid_block.temp_convs.{j}."
173+
sd_mid_res_prefix = f"middle_block.{2*j}."
174+
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
134175

176+
# The pipeline
135177
def convert_unet_state_dict(unet_state_dict):
136178
print ('Converting the UNET')
137179
# buyer beware: this is a *brittle* function,
@@ -146,6 +188,10 @@ def convert_unet_state_dict(unet_state_dict):
146188
for sd_part, hf_part in unet_conversion_map_resnet:
147189
v = v.replace(hf_part, sd_part)
148190
mapping[k] = v
191+
elif "temp_convs" in k:
192+
for sd_part, hf_part in unet_conversion_map_resnet:
193+
v = v.replace(hf_part, sd_part)
194+
mapping[k] = v
149195
for k, v in mapping.items():
150196
for sd_part, hf_part in unet_conversion_map_layer:
151197
v = v.replace(hf_part, sd_part)
@@ -326,7 +372,7 @@ def convert_text_enc_state_dict(text_enc_dict):
326372

327373
# Path for safetensors
328374
unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.safetensors")
329-
vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.safetensors")
375+
#vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.safetensors")
330376
text_enc_path = osp.join(args.model_path, "text_encoder", "model.safetensors")
331377

332378
# Load models from safetensors if it exists, if it doesn't pytorch
@@ -336,11 +382,11 @@ def convert_text_enc_state_dict(text_enc_dict):
336382
unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin")
337383
unet_state_dict = torch.load(unet_path, map_location="cpu")
338384

339-
if osp.exists(vae_path):
340-
vae_state_dict = load_file(vae_path, device="cpu")
341-
else:
342-
vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")
343-
vae_state_dict = torch.load(vae_path, map_location="cpu")
385+
# if osp.exists(vae_path):
386+
# vae_state_dict = load_file(vae_path, device="cpu")
387+
# else:
388+
# vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")
389+
# vae_state_dict = torch.load(vae_path, map_location="cpu")
344390

345391
if osp.exists(text_enc_path):
346392
text_enc_dict = load_file(text_enc_path, device="cpu")

0 commit comments

Comments
 (0)