Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 107 additions & 42 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,7 @@ def save_pretrained(
variant: Optional[str] = None,
max_shard_size: Union[int, str] = "10GB",
push_to_hub: bool = False,
use_flashpack: bool = False,
**kwargs,
):
"""
Expand Down Expand Up @@ -707,6 +708,10 @@ def save_pretrained(
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
use_flashpack (`bool`, *optional*, defaults to `False`):
Whether to save the model in [FlashPack](https://github.com/fal-ai/flashpack) format.
FlashPack is a binary format that allows for faster loading.
Requires the `flashpack` library to be installed.
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
Expand Down Expand Up @@ -734,7 +739,15 @@ def save_pretrained(
)

os.makedirs(save_directory, exist_ok=True)

if use_flashpack:
if not is_main_process:
return
from ..utils.flashpack_utils import save_flashpack
save_flashpack(
self,
save_directory,
variant=variant,
)
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
private = kwargs.pop("private", None)
Expand Down Expand Up @@ -939,6 +952,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
weights. If set to `False`, `safetensors` weights are not loaded.
use_flashpack (`bool`, *optional*, defaults to `False`):
If set to `True`, the model is first loaded from `flashpack` (https://github.com/fal-ai/flashpack) weights if a compatible `.flashpack` file
is found. If flashpack is unavailable or the `.flashpack` file cannot be used, automatic fallback to
the standard loading path (for example, `safetensors`).
disable_mmap ('bool', *optional*, defaults to 'False'):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
Expand Down Expand Up @@ -982,6 +999,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None)
use_flashpack = kwargs.pop("use_flashpack", False)
quantization_config = kwargs.pop("quantization_config", None)
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
disable_mmap = kwargs.pop("disable_mmap", False)
Expand Down Expand Up @@ -1199,26 +1217,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model

model = load_flax_checkpoint_in_pytorch_model(model, resolved_model_file)

else:
# in the case it is sharded, we have already the index
if is_sharded:
resolved_model_file, sharded_metadata = _get_checkpoint_shard_files(
pretrained_model_name_or_path,
index_file,
cache_dir=cache_dir,
proxies=proxies,
local_files_only=local_files_only,
token=token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder or "",
dduf_entries=dduf_entries,
)
elif use_safetensors:
flashpack_file = None
if use_flashpack:
try:
resolved_model_file = _get_model_file(
flashpack_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
weights_name=_add_variant("model.flashpack", variant),
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
Expand All @@ -1230,33 +1236,68 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
commit_hash=commit_hash,
dduf_entries=dduf_entries,
)
except EnvironmentError:
flashpack_file = None

except IOError as e:
logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
if not allow_pickle:
raise
logger.warning(
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
if flashpack_file is None:
# in the case it is sharded, we have already the index
if is_sharded:
resolved_model_file, sharded_metadata = _get_checkpoint_shard_files(
pretrained_model_name_or_path,
index_file,
cache_dir=cache_dir,
proxies=proxies,
local_files_only=local_files_only,
token=token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder or "",
dduf_entries=dduf_entries,
)
elif use_safetensors:
logger.warning("Trying to load model weights with safetensors format.")
try:
resolved_model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
dduf_entries=dduf_entries,
)

if resolved_model_file is None and not is_sharded:
resolved_model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=_add_variant(WEIGHTS_NAME, variant),
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
dduf_entries=dduf_entries,
)
except IOError as e:
logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
if not allow_pickle:
raise
logger.warning(
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
)

if resolved_model_file is None and not is_sharded:
resolved_model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=_add_variant(WEIGHTS_NAME, variant),
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
dduf_entries=dduf_entries,
)

if not isinstance(resolved_model_file, list):
resolved_model_file = [resolved_model_file]
if not isinstance(resolved_model_file, list):
resolved_model_file = [resolved_model_file]

# set dtype to instantiate the model under:
# 1. If torch_dtype is not None, we use that dtype
Expand All @@ -1280,6 +1321,28 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)

if flashpack_file is not None:
from ..utils.flashpack_utils import load_flashpack
# Even when using FlashPack, we preserve `low_cpu_mem_usage` behavior by initializing
# the model with meta tensors. Since FlashPack cannot write into meta tensors, we
# explicitly materialize parameters before loading to ensure correctness and parity
# with the standard loading path.
if any(p.device.type == "meta" for p in model.parameters()):
model.to_empty(device="cpu")
load_flashpack(model, flashpack_file)
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
model.eval()
Comment on lines +1333 to +1334
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why should these have to be conditioned under if flashpack_file is not None?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the current early-return behavior, the FlashPack code path is intentionally isolated from the standard loading flow. After instantiating the model and loading weights via FlashPack, we finalize the model (registering _name_or_path and switching to eval mode) and return immediately.
As a result, the model deliberately does not go through the usual checkpoint-loading logic (load_state_dict, _load_pretrained_model), device-map dispatch, sharding handling, quantizer post-processing, or the common end-of-function finalization.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me look into it and see if I can help with further simplifying.


if output_loading_info:
return model, {
"missing_keys": [],
"unexpected_keys": [],
"mismatched_keys": [],
"error_msgs": [],
}

return model
Comment on lines +1336 to +1344
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This as well.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block exists to preserve the from_pretrained(..., output_loading_info=True) API contract for the FlashPack loading path. Since FlashPack bypasses _load_pretrained_model, none of the usual key-matching or error-collection logic is executed, so there is no meaningful loading diagnostics to report. Returning an empty but well-formed loading_info dict keeps the return type consistent with the standard loading path without implying any missing or mismatched keys were checked.


state_dict = None
if not is_sharded:
# Time to load the checkpoint
Expand Down Expand Up @@ -1327,7 +1390,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
keep_in_fp32_modules=keep_in_fp32_modules,
dduf_entries=dduf_entries,
is_parallel_loading_enabled=is_parallel_loading_enabled,
disable_mmap=disable_mmap,
)
loading_info = {
"missing_keys": missing_keys,
Expand Down Expand Up @@ -1372,9 +1434,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

if output_loading_info:
return model, loading_info

logger.warning(f"Model till end {pretrained_model_name_or_path} loaded successfully")

return model


# Adapted from `transformers`.
@wraps(torch.nn.Module.cuda)
def cuda(self, *args, **kwargs):
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,7 @@ def load_sub_model(
low_cpu_mem_usage: bool,
cached_folder: Union[str, os.PathLike],
use_safetensors: bool,
use_flashpack: bool,
dduf_entries: Optional[Dict[str, DDUFEntry]],
provider_options: Any,
disable_mmap: bool,
Expand Down Expand Up @@ -838,6 +839,9 @@ def load_sub_model(
loading_kwargs["variant"] = model_variants.pop(name, None)
loading_kwargs["use_safetensors"] = use_safetensors

if is_diffusers_model:
loading_kwargs["use_flashpack"] = use_flashpack

if from_flax:
loading_kwargs["from_flax"] = True

Expand Down
15 changes: 14 additions & 1 deletion src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ def save_pretrained(
variant: Optional[str] = None,
max_shard_size: Optional[Union[int, str]] = None,
push_to_hub: bool = False,
use_flashpack: bool = False,
**kwargs,
):
"""
Expand All @@ -268,7 +269,9 @@ class implements both a save and loading method. The pipeline is easily reloaded
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).

use_flashpack (`bool`, *optional*, defaults to `False`):
Whether or not to use `flashpack` to save the model weights. Requires the `flashpack` library: `pip install
flashpack`.
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
Expand Down Expand Up @@ -340,6 +343,7 @@ def is_saveable_module(name, value):
save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
save_method_accept_variant = "variant" in save_method_signature.parameters
save_method_accept_max_shard_size = "max_shard_size" in save_method_signature.parameters
save_method_accept_flashpack = "use_flashpack" in save_method_signature.parameters

save_kwargs = {}
if save_method_accept_safe:
Expand All @@ -349,6 +353,8 @@ def is_saveable_module(name, value):
if save_method_accept_max_shard_size and max_shard_size is not None:
# max_shard_size is expected to not be None in ModelMixin
save_kwargs["max_shard_size"] = max_shard_size
if save_method_accept_flashpack:
save_kwargs["use_flashpack"] = use_flashpack

save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)

Expand Down Expand Up @@ -707,6 +713,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
weights. If set to `False`, safetensors weights are not loaded.
use_flashpack (`bool`, *optional*, defaults to `False`):
If set to `True`, the model is first loaded from `flashpack` weights if a compatible `.flashpack` file
is found. If flashpack is unavailable or the `.flashpack` file cannot be used, automatic fallback to
the standard loading path (for example, `safetensors`). Requires the `flashpack` library: `pip install
flashpack`.
use_onnx (`bool`, *optional*, defaults to `None`):
If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights
will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
Expand Down Expand Up @@ -772,6 +783,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
variant = kwargs.pop("variant", None)
dduf_file = kwargs.pop("dduf_file", None)
use_safetensors = kwargs.pop("use_safetensors", None)
use_flashpack = kwargs.pop("use_flashpack", False)
use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
quantization_config = kwargs.pop("quantization_config", None)
Expand Down Expand Up @@ -1061,6 +1073,7 @@ def load_module(name, value):
low_cpu_mem_usage=low_cpu_mem_usage,
cached_folder=cached_folder,
use_safetensors=use_safetensors,
use_flashpack=use_flashpack,
dduf_entries=dduf_entries,
provider_options=provider_options,
disable_mmap=disable_mmap,
Expand Down
83 changes: 83 additions & 0 deletions src/diffusers/utils/flashpack_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import json
import os
from typing import Optional
from .import_utils import is_flashpack_available
from .logging import get_logger
from ..utils import _add_variant

logger = get_logger(__name__)

def save_flashpack(
model,
save_directory: str,
variant: Optional[str] = None,
is_main_process: bool = True,
):
"""
Save model weights in FlashPack format along with a metadata config.

Args:
model: Diffusers model instance
save_directory (`str`): Directory to save weights
variant (`str`, *optional*): Model variant
"""
if not is_flashpack_available():
raise ImportError(
"The `use_flashpack=True` argument requires the `flashpack` package. "
"Install it with `pip install flashpack`."
)

from flashpack import pack_to_file

os.makedirs(save_directory, exist_ok=True)

weights_name = _add_variant("model.flashpack", variant)
weights_path = os.path.join(save_directory, weights_name)
config_path = os.path.join(save_directory, "flashpack_config.json")

try:
target_dtype = getattr(model, "dtype", None)
logger.warning(f"Dtype used for FlashPack save: {target_dtype}")

# 1. Save binary weights
pack_to_file(model, weights_path, target_dtype=target_dtype)

# 2. Save config metadata (best-effort)
if hasattr(model, "config"):
try:
if hasattr(model.config, "to_dict"):
config_data = model.config.to_dict()
else:
config_data = dict(model.config)

with open(config_path, "w") as f:
json.dump(config_data, f, indent=4)

except Exception as config_err:
logger.warning(
f"FlashPack weights saved, but config serialization failed: {config_err}"
)

except Exception as e:
logger.error(f"Failed to save weights in FlashPack format: {e}")
raise

def load_flashpack(model, flashpack_file: str):
"""
Assign FlashPack weights from a file into an initialized PyTorch model.
"""
if not is_flashpack_available():
raise ImportError(
"FlashPack weights require the `flashpack` package. "
"Install with `pip install flashpack`."
)

from flashpack import assign_from_file
logger.warning(f"Loading FlashPack weights from {flashpack_file}")

try:
assign_from_file(model, flashpack_file)
except Exception as e:
raise RuntimeError(
f"Failed to load FlashPack weights from {flashpack_file}"
) from e
Loading