From b2c0d54029736d63c33e9d358e74325a272d689b Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Wed, 25 Jun 2025 01:19:14 +0000 Subject: [PATCH 01/15] Addons for FP8 attention bmm in FMS Signed-off-by: Andrea Fasoli --- fms_mo/aiu_addons/fp8/__init__.py | 0 fms_mo/aiu_addons/fp8/fp8_aiu_bmm.py | 228 +++++++++++++++++++++++++++ fms_mo/aiu_addons/fp8/fp8_aiu_op.py | 84 ++++++++++ 3 files changed, 312 insertions(+) create mode 100644 fms_mo/aiu_addons/fp8/__init__.py create mode 100644 fms_mo/aiu_addons/fp8/fp8_aiu_bmm.py create mode 100644 fms_mo/aiu_addons/fp8/fp8_aiu_op.py diff --git a/fms_mo/aiu_addons/fp8/__init__.py b/fms_mo/aiu_addons/fp8/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fms_mo/aiu_addons/fp8/fp8_aiu_bmm.py b/fms_mo/aiu_addons/fp8/fp8_aiu_bmm.py new file mode 100644 index 00000000..21b182c2 --- /dev/null +++ b/fms_mo/aiu_addons/fp8/fp8_aiu_bmm.py @@ -0,0 +1,228 @@ +# Copyright The FMS Model Optimizer Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""FMS registration of attention BMM operation using torch-registered scaled BMM.""" + +# Standard +from importlib.util import find_spec +from typing import NotRequired, Unpack +import math + +# Third Party +from fms.modules.attention import ( + AttentionKwargs, + _sdpa_update_attn_kwargs, + register_attention_op, +) +from torch import Tensor +import torch + +# Local +import fms_mo.aiu_addons.fp8.fp8_aiu_op # pylint: disable=unused-import + +if find_spec("torchao"): + TORCHAO_INSTALLED = True + # Third Party + from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor + from torchao.dtypes.floatx.float8_layout import ( + Float8AQTTensorImpl, + Float8Layout, + Float8MMConfig, + ) + from torchao.quantization.granularity import PerTensor + from torchao.quantization.observer import get_block_size + from torchao.quantization.quant_primitives import ZeroPointDomain +else: + TORCHAO_INSTALLED = False + + +class MathFP8AttentionKwargs(AttentionKwargs): + """TypedDict for FP8 attention.""" + + mask: NotRequired[Tensor] + do_scale_q: bool + is_causal_mask: bool + + +# TODO: Doesn't quite work yet, more discussion needed +Q_RANGE = 200.0 +K_RANGE = 200.0 +V_RANGE = 100.0 + + +def _construct_fp8_cache( + tensor: Tensor, scale: Tensor, orig_dtype: torch.dtype +) -> AffineQuantizedTensor: + """Construct the torchao tensor to save kv cache with its scales.""" + + weight_granularity = PerTensor() + fp8_layout = Float8Layout(Float8MMConfig(use_fast_accum=True)) + return AffineQuantizedTensor( + Float8AQTTensorImpl.from_plain( + tensor, + scale, + None, + fp8_layout, + ), + get_block_size(tensor.shape, weight_granularity), + tensor.shape, + zero_point_domain=ZeroPointDomain.NONE, + dtype=orig_dtype, + ) + + +def _math_fp8_store_op( + keys: Tensor, # pylint: disable=unused-argument + values: Tensor, + key_cache: Tensor | None, + value_cache: Tensor | None, + **attn_kwargs: Unpack[MathFP8AttentionKwargs], +) -> tuple[Tensor, Tensor, Tensor, Tensor]: + """Implement math of KV cache storing.""" + + orig_dtype = keys.dtype + + if isinstance(key_cache, AffineQuantizedTensor) and isinstance( + value_cache, AffineQuantizedTensor + ): + k_scale = key_cache.tensor_impl.scale + v_scale = value_cache.tensor_impl.scale + else: + k_scale = (torch.abs(keys).max() / K_RANGE).to(dtype=torch.float32) + v_scale = (torch.abs(values).max() / V_RANGE).to(dtype=torch.float32) + + keys = (keys / k_scale).to(torch.float8_e4m3fn).transpose(2, 1) + values = (values / v_scale).to(torch.float8_e4m3fn).transpose(2, 1) + + if ( + isinstance(key_cache, AffineQuantizedTensor) + and isinstance(value_cache, AffineQuantizedTensor) + and value_cache.numel() > 0 + ): + key_cache = torch.cat((key_cache.tensor_impl.float8_data, keys), dim=2) + value_cache = torch.cat((value_cache.tensor_impl.float8_data, values), dim=2) + key_cache = _construct_fp8_cache(key_cache, k_scale, orig_dtype) + value_cache = _construct_fp8_cache(value_cache, v_scale, orig_dtype) + return ( + key_cache, + value_cache, + key_cache, + value_cache, + ) + + keys = _construct_fp8_cache(keys, k_scale, orig_dtype) + values = _construct_fp8_cache(values, v_scale, orig_dtype) + return (keys, values, keys, values) + + +def _math_fp8_compute_op( + query: Tensor, + key_cache: Tensor, + value_cache: Tensor, + nheads: int, + kvheads: int, + p_dropout: float, + scale_factor: float | None, + **attn_kwargs: Unpack[MathFP8AttentionKwargs], +) -> Tensor: + """Implement computation of attention BMM, leveraging the custom scaled attention + BMM op that was pre-registered for torch.compile.""" + + orig_dtype = query.dtype + + q_scale = torch.tensor(1.0, dtype=torch.float32, device=query.device) + if attn_kwargs.get("do_scale_q", False): + q_scale.copy_(torch.abs(query).max() / Q_RANGE) + query = query / q_scale + + query = query.to(torch.float8_e4m3fn).transpose(2, 1) + + if isinstance(key_cache, AffineQuantizedTensor) and isinstance( + value_cache, AffineQuantizedTensor + ): + k_scale = key_cache.tensor_impl.scale + v_scale = value_cache.tensor_impl.scale + key_cache = key_cache.tensor_impl.float8_data + value_cache = value_cache.tensor_impl.float8_data + else: + k_scale = (torch.abs(key_cache).max() / K_RANGE).to(dtype=torch.float32) + v_scale = (torch.abs(value_cache).max() / V_RANGE).to(dtype=torch.float32) + key_cache = (key_cache / k_scale).to(torch.float8_e4m3fn) + value_cache = (value_cache / v_scale).to(torch.float8_e4m3fn) + + # no longer transposing prior to store, so need to check this in case of no cache + # TODO: Refactor FMS to avoid edge cases where this fails; add use_cache param here + if key_cache.shape[1] != kvheads and key_cache.shape[2] == kvheads: + key_cache = key_cache.transpose(2, 1) + value_cache = value_cache.transpose(2, 1) + + mask = attn_kwargs.get("mask", None) + if mask is not None: + # Our expected mask format is bs x q_len x k_len, so to make it broadcastable + # we need to create the nheads dimension + while len(mask.size()) != 4: # expects bs (x nheads) x q_len x kv_len + mask = mask.unsqueeze(1) + + L, S = query.size(-2), key_cache.size(-2) + scale_factor = ( + 1 / math.sqrt(query.size(-1)) if scale_factor is None else scale_factor + ) + attn_bias = torch.zeros(L, S, dtype=orig_dtype, device=query.device) + if attn_kwargs.get("is_causal_mask", False): + assert mask is None + temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(torch.float32) + + if mask is not None: + if mask.dtype == torch.bool: + attn_bias.masked_fill_(mask.logical_not(), float("-inf")) + else: + attn_bias = mask + attn_bias + + expansion = nheads // kvheads + if expansion > 1: + key_cache = key_cache.repeat_interleave( + query.size(-3) // key_cache.size(-3), -3 + ) + value_cache = value_cache.repeat_interleave( + query.size(-3) // value_cache.size(-3), -3 + ) + + attn_weight = ( + torch.ops.sendnn.scaled_bmm( + query, + key_cache.transpose(-2, -1), + q_scale, + k_scale, + out_dtype=orig_dtype, + use_fast_accum=True, + ) + * scale_factor + ) + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + attn_weight = torch.dropout(attn_weight, p_dropout, train=True) + # Do matmul in orig_dtype + attn = attn_weight @ (value_cache.to(dtype=orig_dtype) * v_scale) + + attn = attn.to(orig_dtype).transpose(2, 1).contiguous() + return attn + + +register_attention_op( + "math_fp8", + _math_fp8_store_op, + _math_fp8_compute_op, + update_attn_kwargs_op=_sdpa_update_attn_kwargs, +) diff --git a/fms_mo/aiu_addons/fp8/fp8_aiu_op.py b/fms_mo/aiu_addons/fp8/fp8_aiu_op.py new file mode 100644 index 00000000..483cff16 --- /dev/null +++ b/fms_mo/aiu_addons/fp8/fp8_aiu_op.py @@ -0,0 +1,84 @@ +# Copyright The FMS Model Optimizer Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Torch registration of FP8xFP8 operation for attention BMMs.""" + +# Third Party +from torch import Tensor +import torch + +# pylint: disable=unused-argument +# abstract op must be registered with specific I/O, even if not in use by the op function + + +@torch.library.custom_op("sendnn::scaled_bmm", mutates_args=()) +def sendnn_scaled_bmm( + mat1: Tensor, + mat2: Tensor, + scale1: Tensor, + scale2: Tensor, + out_dtype: torch.dtype | None = None, + use_fast_accum: bool = False, +) -> Tensor: + """Implement a custom scaled attention BMM op: a batched version of _scaled_mm. + The operations that are part of this function are not exposed to the computational + graph, but are invoked when running on non-AIU devices. + """ + + assert ( + mat1.shape[:-2] == mat2.shape[:-2] + ), "batch dimensions must match for mat1 and mat2" + assert ( + mat1.shape[:-2] == scale1.shape[:-2] + ), "batch dimensions must match for mat1 and scale1" + assert ( + mat2.shape[:-2] == scale2.shape[:-2] + ), "batch dimensions must match for mat2 and scale2" + + mat1 = mat1.view(-1, *mat1.shape[-2:]) + mat2 = mat2.view(-1, *mat2.shape[-2:]) + scale1 = scale1.view(-1, *scale1.shape[-2:]) + scale2 = scale2.view(-1, *scale2.shape[-2:]) + out = torch.empty( + (mat1.shape[0], mat1.shape[1], mat2.shape[2]), + dtype=out_dtype, + device=mat1.device, + ) + for b_idx in range(mat1.shape[0]): + out[b_idx] = torch._scaled_mm( + mat1[b_idx], + mat2[b_idx], + scale1[b_idx], + scale2[b_idx], + out_dtype, + use_fast_accum, + ) + return out + + +@sendnn_scaled_bmm.register_fake +def _( + mat1: Tensor, + mat2: Tensor, + scale1: Tensor, + scale2: Tensor, + out_dtype: torch.dtype | None = None, + use_fast_accum: bool = False, +) -> Tensor: + """Template for scaled attention BMM operation. I/O retain the expected size.""" + + return torch.empty( + (*mat1.shape[:-2], mat1.shape[-2], mat2.shape[-1]), + dtype=out_dtype, + device=mat1.device, + ) From 6f289b09861d83e0295c126f73a297505ea5bf42 Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Fri, 27 Jun 2025 23:27:55 +0000 Subject: [PATCH 02/15] Update FP8 bmm Signed-off-by: Andrea Fasoli --- .../fp8/{fp8_aiu_bmm.py => fp8_bmm.py} | 103 ++++-------- fms_mo/aiu_addons/fp8/fp8_linear.py | 0 .../fp8/{fp8_aiu_op.py => fp8_spyre_op.py} | 21 +-- fms_mo/aiu_addons/fp8/fp8_utils.py | 148 ++++++++++++++++++ 4 files changed, 186 insertions(+), 86 deletions(-) rename fms_mo/aiu_addons/fp8/{fp8_aiu_bmm.py => fp8_bmm.py} (64%) create mode 100644 fms_mo/aiu_addons/fp8/fp8_linear.py rename fms_mo/aiu_addons/fp8/{fp8_aiu_op.py => fp8_spyre_op.py} (80%) create mode 100644 fms_mo/aiu_addons/fp8/fp8_utils.py diff --git a/fms_mo/aiu_addons/fp8/fp8_aiu_bmm.py b/fms_mo/aiu_addons/fp8/fp8_bmm.py similarity index 64% rename from fms_mo/aiu_addons/fp8/fp8_aiu_bmm.py rename to fms_mo/aiu_addons/fp8/fp8_bmm.py index 21b182c2..e04cce9b 100644 --- a/fms_mo/aiu_addons/fp8/fp8_aiu_bmm.py +++ b/fms_mo/aiu_addons/fp8/fp8_bmm.py @@ -14,7 +14,6 @@ """FMS registration of attention BMM operation using torch-registered scaled BMM.""" # Standard -from importlib.util import find_spec from typing import NotRequired, Unpack import math @@ -24,79 +23,44 @@ _sdpa_update_attn_kwargs, register_attention_op, ) -from torch import Tensor import torch # Local -import fms_mo.aiu_addons.fp8.fp8_aiu_op # pylint: disable=unused-import - -if find_spec("torchao"): - TORCHAO_INSTALLED = True - # Third Party - from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor - from torchao.dtypes.floatx.float8_layout import ( - Float8AQTTensorImpl, - Float8Layout, - Float8MMConfig, - ) - from torchao.quantization.granularity import PerTensor - from torchao.quantization.observer import get_block_size - from torchao.quantization.quant_primitives import ZeroPointDomain -else: - TORCHAO_INSTALLED = False +from fms_mo.aiu_addons.fp8.fp8_utils import ScaledTensor +import fms_mo.aiu_addons.fp8.fp8_spyre_op # pylint: disable=unused-import class MathFP8AttentionKwargs(AttentionKwargs): """TypedDict for FP8 attention.""" - mask: NotRequired[Tensor] + mask: NotRequired[torch.Tensor] do_scale_q: bool is_causal_mask: bool -# TODO: Doesn't quite work yet, more discussion needed +# TODO: Figure out better scales for AIU? These come from vLLM Q_RANGE = 200.0 K_RANGE = 200.0 V_RANGE = 100.0 -def _construct_fp8_cache( - tensor: Tensor, scale: Tensor, orig_dtype: torch.dtype -) -> AffineQuantizedTensor: - """Construct the torchao tensor to save kv cache with its scales.""" - - weight_granularity = PerTensor() - fp8_layout = Float8Layout(Float8MMConfig(use_fast_accum=True)) - return AffineQuantizedTensor( - Float8AQTTensorImpl.from_plain( - tensor, - scale, - None, - fp8_layout, - ), - get_block_size(tensor.shape, weight_granularity), - tensor.shape, - zero_point_domain=ZeroPointDomain.NONE, - dtype=orig_dtype, - ) +def _construct_fp8_cache(tensor: torch.Tensor, scale: torch.Tensor) -> ScaledTensor: + """Construct the custom object to save KV cache with its scales.""" + return ScaledTensor(tensor, scale) def _math_fp8_store_op( - keys: Tensor, # pylint: disable=unused-argument - values: Tensor, - key_cache: Tensor | None, - value_cache: Tensor | None, + keys: torch.Tensor, # pylint: disable=unused-argument + values: torch.Tensor, + key_cache: torch.Tensor | None, + value_cache: torch.Tensor | None, **attn_kwargs: Unpack[MathFP8AttentionKwargs], -) -> tuple[Tensor, Tensor, Tensor, Tensor]: +) -> tuple[ScaledTensor, ScaledTensor, ScaledTensor, ScaledTensor]: """Implement math of KV cache storing.""" - orig_dtype = keys.dtype - - if isinstance(key_cache, AffineQuantizedTensor) and isinstance( - value_cache, AffineQuantizedTensor - ): - k_scale = key_cache.tensor_impl.scale - v_scale = value_cache.tensor_impl.scale + if isinstance(key_cache, ScaledTensor) and isinstance(value_cache, ScaledTensor): + k_scale = key_cache._scale + v_scale = value_cache._scale else: k_scale = (torch.abs(keys).max() / K_RANGE).to(dtype=torch.float32) v_scale = (torch.abs(values).max() / V_RANGE).to(dtype=torch.float32) @@ -105,36 +69,35 @@ def _math_fp8_store_op( values = (values / v_scale).to(torch.float8_e4m3fn).transpose(2, 1) if ( - isinstance(key_cache, AffineQuantizedTensor) - and isinstance(value_cache, AffineQuantizedTensor) + isinstance(key_cache, ScaledTensor) + and isinstance(value_cache, ScaledTensor) and value_cache.numel() > 0 ): - key_cache = torch.cat((key_cache.tensor_impl.float8_data, keys), dim=2) - value_cache = torch.cat((value_cache.tensor_impl.float8_data, values), dim=2) - key_cache = _construct_fp8_cache(key_cache, k_scale, orig_dtype) - value_cache = _construct_fp8_cache(value_cache, v_scale, orig_dtype) + key_cache = torch.cat((key_cache._data, keys), dim=2) + value_cache = torch.cat((value_cache._data, values), dim=2) + key_cache = _construct_fp8_cache(key_cache, k_scale) + value_cache = _construct_fp8_cache(value_cache, v_scale) return ( key_cache, value_cache, key_cache, value_cache, ) - - keys = _construct_fp8_cache(keys, k_scale, orig_dtype) - values = _construct_fp8_cache(values, v_scale, orig_dtype) + keys = _construct_fp8_cache(keys.contiguous(), k_scale) + values = _construct_fp8_cache(values.contiguous(), v_scale) return (keys, values, keys, values) def _math_fp8_compute_op( - query: Tensor, - key_cache: Tensor, - value_cache: Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, nheads: int, kvheads: int, p_dropout: float, scale_factor: float | None, **attn_kwargs: Unpack[MathFP8AttentionKwargs], -) -> Tensor: +) -> torch.Tensor: """Implement computation of attention BMM, leveraging the custom scaled attention BMM op that was pre-registered for torch.compile.""" @@ -147,13 +110,11 @@ def _math_fp8_compute_op( query = query.to(torch.float8_e4m3fn).transpose(2, 1) - if isinstance(key_cache, AffineQuantizedTensor) and isinstance( - value_cache, AffineQuantizedTensor - ): - k_scale = key_cache.tensor_impl.scale - v_scale = value_cache.tensor_impl.scale - key_cache = key_cache.tensor_impl.float8_data - value_cache = value_cache.tensor_impl.float8_data + if isinstance(key_cache, ScaledTensor) and isinstance(value_cache, ScaledTensor): + k_scale = key_cache._scale + v_scale = value_cache._scale + key_cache = key_cache._data + value_cache = value_cache._data else: k_scale = (torch.abs(key_cache).max() / K_RANGE).to(dtype=torch.float32) v_scale = (torch.abs(value_cache).max() / V_RANGE).to(dtype=torch.float32) diff --git a/fms_mo/aiu_addons/fp8/fp8_linear.py b/fms_mo/aiu_addons/fp8/fp8_linear.py new file mode 100644 index 00000000..e69de29b diff --git a/fms_mo/aiu_addons/fp8/fp8_aiu_op.py b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py similarity index 80% rename from fms_mo/aiu_addons/fp8/fp8_aiu_op.py rename to fms_mo/aiu_addons/fp8/fp8_spyre_op.py index 483cff16..ad91fab0 100644 --- a/fms_mo/aiu_addons/fp8/fp8_aiu_op.py +++ b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py @@ -21,7 +21,7 @@ # abstract op must be registered with specific I/O, even if not in use by the op function -@torch.library.custom_op("sendnn::scaled_bmm", mutates_args=()) +@torch.library.custom_op("spyre::scaled_bmm", mutates_args=()) def sendnn_scaled_bmm( mat1: Tensor, mat2: Tensor, @@ -38,17 +38,8 @@ def sendnn_scaled_bmm( assert ( mat1.shape[:-2] == mat2.shape[:-2] ), "batch dimensions must match for mat1 and mat2" - assert ( - mat1.shape[:-2] == scale1.shape[:-2] - ), "batch dimensions must match for mat1 and scale1" - assert ( - mat2.shape[:-2] == scale2.shape[:-2] - ), "batch dimensions must match for mat2 and scale2" - mat1 = mat1.view(-1, *mat1.shape[-2:]) mat2 = mat2.view(-1, *mat2.shape[-2:]) - scale1 = scale1.view(-1, *scale1.shape[-2:]) - scale2 = scale2.view(-1, *scale2.shape[-2:]) out = torch.empty( (mat1.shape[0], mat1.shape[1], mat2.shape[2]), dtype=out_dtype, @@ -58,12 +49,12 @@ def sendnn_scaled_bmm( out[b_idx] = torch._scaled_mm( mat1[b_idx], mat2[b_idx], - scale1[b_idx], - scale2[b_idx], - out_dtype, - use_fast_accum, + scale1, + scale2, + out_dtype=out_dtype, + use_fast_accum=use_fast_accum, ) - return out + return out.view(*mat1.shape[:-2], mat1.shape[1], mat2.shape[2]) @sendnn_scaled_bmm.register_fake diff --git a/fms_mo/aiu_addons/fp8/fp8_utils.py b/fms_mo/aiu_addons/fp8/fp8_utils.py new file mode 100644 index 00000000..0dd21024 --- /dev/null +++ b/fms_mo/aiu_addons/fp8/fp8_utils.py @@ -0,0 +1,148 @@ +# Copyright The FMS Model Optimizer Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""FMS registration of attention BMM operation using torch-registered scaled BMM.""" + +# Standard +import functools + +# Third Party +import torch + +# pylint: disable=unused-argument +# unusued arguments are needed for templates + + +_HANDLED_FUNCTIONS = {} + + +def _implements(torch_function): + """Register a torch function override""" + + def decorator(func): + @functools.wraps(torch_function) + def wrapper(f, types, args, kwargs): + return func(f, types, args, kwargs) + + _HANDLED_FUNCTIONS[torch_function] = wrapper + return func + + return decorator + + +class ScaledTensor(torch.Tensor): + """Representation of a quantized tensor and its scale.""" + + def __new__( + cls, + data: torch.Tensor, + scale: torch.Tensor, + ): + return torch.Tensor._make_wrapper_subclass( + cls, + data.size(), + strides=data.stride(), + storage_offset=data.storage_offset(), + dtype=data.dtype, + layout=data.layout, + requires_grad=data.requires_grad, + device=data.device, + ) + + def __init__( + self, + data: torch.Tensor, + scale: torch.Tensor, + ): + self._data = data + self._scale = scale + + def __tensor_flatten__(self): + ctx = {} + return ["_data", "_scale"], ctx + + @staticmethod + def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride): + assert len(inner_tensors) == 2 + return ScaledTensor( + inner_tensors["_data"], + inner_tensors["_scale"], + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + if func in _HANDLED_FUNCTIONS: + return _HANDLED_FUNCTIONS[func](func, types, args, kwargs) + + arg_types = tuple(type(arg) for arg in args) + kwarg_types = {k: type(arg) for k, arg in kwargs.items()} + raise NotImplementedError( + f"{cls.__name__} dispatch: attempting to run unimplemented " + f"operator/function: {func=}, {types=}, {arg_types=}, {kwarg_types=}" + ) + + def __repr__(self): + return f"{self._data.__repr__()}\n{self._scale.__repr__()}" + + +def _infer_quantization_config(quant_config: dict) -> dict | None: + # There's many quantization packages compatible with HF + # We initially focus on llm-compressor as it is the one used in FMS-MO + + # llm-compressor saves its checkpoints with quant_method = compressed-tensors + # quantization_status tells us whether the model has already been quantized + # We only support loading already quantized models (compressed status) + if ( + quant_config["quant_method"] == "compressed-tensors" + and quant_config["quantization_status"] == "compressed" + ): + # FP8 quantization will have FP8 weights + # We assume a single quantization group (group_0), to follow fms-mo checkpoints + # num_bits and type tells us "float" with "8" bits, aka FP8 + if ( + quant_config["config_groups"]["group_0"]["weights"]["type"] == "float" + and quant_config["config_groups"]["group_0"]["weights"]["num_bits"] == 8 + ): + # This is used by get_linear to decide whether a linear layer + # will be quantized or not inside the model + def fp8_linear_type(name: str) -> str: + # We need to translate HF names to FMS names + translations = { + "lm_head": "head", + } + for ignored_layer in quant_config["ignore"]: + assert isinstance(ignored_layer, str) + fms_ign_layer = translations.get(ignored_layer, ignored_layer) + if name in fms_ign_layer: + return "torch_linear" + for pattern in quant_config["config_groups"]["group_0"]["targets"]: + # Special case from llm-compressor that covers all linear layers + # not in the ignore pattern + assert isinstance(pattern, str) + if pattern == "Linear": + return "fp8" + if name in translations.get(pattern, pattern): + return "fp8" + return "torch_linear" + + return { + "linear_type": fp8_linear_type, + "input_activations": quant_config["config_groups"]["group_0"][ + "input_activations" + ], + "output_activations": quant_config["config_groups"]["group_0"][ + "output_activations" + ], + "weights": quant_config["config_groups"]["group_0"]["weights"], + } + return None From 30f76a90230af491a613bbd0bc59197d3a513e51 Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Sat, 28 Jun 2025 00:11:09 +0000 Subject: [PATCH 03/15] Add FP8 adapter step Signed-off-by: Andrea Fasoli --- fms_mo/aiu_addons/fp8/fp8_adapter.py | 55 +++++++++++++++++++++++++++ fms_mo/aiu_addons/fp8/fp8_spyre_op.py | 2 +- fms_mo/aiu_addons/fp8/fp8_utils.py | 2 +- 3 files changed, 57 insertions(+), 2 deletions(-) create mode 100644 fms_mo/aiu_addons/fp8/fp8_adapter.py diff --git a/fms_mo/aiu_addons/fp8/fp8_adapter.py b/fms_mo/aiu_addons/fp8/fp8_adapter.py new file mode 100644 index 00000000..57e52a35 --- /dev/null +++ b/fms_mo/aiu_addons/fp8/fp8_adapter.py @@ -0,0 +1,55 @@ +# Copyright The FMS Model Optimizer Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Implement and register FMS adapters for FP8 checkpoint loading.""" + +# Standard +from typing import Any, Mapping + +# Third Party +from fms.modules.linear import get_linear_type +from fms.utils import serialization +from fms.utils.config import ModelConfig + +# NOTE: this adapter step must be registered before the adapter that uses it (such as +# the llama adapter in fms.models.llama) +# TODO: may be shared with gptq llama +# TODO: generalize across architectures if possible +def _hf_fp8_llama_check( + input_sd: Mapping[str, Any], model_config: ModelConfig | None = None, **kwargs +) -> Mapping[str, Any]: + """Implementation of adapter step for FMS Llama: ensure that when FP8 quantization + is in use, weights are unfused. + """ + + has_fused_weights = True + linear_type = "torch_linear" + if model_config: + if not model_config.fused_weights: + has_fused_weights = False + if model_config.linear_config: + linear_type = model_config.linear_config["linear_type"] + if callable(linear_type): + # Calling this with "any" guarantees "fp8" to be returned + # when loading an HF fp8 checkpoint, and never in any other condition + linear_type = get_linear_type(model_config.linear_config, "any") + + if "fp8" in linear_type and has_fused_weights: + raise ValueError( + "FP8 HF llama checkpoints cannot be loaded into a model with fused weights" + ) + + return input_sd + + +serialization.register_adapter_step("llama", "hf_fp8_llama_check", _hf_fp8_llama_check) diff --git a/fms_mo/aiu_addons/fp8/fp8_spyre_op.py b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py index ad91fab0..c6280d4a 100644 --- a/fms_mo/aiu_addons/fp8/fp8_spyre_op.py +++ b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py @@ -32,7 +32,7 @@ def sendnn_scaled_bmm( ) -> Tensor: """Implement a custom scaled attention BMM op: a batched version of _scaled_mm. The operations that are part of this function are not exposed to the computational - graph, but are invoked when running on non-AIU devices. + graph, but are invoked when running on non-Spyre devices. """ assert ( diff --git a/fms_mo/aiu_addons/fp8/fp8_utils.py b/fms_mo/aiu_addons/fp8/fp8_utils.py index 0dd21024..6ad38752 100644 --- a/fms_mo/aiu_addons/fp8/fp8_utils.py +++ b/fms_mo/aiu_addons/fp8/fp8_utils.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""FMS registration of attention BMM operation using torch-registered scaled BMM.""" +"""Utility functions and components for FP8 addon implementation.""" # Standard import functools From 496bf44c395b4caa6185b7fc3af09b2c2ca53381 Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Sat, 28 Jun 2025 00:26:52 +0000 Subject: [PATCH 04/15] Add FP8 linear to FMS addon Signed-off-by: Andrea Fasoli --- fms_mo/aiu_addons/fp8/fp8_adapter.py | 4 + fms_mo/aiu_addons/fp8/fp8_linear.py | 325 +++++++++++++++++++++++++++ fms_mo/aiu_addons/fp8/fp8_utils.py | 16 +- 3 files changed, 339 insertions(+), 6 deletions(-) diff --git a/fms_mo/aiu_addons/fp8/fp8_adapter.py b/fms_mo/aiu_addons/fp8/fp8_adapter.py index 57e52a35..0155ef39 100644 --- a/fms_mo/aiu_addons/fp8/fp8_adapter.py +++ b/fms_mo/aiu_addons/fp8/fp8_adapter.py @@ -21,6 +21,10 @@ from fms.utils import serialization from fms.utils.config import ModelConfig +# pylint: disable=unused-argument +# Retaining kwargs input arguments for consistency. + + # NOTE: this adapter step must be registered before the adapter that uses it (such as # the llama adapter in fms.models.llama) # TODO: may be shared with gptq llama diff --git a/fms_mo/aiu_addons/fp8/fp8_linear.py b/fms_mo/aiu_addons/fp8/fp8_linear.py index e69de29b..ecbf681a 100644 --- a/fms_mo/aiu_addons/fp8/fp8_linear.py +++ b/fms_mo/aiu_addons/fp8/fp8_linear.py @@ -0,0 +1,325 @@ +# Copyright The FMS Model Optimizer Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Implement FP8 linear module to be loaded via FMS.""" + +# Standard +from importlib.util import find_spec +from typing import Any, Mapping + +# Third Party +from fms.modules.linear import ( + LinearModuleShardingInfo, + LinearParameterShardingInfo, + register_linear_type_to_module_map, + register_linear_type_to_sharding_map, + shard_base_linear, +) +from fms.modules.tp import ShardType, TPModule +import torch + +# pylint: disable=not-callable +# torch.nn.functional.linear not recognized as callable +# open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482 + + +### FP8 linear layers +if find_spec("torchao"): + TORCHAO_INSTALLED = True + + # Third Party + from torchao.dtypes.affine_quantized_tensor import ( # type: ignore + AffineQuantizedTensor, + to_affine_quantized_floatx, + to_affine_quantized_floatx_static, + ) + from torchao.dtypes.floatx.float8_layout import ( # type: ignore + Float8AQTTensorImpl, + Float8Layout, + Float8MMConfig, + preprocess_data, + preprocess_scale, + ) + from torchao.dtypes.utils import get_out_shape # type: ignore + from torchao.float8.inference import ( # type: ignore + _is_rowwise_scaled, + addmm_float8_unwrapped_inference, + ) + from torchao.quantization.granularity import PerRow, PerTensor # type: ignore + from torchao.quantization.observer import get_block_size # type: ignore + from torchao.quantization.quant_primitives import ZeroPointDomain # type: ignore +else: + TORCHAO_INSTALLED = False + + +class FP8Linear(torch.nn.Module): + """Class handles FP8 weights loading and uses torchao for the matmuls.""" + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_config: Mapping[str, Any], + ): + super().__init__() + + self.in_features = in_features + self.out_features = out_features + self.has_bias = bias + self.linear_config = linear_config + + assert ( + self.linear_config["weights"] is not None + ), "Weights must always be quantized for FP8Linear" + assert self.linear_config["weights"][ + "symmetric" + ], "We only support symmetric weights for now" + assert not self.linear_config["weights"][ + "dynamic" + ], "We only support pre-quantized weights for now" + + self.weight = torch.nn.Parameter( + torch.zeros(out_features, in_features, dtype=torch.float8_e4m3fn), + requires_grad=False, + ) + + weight_scale_shape = ( + (1,) + if self.linear_config["weights"]["strategy"] == "tensor" + else (out_features, 1) + ) + self.weight_scale = torch.nn.Parameter( + torch.ones(weight_scale_shape), requires_grad=False + ) + + self.has_bias = bias + if self.has_bias: + self.bias = torch.nn.Parameter(torch.zeros((out_features,))) + + if ( + self.linear_config["input_activations"] is not None + and not self.linear_config["input_activations"]["dynamic"] + ): + input_scale_shape = ( + (1,) + if self.linear_config["input_activations"]["strategy"] == "tensor" + else (out_features, 1) + ) + self.input_scale = torch.nn.Parameter( + torch.ones(input_scale_shape), requires_grad=False + ) + + def _input_activation_quant_func_fp8( + self, + x: torch.Tensor, + activation_granularity, + activation_dtype: torch.dtype, + scale: torch.Tensor | None = None, + ): + """Quantize the input activation tensor for an aqt_float variant. + If scale is not provided, it will be dynamically calculated, otherwise the + provided scale will be used. + """ + + block_size = get_block_size(x.shape, activation_granularity) + if scale is None: + activation = to_affine_quantized_floatx( + input_float=x, + block_size=block_size, + target_dtype=activation_dtype, + scale_dtype=torch.float32, + _layout=Float8Layout(mm_config=None), # Config is stored on weight + ) + else: + assert isinstance( + activation_granularity, PerTensor + ), "Static quantization only supports PerTensor granularity" + activation = to_affine_quantized_floatx_static( + input_float=x, + block_size=block_size, + scale=scale, + target_dtype=activation_dtype, + _layout=Float8Layout(mm_config=None), # Config is stored on weight + ) + return activation + + def _construct_qweight_structure(self) -> "AffineQuantizedTensor": + """Construct the torchao machinery for the fp8 matmul""" + + weight_granularity = ( + PerTensor() + if self.linear_config["weights"]["strategy"] == "tensor" + else PerRow() + ) + fp8_layout = Float8Layout(Float8MMConfig(use_fast_accum=True)) + return AffineQuantizedTensor( + Float8AQTTensorImpl.from_plain( + self.weight, + self.weight_scale.squeeze().to(torch.float32), + None, + fp8_layout, + ), + get_block_size(self.weight.shape, weight_granularity), + self.weight.shape, + zero_point_domain=ZeroPointDomain.NONE, + dtype=self.weight_scale.dtype, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """If input quantization is active, compute FP8xFP8 addmm.""" + + # fp8 weight tensor for torchao + qweight: AffineQuantizedTensor = self._construct_qweight_structure() + + if self.linear_config["input_activations"] is not None: + # activations are also fp8, quantize as required by model + act_granularity = ( + PerTensor() + if self.linear_config["input_activations"]["strategy"] == "tensor" + else PerRow() + ) + input_quant_kwargs = { + "activation_granularity": act_granularity, + "activation_dtype": torch.float8_e4m3fn, + } + if not self.linear_config["input_activations"]["dynamic"]: + input_quant_kwargs["scale"] = self.input_scale.squeeze().to( + torch.float32 + ) + qx = self._input_activation_quant_func_fp8(x, **input_quant_kwargs) + + # Copied from torchao _linear_fp8_act_fp8_weight_impl (with changes to support fp8 out) + scaled_mm_config = Float8MMConfig(use_fast_accum=True) + out_shape = get_out_shape(qx.shape, qweight.shape) + + # Weight tensor preprocessing + w_tensor_impl = qweight.tensor_impl + assert not w_tensor_impl.transposed, "Weight tensor must be contiguous" + w_data = w_tensor_impl.float8_data + w_scale = w_tensor_impl.scale + + # Input tensor preprocessing + inpt_data = qx.tensor_impl.float8_data + input_scale = qx.tensor_impl.scale + # Handle case where input tensor is more than 2D + inpt_data = inpt_data.reshape(-1, inpt_data.shape[-1]) + + # Handle rowwise case + if _is_rowwise_scaled(qweight): + assert _is_rowwise_scaled(qx), "Input tensor must be rowwise block size" + w_scale = w_scale.unsqueeze(-1).T + input_scale = preprocess_scale(input_scale, qx.shape) + + # Preprocess data + inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config) + + # Perform the computation + return addmm_float8_unwrapped_inference( + inpt_data, + input_scale, + w_data, + w_scale, + output_dtype=qx.dtype, + bias=getattr(self, "bias", None), + use_fast_accum=scaled_mm_config.use_fast_accum, + ).reshape(out_shape) + + # activations not quantized, dequant fp8 weight and do regular matmul + out = torch.nn.functional.linear( + x, qweight.dequantize(), self.bias if self.has_bias else None + ) + return out + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}" + f"(in={self.in_features}, out={self.out_features}, " + f"bias={self.has_bias}, fp8_config={self.linear_config})" + ) + + +def get_fp8_linear( + in_features: int, + out_features: int, + bias: bool, + linear_config: Mapping[str, Any], +) -> FP8Linear: + """Retrieve an FP8 Linear module""" + + if not TORCHAO_INSTALLED: + raise ModuleNotFoundError("You need to install torchao for FP8 support in FMS!") + + return FP8Linear(in_features, out_features, bias, linear_config) + + +def shard_fp8_linear( + tensor_values: dict[str, torch.Tensor], + tp_module: TPModule, + module_sharding_info: dict[str, LinearModuleShardingInfo], +) -> set | None: + """ + | GPU | + sharding | param | shard | dim | + ----------+----------------+-------+-----| + colwise | weight | Y | 0 | + | weight_scale | N | - | + | input_scale | N | - | + | bias | Y | 0 | + ----------+----------------+-------+-----| + rowwise | weight | Y | 1 | + | weight_scale | Y/N | 0/- | + | input_scale | Y/N | 0/- | + | bias | 0 | - | + """ + param_sharding_info: dict[str, dict[str, LinearParameterShardingInfo]] = {} + for module_name, module_info in module_sharding_info.items(): + linear_mod: torch.nn.Module = module_info.linear_module + weight_strategy = getattr(linear_mod, "linear_config")["input_activations"][ + "strategy" + ] + # Scales are per-row or per-tensor + # Only sharding needed when row parallel and per-row + shard_scales = weight_strategy != "tensor" and module_info.sharding_dim == 1 + params: dict[str, LinearParameterShardingInfo] = { + "weight": LinearParameterShardingInfo( + module_info.sharding_dim, ShardType.SHARD + ), + "weight_scale": LinearParameterShardingInfo( + module_info.sharding_dim, + ShardType.SHARD if shard_scales else ShardType.CLONE, + ), + } + if hasattr(linear_mod, "input_scale"): + params["input_scale"] = LinearParameterShardingInfo( + module_info.sharding_dim, + ShardType.SHARD if shard_scales else ShardType.CLONE, + ) + if hasattr(linear_mod, "bias") and linear_mod.bias is not None: + params["bias"] = LinearParameterShardingInfo( + module_info.sharding_dim, + ShardType.SHARD if module_info.sharding_dim == 0 else ShardType.RANK0, + ) + param_sharding_info[module_name] = params + + unused_keys = shard_base_linear( + tensor_values, + tp_module, + module_sharding_info, + param_sharding_info, + ) + return unused_keys + + +register_linear_type_to_module_map("fp8", get_fp8_linear) +register_linear_type_to_sharding_map("fp8", shard_fp8_linear) diff --git a/fms_mo/aiu_addons/fp8/fp8_utils.py b/fms_mo/aiu_addons/fp8/fp8_utils.py index 6ad38752..0f5fbab4 100644 --- a/fms_mo/aiu_addons/fp8/fp8_utils.py +++ b/fms_mo/aiu_addons/fp8/fp8_utils.py @@ -59,7 +59,7 @@ def __new__( device=data.device, ) - def __init__( + def __init__( # pylint: disable=super-init-not-called self, data: torch.Tensor, scale: torch.Tensor, @@ -96,12 +96,16 @@ def __repr__(self): def _infer_quantization_config(quant_config: dict) -> dict | None: - # There's many quantization packages compatible with HF - # We initially focus on llm-compressor as it is the one used in FMS-MO + """Construct linear_config dictionary carrying FP8 configuration for FMS. + + There's many quantization packages compatible with HF + We initially focus on llm-compressor as it is the one used in FMS-MO + + llm-compressor saves its checkpoints with quant_method = compressed-tensors + quantization_status tells us whether the model has already been quantized + We only support loading already quantized models (compressed status) + """ - # llm-compressor saves its checkpoints with quant_method = compressed-tensors - # quantization_status tells us whether the model has already been quantized - # We only support loading already quantized models (compressed status) if ( quant_config["quant_method"] == "compressed-tensors" and quant_config["quantization_status"] == "compressed" From c931ad73d54fa7f2173e24619f853335fcddc08d Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Sat, 28 Jun 2025 01:44:17 +0000 Subject: [PATCH 05/15] rename fp8 attention Signed-off-by: Andrea Fasoli --- fms_mo/aiu_addons/fp8/fp8_adapter.py | 2 +- .../fp8/{fp8_bmm.py => fp8_attn.py} | 0 fms_mo/aiu_addons/fp8/fp8_linear.py | 20 ++++++++++--------- 3 files changed, 12 insertions(+), 10 deletions(-) rename fms_mo/aiu_addons/fp8/{fp8_bmm.py => fp8_attn.py} (100%) diff --git a/fms_mo/aiu_addons/fp8/fp8_adapter.py b/fms_mo/aiu_addons/fp8/fp8_adapter.py index 0155ef39..7f339f72 100644 --- a/fms_mo/aiu_addons/fp8/fp8_adapter.py +++ b/fms_mo/aiu_addons/fp8/fp8_adapter.py @@ -22,7 +22,7 @@ from fms.utils.config import ModelConfig # pylint: disable=unused-argument -# Retaining kwargs input arguments for consistency. +# Retaining kwargs input arguments for consistency with other adapter steps. # NOTE: this adapter step must be registered before the adapter that uses it (such as diff --git a/fms_mo/aiu_addons/fp8/fp8_bmm.py b/fms_mo/aiu_addons/fp8/fp8_attn.py similarity index 100% rename from fms_mo/aiu_addons/fp8/fp8_bmm.py rename to fms_mo/aiu_addons/fp8/fp8_attn.py diff --git a/fms_mo/aiu_addons/fp8/fp8_linear.py b/fms_mo/aiu_addons/fp8/fp8_linear.py index ecbf681a..3e2246ba 100644 --- a/fms_mo/aiu_addons/fp8/fp8_linear.py +++ b/fms_mo/aiu_addons/fp8/fp8_linear.py @@ -33,31 +33,31 @@ # open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482 -### FP8 linear layers +# Gated torchao imports for FP8 implementation if find_spec("torchao"): TORCHAO_INSTALLED = True # Third Party - from torchao.dtypes.affine_quantized_tensor import ( # type: ignore + from torchao.dtypes.affine_quantized_tensor import ( AffineQuantizedTensor, to_affine_quantized_floatx, to_affine_quantized_floatx_static, ) - from torchao.dtypes.floatx.float8_layout import ( # type: ignore + from torchao.dtypes.floatx.float8_layout import ( Float8AQTTensorImpl, Float8Layout, Float8MMConfig, preprocess_data, preprocess_scale, ) - from torchao.dtypes.utils import get_out_shape # type: ignore - from torchao.float8.inference import ( # type: ignore + from torchao.dtypes.utils import get_out_shape + from torchao.float8.inference import ( _is_rowwise_scaled, addmm_float8_unwrapped_inference, ) - from torchao.quantization.granularity import PerRow, PerTensor # type: ignore - from torchao.quantization.observer import get_block_size # type: ignore - from torchao.quantization.quant_primitives import ZeroPointDomain # type: ignore + from torchao.quantization.granularity import PerRow, PerTensor + from torchao.quantization.observer import get_block_size + from torchao.quantization.quant_primitives import ZeroPointDomain else: TORCHAO_INSTALLED = False @@ -177,7 +177,8 @@ def _construct_qweight_structure(self) -> "AffineQuantizedTensor": ) def forward(self, x: torch.Tensor) -> torch.Tensor: - """If input quantization is active, compute FP8xFP8 addmm.""" + """If input quantization is active, compute FP8xFP8 addmm leveraging torchao + functionalities. Otherwise compute non-quantized addmm.""" # fp8 weight tensor for torchao qweight: AffineQuantizedTensor = self._construct_qweight_structure() @@ -282,6 +283,7 @@ def shard_fp8_linear( | input_scale | Y/N | 0/- | | bias | 0 | - | """ + param_sharding_info: dict[str, dict[str, LinearParameterShardingInfo]] = {} for module_name, module_info in module_sharding_info.items(): linear_mod: torch.nn.Module = module_info.linear_module From f05beb5590dcf59ee63186f2cb47f195e3435b92 Mon Sep 17 00:00:00 2001 From: Antoni Viros i Martin Date: Mon, 30 Jun 2025 19:22:38 +0000 Subject: [PATCH 06/15] Fix linting, add paged attention kernels Signed-off-by: Antoni Viros i Martin --- fms_mo/aiu_addons/__init__.py | 55 +++ fms_mo/aiu_addons/fp8/fp8_adapter.py | 31 +- fms_mo/aiu_addons/fp8/fp8_attn.py | 105 +++++- fms_mo/aiu_addons/fp8/fp8_linear.py | 497 +++++++++++++------------- fms_mo/aiu_addons/fp8/fp8_spyre_op.py | 146 ++++++++ fms_mo/aiu_addons/fp8/fp8_utils.py | 66 +--- fms_mo/utils/import_utils.py | 1 + pyproject.toml | 3 +- 8 files changed, 577 insertions(+), 327 deletions(-) diff --git a/fms_mo/aiu_addons/__init__.py b/fms_mo/aiu_addons/__init__.py index e69de29b..e4f30082 100644 --- a/fms_mo/aiu_addons/__init__.py +++ b/fms_mo/aiu_addons/__init__.py @@ -0,0 +1,55 @@ +def _infer_quantization_config(quant_config: dict) -> dict | None: + """Construct linear_config dictionary carrying FP8 configuration for FMS. + + There's many quantization packages compatible with HF + We initially focus on llm-compressor as it is the one used in FMS-MO + + llm-compressor saves its checkpoints with quant_method = compressed-tensors + quantization_status tells us whether the model has already been quantized + We only support loading already quantized models (compressed status) + """ + + if ( + quant_config["quant_method"] == "compressed-tensors" + and quant_config["quantization_status"] == "compressed" + ): + # FP8 quantization will have FP8 weights + # We assume a single quantization group (group_0), to follow fms-mo checkpoints + # num_bits and type tells us "float" with "8" bits, aka FP8 + if ( + quant_config["config_groups"]["group_0"]["weights"]["type"] == "float" + and quant_config["config_groups"]["group_0"]["weights"]["num_bits"] == 8 + ): + # This is used by get_linear to decide whether a linear layer + # will be quantized or not inside the model + def fp8_linear_type(name: str) -> str: + # We need to translate HF names to FMS names + translations = { + "lm_head": "head", + } + for ignored_layer in quant_config["ignore"]: + assert isinstance(ignored_layer, str) + fms_ign_layer = translations.get(ignored_layer, ignored_layer) + if name in fms_ign_layer: + return "torch_linear" + for pattern in quant_config["config_groups"]["group_0"]["targets"]: + # Special case from llm-compressor that covers all linear layers + # not in the ignore pattern + assert isinstance(pattern, str) + if pattern == "Linear": + return "fp8" + if name in translations.get(pattern, pattern): + return "fp8" + return "torch_linear" + + return { + "linear_type": fp8_linear_type, + "input_activations": quant_config["config_groups"]["group_0"][ + "input_activations" + ], + "output_activations": quant_config["config_groups"]["group_0"][ + "output_activations" + ], + "weights": quant_config["config_groups"]["group_0"]["weights"], + } + return None diff --git a/fms_mo/aiu_addons/fp8/fp8_adapter.py b/fms_mo/aiu_addons/fp8/fp8_adapter.py index 7f339f72..98c7ca0f 100644 --- a/fms_mo/aiu_addons/fp8/fp8_adapter.py +++ b/fms_mo/aiu_addons/fp8/fp8_adapter.py @@ -15,6 +15,7 @@ # Standard from typing import Any, Mapping +import functools # Third Party from fms.modules.linear import get_linear_type @@ -25,15 +26,15 @@ # Retaining kwargs input arguments for consistency with other adapter steps. -# NOTE: this adapter step must be registered before the adapter that uses it (such as -# the llama adapter in fms.models.llama) # TODO: may be shared with gptq llama -# TODO: generalize across architectures if possible -def _hf_fp8_llama_check( - input_sd: Mapping[str, Any], model_config: ModelConfig | None = None, **kwargs +def _hf_fp8_check( + input_sd: Mapping[str, Any], + model_config: ModelConfig | None = None, + checkpoint_is_fused: bool = False, + **kwargs, ) -> Mapping[str, Any]: - """Implementation of adapter step for FMS Llama: ensure that when FP8 quantization - is in use, weights are unfused. + """Implementation of adapter step for FMS: ensure that when FP8 quantization + is in use, weights are fused like the model checkpoint. """ has_fused_weights = True @@ -44,11 +45,11 @@ def _hf_fp8_llama_check( if model_config.linear_config: linear_type = model_config.linear_config["linear_type"] if callable(linear_type): - # Calling this with "any" guarantees "fp8" to be returned + # Calling this function with "any" guarantees "fp8" to be returned # when loading an HF fp8 checkpoint, and never in any other condition linear_type = get_linear_type(model_config.linear_config, "any") - if "fp8" in linear_type and has_fused_weights: + if "fp8" in linear_type and has_fused_weights != checkpoint_is_fused: raise ValueError( "FP8 HF llama checkpoints cannot be loaded into a model with fused weights" ) @@ -56,4 +57,14 @@ def _hf_fp8_llama_check( return input_sd -serialization.register_adapter_step("llama", "hf_fp8_llama_check", _hf_fp8_llama_check) +serialization.register_adapter_step( + "llama", "hf_fp8_check", functools.partial(_hf_fp8_check, checkpoint_is_fused=False) +) +serialization.extend_adapter("llama", "hf", ["hf_fp8_check"]) + +serialization.register_adapter_step( + "granite", + "hf_fp8_check", + functools.partial(_hf_fp8_check, checkpoint_is_fused=False), +) +serialization.extend_adapter("granite", "hf", ["hf_fp8_check"]) diff --git a/fms_mo/aiu_addons/fp8/fp8_attn.py b/fms_mo/aiu_addons/fp8/fp8_attn.py index e04cce9b..46fcc01c 100644 --- a/fms_mo/aiu_addons/fp8/fp8_attn.py +++ b/fms_mo/aiu_addons/fp8/fp8_attn.py @@ -14,7 +14,7 @@ """FMS registration of attention BMM operation using torch-registered scaled BMM.""" # Standard -from typing import NotRequired, Unpack +from typing import NotRequired, Optional, Unpack import math # Third Party @@ -23,6 +23,10 @@ _sdpa_update_attn_kwargs, register_attention_op, ) +from fms.utils.spyre.paged import ( + SpyrePagedAttentionKwargs, + __spyre_paged_validate_attn_kwargs_op, +) import torch # Local @@ -46,7 +50,7 @@ class MathFP8AttentionKwargs(AttentionKwargs): def _construct_fp8_cache(tensor: torch.Tensor, scale: torch.Tensor) -> ScaledTensor: """Construct the custom object to save KV cache with its scales.""" - return ScaledTensor(tensor, scale) + return ScaledTensor(tensor, scale, True) def _math_fp8_store_op( @@ -58,6 +62,7 @@ def _math_fp8_store_op( ) -> tuple[ScaledTensor, ScaledTensor, ScaledTensor, ScaledTensor]: """Implement math of KV cache storing.""" + # Grab scale from kv-cache if already there, compute dynamically otherwise if isinstance(key_cache, ScaledTensor) and isinstance(value_cache, ScaledTensor): k_scale = key_cache._scale v_scale = value_cache._scale @@ -65,6 +70,7 @@ def _math_fp8_store_op( k_scale = (torch.abs(keys).max() / K_RANGE).to(dtype=torch.float32) v_scale = (torch.abs(values).max() / V_RANGE).to(dtype=torch.float32) + # Scale kv tensors for storage keys = (keys / k_scale).to(torch.float8_e4m3fn).transpose(2, 1) values = (values / v_scale).to(torch.float8_e4m3fn).transpose(2, 1) @@ -83,6 +89,7 @@ def _math_fp8_store_op( key_cache, value_cache, ) + # If it's a new kv cache, ensure it's contiguous for spyre use cases keys = _construct_fp8_cache(keys.contiguous(), k_scale) values = _construct_fp8_cache(values.contiguous(), v_scale) return (keys, values, keys, values) @@ -98,11 +105,12 @@ def _math_fp8_compute_op( scale_factor: float | None, **attn_kwargs: Unpack[MathFP8AttentionKwargs], ) -> torch.Tensor: - """Implement computation of attention BMM, leveraging the custom scaled attention - BMM op that was pre-registered for torch.compile.""" + """Implement computation of scaled dot product attention, leveraging + the custom scaled BMM op that was pre-registered for torch.compile.""" orig_dtype = query.dtype + # Scaling the Q tensor is optional q_scale = torch.tensor(1.0, dtype=torch.float32, device=query.device) if attn_kwargs.get("do_scale_q", False): q_scale.copy_(torch.abs(query).max() / Q_RANGE) @@ -110,23 +118,27 @@ def _math_fp8_compute_op( query = query.to(torch.float8_e4m3fn).transpose(2, 1) + # Grab kv cache and deal with cases where no store op was run if isinstance(key_cache, ScaledTensor) and isinstance(value_cache, ScaledTensor): + # Store op was run k_scale = key_cache._scale v_scale = value_cache._scale key_cache = key_cache._data value_cache = value_cache._data else: + # Store op wasn't run (e.g. encoders, use_cache=False) k_scale = (torch.abs(key_cache).max() / K_RANGE).to(dtype=torch.float32) v_scale = (torch.abs(value_cache).max() / V_RANGE).to(dtype=torch.float32) key_cache = (key_cache / k_scale).to(torch.float8_e4m3fn) value_cache = (value_cache / v_scale).to(torch.float8_e4m3fn) - # no longer transposing prior to store, so need to check this in case of no cache + # If store wasn't run, we need to transpose the tensors here # TODO: Refactor FMS to avoid edge cases where this fails; add use_cache param here if key_cache.shape[1] != kvheads and key_cache.shape[2] == kvheads: key_cache = key_cache.transpose(2, 1) value_cache = value_cache.transpose(2, 1) + # Most of the code that follows is a copy of Pytorch SDPA, with fp8 additions mask = attn_kwargs.get("mask", None) if mask is not None: # Our expected mask format is bs x q_len x k_len, so to make it broadcastable @@ -187,3 +199,86 @@ def _math_fp8_compute_op( _math_fp8_compute_op, update_attn_kwargs_op=_sdpa_update_attn_kwargs, ) + + +def _spyre_scaled_paged_store_op( + keys: torch.Tensor, + values: torch.Tensor, + key_cache: Optional[torch.Tensor], + value_cache: Optional[torch.Tensor], + **attn_kwargs: Unpack[SpyrePagedAttentionKwargs], +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # For paged store, we must have pre-allocated the kv-cache + assert key_cache is not None and isinstance( + key_cache, ScaledTensor + ), "kv cache must be preallocated" + assert value_cache is not None and isinstance( + value_cache, ScaledTensor + ), "kv cache must be preallocated" + if not key_cache._scaled: + key_cache._scale = (torch.abs(keys).max() / 200.0).to(dtype=torch.float32) + value_cache._scale = (torch.abs(values).max() / 100.0).to(dtype=torch.float32) + + result_key_cache_data, result_value_cache_data = ( + torch.ops.spyre.scaled_paged_attn_store( + keys, + values, + key_cache._data, + value_cache._data, + key_cache._scale, + value_cache._scale, + attn_kwargs["slot_mapping"], + ) + ) + + result_key_cache = _construct_fp8_cache(result_key_cache_data, key_cache._scale) + result_value_cache = _construct_fp8_cache( + result_value_cache_data, value_cache._scale + ) + + # for prefill, we want to return the original keys/values + if attn_kwargs.get("block_table", None) is None: + return keys, values, result_key_cache, result_value_cache + return ( + result_key_cache, + result_value_cache, + result_key_cache, + result_value_cache, + ) + + +def _spyre_scaled_paged_compute_op( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + nheads: int, + kvheads: int, + p_dropout: float, + scale_factor: Optional[float], + **attn_kwargs, +) -> torch.Tensor: + assert isinstance(key_cache, ScaledTensor), "kv cache must be scaled" + assert isinstance(value_cache, ScaledTensor), "kv cache must be scaled" + if scale_factor is None: + scale_factor = 1 / math.sqrt(query.shape[-1]) + return torch.ops.spyre.scaled_paged_attn_compute( + query, + key_cache._data, + value_cache._data, + key_cache._scale, + value_cache._scale, + scale_factor, + attn_kwargs["current_tkv_mask"], + attn_kwargs["left_padded_prompt_mask"], + attn_kwargs["block_table"], + ) + + +register_attention_op( + "spyre_paged_attn_fp8", + _spyre_scaled_paged_store_op, + compute_op=_math_fp8_compute_op, + is_prefill_op=lambda **attn_kwargs: attn_kwargs.get("block_table", None) is None, + compute_decode_op=_spyre_scaled_paged_compute_op, + validate_attn_kwargs_op=__spyre_paged_validate_attn_kwargs_op, +) diff --git a/fms_mo/aiu_addons/fp8/fp8_linear.py b/fms_mo/aiu_addons/fp8/fp8_linear.py index 3e2246ba..fbe36a13 100644 --- a/fms_mo/aiu_addons/fp8/fp8_linear.py +++ b/fms_mo/aiu_addons/fp8/fp8_linear.py @@ -14,7 +14,6 @@ """Implement FP8 linear module to be loaded via FMS.""" # Standard -from importlib.util import find_spec from typing import Any, Mapping # Third Party @@ -28,15 +27,15 @@ from fms.modules.tp import ShardType, TPModule import torch +# Local +from fms_mo.prep import available_packages + # pylint: disable=not-callable -# torch.nn.functional.linear not recognized as callable +# torch.nn.functional.scaled_dot_product_attention not recognized as callable # open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482 - # Gated torchao imports for FP8 implementation -if find_spec("torchao"): - TORCHAO_INSTALLED = True - +if available_packages["torchao"]: # Third Party from torchao.dtypes.affine_quantized_tensor import ( AffineQuantizedTensor, @@ -58,270 +57,264 @@ from torchao.quantization.granularity import PerRow, PerTensor from torchao.quantization.observer import get_block_size from torchao.quantization.quant_primitives import ZeroPointDomain -else: - TORCHAO_INSTALLED = False - - -class FP8Linear(torch.nn.Module): - """Class handles FP8 weights loading and uses torchao for the matmuls.""" - - def __init__( - self, - in_features: int, - out_features: int, - bias: bool, - linear_config: Mapping[str, Any], - ): - super().__init__() - - self.in_features = in_features - self.out_features = out_features - self.has_bias = bias - self.linear_config = linear_config - - assert ( - self.linear_config["weights"] is not None - ), "Weights must always be quantized for FP8Linear" - assert self.linear_config["weights"][ - "symmetric" - ], "We only support symmetric weights for now" - assert not self.linear_config["weights"][ - "dynamic" - ], "We only support pre-quantized weights for now" - - self.weight = torch.nn.Parameter( - torch.zeros(out_features, in_features, dtype=torch.float8_e4m3fn), - requires_grad=False, - ) - - weight_scale_shape = ( - (1,) - if self.linear_config["weights"]["strategy"] == "tensor" - else (out_features, 1) - ) - self.weight_scale = torch.nn.Parameter( - torch.ones(weight_scale_shape), requires_grad=False - ) - self.has_bias = bias - if self.has_bias: - self.bias = torch.nn.Parameter(torch.zeros((out_features,))) + class FP8Linear(torch.nn.Module): + """Class handles FP8 weights loading and uses torchao for the matmuls.""" - if ( - self.linear_config["input_activations"] is not None - and not self.linear_config["input_activations"]["dynamic"] + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_config: Mapping[str, Any], ): - input_scale_shape = ( - (1,) - if self.linear_config["input_activations"]["strategy"] == "tensor" - else (out_features, 1) - ) - self.input_scale = torch.nn.Parameter( - torch.ones(input_scale_shape), requires_grad=False + super().__init__() + + self.in_features = in_features + self.out_features = out_features + self.has_bias = bias + self.linear_config = linear_config + + assert ( + self.linear_config["weights"] is not None + ), "Weights must always be quantized for FP8Linear" + assert self.linear_config["weights"][ + "symmetric" + ], "We only support symmetric weights for now" + assert not self.linear_config["weights"][ + "dynamic" + ], "We only support pre-quantized weights for now" + + self.weight = torch.nn.Parameter( + torch.zeros(out_features, in_features, dtype=torch.float8_e4m3fn), + requires_grad=False, ) - def _input_activation_quant_func_fp8( - self, - x: torch.Tensor, - activation_granularity, - activation_dtype: torch.dtype, - scale: torch.Tensor | None = None, - ): - """Quantize the input activation tensor for an aqt_float variant. - If scale is not provided, it will be dynamically calculated, otherwise the - provided scale will be used. - """ - - block_size = get_block_size(x.shape, activation_granularity) - if scale is None: - activation = to_affine_quantized_floatx( - input_float=x, - block_size=block_size, - target_dtype=activation_dtype, - scale_dtype=torch.float32, - _layout=Float8Layout(mm_config=None), # Config is stored on weight + weight_scale_shape = ( + (1,) + if self.linear_config["weights"]["strategy"] == "tensor" + else (out_features, 1) ) - else: - assert isinstance( - activation_granularity, PerTensor - ), "Static quantization only supports PerTensor granularity" - activation = to_affine_quantized_floatx_static( - input_float=x, - block_size=block_size, - scale=scale, - target_dtype=activation_dtype, - _layout=Float8Layout(mm_config=None), # Config is stored on weight + self.weight_scale = torch.nn.Parameter( + torch.ones(weight_scale_shape), requires_grad=False ) - return activation - def _construct_qweight_structure(self) -> "AffineQuantizedTensor": - """Construct the torchao machinery for the fp8 matmul""" - - weight_granularity = ( - PerTensor() - if self.linear_config["weights"]["strategy"] == "tensor" - else PerRow() - ) - fp8_layout = Float8Layout(Float8MMConfig(use_fast_accum=True)) - return AffineQuantizedTensor( - Float8AQTTensorImpl.from_plain( - self.weight, - self.weight_scale.squeeze().to(torch.float32), - None, - fp8_layout, - ), - get_block_size(self.weight.shape, weight_granularity), - self.weight.shape, - zero_point_domain=ZeroPointDomain.NONE, - dtype=self.weight_scale.dtype, - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """If input quantization is active, compute FP8xFP8 addmm leveraging torchao - functionalities. Otherwise compute non-quantized addmm.""" + self.has_bias = bias + if self.has_bias: + self.bias = torch.nn.Parameter(torch.zeros((out_features,))) + + if ( + self.linear_config["input_activations"] is not None + and not self.linear_config["input_activations"]["dynamic"] + ): + input_scale_shape = ( + (1,) + if self.linear_config["input_activations"]["strategy"] == "tensor" + else (out_features, 1) + ) + self.input_scale = torch.nn.Parameter( + torch.ones(input_scale_shape), requires_grad=False + ) - # fp8 weight tensor for torchao - qweight: AffineQuantizedTensor = self._construct_qweight_structure() + def _input_activation_quant_func_fp8( + self, + x: torch.Tensor, + activation_granularity, + activation_dtype: torch.dtype, + scale: torch.Tensor | None = None, + ): + """Quantize the input activation tensor for an aqt_float variant. + If scale is not provided, it will be dynamically calculated, otherwise the + provided scale will be used. + """ + block_size = get_block_size(x.shape, activation_granularity) + if scale is None: + activation = to_affine_quantized_floatx( + input_float=x, + block_size=block_size, + target_dtype=activation_dtype, + scale_dtype=torch.float32, + _layout=Float8Layout(mm_config=None), # Config is stored on weight + ) + else: + assert isinstance( + activation_granularity, PerTensor + ), "Static quantization only supports PerTensor granularity" + activation = to_affine_quantized_floatx_static( + input_float=x, + block_size=block_size, + scale=scale, + target_dtype=activation_dtype, + _layout=Float8Layout(mm_config=None), # Config is stored on weight + ) + return activation - if self.linear_config["input_activations"] is not None: - # activations are also fp8, quantize as required by model - act_granularity = ( + def _construct_qweight_structure(self) -> "AffineQuantizedTensor": + """Construct the torchao machinery for the fp8 matmul""" + weight_granularity = ( PerTensor() - if self.linear_config["input_activations"]["strategy"] == "tensor" + if self.linear_config["weights"]["strategy"] == "tensor" else PerRow() ) - input_quant_kwargs = { - "activation_granularity": act_granularity, - "activation_dtype": torch.float8_e4m3fn, - } - if not self.linear_config["input_activations"]["dynamic"]: - input_quant_kwargs["scale"] = self.input_scale.squeeze().to( - torch.float32 - ) - qx = self._input_activation_quant_func_fp8(x, **input_quant_kwargs) - - # Copied from torchao _linear_fp8_act_fp8_weight_impl (with changes to support fp8 out) - scaled_mm_config = Float8MMConfig(use_fast_accum=True) - out_shape = get_out_shape(qx.shape, qweight.shape) - - # Weight tensor preprocessing - w_tensor_impl = qweight.tensor_impl - assert not w_tensor_impl.transposed, "Weight tensor must be contiguous" - w_data = w_tensor_impl.float8_data - w_scale = w_tensor_impl.scale + fp8_layout = Float8Layout(Float8MMConfig(use_fast_accum=True)) + return AffineQuantizedTensor( + Float8AQTTensorImpl.from_plain( + self.weight, + self.weight_scale.squeeze().to(torch.float32), + None, + fp8_layout, + ), + get_block_size(self.weight.shape, weight_granularity), + self.weight.shape, + zero_point_domain=ZeroPointDomain.NONE, + dtype=self.weight_scale.dtype, + ) - # Input tensor preprocessing - inpt_data = qx.tensor_impl.float8_data - input_scale = qx.tensor_impl.scale - # Handle case where input tensor is more than 2D - inpt_data = inpt_data.reshape(-1, inpt_data.shape[-1]) + def forward(self, x: torch.Tensor) -> torch.Tensor: + """If input quantization is active, compute FP8xFP8 addmm leveraging torchao + functionalities. Otherwise compute non-quantized addmm.""" + # fp8 weight tensor for torchao + qweight: AffineQuantizedTensor = self._construct_qweight_structure() + + if self.linear_config["input_activations"] is not None: + # activations are also fp8, quantize as required by model + act_granularity = ( + PerTensor() + if self.linear_config["input_activations"]["strategy"] == "tensor" + else PerRow() + ) + input_quant_kwargs = { + "activation_granularity": act_granularity, + "activation_dtype": torch.float8_e4m3fn, + } + if not self.linear_config["input_activations"]["dynamic"]: + input_quant_kwargs["scale"] = self.input_scale.squeeze().to( + torch.float32 + ) + qx = self._input_activation_quant_func_fp8(x, **input_quant_kwargs) + + # Copied from torchao _linear_fp8_act_fp8_weight_impl + # (with changes to support fp8 out) + scaled_mm_config = Float8MMConfig(use_fast_accum=True) + out_shape = get_out_shape(qx.shape, qweight.shape) + + # Weight tensor preprocessing + w_tensor_impl = qweight.tensor_impl + assert not w_tensor_impl.transposed, "Weight tensor must be contiguous" + w_data = w_tensor_impl.float8_data + w_scale = w_tensor_impl.scale + + # Input tensor preprocessing + inpt_data = qx.tensor_impl.float8_data + input_scale = qx.tensor_impl.scale + # Handle case where input tensor is more than 2D + inpt_data = inpt_data.reshape(-1, inpt_data.shape[-1]) + + # Handle rowwise case + if _is_rowwise_scaled(qweight): + assert _is_rowwise_scaled( + qx + ), "Input tensor must be rowwise block size" + w_scale = w_scale.unsqueeze(-1).T + input_scale = preprocess_scale(input_scale, qx.shape) + + # Preprocess data + inpt_data, w_data = preprocess_data( + inpt_data, w_data.T, scaled_mm_config + ) - # Handle rowwise case - if _is_rowwise_scaled(qweight): - assert _is_rowwise_scaled(qx), "Input tensor must be rowwise block size" - w_scale = w_scale.unsqueeze(-1).T - input_scale = preprocess_scale(input_scale, qx.shape) + # Perform the computation + return addmm_float8_unwrapped_inference( + inpt_data, + input_scale, + w_data, + w_scale, + output_dtype=qx.dtype, + bias=getattr(self, "bias", None), + use_fast_accum=scaled_mm_config.use_fast_accum, + ).reshape(out_shape) + + # activations not quantized, dequant fp8 weight and do regular matmul + out = torch.nn.functional.linear( + x, qweight.dequantize(), self.bias if self.has_bias else None + ) + return out - # Preprocess data - inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config) + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}" + f"(in={self.in_features}, out={self.out_features}, " + f"bias={self.has_bias}, fp8_config={self.linear_config})" + ) - # Perform the computation - return addmm_float8_unwrapped_inference( - inpt_data, - input_scale, - w_data, - w_scale, - output_dtype=qx.dtype, - bias=getattr(self, "bias", None), - use_fast_accum=scaled_mm_config.use_fast_accum, - ).reshape(out_shape) + def get_fp8_linear( + in_features: int, + out_features: int, + bias: bool, + linear_config: Mapping[str, Any], + ) -> FP8Linear: + """Retrieve an FP8 Linear module""" + return FP8Linear(in_features, out_features, bias, linear_config) + + def shard_fp8_linear( + tensor_values: dict[str, torch.Tensor], + tp_module: TPModule, + module_sharding_info: dict[str, LinearModuleShardingInfo], + ) -> set | None: + """ + | GPU | + sharding | param | shard | dim | + ----------+----------------+-------+-----| + colwise | weight | Y | 0 | + | weight_scale | N | - | + | input_scale | N | - | + | bias | Y | 0 | + ----------+----------------+-------+-----| + rowwise | weight | Y | 1 | + | weight_scale | Y/N | 0/- | + | input_scale | Y/N | 0/- | + | bias | 0 | - | + """ - # activations not quantized, dequant fp8 weight and do regular matmul - out = torch.nn.functional.linear( - x, qweight.dequantize(), self.bias if self.has_bias else None - ) - return out + param_sharding_info: dict[str, dict[str, LinearParameterShardingInfo]] = {} + for module_name, module_info in module_sharding_info.items(): + linear_mod: torch.nn.Module = module_info.linear_module + weight_strategy = getattr(linear_mod, "linear_config")["input_activations"][ + "strategy" + ] + # Scales are per-row or per-tensor + # Only sharding needed when row parallel and per-row + shard_scales = weight_strategy != "tensor" and module_info.sharding_dim == 1 + params: dict[str, LinearParameterShardingInfo] = { + "weight": LinearParameterShardingInfo( + module_info.sharding_dim, ShardType.SHARD + ), + "weight_scale": LinearParameterShardingInfo( + module_info.sharding_dim, + ShardType.SHARD if shard_scales else ShardType.CLONE, + ), + } + if hasattr(linear_mod, "input_scale"): + params["input_scale"] = LinearParameterShardingInfo( + module_info.sharding_dim, + ShardType.SHARD if shard_scales else ShardType.CLONE, + ) + if hasattr(linear_mod, "bias") and linear_mod.bias is not None: + params["bias"] = LinearParameterShardingInfo( + module_info.sharding_dim, + ShardType.SHARD + if module_info.sharding_dim == 0 + else ShardType.RANK0, + ) + param_sharding_info[module_name] = params - def __repr__(self) -> str: - return ( - f"{self.__class__.__name__}" - f"(in={self.in_features}, out={self.out_features}, " - f"bias={self.has_bias}, fp8_config={self.linear_config})" + unused_keys = shard_base_linear( + tensor_values, + tp_module, + module_sharding_info, + param_sharding_info, ) + return unused_keys - -def get_fp8_linear( - in_features: int, - out_features: int, - bias: bool, - linear_config: Mapping[str, Any], -) -> FP8Linear: - """Retrieve an FP8 Linear module""" - - if not TORCHAO_INSTALLED: - raise ModuleNotFoundError("You need to install torchao for FP8 support in FMS!") - - return FP8Linear(in_features, out_features, bias, linear_config) - - -def shard_fp8_linear( - tensor_values: dict[str, torch.Tensor], - tp_module: TPModule, - module_sharding_info: dict[str, LinearModuleShardingInfo], -) -> set | None: - """ - | GPU | - sharding | param | shard | dim | - ----------+----------------+-------+-----| - colwise | weight | Y | 0 | - | weight_scale | N | - | - | input_scale | N | - | - | bias | Y | 0 | - ----------+----------------+-------+-----| - rowwise | weight | Y | 1 | - | weight_scale | Y/N | 0/- | - | input_scale | Y/N | 0/- | - | bias | 0 | - | - """ - - param_sharding_info: dict[str, dict[str, LinearParameterShardingInfo]] = {} - for module_name, module_info in module_sharding_info.items(): - linear_mod: torch.nn.Module = module_info.linear_module - weight_strategy = getattr(linear_mod, "linear_config")["input_activations"][ - "strategy" - ] - # Scales are per-row or per-tensor - # Only sharding needed when row parallel and per-row - shard_scales = weight_strategy != "tensor" and module_info.sharding_dim == 1 - params: dict[str, LinearParameterShardingInfo] = { - "weight": LinearParameterShardingInfo( - module_info.sharding_dim, ShardType.SHARD - ), - "weight_scale": LinearParameterShardingInfo( - module_info.sharding_dim, - ShardType.SHARD if shard_scales else ShardType.CLONE, - ), - } - if hasattr(linear_mod, "input_scale"): - params["input_scale"] = LinearParameterShardingInfo( - module_info.sharding_dim, - ShardType.SHARD if shard_scales else ShardType.CLONE, - ) - if hasattr(linear_mod, "bias") and linear_mod.bias is not None: - params["bias"] = LinearParameterShardingInfo( - module_info.sharding_dim, - ShardType.SHARD if module_info.sharding_dim == 0 else ShardType.RANK0, - ) - param_sharding_info[module_name] = params - - unused_keys = shard_base_linear( - tensor_values, - tp_module, - module_sharding_info, - param_sharding_info, - ) - return unused_keys - - -register_linear_type_to_module_map("fp8", get_fp8_linear) -register_linear_type_to_sharding_map("fp8", shard_fp8_linear) + register_linear_type_to_module_map("fp8", get_fp8_linear) + register_linear_type_to_sharding_map("fp8", shard_fp8_linear) diff --git a/fms_mo/aiu_addons/fp8/fp8_spyre_op.py b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py index c6280d4a..910073bf 100644 --- a/fms_mo/aiu_addons/fp8/fp8_spyre_op.py +++ b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py @@ -16,10 +16,15 @@ # Third Party from torch import Tensor import torch +import torch.nn.functional as F # pylint: disable=unused-argument # abstract op must be registered with specific I/O, even if not in use by the op function +# pylint: disable=not-callable +# torch.nn.functional.scaled_dot_product_attention not recognized as callable +# open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482 + @torch.library.custom_op("spyre::scaled_bmm", mutates_args=()) def sendnn_scaled_bmm( @@ -73,3 +78,144 @@ def _( dtype=out_dtype, device=mat1.device, ) + + +@torch.library.custom_op( + "spyre::scaled_paged_attn_store", mutates_args=(), device_types="cpu" +) +def scaled_paged_attn_store( + key: Tensor, + value: Tensor, + key_cache: Tensor, + value_cache: Tensor, + key_scale: Tensor, + value_scale: Tensor, + slot_mapping: Tensor, +) -> tuple[Tensor, Tensor]: + """ + FP8 CPU implementation of the Paged Attn store operation. + Scales key and value tensors, and stores them to the paged KV cache + using the same schema as vLLM. + """ + result_key_cache = key_cache.clone() + result_value_cache = value_cache.clone() + for seq_i, slot_mapping_seq in enumerate(slot_mapping): + for tok_i, slot in enumerate(slot_mapping_seq): + block_number = slot.item() // 64 + position = slot.item() % 64 + + result_key_cache[block_number, position, :, :] = ( + key[seq_i, tok_i, :, :] / key_scale + ).to(dtype=torch.float8_e4m3fn) + result_value_cache[block_number, position, :, :] = ( + value[seq_i, tok_i, :, :] / value_scale + ).to(dtype=torch.float8_e4m3fn) + return result_key_cache, result_value_cache + + +@scaled_paged_attn_store.register_fake +def scaled_paged_attn_store_meta( + key: Tensor, + value: Tensor, + key_cache: Tensor, + value_cache: Tensor, + key_scale: Tensor, + value_scale: Tensor, + slot_mapping: Tensor, +) -> tuple[Tensor, Tensor]: + """ + Fake tensor implementation of the FP8 Paged Attn store operation. + """ + return key_cache, value_cache + + +@torch.library.custom_op( + "spyre::scaled_paged_attn_compute", mutates_args={}, device_types="cpu" +) +def scaled_paged_attn_compute( + query: Tensor, + key_cache: Tensor, + value_cache: Tensor, + key_scale: Tensor, + value_scale: Tensor, + scale: float, + current_tkv_mask: Tensor, + left_padded_prompt_mask: Tensor, + block_table: Tensor, +) -> Tensor: + """ + FP8 CPU implementation of the Paged Attn compute operation. + Implements a CPU fallback to run the kernel that has been confirmed + to match the vLLM fused kernel. + """ + # torch.zeros(NUM_BLOCKS, BLOCK_SIZE, kvheads, head_size, dtype=model_dtype), + output = torch.zeros_like(query) + num_query_heads = query.shape[2] + num_kv_heads = value_cache.shape[2] + head_size = value_cache.shape[3] + block_size = value_cache.shape[1] + num_seqs = query.shape[0] + + block_tables_lst = block_table.cpu().tolist() + + seq_lens_lst = current_tkv_mask.cpu().tolist() + for i in range(num_seqs): + q = query[i] + block_table = block_tables_lst[i] + start_pos = int(left_padded_prompt_mask[i].item()) + seq_len = int(seq_lens_lst[i]) + + keys_lst: list[torch.Tensor] = [] + values_lst: list[torch.Tensor] = [] + for j in range(start_pos, seq_len): + block_number = int(block_table[j // block_size]) + block_offset = j % block_size + + k = key_cache[block_number, block_offset, :, :] + k = k.reshape(num_kv_heads, head_size) + keys_lst.append(k) + + v = value_cache[block_number, block_offset, :, :] + values_lst.append(v) + keys = torch.stack(keys_lst, dim=0) + values = torch.stack(values_lst, dim=0) + if num_kv_heads > 1: + # Handle MQA and GQA + keys = torch.repeat_interleave(keys, num_query_heads // num_kv_heads, dim=1) + values = torch.repeat_interleave( + values, num_query_heads // num_kv_heads, dim=1 + ) + + out = F.scaled_dot_product_attention( # noqa: E1102 + q.transpose(0, 1).unsqueeze(0), # format for sdpa + (keys.transpose(0, 1).unsqueeze(0).to(dtype=q.dtype) * key_scale).to( + dtype=q.dtype + ), # format for sdpa + (values.transpose(0, 1).unsqueeze(0).to(dtype=q.dtype) * value_scale).to( + dtype=q.dtype + ), # format for sdpa + is_causal=False, # decode assumes no causal mask + scale=scale, + ) + + out = out.view(num_query_heads, head_size) + output[i].copy_(out, non_blocking=True) + return output + + +@scaled_paged_attn_compute.register_fake +def scaled_paged_attn_compute_meta( + query: Tensor, + key_cache: Tensor, + value_cache: Tensor, + key_scale: Tensor, + value_scale: Tensor, + scale: float, + current_tkv_mask: Tensor, + left_padded_prompt_mask: Tensor, + block_table: Tensor, +) -> Tensor: + """ + Fake tensor implementation of the FP8 Paged Attn compute operation. + """ + return torch.zeros_like(query) diff --git a/fms_mo/aiu_addons/fp8/fp8_utils.py b/fms_mo/aiu_addons/fp8/fp8_utils.py index 0f5fbab4..0024ec4d 100644 --- a/fms_mo/aiu_addons/fp8/fp8_utils.py +++ b/fms_mo/aiu_addons/fp8/fp8_utils.py @@ -47,6 +47,7 @@ def __new__( cls, data: torch.Tensor, scale: torch.Tensor, + scaled: bool = True, ): return torch.Tensor._make_wrapper_subclass( cls, @@ -59,16 +60,19 @@ def __new__( device=data.device, ) - def __init__( # pylint: disable=super-init-not-called + def __init__( self, data: torch.Tensor, scale: torch.Tensor, + scaled: bool = True, ): + super().__init__() self._data = data self._scale = scale + self._scaled = scaled def __tensor_flatten__(self): - ctx = {} + ctx = {"scaled", self._scaled} return ["_data", "_scale"], ctx @staticmethod @@ -77,6 +81,7 @@ def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride): return ScaledTensor( inner_tensors["_data"], inner_tensors["_scale"], + metadata["scaled"], ) @classmethod @@ -93,60 +98,3 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): def __repr__(self): return f"{self._data.__repr__()}\n{self._scale.__repr__()}" - - -def _infer_quantization_config(quant_config: dict) -> dict | None: - """Construct linear_config dictionary carrying FP8 configuration for FMS. - - There's many quantization packages compatible with HF - We initially focus on llm-compressor as it is the one used in FMS-MO - - llm-compressor saves its checkpoints with quant_method = compressed-tensors - quantization_status tells us whether the model has already been quantized - We only support loading already quantized models (compressed status) - """ - - if ( - quant_config["quant_method"] == "compressed-tensors" - and quant_config["quantization_status"] == "compressed" - ): - # FP8 quantization will have FP8 weights - # We assume a single quantization group (group_0), to follow fms-mo checkpoints - # num_bits and type tells us "float" with "8" bits, aka FP8 - if ( - quant_config["config_groups"]["group_0"]["weights"]["type"] == "float" - and quant_config["config_groups"]["group_0"]["weights"]["num_bits"] == 8 - ): - # This is used by get_linear to decide whether a linear layer - # will be quantized or not inside the model - def fp8_linear_type(name: str) -> str: - # We need to translate HF names to FMS names - translations = { - "lm_head": "head", - } - for ignored_layer in quant_config["ignore"]: - assert isinstance(ignored_layer, str) - fms_ign_layer = translations.get(ignored_layer, ignored_layer) - if name in fms_ign_layer: - return "torch_linear" - for pattern in quant_config["config_groups"]["group_0"]["targets"]: - # Special case from llm-compressor that covers all linear layers - # not in the ignore pattern - assert isinstance(pattern, str) - if pattern == "Linear": - return "fp8" - if name in translations.get(pattern, pattern): - return "fp8" - return "torch_linear" - - return { - "linear_type": fp8_linear_type, - "input_activations": quant_config["config_groups"]["group_0"][ - "input_activations" - ], - "output_activations": quant_config["config_groups"]["group_0"][ - "output_activations" - ], - "weights": quant_config["config_groups"]["group_0"]["weights"], - } - return None diff --git a/fms_mo/utils/import_utils.py b/fms_mo/utils/import_utils.py index 51b113ee..d2f7894f 100644 --- a/fms_mo/utils/import_utils.py +++ b/fms_mo/utils/import_utils.py @@ -32,6 +32,7 @@ "fms", "triton", "torchvision", + "torchao", ] available_packages = {} diff --git a/pyproject.toml b/pyproject.toml index 4b43cf98..85884622 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,10 +37,11 @@ dependencies = [ "safetensors", "ibm-fms>=0.0.8", "pkginfo>1.10", +"torchao" ] [project.optional-dependencies] -fp8 = ["llmcompressor"] +fp8 = ["llmcompressor", "torchao"] gptq = ["Cython", "gptqmodel>=1.7.3"] mx = ["microxcaling>=1.1"] opt = ["fms-model-optimizer[fp8, gptq, mx]"] From cf2082efb95cda6ac980091775a68f5ff5f71305 Mon Sep 17 00:00:00 2001 From: Antoni Viros i Martin Date: Mon, 30 Jun 2025 19:57:07 +0000 Subject: [PATCH 07/15] Make changes to work with fms and aftu Signed-off-by: Antoni Viros i Martin --- fms_mo/aiu_addons/__init__.py | 3 +++ fms_mo/aiu_addons/fp8/fp8_attn.py | 6 +++--- fms_mo/aiu_addons/fp8/fp8_linear.py | 2 +- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/fms_mo/aiu_addons/__init__.py b/fms_mo/aiu_addons/__init__.py index e4f30082..b3ef66bb 100644 --- a/fms_mo/aiu_addons/__init__.py +++ b/fms_mo/aiu_addons/__init__.py @@ -20,6 +20,9 @@ def _infer_quantization_config(quant_config: dict) -> dict | None: quant_config["config_groups"]["group_0"]["weights"]["type"] == "float" and quant_config["config_groups"]["group_0"]["weights"]["num_bits"] == 8 ): + # First, import required FP8 linear classes from fms-mo + import fms_mo.aiu_addons.fp8.fp8_linear # pylint: disable=unused-import + import fms_mo.aiu_addons.fp8.fp8_adapter # pylint: disable=unused-import # This is used by get_linear to decide whether a linear layer # will be quantized or not inside the model def fp8_linear_type(name: str) -> str: diff --git a/fms_mo/aiu_addons/fp8/fp8_attn.py b/fms_mo/aiu_addons/fp8/fp8_attn.py index 46fcc01c..54188d3f 100644 --- a/fms_mo/aiu_addons/fp8/fp8_attn.py +++ b/fms_mo/aiu_addons/fp8/fp8_attn.py @@ -251,9 +251,9 @@ def _spyre_scaled_paged_compute_op( query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, - nheads: int, - kvheads: int, - p_dropout: float, + nheads: int, # pylint: disable=unused-argument + kvheads: int, # pylint: disable=unused-argument + p_dropout: float, # pylint: disable=unused-argument scale_factor: Optional[float], **attn_kwargs, ) -> torch.Tensor: diff --git a/fms_mo/aiu_addons/fp8/fp8_linear.py b/fms_mo/aiu_addons/fp8/fp8_linear.py index fbe36a13..7b62a08d 100644 --- a/fms_mo/aiu_addons/fp8/fp8_linear.py +++ b/fms_mo/aiu_addons/fp8/fp8_linear.py @@ -193,7 +193,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) qx = self._input_activation_quant_func_fp8(x, **input_quant_kwargs) - # Copied from torchao _linear_fp8_act_fp8_weight_impl + # Copied from torchao _linear_fp8_act_fp8_weight_impl # (with changes to support fp8 out) scaled_mm_config = Float8MMConfig(use_fast_accum=True) out_shape = get_out_shape(qx.shape, qweight.shape) From b12dc586d6cec45895b642e33e942eb1efc11a2f Mon Sep 17 00:00:00 2001 From: Antoni Viros i Martin Date: Tue, 1 Jul 2025 17:18:19 +0000 Subject: [PATCH 08/15] Fixes from PR comments, unit tests Signed-off-by: Antoni Viros i Martin --- fms_mo/aiu_addons/fp8/fp8_linear.py | 2 +- fms_mo/aiu_addons/fp8/fp8_spyre_op.py | 2 ++ pyproject.toml | 3 +-- tox.ini | 2 ++ 4 files changed, 6 insertions(+), 3 deletions(-) diff --git a/fms_mo/aiu_addons/fp8/fp8_linear.py b/fms_mo/aiu_addons/fp8/fp8_linear.py index 7b62a08d..4b86e729 100644 --- a/fms_mo/aiu_addons/fp8/fp8_linear.py +++ b/fms_mo/aiu_addons/fp8/fp8_linear.py @@ -31,7 +31,7 @@ from fms_mo.prep import available_packages # pylint: disable=not-callable -# torch.nn.functional.scaled_dot_product_attention not recognized as callable +# torch.nn.functional.linear not recognized as callable # open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482 # Gated torchao imports for FP8 implementation diff --git a/fms_mo/aiu_addons/fp8/fp8_spyre_op.py b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py index 8b311349..79024245 100644 --- a/fms_mo/aiu_addons/fp8/fp8_spyre_op.py +++ b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py @@ -99,6 +99,7 @@ def scaled_paged_attn_store( Scales key and value tensors, and stores them to the paged KV cache using the same schema as vLLM. """ + print("Should never hit") result_key_cache = key_cache.clone() result_value_cache = value_cache.clone() for seq_i, slot_mapping_seq in enumerate(slot_mapping): @@ -150,6 +151,7 @@ def scaled_paged_attn_compute( Implements a CPU fallback to run the kernel that has been confirmed to match the vLLM fused kernel. """ + print("Should never hit") # torch.zeros(NUM_BLOCKS, BLOCK_SIZE, kvheads, head_size, dtype=model_dtype), output = torch.zeros_like(query) num_query_heads = query.shape[2] diff --git a/pyproject.toml b/pyproject.toml index def6fd6e..06d1e229 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,8 +30,7 @@ dependencies = [ "datasets>=3.0.0,<4.0", "pandas", "safetensors", -"pkginfo>1.10", -"torchao" +"pkginfo>1.10" ] [project.optional-dependencies] diff --git a/tox.ini b/tox.ini index 2f2d0441..78d792be 100644 --- a/tox.ini +++ b/tox.ini @@ -33,6 +33,8 @@ deps = pytest pylint>=2.16.2,<4.0 pylint-pydantic + ibm-fms + torchao commands = {basepython} -m pylint --load-plugins pylint_pydantic fms_mo/ tests/ From 6a881178c91ff398f134a3f72d927d3359fa60b0 Mon Sep 17 00:00:00 2001 From: Antoni Viros i Martin Date: Tue, 1 Jul 2025 17:23:21 +0000 Subject: [PATCH 09/15] Add test Signed-off-by: Antoni Viros i Martin --- tests/aiu_addons/test_fp8_addon.py | 54 ++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 tests/aiu_addons/test_fp8_addon.py diff --git a/tests/aiu_addons/test_fp8_addon.py b/tests/aiu_addons/test_fp8_addon.py new file mode 100644 index 00000000..9d921e4e --- /dev/null +++ b/tests/aiu_addons/test_fp8_addon.py @@ -0,0 +1,54 @@ +# Copyright The FMS Model Optimizer Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test suite for FMS addon introducing FP8 functionalities""" + +# Third Party +import pytest +import torch + +# Local +import fms_mo.aiu_addons.fp8.fp8_spyre_op # pylint: disable=unused-import + + +def test_fp8_registration() -> None: + """ + Ensure fp8 ops are registered properly. + """ + + assert hasattr(torch.ops, "spyre") + assert hasattr(torch.ops.spyre, "scaled_bmm") + assert hasattr(torch.ops.spyre, "scaled_paged_attn_store") + assert hasattr(torch.ops.spyre, "scaled_paged_attn_compute") + + +# This test requires an H100 or higher GPU to run +@pytest.mark.skipif( + not torch.cuda.is_available() + or (torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 9)), + reason="FP8 is only available on GPUs with device level 8.9 or higher", +) +def test_fp8_op() -> None: + """Validate output shapes of GPTQ W4A16 tensors. + Note: this AIU-compatible operation only returns a zero tensor of the + expected shape, it does not perform a real W4A16 matmul operation. + """ + # Local + from fms_mo.aiu_addons.fp8.fp8_attn import _math_fp8_compute_op + + query = torch.randn((1, 32, 64, 128), dtype=torch.bfloat16, device="cuda") + key = torch.randn((1, 32, 64, 128), dtype=torch.bfloat16, device="cuda") + value = torch.randn((1, 32, 64, 128), dtype=torch.bfloat16, device="cuda") + + out = _math_fp8_compute_op(query, key, value, 32, 32, 0.0, None) + assert out.size() == query.size() From bdf1cf2f00e0c9494b5c0f54602d71c23fe2952a Mon Sep 17 00:00:00 2001 From: Antoni Viros i Martin Date: Wed, 2 Jul 2025 16:11:48 +0000 Subject: [PATCH 10/15] Gate FMS imports Signed-off-by: Antoni Viros i Martin --- fms_mo/aiu_addons/fp8/fp8_adapter.py | 91 ++--- fms_mo/aiu_addons/fp8/fp8_attn.py | 484 ++++++++++++++------------- fms_mo/aiu_addons/fp8/fp8_linear.py | 18 +- 3 files changed, 299 insertions(+), 294 deletions(-) diff --git a/fms_mo/aiu_addons/fp8/fp8_adapter.py b/fms_mo/aiu_addons/fp8/fp8_adapter.py index 98c7ca0f..ba44fce8 100644 --- a/fms_mo/aiu_addons/fp8/fp8_adapter.py +++ b/fms_mo/aiu_addons/fp8/fp8_adapter.py @@ -17,54 +17,57 @@ from typing import Any, Mapping import functools -# Third Party -from fms.modules.linear import get_linear_type -from fms.utils import serialization -from fms.utils.config import ModelConfig +# Local +from fms_mo.prep import available_packages -# pylint: disable=unused-argument -# Retaining kwargs input arguments for consistency with other adapter steps. +if available_packages["fms"]: + # Third Party + from fms.modules.linear import get_linear_type + from fms.utils import serialization + from fms.utils.config import ModelConfig + # pylint: disable=unused-argument + # Retaining kwargs input arguments for consistency with other adapter steps. + # TODO: may be shared with gptq llama + def _hf_fp8_check( + input_sd: Mapping[str, Any], + model_config: ModelConfig | None = None, + checkpoint_is_fused: bool = False, + **kwargs, + ) -> Mapping[str, Any]: + """Implementation of adapter step for FMS: ensure that when FP8 quantization + is in use, weights are fused like the model checkpoint. + """ -# TODO: may be shared with gptq llama -def _hf_fp8_check( - input_sd: Mapping[str, Any], - model_config: ModelConfig | None = None, - checkpoint_is_fused: bool = False, - **kwargs, -) -> Mapping[str, Any]: - """Implementation of adapter step for FMS: ensure that when FP8 quantization - is in use, weights are fused like the model checkpoint. - """ + has_fused_weights = True + linear_type = "torch_linear" + if model_config: + if not model_config.fused_weights: + has_fused_weights = False + if model_config.linear_config: + linear_type = model_config.linear_config["linear_type"] + if callable(linear_type): + # Calling this function with "any" guarantees "fp8" to be returned + # when loading an HF fp8 checkpoint, and never in any other condition + linear_type = get_linear_type(model_config.linear_config, "any") - has_fused_weights = True - linear_type = "torch_linear" - if model_config: - if not model_config.fused_weights: - has_fused_weights = False - if model_config.linear_config: - linear_type = model_config.linear_config["linear_type"] - if callable(linear_type): - # Calling this function with "any" guarantees "fp8" to be returned - # when loading an HF fp8 checkpoint, and never in any other condition - linear_type = get_linear_type(model_config.linear_config, "any") + if "fp8" in linear_type and has_fused_weights != checkpoint_is_fused: + raise ValueError( + "FP8 HF llama checkpoints cannot be loaded into a model with fused weights" + ) - if "fp8" in linear_type and has_fused_weights != checkpoint_is_fused: - raise ValueError( - "FP8 HF llama checkpoints cannot be loaded into a model with fused weights" - ) + return input_sd - return input_sd + serialization.register_adapter_step( + "llama", + "hf_fp8_check", + functools.partial(_hf_fp8_check, checkpoint_is_fused=False), + ) + serialization.extend_adapter("llama", "hf", ["hf_fp8_check"]) - -serialization.register_adapter_step( - "llama", "hf_fp8_check", functools.partial(_hf_fp8_check, checkpoint_is_fused=False) -) -serialization.extend_adapter("llama", "hf", ["hf_fp8_check"]) - -serialization.register_adapter_step( - "granite", - "hf_fp8_check", - functools.partial(_hf_fp8_check, checkpoint_is_fused=False), -) -serialization.extend_adapter("granite", "hf", ["hf_fp8_check"]) + serialization.register_adapter_step( + "granite", + "hf_fp8_check", + functools.partial(_hf_fp8_check, checkpoint_is_fused=False), + ) + serialization.extend_adapter("granite", "hf", ["hf_fp8_check"]) diff --git a/fms_mo/aiu_addons/fp8/fp8_attn.py b/fms_mo/aiu_addons/fp8/fp8_attn.py index 5a9f7178..485a6f65 100644 --- a/fms_mo/aiu_addons/fp8/fp8_attn.py +++ b/fms_mo/aiu_addons/fp8/fp8_attn.py @@ -18,267 +18,269 @@ import math # Third Party -from fms.modules.attention import ( - AttentionKwargs, - _sdpa_update_attn_kwargs, - register_attention_op, -) -from fms.utils.spyre.paged import ( - SpyrePagedAttentionKwargs, - __spyre_paged_validate_attn_kwargs_op, -) import torch # Local from fms_mo.aiu_addons.fp8.fp8_utils import ScaledTensor +from fms_mo.prep import available_packages import fms_mo.aiu_addons.fp8.fp8_spyre_op # pylint: disable=unused-import - -class MathFP8AttentionKwargs(AttentionKwargs): - """TypedDict for FP8 attention.""" - - mask: NotRequired[torch.Tensor] - do_scale_q: bool - is_causal_mask: bool - - -# TODO: Figure out better scales for AIU? These come from vLLM -Q_RANGE = 200.0 -K_RANGE = 200.0 -V_RANGE = 100.0 - - -def _construct_fp8_cache(tensor: torch.Tensor, scale: torch.Tensor) -> ScaledTensor: - """Construct the custom object to save KV cache with its scales.""" - return ScaledTensor(tensor, scale, True) - - -def _math_fp8_store_op( - keys: torch.Tensor, # pylint: disable=unused-argument - values: torch.Tensor, - key_cache: torch.Tensor | None, - value_cache: torch.Tensor | None, - **attn_kwargs: Unpack[MathFP8AttentionKwargs], -) -> tuple[ScaledTensor, ScaledTensor, ScaledTensor, ScaledTensor]: - """Implement math of KV cache storing.""" - - # Grab scale from kv-cache if already there, compute dynamically otherwise - if isinstance(key_cache, ScaledTensor) and isinstance(value_cache, ScaledTensor): - k_scale = key_cache._scale - v_scale = value_cache._scale - else: - k_scale = (torch.abs(keys).max() / K_RANGE).to(dtype=torch.float32) - v_scale = (torch.abs(values).max() / V_RANGE).to(dtype=torch.float32) - - # Scale kv tensors for storage - keys = (keys / k_scale).to(torch.float8_e4m3fn).transpose(2, 1) - values = (values / v_scale).to(torch.float8_e4m3fn).transpose(2, 1) - - if ( - isinstance(key_cache, ScaledTensor) - and isinstance(value_cache, ScaledTensor) - and value_cache.numel() > 0 - ): - key_cache = torch.cat((key_cache._data, keys), dim=2) - value_cache = torch.cat((value_cache._data, values), dim=2) - key_cache = _construct_fp8_cache(key_cache, k_scale) - value_cache = _construct_fp8_cache(value_cache, v_scale) - return ( - key_cache, - value_cache, - key_cache, - value_cache, - ) - # If it's a new kv cache, ensure it's contiguous for spyre use cases - keys = _construct_fp8_cache(keys.contiguous(), k_scale) - values = _construct_fp8_cache(values.contiguous(), v_scale) - return (keys, values, keys, values) - - -def _math_fp8_compute_op( - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - nheads: int, - kvheads: int, - p_dropout: float, - scale_factor: float | None, - **attn_kwargs: Unpack[MathFP8AttentionKwargs], -) -> torch.Tensor: - """Implement computation of scaled dot product attention, leveraging - the custom scaled BMM op that was pre-registered for torch.compile.""" - - orig_dtype = query.dtype - - # Scaling the Q tensor is optional - q_scale = torch.tensor(1.0, dtype=torch.float32, device=query.device) - if attn_kwargs.get("do_scale_q", False): - q_scale.copy_(torch.abs(query).max() / Q_RANGE) - query = query / q_scale - - query = query.to(torch.float8_e4m3fn).transpose(2, 1) - - # Grab kv cache and deal with cases where no store op was run - if isinstance(key_cache, ScaledTensor) and isinstance(value_cache, ScaledTensor): - # Store op was run - k_scale = key_cache._scale - v_scale = value_cache._scale - key_cache = key_cache._data - value_cache = value_cache._data - else: - # Store op wasn't run (e.g. encoders, use_cache=False) - k_scale = (torch.abs(key_cache).max() / K_RANGE).to(dtype=torch.float32) - v_scale = (torch.abs(value_cache).max() / V_RANGE).to(dtype=torch.float32) - key_cache = (key_cache / k_scale).to(torch.float8_e4m3fn) - value_cache = (value_cache / v_scale).to(torch.float8_e4m3fn) - - # If store wasn't run, we need to transpose the tensors here - # TODO: Refactor FMS to avoid edge cases where this fails; add use_cache param here - if key_cache.shape[1] != kvheads and key_cache.shape[2] == kvheads: - key_cache = key_cache.transpose(2, 1) - value_cache = value_cache.transpose(2, 1) - - # Most of the code that follows is a copy of Pytorch SDPA, with fp8 additions - mask = attn_kwargs.get("mask", None) - if mask is not None: - # Our expected mask format is bs x q_len x k_len, so to make it broadcastable - # we need to create the nheads dimension - while len(mask.size()) != 4: # expects bs (x nheads) x q_len x kv_len - mask = mask.unsqueeze(1) - - L, S = query.size(-2), key_cache.size(-2) - scale_factor = ( - 1 / math.sqrt(query.size(-1)) if scale_factor is None else scale_factor +if available_packages["fms"]: + # Third Party + from fms.modules.attention import ( + AttentionKwargs, + _sdpa_update_attn_kwargs, + register_attention_op, + ) + from fms.utils.spyre.paged import ( + SpyrePagedAttentionKwargs, + __spyre_paged_validate_attn_kwargs_op, ) - attn_bias = torch.zeros(L, S, dtype=orig_dtype, device=query.device) - if attn_kwargs.get("is_causal_mask", False): - assert mask is None - temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) - attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) - attn_bias.to(torch.float32) - if mask is not None: - if mask.dtype == torch.bool: - attn_bias.masked_fill_(mask.logical_not(), float("-inf")) + class MathFP8AttentionKwargs(AttentionKwargs): + """TypedDict for FP8 attention.""" + + mask: NotRequired[torch.Tensor] + do_scale_q: bool + is_causal_mask: bool + + # TODO: Figure out better scales for AIU? These come from vLLM + Q_RANGE = 200.0 + K_RANGE = 200.0 + V_RANGE = 100.0 + + def _construct_fp8_cache(tensor: torch.Tensor, scale: torch.Tensor) -> ScaledTensor: + """Construct the custom object to save KV cache with its scales.""" + return ScaledTensor(tensor, scale, True) + + def _math_fp8_store_op( + keys: torch.Tensor, # pylint: disable=unused-argument + values: torch.Tensor, + key_cache: torch.Tensor | None, + value_cache: torch.Tensor | None, + **attn_kwargs: Unpack[MathFP8AttentionKwargs], + ) -> tuple[ScaledTensor, ScaledTensor, ScaledTensor, ScaledTensor]: + """Implement math of KV cache storing.""" + + # Grab scale from kv-cache if already there, compute dynamically otherwise + if isinstance(key_cache, ScaledTensor) and isinstance( + value_cache, ScaledTensor + ): + k_scale = key_cache._scale + v_scale = value_cache._scale else: - attn_bias = mask + attn_bias - - expansion = nheads // kvheads - if expansion > 1: - key_cache = key_cache.repeat_interleave( - query.size(-3) // key_cache.size(-3), -3 - ) - value_cache = value_cache.repeat_interleave( - query.size(-3) // value_cache.size(-3), -3 + k_scale = (torch.abs(keys).max() / K_RANGE).to(dtype=torch.float32) + v_scale = (torch.abs(values).max() / V_RANGE).to(dtype=torch.float32) + + # Scale kv tensors for storage + keys = (keys / k_scale).to(torch.float8_e4m3fn).transpose(2, 1) + values = (values / v_scale).to(torch.float8_e4m3fn).transpose(2, 1) + + if ( + isinstance(key_cache, ScaledTensor) + and isinstance(value_cache, ScaledTensor) + and value_cache.numel() > 0 + ): + key_cache = torch.cat((key_cache._data, keys), dim=2) + value_cache = torch.cat((value_cache._data, values), dim=2) + key_cache = _construct_fp8_cache(key_cache, k_scale) + value_cache = _construct_fp8_cache(value_cache, v_scale) + return ( + key_cache, + value_cache, + key_cache, + value_cache, + ) + # If it's a new kv cache, ensure it's contiguous for spyre use cases + keys = _construct_fp8_cache(keys.contiguous(), k_scale) + values = _construct_fp8_cache(values.contiguous(), v_scale) + return (keys, values, keys, values) + + def _math_fp8_compute_op( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + nheads: int, + kvheads: int, + p_dropout: float, + scale_factor: float | None, + **attn_kwargs: Unpack[MathFP8AttentionKwargs], + ) -> torch.Tensor: + """Implement computation of scaled dot product attention, leveraging + the custom scaled BMM op that was pre-registered for torch.compile.""" + + orig_dtype = query.dtype + + # Scaling the Q tensor is optional + q_scale = torch.tensor(1.0, dtype=torch.float32, device=query.device) + if attn_kwargs.get("do_scale_q", False): + q_scale.copy_(torch.abs(query).max() / Q_RANGE) + query = query / q_scale + + query = query.to(torch.float8_e4m3fn).transpose(2, 1) + + # Grab kv cache and deal with cases where no store op was run + if isinstance(key_cache, ScaledTensor) and isinstance( + value_cache, ScaledTensor + ): + # Store op was run + k_scale = key_cache._scale + v_scale = value_cache._scale + key_cache = key_cache._data + value_cache = value_cache._data + else: + # Store op wasn't run (e.g. encoders, use_cache=False) + k_scale = (torch.abs(key_cache).max() / K_RANGE).to(dtype=torch.float32) + v_scale = (torch.abs(value_cache).max() / V_RANGE).to(dtype=torch.float32) + key_cache = (key_cache / k_scale).to(torch.float8_e4m3fn) + value_cache = (value_cache / v_scale).to(torch.float8_e4m3fn) + + # If store wasn't run, we need to transpose the tensors here + # TODO: Refactor FMS to avoid edge cases where this fails; add use_cache param here + if key_cache.shape[1] != kvheads and key_cache.shape[2] == kvheads: + key_cache = key_cache.transpose(2, 1) + value_cache = value_cache.transpose(2, 1) + + # Most of the code that follows is a copy of Pytorch SDPA, with fp8 additions + mask = attn_kwargs.get("mask", None) + if mask is not None: + # Our expected mask format is bs x q_len x k_len, so to make it broadcastable + # we need to create the nheads dimension + while len(mask.size()) != 4: # expects bs (x nheads) x q_len x kv_len + mask = mask.unsqueeze(1) + + L, S = query.size(-2), key_cache.size(-2) + scale_factor = ( + 1 / math.sqrt(query.size(-1)) if scale_factor is None else scale_factor ) - - attn_weight = ( - torch.ops.spyre.scaled_bmm( - query, - key_cache.transpose(-2, -1), - q_scale, - k_scale, - out_dtype=orig_dtype, - use_fast_accum=True, + attn_bias = torch.zeros(L, S, dtype=orig_dtype, device=query.device) + if attn_kwargs.get("is_causal_mask", False): + assert mask is None + temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(torch.float32) + + if mask is not None: + if mask.dtype == torch.bool: + attn_bias.masked_fill_(mask.logical_not(), float("-inf")) + else: + attn_bias = mask + attn_bias + + expansion = nheads // kvheads + if expansion > 1: + key_cache = key_cache.repeat_interleave( + query.size(-3) // key_cache.size(-3), -3 + ) + value_cache = value_cache.repeat_interleave( + query.size(-3) // value_cache.size(-3), -3 + ) + + attn_weight = ( + torch.ops.spyre.scaled_bmm( + query, + key_cache.transpose(-2, -1), + q_scale, + k_scale, + out_dtype=orig_dtype, + use_fast_accum=True, + ) + * scale_factor ) - * scale_factor + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + attn_weight = torch.dropout(attn_weight, p_dropout, train=True) + # Do matmul in orig_dtype + attn = attn_weight @ (value_cache.to(dtype=orig_dtype) * v_scale) + + attn = attn.to(orig_dtype).transpose(2, 1).contiguous() + return attn + + register_attention_op( + "math_fp8", + _math_fp8_store_op, + _math_fp8_compute_op, + update_attn_kwargs_op=_sdpa_update_attn_kwargs, ) - attn_weight += attn_bias - attn_weight = torch.softmax(attn_weight, dim=-1) - attn_weight = torch.dropout(attn_weight, p_dropout, train=True) - # Do matmul in orig_dtype - attn = attn_weight @ (value_cache.to(dtype=orig_dtype) * v_scale) - attn = attn.to(orig_dtype).transpose(2, 1).contiguous() - return attn - - -register_attention_op( - "math_fp8", - _math_fp8_store_op, - _math_fp8_compute_op, - update_attn_kwargs_op=_sdpa_update_attn_kwargs, -) + def _spyre_scaled_paged_store_op( + keys: torch.Tensor, + values: torch.Tensor, + key_cache: Optional[torch.Tensor], + value_cache: Optional[torch.Tensor], + **attn_kwargs: Unpack[SpyrePagedAttentionKwargs], + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # For paged store, we must have pre-allocated the kv-cache + assert key_cache is not None and isinstance( + key_cache, ScaledTensor + ), "kv cache must be preallocated" + assert value_cache is not None and isinstance( + value_cache, ScaledTensor + ), "kv cache must be preallocated" + if not key_cache._scaled: + key_cache._scale = (torch.abs(keys).max() / 200.0).to(dtype=torch.float32) + value_cache._scale = (torch.abs(values).max() / 100.0).to( + dtype=torch.float32 + ) + + result_key_cache_data, result_value_cache_data = ( + torch.ops.spyre.scaled_paged_attn_store( + keys, + values, + key_cache._data, + value_cache._data, + key_cache._scale, + value_cache._scale, + attn_kwargs["slot_mapping"], + ) + ) + result_key_cache = _construct_fp8_cache(result_key_cache_data, key_cache._scale) + result_value_cache = _construct_fp8_cache( + result_value_cache_data, value_cache._scale + ) -def _spyre_scaled_paged_store_op( - keys: torch.Tensor, - values: torch.Tensor, - key_cache: Optional[torch.Tensor], - value_cache: Optional[torch.Tensor], - **attn_kwargs: Unpack[SpyrePagedAttentionKwargs], -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - # For paged store, we must have pre-allocated the kv-cache - assert key_cache is not None and isinstance( - key_cache, ScaledTensor - ), "kv cache must be preallocated" - assert value_cache is not None and isinstance( - value_cache, ScaledTensor - ), "kv cache must be preallocated" - if not key_cache._scaled: - key_cache._scale = (torch.abs(keys).max() / 200.0).to(dtype=torch.float32) - value_cache._scale = (torch.abs(values).max() / 100.0).to(dtype=torch.float32) + # for prefill, we want to return the original keys/values + if attn_kwargs.get("block_table", None) is None: + return keys, values, result_key_cache, result_value_cache + return ( + result_key_cache, + result_value_cache, + result_key_cache, + result_value_cache, + ) - result_key_cache_data, result_value_cache_data = ( - torch.ops.spyre.scaled_paged_attn_store( - keys, - values, + def _spyre_scaled_paged_compute_op( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + nheads: int, # pylint: disable=unused-argument + kvheads: int, # pylint: disable=unused-argument + p_dropout: float, # pylint: disable=unused-argument + scale_factor: Optional[float], + **attn_kwargs, + ) -> torch.Tensor: + assert isinstance(key_cache, ScaledTensor), "kv cache must be scaled" + assert isinstance(value_cache, ScaledTensor), "kv cache must be scaled" + if scale_factor is None: + scale_factor = 1 / math.sqrt(query.shape[-1]) + return torch.ops.spyre.scaled_paged_attn_compute( + query, key_cache._data, value_cache._data, key_cache._scale, value_cache._scale, - attn_kwargs["slot_mapping"], + scale_factor, + attn_kwargs["current_tkv_mask"], + attn_kwargs["left_padded_prompt_mask"], + attn_kwargs["block_table"], ) - ) - - result_key_cache = _construct_fp8_cache(result_key_cache_data, key_cache._scale) - result_value_cache = _construct_fp8_cache( - result_value_cache_data, value_cache._scale - ) - # for prefill, we want to return the original keys/values - if attn_kwargs.get("block_table", None) is None: - return keys, values, result_key_cache, result_value_cache - return ( - result_key_cache, - result_value_cache, - result_key_cache, - result_value_cache, + register_attention_op( + "spyre_paged_attn_fp8", + _spyre_scaled_paged_store_op, + compute_op=_math_fp8_compute_op, + is_prefill_op=lambda **attn_kwargs: attn_kwargs.get("block_table", None) + is None, + compute_decode_op=_spyre_scaled_paged_compute_op, + validate_attn_kwargs_op=__spyre_paged_validate_attn_kwargs_op, ) - - -def _spyre_scaled_paged_compute_op( - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - nheads: int, # pylint: disable=unused-argument - kvheads: int, # pylint: disable=unused-argument - p_dropout: float, # pylint: disable=unused-argument - scale_factor: Optional[float], - **attn_kwargs, -) -> torch.Tensor: - assert isinstance(key_cache, ScaledTensor), "kv cache must be scaled" - assert isinstance(value_cache, ScaledTensor), "kv cache must be scaled" - if scale_factor is None: - scale_factor = 1 / math.sqrt(query.shape[-1]) - return torch.ops.spyre.scaled_paged_attn_compute( - query, - key_cache._data, - value_cache._data, - key_cache._scale, - value_cache._scale, - scale_factor, - attn_kwargs["current_tkv_mask"], - attn_kwargs["left_padded_prompt_mask"], - attn_kwargs["block_table"], - ) - - -register_attention_op( - "spyre_paged_attn_fp8", - _spyre_scaled_paged_store_op, - compute_op=_math_fp8_compute_op, - is_prefill_op=lambda **attn_kwargs: attn_kwargs.get("block_table", None) is None, - compute_decode_op=_spyre_scaled_paged_compute_op, - validate_attn_kwargs_op=__spyre_paged_validate_attn_kwargs_op, -) diff --git a/fms_mo/aiu_addons/fp8/fp8_linear.py b/fms_mo/aiu_addons/fp8/fp8_linear.py index 4b86e729..de13bb3f 100644 --- a/fms_mo/aiu_addons/fp8/fp8_linear.py +++ b/fms_mo/aiu_addons/fp8/fp8_linear.py @@ -17,14 +17,6 @@ from typing import Any, Mapping # Third Party -from fms.modules.linear import ( - LinearModuleShardingInfo, - LinearParameterShardingInfo, - register_linear_type_to_module_map, - register_linear_type_to_sharding_map, - shard_base_linear, -) -from fms.modules.tp import ShardType, TPModule import torch # Local @@ -35,8 +27,16 @@ # open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482 # Gated torchao imports for FP8 implementation -if available_packages["torchao"]: +if available_packages["fms"] and available_packages["torchao"]: # Third Party + from fms.modules.linear import ( + LinearModuleShardingInfo, + LinearParameterShardingInfo, + register_linear_type_to_module_map, + register_linear_type_to_sharding_map, + shard_base_linear, + ) + from fms.modules.tp import ShardType, TPModule from torchao.dtypes.affine_quantized_tensor import ( AffineQuantizedTensor, to_affine_quantized_floatx, From 3a373ff4e40bad89ce263ce24ff5a3f508f1bbbe Mon Sep 17 00:00:00 2001 From: Antoni Viros i Martin Date: Wed, 2 Jul 2025 16:55:47 +0000 Subject: [PATCH 11/15] Add choice for scaled bmm Signed-off-by: Antoni Viros i Martin --- fms_mo/aiu_addons/fp8/fp8_attn.py | 41 +++++++++++++++++++------------ 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/fms_mo/aiu_addons/fp8/fp8_attn.py b/fms_mo/aiu_addons/fp8/fp8_attn.py index 485a6f65..ea86e08a 100644 --- a/fms_mo/aiu_addons/fp8/fp8_attn.py +++ b/fms_mo/aiu_addons/fp8/fp8_attn.py @@ -42,6 +42,7 @@ class MathFP8AttentionKwargs(AttentionKwargs): mask: NotRequired[torch.Tensor] do_scale_q: bool + do_scaled_bmm: bool is_causal_mask: bool # TODO: Figure out better scales for AIU? These come from vLLM @@ -110,14 +111,17 @@ def _math_fp8_compute_op( the custom scaled BMM op that was pre-registered for torch.compile.""" orig_dtype = query.dtype + do_scaled_bmm = attn_kwargs.get("do_scaled_bmm", False) - # Scaling the Q tensor is optional - q_scale = torch.tensor(1.0, dtype=torch.float32, device=query.device) - if attn_kwargs.get("do_scale_q", False): - q_scale.copy_(torch.abs(query).max() / Q_RANGE) - query = query / q_scale + if do_scaled_bmm: + # Scaling the Q tensor is optional + q_scale = torch.tensor(1.0, dtype=torch.float32, device=query.device) + if attn_kwargs.get("do_scale_q", False): + q_scale.copy_(torch.abs(query).max() / Q_RANGE) + query = query / q_scale - query = query.to(torch.float8_e4m3fn).transpose(2, 1) + query = query.to(torch.float8_e4m3fn) + query = query.transpose(2, 1) # Grab kv cache and deal with cases where no store op was run if isinstance(key_cache, ScaledTensor) and isinstance( @@ -175,17 +179,22 @@ def _math_fp8_compute_op( query.size(-3) // value_cache.size(-3), -3 ) - attn_weight = ( - torch.ops.spyre.scaled_bmm( - query, - key_cache.transpose(-2, -1), - q_scale, - k_scale, - out_dtype=orig_dtype, - use_fast_accum=True, + if do_scaled_bmm: + attn_weight = ( + torch.ops.spyre.scaled_bmm( + query, + key_cache.transpose(-2, -1), + q_scale, + k_scale, + out_dtype=orig_dtype, + use_fast_accum=True, + ) + * scale_factor ) - * scale_factor - ) + else: + key_t = (key_cache.to(dtype=orig_dtype) * k_scale).transpose(-2, -1) + attn_weight = query @ key_t + attn_weight *= scale_factor attn_weight += attn_bias attn_weight = torch.softmax(attn_weight, dim=-1) attn_weight = torch.dropout(attn_weight, p_dropout, train=True) From 43372e4b97b62796d695ac4d4f9294c635b42198 Mon Sep 17 00:00:00 2001 From: Antoni Viros i Martin Date: Wed, 2 Jul 2025 17:28:08 +0000 Subject: [PATCH 12/15] Improve package checking to allow editable builds Signed-off-by: Antoni Viros i Martin --- fms_mo/utils/import_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fms_mo/utils/import_utils.py b/fms_mo/utils/import_utils.py index 7eaddefb..a695e39d 100644 --- a/fms_mo/utils/import_utils.py +++ b/fms_mo/utils/import_utils.py @@ -17,11 +17,11 @@ """ # Standard +from importlib.util import find_spec import pkgutil import sys # Third Party -from transformers.utils.import_utils import _is_package_available import torch all_available_modules = [] @@ -47,7 +47,7 @@ available_packages = {} for package in optional_packages: available_packages[package] = ( - _is_package_available(package) or package in all_available_modules + find_spec(package) is not None or package in all_available_modules ) # cutlass is detected through torch.ops.cutlass_gemm From c5a55fc8179b7f7126a97472d97ba63e70b0b381 Mon Sep 17 00:00:00 2001 From: Antoni Viros i Martin Date: Wed, 2 Jul 2025 18:00:41 +0000 Subject: [PATCH 13/15] Add CPU fallback for scaled_mm Signed-off-by: Antoni Viros i Martin --- fms_mo/aiu_addons/fp8/fp8_spyre_op.py | 28 +++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/fms_mo/aiu_addons/fp8/fp8_spyre_op.py b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py index 79024245..696aab25 100644 --- a/fms_mo/aiu_addons/fp8/fp8_spyre_op.py +++ b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py @@ -13,6 +13,9 @@ # limitations under the License. """Torch registration of FP8xFP8 operation for attention BMMs.""" +# Standard +from typing import Optional + # Third Party from torch import Tensor import torch @@ -26,6 +29,31 @@ # open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482 +aten = torch.ops.aten +DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined] + + +@torch.library.register_kernel("aten::_scaled_mm", "cpu") +def _scaled_mm_cpu( + mat1: Tensor, + mat2: Tensor, + scale1: Tensor, + scale2: Tensor, + bias: Optional[Tensor] = None, + scale_result: Optional[Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + use_fast_accum: bool = False, +) -> Tensor: + if out_dtype is None: + out_dtype = torch.float32 + mat1 = (mat1.to(dtype=out_dtype) * scale1).to(dtype=out_dtype) + mat2 = (mat2.to(dtype=out_dtype) * scale2).to(dtype=out_dtype) + + if bias is not None: + return torch.addmm(bias, mat1, mat2).to(dtype=out_dtype) + return torch.mm(mat1, mat2).to(dtype=out_dtype) + + @torch.library.custom_op("spyre::scaled_bmm", mutates_args=()) def spyre_scaled_bmm( mat1: Tensor, From 2b56e0f7d83a1115d86f16a0bacbbe8c1f0fd403 Mon Sep 17 00:00:00 2001 From: Antoni Viros i Martin Date: Thu, 3 Jul 2025 14:10:31 +0000 Subject: [PATCH 14/15] Clean repr for fp8linear Signed-off-by: Antoni Viros i Martin --- fms_mo/aiu_addons/fp8/fp8_linear.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/fms_mo/aiu_addons/fp8/fp8_linear.py b/fms_mo/aiu_addons/fp8/fp8_linear.py index de13bb3f..6062665b 100644 --- a/fms_mo/aiu_addons/fp8/fp8_linear.py +++ b/fms_mo/aiu_addons/fp8/fp8_linear.py @@ -244,7 +244,21 @@ def __repr__(self) -> str: return ( f"{self.__class__.__name__}" f"(in={self.in_features}, out={self.out_features}, " - f"bias={self.has_bias}, fp8_config={self.linear_config})" + f"bias={self.has_bias}, fp8_config={self._repr_fp8_config()})" + ) + + def _repr_fp8_config(self) -> str: + return ( + "(" + "acts: (" + f"dynamic: {self.linear_config['input_activations']['dynamic']}, " + f"strategy: {self.linear_config['input_activations']['strategy']}" + "), " + "weights: (" + f"dynamic: {self.linear_config['weights']['dynamic']}, " + f"strategy: {self.linear_config['weights']['strategy']}" + ")" + ")" ) def get_fp8_linear( @@ -266,14 +280,14 @@ def shard_fp8_linear( sharding | param | shard | dim | ----------+----------------+-------+-----| colwise | weight | Y | 0 | - | weight_scale | N | - | - | input_scale | N | - | - | bias | Y | 0 | + | weight_scale | N | - | + | input_scale | N | - | + | bias | Y | 0 | ----------+----------------+-------+-----| rowwise | weight | Y | 1 | - | weight_scale | Y/N | 0/- | - | input_scale | Y/N | 0/- | - | bias | 0 | - | + | weight_scale | Y/N | 0/- | + | input_scale | Y/N | 0/- | + | bias | 0 | - | """ param_sharding_info: dict[str, dict[str, LinearParameterShardingInfo]] = {} From 42528b0906cf3a3d9fe2444bc45937183516444f Mon Sep 17 00:00:00 2001 From: Antoni Viros i Martin Date: Thu, 3 Jul 2025 17:57:23 +0000 Subject: [PATCH 15/15] Add further skips to test Signed-off-by: Antoni Viros i Martin --- tests/aiu_addons/test_fp8_addon.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/aiu_addons/test_fp8_addon.py b/tests/aiu_addons/test_fp8_addon.py index 9d921e4e..81d263b3 100644 --- a/tests/aiu_addons/test_fp8_addon.py +++ b/tests/aiu_addons/test_fp8_addon.py @@ -18,6 +18,7 @@ import torch # Local +from fms_mo.prep import available_packages import fms_mo.aiu_addons.fp8.fp8_spyre_op # pylint: disable=unused-import @@ -33,6 +34,10 @@ def test_fp8_registration() -> None: # This test requires an H100 or higher GPU to run +@pytest.mark.skipif( + not available_packages["torchao"] or not available_packages["fms"], + reason="FMS and torchao required to run this test", +) @pytest.mark.skipif( not torch.cuda.is_available() or (torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 9)),