diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index 198dacc5f..5417b7fa2 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -236,6 +236,31 @@ def get_scaling_factor(quantizer: TensorQuantizer) -> torch.Tensor: return scaling_factor +def _ensure_weight_quantizer_calibrated( + weight_quantizer: TensorQuantizer, weight: torch.Tensor, module_name: str = "" +) -> None: + """Calibrate weight quantizer if amax is not set. + + This is a lazy calibration pattern used during export when weight quantizers + may not have been calibrated during the main calibration phase. + + Args: + weight_quantizer: The weight quantizer to calibrate + weight: The weight tensor to use for calibration + module_name: Optional module name for better warning messages + """ + if not hasattr(weight_quantizer, "_amax") or weight_quantizer._amax is None: + warn( + f"Weight quantizer{f' for {module_name}' if module_name else ''} was not calibrated. " + f"Computing amax from weights. This may occur if: " + f"some experts were not activated during calibration (expected for MoE models), try increasing --calib_size" + ) + weight_quantizer.reset_amax() + enable_stats_collection(weight_quantizer) + weight_quantizer(weight) + finish_stats_collection(weight_quantizer) + + def get_activation_scaling_factor( module: nn.Module, input_quantizer_name: str = "input_quantizer" ) -> torch.Tensor: @@ -279,6 +304,10 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") -> QUANTIZATION_NVFP4_SVDQUANT, QUANTIZATION_W4A8_NVFP4_FP8, ]: + # Calibrate weight quantizer if amax is not set + module_name = f"{type(module).__name__}.{weight_name}" + _ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name) + if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8: # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6. # This is because the kernel dequantizes weight to fp8, which is in range 448. @@ -307,13 +336,26 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight") if weight_quantizer is None: return None - if get_quantization_format(module) in [ + quantization_format = get_quantization_format(module) + + # Calibrate weight quantizer if amax is not set for all NVFP4 variants + if quantization_format in [ + QUANTIZATION_NVFP4, + QUANTIZATION_NVFP4_AWQ, + QUANTIZATION_NVFP4_SVDQUANT, + QUANTIZATION_W4A8_NVFP4_FP8, + ]: + weight = getattr(module, weight_name) + module_name = f"{type(module).__name__}.{weight_name}" + _ensure_weight_quantizer_calibrated(weight_quantizer, weight, module_name) + + if quantization_format in [ QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, QUANTIZATION_NVFP4_SVDQUANT, ]: return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer) - elif get_quantization_format(module) == QUANTIZATION_W4A8_NVFP4_FP8: + elif quantization_format == QUANTIZATION_W4A8_NVFP4_FP8: # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6. # This is because the kernel dequantizes weight to fp8, which is in range 448. return weight_quantizer._amax.float() / 448.0