diff --git a/fms_mo/aiu_addons/fp8/fp8_linear.py b/fms_mo/aiu_addons/fp8/fp8_linear.py index 7f85f33..8456268 100644 --- a/fms_mo/aiu_addons/fp8/fp8_linear.py +++ b/fms_mo/aiu_addons/fp8/fp8_linear.py @@ -38,6 +38,46 @@ shard_base_linear, ) from fms.modules.tp import ShardType, TPModule + + # Register decomps for torchao >= 0.12 + AIU. + # This import only succeeds if torchao is 0.12 or higher + try: + # Third Party + from torchao.quantization.quant_primitives import _expand_scale_to_tensor_shape + + # This function is copied from _quantize_affine_float8 + # in torchao.quantization.quant_primitives, but removing + # the wrapping that turns it into a custom pytorch op + def _quantize_affine_float8_custom( + tensor: torch.Tensor, + scale: torch.Tensor, + float8_dtype: torch.dtype = torch.float8_e4m3fn, + ) -> torch.Tensor: + """ + Quantizes the high precision floating point tensor + to a float8 tensor, using the given scaling factor. + """ + tensor_fp32 = tensor.to(torch.float32) + + # Expand scale to match tensor dimensions for block-wise quantization + scale_expanded = _expand_scale_to_tensor_shape(scale, tensor.shape) + + tensor_scaled = tensor_fp32 / scale_expanded + max_value = torch.finfo(float8_dtype).max + tensor_clamped = tensor_scaled.clamp(min=-max_value, max=max_value) + fp8_tensor = tensor_clamped.to(float8_dtype) + return fp8_tensor + + quant_lib = torch.library.Library("torchao", "FRAGMENT") + quant_lib.impl( + "quantize_affine_float8", + _quantize_affine_float8_custom, + "CompositeImplicitAutograd", + ) + except ImportError: + pass + + # Third Party from torchao.dtypes.affine_quantized_tensor import ( AffineQuantizedTensor, to_affine_quantized_floatx, diff --git a/pyproject.toml b/pyproject.toml index 06476e6..e3f2300 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ [project.optional-dependencies] examples = ["ninja>=1.11.1.1,<2.0", "evaluate", "huggingface_hub"] -fp8 = ["llmcompressor", "torchao"] +fp8 = ["llmcompressor", "torchao>=0.11,<=0.12"] gptq = ["Cython", "gptqmodel>=1.7.3"] mx = ["microxcaling>=1.1"] opt = ["fms-model-optimizer[fp8, gptq, mx]"] diff --git a/tox.ini b/tox.ini index 78d792b..f01bd4c 100644 --- a/tox.ini +++ b/tox.ini @@ -34,7 +34,7 @@ deps = pylint>=2.16.2,<4.0 pylint-pydantic ibm-fms - torchao + torchao>=0.11,<=0.12 commands = {basepython} -m pylint --load-plugins pylint_pydantic fms_mo/ tests/