Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 52 additions & 22 deletions src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,24 @@ def _copy_layers(hf_layers, pt_layers):


def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None):
"""Manually copy in relevant values from checkpoint to proxy object

Create a TextEncoder object of the type we THINK was used in the checkpoint.
Then exctract key values from the checkpoint, and hand-copy them into that.
Note: "text_encoder" is more like "text_encoder_type_template"
"""
keys = list(checkpoint.keys())

text_model_dict = {}

remove_prefixes = ["cond_stage_model.transformer", "conditioner.embedders.0.transformer"]

for key in keys:
for prefix in remove_prefixes:
if key.startswith(prefix):
text_model_dict[key[len(prefix + ".") :]] = checkpoint[key]

# For most uses, this will actually be None
if text_encoder is None:
config_name = "openai/clip-vit-large-patch14"
try:
Expand All @@ -807,24 +825,26 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder
raise ValueError(
f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: 'openai/clip-vit-large-patch14'."
)
# Shennanigans to handle special cases like LongCLIP
maxkey = config.max_position_embeddings
# Make this a list, because some checkpoints use
# a different key name, but I dont know them right now
for maxkeyname in ["text_model.embeddings.position_embedding.weight"]:
if maxkeyname in text_model_dict:
maxkey = int(text_model_dict[maxkeyname].shape[0])
logger.debug("maxkey is", maxkey)

if config.max_position_embeddings != maxkey:
logger.debug("changing max_position_embeddings to", maxkey)
config.max_position_embeddings = maxkey


ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
text_model = CLIPTextModel(config)
else:
text_model = text_encoder

keys = list(checkpoint.keys())

text_model_dict = {}

remove_prefixes = ["cond_stage_model.transformer", "conditioner.embedders.0.transformer"]

for key in keys:
for prefix in remove_prefixes:
if key.startswith(prefix):
text_model_dict[key[len(prefix + ".") :]] = checkpoint[key]

if is_accelerate_available():
for param_name, param in text_model_dict.items():
set_module_tensor_to_device(text_model, param_name, "cpu", value=param)
Expand Down Expand Up @@ -1142,6 +1162,20 @@ def convert_controlnet_checkpoint(

return controlnet

def get_extractor(model, local):
try:
from transformers import AutoImageProcessor
return AutoImageProcessor.from_pretrained(
model,
local_files_only=local,
)
except Exception:
# fallback for older transformers
from transformers import CLIPImageProcessor
return CLIPImageProcessor.from_pretrained(
model,
local_files_only=local,
)

def download_from_original_stable_diffusion_ckpt(
checkpoint_path_or_dict: Union[str, Dict[str, torch.Tensor]],
Expand Down Expand Up @@ -1664,14 +1698,10 @@ def download_from_original_stable_diffusion_ckpt(
raise ValueError(
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'."
)
try:
feature_extractor = AutoFeatureExtractor.from_pretrained(
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only
)
except Exception:
raise ValueError(
f"With local_files_only set to {local_files_only}, you must first locally save the feature_extractor in the following path: 'CompVis/stable-diffusion-safety-checker'."
)
feature_extractor = get_extractor(
"CompVis/stable-diffusion-safety-checker",
local_files_only
)
pipe = PaintByExamplePipeline(
vae=vae,
image_encoder=vision_model,
Expand Down Expand Up @@ -1699,10 +1729,10 @@ def download_from_original_stable_diffusion_ckpt(
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only
)
feature_extractor = AutoFeatureExtractor.from_pretrained(
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only
feature_extractor = get_extractor(
"CompVis/stable-diffusion-safety-checker",
local_files_only
)

if controlnet:
pipe = pipeline_class(
vae=vae,
Expand Down