diff --git a/fms_mo/aiu_addons/__init__.py b/fms_mo/aiu_addons/__init__.py index e69de29b..30367a40 100644 --- a/fms_mo/aiu_addons/__init__.py +++ b/fms_mo/aiu_addons/__init__.py @@ -0,0 +1,60 @@ +def _infer_quantization_config(quant_config: dict) -> dict | None: + """Construct linear_config dictionary carrying FP8 configuration for FMS. + + There's many quantization packages compatible with HF + We initially focus on llm-compressor as it is the one used in FMS-MO + + llm-compressor saves its checkpoints with quant_method = compressed-tensors + quantization_status tells us whether the model has already been quantized + We only support loading already quantized models (compressed status) + """ + + if ( + quant_config["quant_method"] == "compressed-tensors" + and quant_config["quantization_status"] == "compressed" + ): + # FP8 quantization will have FP8 weights + # We assume a single quantization group (group_0), to follow fms-mo checkpoints + # num_bits and type tells us "float" with "8" bits, aka FP8 + if ( + quant_config["config_groups"]["group_0"]["weights"]["type"] == "float" + and quant_config["config_groups"]["group_0"]["weights"]["num_bits"] == 8 + ): + # First, import required FP8 linear classes from fms-mo + # Local + import fms_mo.aiu_addons.fp8.fp8_adapter # pylint: disable=unused-import + import fms_mo.aiu_addons.fp8.fp8_linear # pylint: disable=unused-import + + # This is used by get_linear to decide whether a linear layer + # will be quantized or not inside the model + def fp8_linear_type(name: str) -> str: + # We need to translate HF names to FMS names + translations = { + "lm_head": "head", + } + for ignored_layer in quant_config["ignore"]: + assert isinstance(ignored_layer, str) + fms_ign_layer = translations.get(ignored_layer, ignored_layer) + if name in fms_ign_layer: + return "torch_linear" + for pattern in quant_config["config_groups"]["group_0"]["targets"]: + # Special case from llm-compressor that covers all linear layers + # not in the ignore pattern + assert isinstance(pattern, str) + if pattern == "Linear": + return "fp8" + if name in translations.get(pattern, pattern): + return "fp8" + return "torch_linear" + + return { + "linear_type": fp8_linear_type, + "input_activations": quant_config["config_groups"]["group_0"][ + "input_activations" + ], + "output_activations": quant_config["config_groups"]["group_0"][ + "output_activations" + ], + "weights": quant_config["config_groups"]["group_0"]["weights"], + } + return None diff --git a/fms_mo/aiu_addons/fp8/__init__.py b/fms_mo/aiu_addons/fp8/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fms_mo/aiu_addons/fp8/fp8_adapter.py b/fms_mo/aiu_addons/fp8/fp8_adapter.py new file mode 100644 index 00000000..ba44fce8 --- /dev/null +++ b/fms_mo/aiu_addons/fp8/fp8_adapter.py @@ -0,0 +1,73 @@ +# Copyright The FMS Model Optimizer Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Implement and register FMS adapters for FP8 checkpoint loading.""" + +# Standard +from typing import Any, Mapping +import functools + +# Local +from fms_mo.prep import available_packages + +if available_packages["fms"]: + # Third Party + from fms.modules.linear import get_linear_type + from fms.utils import serialization + from fms.utils.config import ModelConfig + + # pylint: disable=unused-argument + # Retaining kwargs input arguments for consistency with other adapter steps. + # TODO: may be shared with gptq llama + def _hf_fp8_check( + input_sd: Mapping[str, Any], + model_config: ModelConfig | None = None, + checkpoint_is_fused: bool = False, + **kwargs, + ) -> Mapping[str, Any]: + """Implementation of adapter step for FMS: ensure that when FP8 quantization + is in use, weights are fused like the model checkpoint. + """ + + has_fused_weights = True + linear_type = "torch_linear" + if model_config: + if not model_config.fused_weights: + has_fused_weights = False + if model_config.linear_config: + linear_type = model_config.linear_config["linear_type"] + if callable(linear_type): + # Calling this function with "any" guarantees "fp8" to be returned + # when loading an HF fp8 checkpoint, and never in any other condition + linear_type = get_linear_type(model_config.linear_config, "any") + + if "fp8" in linear_type and has_fused_weights != checkpoint_is_fused: + raise ValueError( + "FP8 HF llama checkpoints cannot be loaded into a model with fused weights" + ) + + return input_sd + + serialization.register_adapter_step( + "llama", + "hf_fp8_check", + functools.partial(_hf_fp8_check, checkpoint_is_fused=False), + ) + serialization.extend_adapter("llama", "hf", ["hf_fp8_check"]) + + serialization.register_adapter_step( + "granite", + "hf_fp8_check", + functools.partial(_hf_fp8_check, checkpoint_is_fused=False), + ) + serialization.extend_adapter("granite", "hf", ["hf_fp8_check"]) diff --git a/fms_mo/aiu_addons/fp8/fp8_attn.py b/fms_mo/aiu_addons/fp8/fp8_attn.py new file mode 100644 index 00000000..ea86e08a --- /dev/null +++ b/fms_mo/aiu_addons/fp8/fp8_attn.py @@ -0,0 +1,295 @@ +# Copyright The FMS Model Optimizer Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""FMS registration of attention BMM operation using torch-registered scaled BMM.""" + +# Standard +from typing import NotRequired, Optional, Unpack +import math + +# Third Party +import torch + +# Local +from fms_mo.aiu_addons.fp8.fp8_utils import ScaledTensor +from fms_mo.prep import available_packages +import fms_mo.aiu_addons.fp8.fp8_spyre_op # pylint: disable=unused-import + +if available_packages["fms"]: + # Third Party + from fms.modules.attention import ( + AttentionKwargs, + _sdpa_update_attn_kwargs, + register_attention_op, + ) + from fms.utils.spyre.paged import ( + SpyrePagedAttentionKwargs, + __spyre_paged_validate_attn_kwargs_op, + ) + + class MathFP8AttentionKwargs(AttentionKwargs): + """TypedDict for FP8 attention.""" + + mask: NotRequired[torch.Tensor] + do_scale_q: bool + do_scaled_bmm: bool + is_causal_mask: bool + + # TODO: Figure out better scales for AIU? These come from vLLM + Q_RANGE = 200.0 + K_RANGE = 200.0 + V_RANGE = 100.0 + + def _construct_fp8_cache(tensor: torch.Tensor, scale: torch.Tensor) -> ScaledTensor: + """Construct the custom object to save KV cache with its scales.""" + return ScaledTensor(tensor, scale, True) + + def _math_fp8_store_op( + keys: torch.Tensor, # pylint: disable=unused-argument + values: torch.Tensor, + key_cache: torch.Tensor | None, + value_cache: torch.Tensor | None, + **attn_kwargs: Unpack[MathFP8AttentionKwargs], + ) -> tuple[ScaledTensor, ScaledTensor, ScaledTensor, ScaledTensor]: + """Implement math of KV cache storing.""" + + # Grab scale from kv-cache if already there, compute dynamically otherwise + if isinstance(key_cache, ScaledTensor) and isinstance( + value_cache, ScaledTensor + ): + k_scale = key_cache._scale + v_scale = value_cache._scale + else: + k_scale = (torch.abs(keys).max() / K_RANGE).to(dtype=torch.float32) + v_scale = (torch.abs(values).max() / V_RANGE).to(dtype=torch.float32) + + # Scale kv tensors for storage + keys = (keys / k_scale).to(torch.float8_e4m3fn).transpose(2, 1) + values = (values / v_scale).to(torch.float8_e4m3fn).transpose(2, 1) + + if ( + isinstance(key_cache, ScaledTensor) + and isinstance(value_cache, ScaledTensor) + and value_cache.numel() > 0 + ): + key_cache = torch.cat((key_cache._data, keys), dim=2) + value_cache = torch.cat((value_cache._data, values), dim=2) + key_cache = _construct_fp8_cache(key_cache, k_scale) + value_cache = _construct_fp8_cache(value_cache, v_scale) + return ( + key_cache, + value_cache, + key_cache, + value_cache, + ) + # If it's a new kv cache, ensure it's contiguous for spyre use cases + keys = _construct_fp8_cache(keys.contiguous(), k_scale) + values = _construct_fp8_cache(values.contiguous(), v_scale) + return (keys, values, keys, values) + + def _math_fp8_compute_op( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + nheads: int, + kvheads: int, + p_dropout: float, + scale_factor: float | None, + **attn_kwargs: Unpack[MathFP8AttentionKwargs], + ) -> torch.Tensor: + """Implement computation of scaled dot product attention, leveraging + the custom scaled BMM op that was pre-registered for torch.compile.""" + + orig_dtype = query.dtype + do_scaled_bmm = attn_kwargs.get("do_scaled_bmm", False) + + if do_scaled_bmm: + # Scaling the Q tensor is optional + q_scale = torch.tensor(1.0, dtype=torch.float32, device=query.device) + if attn_kwargs.get("do_scale_q", False): + q_scale.copy_(torch.abs(query).max() / Q_RANGE) + query = query / q_scale + + query = query.to(torch.float8_e4m3fn) + query = query.transpose(2, 1) + + # Grab kv cache and deal with cases where no store op was run + if isinstance(key_cache, ScaledTensor) and isinstance( + value_cache, ScaledTensor + ): + # Store op was run + k_scale = key_cache._scale + v_scale = value_cache._scale + key_cache = key_cache._data + value_cache = value_cache._data + else: + # Store op wasn't run (e.g. encoders, use_cache=False) + k_scale = (torch.abs(key_cache).max() / K_RANGE).to(dtype=torch.float32) + v_scale = (torch.abs(value_cache).max() / V_RANGE).to(dtype=torch.float32) + key_cache = (key_cache / k_scale).to(torch.float8_e4m3fn) + value_cache = (value_cache / v_scale).to(torch.float8_e4m3fn) + + # If store wasn't run, we need to transpose the tensors here + # TODO: Refactor FMS to avoid edge cases where this fails; add use_cache param here + if key_cache.shape[1] != kvheads and key_cache.shape[2] == kvheads: + key_cache = key_cache.transpose(2, 1) + value_cache = value_cache.transpose(2, 1) + + # Most of the code that follows is a copy of Pytorch SDPA, with fp8 additions + mask = attn_kwargs.get("mask", None) + if mask is not None: + # Our expected mask format is bs x q_len x k_len, so to make it broadcastable + # we need to create the nheads dimension + while len(mask.size()) != 4: # expects bs (x nheads) x q_len x kv_len + mask = mask.unsqueeze(1) + + L, S = query.size(-2), key_cache.size(-2) + scale_factor = ( + 1 / math.sqrt(query.size(-1)) if scale_factor is None else scale_factor + ) + attn_bias = torch.zeros(L, S, dtype=orig_dtype, device=query.device) + if attn_kwargs.get("is_causal_mask", False): + assert mask is None + temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(torch.float32) + + if mask is not None: + if mask.dtype == torch.bool: + attn_bias.masked_fill_(mask.logical_not(), float("-inf")) + else: + attn_bias = mask + attn_bias + + expansion = nheads // kvheads + if expansion > 1: + key_cache = key_cache.repeat_interleave( + query.size(-3) // key_cache.size(-3), -3 + ) + value_cache = value_cache.repeat_interleave( + query.size(-3) // value_cache.size(-3), -3 + ) + + if do_scaled_bmm: + attn_weight = ( + torch.ops.spyre.scaled_bmm( + query, + key_cache.transpose(-2, -1), + q_scale, + k_scale, + out_dtype=orig_dtype, + use_fast_accum=True, + ) + * scale_factor + ) + else: + key_t = (key_cache.to(dtype=orig_dtype) * k_scale).transpose(-2, -1) + attn_weight = query @ key_t + attn_weight *= scale_factor + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + attn_weight = torch.dropout(attn_weight, p_dropout, train=True) + # Do matmul in orig_dtype + attn = attn_weight @ (value_cache.to(dtype=orig_dtype) * v_scale) + + attn = attn.to(orig_dtype).transpose(2, 1).contiguous() + return attn + + register_attention_op( + "math_fp8", + _math_fp8_store_op, + _math_fp8_compute_op, + update_attn_kwargs_op=_sdpa_update_attn_kwargs, + ) + + def _spyre_scaled_paged_store_op( + keys: torch.Tensor, + values: torch.Tensor, + key_cache: Optional[torch.Tensor], + value_cache: Optional[torch.Tensor], + **attn_kwargs: Unpack[SpyrePagedAttentionKwargs], + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # For paged store, we must have pre-allocated the kv-cache + assert key_cache is not None and isinstance( + key_cache, ScaledTensor + ), "kv cache must be preallocated" + assert value_cache is not None and isinstance( + value_cache, ScaledTensor + ), "kv cache must be preallocated" + if not key_cache._scaled: + key_cache._scale = (torch.abs(keys).max() / 200.0).to(dtype=torch.float32) + value_cache._scale = (torch.abs(values).max() / 100.0).to( + dtype=torch.float32 + ) + + result_key_cache_data, result_value_cache_data = ( + torch.ops.spyre.scaled_paged_attn_store( + keys, + values, + key_cache._data, + value_cache._data, + key_cache._scale, + value_cache._scale, + attn_kwargs["slot_mapping"], + ) + ) + + result_key_cache = _construct_fp8_cache(result_key_cache_data, key_cache._scale) + result_value_cache = _construct_fp8_cache( + result_value_cache_data, value_cache._scale + ) + + # for prefill, we want to return the original keys/values + if attn_kwargs.get("block_table", None) is None: + return keys, values, result_key_cache, result_value_cache + return ( + result_key_cache, + result_value_cache, + result_key_cache, + result_value_cache, + ) + + def _spyre_scaled_paged_compute_op( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + nheads: int, # pylint: disable=unused-argument + kvheads: int, # pylint: disable=unused-argument + p_dropout: float, # pylint: disable=unused-argument + scale_factor: Optional[float], + **attn_kwargs, + ) -> torch.Tensor: + assert isinstance(key_cache, ScaledTensor), "kv cache must be scaled" + assert isinstance(value_cache, ScaledTensor), "kv cache must be scaled" + if scale_factor is None: + scale_factor = 1 / math.sqrt(query.shape[-1]) + return torch.ops.spyre.scaled_paged_attn_compute( + query, + key_cache._data, + value_cache._data, + key_cache._scale, + value_cache._scale, + scale_factor, + attn_kwargs["current_tkv_mask"], + attn_kwargs["left_padded_prompt_mask"], + attn_kwargs["block_table"], + ) + + register_attention_op( + "spyre_paged_attn_fp8", + _spyre_scaled_paged_store_op, + compute_op=_math_fp8_compute_op, + is_prefill_op=lambda **attn_kwargs: attn_kwargs.get("block_table", None) + is None, + compute_decode_op=_spyre_scaled_paged_compute_op, + validate_attn_kwargs_op=__spyre_paged_validate_attn_kwargs_op, + ) diff --git a/fms_mo/aiu_addons/fp8/fp8_linear.py b/fms_mo/aiu_addons/fp8/fp8_linear.py new file mode 100644 index 00000000..6062665b --- /dev/null +++ b/fms_mo/aiu_addons/fp8/fp8_linear.py @@ -0,0 +1,334 @@ +# Copyright The FMS Model Optimizer Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Implement FP8 linear module to be loaded via FMS.""" + +# Standard +from typing import Any, Mapping + +# Third Party +import torch + +# Local +from fms_mo.prep import available_packages + +# pylint: disable=not-callable +# torch.nn.functional.linear not recognized as callable +# open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482 + +# Gated torchao imports for FP8 implementation +if available_packages["fms"] and available_packages["torchao"]: + # Third Party + from fms.modules.linear import ( + LinearModuleShardingInfo, + LinearParameterShardingInfo, + register_linear_type_to_module_map, + register_linear_type_to_sharding_map, + shard_base_linear, + ) + from fms.modules.tp import ShardType, TPModule + from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + to_affine_quantized_floatx, + to_affine_quantized_floatx_static, + ) + from torchao.dtypes.floatx.float8_layout import ( + Float8AQTTensorImpl, + Float8Layout, + Float8MMConfig, + preprocess_data, + preprocess_scale, + ) + from torchao.dtypes.utils import get_out_shape + from torchao.float8.inference import ( + _is_rowwise_scaled, + addmm_float8_unwrapped_inference, + ) + from torchao.quantization.granularity import PerRow, PerTensor + from torchao.quantization.observer import get_block_size + from torchao.quantization.quant_primitives import ZeroPointDomain + + class FP8Linear(torch.nn.Module): + """Class handles FP8 weights loading and uses torchao for the matmuls.""" + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_config: Mapping[str, Any], + ): + super().__init__() + + self.in_features = in_features + self.out_features = out_features + self.has_bias = bias + self.linear_config = linear_config + + assert ( + self.linear_config["weights"] is not None + ), "Weights must always be quantized for FP8Linear" + assert self.linear_config["weights"][ + "symmetric" + ], "We only support symmetric weights for now" + assert not self.linear_config["weights"][ + "dynamic" + ], "We only support pre-quantized weights for now" + + self.weight = torch.nn.Parameter( + torch.zeros(out_features, in_features, dtype=torch.float8_e4m3fn), + requires_grad=False, + ) + + weight_scale_shape = ( + (1,) + if self.linear_config["weights"]["strategy"] == "tensor" + else (out_features, 1) + ) + self.weight_scale = torch.nn.Parameter( + torch.ones(weight_scale_shape), requires_grad=False + ) + + self.has_bias = bias + if self.has_bias: + self.bias = torch.nn.Parameter(torch.zeros((out_features,))) + + if ( + self.linear_config["input_activations"] is not None + and not self.linear_config["input_activations"]["dynamic"] + ): + input_scale_shape = ( + (1,) + if self.linear_config["input_activations"]["strategy"] == "tensor" + else (out_features, 1) + ) + self.input_scale = torch.nn.Parameter( + torch.ones(input_scale_shape), requires_grad=False + ) + + def _input_activation_quant_func_fp8( + self, + x: torch.Tensor, + activation_granularity, + activation_dtype: torch.dtype, + scale: torch.Tensor | None = None, + ): + """Quantize the input activation tensor for an aqt_float variant. + If scale is not provided, it will be dynamically calculated, otherwise the + provided scale will be used. + """ + block_size = get_block_size(x.shape, activation_granularity) + if scale is None: + activation = to_affine_quantized_floatx( + input_float=x, + block_size=block_size, + target_dtype=activation_dtype, + scale_dtype=torch.float32, + _layout=Float8Layout(mm_config=None), # Config is stored on weight + ) + else: + assert isinstance( + activation_granularity, PerTensor + ), "Static quantization only supports PerTensor granularity" + activation = to_affine_quantized_floatx_static( + input_float=x, + block_size=block_size, + scale=scale, + target_dtype=activation_dtype, + _layout=Float8Layout(mm_config=None), # Config is stored on weight + ) + return activation + + def _construct_qweight_structure(self) -> "AffineQuantizedTensor": + """Construct the torchao machinery for the fp8 matmul""" + weight_granularity = ( + PerTensor() + if self.linear_config["weights"]["strategy"] == "tensor" + else PerRow() + ) + fp8_layout = Float8Layout(Float8MMConfig(use_fast_accum=True)) + return AffineQuantizedTensor( + Float8AQTTensorImpl.from_plain( + self.weight, + self.weight_scale.squeeze().to(torch.float32), + None, + fp8_layout, + ), + get_block_size(self.weight.shape, weight_granularity), + self.weight.shape, + zero_point_domain=ZeroPointDomain.NONE, + dtype=self.weight_scale.dtype, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """If input quantization is active, compute FP8xFP8 addmm leveraging torchao + functionalities. Otherwise compute non-quantized addmm.""" + # fp8 weight tensor for torchao + qweight: AffineQuantizedTensor = self._construct_qweight_structure() + + if self.linear_config["input_activations"] is not None: + # activations are also fp8, quantize as required by model + act_granularity = ( + PerTensor() + if self.linear_config["input_activations"]["strategy"] == "tensor" + else PerRow() + ) + input_quant_kwargs = { + "activation_granularity": act_granularity, + "activation_dtype": torch.float8_e4m3fn, + } + if not self.linear_config["input_activations"]["dynamic"]: + input_quant_kwargs["scale"] = self.input_scale.squeeze().to( + torch.float32 + ) + qx = self._input_activation_quant_func_fp8(x, **input_quant_kwargs) + + # Copied from torchao _linear_fp8_act_fp8_weight_impl + # (with changes to support fp8 out) + scaled_mm_config = Float8MMConfig(use_fast_accum=True) + out_shape = get_out_shape(qx.shape, qweight.shape) + + # Weight tensor preprocessing + w_tensor_impl = qweight.tensor_impl + assert not w_tensor_impl.transposed, "Weight tensor must be contiguous" + w_data = w_tensor_impl.float8_data + w_scale = w_tensor_impl.scale + + # Input tensor preprocessing + inpt_data = qx.tensor_impl.float8_data + input_scale = qx.tensor_impl.scale + # Handle case where input tensor is more than 2D + inpt_data = inpt_data.reshape(-1, inpt_data.shape[-1]) + + # Handle rowwise case + if _is_rowwise_scaled(qweight): + assert _is_rowwise_scaled( + qx + ), "Input tensor must be rowwise block size" + w_scale = w_scale.unsqueeze(-1).T + input_scale = preprocess_scale(input_scale, qx.shape) + + # Preprocess data + inpt_data, w_data = preprocess_data( + inpt_data, w_data.T, scaled_mm_config + ) + + # Perform the computation + return addmm_float8_unwrapped_inference( + inpt_data, + input_scale, + w_data, + w_scale, + output_dtype=qx.dtype, + bias=getattr(self, "bias", None), + use_fast_accum=scaled_mm_config.use_fast_accum, + ).reshape(out_shape) + + # activations not quantized, dequant fp8 weight and do regular matmul + out = torch.nn.functional.linear( + x, qweight.dequantize(), self.bias if self.has_bias else None + ) + return out + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}" + f"(in={self.in_features}, out={self.out_features}, " + f"bias={self.has_bias}, fp8_config={self._repr_fp8_config()})" + ) + + def _repr_fp8_config(self) -> str: + return ( + "(" + "acts: (" + f"dynamic: {self.linear_config['input_activations']['dynamic']}, " + f"strategy: {self.linear_config['input_activations']['strategy']}" + "), " + "weights: (" + f"dynamic: {self.linear_config['weights']['dynamic']}, " + f"strategy: {self.linear_config['weights']['strategy']}" + ")" + ")" + ) + + def get_fp8_linear( + in_features: int, + out_features: int, + bias: bool, + linear_config: Mapping[str, Any], + ) -> FP8Linear: + """Retrieve an FP8 Linear module""" + return FP8Linear(in_features, out_features, bias, linear_config) + + def shard_fp8_linear( + tensor_values: dict[str, torch.Tensor], + tp_module: TPModule, + module_sharding_info: dict[str, LinearModuleShardingInfo], + ) -> set | None: + """ + | GPU | + sharding | param | shard | dim | + ----------+----------------+-------+-----| + colwise | weight | Y | 0 | + | weight_scale | N | - | + | input_scale | N | - | + | bias | Y | 0 | + ----------+----------------+-------+-----| + rowwise | weight | Y | 1 | + | weight_scale | Y/N | 0/- | + | input_scale | Y/N | 0/- | + | bias | 0 | - | + """ + + param_sharding_info: dict[str, dict[str, LinearParameterShardingInfo]] = {} + for module_name, module_info in module_sharding_info.items(): + linear_mod: torch.nn.Module = module_info.linear_module + weight_strategy = getattr(linear_mod, "linear_config")["input_activations"][ + "strategy" + ] + # Scales are per-row or per-tensor + # Only sharding needed when row parallel and per-row + shard_scales = weight_strategy != "tensor" and module_info.sharding_dim == 1 + params: dict[str, LinearParameterShardingInfo] = { + "weight": LinearParameterShardingInfo( + module_info.sharding_dim, ShardType.SHARD + ), + "weight_scale": LinearParameterShardingInfo( + module_info.sharding_dim, + ShardType.SHARD if shard_scales else ShardType.CLONE, + ), + } + if hasattr(linear_mod, "input_scale"): + params["input_scale"] = LinearParameterShardingInfo( + module_info.sharding_dim, + ShardType.SHARD if shard_scales else ShardType.CLONE, + ) + if hasattr(linear_mod, "bias") and linear_mod.bias is not None: + params["bias"] = LinearParameterShardingInfo( + module_info.sharding_dim, + ShardType.SHARD + if module_info.sharding_dim == 0 + else ShardType.RANK0, + ) + param_sharding_info[module_name] = params + + unused_keys = shard_base_linear( + tensor_values, + tp_module, + module_sharding_info, + param_sharding_info, + ) + return unused_keys + + register_linear_type_to_module_map("fp8", get_fp8_linear) + register_linear_type_to_sharding_map("fp8", shard_fp8_linear) diff --git a/fms_mo/aiu_addons/fp8/fp8_spyre_op.py b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py new file mode 100644 index 00000000..696aab25 --- /dev/null +++ b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py @@ -0,0 +1,253 @@ +# Copyright The FMS Model Optimizer Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Torch registration of FP8xFP8 operation for attention BMMs.""" + +# Standard +from typing import Optional + +# Third Party +from torch import Tensor +import torch +import torch.nn.functional as F + +# pylint: disable=unused-argument +# abstract op must be registered with specific I/O, even if not in use by the op function + +# pylint: disable=not-callable +# torch.nn.functional.scaled_dot_product_attention not recognized as callable +# open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482 + + +aten = torch.ops.aten +DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined] + + +@torch.library.register_kernel("aten::_scaled_mm", "cpu") +def _scaled_mm_cpu( + mat1: Tensor, + mat2: Tensor, + scale1: Tensor, + scale2: Tensor, + bias: Optional[Tensor] = None, + scale_result: Optional[Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + use_fast_accum: bool = False, +) -> Tensor: + if out_dtype is None: + out_dtype = torch.float32 + mat1 = (mat1.to(dtype=out_dtype) * scale1).to(dtype=out_dtype) + mat2 = (mat2.to(dtype=out_dtype) * scale2).to(dtype=out_dtype) + + if bias is not None: + return torch.addmm(bias, mat1, mat2).to(dtype=out_dtype) + return torch.mm(mat1, mat2).to(dtype=out_dtype) + + +@torch.library.custom_op("spyre::scaled_bmm", mutates_args=()) +def spyre_scaled_bmm( + mat1: Tensor, + mat2: Tensor, + scale1: Tensor, + scale2: Tensor, + out_dtype: torch.dtype | None = None, + use_fast_accum: bool = False, +) -> Tensor: + """Implement a custom scaled attention BMM op: a batched version of _scaled_mm. + The operations that are part of this function are not exposed to the computational + graph, but are invoked when running on non-Spyre devices. + """ + + assert ( + mat1.shape[:-2] == mat2.shape[:-2] + ), "batch dimensions must match for mat1 and mat2" + assert scale1.numel() == 1, "only per-tensor scales supported" + assert scale2.numel() == 1, "only per-tensor scales supported" + mat1 = mat1.view(-1, *mat1.shape[-2:]) + mat2 = mat2.view(-1, *mat2.shape[-2:]) + out = torch.empty( + (mat1.shape[0], mat1.shape[1], mat2.shape[2]), + dtype=out_dtype, + device=mat1.device, + ) + for b_idx in range(mat1.shape[0]): + out[b_idx] = torch._scaled_mm( + mat1[b_idx], + mat2[b_idx], + scale1, + scale2, + out_dtype=out_dtype, + use_fast_accum=use_fast_accum, + ) + return out.view(*mat1.shape[:-2], mat1.shape[1], mat2.shape[2]) + + +@spyre_scaled_bmm.register_fake +def _( + mat1: Tensor, + mat2: Tensor, + scale1: Tensor, + scale2: Tensor, + out_dtype: torch.dtype | None = None, + use_fast_accum: bool = False, +) -> Tensor: + """Template for scaled attention BMM operation. I/O retain the expected size.""" + + return torch.empty( + (*mat1.shape[:-2], mat1.shape[-2], mat2.shape[-1]), + dtype=out_dtype, + device=mat1.device, + ) + + +@torch.library.custom_op( + "spyre::scaled_paged_attn_store", mutates_args=(), device_types="cpu" +) +def scaled_paged_attn_store( + key: Tensor, + value: Tensor, + key_cache: Tensor, + value_cache: Tensor, + key_scale: Tensor, + value_scale: Tensor, + slot_mapping: Tensor, +) -> tuple[Tensor, Tensor]: + """ + FP8 CPU implementation of the Paged Attn store operation. + Scales key and value tensors, and stores them to the paged KV cache + using the same schema as vLLM. + """ + print("Should never hit") + result_key_cache = key_cache.clone() + result_value_cache = value_cache.clone() + for seq_i, slot_mapping_seq in enumerate(slot_mapping): + for tok_i, slot in enumerate(slot_mapping_seq): + block_number = slot.item() // 64 + position = slot.item() % 64 + + result_key_cache[block_number, position, :, :] = ( + key[seq_i, tok_i, :, :] / key_scale + ).to(dtype=torch.float8_e4m3fn) + result_value_cache[block_number, position, :, :] = ( + value[seq_i, tok_i, :, :] / value_scale + ).to(dtype=torch.float8_e4m3fn) + return result_key_cache, result_value_cache + + +@scaled_paged_attn_store.register_fake +def scaled_paged_attn_store_meta( + key: Tensor, + value: Tensor, + key_cache: Tensor, + value_cache: Tensor, + key_scale: Tensor, + value_scale: Tensor, + slot_mapping: Tensor, +) -> tuple[Tensor, Tensor]: + """ + Fake tensor implementation of the FP8 Paged Attn store operation. + """ + return key_cache, value_cache + + +@torch.library.custom_op( + "spyre::scaled_paged_attn_compute", mutates_args={}, device_types="cpu" +) +def scaled_paged_attn_compute( + query: Tensor, + key_cache: Tensor, + value_cache: Tensor, + key_scale: Tensor, + value_scale: Tensor, + scale: float, + current_tkv_mask: Tensor, + left_padded_prompt_mask: Tensor, + block_table: Tensor, +) -> Tensor: + """ + FP8 CPU implementation of the Paged Attn compute operation. + Implements a CPU fallback to run the kernel that has been confirmed + to match the vLLM fused kernel. + """ + print("Should never hit") + # torch.zeros(NUM_BLOCKS, BLOCK_SIZE, kvheads, head_size, dtype=model_dtype), + output = torch.zeros_like(query) + num_query_heads = query.shape[2] + num_kv_heads = value_cache.shape[2] + head_size = value_cache.shape[3] + block_size = value_cache.shape[1] + num_seqs = query.shape[0] + + block_tables_lst = block_table.cpu().tolist() + + seq_lens_lst = current_tkv_mask.cpu().tolist() + for i in range(num_seqs): + q = query[i] + block_table = block_tables_lst[i] + start_pos = int(left_padded_prompt_mask[i].item()) + seq_len = int(seq_lens_lst[i]) + + keys_lst: list[torch.Tensor] = [] + values_lst: list[torch.Tensor] = [] + for j in range(start_pos, seq_len): + block_number = int(block_table[j // block_size]) + block_offset = j % block_size + + k = key_cache[block_number, block_offset, :, :] + k = k.reshape(num_kv_heads, head_size) + keys_lst.append(k) + + v = value_cache[block_number, block_offset, :, :] + values_lst.append(v) + keys = torch.stack(keys_lst, dim=0) + values = torch.stack(values_lst, dim=0) + if num_kv_heads > 1: + # Handle MQA and GQA + keys = torch.repeat_interleave(keys, num_query_heads // num_kv_heads, dim=1) + values = torch.repeat_interleave( + values, num_query_heads // num_kv_heads, dim=1 + ) + + out = F.scaled_dot_product_attention( # noqa: E1102 + q.transpose(0, 1).unsqueeze(0), # format for sdpa + (keys.transpose(0, 1).unsqueeze(0).to(dtype=q.dtype) * key_scale).to( + dtype=q.dtype + ), # format for sdpa + (values.transpose(0, 1).unsqueeze(0).to(dtype=q.dtype) * value_scale).to( + dtype=q.dtype + ), # format for sdpa + is_causal=False, # decode assumes no causal mask + scale=scale, + ) + + out = out.view(num_query_heads, head_size) + output[i].copy_(out, non_blocking=True) + return output + + +@scaled_paged_attn_compute.register_fake +def scaled_paged_attn_compute_meta( + query: Tensor, + key_cache: Tensor, + value_cache: Tensor, + key_scale: Tensor, + value_scale: Tensor, + scale: float, + current_tkv_mask: Tensor, + left_padded_prompt_mask: Tensor, + block_table: Tensor, +) -> Tensor: + """ + Fake tensor implementation of the FP8 Paged Attn compute operation. + """ + return torch.zeros_like(query) diff --git a/fms_mo/aiu_addons/fp8/fp8_utils.py b/fms_mo/aiu_addons/fp8/fp8_utils.py new file mode 100644 index 00000000..0af7a2a9 --- /dev/null +++ b/fms_mo/aiu_addons/fp8/fp8_utils.py @@ -0,0 +1,100 @@ +# Copyright The FMS Model Optimizer Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utility functions and components for FP8 addon implementation.""" + +# Standard +import functools + +# Third Party +import torch + +# pylint: disable=unused-argument +# unusued arguments are needed for templates + + +_HANDLED_FUNCTIONS = {} + + +def _implements(torch_function): + """Register a torch function override""" + + def decorator(func): + @functools.wraps(torch_function) + def wrapper(f, types, args, kwargs): + return func(f, types, args, kwargs) + + _HANDLED_FUNCTIONS[torch_function] = wrapper + return func + + return decorator + + +class ScaledTensor(torch.Tensor): + """Representation of a quantized tensor and its scale.""" + + def __new__( + cls, + data: torch.Tensor, + scale: torch.Tensor, + scaled: bool = True, + ): + return torch.Tensor._make_wrapper_subclass( + cls, + data.size(), + strides=data.stride(), + storage_offset=data.storage_offset(), + dtype=data.dtype, + layout=data.layout, + requires_grad=data.requires_grad, + device=data.device, + ) + + def __init__( + self, + data: torch.Tensor, + scale: torch.Tensor, + scaled: bool = True, + ): + super().__init__() + self._data = data + self._scale = scale + self._scaled = scaled + + def __tensor_flatten__(self): + ctx = {"scaled": self._scaled} + return ["_data", "_scale"], ctx + + @staticmethod + def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride): + assert len(inner_tensors) == 2 + return ScaledTensor( + inner_tensors["_data"], + inner_tensors["_scale"], + metadata["scaled"], + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + if func in _HANDLED_FUNCTIONS: + return _HANDLED_FUNCTIONS[func](func, types, args, kwargs) + + arg_types = tuple(type(arg) for arg in args) + kwarg_types = {k: type(arg) for k, arg in kwargs.items()} + raise NotImplementedError( + f"{cls.__name__} dispatch: attempting to run unimplemented " + f"operator/function: {func=}, {types=}, {arg_types=}, {kwarg_types=}" + ) + + def __repr__(self): + return f"{self._data.__repr__()}\n{self._scale.__repr__()}" diff --git a/fms_mo/utils/import_utils.py b/fms_mo/utils/import_utils.py index 958f06e9..a695e39d 100644 --- a/fms_mo/utils/import_utils.py +++ b/fms_mo/utils/import_utils.py @@ -17,11 +17,11 @@ """ # Standard +from importlib.util import find_spec import pkgutil import sys # Third Party -from transformers.utils.import_utils import _is_package_available import torch all_available_modules = [] @@ -41,12 +41,13 @@ "triton", "torchvision", "huggingface_hub", + "torchao", ] available_packages = {} for package in optional_packages: available_packages[package] = ( - _is_package_available(package) or package in all_available_modules + find_spec(package) is not None or package in all_available_modules ) # cutlass is detected through torch.ops.cutlass_gemm diff --git a/pyproject.toml b/pyproject.toml index 8c7cbfe7..06d1e229 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,12 +30,12 @@ dependencies = [ "datasets>=3.0.0,<4.0", "pandas", "safetensors", -"pkginfo>1.10", +"pkginfo>1.10" ] [project.optional-dependencies] examples = ["ninja>=1.11.1.1,<2.0", "evaluate", "huggingface_hub"] -fp8 = ["llmcompressor"] +fp8 = ["llmcompressor", "torchao"] gptq = ["Cython", "gptqmodel>=1.7.3"] mx = ["microxcaling>=1.1"] opt = ["fms-model-optimizer[fp8, gptq, mx]"] diff --git a/tests/aiu_addons/test_fp8_addon.py b/tests/aiu_addons/test_fp8_addon.py new file mode 100644 index 00000000..81d263b3 --- /dev/null +++ b/tests/aiu_addons/test_fp8_addon.py @@ -0,0 +1,59 @@ +# Copyright The FMS Model Optimizer Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test suite for FMS addon introducing FP8 functionalities""" + +# Third Party +import pytest +import torch + +# Local +from fms_mo.prep import available_packages +import fms_mo.aiu_addons.fp8.fp8_spyre_op # pylint: disable=unused-import + + +def test_fp8_registration() -> None: + """ + Ensure fp8 ops are registered properly. + """ + + assert hasattr(torch.ops, "spyre") + assert hasattr(torch.ops.spyre, "scaled_bmm") + assert hasattr(torch.ops.spyre, "scaled_paged_attn_store") + assert hasattr(torch.ops.spyre, "scaled_paged_attn_compute") + + +# This test requires an H100 or higher GPU to run +@pytest.mark.skipif( + not available_packages["torchao"] or not available_packages["fms"], + reason="FMS and torchao required to run this test", +) +@pytest.mark.skipif( + not torch.cuda.is_available() + or (torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 9)), + reason="FP8 is only available on GPUs with device level 8.9 or higher", +) +def test_fp8_op() -> None: + """Validate output shapes of GPTQ W4A16 tensors. + Note: this AIU-compatible operation only returns a zero tensor of the + expected shape, it does not perform a real W4A16 matmul operation. + """ + # Local + from fms_mo.aiu_addons.fp8.fp8_attn import _math_fp8_compute_op + + query = torch.randn((1, 32, 64, 128), dtype=torch.bfloat16, device="cuda") + key = torch.randn((1, 32, 64, 128), dtype=torch.bfloat16, device="cuda") + value = torch.randn((1, 32, 64, 128), dtype=torch.bfloat16, device="cuda") + + out = _math_fp8_compute_op(query, key, value, 32, 32, 0.0, None) + assert out.size() == query.size() diff --git a/tox.ini b/tox.ini index 2f2d0441..78d792be 100644 --- a/tox.ini +++ b/tox.ini @@ -33,6 +33,8 @@ deps = pytest pylint>=2.16.2,<4.0 pylint-pydantic + ibm-fms + torchao commands = {basepython} -m pylint --load-plugins pylint_pydantic fms_mo/ tests/