diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 6c0221d2092a..566b2167362b 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -799,6 +799,24 @@ def _copy_layers(hf_layers, pt_layers): def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None): + """Manually copy in relevant values from checkpoint to proxy object + + Create a TextEncoder object of the type we THINK was used in the checkpoint. + Then exctract key values from the checkpoint, and hand-copy them into that. + Note: "text_encoder" is more like "text_encoder_type_template" + """ + keys = list(checkpoint.keys()) + + text_model_dict = {} + + remove_prefixes = ["cond_stage_model.transformer", "conditioner.embedders.0.transformer"] + + for key in keys: + for prefix in remove_prefixes: + if key.startswith(prefix): + text_model_dict[key[len(prefix + ".") :]] = checkpoint[key] + + # For most uses, this will actually be None if text_encoder is None: config_name = "openai/clip-vit-large-patch14" try: @@ -807,6 +825,19 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder raise ValueError( f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: 'openai/clip-vit-large-patch14'." ) + # Shennanigans to handle special cases like LongCLIP + maxkey = config.max_position_embeddings + # Make this a list, because some checkpoints use + # a different key name, but I dont know them right now + for maxkeyname in ["text_model.embeddings.position_embedding.weight"]: + if maxkeyname in text_model_dict: + maxkey = int(text_model_dict[maxkeyname].shape[0]) + logger.debug("maxkey is", maxkey) + + if config.max_position_embeddings != maxkey: + logger.debug("changing max_position_embeddings to", maxkey) + config.max_position_embeddings = maxkey + ctx = init_empty_weights if is_accelerate_available() else nullcontext with ctx(): @@ -814,17 +845,6 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder else: text_model = text_encoder - keys = list(checkpoint.keys()) - - text_model_dict = {} - - remove_prefixes = ["cond_stage_model.transformer", "conditioner.embedders.0.transformer"] - - for key in keys: - for prefix in remove_prefixes: - if key.startswith(prefix): - text_model_dict[key[len(prefix + ".") :]] = checkpoint[key] - if is_accelerate_available(): for param_name, param in text_model_dict.items(): set_module_tensor_to_device(text_model, param_name, "cpu", value=param) @@ -1142,6 +1162,20 @@ def convert_controlnet_checkpoint( return controlnet +def get_extractor(model, local): + try: + from transformers import AutoImageProcessor + return AutoImageProcessor.from_pretrained( + model, + local_files_only=local, + ) + except Exception: + # fallback for older transformers + from transformers import CLIPImageProcessor + return CLIPImageProcessor.from_pretrained( + model, + local_files_only=local, + ) def download_from_original_stable_diffusion_ckpt( checkpoint_path_or_dict: Union[str, Dict[str, torch.Tensor]], @@ -1664,14 +1698,10 @@ def download_from_original_stable_diffusion_ckpt( raise ValueError( f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'." ) - try: - feature_extractor = AutoFeatureExtractor.from_pretrained( - "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only - ) - except Exception: - raise ValueError( - f"With local_files_only set to {local_files_only}, you must first locally save the feature_extractor in the following path: 'CompVis/stable-diffusion-safety-checker'." - ) + feature_extractor = get_extractor( + "CompVis/stable-diffusion-safety-checker", + local_files_only + ) pipe = PaintByExamplePipeline( vae=vae, image_encoder=vision_model, @@ -1699,10 +1729,10 @@ def download_from_original_stable_diffusion_ckpt( safety_checker = StableDiffusionSafetyChecker.from_pretrained( "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only ) - feature_extractor = AutoFeatureExtractor.from_pretrained( - "CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only + feature_extractor = get_extractor( + "CompVis/stable-diffusion-safety-checker", + local_files_only ) - if controlnet: pipe = pipeline_class( vae=vae,