diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 452f36538..68855430f 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -13,6 +13,7 @@ NVIDIA Model Optimizer Changelog (Linux) - Add standalone type inference option (``--use_standalone_type_inference``) in ONNX AutoCast as an alternative to ONNX's ``infer_shapes``. This experimental feature performs type-only inference without shape inference, useful as a workaround when shape inference fails or to avoid unnecessary shape inference overhead. - Add support for Kimi K2 Thinking model quantization from the original int4 checkpoint. - Add support for ``params`` constraint based automatic neural architecture search in Minitron pruning (``mcore_minitron``) as an alternative to manual pruning (using ``export_config``). See `examples/pruning/README.md `_ for more details on its usage. +- Add ``--opset`` option to ONNX quantization CLI to specify the target opset version for the quantized model. 0.41 (2026-01-19) ^^^^^^^^^^^^^^^^^ diff --git a/modelopt/onnx/autocast/convert.py b/modelopt/onnx/autocast/convert.py index 73d2bea4d..2391e312b 100644 --- a/modelopt/onnx/autocast/convert.py +++ b/modelopt/onnx/autocast/convert.py @@ -33,6 +33,7 @@ from modelopt.onnx.autocast.nodeclassifier import NodeClassifier, NodeRuleBase from modelopt.onnx.autocast.precisionconverter import PrecisionConverter from modelopt.onnx.autocast.referencerunner import ReferenceRunner +from modelopt.onnx.utils import get_min_opset_for_precisions, get_qdq_precisions """ FP16 accuracy decreases in accordance with the data's magnitude. @@ -84,7 +85,7 @@ def convert_to_mixed_precision( trt_plugins_precision: List indicating the precision for each custom op. max_depth_of_reduction: Maximum depth of reduction for node classification. opset: Target ONNX opset version. If None, uses default minimum opset based on low_precision_type - (22 for bf16, 13 for fp16). The opset may be automatically increased if certain operations + (22 for bf16, 19 for fp16). The opset may be automatically increased if certain operations require a higher version. use_standalone_type_inference: If True, use standalone type inference implementation instead of ONNX's infer_shapes. This is a workaround (WAR) when only type inference is @@ -202,6 +203,7 @@ def convert_to_f16( tensor_block_dict: dict[str, dict[str, list[int]]] = {}, trt_plugins: list[str] | None = [], use_standalone_type_inference: bool = False, + opset: int | None = None, ) -> onnx.ModelProto: """Convert model to mixed precision, using PrecisionConverter. @@ -217,13 +219,45 @@ def convert_to_f16( use_standalone_type_inference: If True, use standalone type inference implementation instead of ONNX's infer_shapes. This is a workaround (WAR) when only type inference is needed without shape inference. Default: False. + opset: Target ONNX opset version. If None, uses default minimum opset based on precision type + (22 for bf16, 13 for fp16) and Q/DQ node requirements. The opset may be automatically + increased if Q/DQ nodes in the model require a higher version (e.g., FP8 requires 19, + INT4 requires 21, NVFP4 requires 23). """ assert low_precision_type in ["fp16", "bf16"], "low_precision_type must be either fp16 or bf16" - # Opset 21 is needed for NVFP4 quantization support (DQ with 'block_size' attribute) + # Check Q/DQ precision types in the model and determine required opset + qdq_precisions = get_qdq_precisions(model) + qdq_min_opset = get_min_opset_for_precisions(qdq_precisions) + + # Base minimum opset for FP16/BF16 conversion + # Opset 19 is the first to support fp16 scales in Q/DQ nodes + base_min_opset = 22 if low_precision_type == "bf16" else 19 + + # Determine target opset version + if opset is not None: + min_opset = opset + # Check if Q/DQ nodes require a higher opset + if qdq_precisions and qdq_min_opset > min_opset: + logger.warning( + f"Model contains Q/DQ nodes with precisions {qdq_precisions} that require " + f"opset >= {qdq_min_opset}. Upgrading from specified opset {opset} to {qdq_min_opset}." + ) + min_opset = qdq_min_opset + # Also ensure we meet base minimum for precision type + if min_opset < base_min_opset: + logger.warning( + f"Opset {min_opset} is below minimum opset {base_min_opset} for {low_precision_type}. " + f"Upgrading to opset {base_min_opset}." + ) + min_opset = base_min_opset + else: + # Use the highest required opset between base and Q/DQ requirements + min_opset = max(base_min_opset, qdq_min_opset) + sanitizer = GraphSanitizer( model, - min_opset=21, + min_opset=min_opset, trt_plugins=trt_plugins, max_ir_version=LATEST_IR_VERSION_SUPPORTED_BY_ORT, ) diff --git a/modelopt/onnx/quantization/__main__.py b/modelopt/onnx/quantization/__main__.py index e416a3e93..6c79d9317 100644 --- a/modelopt/onnx/quantization/__main__.py +++ b/modelopt/onnx/quantization/__main__.py @@ -286,6 +286,15 @@ def get_parser() -> argparse.ArgumentParser: "The currently supported precisions are {fp16, int8, fp8}." ), ) + argparser.add_argument( + "--opset", + type=int, + help=( + "Target ONNX opset version for the quantized model. If not specified, uses default minimum opset " + "(19 for fp16 scales support, 21 for int4, 23 for nvfp4). The opset may be automatically increased " + "if certain operations require a higher version." + ), + ) return argparser @@ -352,6 +361,7 @@ def main(): simplify=args.simplify, calibrate_per_node=args.calibrate_per_node, direct_io_types=args.direct_io_types, + opset=args.opset, ) diff --git a/modelopt/onnx/quantization/fp8.py b/modelopt/onnx/quantization/fp8.py index cab92483c..76a3e8167 100755 --- a/modelopt/onnx/quantization/fp8.py +++ b/modelopt/onnx/quantization/fp8.py @@ -182,6 +182,7 @@ def quantize( calibrate_per_node: bool = False, custom_ops_to_quantize: list[str] = [], direct_io_types: bool = False, + opset: int | None = None, **kwargs, ) -> onnx.ModelProto: """Applies FP8 GEMM only quantization to an ONNX file. @@ -328,6 +329,7 @@ def quantize( tensor_block_dict=custom_ops_to_cast_fp32 or {}, low_precision_type=high_precision_dtype, trt_plugins=trt_extra_plugin_lib_paths, + opset=opset, ) current_opsets = {opset.domain: opset.version for opset in onnx_model.opset_import} diff --git a/modelopt/onnx/quantization/int8.py b/modelopt/onnx/quantization/int8.py index 01929667c..6e350a16f 100755 --- a/modelopt/onnx/quantization/int8.py +++ b/modelopt/onnx/quantization/int8.py @@ -132,6 +132,7 @@ def quantize( calibrate_per_node: bool = False, custom_ops_to_quantize: list[str] = [], direct_io_types: bool = False, + opset: int | None = None, **kwargs, ) -> onnx.ModelProto: """Applies INT8 quantization to an ONNX file using the compiler friendly heuristics. @@ -289,6 +290,7 @@ def quantize( tensor_block_dict=custom_ops_to_cast_fp32 or {}, low_precision_type=high_precision_dtype, trt_plugins=trt_extra_plugin_lib_paths, + opset=opset, ) if nodes_to_quantize: diff --git a/modelopt/onnx/quantization/quantize.py b/modelopt/onnx/quantization/quantize.py index ecc494c43..da7ff126d 100755 --- a/modelopt/onnx/quantization/quantize.py +++ b/modelopt/onnx/quantization/quantize.py @@ -69,6 +69,8 @@ ) from modelopt.onnx.trt_utils import interpret_trt_plugins_precision_flag, load_onnx_model from modelopt.onnx.utils import ( + BASE_MIN_OPSET, + QDQ_PRECISION_MIN_OPSET, duplicate_shared_constants, get_opset_version, name_onnx_nodes, @@ -78,6 +80,17 @@ __all__ = ["quantize"] +def _normalize_quantize_mode_for_opset(quantize_mode: str) -> str: + """Map variants like "int4_awq", "int4_rtn", "nvfp4" to their base precision types for lookup purposes.""" + mode_lower = quantize_mode.lower() + if "int4" in mode_lower: + return "int4" + if "nvfp4" in mode_lower or "float4" in mode_lower: + return "float4_e2m1fn" + # For "int8", "fp8", etc., return as-is (fp8 falls back to BASE_MIN_OPSET which is correct) + return quantize_mode + + def _preprocess_onnx( onnx_path: str, use_external_data_format: bool, @@ -88,6 +101,7 @@ def _preprocess_onnx( override_shapes: str, simplify: bool = False, quantize_mode: str = "int8", + opset: int | None = None, ) -> tuple[str, onnx.ModelProto, list[str], bool, bool, bool, dict, dict]: logger.info(f"Preprocessing the model {onnx_path}") intermediate_generated_files = [] @@ -118,16 +132,45 @@ def _preprocess_onnx( " '--trt_plugins' flag (requires TRT 10+)." ) - # Per-Channel support with QDQ format requires onnx opset version 13 or above - opset_version = get_opset_version(onnx_model) + # Opset 19 is the minimum required for fp16 scales in Q/DQ nodes + # Higher opsets required for specific quantization modes (int4: 21, nvfp4: 23) + original_opset_version = get_opset_version(onnx_model) + + # Determine minimum required opset based on quantization mode + # Normalize quantize_mode to handle variants like "int4_awq", "nvfp4", etc. + normalized_mode = _normalize_quantize_mode_for_opset(quantize_mode) + mode_min_opset = QDQ_PRECISION_MIN_OPSET.get(normalized_mode, BASE_MIN_OPSET) + + # Determine target opset version + if opset is not None: + target_opset = opset + # Warn if user-specified opset is below mode minimum (but still respect it) + if opset < mode_min_opset: + logger.warning( + f"Opset {opset} is below the minimum opset {mode_min_opset} required for " + f"{quantize_mode} quantization. Upgrading to opset {mode_min_opset}." + ) + target_opset = mode_min_opset + # Warn if user-specified opset is lower than original + if opset < original_opset_version: + logger.warning( + f"Specified opset {opset} is lower than the original model's opset {original_opset_version}. " + f"Using original model's opset {original_opset_version}." + ) + target_opset = max(target_opset, original_opset_version) + else: + # Use model's opset if it's >= mode_min_opset, otherwise upgrade to mode_min_opset + target_opset = ( + max(original_opset_version, mode_min_opset) + if original_opset_version != 1 + else mode_min_opset + ) - required_opset_version = 13 - if opset_version < required_opset_version and opset_version != 1: - opset_version = required_opset_version - onnx_model = onnx.version_converter.convert_version(onnx_model, opset_version) - onnx_path = os.path.join(output_dir, f"{model_name}_opset{opset_version}.onnx") + if original_opset_version < target_opset and original_opset_version != 1: + onnx_model = onnx.version_converter.convert_version(onnx_model, target_opset) + onnx_path = os.path.join(output_dir, f"{model_name}_opset{target_opset}.onnx") save_onnx(onnx_model, onnx_path, use_external_data_format) - logger.info(f"Model is cloned to {onnx_path} with opset_version {opset_version}") + logger.info(f"Model is cloned to {onnx_path} with opset_version {target_opset}") intermediate_generated_files.append(onnx_path) # Simplify model if requested @@ -231,6 +274,7 @@ def quantize( calibrate_per_node: bool = False, input_shapes_profile: Sequence[dict[str, str]] | None = None, direct_io_types: bool = False, + opset: int | None = None, **kwargs: Any, ) -> None: """Quantizes the provided ONNX model. @@ -350,6 +394,10 @@ def quantize( direct_io_types: If True, modify the I/O types in the quantized ONNX model to be lower precision whenever possible. If False, keep the I/O types in the quantized ONNX model the same as in the given ONNX model. + opset: + Target ONNX opset version for the quantized model. If None, uses required minimum opset + (19 for int8/fp8, 21 for int4, 23 for nvfp4). If the specified opset is lower than the required minimum, + a warning will be issued and the opset will be upgraded to the required minimum. kwargs: Additional keyword arguments for int4 quantization, including: - awqlite_alpha_step (float): Alpha step for lite, range [0, 1]. @@ -420,6 +468,7 @@ def quantize( override_shapes, # type: ignore[arg-type] simplify, quantize_mode, + opset, ) trt_plugins = update_trt_ep_support(calibration_eps, has_dds_op, has_custom_op, trt_plugins) # type: ignore[arg-type] @@ -481,6 +530,7 @@ def quantize( calibrate_per_node=calibrate_per_node, custom_ops_to_quantize=list(custom_ops_to_quantize.keys()), direct_io_types=direct_io_types, + opset=opset, **kwargs, ) elif "int4" in quantize_mode: diff --git a/modelopt/onnx/utils.py b/modelopt/onnx/utils.py index 02306792a..669259aa1 100644 --- a/modelopt/onnx/utils.py +++ b/modelopt/onnx/utils.py @@ -30,6 +30,9 @@ from modelopt.onnx.logging_config import logger +# Base minimum opset for quantization (opset 19 is the first to support fp16 scales) +BASE_MIN_OPSET = 19 + def get_input_names_from_bytes(model_bytes: bytes, external_inputs_only: bool = True) -> list[str]: """This function returns the inputs names of the given onnx model in bytes. @@ -696,6 +699,67 @@ def get_opset_version(model: onnx.ModelProto) -> int: return ai_onnx_domain[0].version +def get_qdq_precisions(model: onnx.ModelProto) -> set: + """Gets the Q/DQ precision types present in the model. + + Args: + model: Loaded in-memory onnx ModelProto. + + Returns: + set: Set of Q/DQ precision types present in the model (e.g., 'float8_e4m3fn', 'int8', + 'int4', 'float4_e2m1fn'). + """ + graph = gs.import_onnx(model) + precisions = set() + + # Check for custom 'NVFP4' nodes + custom_fp4_q_nodes = [node for node in graph.nodes if node.op == "TRT_FP4DynamicQuantize"] + if custom_fp4_q_nodes: + precisions.add("float4_e2m1fn") + + # Check for precision in DQ nodes + dq_nodes = [node for node in graph.nodes if node.op == "DequantizeLinear"] + for dq_node in dq_nodes: + if len(dq_node.inputs) >= 3 and dq_node.inputs[2] is not None: + # If zero-point is set, return that as the quantization mode + if isinstance(dq_node.inputs[2], Constant) and dq_node.inputs[2].values is not None: + precisions.add(dq_node.inputs[2].values.dtype.name) + elif isinstance(dq_node.inputs[0], Constant) and dq_node.inputs[0].values is not None: + # Else, return the node's input precision (ex: 'NVFP4' weight quantization) + precisions.add(dq_node.inputs[0].values.dtype.name) + + return precisions + + +# Minimum opset requirements by quantization mode/precision +# Base minimum is 19 (first opset that allows fp16 scales in Q/DQ nodes) +# Supports both quantize modes (e.g., "fp8") and dtype prefixes (e.g., "float8" for "float8_e4m3fn") +QDQ_PRECISION_MIN_OPSET = { + "int8": BASE_MIN_OPSET, + "float8_e4m3fn": BASE_MIN_OPSET, + "int4": 21, + "uint4": 21, + "float4_e2m1fn": 23, +} + + +def get_min_opset_for_precisions(precisions: set) -> int: + """Gets the minimum required opset version for a set of Q/DQ precision types. + + Args: + precisions: Set of precision type strings (e.g., 'float8_e4m3fn', 'int4'). + + Returns: + int: Minimum required opset version for the given precisions. + """ + min_opset = BASE_MIN_OPSET # Base minimum for fp16 scales support + for precision in precisions: + # Direct lookup first + if precision in QDQ_PRECISION_MIN_OPSET: + min_opset = max(min_opset, QDQ_PRECISION_MIN_OPSET[precision]) + return min_opset + + def bfloat16_to_float32(bf16_array): """Converts a bfloat16 array (as raw data) to a float32 array.""" uint32_array = bf16_array.astype(np.uint32) << 16 diff --git a/tests/unit/onnx/test_quantize_api.py b/tests/unit/onnx/test_quantize_api.py new file mode 100644 index 000000000..0fffa9636 --- /dev/null +++ b/tests/unit/onnx/test_quantize_api.py @@ -0,0 +1,122 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for ONNX quantization opset handling.""" + +import os + +import onnx +import onnxruntime +import pytest +import torch +from _test_utils.onnx.lib_test_models import SimpleMLP, export_as_onnx +from packaging import version + +import modelopt.onnx.quantization as moq +from modelopt.onnx.utils import get_opset_version + +# Mapping of quantization mode to minimum required opset +MIN_OPSET = { + "int8": 19, + "fp8": 19, + "int4": 21, +} + +# onnxruntime version that supports opset 22+ +ORT_VERSION_FOR_OPSET_22 = version.parse("1.23.0") + + +@pytest.mark.parametrize("quant_mode", ["int8", "fp8", "int4"]) +def test_opset_below_minimum_upgrades_to_minimum(tmp_path, quant_mode): + """Test that specifying opset below minimum upgrades to minimum.""" + model_torch = SimpleMLP() + input_tensor = torch.randn(2, 16, 16) + + onnx_path = os.path.join(tmp_path, "model.onnx") + export_as_onnx(model_torch, input_tensor, onnx_filename=onnx_path) + + min_opset = MIN_OPSET[quant_mode] + + # Request opset below minimum + moq.quantize(onnx_path, quantize_mode=quant_mode, opset=min_opset - 1) + + # Verify output model was upgraded to minimum opset + output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx") + output_model = onnx.load(output_onnx_path) + output_opset = get_opset_version(output_model) + + assert output_opset == min_opset, ( + f"Expected opset {min_opset} for {quant_mode}, got {output_opset}" + ) + + +@pytest.mark.parametrize("quant_mode", ["int8", "fp8", "int4"]) +def test_opset_below_original_uses_original(tmp_path, quant_mode): + """Test that specifying opset below original model's opset uses original.""" + model_torch = SimpleMLP() + input_tensor = torch.randn(2, 16, 16) + + min_opset = MIN_OPSET[quant_mode] + higher_opset = min_opset + 1 + + # Skip if required opset exceeds onnxruntime support + ort_version = version.parse(onnxruntime.__version__) + if higher_opset >= 22 and ort_version < ORT_VERSION_FOR_OPSET_22: + pytest.skip( + f"Opset {higher_opset} requires onnxruntime >= {ORT_VERSION_FOR_OPSET_22}, have {ort_version}" + ) + + onnx_path = os.path.join(tmp_path, "model.onnx") + export_as_onnx(model_torch, input_tensor, onnx_filename=onnx_path, opset=higher_opset) + + # Verify the exported model has the higher opset + original_model = onnx.load(onnx_path) + assert get_opset_version(original_model) == higher_opset + + # Request opset below original (but above minimum) + moq.quantize(onnx_path, quantize_mode=quant_mode, opset=min_opset) + + # Verify output model preserves the higher original opset + output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx") + output_model = onnx.load(output_onnx_path) + output_opset = get_opset_version(output_model) + + assert output_opset == higher_opset, ( + f"Expected original opset {higher_opset} to be preserved, got {output_opset}" + ) + + +@pytest.mark.parametrize("quant_mode", ["int8", "fp8", "int4"]) +def test_opset_above_minimum(tmp_path, quant_mode): + """Test that specifying opset at or above minimum is respected.""" + model_torch = SimpleMLP() + input_tensor = torch.randn(2, 16, 16) + + min_opset = MIN_OPSET[quant_mode] + target_opset = min_opset + 1 + + onnx_path = os.path.join(tmp_path, "model.onnx") + export_as_onnx(model_torch, input_tensor, onnx_filename=onnx_path) + + moq.quantize(onnx_path, quantize_mode=quant_mode, opset=target_opset) + + # Verify output model has the requested opset + output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx") + output_model = onnx.load(output_onnx_path) + output_opset = get_opset_version(output_model) + + assert output_opset == target_opset, ( + f"Expected opset {target_opset} for {quant_mode}, got {output_opset}" + )