Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions fms_mo/aiu_addons/fp8/fp8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,46 @@
shard_base_linear,
)
from fms.modules.tp import ShardType, TPModule

# Register decomps for torchao >= 0.12 + AIU.
# This import only succeeds if torchao is 0.12 or higher
try:
# Third Party
from torchao.quantization.quant_primitives import _expand_scale_to_tensor_shape

# This function is copied from _quantize_affine_float8
# in torchao.quantization.quant_primitives, but removing
# the wrapping that turns it into a custom pytorch op
def _quantize_affine_float8_custom(
tensor: torch.Tensor,
scale: torch.Tensor,
float8_dtype: torch.dtype = torch.float8_e4m3fn,
) -> torch.Tensor:
"""
Quantizes the high precision floating point tensor
to a float8 tensor, using the given scaling factor.
"""
tensor_fp32 = tensor.to(torch.float32)

# Expand scale to match tensor dimensions for block-wise quantization
scale_expanded = _expand_scale_to_tensor_shape(scale, tensor.shape)

tensor_scaled = tensor_fp32 / scale_expanded
max_value = torch.finfo(float8_dtype).max
tensor_clamped = tensor_scaled.clamp(min=-max_value, max=max_value)
fp8_tensor = tensor_clamped.to(float8_dtype)
return fp8_tensor

quant_lib = torch.library.Library("torchao", "FRAGMENT")
quant_lib.impl(
"quantize_affine_float8",
_quantize_affine_float8_custom,
"CompositeImplicitAutograd",
)
except ImportError:
pass

# Third Party
from torchao.dtypes.affine_quantized_tensor import (
AffineQuantizedTensor,
to_affine_quantized_floatx,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ dependencies = [

[project.optional-dependencies]
examples = ["ninja>=1.11.1.1,<2.0", "evaluate", "huggingface_hub"]
fp8 = ["llmcompressor", "torchao"]
fp8 = ["llmcompressor", "torchao>=0.11,<=0.12"]
gptq = ["Cython", "gptqmodel>=1.7.3"]
mx = ["microxcaling>=1.1"]
opt = ["fms-model-optimizer[fp8, gptq, mx]"]
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ deps =
pylint>=2.16.2,<4.0
pylint-pydantic
ibm-fms
torchao
torchao>=0.11,<=0.12
commands =
{basepython} -m pylint --load-plugins pylint_pydantic fms_mo/ tests/

Expand Down
Loading