From ec541906c5167d95c38db94d831eff96f0a55fd0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cdevanshi00=E2=80=9D?= <“devanshi7309@gmail.com”> Date: Mon, 19 Jan 2026 14:52:15 +0530 Subject: [PATCH 1/4] added fal-flashpack support --- src/diffusers/models/modeling_utils.py | 111 ++++++++++++++++++ .../pipelines/pipeline_loading_utils.py | 4 + src/diffusers/pipelines/pipeline_utils.py | 15 ++- 3 files changed, 129 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 63e50af61771..a4bfa47f6d31 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -675,6 +675,7 @@ def save_pretrained( variant: Optional[str] = None, max_shard_size: Union[int, str] = "10GB", push_to_hub: bool = False, + use_flashpack: bool = False, **kwargs, ): """ @@ -707,6 +708,9 @@ def save_pretrained( Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the repository you want to push to with `repo_id` (will default to the name of `save_directory` in your namespace). + use_flashpack (`bool`, *optional*, defaults to `False`): + Whether to save the model in FlashPack format. FlashPack is a binary format that allows for faster + loading. Requires the `flashpack` library to be installed. kwargs (`Dict[str, Any]`, *optional*): Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ @@ -734,6 +738,44 @@ def save_pretrained( ) os.makedirs(save_directory, exist_ok=True) + if use_flashpack: + if not is_main_process: + return + + try: + from flashpack import pack_to_file + import json # Ensure json is imported + except ImportError: + raise ImportError("The `use_flashpack=True` argument requires the `flashpack` library.") + + flashpack_weights_name = _add_variant("model.flashpack", variant) + save_path = os.path.join(save_directory, flashpack_weights_name) + # Define the config path - this is what your benchmark script is looking for + config_save_path = os.path.join(save_directory, "flashpack_config.json") + + try: + target_dtype = getattr(self, "dtype", None) + logger.warning(f"Dtype used: {target_dtype}") + # 1. Save the binary weights + pack_to_file(self, save_path, target_dtype=target_dtype) + + # 2. Save the metadata config + if hasattr(self, "config"): + try: + # Attempt to get dictionary representation + if hasattr(self.config, "to_dict"): + config_data = self.config.to_dict() + else: + config_data = dict(self.config) + + with open(config_save_path, "w") as f: + json.dump(config_data, f, indent=4) + except Exception as config_err: + logger.warning(f"Weights saved but config serialization failed: {config_err}") + + logger.info(f"Model weights saved in FlashPack format at {save_path}") + except Exception as e: + logger.error(f"Failed to save weights in FlashPack format: {e}") if push_to_hub: commit_message = kwargs.pop("commit_message", None) @@ -939,6 +981,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors` weights. If set to `False`, `safetensors` weights are not loaded. + use_flashpack (`bool`, *optional*, defaults to `False`): + If set to `True`, the model is first loaded from `flashpack` weights if a compatible `.flashpack` file + is found. If flashpack is unavailable or the `.flashpack` file cannot be used, automatic fallback to + the standard loading path (for example, `safetensors`). Requires the `flashpack` library: `pip install + flashpack`. disable_mmap ('bool', *optional*, defaults to 'False'): Whether to disable mmap when loading a Safetensors model. This option can perform better when the model is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well. @@ -982,6 +1029,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) variant = kwargs.pop("variant", None) use_safetensors = kwargs.pop("use_safetensors", None) + use_flashpack = kwargs.pop("use_flashpack", False) quantization_config = kwargs.pop("quantization_config", None) dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None) disable_mmap = kwargs.pop("disable_mmap", False) @@ -1200,6 +1248,69 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P model = load_flax_checkpoint_in_pytorch_model(model, resolved_model_file) else: + # If we are using `use_flashpack`, we try to load the model from flashpack first. + # If flashpack is not available or the file cannot be loaded, we fall back to + # the standard loading path (e.g. safetensors or PyTorch). + if use_flashpack: + try: + from flashpack import assign_from_file + except ImportError: + pass + else: + flashpack_weights_name = _add_variant("model.flashpack", variant) + try: + flashpack_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=flashpack_weights_name, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + except EnvironmentError: + pass + else: + dtype_orig = None + if torch_dtype is not None and torch_dtype != getattr(torch, "float8_e4m3fn", None): + if not isinstance(torch_dtype, torch.dtype): + raise ValueError( + f"{torch_dtype} needs to be a `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." + ) + dtype_orig = cls._set_default_torch_dtype(torch_dtype) + + with no_init_weights(): + model = cls.from_config(config, **unused_kwargs) + + if dtype_orig is not None: + torch.set_default_dtype(dtype_orig) + + try: + assign_from_file(model, flashpack_file) + model.register_to_config(_name_or_path=pretrained_model_name_or_path) + + if torch_dtype is not None and torch_dtype != getattr(torch, "float8_e4m3fn", None): + model = model.to(torch_dtype) + + model.eval() + + if output_loading_info: + loading_info = { + "missing_keys": [], + "unexpected_keys": [], + "mismatched_keys": [], + "error_msgs": [], + } + return model, loading_info + + return model + + except Exception: + pass # in the case it is sharded, we have already the index if is_sharded: resolved_model_file, sharded_metadata = _get_checkpoint_shard_files( diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 57d4eaa8f89e..2fbb1028f9da 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -756,6 +756,7 @@ def load_sub_model( low_cpu_mem_usage: bool, cached_folder: Union[str, os.PathLike], use_safetensors: bool, + use_flashpack: bool, dduf_entries: Optional[Dict[str, DDUFEntry]], provider_options: Any, disable_mmap: bool, @@ -838,6 +839,9 @@ def load_sub_model( loading_kwargs["variant"] = model_variants.pop(name, None) loading_kwargs["use_safetensors"] = use_safetensors + if is_diffusers_model: + loading_kwargs["use_flashpack"] = use_flashpack + if from_flax: loading_kwargs["from_flax"] = True diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index b96305c74131..34e42f42862f 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -243,6 +243,7 @@ def save_pretrained( variant: Optional[str] = None, max_shard_size: Optional[Union[int, str]] = None, push_to_hub: bool = False, + use_flashpack: bool = False, **kwargs, ): """ @@ -268,7 +269,9 @@ class implements both a save and loading method. The pipeline is easily reloaded Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the repository you want to push to with `repo_id` (will default to the name of `save_directory` in your namespace). - + use_flashpack (`bool`, *optional*, defaults to `False`): + Whether or not to use `flashpack` to save the model weights. Requires the `flashpack` library: `pip install + flashpack`. kwargs (`Dict[str, Any]`, *optional*): Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ @@ -340,6 +343,7 @@ def is_saveable_module(name, value): save_method_accept_safe = "safe_serialization" in save_method_signature.parameters save_method_accept_variant = "variant" in save_method_signature.parameters save_method_accept_max_shard_size = "max_shard_size" in save_method_signature.parameters + save_method_accept_flashpack = "use_flashpack" in save_method_signature.parameters save_kwargs = {} if save_method_accept_safe: @@ -349,6 +353,8 @@ def is_saveable_module(name, value): if save_method_accept_max_shard_size and max_shard_size is not None: # max_shard_size is expected to not be None in ModelMixin save_kwargs["max_shard_size"] = max_shard_size + if save_method_accept_flashpack: + save_kwargs["use_flashpack"] = use_flashpack save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs) @@ -707,6 +713,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P If set to `None`, the safetensors weights are downloaded if they're available **and** if the safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors weights. If set to `False`, safetensors weights are not loaded. + use_flashpack (`bool`, *optional*, defaults to `False`): + If set to `True`, the model is first loaded from `flashpack` weights if a compatible `.flashpack` file + is found. If flashpack is unavailable or the `.flashpack` file cannot be used, automatic fallback to + the standard loading path (for example, `safetensors`). Requires the `flashpack` library: `pip install + flashpack`. use_onnx (`bool`, *optional*, defaults to `None`): If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is @@ -772,6 +783,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P variant = kwargs.pop("variant", None) dduf_file = kwargs.pop("dduf_file", None) use_safetensors = kwargs.pop("use_safetensors", None) + use_flashpack = kwargs.pop("use_flashpack", False) use_onnx = kwargs.pop("use_onnx", None) load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) quantization_config = kwargs.pop("quantization_config", None) @@ -1061,6 +1073,7 @@ def load_module(name, value): low_cpu_mem_usage=low_cpu_mem_usage, cached_folder=cached_folder, use_safetensors=use_safetensors, + use_flashpack=use_flashpack, dduf_entries=dduf_entries, provider_options=provider_options, disable_mmap=disable_mmap, From e5bb10cfe10dce1e806e442993f4171cc51ad426 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cdevanshi00=E2=80=9D?= <“devanshi7309@gmail.com”> Date: Wed, 21 Jan 2026 04:22:50 +0530 Subject: [PATCH 2/4] review comments resolved --- src/diffusers/models/modeling_utils.py | 148 ++++++++----------------- src/diffusers/utils/flashpack_utils.py | 83 ++++++++++++++ src/diffusers/utils/import_utils.py | 12 +- 3 files changed, 142 insertions(+), 101 deletions(-) create mode 100644 src/diffusers/utils/flashpack_utils.py diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index a4bfa47f6d31..62fbc41cd55f 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -708,9 +708,10 @@ def save_pretrained( Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the repository you want to push to with `repo_id` (will default to the name of `save_directory` in your namespace). - use_flashpack (`bool`, *optional*, defaults to `False`): - Whether to save the model in FlashPack format. FlashPack is a binary format that allows for faster - loading. Requires the `flashpack` library to be installed. + use_flashpack (`bool`, *optional*, defaults to `False`): + Whether to save the model in [FlashPack](https://github.com/fal-ai/flashpack) format. + FlashPack is a binary format that allows for faster loading. + Requires the `flashpack` library to be installed. kwargs (`Dict[str, Any]`, *optional*): Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ @@ -741,42 +742,12 @@ def save_pretrained( if use_flashpack: if not is_main_process: return - - try: - from flashpack import pack_to_file - import json # Ensure json is imported - except ImportError: - raise ImportError("The `use_flashpack=True` argument requires the `flashpack` library.") - - flashpack_weights_name = _add_variant("model.flashpack", variant) - save_path = os.path.join(save_directory, flashpack_weights_name) - # Define the config path - this is what your benchmark script is looking for - config_save_path = os.path.join(save_directory, "flashpack_config.json") - - try: - target_dtype = getattr(self, "dtype", None) - logger.warning(f"Dtype used: {target_dtype}") - # 1. Save the binary weights - pack_to_file(self, save_path, target_dtype=target_dtype) - - # 2. Save the metadata config - if hasattr(self, "config"): - try: - # Attempt to get dictionary representation - if hasattr(self.config, "to_dict"): - config_data = self.config.to_dict() - else: - config_data = dict(self.config) - - with open(config_save_path, "w") as f: - json.dump(config_data, f, indent=4) - except Exception as config_err: - logger.warning(f"Weights saved but config serialization failed: {config_err}") - - logger.info(f"Model weights saved in FlashPack format at {save_path}") - except Exception as e: - logger.error(f"Failed to save weights in FlashPack format: {e}") - + from ..utils.flashpack_utils import save_flashpack + save_flashpack( + self, + save_directory, + variant=variant, + ) if push_to_hub: commit_message = kwargs.pop("commit_message", None) private = kwargs.pop("private", None) @@ -982,10 +953,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors` weights. If set to `False`, `safetensors` weights are not loaded. use_flashpack (`bool`, *optional*, defaults to `False`): - If set to `True`, the model is first loaded from `flashpack` weights if a compatible `.flashpack` file + If set to `True`, the model is first loaded from `flashpack` (https://github.com/fal-ai/flashpack) weights if a compatible `.flashpack` file is found. If flashpack is unavailable or the `.flashpack` file cannot be used, automatic fallback to - the standard loading path (for example, `safetensors`). Requires the `flashpack` library: `pip install - flashpack`. + the standard loading path (for example, `safetensors`). disable_mmap ('bool', *optional*, defaults to 'False'): Whether to disable mmap when loading a Safetensors model. This option can perform better when the model is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well. @@ -1252,65 +1222,43 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # If flashpack is not available or the file cannot be loaded, we fall back to # the standard loading path (e.g. safetensors or PyTorch). if use_flashpack: + weights_name = _add_variant("model.flashpack", variant) + try: - from flashpack import assign_from_file - except ImportError: - pass - else: - flashpack_weights_name = _add_variant("model.flashpack", variant) - try: - flashpack_file = _get_model_file( - pretrained_model_name_or_path, - weights_name=flashpack_weights_name, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - commit_hash=commit_hash, - ) - except EnvironmentError: - pass - else: - dtype_orig = None - if torch_dtype is not None and torch_dtype != getattr(torch, "float8_e4m3fn", None): - if not isinstance(torch_dtype, torch.dtype): - raise ValueError( - f"{torch_dtype} needs to be a `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." - ) - dtype_orig = cls._set_default_torch_dtype(torch_dtype) - - with no_init_weights(): - model = cls.from_config(config, **unused_kwargs) - - if dtype_orig is not None: - torch.set_default_dtype(dtype_orig) - - try: - assign_from_file(model, flashpack_file) - model.register_to_config(_name_or_path=pretrained_model_name_or_path) - - if torch_dtype is not None and torch_dtype != getattr(torch, "float8_e4m3fn", None): - model = model.to(torch_dtype) - - model.eval() - - if output_loading_info: - loading_info = { - "missing_keys": [], - "unexpected_keys": [], - "mismatched_keys": [], - "error_msgs": [], - } - return model, loading_info - - return model - - except Exception: - pass + resolved_model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=weights_name, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + except EnvironmentError: + resolved_model_file = None + with no_init_weights(): + model = cls.from_config(config, **unused_kwargs) + if resolved_model_file is not None: + from ..utils.flashpack_utils import load_flashpack + load_flashpack(model, resolved_model_file) + model.register_to_config(_name_or_path=pretrained_model_name_or_path) + model.eval() + + if output_loading_info: + loading_info = { + "missing_keys": [], + "unexpected_keys": [], + "mismatched_keys": [], + "error_msgs": [], + } + return model, loading_info + + return model + # in the case it is sharded, we have already the index if is_sharded: resolved_model_file, sharded_metadata = _get_checkpoint_shard_files( diff --git a/src/diffusers/utils/flashpack_utils.py b/src/diffusers/utils/flashpack_utils.py new file mode 100644 index 000000000000..14031a7c543a --- /dev/null +++ b/src/diffusers/utils/flashpack_utils.py @@ -0,0 +1,83 @@ +import json +import os +from typing import Optional +from .import_utils import is_flashpack_available +from .logging import get_logger +from ..utils import _add_variant + +logger = get_logger(__name__) + +def save_flashpack( + model, + save_directory: str, + variant: Optional[str] = None, + is_main_process: bool = True, +): + """ + Save model weights in FlashPack format along with a metadata config. + + Args: + model: Diffusers model instance + save_directory (`str`): Directory to save weights + variant (`str`, *optional*): Model variant + """ + if not is_flashpack_available(): + raise ImportError( + "The `use_flashpack=True` argument requires the `flashpack` package. " + "Install it with `pip install flashpack`." + ) + + from flashpack import pack_to_file + + os.makedirs(save_directory, exist_ok=True) + + weights_name = _add_variant("model.flashpack", variant) + weights_path = os.path.join(save_directory, weights_name) + config_path = os.path.join(save_directory, "flashpack_config.json") + + try: + target_dtype = getattr(model, "dtype", None) + logger.warning(f"Dtype used for FlashPack save: {target_dtype}") + + # 1. Save binary weights + pack_to_file(model, weights_path, target_dtype=target_dtype) + + # 2. Save config metadata (best-effort) + if hasattr(model, "config"): + try: + if hasattr(model.config, "to_dict"): + config_data = model.config.to_dict() + else: + config_data = dict(model.config) + + with open(config_path, "w") as f: + json.dump(config_data, f, indent=4) + + except Exception as config_err: + logger.warning( + f"FlashPack weights saved, but config serialization failed: {config_err}" + ) + + except Exception as e: + logger.error(f"Failed to save weights in FlashPack format: {e}") + raise + +def load_flashpack(model, flashpack_file: str): + """ + Assign FlashPack weights from a file into an initialized PyTorch model. + """ + if not is_flashpack_available(): + raise ImportError( + "FlashPack weights require the `flashpack` package. " + "Install with `pip install flashpack`." + ) + + from flashpack import assign_from_file + logger.warning(f"Loading FlashPack weights from {flashpack_file}") + + try: + assign_from_file(model, flashpack_file) + except Exception as e: + raise RuntimeError( + f"Failed to load FlashPack weights from {flashpack_file}" + ) from e \ No newline at end of file diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 425c360a3110..2b99e42a26f7 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -231,7 +231,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b _kornia_available, _kornia_version = _is_package_available("kornia") _nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True) _av_available, _av_version = _is_package_available("av") - +_flashpack_available, _flashpack_version = _is_package_available("flashpack") def is_torch_available(): return _torch_available @@ -424,6 +424,8 @@ def is_kornia_available(): def is_av_available(): return _av_available +def is_flashpack_available(): + return _flashpack_available # docstyle-ignore FLAX_IMPORT_ERROR = """ @@ -941,6 +943,14 @@ def is_aiter_version(operation: str, version: str): return False return compare_versions(parse(_aiter_version), operation, version) +@cache +def is_flashpack_version(operation: str, version: str): + """ + Compares the current flashpack version to a given reference with an operation. + """ + if not _flashpack_available: + return False + return compare_versions(parse(_flashpack_version), operation, version) def get_objects_from_module(module): """ From 8cc38a75d3996491b2ac37834489cb0e0e5a0e8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cdevanshi00=E2=80=9D?= <“devanshi7309@gmail.com”> Date: Wed, 21 Jan 2026 12:27:42 +0530 Subject: [PATCH 3/4] redundant model initialisation removed --- src/diffusers/models/modeling_utils.py | 148 +++++++++++++------------ 1 file changed, 77 insertions(+), 71 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 62fbc41cd55f..763a245c4e4e 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1217,17 +1217,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model model = load_flax_checkpoint_in_pytorch_model(model, resolved_model_file) + else: - # If we are using `use_flashpack`, we try to load the model from flashpack first. - # If flashpack is not available or the file cannot be loaded, we fall back to - # the standard loading path (e.g. safetensors or PyTorch). + flashpack_file = None if use_flashpack: - weights_name = _add_variant("model.flashpack", variant) - try: - resolved_model_file = _get_model_file( + flashpack_file = _get_model_file( pretrained_model_name_or_path, - weights_name=weights_name, + weights_name=_add_variant("model.flashpack", variant), cache_dir=cache_dir, force_download=force_download, proxies=proxies, @@ -1237,47 +1234,56 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P subfolder=subfolder, user_agent=user_agent, commit_hash=commit_hash, + dduf_entries=dduf_entries, ) except EnvironmentError: - resolved_model_file = None - with no_init_weights(): - model = cls.from_config(config, **unused_kwargs) - if resolved_model_file is not None: - from ..utils.flashpack_utils import load_flashpack - load_flashpack(model, resolved_model_file) - model.register_to_config(_name_or_path=pretrained_model_name_or_path) - model.eval() - - if output_loading_info: - loading_info = { - "missing_keys": [], - "unexpected_keys": [], - "mismatched_keys": [], - "error_msgs": [], - } - return model, loading_info - - return model - - # in the case it is sharded, we have already the index - if is_sharded: - resolved_model_file, sharded_metadata = _get_checkpoint_shard_files( - pretrained_model_name_or_path, - index_file, - cache_dir=cache_dir, - proxies=proxies, - local_files_only=local_files_only, - token=token, - user_agent=user_agent, - revision=revision, - subfolder=subfolder or "", - dduf_entries=dduf_entries, - ) - elif use_safetensors: - try: + flashpack_file = None + + if flashpack_file is None: + # in the case it is sharded, we have already the index + if is_sharded: + resolved_model_file, sharded_metadata = _get_checkpoint_shard_files( + pretrained_model_name_or_path, + index_file, + cache_dir=cache_dir, + proxies=proxies, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder or "", + dduf_entries=dduf_entries, + ) + elif use_safetensors: + logger.warning("Trying to load model weights with safetensors format.") + try: + resolved_model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + dduf_entries=dduf_entries, + ) + + except IOError as e: + logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}") + if not allow_pickle: + raise + logger.warning( + "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead." + ) + + if resolved_model_file is None and not is_sharded: resolved_model_file = _get_model_file( pretrained_model_name_or_path, - weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), + weights_name=_add_variant(WEIGHTS_NAME, variant), cache_dir=cache_dir, force_download=force_download, proxies=proxies, @@ -1290,32 +1296,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P dduf_entries=dduf_entries, ) - except IOError as e: - logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}") - if not allow_pickle: - raise - logger.warning( - "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead." - ) - - if resolved_model_file is None and not is_sharded: - resolved_model_file = _get_model_file( - pretrained_model_name_or_path, - weights_name=_add_variant(WEIGHTS_NAME, variant), - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - commit_hash=commit_hash, - dduf_entries=dduf_entries, - ) - - if not isinstance(resolved_model_file, list): - resolved_model_file = [resolved_model_file] + if not isinstance(resolved_model_file, list): + resolved_model_file = [resolved_model_file] # set dtype to instantiate the model under: # 1. If torch_dtype is not None, we use that dtype @@ -1339,6 +1321,28 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P if dtype_orig is not None: torch.set_default_dtype(dtype_orig) + if flashpack_file is not None: + from ..utils.flashpack_utils import load_flashpack + # Even when using FlashPack, we preserve `low_cpu_mem_usage` behavior by initializing + # the model with meta tensors. Since FlashPack cannot write into meta tensors, we + # explicitly materialize parameters before loading to ensure correctness and parity + # with the standard loading path. + if any(p.device.type == "meta" for p in model.parameters()): + model.to_empty(device="cpu") + load_flashpack(model, flashpack_file) + model.register_to_config(_name_or_path=pretrained_model_name_or_path) + model.eval() + + if output_loading_info: + return model, { + "missing_keys": [], + "unexpected_keys": [], + "mismatched_keys": [], + "error_msgs": [], + } + + return model + state_dict = None if not is_sharded: # Time to load the checkpoint @@ -1386,7 +1390,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P keep_in_fp32_modules=keep_in_fp32_modules, dduf_entries=dduf_entries, is_parallel_loading_enabled=is_parallel_loading_enabled, - disable_mmap=disable_mmap, ) loading_info = { "missing_keys": missing_keys, @@ -1431,9 +1434,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P if output_loading_info: return model, loading_info + + logger.warning(f"Model till end {pretrained_model_name_or_path} loaded successfully") return model + # Adapted from `transformers`. @wraps(torch.nn.Module.cuda) def cuda(self, *args, **kwargs): From 3bc3fdb035f187fd7b60876454c0a34c979eadf5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cdevanshi00=E2=80=9D?= <“devanshi7309@gmail.com”> Date: Wed, 21 Jan 2026 12:31:43 +0530 Subject: [PATCH 4/4] redundant model initialisation removed final --- src/diffusers/models/modeling_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 763a245c4e4e..5b0f8a3a0d64 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1296,8 +1296,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P dduf_entries=dduf_entries, ) - if not isinstance(resolved_model_file, list): - resolved_model_file = [resolved_model_file] + if not isinstance(resolved_model_file, list): + resolved_model_file = [resolved_model_file] # set dtype to instantiate the model under: # 1. If torch_dtype is not None, we use that dtype