diff --git a/examples/diffusers/quantization/calibration.py b/examples/diffusers/quantization/calibration.py new file mode 100644 index 000000000..aa5d37847 --- /dev/null +++ b/examples/diffusers/quantization/calibration.py @@ -0,0 +1,195 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from pathlib import Path +from typing import Any + +from models_utils import MODEL_DEFAULTS, ModelType +from pipeline_manager import PipelineManager +from quantize_config import CalibrationConfig +from tqdm import tqdm +from utils import load_calib_prompts + + +class Calibrator: + """Handles model calibration for quantization.""" + + def __init__( + self, + pipeline_manager: PipelineManager, + config: CalibrationConfig, + model_type: ModelType, + logger: logging.Logger, + ): + """ + Initialize calibrator. + + Args: + pipeline_manager: Pipeline manager with main and upsampler pipelines + config: Calibration configuration + model_type: Type of model being calibrated + logger: Logger instance + """ + self.pipeline_manager = pipeline_manager + self.pipe = pipeline_manager.pipe + self.pipe_upsample = pipeline_manager.pipe_upsample + self.config = config + self.model_type = model_type + self.logger = logger + + def load_and_batch_prompts(self) -> list[list[str]]: + """ + Load calibration prompts from file. + + Returns: + List of batched calibration prompts + """ + self.logger.info(f"Loading calibration prompts from {self.config.prompts_dataset}") + if isinstance(self.config.prompts_dataset, Path): + return load_calib_prompts( + self.config.batch_size, + self.config.prompts_dataset, + ) + + return load_calib_prompts( + self.config.batch_size, + self.config.prompts_dataset["name"], + self.config.prompts_dataset["split"], + self.config.prompts_dataset["column"], + ) + + def run_calibration(self, batched_prompts: list[list[str]]) -> None: + """ + Run calibration steps on the pipeline. + + Args: + batched_prompts: List of batched calibration prompts + """ + self.logger.info(f"Starting calibration with {self.config.num_batches} batches") + extra_args = MODEL_DEFAULTS.get(self.model_type, {}).get("inference_extra_args", {}) + + with tqdm(total=self.config.num_batches, desc="Calibration", unit="batch") as pbar: + for i, prompt_batch in enumerate(batched_prompts): + if i >= self.config.num_batches: + break + + if self.model_type == ModelType.LTX2: + self._run_ltx2_calibration(prompt_batch, extra_args) + elif self.model_type == ModelType.LTX_VIDEO_DEV: + # Special handling for LTX-Video + self._run_ltx_video_calibration(prompt_batch, extra_args) + elif self.model_type in [ModelType.WAN22_T2V_14b, ModelType.WAN22_T2V_5b]: + # Special handling for WAN video models + self._run_wan_video_calibration(prompt_batch, extra_args) + else: + common_args = { + "prompt": prompt_batch, + "num_inference_steps": self.config.n_steps, + } + self.pipe(**common_args, **extra_args).images + pbar.update(1) + self.logger.debug(f"Completed calibration batch {i + 1}/{self.config.num_batches}") + self.logger.info("Calibration completed successfully") + + def _run_wan_video_calibration( + self, prompt_batch: list[str], extra_args: dict[str, Any] + ) -> None: + kwargs = {} + kwargs["negative_prompt"] = extra_args["negative_prompt"] + kwargs["height"] = extra_args["height"] + kwargs["width"] = extra_args["width"] + kwargs["num_frames"] = extra_args["num_frames"] + kwargs["guidance_scale"] = extra_args["guidance_scale"] + if "guidance_scale_2" in extra_args: + kwargs["guidance_scale_2"] = extra_args["guidance_scale_2"] + kwargs["num_inference_steps"] = self.config.n_steps + + self.pipe(prompt=prompt_batch, **kwargs).frames + + def _run_ltx2_calibration(self, prompt_batch: list[str], extra_args: dict[str, Any]) -> None: + from ltx_core.model.video_vae import TilingConfig + + prompt = prompt_batch[0] + extra_params = self.pipeline_manager.config.extra_params + kwargs = { + "negative_prompt": extra_args.get( + "negative_prompt", "worst quality, inconsistent motion, blurry, jittery, distorted" + ), + "seed": extra_params.get("seed", 0), + "height": extra_params.get("height", extra_args.get("height", 1024)), + "width": extra_params.get("width", extra_args.get("width", 1536)), + "num_frames": extra_params.get("num_frames", extra_args.get("num_frames", 121)), + "frame_rate": extra_params.get("frame_rate", extra_args.get("frame_rate", 24.0)), + "num_inference_steps": self.config.n_steps, + "cfg_guidance_scale": extra_params.get( + "cfg_guidance_scale", extra_args.get("cfg_guidance_scale", 4.0) + ), + "images": extra_params.get("images", []), + "tiling_config": extra_params.get("tiling_config", TilingConfig.default()), + } + self.pipe(prompt=prompt, **kwargs) + + def _run_ltx_video_calibration( + self, prompt_batch: list[str], extra_args: dict[str, Any] + ) -> None: + """ + Run calibration for LTX-Video model using the full multi-stage pipeline. + + Args: + prompt_batch: Batch of prompts + extra_args: Model-specific arguments + """ + # Extract specific args for LTX-Video + expected_height = extra_args.get("height", 512) + expected_width = extra_args.get("width", 704) + num_frames = extra_args.get("num_frames", 121) + negative_prompt = extra_args.get( + "negative_prompt", "worst quality, inconsistent motion, blurry, jittery, distorted" + ) + + def round_to_nearest_resolution_acceptable_by_vae(height, width): + height = height - (height % self.pipe.vae_spatial_compression_ratio) + width = width - (width % self.pipe.vae_spatial_compression_ratio) + return height, width + + downscale_factor = 2 / 3 + # Part 1: Generate video at smaller resolution + downscaled_height, downscaled_width = ( + int(expected_height * downscale_factor), + int(expected_width * downscale_factor), + ) + downscaled_height, downscaled_width = round_to_nearest_resolution_acceptable_by_vae( + downscaled_height, downscaled_width + ) + + # Generate initial latents at lower resolution + latents = self.pipe( + conditions=None, + prompt=prompt_batch, + negative_prompt=negative_prompt, + width=downscaled_width, + height=downscaled_height, + num_frames=num_frames, + num_inference_steps=self.config.n_steps, + output_type="latent", + ).frames + + # Part 2: Upscale generated video using latent upsampler (if available) + if self.pipe_upsample is not None: + _ = self.pipe_upsample(latents=latents, output_type="latent").frames + + # Part 3: Denoise the upscaled video with few steps to improve texture + # However, in this example code, we will omit the upscale step since its optional. diff --git a/examples/diffusers/quantization/models_utils.py b/examples/diffusers/quantization/models_utils.py index 8fb6d7788..9a061622e 100644 --- a/examples/diffusers/quantization/models_utils.py +++ b/examples/diffusers/quantization/models_utils.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging from collections.abc import Callable from enum import Enum from typing import Any @@ -42,6 +43,7 @@ class ModelType(str, Enum): FLUX_DEV = "flux-dev" FLUX_SCHNELL = "flux-schnell" LTX_VIDEO_DEV = "ltx-video-dev" + LTX2 = "ltx-2" WAN22_T2V_14b = "wan2.2-t2v-14b" WAN22_T2V_5b = "wan2.2-t2v-5b" @@ -64,6 +66,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: ModelType.SD3_MEDIUM: filter_func_default, ModelType.SD35_MEDIUM: filter_func_default, ModelType.LTX_VIDEO_DEV: filter_func_ltx_video, + ModelType.LTX2: filter_func_ltx_video, ModelType.WAN22_T2V_14b: filter_func_wan_video, ModelType.WAN22_T2V_5b: filter_func_wan_video, } @@ -80,11 +83,12 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: ModelType.FLUX_DEV: "black-forest-labs/FLUX.1-dev", ModelType.FLUX_SCHNELL: "black-forest-labs/FLUX.1-schnell", ModelType.LTX_VIDEO_DEV: "Lightricks/LTX-Video-0.9.7-dev", + ModelType.LTX2: "Lightricks/LTX-2", ModelType.WAN22_T2V_14b: "Wan-AI/Wan2.2-T2V-A14B-Diffusers", ModelType.WAN22_T2V_5b: "Wan-AI/Wan2.2-TI2V-5B-Diffusers", } -MODEL_PIPELINE: dict[ModelType, type[DiffusionPipeline]] = { +MODEL_PIPELINE: dict[ModelType, type[DiffusionPipeline] | None] = { ModelType.SDXL_BASE: DiffusionPipeline, ModelType.SDXL_TURBO: DiffusionPipeline, ModelType.SD3_MEDIUM: StableDiffusion3Pipeline, @@ -92,6 +96,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: ModelType.FLUX_DEV: FluxPipeline, ModelType.FLUX_SCHNELL: FluxPipeline, ModelType.LTX_VIDEO_DEV: LTXConditionPipeline, + ModelType.LTX2: None, ModelType.WAN22_T2V_14b: WanPipeline, ModelType.WAN22_T2V_5b: WanPipeline, } @@ -154,6 +159,18 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: "negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted", }, }, + ModelType.LTX2: { + "backbone": "transformer", + "dataset": _SD_PROMPTS_DATASET, + "inference_extra_args": { + "height": 1024, + "width": 1536, + "num_frames": 121, + "frame_rate": 24.0, + "cfg_guidance_scale": 4.0, + "negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted", + }, + }, ModelType.WAN22_T2V_14b: { **_WAN_BASE_CONFIG, "from_pretrained_extra_args": { @@ -192,3 +209,48 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: }, }, } + + +def _coerce_extra_param_value(value: str) -> Any: + lowered = value.lower() + if lowered in {"true", "false"}: + return lowered == "true" + try: + return int(value) + except ValueError: + pass + try: + return float(value) + except ValueError: + return value + + +def parse_extra_params( + kv_args: list[str], unknown_args: list[str], logger: logging.Logger +) -> dict[str, Any]: + extra_params: dict[str, Any] = {} + for item in kv_args: + if "=" not in item: + raise ValueError(f"Invalid --extra-param value: '{item}'. Expected KEY=VALUE.") + key, value = item.split("=", 1) + extra_params[key] = _coerce_extra_param_value(value) + + i = 0 + while i < len(unknown_args): + token = unknown_args[i] + if token.startswith("--extra_param."): + key = token[len("--extra_param.") :] + value = "true" + if i + 1 < len(unknown_args) and not unknown_args[i + 1].startswith("--"): + value = unknown_args[i + 1] + i += 1 + extra_params[key] = _coerce_extra_param_value(value) + elif token.startswith("--extra_param"): + raise ValueError( + "Use --extra_param.KEY VALUE or --extra-param KEY=VALUE for extra parameters." + ) + else: + logger.warning("Ignoring unknown argument: %s", token) + i += 1 + + return extra_params diff --git a/examples/diffusers/quantization/pipeline_manager.py b/examples/diffusers/quantization/pipeline_manager.py new file mode 100644 index 000000000..fc821a099 --- /dev/null +++ b/examples/diffusers/quantization/pipeline_manager.py @@ -0,0 +1,216 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any + +import torch +from diffusers import DiffusionPipeline, LTXLatentUpsamplePipeline +from models_utils import MODEL_DEFAULTS, MODEL_PIPELINE, MODEL_REGISTRY, ModelType +from quantize_config import ModelConfig + + +class PipelineManager: + """Manages diffusion pipeline creation and configuration.""" + + def __init__(self, config: ModelConfig, logger: logging.Logger): + """ + Initialize pipeline manager. + + Args: + config: Model configuration + logger: Logger instance + """ + self.config = config + self.logger = logger + self.pipe: Any | None = None + self.pipe_upsample: LTXLatentUpsamplePipeline | None = None # For LTX-Video upsampling + self._transformer: torch.nn.Module | None = None + + @staticmethod + def create_pipeline_from( + model_type: ModelType, + torch_dtype: torch.dtype | dict[str, str | torch.dtype] = torch.bfloat16, + override_model_path: str | None = None, + ) -> DiffusionPipeline: + """ + Create and return an appropriate pipeline based on configuration. + + Returns: + Configured diffusion pipeline + + Raises: + ValueError: If model type is unsupported + """ + try: + pipeline_cls = MODEL_PIPELINE[model_type] + if pipeline_cls is None: + raise ValueError(f"Model type {model_type.value} does not use diffusers pipelines.") + model_id = ( + MODEL_REGISTRY[model_type] if override_model_path is None else override_model_path + ) + pipe = pipeline_cls.from_pretrained( + model_id, + torch_dtype=torch_dtype, + use_safetensors=True, + **MODEL_DEFAULTS[model_type].get("from_pretrained_extra_args", {}), + ) + pipe.set_progress_bar_config(disable=True) + return pipe + except Exception as e: + raise e + + def create_pipeline(self) -> Any: + """ + Create and return an appropriate pipeline based on configuration. + + Returns: + Configured diffusion pipeline + + Raises: + ValueError: If model type is unsupported + """ + self.logger.info(f"Creating pipeline for {self.config.model_type.value}") + self.logger.info(f"Model path: {self.config.model_path}") + self.logger.info(f"Data type: {self.config.model_dtype}") + + try: + if self.config.model_type == ModelType.LTX2: + from modelopt.torch.quantization.plugins.diffusion import ltx2 as ltx2_plugin + + ltx2_plugin.register_ltx2_quant_linear() + self.pipe = self._create_ltx2_pipeline() + self.logger.info("LTX-2 pipeline created successfully") + return self.pipe + + pipeline_cls = MODEL_PIPELINE[self.config.model_type] + if pipeline_cls is None: + raise ValueError( + f"Model type {self.config.model_type.value} does not use diffusers pipelines." + ) + self.pipe = pipeline_cls.from_pretrained( + self.config.model_path, + torch_dtype=self.config.model_dtype, + use_safetensors=True, + **MODEL_DEFAULTS[self.config.model_type].get("from_pretrained_extra_args", {}), + ) + if self.config.model_type == ModelType.LTX_VIDEO_DEV: + # Optionally load the upsampler pipeline for LTX-Video + if not self.config.ltx_skip_upsampler: + self.logger.info("Loading LTX-Video upsampler pipeline...") + self.pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained( + "Lightricks/ltxv-spatial-upscaler-0.9.7", + vae=self.pipe.vae, + torch_dtype=self.config.model_dtype, + ) + self.pipe_upsample.set_progress_bar_config(disable=True) + else: + self.logger.info("Skipping upsampler pipeline for faster calibration") + self.pipe.set_progress_bar_config(disable=True) + + self.logger.info("Pipeline created successfully") + return self.pipe + + except Exception as e: + self.logger.error(f"Failed to create pipeline: {e}") + raise + + def setup_device(self) -> None: + """Configure pipeline device placement.""" + if not self.pipe: + raise RuntimeError("Pipeline not created. Call create_pipeline() first.") + + if self.config.model_type == ModelType.LTX2: + self.logger.info("Skipping device setup for LTX-2 pipeline (handled internally)") + return + + if self.config.cpu_offloading: + self.logger.info("Enabling CPU offloading for memory efficiency") + self.pipe.enable_model_cpu_offload() + if self.pipe_upsample: + self.pipe_upsample.enable_model_cpu_offload() + else: + self.logger.info("Moving pipeline to CUDA") + self.pipe.to("cuda") + if self.pipe_upsample: + self.logger.info("Moving upsampler pipeline to CUDA") + self.pipe_upsample.to("cuda") + # Enable VAE tiling for LTX-Video to save memory + if self.config.model_type == ModelType.LTX_VIDEO_DEV: + if hasattr(self.pipe, "vae") and hasattr(self.pipe.vae, "enable_tiling"): + self.logger.info("Enabling VAE tiling for LTX-Video") + self.pipe.vae.enable_tiling() + + def get_backbone(self) -> torch.nn.Module: + """ + Get the backbone model (transformer or UNet). + + Returns: + Backbone model module + """ + if not self.pipe: + raise RuntimeError("Pipeline not created. Call create_pipeline() first.") + + if self.config.model_type == ModelType.LTX2: + self._ensure_ltx2_transformer_cached() + return self._transformer + return getattr(self.pipe, self.config.backbone) + + def _ensure_ltx2_transformer_cached(self) -> None: + if not self.pipe: + raise RuntimeError("Pipeline not created. Call create_pipeline() first.") + if self._transformer is None: + transformer = self.pipe.stage_1_model_ledger.transformer() + self.pipe.stage_1_model_ledger.transformer = lambda: transformer + self._transformer = transformer + + def _create_ltx2_pipeline(self) -> Any: + params = dict(self.config.extra_params) + checkpoint_path = params.pop("checkpoint_path", None) + distilled_lora_path = params.pop("distilled_lora_path", None) + distilled_lora_strength = params.pop("distilled_lora_strength", 0.8) + spatial_upsampler_path = params.pop("spatial_upsampler_path", None) + gemma_root = params.pop("gemma_root", None) + fp8transformer = params.pop("fp8transformer", False) + + if not checkpoint_path: + raise ValueError("Missing required extra_param: checkpoint_path.") + if not distilled_lora_path: + raise ValueError("Missing required extra_param: distilled_lora_path.") + if not spatial_upsampler_path: + raise ValueError("Missing required extra_param: spatial_upsampler_path.") + if not gemma_root: + raise ValueError("Missing required extra_param: gemma_root.") + + from ltx_core.loader import LTXV_LORA_COMFY_RENAMING_MAP, LoraPathStrengthAndSDOps + from ltx_pipelines.ti2vid_two_stages import TI2VidTwoStagesPipeline + + distilled_lora = [ + LoraPathStrengthAndSDOps( + str(distilled_lora_path), + float(distilled_lora_strength), + LTXV_LORA_COMFY_RENAMING_MAP, + ) + ] + pipeline_kwargs = { + "checkpoint_path": str(checkpoint_path), + "distilled_lora": distilled_lora, + "spatial_upsampler_path": str(spatial_upsampler_path), + "gemma_root": str(gemma_root), + "loras": [], + "fp8transformer": bool(fp8transformer), + } + pipeline_kwargs.update(params) + return TI2VidTwoStagesPipeline(**pipeline_kwargs) diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index 3f88b2911..e45959b4e 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -17,13 +17,12 @@ import logging import sys import time as time -from dataclasses import dataclass, field -from enum import Enum from pathlib import Path from typing import Any import torch import torch.nn as nn +from calibration import Calibrator from config import ( FP8_DEFAULT_CONFIG, INT8_DEFAULT_CONFIG, @@ -32,173 +31,27 @@ reset_set_int8_config, set_quant_config_attr, ) - -# This is a workaround for making the onnx export of models that use the torch RMSNorm work. We will -# need to move on to use dynamo based onnx export to properly fix the problem. The issue has been hit -# by both external users https://github.com/NVIDIA/Model-Optimizer/issues/262, and our -# internal users from MLPerf Inference. -# -if __name__ == "__main__": - from diffusers.models.normalization import RMSNorm as DiffuserRMSNorm - - torch.nn.RMSNorm = DiffuserRMSNorm - torch.nn.modules.normalization.RMSNorm = DiffuserRMSNorm - -from diffusers import DiffusionPipeline, LTXLatentUpsamplePipeline -from models_utils import ( - MODEL_DEFAULTS, - MODEL_PIPELINE, - MODEL_REGISTRY, - ModelType, - get_model_filter_func, -) +from diffusers import DiffusionPipeline +from models_utils import MODEL_DEFAULTS, ModelType, get_model_filter_func, parse_extra_params from onnx_utils.export import generate_fp8_scales, modelopt_export_sd -from tqdm import tqdm -from utils import check_conv_and_mha, check_lora, load_calib_prompts +from pipeline_manager import PipelineManager +from quantize_config import ( + CalibrationConfig, + CollectMethod, + DataType, + ExportConfig, + ModelConfig, + QuantAlgo, + QuantFormat, + QuantizationConfig, +) +from utils import check_conv_and_mha, check_lora import modelopt.torch.opt as mto import modelopt.torch.quantization as mtq from modelopt.torch.export import export_hf_checkpoint -class DataType(str, Enum): - """Supported data types for model loading.""" - - HALF = "Half" - BFLOAT16 = "BFloat16" - FLOAT = "Float" - - @property - def torch_dtype(self) -> torch.dtype: - return self._dtype_map[self.value] - - -DataType._dtype_map = { - DataType.HALF: torch.float16, - DataType.BFLOAT16: torch.bfloat16, - DataType.FLOAT: torch.float32, -} - - -class QuantFormat(str, Enum): - """Supported quantization formats.""" - - INT8 = "int8" - FP8 = "fp8" - FP4 = "fp4" - - -class QuantAlgo(str, Enum): - """Supported quantization algorithms.""" - - MAX = "max" - SVDQUANT = "svdquant" - SMOOTHQUANT = "smoothquant" - - -class CollectMethod(str, Enum): - """Calibration collection methods.""" - - GLOBAL_MIN = "global_min" - MIN_MAX = "min-max" - MIN_MEAN = "min-mean" - MEAN_MAX = "mean-max" - DEFAULT = "default" - - -@dataclass -class QuantizationConfig: - """Configuration for model quantization.""" - - format: QuantFormat = QuantFormat.INT8 - algo: QuantAlgo = QuantAlgo.MAX - percentile: float = 1.0 - collect_method: CollectMethod = CollectMethod.DEFAULT - alpha: float = 1.0 # SmoothQuant alpha - lowrank: int = 32 # SVDQuant lowrank - quantize_mha: bool = False - compress: bool = False - - def validate(self) -> None: - """Validate configuration consistency.""" - if self.format == QuantFormat.FP8 and self.collect_method != CollectMethod.DEFAULT: - raise NotImplementedError("Only 'default' collect method is implemented for FP8.") - if self.quantize_mha and self.format == QuantFormat.INT8: - raise ValueError("MHA quantization is only supported for FP8, not INT8.") - if self.compress and self.format == QuantFormat.INT8: - raise ValueError("Compression is only supported for FP8 and FP4, not INT8.") - - -@dataclass -class CalibrationConfig: - """Configuration for calibration process.""" - - prompts_dataset: dict | Path - batch_size: int = 2 - calib_size: int = 128 - n_steps: int = 30 - - def validate(self) -> None: - """Validate calibration configuration.""" - if self.batch_size <= 0: - raise ValueError("Batch size must be positive.") - if self.calib_size <= 0: - raise ValueError("Calibration size must be positive.") - if self.n_steps <= 0: - raise ValueError("Number of steps must be positive.") - - @property - def num_batches(self) -> int: - """Calculate number of calibration batches.""" - return self.calib_size // self.batch_size - - -@dataclass -class ModelConfig: - """Configuration for model loading and inference.""" - - model_type: ModelType = ModelType.FLUX_DEV - model_dtype: dict[str, torch.dtype] = field(default_factory=lambda: {"default": torch.float16}) - backbone: str = "" - trt_high_precision_dtype: DataType = DataType.HALF - override_model_path: Path | None = None - cpu_offloading: bool = False - ltx_skip_upsampler: bool = False # Skip upsampler for LTX-Video (faster calibration) - - @property - def model_path(self) -> str: - """Get the model path (override or default).""" - if self.override_model_path: - return str(self.override_model_path) - return MODEL_REGISTRY[self.model_type] - - -@dataclass -class ExportConfig: - """Configuration for model export.""" - - quantized_torch_ckpt_path: Path | None = None - onnx_dir: Path | None = None - hf_ckpt_dir: Path | None = None - restore_from: Path | None = None - - def validate(self) -> None: - """Validate export configuration.""" - if self.restore_from and not self.restore_from.exists(): - raise FileNotFoundError(f"Restore checkpoint not found: {self.restore_from}") - - if self.quantized_torch_ckpt_path: - parent_dir = self.quantized_torch_ckpt_path.parent - if not parent_dir.exists(): - parent_dir.mkdir(parents=True, exist_ok=True) - - if self.onnx_dir and not self.onnx_dir.exists(): - self.onnx_dir.mkdir(parents=True, exist_ok=True) - - if self.hf_ckpt_dir and not self.hf_ckpt_dir.exists(): - self.hf_ckpt_dir.mkdir(parents=True, exist_ok=True) - - def setup_logging(verbose: bool = False) -> logging.Logger: """ Set up logging configuration. @@ -232,275 +85,6 @@ def setup_logging(verbose: bool = False) -> logging.Logger: return logger -class PipelineManager: - """Manages diffusion pipeline creation and configuration.""" - - def __init__(self, config: ModelConfig, logger: logging.Logger): - """ - Initialize pipeline manager. - - Args: - config: Model configuration - logger: Logger instance - """ - self.config = config - self.logger = logger - self.pipe: DiffusionPipeline | None = None - self.pipe_upsample: LTXLatentUpsamplePipeline | None = None # For LTX-Video upsampling - - @staticmethod - def create_pipeline_from( - model_type: ModelType, - torch_dtype: torch.dtype | dict[str, str | torch.dtype] = torch.bfloat16, - override_model_path: str | None = None, - ) -> DiffusionPipeline: - """ - Create and return an appropriate pipeline based on configuration. - - Returns: - Configured diffusion pipeline - - Raises: - ValueError: If model type is unsupported - """ - try: - model_id = ( - MODEL_REGISTRY[model_type] if override_model_path is None else override_model_path - ) - pipe = MODEL_PIPELINE[model_type].from_pretrained( - model_id, - torch_dtype=torch_dtype, - use_safetensors=True, - **MODEL_DEFAULTS[model_type].get("from_pretrained_extra_args", {}), - ) - pipe.set_progress_bar_config(disable=True) - return pipe - except Exception as e: - raise e - - def create_pipeline(self) -> DiffusionPipeline: - """ - Create and return an appropriate pipeline based on configuration. - - Returns: - Configured diffusion pipeline - - Raises: - ValueError: If model type is unsupported - """ - self.logger.info(f"Creating pipeline for {self.config.model_type.value}") - self.logger.info(f"Model path: {self.config.model_path}") - self.logger.info(f"Data type: {self.config.model_dtype}") - - try: - self.pipe = MODEL_PIPELINE[self.config.model_type].from_pretrained( - self.config.model_path, - torch_dtype=self.config.model_dtype, - use_safetensors=True, - **MODEL_DEFAULTS[self.config.model_type].get("from_pretrained_extra_args", {}), - ) - if self.config.model_type == ModelType.LTX_VIDEO_DEV: - # Optionally load the upsampler pipeline for LTX-Video - if not self.config.ltx_skip_upsampler: - self.logger.info("Loading LTX-Video upsampler pipeline...") - self.pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained( - "Lightricks/ltxv-spatial-upscaler-0.9.7", - vae=self.pipe.vae, - torch_dtype=self.config.model_dtype, - ) - self.pipe_upsample.set_progress_bar_config(disable=True) - else: - self.logger.info("Skipping upsampler pipeline for faster calibration") - self.pipe.set_progress_bar_config(disable=True) - - self.logger.info("Pipeline created successfully") - return self.pipe - - except Exception as e: - self.logger.error(f"Failed to create pipeline: {e}") - raise - - def setup_device(self) -> None: - """Configure pipeline device placement.""" - if not self.pipe: - raise RuntimeError("Pipeline not created. Call create_pipeline() first.") - - if self.config.cpu_offloading: - self.logger.info("Enabling CPU offloading for memory efficiency") - self.pipe.enable_model_cpu_offload() - if self.pipe_upsample: - self.pipe_upsample.enable_model_cpu_offload() - else: - self.logger.info("Moving pipeline to CUDA") - self.pipe.to("cuda") - if self.pipe_upsample: - self.logger.info("Moving upsampler pipeline to CUDA") - self.pipe_upsample.to("cuda") - # Enable VAE tiling for LTX-Video to save memory - if self.config.model_type == ModelType.LTX_VIDEO_DEV: - if hasattr(self.pipe, "vae") and hasattr(self.pipe.vae, "enable_tiling"): - self.logger.info("Enabling VAE tiling for LTX-Video") - self.pipe.vae.enable_tiling() - - def get_backbone(self) -> torch.nn.Module: - """ - Get the backbone model (transformer or UNet). - - Returns: - Backbone model module - """ - if not self.pipe: - raise RuntimeError("Pipeline not created. Call create_pipeline() first.") - - return getattr(self.pipe, self.config.backbone) - - -class Calibrator: - """Handles model calibration for quantization.""" - - def __init__( - self, - pipeline_manager: PipelineManager, - config: CalibrationConfig, - model_type: ModelType, - logger: logging.Logger, - ): - """ - Initialize calibrator. - - Args: - pipeline_manager: Pipeline manager with main and upsampler pipelines - config: Calibration configuration - model_type: Type of model being calibrated - logger: Logger instance - """ - self.pipeline_manager = pipeline_manager - self.pipe = pipeline_manager.pipe - self.pipe_upsample = pipeline_manager.pipe_upsample - self.config = config - self.model_type = model_type - self.logger = logger - - def load_and_batch_prompts(self) -> list[list[str]]: - """ - Load calibration prompts from file. - - Returns: - List of batched calibration prompts - """ - self.logger.info(f"Loading calibration prompts from {self.config.prompts_dataset}") - if isinstance(self.config.prompts_dataset, Path): - return load_calib_prompts( - self.config.batch_size, - self.config.prompts_dataset, - ) - - return load_calib_prompts( - self.config.batch_size, - self.config.prompts_dataset["name"], - self.config.prompts_dataset["split"], - self.config.prompts_dataset["column"], - ) - - def run_calibration(self, batched_prompts: list[list[str]]) -> None: - """ - Run calibration steps on the pipeline. - - Args: - batched_prompts: List of batched calibration prompts - """ - self.logger.info(f"Starting calibration with {self.config.num_batches} batches") - extra_args = MODEL_DEFAULTS.get(self.model_type, {}).get("inference_extra_args", {}) - - with tqdm(total=self.config.num_batches, desc="Calibration", unit="batch") as pbar: - for i, prompt_batch in enumerate(batched_prompts): - if i >= self.config.num_batches: - break - - if self.model_type == ModelType.LTX_VIDEO_DEV: - # Special handling for LTX-Video - self._run_ltx_video_calibration(prompt_batch, extra_args) - elif self.model_type in [ModelType.WAN22_T2V_14b, ModelType.WAN22_T2V_5b]: - # Special handling for WAN video models - self._run_wan_video_calibration(prompt_batch, extra_args) - else: - common_args = { - "prompt": prompt_batch, - "num_inference_steps": self.config.n_steps, - } - self.pipe(**common_args, **extra_args).images # type: ignore[misc] - pbar.update(1) - self.logger.debug(f"Completed calibration batch {i + 1}/{self.config.num_batches}") - self.logger.info("Calibration completed successfully") - - def _run_wan_video_calibration( - self, prompt_batch: list[str], extra_args: dict[str, Any] - ) -> None: - kwargs = {} - kwargs["negative_prompt"] = extra_args["negative_prompt"] - kwargs["height"] = extra_args["height"] - kwargs["width"] = extra_args["width"] - kwargs["num_frames"] = extra_args["num_frames"] - kwargs["guidance_scale"] = extra_args["guidance_scale"] - if "guidance_scale_2" in extra_args: - kwargs["guidance_scale_2"] = extra_args["guidance_scale_2"] - kwargs["num_inference_steps"] = self.config.n_steps - - self.pipe(prompt=prompt_batch, **kwargs).frames # type: ignore[misc] - - def _run_ltx_video_calibration( - self, prompt_batch: list[str], extra_args: dict[str, Any] - ) -> None: - """ - Run calibration for LTX-Video model using the full multi-stage pipeline. - - Args: - prompt_batch: Batch of prompts - extra_args: Model-specific arguments - """ - # Extract specific args for LTX-Video - expected_height = extra_args.get("height", 512) - expected_width = extra_args.get("width", 704) - num_frames = extra_args.get("num_frames", 121) - negative_prompt = extra_args.get( - "negative_prompt", "worst quality, inconsistent motion, blurry, jittery, distorted" - ) - - def round_to_nearest_resolution_acceptable_by_vae(height, width): - height = height - (height % self.pipe.vae_spatial_compression_ratio) # type: ignore[union-attr] - width = width - (width % self.pipe.vae_spatial_compression_ratio) # type: ignore[union-attr] - return height, width - - downscale_factor = 2 / 3 - # Part 1: Generate video at smaller resolution - downscaled_height, downscaled_width = ( - int(expected_height * downscale_factor), - int(expected_width * downscale_factor), - ) - downscaled_height, downscaled_width = round_to_nearest_resolution_acceptable_by_vae( - downscaled_height, downscaled_width - ) - - # Generate initial latents at lower resolution - latents = self.pipe( # type: ignore[misc] - conditions=None, - prompt=prompt_batch, - negative_prompt=negative_prompt, - width=downscaled_width, - height=downscaled_height, - num_frames=num_frames, - num_inference_steps=self.config.n_steps, - output_type="latent", - ).frames - - # Part 2: Upscale generated video using latent upsampler (if available) - if self.pipe_upsample is not None: - _ = self.pipe_upsample(latents=latents, output_type="latent").frames - - # Part 3: Denoise the upscaled video with few steps to improve texture - # However, in this example code, we will omit the upscale step since its optional. - - class Quantizer: """Handles model quantization operations.""" @@ -568,7 +152,7 @@ def quantize_model( backbone: torch.nn.Module, quant_config: Any, forward_loop: callable, # type: ignore[valid-type] - ) -> None: + ) -> torch.nn.Module: """ Apply quantization to the model. @@ -590,6 +174,7 @@ def quantize_model( mtq.disable_quantizer(backbone, model_filter_func) self.logger.info("Quantization completed successfully") + return backbone class ExportManager: @@ -691,7 +276,8 @@ def restore_checkpoint(self, backbone: nn.Module) -> None: mto.restore(backbone, str(self.config.restore_from)) self.logger.info("Model restored successfully") - def export_hf_ckpt(self, pipe: DiffusionPipeline) -> None: + # TODO: should not do the any data type + def export_hf_ckpt(self, pipe: Any) -> None: """ Export quantized model to HuggingFace checkpoint format. @@ -754,7 +340,7 @@ def create_argument_parser() -> argparse.ArgumentParser: model_group.add_argument( "--model-dtype", type=str, - default="Half", + default="BFloat16", choices=[d.value for d in DataType], help="Precision for loading the pipeline. If you want different dtypes for separate components, " "please specify using --component-dtype", @@ -778,6 +364,16 @@ def create_argument_parser() -> argparse.ArgumentParser: action="store_true", help="Skip upsampler pipeline for LTX-Video (faster calibration, only quantizes main transformer)", ) + model_group.add_argument( + "--extra-param", + action="append", + default=[], + metavar="KEY=VALUE", + help=( + "Extra model-specific parameters in KEY=VALUE form. Can be provided multiple times. " + "These override model-specific CLI arguments when present." + ), + ) quant_group = parser.add_argument_group("Quantization Configuration") quant_group.add_argument( "--format", @@ -859,7 +455,7 @@ def create_argument_parser() -> argparse.ArgumentParser: def main() -> None: parser = create_argument_parser() - args = parser.parse_args() + args, unknown_args = parser.parse_known_args() model_type = ModelType(args.model) if args.backbone is None: @@ -875,6 +471,7 @@ def main() -> None: logger.info("Starting Enhanced Diffusion Model Quantization") try: + extra_params = parse_extra_params(args.extra_param, unknown_args, logger) model_config = ModelConfig( model_type=model_type, model_dtype=model_dtype, @@ -885,6 +482,7 @@ def main() -> None: else None, cpu_offloading=args.cpu_offloading, ltx_skip_upsampler=args.ltx_skip_upsampler, + extra_params=extra_params, ) quant_config = QuantizationConfig( @@ -950,6 +548,7 @@ def main() -> None: quantizer = Quantizer(quant_config, model_config, logger) backbone_quant_config = quantizer.get_quant_config(calib_config.n_steps, backbone) + # Pipe loads the ckpt just before the inference. def forward_loop(mod): calibrator.run_calibration(batched_prompts) diff --git a/examples/diffusers/quantization/quantize_config.py b/examples/diffusers/quantization/quantize_config.py new file mode 100644 index 000000000..980d39f31 --- /dev/null +++ b/examples/diffusers/quantization/quantize_config.py @@ -0,0 +1,160 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass, field +from enum import Enum +from pathlib import Path +from typing import Any + +import torch +from models_utils import MODEL_REGISTRY, ModelType + + +class DataType(str, Enum): + """Supported data types for model loading.""" + + HALF = "Half" + BFLOAT16 = "BFloat16" + FLOAT = "Float" + + @property + def torch_dtype(self) -> torch.dtype: + return self._dtype_map[self.value] + + +DataType._dtype_map = { + DataType.HALF: torch.float16, + DataType.BFLOAT16: torch.bfloat16, + DataType.FLOAT: torch.float32, +} + + +class QuantFormat(str, Enum): + """Supported quantization formats.""" + + INT8 = "int8" + FP8 = "fp8" + FP4 = "fp4" + + +class QuantAlgo(str, Enum): + """Supported quantization algorithms.""" + + MAX = "max" + SVDQUANT = "svdquant" + SMOOTHQUANT = "smoothquant" + + +class CollectMethod(str, Enum): + """Calibration collection methods.""" + + GLOBAL_MIN = "global_min" + MIN_MAX = "min-max" + MIN_MEAN = "min-mean" + MEAN_MAX = "mean-max" + DEFAULT = "default" + + +@dataclass +class QuantizationConfig: + """Configuration for model quantization.""" + + format: QuantFormat = QuantFormat.INT8 + algo: QuantAlgo = QuantAlgo.MAX + percentile: float = 1.0 + collect_method: CollectMethod = CollectMethod.DEFAULT + alpha: float = 1.0 # SmoothQuant alpha + lowrank: int = 32 # SVDQuant lowrank + quantize_mha: bool = False + compress: bool = False + + def validate(self) -> None: + """Validate configuration consistency.""" + if self.format == QuantFormat.FP8 and self.collect_method != CollectMethod.DEFAULT: + raise NotImplementedError("Only 'default' collect method is implemented for FP8.") + if self.quantize_mha and self.format == QuantFormat.INT8: + raise ValueError("MHA quantization is only supported for FP8, not INT8.") + if self.compress and self.format == QuantFormat.INT8: + raise ValueError("Compression is only supported for FP8 and FP4, not INT8.") + + +@dataclass +class CalibrationConfig: + """Configuration for calibration process.""" + + prompts_dataset: dict | Path + batch_size: int = 2 + calib_size: int = 128 + n_steps: int = 30 + + def validate(self) -> None: + """Validate calibration configuration.""" + if self.batch_size <= 0: + raise ValueError("Batch size must be positive.") + if self.calib_size <= 0: + raise ValueError("Calibration size must be positive.") + if self.n_steps <= 0: + raise ValueError("Number of steps must be positive.") + + @property + def num_batches(self) -> int: + """Calculate number of calibration batches.""" + return self.calib_size // self.batch_size + + +@dataclass +class ModelConfig: + """Configuration for model loading and inference.""" + + model_type: ModelType = ModelType.FLUX_DEV + model_dtype: dict[str, torch.dtype] = field(default_factory=lambda: {"default": torch.float16}) + backbone: str = "" + trt_high_precision_dtype: DataType = DataType.HALF + override_model_path: Path | None = None + cpu_offloading: bool = False + ltx_skip_upsampler: bool = False # Skip upsampler for LTX-Video (faster calibration) + extra_params: dict[str, Any] = field(default_factory=dict) + + @property + def model_path(self) -> str: + """Get the model path (override or default).""" + if self.override_model_path: + return str(self.override_model_path) + return MODEL_REGISTRY[self.model_type] + + +@dataclass +class ExportConfig: + """Configuration for model export.""" + + quantized_torch_ckpt_path: Path | None = None + onnx_dir: Path | None = None + hf_ckpt_dir: Path | None = None + restore_from: Path | None = None + + def validate(self) -> None: + """Validate export configuration.""" + if self.restore_from and not self.restore_from.exists(): + raise FileNotFoundError(f"Restore checkpoint not found: {self.restore_from}") + + if self.quantized_torch_ckpt_path: + parent_dir = self.quantized_torch_ckpt_path.parent + if not parent_dir.exists(): + parent_dir.mkdir(parents=True, exist_ok=True) + + if self.onnx_dir and not self.onnx_dir.exists(): + self.onnx_dir.mkdir(parents=True, exist_ok=True) + + if self.hf_ckpt_dir and not self.hf_ckpt_dir.exists(): + self.hf_ckpt_dir.mkdir(parents=True, exist_ok=True) diff --git a/examples/diffusers/quantization/utils.py b/examples/diffusers/quantization/utils.py index e5cc7c015..be9be0424 100644 --- a/examples/diffusers/quantization/utils.py +++ b/examples/diffusers/quantization/utils.py @@ -25,7 +25,7 @@ from diffusers.utils import load_image import modelopt.torch.quantization as mtq -from modelopt.torch.quantization.plugins.diffusers import AttentionModuleMixin +from modelopt.torch.quantization.plugins.diffusion.diffusers import AttentionModuleMixin USE_PEFT = True try: @@ -69,7 +69,9 @@ def check_conv_and_mha(backbone, if_fp4, quantize_mha): def filter_func_ltx_video(name: str) -> bool: """Filter function specifically for LTX-Video models.""" - pattern = re.compile(r".*(proj_in|time_embed|caption_projection|proj_out).*") + pattern = re.compile( + r".*(proj_in|time_embed|caption_projection|proj_out|patchify_proj|adaln_single).*" + ) return pattern.match(name) is not None diff --git a/modelopt/torch/export/diffusers_utils.py b/modelopt/torch/export/diffusers_utils.py index 001324cba..7667afde1 100644 --- a/modelopt/torch/export/diffusers_utils.py +++ b/modelopt/torch/export/diffusers_utils.py @@ -16,16 +16,26 @@ """Code that export quantized Hugging Face models for deployment.""" import warnings +from collections.abc import Callable from contextlib import contextmanager from importlib import import_module from typing import Any import torch import torch.nn as nn -from diffusers import DiffusionPipeline from .layer_utils import is_quantlinear +DiffusionPipeline: type[Any] | None +try: # diffusers is optional for LTX-2 export paths + from diffusers import DiffusionPipeline as _DiffusionPipeline + + DiffusionPipeline = _DiffusionPipeline + _HAS_DIFFUSERS = True +except Exception: # pragma: no cover + DiffusionPipeline = None + _HAS_DIFFUSERS = False + def generate_diffusion_dummy_inputs( model: nn.Module, device: torch.device, dtype: torch.dtype @@ -288,6 +298,126 @@ def _generic_transformer_inputs() -> dict[str, torch.Tensor] | None: return None +def generate_diffusion_dummy_forward_fn(model: nn.Module) -> Callable[[], None]: + """Create a dummy forward function for diffusion(-like) models. + + - For diffusers components, this uses `generate_diffusion_dummy_inputs()` and calls `model(**kwargs)`. + - For LTX-2 stage-1 transformer (X0Model), the forward signature is + `model(video: Modality|None, audio: Modality|None, perturbations: BatchedPerturbationConfig)`, + so we build tiny `ltx_core` dataclasses and call the model directly. + """ + # Duck-typed LTX-2 stage-1 transformer wrapper + velocity_model = getattr(model, "velocity_model", None) + if velocity_model is not None: + + def _ltx2_dummy_forward() -> None: + try: + from ltx_core.guidance.perturbations import BatchedPerturbationConfig + from ltx_core.model.transformer.modality import Modality + except Exception as e: # pragma: no cover + raise RuntimeError( + "LTX-2 export requires `ltx_core` to be installed (Modality, BatchedPerturbationConfig)." + ) from e + + # Small shapes for speed/memory + batch_size = 1 + v_seq_len = 8 + a_seq_len = 8 + ctx_len = 4 + + device = next(model.parameters()).device + default_dtype = next(model.parameters()).dtype + + def _param_dtype(module: Any, fallback: torch.dtype) -> torch.dtype: + w = getattr(getattr(module, "weight", None), "dtype", None) + return w if isinstance(w, torch.dtype) else fallback + + def _positions(bounds_dims: int, seq_len: int) -> torch.Tensor: + # [B, dims, seq_len, 2] bounds (start/end) + pos = torch.zeros( + (batch_size, bounds_dims, seq_len, 2), device=device, dtype=torch.float32 + ) + pos[..., 1] = 1.0 + return pos + + has_video = hasattr(velocity_model, "patchify_proj") and hasattr( + velocity_model, "caption_projection" + ) + has_audio = hasattr(velocity_model, "audio_patchify_proj") and hasattr( + velocity_model, "audio_caption_projection" + ) + if not has_video and not has_audio: + raise ValueError( + "Unsupported LTX-2 velocity model: missing both video and audio preprocessors." + ) + + video = None + if has_video: + v_in = int(velocity_model.patchify_proj.in_features) + v_caption_in = int(velocity_model.caption_projection.linear_1.in_features) + v_latent_dtype = _param_dtype(velocity_model.patchify_proj, default_dtype) + v_ctx_dtype = _param_dtype( + velocity_model.caption_projection.linear_1, default_dtype + ) + video = Modality( + enabled=True, + latent=torch.randn( + batch_size, v_seq_len, v_in, device=device, dtype=v_latent_dtype + ), + # LTX `X0Model` uses `timesteps` as the sigma tensor in `to_denoised(sample, velocity, sigma)`. + # It must be broadcastable to `[B, T, D]`, so we use `[B, T, 1]`. + timesteps=torch.full( + (batch_size, v_seq_len, 1), 0.5, device=device, dtype=torch.float32 + ), + positions=_positions(bounds_dims=3, seq_len=v_seq_len), + context=torch.randn( + batch_size, ctx_len, v_caption_in, device=device, dtype=v_ctx_dtype + ), + context_mask=None, + ) + + audio = None + if has_audio: + a_in = int(velocity_model.audio_patchify_proj.in_features) + a_caption_in = int(velocity_model.audio_caption_projection.linear_1.in_features) + a_latent_dtype = _param_dtype(velocity_model.audio_patchify_proj, default_dtype) + a_ctx_dtype = _param_dtype( + velocity_model.audio_caption_projection.linear_1, default_dtype + ) + audio = Modality( + enabled=True, + latent=torch.randn( + batch_size, a_seq_len, a_in, device=device, dtype=a_latent_dtype + ), + timesteps=torch.full( + (batch_size, a_seq_len, 1), 0.5, device=device, dtype=torch.float32 + ), + positions=_positions(bounds_dims=1, seq_len=a_seq_len), + context=torch.randn( + batch_size, ctx_len, a_caption_in, device=device, dtype=a_ctx_dtype + ), + context_mask=None, + ) + + perturbations = BatchedPerturbationConfig.empty(batch_size) + model(video, audio, perturbations) + + return _ltx2_dummy_forward + + # Default: diffusers-style `model(**kwargs)` + def _diffusers_dummy_forward() -> None: + device = next(model.parameters()).device + dtype = next(model.parameters()).dtype + dummy_inputs = generate_diffusion_dummy_inputs(model, device, dtype) + if dummy_inputs is None: + raise ValueError( + f"Unknown model type '{type(model).__name__}', cannot generate dummy inputs." + ) + model(**dummy_inputs) + + return _diffusers_dummy_forward + + def is_qkv_projection(module_name: str) -> bool: """Check if a module name corresponds to a QKV projection layer. @@ -377,17 +507,19 @@ def get_qkv_group_key(module_name: str) -> str: return f"{parent_path}.{qkv_type}" -def get_diffusers_components( - model: DiffusionPipeline | nn.Module, +def get_diffusion_components( + model: Any, components: list[str] | None = None, ) -> dict[str, Any]: - """Get all exportable components from a diffusers pipeline. + """Get all exportable components from a diffusion(-like) pipeline. - This function extracts all components from a DiffusionPipeline including - nn.Module models, tokenizers, schedulers, feature extractors, etc. + Supports: + - diffusers `DiffusionPipeline`: returns `pipeline.components` + - diffusers component `nn.Module` (e.g., UNet / transformer) + - LTX-2 pipeline (duck-typed): returns stage-1 transformer only as `stage_1_transformer` Args: - model: The diffusers pipeline. + model: The pipeline or component. components: Optional list of component names to filter. If None, all components are returned. @@ -395,7 +527,21 @@ def get_diffusers_components( Dictionary mapping component names to their instances (can be nn.Module, tokenizers, schedulers, etc.). """ - if isinstance(model, DiffusionPipeline): + # LTX-2 pipeline: duck-typed stage-1 transformer export + stage_1 = getattr(model, "stage_1_model_ledger", None) + transformer_fn = getattr(stage_1, "transformer", None) + if stage_1 is not None and callable(transformer_fn): + all_components: dict[str, Any] = {"stage_1_transformer": stage_1.transformer()} + if components is not None: + filtered = {name: comp for name, comp in all_components.items() if name in components} + missing = set(components) - set(filtered.keys()) + if missing: + warnings.warn(f"Requested components not found in pipeline: {missing}") + return filtered + return all_components + + # diffusers pipeline + if _HAS_DIFFUSERS and DiffusionPipeline is not None and isinstance(model, DiffusionPipeline): # Get all components from the pipeline all_components = {name: comp for name, comp in model.components.items() if comp is not None} @@ -427,6 +573,10 @@ def get_diffusers_components( raise TypeError(f"Expected DiffusionPipeline or nn.Module, got {type(model).__name__}") +# Backward-compatible alias +get_diffusers_components = get_diffusion_components + + @contextmanager def hide_quantizers_from_state_dict(model: nn.Module): """Context manager that temporarily removes quantizer modules from the model. diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index ce4b557d9..d9abe190b 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -35,8 +35,8 @@ from diffusers import DiffusionPipeline, ModelMixin from .diffusers_utils import ( - generate_diffusion_dummy_inputs, - get_diffusers_components, + generate_diffusion_dummy_forward_fn, + get_diffusion_components, get_qkv_group_key, hide_quantizers_from_state_dict, infer_dtype_from_model, @@ -92,6 +92,11 @@ to_quantized_weight, ) +try: # optional for LTX-2 export paths + from ltx_pipelines.ti2vid_two_stages import TI2VidTwoStagesPipeline +except ImportError: # pragma: no cover + TI2VidTwoStagesPipeline = None + __all__ = ["export_hf_checkpoint"] @@ -219,7 +224,7 @@ def _fuse_shared_input_modules( # Fuse each group separately for group_key, group_modules in qkv_groups.items(): - if len(group_modules) > 1: + if len(group_modules) >= 2: preprocess_linear_fusion(group_modules, resmooth_only=False) fused_count += 1 module_names = [getattr(m, "name", "unknown") for m in group_modules] @@ -703,7 +708,9 @@ def _export_transformers_checkpoint( return quantized_state_dict, quant_config -def _fuse_qkv_linears_diffusion(model: nn.Module) -> None: +def _fuse_qkv_linears_diffusion( + model: nn.Module, dummy_forward_fn: Callable[[], None] | None = None +) -> None: """Fuse QKV linear layers that share the same input for diffusion models. This function uses forward hooks to dynamically identify linear modules that @@ -718,33 +725,22 @@ def _fuse_qkv_linears_diffusion(model: nn.Module) -> None: Args: model: The diffusion model component (e.g., transformer, unet). + dummy_forward_fn: Optional callable to run a dummy forward pass. Use this + for diffusion-like models whose forward signature is not compatible + with `generate_diffusion_dummy_inputs`. """ quantization_format = get_quantization_format(model) if quantization_format == QUANTIZATION_NONE: return - # Define the dummy forward function for diffusion models - def diffusion_dummy_forward(): - device = next(model.parameters()).device - dtype = next(model.parameters()).dtype - - # Generate appropriate dummy inputs based on model type - dummy_inputs = generate_diffusion_dummy_inputs(model, device, dtype) - - if dummy_inputs is None: - model_class_name = type(model).__name__ - raise ValueError( - f"Unknown model type '{model_class_name}', cannot generate dummy inputs." - ) - - # Run forward pass with dummy inputs - model(**dummy_inputs) + if dummy_forward_fn is None: + dummy_forward_fn = generate_diffusion_dummy_forward_fn(model) # Collect modules sharing the same input try: input_to_linear, _ = _collect_shared_input_modules( - model, diffusion_dummy_forward, collect_layernorms=False + model, dummy_forward_fn, collect_layernorms=False ) except Exception as e: print(f"Warning: Failed to run dummy forward for QKV fusion: {e}") @@ -767,20 +763,20 @@ def diffusion_dummy_forward(): def _export_diffusers_checkpoint( - pipe: "DiffusionPipeline | ModelMixin", + pipe: Any, dtype: torch.dtype | None, export_dir: Path, components: list[str] | None, max_shard_size: int | str = "10GB", ) -> None: - """Internal: Export Diffusers model/pipeline checkpoint. + """Internal: Export diffusion(-like) model/pipeline checkpoint. - This function handles the export of diffusers models, including - DiffusionPipeline and individual ModelMixin components. It exports all - components including nn.Module models, tokenizers, schedulers, etc. + This function handles the export of: + - diffusers models: DiffusionPipeline and individual ModelMixin components. + - LTX-2 pipelines (duck-typed): exports stage-1 transformer only. Args: - pipe: The diffusers model or pipeline to export. + pipe: The model or pipeline to export. dtype: The data type for weight conversion. If None, will be inferred from model. export_dir: The directory to save the exported checkpoint. components: Optional list of component names to export. Only used for pipelines. @@ -792,7 +788,7 @@ def _export_diffusers_checkpoint( export_dir = Path(export_dir) # Step 1: Get all pipeline components (nn.Module, tokenizers, schedulers, etc.) - all_components = get_diffusers_components(pipe, components) + all_components = get_diffusion_components(pipe, components) if not all_components: warnings.warn("No exportable components found in the model.") @@ -803,6 +799,16 @@ def _export_diffusers_checkpoint( name: comp for name, comp in all_components.items() if isinstance(comp, nn.Module) } + # Best-effort diffusers pipeline check (kept for folder layout + model_index.json behavior) + is_diffusers_pipe = False + if HAS_DIFFUSERS: + try: + from diffusers import DiffusionPipeline as _DiffusionPipeline + + is_diffusers_pipe = isinstance(pipe, _DiffusionPipeline) + except Exception: + is_diffusers_pipe = False + # Step 3: Export each nn.Module component with quantization handling for component_name, component in module_components.items(): is_quantized = has_quantized_modules(component) @@ -811,7 +817,7 @@ def _export_diffusers_checkpoint( # Determine component export directory # For pipelines, each component goes in a subfolder - if isinstance(pipe, DiffusionPipeline): + if is_diffusers_pipe: component_export_dir = export_dir / component_name else: component_export_dir = export_dir @@ -835,11 +841,26 @@ def _export_diffusers_checkpoint( quant_config = get_quant_config(component, is_modelopt_qlora=False) # Step 6: Save the component - # Note: diffusers ModelMixin.save_pretrained does NOT accept state_dict parameter - # (unlike transformers), so we use a context manager to temporarily hide quantizers - # from the state dict during save. This avoids saving quantizer buffers like _amax. - with hide_quantizers_from_state_dict(component): - component.save_pretrained(component_export_dir, max_shard_size=max_shard_size) + # - diffusers ModelMixin.save_pretrained does NOT accept state_dict parameter + # - for non-diffusers modules (e.g., LTX-2 transformer), fall back to torch.save + if hasattr(component, "save_pretrained"): + with hide_quantizers_from_state_dict(component): + component.save_pretrained(component_export_dir, max_shard_size=max_shard_size) + else: + with hide_quantizers_from_state_dict(component): + cpu_state_dict = { + k: v.detach().contiguous().cpu() for k, v in component.state_dict().items() + } + save_file(cpu_state_dict, str(component_export_dir / "model.safetensors")) + with open(component_export_dir / "config.json", "w") as f: + json.dump( + { + "_class_name": type(component).__name__, + "_export_format": "safetensors_state_dict", + }, + f, + indent=4, + ) # Step 7: Update config.json with quantization info if quant_config is not None: @@ -852,14 +873,28 @@ def _export_diffusers_checkpoint( config_data["quantization_config"] = hf_quant_config with open(config_path, "w") as file: json.dump(config_data, file, indent=4) - else: - # Non-quantized component: just save as-is + # Non-quantized component: just save as-is + elif hasattr(component, "save_pretrained"): component.save_pretrained(component_export_dir, max_shard_size=max_shard_size) + else: + cpu_state_dict = { + k: v.detach().contiguous().cpu() for k, v in component.state_dict().items() + } + save_file(cpu_state_dict, str(component_export_dir / "model.safetensors")) + with open(component_export_dir / "config.json", "w") as f: + json.dump( + { + "_class_name": type(component).__name__, + "_export_format": "safetensors_state_dict", + }, + f, + indent=4, + ) print(f" Saved to: {component_export_dir}") # Step 4: Export non-nn.Module components (tokenizers, schedulers, feature extractors, etc.) - if isinstance(pipe, DiffusionPipeline): + if is_diffusers_pipe: for component_name, component in all_components.items(): # Skip nn.Module components (already handled above) if isinstance(component, nn.Module): @@ -887,7 +922,7 @@ def _export_diffusers_checkpoint( print(f" Saved to: {component_export_dir}") # Step 5: For pipelines, also save the model_index.json - if isinstance(pipe, DiffusionPipeline): + if is_diffusers_pipe: model_index_path = export_dir / "model_index.json" if hasattr(pipe, "config") and pipe.config is not None: # Save a simplified model_index.json that points to the exported components @@ -911,7 +946,7 @@ def _export_diffusers_checkpoint( def export_hf_checkpoint( - model: "nn.Module | DiffusionPipeline", + model: Any, dtype: torch.dtype | None = None, export_dir: Path | str = tempfile.gettempdir(), save_modelopt_state: bool = False, @@ -936,13 +971,15 @@ def export_hf_checkpoint( export_dir = Path(export_dir) export_dir.mkdir(parents=True, exist_ok=True) - # Check for diffusers models (only when diffusers is installed) + is_diffusers_obj = False if HAS_DIFFUSERS: - from diffusers import DiffusionPipeline, ModelMixin - - if isinstance(model, (DiffusionPipeline, ModelMixin)): - _export_diffusers_checkpoint(model, dtype, export_dir, components) - return + diffusers_types: tuple[type, ...] = (DiffusionPipeline, ModelMixin) + if TI2VidTwoStagesPipeline is not None: + diffusers_types = (*diffusers_types, TI2VidTwoStagesPipeline) + is_diffusers_obj = isinstance(model, diffusers_types) + if is_diffusers_obj: + _export_diffusers_checkpoint(model, dtype, export_dir, components) + return # Transformers model export # NOTE: (hg) Early exit for speculative decoding models diff --git a/modelopt/torch/opt/dynamic.py b/modelopt/torch/opt/dynamic.py index a2834329e..8950d1e5c 100644 --- a/modelopt/torch/opt/dynamic.py +++ b/modelopt/torch/opt/dynamic.py @@ -584,6 +584,14 @@ def export(self) -> nn.Module: assert not is_dynamic, "Exported module must not be a DynamicModule anymore!" delattr(self, "_dm_attribute_manager") + # If this module had a monkey-patched forward before DynamicModule.convert(), we may have + # overridden it by binding the dynamic forward onto the instance (to follow the MRO). + # On final export, restore the original forward to avoid leaking a dynamic forward + # (e.g., DistillationModel.forward) onto the exported (non-dynamic) module instance. + if hasattr(self, "_forward_pre_dm"): + setattr(self, "forward", getattr(self, "_forward_pre_dm")) + delattr(self, "_forward_pre_dm") + return self @classmethod @@ -621,6 +629,10 @@ def bind_forward_method_if_needed(self): # accelerate patched module bind_forward_method(self, self.__class__.forward) else: + if not hasattr(self, "_forward_pre_dm"): + # Keep the patched forward for downstream modules that want to call it. + self._forward_pre_dm = self.forward + bind_forward_method(self, self.__class__.forward) warnings.warn( "Received a module with monkey patched forward method. Dynamic converted module" " might not work." diff --git a/modelopt/torch/quantization/nn/modules/quant_module.py b/modelopt/torch/quantization/nn/modules/quant_module.py index 12aaee3f8..e00e7c77d 100644 --- a/modelopt/torch/quantization/nn/modules/quant_module.py +++ b/modelopt/torch/quantization/nn/modules/quant_module.py @@ -110,7 +110,25 @@ class QuantInputBase(QuantModule): def forward(self, input, *args, **kwargs): """Quantize the input before calling the original forward method.""" input = self.input_quantizer(input) - output = super().forward(input, *args, **kwargs) + if hasattr(self, "_forward_pre_dm"): + pre_fwd = getattr(self, "_forward_pre_dm") + + def _is_forward_in_mro(bound_or_func) -> bool: + # If this is a bound method, compare its underlying function to any `forward` + # implementation in the current MRO. If it matches, it's not an external monkey-patch. + if hasattr(bound_or_func, "__func__"): + fn = bound_or_func.__func__ + for cls in type(self).mro(): + if cls.__dict__.get("forward") is fn: + return True + return False + + if pre_fwd is getattr(self, "forward") or _is_forward_in_mro(pre_fwd): + output = super().forward(input, *args, **kwargs) + else: + output = pre_fwd(input, *args, **kwargs) + else: + output = super().forward(input, *args, **kwargs) if isinstance(output, tuple): return (self.output_quantizer(output[0]), *output[1:]) return self.output_quantizer(output) diff --git a/modelopt/torch/quantization/plugins/__init__.py b/modelopt/torch/quantization/plugins/__init__.py index 71be25d94..ecd24d81e 100644 --- a/modelopt/torch/quantization/plugins/__init__.py +++ b/modelopt/torch/quantization/plugins/__init__.py @@ -41,7 +41,7 @@ from .custom import * with import_plugin("diffusers"): - from .diffusers import * + from .diffusion.diffusers import * with import_plugin("fairscale"): from .fairscale import * @@ -77,4 +77,4 @@ from .trl import * with import_plugin("fastvideo"): - from .fastvideo import * + from .diffusion.fastvideo import * diff --git a/modelopt/torch/quantization/plugins/diffusers.py b/modelopt/torch/quantization/plugins/diffusion/diffusers.py similarity index 98% rename from modelopt/torch/quantization/plugins/diffusers.py rename to modelopt/torch/quantization/plugins/diffusion/diffusers.py index 440d190d3..2ec057766 100644 --- a/modelopt/torch/quantization/plugins/diffusers.py +++ b/modelopt/torch/quantization/plugins/diffusion/diffusers.py @@ -45,8 +45,8 @@ else: # torch >= 2.9 from torch.onnx._internal.torchscript_exporter.jit_utils import GraphContext -from ..export_onnx import export_fp8_mha -from ..nn import ( +from ...export_onnx import export_fp8_mha +from ...nn import ( QuantConv2d, QuantInputBase, QuantLinear, @@ -54,7 +54,7 @@ QuantModuleRegistry, TensorQuantizer, ) -from .custom import _QuantFunctionalMixin +from ..custom import _QuantFunctionalMixin onnx_dtype_map = { "BFloat16": onnx.TensorProto.BFLOAT16, diff --git a/modelopt/torch/quantization/plugins/fastvideo.py b/modelopt/torch/quantization/plugins/diffusion/fastvideo.py similarity index 92% rename from modelopt/torch/quantization/plugins/fastvideo.py rename to modelopt/torch/quantization/plugins/diffusion/fastvideo.py index f518873c0..2fd6a5945 100644 --- a/modelopt/torch/quantization/plugins/fastvideo.py +++ b/modelopt/torch/quantization/plugins/diffusion/fastvideo.py @@ -20,10 +20,10 @@ from fastvideo.layers.linear import ReplicatedLinear from fastvideo.models.vaes.wanvae import WanCausalConv3d -from ..nn import QuantLinearConvBase, QuantModuleRegistry -from ..nn.modules.quant_conv import _QuantConv3d -from ..nn.modules.quant_linear import _QuantLinear -from ..utils import is_torch_export_mode +from ...nn import QuantLinearConvBase, QuantModuleRegistry +from ...nn.modules.quant_conv import _QuantConv3d +from ...nn.modules.quant_linear import _QuantLinear +from ...utils import is_torch_export_mode @QuantModuleRegistry.register({WanCausalConv3d: "WanCausalConv3d"}) diff --git a/modelopt/torch/quantization/plugins/diffusion/ltx2.py b/modelopt/torch/quantization/plugins/diffusion/ltx2.py new file mode 100644 index 000000000..d89fe4b82 --- /dev/null +++ b/modelopt/torch/quantization/plugins/diffusion/ltx2.py @@ -0,0 +1,66 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""LTX-2 quantization plugin.""" + +import contextlib + +import torch + +from modelopt.torch.quantization.nn.modules.quant_linear import _QuantLinear +from modelopt.torch.quantization.nn.modules.quant_module import QuantModuleRegistry +from modelopt.torch.quantization.utils import is_torch_export_mode + +_FP8_DTYPES = tuple( + dtype + for dtype_name in ("float8_e4m3fn", "float8_e5m2", "float8_e4m3fnuz", "float8_e5m2fnuz") + if (dtype := getattr(torch, dtype_name, None)) is not None +) + + +def _upcast_fp8_weight( + weight: torch.Tensor, target_dtype: torch.dtype, seed: int = 0 +) -> torch.Tensor: + if target_dtype is torch.bfloat16: + try: + from ltx_core.loader.fuse_loras import fused_add_round_launch + + return fused_add_round_launch( + torch.zeros_like(weight, dtype=target_dtype), + weight, + seed, + ) + except Exception: + pass + return weight.to(target_dtype) + + +class _QuantLTX2Linear(_QuantLinear): + """Quantized Linear with FP8 upcast before weight quantization.""" + + @staticmethod + def _get_quantized_weight(module: "_QuantLTX2Linear", weight: torch.Tensor) -> torch.Tensor: + if _FP8_DTYPES and weight.dtype in _FP8_DTYPES: + weight = _upcast_fp8_weight(weight, torch.bfloat16, 0) + if module._enable_weight_quantization or is_torch_export_mode(): + return module.weight_quantizer(weight) + return weight + + +def register_ltx2_quant_linear() -> None: + """Register the LTX-2 quantized Linear, overriding the default mapping.""" + with contextlib.suppress(KeyError): + QuantModuleRegistry.unregister(torch.nn.Linear) + QuantModuleRegistry.register({torch.nn.Linear: "nn.Linear"})(_QuantLTX2Linear) diff --git a/modelopt/torch/quantization/utils.py b/modelopt/torch/quantization/utils.py index b663ef5f2..6cf6bc90f 100644 --- a/modelopt/torch/quantization/utils.py +++ b/modelopt/torch/quantization/utils.py @@ -229,9 +229,7 @@ def weight_attr_names(module: nn.Module) -> Generator[str, None, None]: # the standard weight and quantizer case weight = getattr(module, "weight", None) weight_quantizer = getattr(module, "weight_quantizer", None) - if isinstance(weight, nn.Parameter) and isinstance( - weight_quantizer, (TensorQuantizer, SequentialQuantizer) - ): + if isinstance(weight_quantizer, (TensorQuantizer, SequentialQuantizer)): yield "weight" # other weight and quantizer case