|
24 | 24 | ("time_embed.2.weight", "time_embedding.linear_2.weight"), |
25 | 25 | ("time_embed.2.bias", "time_embedding.linear_2.bias"), |
26 | 26 |
|
27 | | - # from Modelscope only |
28 | | - ("label_emb.0.0.weight", "class_embedding.linear_1.weight"), |
29 | | - ("label_emb.0.0.bias", "class_embedding.linear_1.bias"), |
30 | | - ("label_emb.0.2.weight", "class_embedding.linear_2.weight"), |
31 | | - ("label_emb.0.2.bias", "class_embedding.linear_2.bias"), |
| 27 | + # # from Modelscope only |
| 28 | + # ("label_emb.0.0.weight", "class_embedding.linear_1.weight"), |
| 29 | + # ("label_emb.0.0.bias", "class_embedding.linear_1.bias"), |
| 30 | + # ("label_emb.0.2.weight", "class_embedding.linear_2.weight"), |
| 31 | + # ("label_emb.0.2.bias", "class_embedding.linear_2.bias"), |
32 | 32 |
|
33 | 33 | # from Vanilla ModelScope/StableDiffusion |
34 | 34 | ("input_blocks.0.0.weight", "conv_in.weight"), |
35 | 35 | ("input_blocks.0.0.bias", "conv_in.bias"), |
36 | 36 |
|
37 | 37 | # from Modelscope only |
38 | | - ("input_blocks.0.1.weight", "transformer_in.weight"), |
39 | | - ("input_blocks.0.1.bias", "transformer_in.bias"), |
| 38 | + #("input_blocks.0.1", "transformer_in.weight"), |
| 39 | + #("input_blocks.0.1.bias", "transformer_in.bias"), |
40 | 40 |
|
41 | 41 | # from Vanilla ModelScope/StableDiffusion |
42 | 42 | ("out.0.weight", "conv_norm_out.weight"), |
|
62 | 62 |
|
63 | 63 | unet_conversion_map_layer = [] |
64 | 64 |
|
| 65 | +# Convert input TemporalTransformer |
| 66 | +unet_conversion_map_layer.append(('input_blocks.0.1', 'transformer_in')) |
| 67 | + |
65 | 68 | # Reference for the default settings |
66 | 69 |
|
67 | 70 | # "model_cfg": { |
|
85 | 88 |
|
86 | 89 | # hardcoded number of downblocks and resnets/attentions... |
87 | 90 | # would need smarter logic for other networks. |
88 | | -for i in range(4): |
| 91 | +for i in range(4):# 4 UD/DOWN BLOCKS CONFIRMED --kabachuha |
89 | 92 | # loop over downblocks/upblocks |
90 | 93 |
|
91 | | - for j in range(2): |
| 94 | + for j in range(2): # 2 RESNET BLOCKS CONFIRMED --kabachuha |
92 | 95 | # loop over resnets/attentions for downblocks |
93 | 96 |
|
94 | 97 | # Spacial SD stuff |
@@ -181,6 +184,8 @@ def convert_unet_state_dict(unet_state_dict): |
181 | 184 | # the exact order in which I have arranged them. |
182 | 185 | mapping = {k: k for k in unet_state_dict.keys()} |
183 | 186 |
|
| 187 | + |
| 188 | + |
184 | 189 | for sd_name, hf_name in unet_conversion_map: |
185 | 190 | mapping[hf_name] = sd_name |
186 | 191 | for k, v in mapping.items(): |
@@ -444,3 +449,5 @@ def convert_text_enc_state_dict(text_enc_dict): |
444 | 449 | else: |
445 | 450 | state_dict = {"state_dict": state_dict} |
446 | 451 | torch.save(state_dict, args.clip_checkpoint_path) |
| 452 | + |
| 453 | + print('Operation successfull') |
0 commit comments