diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 76a850b63c4e..702ed5daf4eb 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -14,6 +14,7 @@ import importlib import inspect import os +import sys import traceback import warnings from collections import OrderedDict @@ -28,10 +29,16 @@ from typing_extensions import Self from ..configuration_utils import ConfigMixin, FrozenDict -from ..pipelines.pipeline_loading_utils import _fetch_class_library_tuple, simple_get_class_obj +from ..pipelines.pipeline_loading_utils import ( + LOADABLE_CLASSES, + _fetch_class_library_tuple, + _unwrap_model, + simple_get_class_obj, +) from ..utils import PushToHubMixin, is_accelerate_available, logging from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code from ..utils.hub_utils import load_or_create_model_card, populate_model_card +from ..utils.torch_utils import is_compiled_module from .components_manager import ComponentsManager from .modular_pipeline_utils import ( MODULAR_MODEL_CARD_TEMPLATE, @@ -1819,29 +1826,111 @@ def from_pretrained( ) return pipeline - def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs): + def save_pretrained( + self, + save_directory: str | os.PathLike, + safe_serialization: bool = True, + variant: str | None = None, + max_shard_size: int | str | None = None, + push_to_hub: bool = False, + **kwargs, + ): """ - Save the pipeline to a directory. It does not save components, you need to save them separately. + Save the pipeline and all its components to a directory, so that it can be re-loaded using the + [`~ModularPipeline.from_pretrained`] class method. Args: save_directory (`str` or `os.PathLike`): - Path to the directory where the pipeline will be saved. - push_to_hub (`bool`, optional): - Whether to push the pipeline to the huggingface hub. - **kwargs: Additional arguments passed to `save_config()` method - """ + Directory to save the pipeline to. Will be created if it doesn't exist. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + variant (`str`, *optional*): + If specified, weights are saved in the format `pytorch_model..bin`. + max_shard_size (`int` or `str`, defaults to `None`): + The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size + lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`). + If expressed as an integer, the unit is bytes. + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether to push the pipeline to the Hugging Face model hub after saving it. + **kwargs: Additional keyword arguments passed along to the push to hub method. + """ + overwrite_modular_index = kwargs.pop("overwrite_modular_index", False) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + + for component_name, component_spec in self._component_specs.items(): + sub_model = getattr(self, component_name, None) + if sub_model is None: + continue + + model_cls = sub_model.__class__ + if is_compiled_module(sub_model): + sub_model = _unwrap_model(sub_model) + model_cls = sub_model.__class__ + + save_method_name = None + for library_name, library_classes in LOADABLE_CLASSES.items(): + if library_name in sys.modules: + library = importlib.import_module(library_name) + else: + logger.info( + f"{library_name} is not installed. Cannot save {component_name} as {library_classes} from {library_name}" + ) + continue + + for base_class, save_load_methods in library_classes.items(): + class_candidate = getattr(library, base_class, None) + if class_candidate is not None and issubclass(model_cls, class_candidate): + save_method_name = save_load_methods[0] + break + if save_method_name is not None: + break + + if save_method_name is None: + logger.warning(f"self.{component_name}={sub_model} of type {type(sub_model)} cannot be saved.") + continue + + save_method = getattr(sub_model, save_method_name) + save_method_signature = inspect.signature(save_method) + save_method_accept_safe = "safe_serialization" in save_method_signature.parameters + save_method_accept_variant = "variant" in save_method_signature.parameters + save_method_accept_max_shard_size = "max_shard_size" in save_method_signature.parameters + + save_kwargs = {} + if save_method_accept_safe: + save_kwargs["safe_serialization"] = safe_serialization + if save_method_accept_variant: + save_kwargs["variant"] = variant + if save_method_accept_max_shard_size and max_shard_size is not None: + save_kwargs["max_shard_size"] = max_shard_size + + save_method(os.path.join(save_directory, component_name), **save_kwargs) + if push_to_hub: commit_message = kwargs.pop("commit_message", None) private = kwargs.pop("private", None) create_pr = kwargs.pop("create_pr", False) token = kwargs.pop("token", None) - repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id - # Generate modular pipeline card content - card_content = generate_modular_model_card_content(self.blocks) + if overwrite_modular_index: + for component_name, component_spec in self._component_specs.items(): + if component_spec.default_creation_method != "from_pretrained": + continue + if component_name not in self.config: + continue + + library, class_name, component_spec_dict = self.config[component_name] + component_spec_dict["pretrained_model_name_or_path"] = repo_id + component_spec_dict["subfolder"] = component_name + if variant is not None and "variant" in component_spec_dict: + component_spec_dict["variant"] = variant + + self.register_to_config(**{component_name: (library, class_name, component_spec_dict)}) + + self.save_config(save_directory=save_directory) - # Create a new empty model card and eventually tag it + if push_to_hub: + card_content = generate_modular_model_card_content(self.blocks) model_card = load_or_create_model_card( repo_id, token=token, @@ -1850,13 +1939,8 @@ def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = is_modular=True, ) model_card = populate_model_card(model_card, tags=card_content["tags"]) - model_card.save(os.path.join(save_directory, "README.md")) - # YiYi TODO: maybe order the json file to make it more readable: configs first, then components - self.save_config(save_directory=save_directory) - - if push_to_hub: self._upload_folder( save_directory, repo_id,