-
Notifications
You must be signed in to change notification settings - Fork 6.7k
[feat] added fal-flashpack support #12999
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -675,6 +675,7 @@ def save_pretrained( | |
| variant: Optional[str] = None, | ||
| max_shard_size: Union[int, str] = "10GB", | ||
| push_to_hub: bool = False, | ||
| use_flashpack: bool = False, | ||
| **kwargs, | ||
| ): | ||
| """ | ||
|
|
@@ -707,6 +708,10 @@ def save_pretrained( | |
| Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the | ||
| repository you want to push to with `repo_id` (will default to the name of `save_directory` in your | ||
| namespace). | ||
| use_flashpack (`bool`, *optional*, defaults to `False`): | ||
| Whether to save the model in [FlashPack](https://github.com/fal-ai/flashpack) format. | ||
| FlashPack is a binary format that allows for faster loading. | ||
| Requires the `flashpack` library to be installed. | ||
| kwargs (`Dict[str, Any]`, *optional*): | ||
| Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. | ||
| """ | ||
|
|
@@ -734,7 +739,15 @@ def save_pretrained( | |
| ) | ||
|
|
||
| os.makedirs(save_directory, exist_ok=True) | ||
|
|
||
| if use_flashpack: | ||
| if not is_main_process: | ||
| return | ||
| from ..utils.flashpack_utils import save_flashpack | ||
| save_flashpack( | ||
| self, | ||
| save_directory, | ||
| variant=variant, | ||
| ) | ||
| if push_to_hub: | ||
| commit_message = kwargs.pop("commit_message", None) | ||
| private = kwargs.pop("private", None) | ||
|
|
@@ -939,6 +952,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
| If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the | ||
| `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors` | ||
| weights. If set to `False`, `safetensors` weights are not loaded. | ||
| use_flashpack (`bool`, *optional*, defaults to `False`): | ||
| If set to `True`, the model is first loaded from `flashpack` (https://github.com/fal-ai/flashpack) weights if a compatible `.flashpack` file | ||
| is found. If flashpack is unavailable or the `.flashpack` file cannot be used, automatic fallback to | ||
| the standard loading path (for example, `safetensors`). | ||
| disable_mmap ('bool', *optional*, defaults to 'False'): | ||
| Whether to disable mmap when loading a Safetensors model. This option can perform better when the model | ||
| is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well. | ||
|
|
@@ -982,6 +999,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
| low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) | ||
| variant = kwargs.pop("variant", None) | ||
| use_safetensors = kwargs.pop("use_safetensors", None) | ||
| use_flashpack = kwargs.pop("use_flashpack", False) | ||
| quantization_config = kwargs.pop("quantization_config", None) | ||
| dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None) | ||
| disable_mmap = kwargs.pop("disable_mmap", False) | ||
|
|
@@ -1199,26 +1217,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
| from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model | ||
|
|
||
| model = load_flax_checkpoint_in_pytorch_model(model, resolved_model_file) | ||
|
|
||
| else: | ||
| # in the case it is sharded, we have already the index | ||
| if is_sharded: | ||
| resolved_model_file, sharded_metadata = _get_checkpoint_shard_files( | ||
| pretrained_model_name_or_path, | ||
| index_file, | ||
| cache_dir=cache_dir, | ||
| proxies=proxies, | ||
| local_files_only=local_files_only, | ||
| token=token, | ||
| user_agent=user_agent, | ||
| revision=revision, | ||
| subfolder=subfolder or "", | ||
| dduf_entries=dduf_entries, | ||
| ) | ||
| elif use_safetensors: | ||
| flashpack_file = None | ||
| if use_flashpack: | ||
devanshi00 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| try: | ||
| resolved_model_file = _get_model_file( | ||
| flashpack_file = _get_model_file( | ||
| pretrained_model_name_or_path, | ||
| weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), | ||
| weights_name=_add_variant("model.flashpack", variant), | ||
| cache_dir=cache_dir, | ||
| force_download=force_download, | ||
| proxies=proxies, | ||
|
|
@@ -1230,33 +1236,68 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
| commit_hash=commit_hash, | ||
| dduf_entries=dduf_entries, | ||
| ) | ||
| except EnvironmentError: | ||
| flashpack_file = None | ||
|
|
||
| except IOError as e: | ||
| logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}") | ||
| if not allow_pickle: | ||
| raise | ||
| logger.warning( | ||
| "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead." | ||
| if flashpack_file is None: | ||
| # in the case it is sharded, we have already the index | ||
| if is_sharded: | ||
| resolved_model_file, sharded_metadata = _get_checkpoint_shard_files( | ||
| pretrained_model_name_or_path, | ||
| index_file, | ||
| cache_dir=cache_dir, | ||
| proxies=proxies, | ||
| local_files_only=local_files_only, | ||
| token=token, | ||
| user_agent=user_agent, | ||
| revision=revision, | ||
| subfolder=subfolder or "", | ||
| dduf_entries=dduf_entries, | ||
| ) | ||
| elif use_safetensors: | ||
| logger.warning("Trying to load model weights with safetensors format.") | ||
| try: | ||
| resolved_model_file = _get_model_file( | ||
| pretrained_model_name_or_path, | ||
| weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), | ||
| cache_dir=cache_dir, | ||
| force_download=force_download, | ||
| proxies=proxies, | ||
| local_files_only=local_files_only, | ||
| token=token, | ||
| revision=revision, | ||
| subfolder=subfolder, | ||
| user_agent=user_agent, | ||
| commit_hash=commit_hash, | ||
| dduf_entries=dduf_entries, | ||
| ) | ||
|
|
||
| if resolved_model_file is None and not is_sharded: | ||
| resolved_model_file = _get_model_file( | ||
| pretrained_model_name_or_path, | ||
| weights_name=_add_variant(WEIGHTS_NAME, variant), | ||
| cache_dir=cache_dir, | ||
| force_download=force_download, | ||
| proxies=proxies, | ||
| local_files_only=local_files_only, | ||
| token=token, | ||
| revision=revision, | ||
| subfolder=subfolder, | ||
| user_agent=user_agent, | ||
| commit_hash=commit_hash, | ||
| dduf_entries=dduf_entries, | ||
| ) | ||
| except IOError as e: | ||
| logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}") | ||
| if not allow_pickle: | ||
| raise | ||
| logger.warning( | ||
| "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead." | ||
| ) | ||
|
|
||
| if resolved_model_file is None and not is_sharded: | ||
| resolved_model_file = _get_model_file( | ||
| pretrained_model_name_or_path, | ||
| weights_name=_add_variant(WEIGHTS_NAME, variant), | ||
| cache_dir=cache_dir, | ||
| force_download=force_download, | ||
| proxies=proxies, | ||
| local_files_only=local_files_only, | ||
| token=token, | ||
| revision=revision, | ||
| subfolder=subfolder, | ||
| user_agent=user_agent, | ||
| commit_hash=commit_hash, | ||
| dduf_entries=dduf_entries, | ||
| ) | ||
|
|
||
| if not isinstance(resolved_model_file, list): | ||
| resolved_model_file = [resolved_model_file] | ||
| if not isinstance(resolved_model_file, list): | ||
| resolved_model_file = [resolved_model_file] | ||
|
|
||
| # set dtype to instantiate the model under: | ||
| # 1. If torch_dtype is not None, we use that dtype | ||
|
|
@@ -1280,6 +1321,28 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
| if dtype_orig is not None: | ||
| torch.set_default_dtype(dtype_orig) | ||
|
|
||
| if flashpack_file is not None: | ||
| from ..utils.flashpack_utils import load_flashpack | ||
| # Even when using FlashPack, we preserve `low_cpu_mem_usage` behavior by initializing | ||
| # the model with meta tensors. Since FlashPack cannot write into meta tensors, we | ||
| # explicitly materialize parameters before loading to ensure correctness and parity | ||
| # with the standard loading path. | ||
| if any(p.device.type == "meta" for p in model.parameters()): | ||
| model.to_empty(device="cpu") | ||
| load_flashpack(model, flashpack_file) | ||
| model.register_to_config(_name_or_path=pretrained_model_name_or_path) | ||
| model.eval() | ||
|
Comment on lines
+1333
to
+1334
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why should these have to be conditioned under
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With the current early-return behavior, the FlashPack code path is intentionally isolated from the standard loading flow. After instantiating the model and loading weights via FlashPack, we finalize the model (registering _name_or_path and switching to eval mode) and return immediately.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let me look into it and see if I can help with further simplifying. |
||
|
|
||
| if output_loading_info: | ||
| return model, { | ||
| "missing_keys": [], | ||
| "unexpected_keys": [], | ||
| "mismatched_keys": [], | ||
| "error_msgs": [], | ||
| } | ||
|
|
||
| return model | ||
|
Comment on lines
+1336
to
+1344
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This as well.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This block exists to preserve the from_pretrained(..., output_loading_info=True) API contract for the FlashPack loading path. Since FlashPack bypasses _load_pretrained_model, none of the usual key-matching or error-collection logic is executed, so there is no meaningful loading diagnostics to report. Returning an empty but well-formed loading_info dict keeps the return type consistent with the standard loading path without implying any missing or mismatched keys were checked. |
||
|
|
||
| state_dict = None | ||
| if not is_sharded: | ||
| # Time to load the checkpoint | ||
|
|
@@ -1327,7 +1390,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
| keep_in_fp32_modules=keep_in_fp32_modules, | ||
| dduf_entries=dduf_entries, | ||
| is_parallel_loading_enabled=is_parallel_loading_enabled, | ||
| disable_mmap=disable_mmap, | ||
| ) | ||
| loading_info = { | ||
| "missing_keys": missing_keys, | ||
|
|
@@ -1372,9 +1434,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
|
|
||
| if output_loading_info: | ||
| return model, loading_info | ||
|
|
||
| logger.warning(f"Model till end {pretrained_model_name_or_path} loaded successfully") | ||
|
|
||
| return model | ||
|
|
||
|
|
||
| # Adapted from `transformers`. | ||
| @wraps(torch.nn.Module.cuda) | ||
| def cuda(self, *args, **kwargs): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,83 @@ | ||
| import json | ||
| import os | ||
| from typing import Optional | ||
| from .import_utils import is_flashpack_available | ||
| from .logging import get_logger | ||
| from ..utils import _add_variant | ||
|
|
||
| logger = get_logger(__name__) | ||
|
|
||
| def save_flashpack( | ||
| model, | ||
| save_directory: str, | ||
| variant: Optional[str] = None, | ||
| is_main_process: bool = True, | ||
| ): | ||
| """ | ||
| Save model weights in FlashPack format along with a metadata config. | ||
|
|
||
| Args: | ||
| model: Diffusers model instance | ||
| save_directory (`str`): Directory to save weights | ||
| variant (`str`, *optional*): Model variant | ||
| """ | ||
| if not is_flashpack_available(): | ||
| raise ImportError( | ||
| "The `use_flashpack=True` argument requires the `flashpack` package. " | ||
| "Install it with `pip install flashpack`." | ||
| ) | ||
|
|
||
| from flashpack import pack_to_file | ||
|
|
||
| os.makedirs(save_directory, exist_ok=True) | ||
|
|
||
| weights_name = _add_variant("model.flashpack", variant) | ||
| weights_path = os.path.join(save_directory, weights_name) | ||
| config_path = os.path.join(save_directory, "flashpack_config.json") | ||
|
|
||
| try: | ||
| target_dtype = getattr(model, "dtype", None) | ||
| logger.warning(f"Dtype used for FlashPack save: {target_dtype}") | ||
|
|
||
| # 1. Save binary weights | ||
| pack_to_file(model, weights_path, target_dtype=target_dtype) | ||
|
|
||
| # 2. Save config metadata (best-effort) | ||
| if hasattr(model, "config"): | ||
| try: | ||
| if hasattr(model.config, "to_dict"): | ||
| config_data = model.config.to_dict() | ||
| else: | ||
| config_data = dict(model.config) | ||
|
|
||
| with open(config_path, "w") as f: | ||
| json.dump(config_data, f, indent=4) | ||
|
|
||
| except Exception as config_err: | ||
| logger.warning( | ||
| f"FlashPack weights saved, but config serialization failed: {config_err}" | ||
| ) | ||
|
|
||
| except Exception as e: | ||
| logger.error(f"Failed to save weights in FlashPack format: {e}") | ||
| raise | ||
|
|
||
| def load_flashpack(model, flashpack_file: str): | ||
| """ | ||
| Assign FlashPack weights from a file into an initialized PyTorch model. | ||
| """ | ||
| if not is_flashpack_available(): | ||
| raise ImportError( | ||
| "FlashPack weights require the `flashpack` package. " | ||
| "Install with `pip install flashpack`." | ||
| ) | ||
|
|
||
| from flashpack import assign_from_file | ||
| logger.warning(f"Loading FlashPack weights from {flashpack_file}") | ||
|
|
||
| try: | ||
| assign_from_file(model, flashpack_file) | ||
| except Exception as e: | ||
| raise RuntimeError( | ||
| f"Failed to load FlashPack weights from {flashpack_file}" | ||
| ) from e |
Uh oh!
There was an error while loading. Please reload this page.