Skip to content
This repository was archived by the owner on Dec 14, 2023. It is now read-only.

Commit e8149f5

Browse files
committed
convert the input temporal transformer
1 parent 5c428c6 commit e8149f5

File tree

1 file changed

+16
-9
lines changed

1 file changed

+16
-9
lines changed

utils/convert_diffusers_to_original_ms_text_to_video.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,19 @@
2424
("time_embed.2.weight", "time_embedding.linear_2.weight"),
2525
("time_embed.2.bias", "time_embedding.linear_2.bias"),
2626

27-
# from Modelscope only
28-
("label_emb.0.0.weight", "class_embedding.linear_1.weight"),
29-
("label_emb.0.0.bias", "class_embedding.linear_1.bias"),
30-
("label_emb.0.2.weight", "class_embedding.linear_2.weight"),
31-
("label_emb.0.2.bias", "class_embedding.linear_2.bias"),
27+
# # from Modelscope only
28+
# ("label_emb.0.0.weight", "class_embedding.linear_1.weight"),
29+
# ("label_emb.0.0.bias", "class_embedding.linear_1.bias"),
30+
# ("label_emb.0.2.weight", "class_embedding.linear_2.weight"),
31+
# ("label_emb.0.2.bias", "class_embedding.linear_2.bias"),
3232

3333
# from Vanilla ModelScope/StableDiffusion
3434
("input_blocks.0.0.weight", "conv_in.weight"),
3535
("input_blocks.0.0.bias", "conv_in.bias"),
3636

3737
# from Modelscope only
38-
("input_blocks.0.1.weight", "transformer_in.weight"),
39-
("input_blocks.0.1.bias", "transformer_in.bias"),
38+
#("input_blocks.0.1", "transformer_in.weight"),
39+
#("input_blocks.0.1.bias", "transformer_in.bias"),
4040

4141
# from Vanilla ModelScope/StableDiffusion
4242
("out.0.weight", "conv_norm_out.weight"),
@@ -62,6 +62,9 @@
6262

6363
unet_conversion_map_layer = []
6464

65+
# Convert input TemporalTransformer
66+
unet_conversion_map_layer.append(('input_blocks.0.1', 'transformer_in'))
67+
6568
# Reference for the default settings
6669

6770
# "model_cfg": {
@@ -85,10 +88,10 @@
8588

8689
# hardcoded number of downblocks and resnets/attentions...
8790
# would need smarter logic for other networks.
88-
for i in range(4):
91+
for i in range(4):# 4 UD/DOWN BLOCKS CONFIRMED --kabachuha
8992
# loop over downblocks/upblocks
9093

91-
for j in range(2):
94+
for j in range(2): # 2 RESNET BLOCKS CONFIRMED --kabachuha
9295
# loop over resnets/attentions for downblocks
9396

9497
# Spacial SD stuff
@@ -181,6 +184,8 @@ def convert_unet_state_dict(unet_state_dict):
181184
# the exact order in which I have arranged them.
182185
mapping = {k: k for k in unet_state_dict.keys()}
183186

187+
188+
184189
for sd_name, hf_name in unet_conversion_map:
185190
mapping[hf_name] = sd_name
186191
for k, v in mapping.items():
@@ -444,3 +449,5 @@ def convert_text_enc_state_dict(text_enc_dict):
444449
else:
445450
state_dict = {"state_dict": state_dict}
446451
torch.save(state_dict, args.clip_checkpoint_path)
452+
453+
print('Operation successfull')

0 commit comments

Comments
 (0)