From b2c0d54029736d63c33e9d358e74325a272d689b Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Wed, 25 Jun 2025 01:19:14 +0000 Subject: [PATCH 1/5] Addons for FP8 attention bmm in FMS Signed-off-by: Andrea Fasoli --- fms_mo/aiu_addons/fp8/__init__.py | 0 fms_mo/aiu_addons/fp8/fp8_aiu_bmm.py | 228 +++++++++++++++++++++++++++ fms_mo/aiu_addons/fp8/fp8_aiu_op.py | 84 ++++++++++ 3 files changed, 312 insertions(+) create mode 100644 fms_mo/aiu_addons/fp8/__init__.py create mode 100644 fms_mo/aiu_addons/fp8/fp8_aiu_bmm.py create mode 100644 fms_mo/aiu_addons/fp8/fp8_aiu_op.py diff --git a/fms_mo/aiu_addons/fp8/__init__.py b/fms_mo/aiu_addons/fp8/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fms_mo/aiu_addons/fp8/fp8_aiu_bmm.py b/fms_mo/aiu_addons/fp8/fp8_aiu_bmm.py new file mode 100644 index 00000000..21b182c2 --- /dev/null +++ b/fms_mo/aiu_addons/fp8/fp8_aiu_bmm.py @@ -0,0 +1,228 @@ +# Copyright The FMS Model Optimizer Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""FMS registration of attention BMM operation using torch-registered scaled BMM.""" + +# Standard +from importlib.util import find_spec +from typing import NotRequired, Unpack +import math + +# Third Party +from fms.modules.attention import ( + AttentionKwargs, + _sdpa_update_attn_kwargs, + register_attention_op, +) +from torch import Tensor +import torch + +# Local +import fms_mo.aiu_addons.fp8.fp8_aiu_op # pylint: disable=unused-import + +if find_spec("torchao"): + TORCHAO_INSTALLED = True + # Third Party + from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor + from torchao.dtypes.floatx.float8_layout import ( + Float8AQTTensorImpl, + Float8Layout, + Float8MMConfig, + ) + from torchao.quantization.granularity import PerTensor + from torchao.quantization.observer import get_block_size + from torchao.quantization.quant_primitives import ZeroPointDomain +else: + TORCHAO_INSTALLED = False + + +class MathFP8AttentionKwargs(AttentionKwargs): + """TypedDict for FP8 attention.""" + + mask: NotRequired[Tensor] + do_scale_q: bool + is_causal_mask: bool + + +# TODO: Doesn't quite work yet, more discussion needed +Q_RANGE = 200.0 +K_RANGE = 200.0 +V_RANGE = 100.0 + + +def _construct_fp8_cache( + tensor: Tensor, scale: Tensor, orig_dtype: torch.dtype +) -> AffineQuantizedTensor: + """Construct the torchao tensor to save kv cache with its scales.""" + + weight_granularity = PerTensor() + fp8_layout = Float8Layout(Float8MMConfig(use_fast_accum=True)) + return AffineQuantizedTensor( + Float8AQTTensorImpl.from_plain( + tensor, + scale, + None, + fp8_layout, + ), + get_block_size(tensor.shape, weight_granularity), + tensor.shape, + zero_point_domain=ZeroPointDomain.NONE, + dtype=orig_dtype, + ) + + +def _math_fp8_store_op( + keys: Tensor, # pylint: disable=unused-argument + values: Tensor, + key_cache: Tensor | None, + value_cache: Tensor | None, + **attn_kwargs: Unpack[MathFP8AttentionKwargs], +) -> tuple[Tensor, Tensor, Tensor, Tensor]: + """Implement math of KV cache storing.""" + + orig_dtype = keys.dtype + + if isinstance(key_cache, AffineQuantizedTensor) and isinstance( + value_cache, AffineQuantizedTensor + ): + k_scale = key_cache.tensor_impl.scale + v_scale = value_cache.tensor_impl.scale + else: + k_scale = (torch.abs(keys).max() / K_RANGE).to(dtype=torch.float32) + v_scale = (torch.abs(values).max() / V_RANGE).to(dtype=torch.float32) + + keys = (keys / k_scale).to(torch.float8_e4m3fn).transpose(2, 1) + values = (values / v_scale).to(torch.float8_e4m3fn).transpose(2, 1) + + if ( + isinstance(key_cache, AffineQuantizedTensor) + and isinstance(value_cache, AffineQuantizedTensor) + and value_cache.numel() > 0 + ): + key_cache = torch.cat((key_cache.tensor_impl.float8_data, keys), dim=2) + value_cache = torch.cat((value_cache.tensor_impl.float8_data, values), dim=2) + key_cache = _construct_fp8_cache(key_cache, k_scale, orig_dtype) + value_cache = _construct_fp8_cache(value_cache, v_scale, orig_dtype) + return ( + key_cache, + value_cache, + key_cache, + value_cache, + ) + + keys = _construct_fp8_cache(keys, k_scale, orig_dtype) + values = _construct_fp8_cache(values, v_scale, orig_dtype) + return (keys, values, keys, values) + + +def _math_fp8_compute_op( + query: Tensor, + key_cache: Tensor, + value_cache: Tensor, + nheads: int, + kvheads: int, + p_dropout: float, + scale_factor: float | None, + **attn_kwargs: Unpack[MathFP8AttentionKwargs], +) -> Tensor: + """Implement computation of attention BMM, leveraging the custom scaled attention + BMM op that was pre-registered for torch.compile.""" + + orig_dtype = query.dtype + + q_scale = torch.tensor(1.0, dtype=torch.float32, device=query.device) + if attn_kwargs.get("do_scale_q", False): + q_scale.copy_(torch.abs(query).max() / Q_RANGE) + query = query / q_scale + + query = query.to(torch.float8_e4m3fn).transpose(2, 1) + + if isinstance(key_cache, AffineQuantizedTensor) and isinstance( + value_cache, AffineQuantizedTensor + ): + k_scale = key_cache.tensor_impl.scale + v_scale = value_cache.tensor_impl.scale + key_cache = key_cache.tensor_impl.float8_data + value_cache = value_cache.tensor_impl.float8_data + else: + k_scale = (torch.abs(key_cache).max() / K_RANGE).to(dtype=torch.float32) + v_scale = (torch.abs(value_cache).max() / V_RANGE).to(dtype=torch.float32) + key_cache = (key_cache / k_scale).to(torch.float8_e4m3fn) + value_cache = (value_cache / v_scale).to(torch.float8_e4m3fn) + + # no longer transposing prior to store, so need to check this in case of no cache + # TODO: Refactor FMS to avoid edge cases where this fails; add use_cache param here + if key_cache.shape[1] != kvheads and key_cache.shape[2] == kvheads: + key_cache = key_cache.transpose(2, 1) + value_cache = value_cache.transpose(2, 1) + + mask = attn_kwargs.get("mask", None) + if mask is not None: + # Our expected mask format is bs x q_len x k_len, so to make it broadcastable + # we need to create the nheads dimension + while len(mask.size()) != 4: # expects bs (x nheads) x q_len x kv_len + mask = mask.unsqueeze(1) + + L, S = query.size(-2), key_cache.size(-2) + scale_factor = ( + 1 / math.sqrt(query.size(-1)) if scale_factor is None else scale_factor + ) + attn_bias = torch.zeros(L, S, dtype=orig_dtype, device=query.device) + if attn_kwargs.get("is_causal_mask", False): + assert mask is None + temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(torch.float32) + + if mask is not None: + if mask.dtype == torch.bool: + attn_bias.masked_fill_(mask.logical_not(), float("-inf")) + else: + attn_bias = mask + attn_bias + + expansion = nheads // kvheads + if expansion > 1: + key_cache = key_cache.repeat_interleave( + query.size(-3) // key_cache.size(-3), -3 + ) + value_cache = value_cache.repeat_interleave( + query.size(-3) // value_cache.size(-3), -3 + ) + + attn_weight = ( + torch.ops.sendnn.scaled_bmm( + query, + key_cache.transpose(-2, -1), + q_scale, + k_scale, + out_dtype=orig_dtype, + use_fast_accum=True, + ) + * scale_factor + ) + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + attn_weight = torch.dropout(attn_weight, p_dropout, train=True) + # Do matmul in orig_dtype + attn = attn_weight @ (value_cache.to(dtype=orig_dtype) * v_scale) + + attn = attn.to(orig_dtype).transpose(2, 1).contiguous() + return attn + + +register_attention_op( + "math_fp8", + _math_fp8_store_op, + _math_fp8_compute_op, + update_attn_kwargs_op=_sdpa_update_attn_kwargs, +) diff --git a/fms_mo/aiu_addons/fp8/fp8_aiu_op.py b/fms_mo/aiu_addons/fp8/fp8_aiu_op.py new file mode 100644 index 00000000..483cff16 --- /dev/null +++ b/fms_mo/aiu_addons/fp8/fp8_aiu_op.py @@ -0,0 +1,84 @@ +# Copyright The FMS Model Optimizer Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Torch registration of FP8xFP8 operation for attention BMMs.""" + +# Third Party +from torch import Tensor +import torch + +# pylint: disable=unused-argument +# abstract op must be registered with specific I/O, even if not in use by the op function + + +@torch.library.custom_op("sendnn::scaled_bmm", mutates_args=()) +def sendnn_scaled_bmm( + mat1: Tensor, + mat2: Tensor, + scale1: Tensor, + scale2: Tensor, + out_dtype: torch.dtype | None = None, + use_fast_accum: bool = False, +) -> Tensor: + """Implement a custom scaled attention BMM op: a batched version of _scaled_mm. + The operations that are part of this function are not exposed to the computational + graph, but are invoked when running on non-AIU devices. + """ + + assert ( + mat1.shape[:-2] == mat2.shape[:-2] + ), "batch dimensions must match for mat1 and mat2" + assert ( + mat1.shape[:-2] == scale1.shape[:-2] + ), "batch dimensions must match for mat1 and scale1" + assert ( + mat2.shape[:-2] == scale2.shape[:-2] + ), "batch dimensions must match for mat2 and scale2" + + mat1 = mat1.view(-1, *mat1.shape[-2:]) + mat2 = mat2.view(-1, *mat2.shape[-2:]) + scale1 = scale1.view(-1, *scale1.shape[-2:]) + scale2 = scale2.view(-1, *scale2.shape[-2:]) + out = torch.empty( + (mat1.shape[0], mat1.shape[1], mat2.shape[2]), + dtype=out_dtype, + device=mat1.device, + ) + for b_idx in range(mat1.shape[0]): + out[b_idx] = torch._scaled_mm( + mat1[b_idx], + mat2[b_idx], + scale1[b_idx], + scale2[b_idx], + out_dtype, + use_fast_accum, + ) + return out + + +@sendnn_scaled_bmm.register_fake +def _( + mat1: Tensor, + mat2: Tensor, + scale1: Tensor, + scale2: Tensor, + out_dtype: torch.dtype | None = None, + use_fast_accum: bool = False, +) -> Tensor: + """Template for scaled attention BMM operation. I/O retain the expected size.""" + + return torch.empty( + (*mat1.shape[:-2], mat1.shape[-2], mat2.shape[-1]), + dtype=out_dtype, + device=mat1.device, + ) From 6f289b09861d83e0295c126f73a297505ea5bf42 Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Fri, 27 Jun 2025 23:27:55 +0000 Subject: [PATCH 2/5] Update FP8 bmm Signed-off-by: Andrea Fasoli --- .../fp8/{fp8_aiu_bmm.py => fp8_bmm.py} | 103 ++++-------- fms_mo/aiu_addons/fp8/fp8_linear.py | 0 .../fp8/{fp8_aiu_op.py => fp8_spyre_op.py} | 21 +-- fms_mo/aiu_addons/fp8/fp8_utils.py | 148 ++++++++++++++++++ 4 files changed, 186 insertions(+), 86 deletions(-) rename fms_mo/aiu_addons/fp8/{fp8_aiu_bmm.py => fp8_bmm.py} (64%) create mode 100644 fms_mo/aiu_addons/fp8/fp8_linear.py rename fms_mo/aiu_addons/fp8/{fp8_aiu_op.py => fp8_spyre_op.py} (80%) create mode 100644 fms_mo/aiu_addons/fp8/fp8_utils.py diff --git a/fms_mo/aiu_addons/fp8/fp8_aiu_bmm.py b/fms_mo/aiu_addons/fp8/fp8_bmm.py similarity index 64% rename from fms_mo/aiu_addons/fp8/fp8_aiu_bmm.py rename to fms_mo/aiu_addons/fp8/fp8_bmm.py index 21b182c2..e04cce9b 100644 --- a/fms_mo/aiu_addons/fp8/fp8_aiu_bmm.py +++ b/fms_mo/aiu_addons/fp8/fp8_bmm.py @@ -14,7 +14,6 @@ """FMS registration of attention BMM operation using torch-registered scaled BMM.""" # Standard -from importlib.util import find_spec from typing import NotRequired, Unpack import math @@ -24,79 +23,44 @@ _sdpa_update_attn_kwargs, register_attention_op, ) -from torch import Tensor import torch # Local -import fms_mo.aiu_addons.fp8.fp8_aiu_op # pylint: disable=unused-import - -if find_spec("torchao"): - TORCHAO_INSTALLED = True - # Third Party - from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor - from torchao.dtypes.floatx.float8_layout import ( - Float8AQTTensorImpl, - Float8Layout, - Float8MMConfig, - ) - from torchao.quantization.granularity import PerTensor - from torchao.quantization.observer import get_block_size - from torchao.quantization.quant_primitives import ZeroPointDomain -else: - TORCHAO_INSTALLED = False +from fms_mo.aiu_addons.fp8.fp8_utils import ScaledTensor +import fms_mo.aiu_addons.fp8.fp8_spyre_op # pylint: disable=unused-import class MathFP8AttentionKwargs(AttentionKwargs): """TypedDict for FP8 attention.""" - mask: NotRequired[Tensor] + mask: NotRequired[torch.Tensor] do_scale_q: bool is_causal_mask: bool -# TODO: Doesn't quite work yet, more discussion needed +# TODO: Figure out better scales for AIU? These come from vLLM Q_RANGE = 200.0 K_RANGE = 200.0 V_RANGE = 100.0 -def _construct_fp8_cache( - tensor: Tensor, scale: Tensor, orig_dtype: torch.dtype -) -> AffineQuantizedTensor: - """Construct the torchao tensor to save kv cache with its scales.""" - - weight_granularity = PerTensor() - fp8_layout = Float8Layout(Float8MMConfig(use_fast_accum=True)) - return AffineQuantizedTensor( - Float8AQTTensorImpl.from_plain( - tensor, - scale, - None, - fp8_layout, - ), - get_block_size(tensor.shape, weight_granularity), - tensor.shape, - zero_point_domain=ZeroPointDomain.NONE, - dtype=orig_dtype, - ) +def _construct_fp8_cache(tensor: torch.Tensor, scale: torch.Tensor) -> ScaledTensor: + """Construct the custom object to save KV cache with its scales.""" + return ScaledTensor(tensor, scale) def _math_fp8_store_op( - keys: Tensor, # pylint: disable=unused-argument - values: Tensor, - key_cache: Tensor | None, - value_cache: Tensor | None, + keys: torch.Tensor, # pylint: disable=unused-argument + values: torch.Tensor, + key_cache: torch.Tensor | None, + value_cache: torch.Tensor | None, **attn_kwargs: Unpack[MathFP8AttentionKwargs], -) -> tuple[Tensor, Tensor, Tensor, Tensor]: +) -> tuple[ScaledTensor, ScaledTensor, ScaledTensor, ScaledTensor]: """Implement math of KV cache storing.""" - orig_dtype = keys.dtype - - if isinstance(key_cache, AffineQuantizedTensor) and isinstance( - value_cache, AffineQuantizedTensor - ): - k_scale = key_cache.tensor_impl.scale - v_scale = value_cache.tensor_impl.scale + if isinstance(key_cache, ScaledTensor) and isinstance(value_cache, ScaledTensor): + k_scale = key_cache._scale + v_scale = value_cache._scale else: k_scale = (torch.abs(keys).max() / K_RANGE).to(dtype=torch.float32) v_scale = (torch.abs(values).max() / V_RANGE).to(dtype=torch.float32) @@ -105,36 +69,35 @@ def _math_fp8_store_op( values = (values / v_scale).to(torch.float8_e4m3fn).transpose(2, 1) if ( - isinstance(key_cache, AffineQuantizedTensor) - and isinstance(value_cache, AffineQuantizedTensor) + isinstance(key_cache, ScaledTensor) + and isinstance(value_cache, ScaledTensor) and value_cache.numel() > 0 ): - key_cache = torch.cat((key_cache.tensor_impl.float8_data, keys), dim=2) - value_cache = torch.cat((value_cache.tensor_impl.float8_data, values), dim=2) - key_cache = _construct_fp8_cache(key_cache, k_scale, orig_dtype) - value_cache = _construct_fp8_cache(value_cache, v_scale, orig_dtype) + key_cache = torch.cat((key_cache._data, keys), dim=2) + value_cache = torch.cat((value_cache._data, values), dim=2) + key_cache = _construct_fp8_cache(key_cache, k_scale) + value_cache = _construct_fp8_cache(value_cache, v_scale) return ( key_cache, value_cache, key_cache, value_cache, ) - - keys = _construct_fp8_cache(keys, k_scale, orig_dtype) - values = _construct_fp8_cache(values, v_scale, orig_dtype) + keys = _construct_fp8_cache(keys.contiguous(), k_scale) + values = _construct_fp8_cache(values.contiguous(), v_scale) return (keys, values, keys, values) def _math_fp8_compute_op( - query: Tensor, - key_cache: Tensor, - value_cache: Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, nheads: int, kvheads: int, p_dropout: float, scale_factor: float | None, **attn_kwargs: Unpack[MathFP8AttentionKwargs], -) -> Tensor: +) -> torch.Tensor: """Implement computation of attention BMM, leveraging the custom scaled attention BMM op that was pre-registered for torch.compile.""" @@ -147,13 +110,11 @@ def _math_fp8_compute_op( query = query.to(torch.float8_e4m3fn).transpose(2, 1) - if isinstance(key_cache, AffineQuantizedTensor) and isinstance( - value_cache, AffineQuantizedTensor - ): - k_scale = key_cache.tensor_impl.scale - v_scale = value_cache.tensor_impl.scale - key_cache = key_cache.tensor_impl.float8_data - value_cache = value_cache.tensor_impl.float8_data + if isinstance(key_cache, ScaledTensor) and isinstance(value_cache, ScaledTensor): + k_scale = key_cache._scale + v_scale = value_cache._scale + key_cache = key_cache._data + value_cache = value_cache._data else: k_scale = (torch.abs(key_cache).max() / K_RANGE).to(dtype=torch.float32) v_scale = (torch.abs(value_cache).max() / V_RANGE).to(dtype=torch.float32) diff --git a/fms_mo/aiu_addons/fp8/fp8_linear.py b/fms_mo/aiu_addons/fp8/fp8_linear.py new file mode 100644 index 00000000..e69de29b diff --git a/fms_mo/aiu_addons/fp8/fp8_aiu_op.py b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py similarity index 80% rename from fms_mo/aiu_addons/fp8/fp8_aiu_op.py rename to fms_mo/aiu_addons/fp8/fp8_spyre_op.py index 483cff16..ad91fab0 100644 --- a/fms_mo/aiu_addons/fp8/fp8_aiu_op.py +++ b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py @@ -21,7 +21,7 @@ # abstract op must be registered with specific I/O, even if not in use by the op function -@torch.library.custom_op("sendnn::scaled_bmm", mutates_args=()) +@torch.library.custom_op("spyre::scaled_bmm", mutates_args=()) def sendnn_scaled_bmm( mat1: Tensor, mat2: Tensor, @@ -38,17 +38,8 @@ def sendnn_scaled_bmm( assert ( mat1.shape[:-2] == mat2.shape[:-2] ), "batch dimensions must match for mat1 and mat2" - assert ( - mat1.shape[:-2] == scale1.shape[:-2] - ), "batch dimensions must match for mat1 and scale1" - assert ( - mat2.shape[:-2] == scale2.shape[:-2] - ), "batch dimensions must match for mat2 and scale2" - mat1 = mat1.view(-1, *mat1.shape[-2:]) mat2 = mat2.view(-1, *mat2.shape[-2:]) - scale1 = scale1.view(-1, *scale1.shape[-2:]) - scale2 = scale2.view(-1, *scale2.shape[-2:]) out = torch.empty( (mat1.shape[0], mat1.shape[1], mat2.shape[2]), dtype=out_dtype, @@ -58,12 +49,12 @@ def sendnn_scaled_bmm( out[b_idx] = torch._scaled_mm( mat1[b_idx], mat2[b_idx], - scale1[b_idx], - scale2[b_idx], - out_dtype, - use_fast_accum, + scale1, + scale2, + out_dtype=out_dtype, + use_fast_accum=use_fast_accum, ) - return out + return out.view(*mat1.shape[:-2], mat1.shape[1], mat2.shape[2]) @sendnn_scaled_bmm.register_fake diff --git a/fms_mo/aiu_addons/fp8/fp8_utils.py b/fms_mo/aiu_addons/fp8/fp8_utils.py new file mode 100644 index 00000000..0dd21024 --- /dev/null +++ b/fms_mo/aiu_addons/fp8/fp8_utils.py @@ -0,0 +1,148 @@ +# Copyright The FMS Model Optimizer Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""FMS registration of attention BMM operation using torch-registered scaled BMM.""" + +# Standard +import functools + +# Third Party +import torch + +# pylint: disable=unused-argument +# unusued arguments are needed for templates + + +_HANDLED_FUNCTIONS = {} + + +def _implements(torch_function): + """Register a torch function override""" + + def decorator(func): + @functools.wraps(torch_function) + def wrapper(f, types, args, kwargs): + return func(f, types, args, kwargs) + + _HANDLED_FUNCTIONS[torch_function] = wrapper + return func + + return decorator + + +class ScaledTensor(torch.Tensor): + """Representation of a quantized tensor and its scale.""" + + def __new__( + cls, + data: torch.Tensor, + scale: torch.Tensor, + ): + return torch.Tensor._make_wrapper_subclass( + cls, + data.size(), + strides=data.stride(), + storage_offset=data.storage_offset(), + dtype=data.dtype, + layout=data.layout, + requires_grad=data.requires_grad, + device=data.device, + ) + + def __init__( + self, + data: torch.Tensor, + scale: torch.Tensor, + ): + self._data = data + self._scale = scale + + def __tensor_flatten__(self): + ctx = {} + return ["_data", "_scale"], ctx + + @staticmethod + def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride): + assert len(inner_tensors) == 2 + return ScaledTensor( + inner_tensors["_data"], + inner_tensors["_scale"], + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + if func in _HANDLED_FUNCTIONS: + return _HANDLED_FUNCTIONS[func](func, types, args, kwargs) + + arg_types = tuple(type(arg) for arg in args) + kwarg_types = {k: type(arg) for k, arg in kwargs.items()} + raise NotImplementedError( + f"{cls.__name__} dispatch: attempting to run unimplemented " + f"operator/function: {func=}, {types=}, {arg_types=}, {kwarg_types=}" + ) + + def __repr__(self): + return f"{self._data.__repr__()}\n{self._scale.__repr__()}" + + +def _infer_quantization_config(quant_config: dict) -> dict | None: + # There's many quantization packages compatible with HF + # We initially focus on llm-compressor as it is the one used in FMS-MO + + # llm-compressor saves its checkpoints with quant_method = compressed-tensors + # quantization_status tells us whether the model has already been quantized + # We only support loading already quantized models (compressed status) + if ( + quant_config["quant_method"] == "compressed-tensors" + and quant_config["quantization_status"] == "compressed" + ): + # FP8 quantization will have FP8 weights + # We assume a single quantization group (group_0), to follow fms-mo checkpoints + # num_bits and type tells us "float" with "8" bits, aka FP8 + if ( + quant_config["config_groups"]["group_0"]["weights"]["type"] == "float" + and quant_config["config_groups"]["group_0"]["weights"]["num_bits"] == 8 + ): + # This is used by get_linear to decide whether a linear layer + # will be quantized or not inside the model + def fp8_linear_type(name: str) -> str: + # We need to translate HF names to FMS names + translations = { + "lm_head": "head", + } + for ignored_layer in quant_config["ignore"]: + assert isinstance(ignored_layer, str) + fms_ign_layer = translations.get(ignored_layer, ignored_layer) + if name in fms_ign_layer: + return "torch_linear" + for pattern in quant_config["config_groups"]["group_0"]["targets"]: + # Special case from llm-compressor that covers all linear layers + # not in the ignore pattern + assert isinstance(pattern, str) + if pattern == "Linear": + return "fp8" + if name in translations.get(pattern, pattern): + return "fp8" + return "torch_linear" + + return { + "linear_type": fp8_linear_type, + "input_activations": quant_config["config_groups"]["group_0"][ + "input_activations" + ], + "output_activations": quant_config["config_groups"]["group_0"][ + "output_activations" + ], + "weights": quant_config["config_groups"]["group_0"]["weights"], + } + return None From 30f76a90230af491a613bbd0bc59197d3a513e51 Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Sat, 28 Jun 2025 00:11:09 +0000 Subject: [PATCH 3/5] Add FP8 adapter step Signed-off-by: Andrea Fasoli --- fms_mo/aiu_addons/fp8/fp8_adapter.py | 55 +++++++++++++++++++++++++++ fms_mo/aiu_addons/fp8/fp8_spyre_op.py | 2 +- fms_mo/aiu_addons/fp8/fp8_utils.py | 2 +- 3 files changed, 57 insertions(+), 2 deletions(-) create mode 100644 fms_mo/aiu_addons/fp8/fp8_adapter.py diff --git a/fms_mo/aiu_addons/fp8/fp8_adapter.py b/fms_mo/aiu_addons/fp8/fp8_adapter.py new file mode 100644 index 00000000..57e52a35 --- /dev/null +++ b/fms_mo/aiu_addons/fp8/fp8_adapter.py @@ -0,0 +1,55 @@ +# Copyright The FMS Model Optimizer Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Implement and register FMS adapters for FP8 checkpoint loading.""" + +# Standard +from typing import Any, Mapping + +# Third Party +from fms.modules.linear import get_linear_type +from fms.utils import serialization +from fms.utils.config import ModelConfig + +# NOTE: this adapter step must be registered before the adapter that uses it (such as +# the llama adapter in fms.models.llama) +# TODO: may be shared with gptq llama +# TODO: generalize across architectures if possible +def _hf_fp8_llama_check( + input_sd: Mapping[str, Any], model_config: ModelConfig | None = None, **kwargs +) -> Mapping[str, Any]: + """Implementation of adapter step for FMS Llama: ensure that when FP8 quantization + is in use, weights are unfused. + """ + + has_fused_weights = True + linear_type = "torch_linear" + if model_config: + if not model_config.fused_weights: + has_fused_weights = False + if model_config.linear_config: + linear_type = model_config.linear_config["linear_type"] + if callable(linear_type): + # Calling this with "any" guarantees "fp8" to be returned + # when loading an HF fp8 checkpoint, and never in any other condition + linear_type = get_linear_type(model_config.linear_config, "any") + + if "fp8" in linear_type and has_fused_weights: + raise ValueError( + "FP8 HF llama checkpoints cannot be loaded into a model with fused weights" + ) + + return input_sd + + +serialization.register_adapter_step("llama", "hf_fp8_llama_check", _hf_fp8_llama_check) diff --git a/fms_mo/aiu_addons/fp8/fp8_spyre_op.py b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py index ad91fab0..c6280d4a 100644 --- a/fms_mo/aiu_addons/fp8/fp8_spyre_op.py +++ b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py @@ -32,7 +32,7 @@ def sendnn_scaled_bmm( ) -> Tensor: """Implement a custom scaled attention BMM op: a batched version of _scaled_mm. The operations that are part of this function are not exposed to the computational - graph, but are invoked when running on non-AIU devices. + graph, but are invoked when running on non-Spyre devices. """ assert ( diff --git a/fms_mo/aiu_addons/fp8/fp8_utils.py b/fms_mo/aiu_addons/fp8/fp8_utils.py index 0dd21024..6ad38752 100644 --- a/fms_mo/aiu_addons/fp8/fp8_utils.py +++ b/fms_mo/aiu_addons/fp8/fp8_utils.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""FMS registration of attention BMM operation using torch-registered scaled BMM.""" +"""Utility functions and components for FP8 addon implementation.""" # Standard import functools From 496bf44c395b4caa6185b7fc3af09b2c2ca53381 Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Sat, 28 Jun 2025 00:26:52 +0000 Subject: [PATCH 4/5] Add FP8 linear to FMS addon Signed-off-by: Andrea Fasoli --- fms_mo/aiu_addons/fp8/fp8_adapter.py | 4 + fms_mo/aiu_addons/fp8/fp8_linear.py | 325 +++++++++++++++++++++++++++ fms_mo/aiu_addons/fp8/fp8_utils.py | 16 +- 3 files changed, 339 insertions(+), 6 deletions(-) diff --git a/fms_mo/aiu_addons/fp8/fp8_adapter.py b/fms_mo/aiu_addons/fp8/fp8_adapter.py index 57e52a35..0155ef39 100644 --- a/fms_mo/aiu_addons/fp8/fp8_adapter.py +++ b/fms_mo/aiu_addons/fp8/fp8_adapter.py @@ -21,6 +21,10 @@ from fms.utils import serialization from fms.utils.config import ModelConfig +# pylint: disable=unused-argument +# Retaining kwargs input arguments for consistency. + + # NOTE: this adapter step must be registered before the adapter that uses it (such as # the llama adapter in fms.models.llama) # TODO: may be shared with gptq llama diff --git a/fms_mo/aiu_addons/fp8/fp8_linear.py b/fms_mo/aiu_addons/fp8/fp8_linear.py index e69de29b..ecbf681a 100644 --- a/fms_mo/aiu_addons/fp8/fp8_linear.py +++ b/fms_mo/aiu_addons/fp8/fp8_linear.py @@ -0,0 +1,325 @@ +# Copyright The FMS Model Optimizer Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Implement FP8 linear module to be loaded via FMS.""" + +# Standard +from importlib.util import find_spec +from typing import Any, Mapping + +# Third Party +from fms.modules.linear import ( + LinearModuleShardingInfo, + LinearParameterShardingInfo, + register_linear_type_to_module_map, + register_linear_type_to_sharding_map, + shard_base_linear, +) +from fms.modules.tp import ShardType, TPModule +import torch + +# pylint: disable=not-callable +# torch.nn.functional.linear not recognized as callable +# open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482 + + +### FP8 linear layers +if find_spec("torchao"): + TORCHAO_INSTALLED = True + + # Third Party + from torchao.dtypes.affine_quantized_tensor import ( # type: ignore + AffineQuantizedTensor, + to_affine_quantized_floatx, + to_affine_quantized_floatx_static, + ) + from torchao.dtypes.floatx.float8_layout import ( # type: ignore + Float8AQTTensorImpl, + Float8Layout, + Float8MMConfig, + preprocess_data, + preprocess_scale, + ) + from torchao.dtypes.utils import get_out_shape # type: ignore + from torchao.float8.inference import ( # type: ignore + _is_rowwise_scaled, + addmm_float8_unwrapped_inference, + ) + from torchao.quantization.granularity import PerRow, PerTensor # type: ignore + from torchao.quantization.observer import get_block_size # type: ignore + from torchao.quantization.quant_primitives import ZeroPointDomain # type: ignore +else: + TORCHAO_INSTALLED = False + + +class FP8Linear(torch.nn.Module): + """Class handles FP8 weights loading and uses torchao for the matmuls.""" + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_config: Mapping[str, Any], + ): + super().__init__() + + self.in_features = in_features + self.out_features = out_features + self.has_bias = bias + self.linear_config = linear_config + + assert ( + self.linear_config["weights"] is not None + ), "Weights must always be quantized for FP8Linear" + assert self.linear_config["weights"][ + "symmetric" + ], "We only support symmetric weights for now" + assert not self.linear_config["weights"][ + "dynamic" + ], "We only support pre-quantized weights for now" + + self.weight = torch.nn.Parameter( + torch.zeros(out_features, in_features, dtype=torch.float8_e4m3fn), + requires_grad=False, + ) + + weight_scale_shape = ( + (1,) + if self.linear_config["weights"]["strategy"] == "tensor" + else (out_features, 1) + ) + self.weight_scale = torch.nn.Parameter( + torch.ones(weight_scale_shape), requires_grad=False + ) + + self.has_bias = bias + if self.has_bias: + self.bias = torch.nn.Parameter(torch.zeros((out_features,))) + + if ( + self.linear_config["input_activations"] is not None + and not self.linear_config["input_activations"]["dynamic"] + ): + input_scale_shape = ( + (1,) + if self.linear_config["input_activations"]["strategy"] == "tensor" + else (out_features, 1) + ) + self.input_scale = torch.nn.Parameter( + torch.ones(input_scale_shape), requires_grad=False + ) + + def _input_activation_quant_func_fp8( + self, + x: torch.Tensor, + activation_granularity, + activation_dtype: torch.dtype, + scale: torch.Tensor | None = None, + ): + """Quantize the input activation tensor for an aqt_float variant. + If scale is not provided, it will be dynamically calculated, otherwise the + provided scale will be used. + """ + + block_size = get_block_size(x.shape, activation_granularity) + if scale is None: + activation = to_affine_quantized_floatx( + input_float=x, + block_size=block_size, + target_dtype=activation_dtype, + scale_dtype=torch.float32, + _layout=Float8Layout(mm_config=None), # Config is stored on weight + ) + else: + assert isinstance( + activation_granularity, PerTensor + ), "Static quantization only supports PerTensor granularity" + activation = to_affine_quantized_floatx_static( + input_float=x, + block_size=block_size, + scale=scale, + target_dtype=activation_dtype, + _layout=Float8Layout(mm_config=None), # Config is stored on weight + ) + return activation + + def _construct_qweight_structure(self) -> "AffineQuantizedTensor": + """Construct the torchao machinery for the fp8 matmul""" + + weight_granularity = ( + PerTensor() + if self.linear_config["weights"]["strategy"] == "tensor" + else PerRow() + ) + fp8_layout = Float8Layout(Float8MMConfig(use_fast_accum=True)) + return AffineQuantizedTensor( + Float8AQTTensorImpl.from_plain( + self.weight, + self.weight_scale.squeeze().to(torch.float32), + None, + fp8_layout, + ), + get_block_size(self.weight.shape, weight_granularity), + self.weight.shape, + zero_point_domain=ZeroPointDomain.NONE, + dtype=self.weight_scale.dtype, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """If input quantization is active, compute FP8xFP8 addmm.""" + + # fp8 weight tensor for torchao + qweight: AffineQuantizedTensor = self._construct_qweight_structure() + + if self.linear_config["input_activations"] is not None: + # activations are also fp8, quantize as required by model + act_granularity = ( + PerTensor() + if self.linear_config["input_activations"]["strategy"] == "tensor" + else PerRow() + ) + input_quant_kwargs = { + "activation_granularity": act_granularity, + "activation_dtype": torch.float8_e4m3fn, + } + if not self.linear_config["input_activations"]["dynamic"]: + input_quant_kwargs["scale"] = self.input_scale.squeeze().to( + torch.float32 + ) + qx = self._input_activation_quant_func_fp8(x, **input_quant_kwargs) + + # Copied from torchao _linear_fp8_act_fp8_weight_impl (with changes to support fp8 out) + scaled_mm_config = Float8MMConfig(use_fast_accum=True) + out_shape = get_out_shape(qx.shape, qweight.shape) + + # Weight tensor preprocessing + w_tensor_impl = qweight.tensor_impl + assert not w_tensor_impl.transposed, "Weight tensor must be contiguous" + w_data = w_tensor_impl.float8_data + w_scale = w_tensor_impl.scale + + # Input tensor preprocessing + inpt_data = qx.tensor_impl.float8_data + input_scale = qx.tensor_impl.scale + # Handle case where input tensor is more than 2D + inpt_data = inpt_data.reshape(-1, inpt_data.shape[-1]) + + # Handle rowwise case + if _is_rowwise_scaled(qweight): + assert _is_rowwise_scaled(qx), "Input tensor must be rowwise block size" + w_scale = w_scale.unsqueeze(-1).T + input_scale = preprocess_scale(input_scale, qx.shape) + + # Preprocess data + inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config) + + # Perform the computation + return addmm_float8_unwrapped_inference( + inpt_data, + input_scale, + w_data, + w_scale, + output_dtype=qx.dtype, + bias=getattr(self, "bias", None), + use_fast_accum=scaled_mm_config.use_fast_accum, + ).reshape(out_shape) + + # activations not quantized, dequant fp8 weight and do regular matmul + out = torch.nn.functional.linear( + x, qweight.dequantize(), self.bias if self.has_bias else None + ) + return out + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}" + f"(in={self.in_features}, out={self.out_features}, " + f"bias={self.has_bias}, fp8_config={self.linear_config})" + ) + + +def get_fp8_linear( + in_features: int, + out_features: int, + bias: bool, + linear_config: Mapping[str, Any], +) -> FP8Linear: + """Retrieve an FP8 Linear module""" + + if not TORCHAO_INSTALLED: + raise ModuleNotFoundError("You need to install torchao for FP8 support in FMS!") + + return FP8Linear(in_features, out_features, bias, linear_config) + + +def shard_fp8_linear( + tensor_values: dict[str, torch.Tensor], + tp_module: TPModule, + module_sharding_info: dict[str, LinearModuleShardingInfo], +) -> set | None: + """ + | GPU | + sharding | param | shard | dim | + ----------+----------------+-------+-----| + colwise | weight | Y | 0 | + | weight_scale | N | - | + | input_scale | N | - | + | bias | Y | 0 | + ----------+----------------+-------+-----| + rowwise | weight | Y | 1 | + | weight_scale | Y/N | 0/- | + | input_scale | Y/N | 0/- | + | bias | 0 | - | + """ + param_sharding_info: dict[str, dict[str, LinearParameterShardingInfo]] = {} + for module_name, module_info in module_sharding_info.items(): + linear_mod: torch.nn.Module = module_info.linear_module + weight_strategy = getattr(linear_mod, "linear_config")["input_activations"][ + "strategy" + ] + # Scales are per-row or per-tensor + # Only sharding needed when row parallel and per-row + shard_scales = weight_strategy != "tensor" and module_info.sharding_dim == 1 + params: dict[str, LinearParameterShardingInfo] = { + "weight": LinearParameterShardingInfo( + module_info.sharding_dim, ShardType.SHARD + ), + "weight_scale": LinearParameterShardingInfo( + module_info.sharding_dim, + ShardType.SHARD if shard_scales else ShardType.CLONE, + ), + } + if hasattr(linear_mod, "input_scale"): + params["input_scale"] = LinearParameterShardingInfo( + module_info.sharding_dim, + ShardType.SHARD if shard_scales else ShardType.CLONE, + ) + if hasattr(linear_mod, "bias") and linear_mod.bias is not None: + params["bias"] = LinearParameterShardingInfo( + module_info.sharding_dim, + ShardType.SHARD if module_info.sharding_dim == 0 else ShardType.RANK0, + ) + param_sharding_info[module_name] = params + + unused_keys = shard_base_linear( + tensor_values, + tp_module, + module_sharding_info, + param_sharding_info, + ) + return unused_keys + + +register_linear_type_to_module_map("fp8", get_fp8_linear) +register_linear_type_to_sharding_map("fp8", shard_fp8_linear) diff --git a/fms_mo/aiu_addons/fp8/fp8_utils.py b/fms_mo/aiu_addons/fp8/fp8_utils.py index 6ad38752..0f5fbab4 100644 --- a/fms_mo/aiu_addons/fp8/fp8_utils.py +++ b/fms_mo/aiu_addons/fp8/fp8_utils.py @@ -59,7 +59,7 @@ def __new__( device=data.device, ) - def __init__( + def __init__( # pylint: disable=super-init-not-called self, data: torch.Tensor, scale: torch.Tensor, @@ -96,12 +96,16 @@ def __repr__(self): def _infer_quantization_config(quant_config: dict) -> dict | None: - # There's many quantization packages compatible with HF - # We initially focus on llm-compressor as it is the one used in FMS-MO + """Construct linear_config dictionary carrying FP8 configuration for FMS. + + There's many quantization packages compatible with HF + We initially focus on llm-compressor as it is the one used in FMS-MO + + llm-compressor saves its checkpoints with quant_method = compressed-tensors + quantization_status tells us whether the model has already been quantized + We only support loading already quantized models (compressed status) + """ - # llm-compressor saves its checkpoints with quant_method = compressed-tensors - # quantization_status tells us whether the model has already been quantized - # We only support loading already quantized models (compressed status) if ( quant_config["quant_method"] == "compressed-tensors" and quant_config["quantization_status"] == "compressed" From c931ad73d54fa7f2173e24619f853335fcddc08d Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Sat, 28 Jun 2025 01:44:17 +0000 Subject: [PATCH 5/5] rename fp8 attention Signed-off-by: Andrea Fasoli --- fms_mo/aiu_addons/fp8/fp8_adapter.py | 2 +- .../fp8/{fp8_bmm.py => fp8_attn.py} | 0 fms_mo/aiu_addons/fp8/fp8_linear.py | 20 ++++++++++--------- 3 files changed, 12 insertions(+), 10 deletions(-) rename fms_mo/aiu_addons/fp8/{fp8_bmm.py => fp8_attn.py} (100%) diff --git a/fms_mo/aiu_addons/fp8/fp8_adapter.py b/fms_mo/aiu_addons/fp8/fp8_adapter.py index 0155ef39..7f339f72 100644 --- a/fms_mo/aiu_addons/fp8/fp8_adapter.py +++ b/fms_mo/aiu_addons/fp8/fp8_adapter.py @@ -22,7 +22,7 @@ from fms.utils.config import ModelConfig # pylint: disable=unused-argument -# Retaining kwargs input arguments for consistency. +# Retaining kwargs input arguments for consistency with other adapter steps. # NOTE: this adapter step must be registered before the adapter that uses it (such as diff --git a/fms_mo/aiu_addons/fp8/fp8_bmm.py b/fms_mo/aiu_addons/fp8/fp8_attn.py similarity index 100% rename from fms_mo/aiu_addons/fp8/fp8_bmm.py rename to fms_mo/aiu_addons/fp8/fp8_attn.py diff --git a/fms_mo/aiu_addons/fp8/fp8_linear.py b/fms_mo/aiu_addons/fp8/fp8_linear.py index ecbf681a..3e2246ba 100644 --- a/fms_mo/aiu_addons/fp8/fp8_linear.py +++ b/fms_mo/aiu_addons/fp8/fp8_linear.py @@ -33,31 +33,31 @@ # open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482 -### FP8 linear layers +# Gated torchao imports for FP8 implementation if find_spec("torchao"): TORCHAO_INSTALLED = True # Third Party - from torchao.dtypes.affine_quantized_tensor import ( # type: ignore + from torchao.dtypes.affine_quantized_tensor import ( AffineQuantizedTensor, to_affine_quantized_floatx, to_affine_quantized_floatx_static, ) - from torchao.dtypes.floatx.float8_layout import ( # type: ignore + from torchao.dtypes.floatx.float8_layout import ( Float8AQTTensorImpl, Float8Layout, Float8MMConfig, preprocess_data, preprocess_scale, ) - from torchao.dtypes.utils import get_out_shape # type: ignore - from torchao.float8.inference import ( # type: ignore + from torchao.dtypes.utils import get_out_shape + from torchao.float8.inference import ( _is_rowwise_scaled, addmm_float8_unwrapped_inference, ) - from torchao.quantization.granularity import PerRow, PerTensor # type: ignore - from torchao.quantization.observer import get_block_size # type: ignore - from torchao.quantization.quant_primitives import ZeroPointDomain # type: ignore + from torchao.quantization.granularity import PerRow, PerTensor + from torchao.quantization.observer import get_block_size + from torchao.quantization.quant_primitives import ZeroPointDomain else: TORCHAO_INSTALLED = False @@ -177,7 +177,8 @@ def _construct_qweight_structure(self) -> "AffineQuantizedTensor": ) def forward(self, x: torch.Tensor) -> torch.Tensor: - """If input quantization is active, compute FP8xFP8 addmm.""" + """If input quantization is active, compute FP8xFP8 addmm leveraging torchao + functionalities. Otherwise compute non-quantized addmm.""" # fp8 weight tensor for torchao qweight: AffineQuantizedTensor = self._construct_qweight_structure() @@ -282,6 +283,7 @@ def shard_fp8_linear( | input_scale | Y/N | 0/- | | bias | 0 | - | """ + param_sharding_info: dict[str, dict[str, LinearParameterShardingInfo]] = {} for module_name, module_info in module_sharding_info.items(): linear_mod: torch.nn.Module = module_info.linear_module