Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ NVIDIA Model Optimizer Changelog (Linux)
- Add standalone type inference option (``--use_standalone_type_inference``) in ONNX AutoCast as an alternative to ONNX's ``infer_shapes``. This experimental feature performs type-only inference without shape inference, useful as a workaround when shape inference fails or to avoid unnecessary shape inference overhead.
- Add support for Kimi K2 Thinking model quantization from the original int4 checkpoint.
- Add support for ``params`` constraint based automatic neural architecture search in Minitron pruning (``mcore_minitron``) as an alternative to manual pruning (using ``export_config``). See `examples/pruning/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/pruning>`_ for more details on its usage.
- Add ``--opset`` option to ONNX quantization CLI to specify the target opset version for the quantized model.

0.41 (2026-01-19)
^^^^^^^^^^^^^^^^^
Expand Down
40 changes: 37 additions & 3 deletions modelopt/onnx/autocast/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from modelopt.onnx.autocast.nodeclassifier import NodeClassifier, NodeRuleBase
from modelopt.onnx.autocast.precisionconverter import PrecisionConverter
from modelopt.onnx.autocast.referencerunner import ReferenceRunner
from modelopt.onnx.utils import get_min_opset_for_precisions, get_qdq_precisions

"""
FP16 accuracy decreases in accordance with the data's magnitude.
Expand Down Expand Up @@ -84,7 +85,7 @@ def convert_to_mixed_precision(
trt_plugins_precision: List indicating the precision for each custom op.
max_depth_of_reduction: Maximum depth of reduction for node classification.
opset: Target ONNX opset version. If None, uses default minimum opset based on low_precision_type
(22 for bf16, 13 for fp16). The opset may be automatically increased if certain operations
(22 for bf16, 19 for fp16). The opset may be automatically increased if certain operations
require a higher version.
use_standalone_type_inference: If True, use standalone type inference implementation instead of ONNX's
infer_shapes. This is a workaround (WAR) when only type inference is
Expand Down Expand Up @@ -202,6 +203,7 @@ def convert_to_f16(
tensor_block_dict: dict[str, dict[str, list[int]]] = {},
trt_plugins: list[str] | None = [],
use_standalone_type_inference: bool = False,
opset: int | None = None,
) -> onnx.ModelProto:
"""Convert model to mixed precision, using PrecisionConverter.

Expand All @@ -217,13 +219,45 @@ def convert_to_f16(
use_standalone_type_inference: If True, use standalone type inference implementation instead of ONNX's
infer_shapes. This is a workaround (WAR) when only type inference is
needed without shape inference. Default: False.
opset: Target ONNX opset version. If None, uses default minimum opset based on precision type
(22 for bf16, 13 for fp16) and Q/DQ node requirements. The opset may be automatically
increased if Q/DQ nodes in the model require a higher version (e.g., FP8 requires 19,
INT4 requires 21, NVFP4 requires 23).
"""
assert low_precision_type in ["fp16", "bf16"], "low_precision_type must be either fp16 or bf16"

# Opset 21 is needed for NVFP4 quantization support (DQ with 'block_size' attribute)
# Check Q/DQ precision types in the model and determine required opset
qdq_precisions = get_qdq_precisions(model)
qdq_min_opset = get_min_opset_for_precisions(qdq_precisions)

# Base minimum opset for FP16/BF16 conversion
# Opset 19 is the first to support fp16 scales in Q/DQ nodes
base_min_opset = 22 if low_precision_type == "bf16" else 19

# Determine target opset version
if opset is not None:
min_opset = opset
# Check if Q/DQ nodes require a higher opset
if qdq_precisions and qdq_min_opset > min_opset:
logger.warning(
f"Model contains Q/DQ nodes with precisions {qdq_precisions} that require "
f"opset >= {qdq_min_opset}. Upgrading from specified opset {opset} to {qdq_min_opset}."
)
min_opset = qdq_min_opset
# Also ensure we meet base minimum for precision type
if min_opset < base_min_opset:
logger.warning(
f"Opset {min_opset} is below minimum opset {base_min_opset} for {low_precision_type}. "
f"Upgrading to opset {base_min_opset}."
)
min_opset = base_min_opset
else:
# Use the highest required opset between base and Q/DQ requirements
min_opset = max(base_min_opset, qdq_min_opset)

sanitizer = GraphSanitizer(
model,
min_opset=21,
min_opset=min_opset,
trt_plugins=trt_plugins,
max_ir_version=LATEST_IR_VERSION_SUPPORTED_BY_ORT,
)
Expand Down
10 changes: 10 additions & 0 deletions modelopt/onnx/quantization/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,15 @@ def get_parser() -> argparse.ArgumentParser:
"The currently supported precisions are {fp16, int8, fp8}."
),
)
argparser.add_argument(
"--opset",
type=int,
help=(
"Target ONNX opset version for the quantized model. If not specified, uses default minimum opset "
"(19 for fp16 scales support, 21 for int4, 23 for nvfp4). The opset may be automatically increased "
"if certain operations require a higher version."
),
)
return argparser


Expand Down Expand Up @@ -352,6 +361,7 @@ def main():
simplify=args.simplify,
calibrate_per_node=args.calibrate_per_node,
direct_io_types=args.direct_io_types,
opset=args.opset,
)


Expand Down
2 changes: 2 additions & 0 deletions modelopt/onnx/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def quantize(
calibrate_per_node: bool = False,
custom_ops_to_quantize: list[str] = [],
direct_io_types: bool = False,
opset: int | None = None,
**kwargs,
) -> onnx.ModelProto:
"""Applies FP8 GEMM only quantization to an ONNX file.
Expand Down Expand Up @@ -328,6 +329,7 @@ def quantize(
tensor_block_dict=custom_ops_to_cast_fp32 or {},
low_precision_type=high_precision_dtype,
trt_plugins=trt_extra_plugin_lib_paths,
opset=opset,
)

current_opsets = {opset.domain: opset.version for opset in onnx_model.opset_import}
Expand Down
2 changes: 2 additions & 0 deletions modelopt/onnx/quantization/int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def quantize(
calibrate_per_node: bool = False,
custom_ops_to_quantize: list[str] = [],
direct_io_types: bool = False,
opset: int | None = None,
**kwargs,
) -> onnx.ModelProto:
"""Applies INT8 quantization to an ONNX file using the compiler friendly heuristics.
Expand Down Expand Up @@ -289,6 +290,7 @@ def quantize(
tensor_block_dict=custom_ops_to_cast_fp32 or {},
low_precision_type=high_precision_dtype,
trt_plugins=trt_extra_plugin_lib_paths,
opset=opset,
)

if nodes_to_quantize:
Expand Down
66 changes: 58 additions & 8 deletions modelopt/onnx/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@
)
from modelopt.onnx.trt_utils import interpret_trt_plugins_precision_flag, load_onnx_model
from modelopt.onnx.utils import (
BASE_MIN_OPSET,
QDQ_PRECISION_MIN_OPSET,
duplicate_shared_constants,
get_opset_version,
name_onnx_nodes,
Expand All @@ -78,6 +80,17 @@
__all__ = ["quantize"]


def _normalize_quantize_mode_for_opset(quantize_mode: str) -> str:
"""Map variants like "int4_awq", "int4_rtn", "nvfp4" to their base precision types for lookup purposes."""
mode_lower = quantize_mode.lower()
if "int4" in mode_lower:
return "int4"
if "nvfp4" in mode_lower or "float4" in mode_lower:
return "float4_e2m1fn"
# For "int8", "fp8", etc., return as-is (fp8 falls back to BASE_MIN_OPSET which is correct)
return quantize_mode


def _preprocess_onnx(
onnx_path: str,
use_external_data_format: bool,
Expand All @@ -88,6 +101,7 @@ def _preprocess_onnx(
override_shapes: str,
simplify: bool = False,
quantize_mode: str = "int8",
opset: int | None = None,
) -> tuple[str, onnx.ModelProto, list[str], bool, bool, bool, dict, dict]:
logger.info(f"Preprocessing the model {onnx_path}")
intermediate_generated_files = []
Expand Down Expand Up @@ -118,16 +132,45 @@ def _preprocess_onnx(
" '--trt_plugins' flag (requires TRT 10+)."
)

# Per-Channel support with QDQ format requires onnx opset version 13 or above
opset_version = get_opset_version(onnx_model)
# Opset 19 is the minimum required for fp16 scales in Q/DQ nodes
# Higher opsets required for specific quantization modes (int4: 21, nvfp4: 23)
original_opset_version = get_opset_version(onnx_model)

# Determine minimum required opset based on quantization mode
# Normalize quantize_mode to handle variants like "int4_awq", "nvfp4", etc.
normalized_mode = _normalize_quantize_mode_for_opset(quantize_mode)
mode_min_opset = QDQ_PRECISION_MIN_OPSET.get(normalized_mode, BASE_MIN_OPSET)

# Determine target opset version
if opset is not None:
target_opset = opset
# Warn if user-specified opset is below mode minimum (but still respect it)
if opset < mode_min_opset:
logger.warning(
f"Opset {opset} is below the minimum opset {mode_min_opset} required for "
f"{quantize_mode} quantization. Upgrading to opset {mode_min_opset}."
)
target_opset = mode_min_opset
# Warn if user-specified opset is lower than original
if opset < original_opset_version:
logger.warning(
f"Specified opset {opset} is lower than the original model's opset {original_opset_version}. "
f"Using original model's opset {original_opset_version}."
)
target_opset = max(target_opset, original_opset_version)
else:
# Use model's opset if it's >= mode_min_opset, otherwise upgrade to mode_min_opset
target_opset = (
max(original_opset_version, mode_min_opset)
if original_opset_version != 1
else mode_min_opset
)

required_opset_version = 13
if opset_version < required_opset_version and opset_version != 1:
opset_version = required_opset_version
onnx_model = onnx.version_converter.convert_version(onnx_model, opset_version)
onnx_path = os.path.join(output_dir, f"{model_name}_opset{opset_version}.onnx")
if original_opset_version < target_opset and original_opset_version != 1:
onnx_model = onnx.version_converter.convert_version(onnx_model, target_opset)
onnx_path = os.path.join(output_dir, f"{model_name}_opset{target_opset}.onnx")
save_onnx(onnx_model, onnx_path, use_external_data_format)
logger.info(f"Model is cloned to {onnx_path} with opset_version {opset_version}")
logger.info(f"Model is cloned to {onnx_path} with opset_version {target_opset}")
intermediate_generated_files.append(onnx_path)

# Simplify model if requested
Expand Down Expand Up @@ -231,6 +274,7 @@ def quantize(
calibrate_per_node: bool = False,
input_shapes_profile: Sequence[dict[str, str]] | None = None,
direct_io_types: bool = False,
opset: int | None = None,
**kwargs: Any,
) -> None:
"""Quantizes the provided ONNX model.
Expand Down Expand Up @@ -350,6 +394,10 @@ def quantize(
direct_io_types:
If True, modify the I/O types in the quantized ONNX model to be lower precision whenever possible.
If False, keep the I/O types in the quantized ONNX model the same as in the given ONNX model.
opset:
Target ONNX opset version for the quantized model. If None, uses required minimum opset
(19 for int8/fp8, 21 for int4, 23 for nvfp4). If the specified opset is lower than the required minimum,
a warning will be issued and the opset will be upgraded to the required minimum.
kwargs:
Additional keyword arguments for int4 quantization, including:
- awqlite_alpha_step (float): Alpha step for lite, range [0, 1].
Expand Down Expand Up @@ -420,6 +468,7 @@ def quantize(
override_shapes, # type: ignore[arg-type]
simplify,
quantize_mode,
opset,
)
trt_plugins = update_trt_ep_support(calibration_eps, has_dds_op, has_custom_op, trt_plugins) # type: ignore[arg-type]

Expand Down Expand Up @@ -481,6 +530,7 @@ def quantize(
calibrate_per_node=calibrate_per_node,
custom_ops_to_quantize=list(custom_ops_to_quantize.keys()),
direct_io_types=direct_io_types,
opset=opset,
**kwargs,
)
elif "int4" in quantize_mode:
Expand Down
64 changes: 64 additions & 0 deletions modelopt/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@

from modelopt.onnx.logging_config import logger

# Base minimum opset for quantization (opset 19 is the first to support fp16 scales)
BASE_MIN_OPSET = 19


def get_input_names_from_bytes(model_bytes: bytes, external_inputs_only: bool = True) -> list[str]:
"""This function returns the inputs names of the given onnx model in bytes.
Expand Down Expand Up @@ -696,6 +699,67 @@ def get_opset_version(model: onnx.ModelProto) -> int:
return ai_onnx_domain[0].version


def get_qdq_precisions(model: onnx.ModelProto) -> set:
"""Gets the Q/DQ precision types present in the model.

Args:
model: Loaded in-memory onnx ModelProto.

Returns:
set: Set of Q/DQ precision types present in the model (e.g., 'float8_e4m3fn', 'int8',
'int4', 'float4_e2m1fn').
"""
graph = gs.import_onnx(model)
precisions = set()

# Check for custom 'NVFP4' nodes
custom_fp4_q_nodes = [node for node in graph.nodes if node.op == "TRT_FP4DynamicQuantize"]
if custom_fp4_q_nodes:
precisions.add("float4_e2m1fn")

# Check for precision in DQ nodes
dq_nodes = [node for node in graph.nodes if node.op == "DequantizeLinear"]
for dq_node in dq_nodes:
if len(dq_node.inputs) >= 3 and dq_node.inputs[2] is not None:
# If zero-point is set, return that as the quantization mode
if isinstance(dq_node.inputs[2], Constant) and dq_node.inputs[2].values is not None:
precisions.add(dq_node.inputs[2].values.dtype.name)
elif isinstance(dq_node.inputs[0], Constant) and dq_node.inputs[0].values is not None:
# Else, return the node's input precision (ex: 'NVFP4' weight quantization)
precisions.add(dq_node.inputs[0].values.dtype.name)

return precisions


# Minimum opset requirements by quantization mode/precision
# Base minimum is 19 (first opset that allows fp16 scales in Q/DQ nodes)
# Supports both quantize modes (e.g., "fp8") and dtype prefixes (e.g., "float8" for "float8_e4m3fn")
QDQ_PRECISION_MIN_OPSET = {
"int8": BASE_MIN_OPSET,
"float8_e4m3fn": BASE_MIN_OPSET,
"int4": 21,
"uint4": 21,
"float4_e2m1fn": 23,
}


def get_min_opset_for_precisions(precisions: set) -> int:
"""Gets the minimum required opset version for a set of Q/DQ precision types.

Args:
precisions: Set of precision type strings (e.g., 'float8_e4m3fn', 'int4').

Returns:
int: Minimum required opset version for the given precisions.
"""
min_opset = BASE_MIN_OPSET # Base minimum for fp16 scales support
for precision in precisions:
# Direct lookup first
if precision in QDQ_PRECISION_MIN_OPSET:
min_opset = max(min_opset, QDQ_PRECISION_MIN_OPSET[precision])
return min_opset


def bfloat16_to_float32(bf16_array):
"""Converts a bfloat16 array (as raw data) to a float32 array."""
uint32_array = bf16_array.astype(np.uint32) << 16
Expand Down
Loading