Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
"w4a8_mxfp4_fp8": mtq.W4A8_MXFP4_FP8_CFG,
"nvfp4_mlp_only": mtq.NVFP4_MLP_ONLY_CFG,
"nvfp4_svdquant": mtq.NVFP4_SVDQUANT_DEFAULT_CFG,
"mxfp8": mtq.MXFP8_DEFAULT_CFG,
}

KV_QUANT_CFG_CHOICES = {
Expand Down Expand Up @@ -185,6 +186,7 @@ def auto_quantize(
"fp8_pb_wo",
"w4a8_mxfp4_fp8",
"nvfp4_mlp_only",
"mxfp8",
]
for args.qformat in qformat_list
), "One or more quantization formats provided are not supported for unified checkpoint export"
Expand Down Expand Up @@ -776,6 +778,7 @@ def quantize_main(
"fp8_pb_wo",
"w4a8_mxfp4_fp8",
"nvfp4_mlp_only",
"mxfp8",
]
or args.kv_cache_qformat in KV_QUANT_CFG_CHOICES
), f"Plain quantization format {args.qformat} not supported for HF export path"
Expand Down
4 changes: 2 additions & 2 deletions examples/llm_ptq/scripts/huggingface_example.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ esac
IFS=","
for qformat in $QFORMAT; do
case $qformat in
fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8 | nvfp4_mlp_only | nvfp4_svdquant) ;;
fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8 | nvfp4_mlp_only | nvfp4_svdquant | mxfp8) ;;
*)
echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, nvfp4_mlp_only, nvfp4_svdquant]" >&2
echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, nvfp4_mlp_only, nvfp4_svdquant, mxfp8]" >&2
exit 1
;;
esac
Expand Down
1 change: 1 addition & 0 deletions modelopt/torch/export/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
QUANTIZATION_NVFP4_SVDQUANT = "nvfp4_svdquant"
QUANTIZATION_W4A8_NVFP4_FP8 = "w4a8_nvfp4_fp8"
QUANTIZATION_MXFP4 = "mxfp4"
QUANTIZATION_MXFP8 = "mxfp8"
QUANTIZATION_W4A8_MXFP4_FP8 = "w4a8_mxfp4_fp8"
QUANTIZATION_NVFP4_AWQ = "nvfp4_awq"
QUANTIZATION_FP8_PB_REAL = "fp8_pb_real"
Expand Down
21 changes: 21 additions & 0 deletions modelopt/torch/export/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from modelopt.torch.quantization.qtensor import (
FP8QTensor,
MXFP4QTensor,
MXFP8QTensor,
NVFP4QTensor,
QTensorWrapper,
)
Expand All @@ -58,6 +59,7 @@
QUANTIZATION_INT8_SQ,
QUANTIZATION_INT8_WO,
QUANTIZATION_MXFP4,
QUANTIZATION_MXFP8,
QUANTIZATION_NONE,
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
Expand Down Expand Up @@ -326,6 +328,9 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") ->
return MXFP4QTensor.quantize(weight, block_size=weight_quantizer.block_sizes[-1])[
1
].reshape(*weight.shape[:-1], -1)

if quantization_format == QUANTIZATION_MXFP8:
return MXFP8QTensor.get_weights_scaling_factor_from_quantizer(weight, weight_quantizer)
return get_scaling_factor(weight_quantizer)


Expand Down Expand Up @@ -524,6 +529,14 @@ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames
if weight_quantizer.num_bits == (4, 3):
if weight_quantizer.block_sizes:
assert weight_quantizer.block_sizes[-1] > 0, "Invalid block_sizes for FP8 quantizer"
# Check if this is MXFP8 (dynamic block quantization with scale_bits (8, 0))
block_sizes = getattr(weight_quantizer, "block_sizes")
if (
isinstance(block_sizes, dict)
and block_sizes.get("type", "static") == "dynamic"
and block_sizes.get("scale_bits") == (8, 0)
):
return QUANTIZATION_MXFP8
if weight_quantizer.fake_quant:
return QUANTIZATION_FP8_PB_WO
else:
Expand Down Expand Up @@ -724,6 +737,11 @@ def process_layer_quant_config(layer_config_dict):
"quant_algo": "NVFP4_SVD",
"group_size": block_size_value,
}
elif v == "mxfp8":
layer_config = {
"quant_algo": "MXFP8",
"group_size": block_size_value,
}
else:
layer_config = {"quant_algo": v}

Expand Down Expand Up @@ -828,6 +846,9 @@ def to_quantized_weight(
if quantization in [QUANTIZATION_INT8_SQ, QUANTIZATION_INT8_WO]:
return (weight / weights_scaling_factor[:, None]).round().clamp(-128, 127).to(torch.int8)

if quantization == QUANTIZATION_MXFP8:
return MXFP8QTensor.quantize_with_scale(weight, weights_scaling_factor)

if quantization == QUANTIZATION_FP8_PB_WO:
return FP8QTensor.quantize(
weight, weights_scaling_factor.squeeze(), block_sizes={-1: block_size, -2: block_size}
Expand Down
12 changes: 11 additions & 1 deletion modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@

from modelopt.torch.quantization import set_quantizer_by_cfg_context
from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer
from modelopt.torch.quantization.qtensor import NVFP4QTensor
from modelopt.torch.quantization.qtensor import MXFP8QTensor, NVFP4QTensor
from modelopt.torch.quantization.utils import fsdp2_aware_weight_update, quantizer_attr_names

from .convert_hf_config import convert_hf_quant_config_format
Expand All @@ -67,6 +67,7 @@
QUANTIZATION_FP8,
QUANTIZATION_FP8_PB_REAL,
QUANTIZATION_FP8_PC_PT,
QUANTIZATION_MXFP8,
QUANTIZATION_NONE,
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
Expand Down Expand Up @@ -426,6 +427,15 @@ def _export_quantized_weight(
weight_quantizer._scale.to(torch.float32),
)
del weight_quantizer._scale
elif quantization_format == QUANTIZATION_MXFP8:
# MXFP8 uses dynamic block quantization with E8M0 scales (uint8)
weight = getattr(sub_module, weight_name)
e8m0_scale = MXFP8QTensor.get_weights_scaling_factor_from_quantizer(
weight, weight_quantizer
)
sub_module.register_buffer(quantizer_attrs.weight_scale, e8m0_scale)
if hasattr(weight_quantizer, "_scale") and weight_quantizer._scale is not None:
del weight_quantizer._scale
else:
sub_module.register_buffer(
quantizer_attrs.weight_scale, get_weight_scaling_factor(sub_module, weight_name)
Expand Down
41 changes: 27 additions & 14 deletions modelopt/torch/quantization/nn/modules/tensor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
INT4QTensor,
INT8QTensor,
MXFP4QTensor,
MXFP8QTensor,
NF4QTensor,
NVFP4QTensor,
QTensorWrapper,
Expand Down Expand Up @@ -649,8 +650,32 @@ def _real_quantize(self, inputs):
assert self._is_real_quantize_support(), "Real quantization not supported for this format."

buffer_to_register = {}
if self._num_bits == (4, 3):
# FP8 quantization
# Check MX formats first (before FP8) since MXFP8 also has num_bits=(4,3)
if (
self._block_sizes
and self._block_sizes.get("scale_bits") == (8, 0)
and self._block_sizes.get("type") == "dynamic"
):
# MX quantization (MXFP4/MXFP8)
if self._num_bits == (2, 1):
# MXFP4
outputs, scales = MXFP4QTensor.quantize(inputs, self._block_sizes[-1])
buffer_to_register["_scale"] = scales
elif self._num_bits == (4, 3):
# MXFP8
assert self._block_sizes[-1] == MXFP8QTensor.BLOCK_SIZE, (
f"MXFP8 requires block size {MXFP8QTensor.BLOCK_SIZE}, "
f"got {self._block_sizes[-1]}"
)
outputs, scales = MXFP8QTensor.quantize(inputs)
buffer_to_register["_scale"] = scales
else:
raise ValueError(
f"Unsupported MX format: num_bits={self._num_bits}. "
f"Expected (2, 1) for MXFP4 or (4, 3) for MXFP8."
)
elif self._num_bits == (4, 3):
# FP8 quantization (non-MX)
# For per-tensor/per-channel quantization, we might need amax which is synced across all ranks
# For blockwise quantization, amax will be recomputed in the kernel
use_amax = self.amax is not None and not (self._block_sizes and self.amax.numel() == 1)
Expand Down Expand Up @@ -683,18 +708,6 @@ def _real_quantize(self, inputs):
buffer_to_register["_scale"] = _scale
buffer_to_register["_double_scale"] = _double_scale
buffer_to_register["_scale_zeros"] = _scale_zeros
elif (
self._block_sizes.get("scale_bits") == (8, 0)
and self._block_sizes.get("type") == "dynamic"
):
# MX quantization
if self._num_bits == (2, 1):
outputs, scales = MXFP4QTensor.quantize(inputs, self._block_sizes[-1])
buffer_to_register["_scale"] = scales
else:
raise ValueError(
f"Real quantization for MX {self._num_bits} format is not supported."
)
elif self._block_sizes.get("scale_bits") == (4, 3):
# NVFP4 default quantization
# Return real quantized tensor and store scales inside TensorQuantizer
Expand Down
1 change: 1 addition & 0 deletions modelopt/torch/quantization/qtensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@
from .int4_tensor import *
from .int8_tensor import *
from .mxfp4_tensor import *
from .mxfp8_tensor import *
from .nf4_tensor import *
from .nvfp4_tensor import *
Loading