1+ # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
2+ # *Only* converts the UNet, and Text Encoder.
3+ # Does not convert optimizer state or any other thing.
4+
5+ import argparse
6+ import os .path as osp
7+ import re
8+
9+ import torch
10+ from safetensors .torch import load_file , save_file
11+
12+ # =================#
13+ # UNet Conversion #
14+ # =================#
15+
16+ print ('Initializing the conversion map' )
17+
18+ unet_conversion_map = [
19+ # (ModelScope, HF Diffusers)
20+
21+ # from Vanilla ModelScope/StableDiffusion
22+ ("time_embed.0.weight" , "time_embedding.linear_1.weight" ),
23+ ("time_embed.0.bias" , "time_embedding.linear_1.bias" ),
24+ ("time_embed.2.weight" , "time_embedding.linear_2.weight" ),
25+ ("time_embed.2.bias" , "time_embedding.linear_2.bias" ),
26+
27+ # from Modelscope only
28+ ("label_emb.0.0.weight" , "class_embedding.linear_1.weight" ),
29+ ("label_emb.0.0.bias" , "class_embedding.linear_1.bias" ),
30+ ("label_emb.0.2.weight" , "class_embedding.linear_2.weight" ),
31+ ("label_emb.0.2.bias" , "class_embedding.linear_2.bias" ),
32+
33+ # from Vanilla ModelScope/StableDiffusion
34+ ("input_blocks.0.0.weight" , "conv_in.weight" ),
35+ ("input_blocks.0.0.bias" , "conv_in.bias" ),
36+
37+ # from Modelscope only
38+ ("input_blocks.0.1.weight" , "transformer_in.weight" ),
39+ ("input_blocks.0.1.bias" , "transformer_in.bias" ),
40+
41+ # from Vanilla ModelScope/StableDiffusion
42+ ("out.0.weight" , "conv_norm_out.weight" ),
43+ ("out.0.bias" , "conv_norm_out.bias" ),
44+ ("out.2.weight" , "conv_out.weight" ),
45+ ("out.2.bias" , "conv_out.bias" ),
46+ ]
47+
48+ unet_conversion_map_resnet = [
49+ # (ModelScope, HF Diffusers)
50+ ("in_layers.0" , "norm1" ),
51+ ("in_layers.2" , "conv1" ),
52+ ("out_layers.0" , "norm2" ),
53+ ("out_layers.3" , "conv2" ),
54+ ("emb_layers.1" , "time_emb_proj" ),
55+ ("skip_connection" , "conv_shortcut" ),
56+ ]
57+
58+ unet_conversion_map_layer = []
59+
60+ # Reference for the default settings
61+
62+ # "model_cfg": {
63+ # "unet_in_dim": 4,
64+ # "unet_dim": 320,
65+ # "unet_y_dim": 768,
66+ # "unet_context_dim": 1024,
67+ # "unet_out_dim": 4,
68+ # "unet_dim_mult": [1, 2, 4, 4],
69+ # "unet_num_heads": 8,
70+ # "unet_head_dim": 64,
71+ # "unet_res_blocks": 2,
72+ # "unet_attn_scales": [1, 0.5, 0.25],
73+ # "unet_dropout": 0.1,
74+ # "temporal_attention": "True",
75+ # "num_timesteps": 1000,
76+ # "mean_type": "eps",
77+ # "var_type": "fixed_small",
78+ # "loss_type": "mse"
79+ # }
80+
81+ # hardcoded number of downblocks and resnets/attentions...
82+ # would need smarter logic for other networks.
83+ for i in range (4 ):
84+ # loop over downblocks/upblocks
85+
86+ for j in range (2 ):
87+ # loop over resnets/attentions for downblocks
88+ hf_down_res_prefix = f"down_blocks.{ i } .resnets.{ j } ."
89+ sd_down_res_prefix = f"input_blocks.{ 3 * i + j + 1 } .0."
90+ unet_conversion_map_layer .append ((sd_down_res_prefix , hf_down_res_prefix ))
91+
92+ if i < 3 :
93+ # no attention layers in down_blocks.3
94+ hf_down_atn_prefix = f"down_blocks.{ i } .attentions.{ j } ."
95+ sd_down_atn_prefix = f"input_blocks.{ 3 * i + j + 1 } .1."
96+ unet_conversion_map_layer .append ((sd_down_atn_prefix , hf_down_atn_prefix ))
97+
98+ for j in range (3 ):
99+ # loop over resnets/attentions for upblocks
100+ hf_up_res_prefix = f"up_blocks.{ i } .resnets.{ j } ."
101+ sd_up_res_prefix = f"output_blocks.{ 3 * i + j } .0."
102+ unet_conversion_map_layer .append ((sd_up_res_prefix , hf_up_res_prefix ))
103+
104+ if i > 0 :
105+ # no attention layers in up_blocks.0
106+ hf_up_atn_prefix = f"up_blocks.{ i } .attentions.{ j } ."
107+ sd_up_atn_prefix = f"output_blocks.{ 3 * i + j } .1."
108+ unet_conversion_map_layer .append ((sd_up_atn_prefix , hf_up_atn_prefix ))
109+
110+ if i < 3 :
111+ # no downsample in down_blocks.3
112+ hf_downsample_prefix = f"down_blocks.{ i } .downsamplers.0.conv."
113+ sd_downsample_prefix = f"input_blocks.{ 3 * (i + 1 )} .0.op."
114+ unet_conversion_map_layer .append ((sd_downsample_prefix , hf_downsample_prefix ))
115+
116+ # no upsample in up_blocks.3
117+ hf_upsample_prefix = f"up_blocks.{ i } .upsamplers.0."
118+ sd_upsample_prefix = f"output_blocks.{ 3 * i + 2 } .{ 1 if i == 0 else 2 } ."
119+ unet_conversion_map_layer .append ((sd_upsample_prefix , hf_upsample_prefix ))
120+
121+
122+ # Handle the middle block
123+
124+ hf_mid_atn_prefix = "mid_block.attentions.0."
125+ sd_mid_atn_prefix = "middle_block.1."
126+ unet_conversion_map_layer .append ((sd_mid_atn_prefix , hf_mid_atn_prefix ))
127+
128+ for j in range (2 ):
129+ hf_mid_res_prefix = f"mid_block.resnets.{ j } ."
130+ sd_mid_res_prefix = f"middle_block.{ 2 * j } ."
131+ unet_conversion_map_layer .append ((sd_mid_res_prefix , hf_mid_res_prefix ))
132+
133+
134+
135+ def convert_unet_state_dict (unet_state_dict ):
136+ print ('Converting the UNET' )
137+ # buyer beware: this is a *brittle* function,
138+ # and correct output requires that all of these pieces interact in
139+ # the exact order in which I have arranged them.
140+ mapping = {k : k for k in unet_state_dict .keys ()}
141+
142+ for sd_name , hf_name in unet_conversion_map :
143+ mapping [hf_name ] = sd_name
144+ for k , v in mapping .items ():
145+ if "resnets" in k :
146+ for sd_part , hf_part in unet_conversion_map_resnet :
147+ v = v .replace (hf_part , sd_part )
148+ mapping [k ] = v
149+ for k , v in mapping .items ():
150+ for sd_part , hf_part in unet_conversion_map_layer :
151+ v = v .replace (hf_part , sd_part )
152+ mapping [k ] = v
153+ new_state_dict = {v : unet_state_dict [k ] for k , v in mapping .items ()}
154+ return new_state_dict
155+
156+ # TODO: VAE conversion. We doesn't train it in the most cases, but may be handy for the future --kabachuha
157+
158+ # =========================#
159+ # Text Encoder Conversion #
160+ # =========================#
161+
162+ # IT IS THE SAME CLIP ENCODER, SO JUST COPYPASTING IT --kabachuha
163+
164+ # =========================#
165+ # Text Encoder Conversion #
166+ # =========================#
167+
168+
169+ textenc_conversion_lst = [
170+ # (stable-diffusion, HF Diffusers)
171+ ("resblocks." , "text_model.encoder.layers." ),
172+ ("ln_1" , "layer_norm1" ),
173+ ("ln_2" , "layer_norm2" ),
174+ (".c_fc." , ".fc1." ),
175+ (".c_proj." , ".fc2." ),
176+ (".attn" , ".self_attn" ),
177+ ("ln_final." , "transformer.text_model.final_layer_norm." ),
178+ ("token_embedding.weight" , "transformer.text_model.embeddings.token_embedding.weight" ),
179+ ("positional_embedding" , "transformer.text_model.embeddings.position_embedding.weight" ),
180+ ]
181+ protected = {re .escape (x [1 ]): x [0 ] for x in textenc_conversion_lst }
182+ textenc_pattern = re .compile ("|" .join (protected .keys ()))
183+
184+ # Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
185+ code2idx = {"q" : 0 , "k" : 1 , "v" : 2 }
186+
187+
188+ def convert_text_enc_state_dict_v20 (text_enc_dict ):
189+ print ('Converting the text encoder' )
190+ new_state_dict = {}
191+ capture_qkv_weight = {}
192+ capture_qkv_bias = {}
193+ for k , v in text_enc_dict .items ():
194+ if (
195+ k .endswith (".self_attn.q_proj.weight" )
196+ or k .endswith (".self_attn.k_proj.weight" )
197+ or k .endswith (".self_attn.v_proj.weight" )
198+ ):
199+ k_pre = k [: - len (".q_proj.weight" )]
200+ k_code = k [- len ("q_proj.weight" )]
201+ if k_pre not in capture_qkv_weight :
202+ capture_qkv_weight [k_pre ] = [None , None , None ]
203+ capture_qkv_weight [k_pre ][code2idx [k_code ]] = v
204+ continue
205+
206+ if (
207+ k .endswith (".self_attn.q_proj.bias" )
208+ or k .endswith (".self_attn.k_proj.bias" )
209+ or k .endswith (".self_attn.v_proj.bias" )
210+ ):
211+ k_pre = k [: - len (".q_proj.bias" )]
212+ k_code = k [- len ("q_proj.bias" )]
213+ if k_pre not in capture_qkv_bias :
214+ capture_qkv_bias [k_pre ] = [None , None , None ]
215+ capture_qkv_bias [k_pre ][code2idx [k_code ]] = v
216+ continue
217+
218+ relabelled_key = textenc_pattern .sub (lambda m : protected [re .escape (m .group (0 ))], k )
219+ new_state_dict [relabelled_key ] = v
220+
221+ for k_pre , tensors in capture_qkv_weight .items ():
222+ if None in tensors :
223+ raise Exception ("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing" )
224+ relabelled_key = textenc_pattern .sub (lambda m : protected [re .escape (m .group (0 ))], k_pre )
225+ new_state_dict [relabelled_key + ".in_proj_weight" ] = torch .cat (tensors )
226+
227+ for k_pre , tensors in capture_qkv_bias .items ():
228+ if None in tensors :
229+ raise Exception ("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing" )
230+ relabelled_key = textenc_pattern .sub (lambda m : protected [re .escape (m .group (0 ))], k_pre )
231+ new_state_dict [relabelled_key + ".in_proj_bias" ] = torch .cat (tensors )
232+
233+ return new_state_dict
234+
235+
236+ def convert_text_enc_state_dict (text_enc_dict ):
237+ return text_enc_dict
238+
239+ textenc_conversion_lst = [
240+ # (stable-diffusion, HF Diffusers)
241+ ("resblocks." , "text_model.encoder.layers." ),
242+ ("ln_1" , "layer_norm1" ),
243+ ("ln_2" , "layer_norm2" ),
244+ (".c_fc." , ".fc1." ),
245+ (".c_proj." , ".fc2." ),
246+ (".attn" , ".self_attn" ),
247+ ("ln_final." , "transformer.text_model.final_layer_norm." ),
248+ ("token_embedding.weight" , "transformer.text_model.embeddings.token_embedding.weight" ),
249+ ("positional_embedding" , "transformer.text_model.embeddings.position_embedding.weight" ),
250+ ]
251+ protected = {re .escape (x [1 ]): x [0 ] for x in textenc_conversion_lst }
252+ textenc_pattern = re .compile ("|" .join (protected .keys ()))
253+
254+ # Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
255+ code2idx = {"q" : 0 , "k" : 1 , "v" : 2 }
256+
257+
258+ def convert_text_enc_state_dict_v20 (text_enc_dict ):
259+ new_state_dict = {}
260+ capture_qkv_weight = {}
261+ capture_qkv_bias = {}
262+ for k , v in text_enc_dict .items ():
263+ if (
264+ k .endswith (".self_attn.q_proj.weight" )
265+ or k .endswith (".self_attn.k_proj.weight" )
266+ or k .endswith (".self_attn.v_proj.weight" )
267+ ):
268+ k_pre = k [: - len (".q_proj.weight" )]
269+ k_code = k [- len ("q_proj.weight" )]
270+ if k_pre not in capture_qkv_weight :
271+ capture_qkv_weight [k_pre ] = [None , None , None ]
272+ capture_qkv_weight [k_pre ][code2idx [k_code ]] = v
273+ continue
274+
275+ if (
276+ k .endswith (".self_attn.q_proj.bias" )
277+ or k .endswith (".self_attn.k_proj.bias" )
278+ or k .endswith (".self_attn.v_proj.bias" )
279+ ):
280+ k_pre = k [: - len (".q_proj.bias" )]
281+ k_code = k [- len ("q_proj.bias" )]
282+ if k_pre not in capture_qkv_bias :
283+ capture_qkv_bias [k_pre ] = [None , None , None ]
284+ capture_qkv_bias [k_pre ][code2idx [k_code ]] = v
285+ continue
286+
287+ relabelled_key = textenc_pattern .sub (lambda m : protected [re .escape (m .group (0 ))], k )
288+ new_state_dict [relabelled_key ] = v
289+
290+ for k_pre , tensors in capture_qkv_weight .items ():
291+ if None in tensors :
292+ raise Exception ("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing" )
293+ relabelled_key = textenc_pattern .sub (lambda m : protected [re .escape (m .group (0 ))], k_pre )
294+ new_state_dict [relabelled_key + ".in_proj_weight" ] = torch .cat (tensors )
295+
296+ for k_pre , tensors in capture_qkv_bias .items ():
297+ if None in tensors :
298+ raise Exception ("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing" )
299+ relabelled_key = textenc_pattern .sub (lambda m : protected [re .escape (m .group (0 ))], k_pre )
300+ new_state_dict [relabelled_key + ".in_proj_bias" ] = torch .cat (tensors )
301+
302+ return new_state_dict
303+
304+
305+ def convert_text_enc_state_dict (text_enc_dict ):
306+ return text_enc_dict
307+
308+ if __name__ == "__main__" :
309+ parser = argparse .ArgumentParser ()
310+
311+ parser .add_argument ("--model_path" , default = None , type = str , required = True , help = "Path to the model to convert." )
312+ parser .add_argument ("--checkpoint_path" , default = None , type = str , required = True , help = "Path to the output model." )
313+ parser .add_argument ("--half" , action = "store_true" , help = "Save weights in half precision." )
314+ parser .add_argument (
315+ "--use_safetensors" , action = "store_true" , help = "Save weights use safetensors, default is ckpt."
316+ )
317+
318+ args = parser .parse_args ()
319+
320+ assert args .model_path is not None , "Must provide a model path!"
321+
322+ assert args .checkpoint_path is not None , "Must provide a checkpoint path!"
323+
324+ # Path for safetensors
325+ unet_path = osp .join (args .model_path , "unet" , "diffusion_pytorch_model.safetensors" )
326+ vae_path = osp .join (args .model_path , "vae" , "diffusion_pytorch_model.safetensors" )
327+ text_enc_path = osp .join (args .model_path , "text_encoder" , "model.safetensors" )
328+
329+ # Load models from safetensors if it exists, if it doesn't pytorch
330+ if osp .exists (unet_path ):
331+ unet_state_dict = load_file (unet_path , device = "cpu" )
332+ else :
333+ unet_path = osp .join (args .model_path , "unet" , "diffusion_pytorch_model.bin" )
334+ unet_state_dict = torch .load (unet_path , map_location = "cpu" )
335+
336+ if osp .exists (vae_path ):
337+ vae_state_dict = load_file (vae_path , device = "cpu" )
338+ else :
339+ vae_path = osp .join (args .model_path , "vae" , "diffusion_pytorch_model.bin" )
340+ vae_state_dict = torch .load (vae_path , map_location = "cpu" )
341+
342+ if osp .exists (text_enc_path ):
343+ text_enc_dict = load_file (text_enc_path , device = "cpu" )
344+ else :
345+ text_enc_path = osp .join (args .model_path , "text_encoder" , "pytorch_model.bin" )
346+ text_enc_dict = torch .load (text_enc_path , map_location = "cpu" )
347+
348+ # Convert the UNet model
349+ unet_state_dict = convert_unet_state_dict (unet_state_dict )
350+ unet_state_dict = {"model.diffusion_model." + k : v for k , v in unet_state_dict .items ()}
351+
352+ # Convert the VAE model
353+ # vae_state_dict = convert_vae_state_dict(vae_state_dict)
354+ # vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
355+
356+ # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
357+ is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
358+
359+ if is_v20_model :
360+
361+ # MODELSCOPE always uses the 2.X encoder, btw --kabachuha
362+
363+ # Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
364+ text_enc_dict = {"transformer." + k : v for k , v in text_enc_dict .items ()}
365+ text_enc_dict = convert_text_enc_state_dict_v20 (text_enc_dict )
366+ text_enc_dict = {"cond_stage_model.model." + k : v for k , v in text_enc_dict .items ()}
367+ else :
368+ text_enc_dict = convert_text_enc_state_dict (text_enc_dict )
369+ text_enc_dict = {"cond_stage_model.transformer." + k : v for k , v in text_enc_dict .items ()}
370+
371+ # DON'T PUT TOGETHER FOR THE NEW CHECKPOINT AS MODELSCOPE USES THEM IN THE SPLITTED FORM --kabachuha
372+
373+
374+
375+ #state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
376+ print ('Saving UNET' )
377+ state_dict = {** unet_state_dict }
378+
379+ if args .half :
380+ state_dict = {k : v .half () for k , v in state_dict .items ()}
381+
382+ if args .use_safetensors :
383+ save_file (state_dict , args .checkpoint_path )
384+ else :
385+ state_dict = {"state_dict" : state_dict }
386+ torch .save (state_dict , args .checkpoint_path )
0 commit comments