diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 63e50af61771..5b0f8a3a0d64 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -675,6 +675,7 @@ def save_pretrained( variant: Optional[str] = None, max_shard_size: Union[int, str] = "10GB", push_to_hub: bool = False, + use_flashpack: bool = False, **kwargs, ): """ @@ -707,6 +708,10 @@ def save_pretrained( Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the repository you want to push to with `repo_id` (will default to the name of `save_directory` in your namespace). + use_flashpack (`bool`, *optional*, defaults to `False`): + Whether to save the model in [FlashPack](https://github.com/fal-ai/flashpack) format. + FlashPack is a binary format that allows for faster loading. + Requires the `flashpack` library to be installed. kwargs (`Dict[str, Any]`, *optional*): Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ @@ -734,7 +739,15 @@ def save_pretrained( ) os.makedirs(save_directory, exist_ok=True) - + if use_flashpack: + if not is_main_process: + return + from ..utils.flashpack_utils import save_flashpack + save_flashpack( + self, + save_directory, + variant=variant, + ) if push_to_hub: commit_message = kwargs.pop("commit_message", None) private = kwargs.pop("private", None) @@ -939,6 +952,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors` weights. If set to `False`, `safetensors` weights are not loaded. + use_flashpack (`bool`, *optional*, defaults to `False`): + If set to `True`, the model is first loaded from `flashpack` (https://github.com/fal-ai/flashpack) weights if a compatible `.flashpack` file + is found. If flashpack is unavailable or the `.flashpack` file cannot be used, automatic fallback to + the standard loading path (for example, `safetensors`). disable_mmap ('bool', *optional*, defaults to 'False'): Whether to disable mmap when loading a Safetensors model. This option can perform better when the model is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well. @@ -982,6 +999,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) variant = kwargs.pop("variant", None) use_safetensors = kwargs.pop("use_safetensors", None) + use_flashpack = kwargs.pop("use_flashpack", False) quantization_config = kwargs.pop("quantization_config", None) dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None) disable_mmap = kwargs.pop("disable_mmap", False) @@ -1199,26 +1217,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model model = load_flax_checkpoint_in_pytorch_model(model, resolved_model_file) + else: - # in the case it is sharded, we have already the index - if is_sharded: - resolved_model_file, sharded_metadata = _get_checkpoint_shard_files( - pretrained_model_name_or_path, - index_file, - cache_dir=cache_dir, - proxies=proxies, - local_files_only=local_files_only, - token=token, - user_agent=user_agent, - revision=revision, - subfolder=subfolder or "", - dduf_entries=dduf_entries, - ) - elif use_safetensors: + flashpack_file = None + if use_flashpack: try: - resolved_model_file = _get_model_file( + flashpack_file = _get_model_file( pretrained_model_name_or_path, - weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), + weights_name=_add_variant("model.flashpack", variant), cache_dir=cache_dir, force_download=force_download, proxies=proxies, @@ -1230,33 +1236,68 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P commit_hash=commit_hash, dduf_entries=dduf_entries, ) + except EnvironmentError: + flashpack_file = None - except IOError as e: - logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}") - if not allow_pickle: - raise - logger.warning( - "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead." + if flashpack_file is None: + # in the case it is sharded, we have already the index + if is_sharded: + resolved_model_file, sharded_metadata = _get_checkpoint_shard_files( + pretrained_model_name_or_path, + index_file, + cache_dir=cache_dir, + proxies=proxies, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder or "", + dduf_entries=dduf_entries, ) + elif use_safetensors: + logger.warning("Trying to load model weights with safetensors format.") + try: + resolved_model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + dduf_entries=dduf_entries, + ) - if resolved_model_file is None and not is_sharded: - resolved_model_file = _get_model_file( - pretrained_model_name_or_path, - weights_name=_add_variant(WEIGHTS_NAME, variant), - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - commit_hash=commit_hash, - dduf_entries=dduf_entries, - ) + except IOError as e: + logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}") + if not allow_pickle: + raise + logger.warning( + "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead." + ) + + if resolved_model_file is None and not is_sharded: + resolved_model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + dduf_entries=dduf_entries, + ) - if not isinstance(resolved_model_file, list): - resolved_model_file = [resolved_model_file] + if not isinstance(resolved_model_file, list): + resolved_model_file = [resolved_model_file] # set dtype to instantiate the model under: # 1. If torch_dtype is not None, we use that dtype @@ -1280,6 +1321,28 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P if dtype_orig is not None: torch.set_default_dtype(dtype_orig) + if flashpack_file is not None: + from ..utils.flashpack_utils import load_flashpack + # Even when using FlashPack, we preserve `low_cpu_mem_usage` behavior by initializing + # the model with meta tensors. Since FlashPack cannot write into meta tensors, we + # explicitly materialize parameters before loading to ensure correctness and parity + # with the standard loading path. + if any(p.device.type == "meta" for p in model.parameters()): + model.to_empty(device="cpu") + load_flashpack(model, flashpack_file) + model.register_to_config(_name_or_path=pretrained_model_name_or_path) + model.eval() + + if output_loading_info: + return model, { + "missing_keys": [], + "unexpected_keys": [], + "mismatched_keys": [], + "error_msgs": [], + } + + return model + state_dict = None if not is_sharded: # Time to load the checkpoint @@ -1327,7 +1390,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P keep_in_fp32_modules=keep_in_fp32_modules, dduf_entries=dduf_entries, is_parallel_loading_enabled=is_parallel_loading_enabled, - disable_mmap=disable_mmap, ) loading_info = { "missing_keys": missing_keys, @@ -1372,9 +1434,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P if output_loading_info: return model, loading_info + + logger.warning(f"Model till end {pretrained_model_name_or_path} loaded successfully") return model + # Adapted from `transformers`. @wraps(torch.nn.Module.cuda) def cuda(self, *args, **kwargs): diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 57d4eaa8f89e..2fbb1028f9da 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -756,6 +756,7 @@ def load_sub_model( low_cpu_mem_usage: bool, cached_folder: Union[str, os.PathLike], use_safetensors: bool, + use_flashpack: bool, dduf_entries: Optional[Dict[str, DDUFEntry]], provider_options: Any, disable_mmap: bool, @@ -838,6 +839,9 @@ def load_sub_model( loading_kwargs["variant"] = model_variants.pop(name, None) loading_kwargs["use_safetensors"] = use_safetensors + if is_diffusers_model: + loading_kwargs["use_flashpack"] = use_flashpack + if from_flax: loading_kwargs["from_flax"] = True diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index b96305c74131..34e42f42862f 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -243,6 +243,7 @@ def save_pretrained( variant: Optional[str] = None, max_shard_size: Optional[Union[int, str]] = None, push_to_hub: bool = False, + use_flashpack: bool = False, **kwargs, ): """ @@ -268,7 +269,9 @@ class implements both a save and loading method. The pipeline is easily reloaded Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the repository you want to push to with `repo_id` (will default to the name of `save_directory` in your namespace). - + use_flashpack (`bool`, *optional*, defaults to `False`): + Whether or not to use `flashpack` to save the model weights. Requires the `flashpack` library: `pip install + flashpack`. kwargs (`Dict[str, Any]`, *optional*): Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ @@ -340,6 +343,7 @@ def is_saveable_module(name, value): save_method_accept_safe = "safe_serialization" in save_method_signature.parameters save_method_accept_variant = "variant" in save_method_signature.parameters save_method_accept_max_shard_size = "max_shard_size" in save_method_signature.parameters + save_method_accept_flashpack = "use_flashpack" in save_method_signature.parameters save_kwargs = {} if save_method_accept_safe: @@ -349,6 +353,8 @@ def is_saveable_module(name, value): if save_method_accept_max_shard_size and max_shard_size is not None: # max_shard_size is expected to not be None in ModelMixin save_kwargs["max_shard_size"] = max_shard_size + if save_method_accept_flashpack: + save_kwargs["use_flashpack"] = use_flashpack save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs) @@ -707,6 +713,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P If set to `None`, the safetensors weights are downloaded if they're available **and** if the safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors weights. If set to `False`, safetensors weights are not loaded. + use_flashpack (`bool`, *optional*, defaults to `False`): + If set to `True`, the model is first loaded from `flashpack` weights if a compatible `.flashpack` file + is found. If flashpack is unavailable or the `.flashpack` file cannot be used, automatic fallback to + the standard loading path (for example, `safetensors`). Requires the `flashpack` library: `pip install + flashpack`. use_onnx (`bool`, *optional*, defaults to `None`): If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is @@ -772,6 +783,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P variant = kwargs.pop("variant", None) dduf_file = kwargs.pop("dduf_file", None) use_safetensors = kwargs.pop("use_safetensors", None) + use_flashpack = kwargs.pop("use_flashpack", False) use_onnx = kwargs.pop("use_onnx", None) load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) quantization_config = kwargs.pop("quantization_config", None) @@ -1061,6 +1073,7 @@ def load_module(name, value): low_cpu_mem_usage=low_cpu_mem_usage, cached_folder=cached_folder, use_safetensors=use_safetensors, + use_flashpack=use_flashpack, dduf_entries=dduf_entries, provider_options=provider_options, disable_mmap=disable_mmap, diff --git a/src/diffusers/utils/flashpack_utils.py b/src/diffusers/utils/flashpack_utils.py new file mode 100644 index 000000000000..14031a7c543a --- /dev/null +++ b/src/diffusers/utils/flashpack_utils.py @@ -0,0 +1,83 @@ +import json +import os +from typing import Optional +from .import_utils import is_flashpack_available +from .logging import get_logger +from ..utils import _add_variant + +logger = get_logger(__name__) + +def save_flashpack( + model, + save_directory: str, + variant: Optional[str] = None, + is_main_process: bool = True, +): + """ + Save model weights in FlashPack format along with a metadata config. + + Args: + model: Diffusers model instance + save_directory (`str`): Directory to save weights + variant (`str`, *optional*): Model variant + """ + if not is_flashpack_available(): + raise ImportError( + "The `use_flashpack=True` argument requires the `flashpack` package. " + "Install it with `pip install flashpack`." + ) + + from flashpack import pack_to_file + + os.makedirs(save_directory, exist_ok=True) + + weights_name = _add_variant("model.flashpack", variant) + weights_path = os.path.join(save_directory, weights_name) + config_path = os.path.join(save_directory, "flashpack_config.json") + + try: + target_dtype = getattr(model, "dtype", None) + logger.warning(f"Dtype used for FlashPack save: {target_dtype}") + + # 1. Save binary weights + pack_to_file(model, weights_path, target_dtype=target_dtype) + + # 2. Save config metadata (best-effort) + if hasattr(model, "config"): + try: + if hasattr(model.config, "to_dict"): + config_data = model.config.to_dict() + else: + config_data = dict(model.config) + + with open(config_path, "w") as f: + json.dump(config_data, f, indent=4) + + except Exception as config_err: + logger.warning( + f"FlashPack weights saved, but config serialization failed: {config_err}" + ) + + except Exception as e: + logger.error(f"Failed to save weights in FlashPack format: {e}") + raise + +def load_flashpack(model, flashpack_file: str): + """ + Assign FlashPack weights from a file into an initialized PyTorch model. + """ + if not is_flashpack_available(): + raise ImportError( + "FlashPack weights require the `flashpack` package. " + "Install with `pip install flashpack`." + ) + + from flashpack import assign_from_file + logger.warning(f"Loading FlashPack weights from {flashpack_file}") + + try: + assign_from_file(model, flashpack_file) + except Exception as e: + raise RuntimeError( + f"Failed to load FlashPack weights from {flashpack_file}" + ) from e \ No newline at end of file diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 425c360a3110..2b99e42a26f7 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -231,7 +231,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b _kornia_available, _kornia_version = _is_package_available("kornia") _nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True) _av_available, _av_version = _is_package_available("av") - +_flashpack_available, _flashpack_version = _is_package_available("flashpack") def is_torch_available(): return _torch_available @@ -424,6 +424,8 @@ def is_kornia_available(): def is_av_available(): return _av_available +def is_flashpack_available(): + return _flashpack_available # docstyle-ignore FLAX_IMPORT_ERROR = """ @@ -941,6 +943,14 @@ def is_aiter_version(operation: str, version: str): return False return compare_versions(parse(_aiter_version), operation, version) +@cache +def is_flashpack_version(operation: str, version: str): + """ + Compares the current flashpack version to a given reference with an operation. + """ + if not _flashpack_available: + return False + return compare_versions(parse(_flashpack_version), operation, version) def get_objects_from_module(module): """