Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
a33cf13
Your commit message describing all changes
jingyu-ml Jan 14, 2026
dff152b
Merge the diffusion and llms layer fusion code
jingyu-ml Jan 14, 2026
9e94843
Create a diffusers utils function, moved some functions to it
jingyu-ml Jan 14, 2026
db61c20
Merge branch 'main' into jingyux/diffusion.export-fixed
jingyu-ml Jan 14, 2026
8a81723
Fixed some bugs in the CI/CD
jingyu-ml Jan 14, 2026
16a2bbf
Merge branch 'main' into jingyux/diffusion.export-fixed
jingyu-ml Jan 14, 2026
68d5665
Move one function to diffusers utils
jingyu-ml Jan 14, 2026
ace5773
Merge branch 'main' into jingyux/diffusion.export-fixed
jingyu-ml Jan 15, 2026
95dfb52
removed the DiffusionPipeline import
jingyu-ml Jan 15, 2026
302e2f4
Update the example
jingyu-ml Jan 15, 2026
8eed21b
Fixed the CI/CD
jingyu-ml Jan 16, 2026
01d31d7
Update the CI/CD
jingyu-ml Jan 16, 2026
ca3fdaa
Update the Flux example & address Chenjie's comments
jingyu-ml Jan 16, 2026
44345f8
use single line of code
jingyu-ml Jan 16, 2026
78f12cc
Update the test case
jingyu-ml Jan 16, 2026
3911a3d
Add the support for the WAN video
jingyu-ml Jan 16, 2026
4cf9e76
Moved the has_quantized_modules to quant utils
jingyu-ml Jan 20, 2026
1da2b46
moving model specific configs to separate files
jingyu-ml Jan 20, 2026
eafedde
Merge branch 'main' into jingyux/diffusion.export-fixed
jingyu-ml Jan 20, 2026
3fb8320
Fixed the CI/CD
jingyu-ml Jan 20, 2026
372c6f7
Fixed the cicd
jingyu-ml Jan 20, 2026
e67bf85
reducee the repeated code
jingyu-ml Jan 21, 2026
9b5cf13
Merge branch 'main' into jingyux/diffusion.export-fixed
jingyu-ml Jan 21, 2026
e931fbc
Update the lint
jingyu-ml Jan 21, 2026
8b29228
Merge branch 'main' into jingyux/diffusion.export-fixed
jingyu-ml Jan 21, 2026
b8b5eaf
Merge branch 'main' into jingyux/2-3-diffusion-export
jingyu-ml Jan 22, 2026
b717bae
Add the LTX2 FP8/BF16 support + Some core code changes
jingyu-ml Jan 23, 2026
0d93e1a
Merge branch 'main' into jingyux/2-3-diffusion-export
jingyu-ml Jan 23, 2026
c2aadca
Update
jingyu-ml Jan 23, 2026
109c010
Merge branch 'main' into jingyux/2-3-diffusion-export
jingyu-ml Jan 23, 2026
d7aef93
Fixed the CICD
jingyu-ml Jan 23, 2026
ac5fcd0
Fixed more CICD
jingyu-ml Jan 24, 2026
a96d58c
Merge branch 'main' into jingyux/2-3-diffusion-export
jingyu-ml Jan 26, 2026
e566834
Update
jingyu-ml Jan 26, 2026
626ae02
Update the example script
jingyu-ml Jan 26, 2026
796c298
Merge branch 'main' into jingyux/2-3-diffusion-export
jingyu-ml Jan 26, 2026
9f0e998
update the qkv fusion rules
jingyu-ml Jan 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 195 additions & 0 deletions examples/diffusers/quantization/calibration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from pathlib import Path
from typing import Any

from models_utils import MODEL_DEFAULTS, ModelType
from pipeline_manager import PipelineManager
from quantize_config import CalibrationConfig
from tqdm import tqdm
from utils import load_calib_prompts


class Calibrator:
"""Handles model calibration for quantization."""

def __init__(
self,
pipeline_manager: PipelineManager,
config: CalibrationConfig,
model_type: ModelType,
logger: logging.Logger,
):
"""
Initialize calibrator.

Args:
pipeline_manager: Pipeline manager with main and upsampler pipelines
config: Calibration configuration
model_type: Type of model being calibrated
logger: Logger instance
"""
self.pipeline_manager = pipeline_manager
self.pipe = pipeline_manager.pipe
self.pipe_upsample = pipeline_manager.pipe_upsample
self.config = config
self.model_type = model_type
self.logger = logger

def load_and_batch_prompts(self) -> list[list[str]]:
"""
Load calibration prompts from file.

Returns:
List of batched calibration prompts
"""
self.logger.info(f"Loading calibration prompts from {self.config.prompts_dataset}")
if isinstance(self.config.prompts_dataset, Path):
return load_calib_prompts(
self.config.batch_size,
self.config.prompts_dataset,
)

return load_calib_prompts(
self.config.batch_size,
self.config.prompts_dataset["name"],
self.config.prompts_dataset["split"],
self.config.prompts_dataset["column"],
)

def run_calibration(self, batched_prompts: list[list[str]]) -> None:
"""
Run calibration steps on the pipeline.

Args:
batched_prompts: List of batched calibration prompts
"""
self.logger.info(f"Starting calibration with {self.config.num_batches} batches")
extra_args = MODEL_DEFAULTS.get(self.model_type, {}).get("inference_extra_args", {})

with tqdm(total=self.config.num_batches, desc="Calibration", unit="batch") as pbar:
for i, prompt_batch in enumerate(batched_prompts):
if i >= self.config.num_batches:
break

if self.model_type == ModelType.LTX2:
self._run_ltx2_calibration(prompt_batch, extra_args)
elif self.model_type == ModelType.LTX_VIDEO_DEV:
# Special handling for LTX-Video
self._run_ltx_video_calibration(prompt_batch, extra_args)
elif self.model_type in [ModelType.WAN22_T2V_14b, ModelType.WAN22_T2V_5b]:
# Special handling for WAN video models
self._run_wan_video_calibration(prompt_batch, extra_args)
else:
common_args = {
"prompt": prompt_batch,
"num_inference_steps": self.config.n_steps,
}
self.pipe(**common_args, **extra_args).images
pbar.update(1)
self.logger.debug(f"Completed calibration batch {i + 1}/{self.config.num_batches}")
self.logger.info("Calibration completed successfully")

def _run_wan_video_calibration(
self, prompt_batch: list[str], extra_args: dict[str, Any]
) -> None:
kwargs = {}
kwargs["negative_prompt"] = extra_args["negative_prompt"]
kwargs["height"] = extra_args["height"]
kwargs["width"] = extra_args["width"]
kwargs["num_frames"] = extra_args["num_frames"]
kwargs["guidance_scale"] = extra_args["guidance_scale"]
if "guidance_scale_2" in extra_args:
kwargs["guidance_scale_2"] = extra_args["guidance_scale_2"]
kwargs["num_inference_steps"] = self.config.n_steps

self.pipe(prompt=prompt_batch, **kwargs).frames

def _run_ltx2_calibration(self, prompt_batch: list[str], extra_args: dict[str, Any]) -> None:
from ltx_core.model.video_vae import TilingConfig

prompt = prompt_batch[0]
extra_params = self.pipeline_manager.config.extra_params
kwargs = {
"negative_prompt": extra_args.get(
"negative_prompt", "worst quality, inconsistent motion, blurry, jittery, distorted"
),
"seed": extra_params.get("seed", 0),
"height": extra_params.get("height", extra_args.get("height", 1024)),
"width": extra_params.get("width", extra_args.get("width", 1536)),
"num_frames": extra_params.get("num_frames", extra_args.get("num_frames", 121)),
"frame_rate": extra_params.get("frame_rate", extra_args.get("frame_rate", 24.0)),
"num_inference_steps": self.config.n_steps,
"cfg_guidance_scale": extra_params.get(
"cfg_guidance_scale", extra_args.get("cfg_guidance_scale", 4.0)
),
"images": extra_params.get("images", []),
"tiling_config": extra_params.get("tiling_config", TilingConfig.default()),
}
self.pipe(prompt=prompt, **kwargs)

def _run_ltx_video_calibration(
self, prompt_batch: list[str], extra_args: dict[str, Any]
) -> None:
"""
Run calibration for LTX-Video model using the full multi-stage pipeline.

Args:
prompt_batch: Batch of prompts
extra_args: Model-specific arguments
"""
# Extract specific args for LTX-Video
expected_height = extra_args.get("height", 512)
expected_width = extra_args.get("width", 704)
num_frames = extra_args.get("num_frames", 121)
negative_prompt = extra_args.get(
"negative_prompt", "worst quality, inconsistent motion, blurry, jittery, distorted"
)

def round_to_nearest_resolution_acceptable_by_vae(height, width):
height = height - (height % self.pipe.vae_spatial_compression_ratio)
width = width - (width % self.pipe.vae_spatial_compression_ratio)
return height, width

downscale_factor = 2 / 3
# Part 1: Generate video at smaller resolution
downscaled_height, downscaled_width = (
int(expected_height * downscale_factor),
int(expected_width * downscale_factor),
)
downscaled_height, downscaled_width = round_to_nearest_resolution_acceptable_by_vae(
downscaled_height, downscaled_width
)

# Generate initial latents at lower resolution
latents = self.pipe(
conditions=None,
prompt=prompt_batch,
negative_prompt=negative_prompt,
width=downscaled_width,
height=downscaled_height,
num_frames=num_frames,
num_inference_steps=self.config.n_steps,
output_type="latent",
).frames

# Part 2: Upscale generated video using latent upsampler (if available)
if self.pipe_upsample is not None:
_ = self.pipe_upsample(latents=latents, output_type="latent").frames

# Part 3: Denoise the upscaled video with few steps to improve texture
# However, in this example code, we will omit the upscale step since its optional.
64 changes: 63 additions & 1 deletion examples/diffusers/quantization/models_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from collections.abc import Callable
from enum import Enum
from typing import Any
Expand Down Expand Up @@ -42,6 +43,7 @@ class ModelType(str, Enum):
FLUX_DEV = "flux-dev"
FLUX_SCHNELL = "flux-schnell"
LTX_VIDEO_DEV = "ltx-video-dev"
LTX2 = "ltx-2"
WAN22_T2V_14b = "wan2.2-t2v-14b"
WAN22_T2V_5b = "wan2.2-t2v-5b"

Expand All @@ -64,6 +66,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
ModelType.SD3_MEDIUM: filter_func_default,
ModelType.SD35_MEDIUM: filter_func_default,
ModelType.LTX_VIDEO_DEV: filter_func_ltx_video,
ModelType.LTX2: filter_func_ltx_video,
ModelType.WAN22_T2V_14b: filter_func_wan_video,
ModelType.WAN22_T2V_5b: filter_func_wan_video,
}
Expand All @@ -80,18 +83,20 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
ModelType.FLUX_DEV: "black-forest-labs/FLUX.1-dev",
ModelType.FLUX_SCHNELL: "black-forest-labs/FLUX.1-schnell",
ModelType.LTX_VIDEO_DEV: "Lightricks/LTX-Video-0.9.7-dev",
ModelType.LTX2: "Lightricks/LTX-2",
ModelType.WAN22_T2V_14b: "Wan-AI/Wan2.2-T2V-A14B-Diffusers",
ModelType.WAN22_T2V_5b: "Wan-AI/Wan2.2-TI2V-5B-Diffusers",
}

MODEL_PIPELINE: dict[ModelType, type[DiffusionPipeline]] = {
MODEL_PIPELINE: dict[ModelType, type[DiffusionPipeline] | None] = {
ModelType.SDXL_BASE: DiffusionPipeline,
ModelType.SDXL_TURBO: DiffusionPipeline,
ModelType.SD3_MEDIUM: StableDiffusion3Pipeline,
ModelType.SD35_MEDIUM: StableDiffusion3Pipeline,
ModelType.FLUX_DEV: FluxPipeline,
ModelType.FLUX_SCHNELL: FluxPipeline,
ModelType.LTX_VIDEO_DEV: LTXConditionPipeline,
ModelType.LTX2: None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we know if HF diffusers has plan to support LTX-2, if LTX-1 is supported?

Copy link
Contributor Author

@jingyu-ml jingyu-ml Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I know, there are no plans at the moment for text-to-video support. Diffusers currently only supports image-to-video for LTX2.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don’t need to stick with Diffusers. The trend I’m seeing is that newer models usually come with their own codebases. This MR already extends support to non-Diffusers code.

ModelType.WAN22_T2V_14b: WanPipeline,
ModelType.WAN22_T2V_5b: WanPipeline,
}
Expand Down Expand Up @@ -154,6 +159,18 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
"negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted",
},
},
ModelType.LTX2: {
"backbone": "transformer",
"dataset": _SD_PROMPTS_DATASET,
"inference_extra_args": {
"height": 1024,
"width": 1536,
"num_frames": 121,
"frame_rate": 24.0,
"cfg_guidance_scale": 4.0,
"negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted",
},
},
ModelType.WAN22_T2V_14b: {
**_WAN_BASE_CONFIG,
"from_pretrained_extra_args": {
Expand Down Expand Up @@ -192,3 +209,48 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
},
},
}


def _coerce_extra_param_value(value: str) -> Any:
lowered = value.lower()
if lowered in {"true", "false"}:
return lowered == "true"
try:
return int(value)
except ValueError:
pass
try:
return float(value)
except ValueError:
return value


def parse_extra_params(
kv_args: list[str], unknown_args: list[str], logger: logging.Logger
) -> dict[str, Any]:
extra_params: dict[str, Any] = {}
for item in kv_args:
if "=" not in item:
raise ValueError(f"Invalid --extra-param value: '{item}'. Expected KEY=VALUE.")
key, value = item.split("=", 1)
extra_params[key] = _coerce_extra_param_value(value)

i = 0
while i < len(unknown_args):
token = unknown_args[i]
if token.startswith("--extra_param."):
key = token[len("--extra_param.") :]
value = "true"
if i + 1 < len(unknown_args) and not unknown_args[i + 1].startswith("--"):
value = unknown_args[i + 1]
i += 1
extra_params[key] = _coerce_extra_param_value(value)
elif token.startswith("--extra_param"):
raise ValueError(
"Use --extra_param.KEY VALUE or --extra-param KEY=VALUE for extra parameters."
)
else:
logger.warning("Ignoring unknown argument: %s", token)
i += 1

return extra_params
Loading
Loading