diff --git a/fms_mo/aiu_addons/fp8/__init__.py b/fms_mo/aiu_addons/fp8/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fms_mo/aiu_addons/fp8/fp8_adapter.py b/fms_mo/aiu_addons/fp8/fp8_adapter.py new file mode 100644 index 00000000..7f339f72 --- /dev/null +++ b/fms_mo/aiu_addons/fp8/fp8_adapter.py @@ -0,0 +1,59 @@ +# Copyright The FMS Model Optimizer Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Implement and register FMS adapters for FP8 checkpoint loading.""" + +# Standard +from typing import Any, Mapping + +# Third Party +from fms.modules.linear import get_linear_type +from fms.utils import serialization +from fms.utils.config import ModelConfig + +# pylint: disable=unused-argument +# Retaining kwargs input arguments for consistency with other adapter steps. + + +# NOTE: this adapter step must be registered before the adapter that uses it (such as +# the llama adapter in fms.models.llama) +# TODO: may be shared with gptq llama +# TODO: generalize across architectures if possible +def _hf_fp8_llama_check( + input_sd: Mapping[str, Any], model_config: ModelConfig | None = None, **kwargs +) -> Mapping[str, Any]: + """Implementation of adapter step for FMS Llama: ensure that when FP8 quantization + is in use, weights are unfused. + """ + + has_fused_weights = True + linear_type = "torch_linear" + if model_config: + if not model_config.fused_weights: + has_fused_weights = False + if model_config.linear_config: + linear_type = model_config.linear_config["linear_type"] + if callable(linear_type): + # Calling this with "any" guarantees "fp8" to be returned + # when loading an HF fp8 checkpoint, and never in any other condition + linear_type = get_linear_type(model_config.linear_config, "any") + + if "fp8" in linear_type and has_fused_weights: + raise ValueError( + "FP8 HF llama checkpoints cannot be loaded into a model with fused weights" + ) + + return input_sd + + +serialization.register_adapter_step("llama", "hf_fp8_llama_check", _hf_fp8_llama_check) diff --git a/fms_mo/aiu_addons/fp8/fp8_attn.py b/fms_mo/aiu_addons/fp8/fp8_attn.py new file mode 100644 index 00000000..e04cce9b --- /dev/null +++ b/fms_mo/aiu_addons/fp8/fp8_attn.py @@ -0,0 +1,189 @@ +# Copyright The FMS Model Optimizer Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""FMS registration of attention BMM operation using torch-registered scaled BMM.""" + +# Standard +from typing import NotRequired, Unpack +import math + +# Third Party +from fms.modules.attention import ( + AttentionKwargs, + _sdpa_update_attn_kwargs, + register_attention_op, +) +import torch + +# Local +from fms_mo.aiu_addons.fp8.fp8_utils import ScaledTensor +import fms_mo.aiu_addons.fp8.fp8_spyre_op # pylint: disable=unused-import + + +class MathFP8AttentionKwargs(AttentionKwargs): + """TypedDict for FP8 attention.""" + + mask: NotRequired[torch.Tensor] + do_scale_q: bool + is_causal_mask: bool + + +# TODO: Figure out better scales for AIU? These come from vLLM +Q_RANGE = 200.0 +K_RANGE = 200.0 +V_RANGE = 100.0 + + +def _construct_fp8_cache(tensor: torch.Tensor, scale: torch.Tensor) -> ScaledTensor: + """Construct the custom object to save KV cache with its scales.""" + return ScaledTensor(tensor, scale) + + +def _math_fp8_store_op( + keys: torch.Tensor, # pylint: disable=unused-argument + values: torch.Tensor, + key_cache: torch.Tensor | None, + value_cache: torch.Tensor | None, + **attn_kwargs: Unpack[MathFP8AttentionKwargs], +) -> tuple[ScaledTensor, ScaledTensor, ScaledTensor, ScaledTensor]: + """Implement math of KV cache storing.""" + + if isinstance(key_cache, ScaledTensor) and isinstance(value_cache, ScaledTensor): + k_scale = key_cache._scale + v_scale = value_cache._scale + else: + k_scale = (torch.abs(keys).max() / K_RANGE).to(dtype=torch.float32) + v_scale = (torch.abs(values).max() / V_RANGE).to(dtype=torch.float32) + + keys = (keys / k_scale).to(torch.float8_e4m3fn).transpose(2, 1) + values = (values / v_scale).to(torch.float8_e4m3fn).transpose(2, 1) + + if ( + isinstance(key_cache, ScaledTensor) + and isinstance(value_cache, ScaledTensor) + and value_cache.numel() > 0 + ): + key_cache = torch.cat((key_cache._data, keys), dim=2) + value_cache = torch.cat((value_cache._data, values), dim=2) + key_cache = _construct_fp8_cache(key_cache, k_scale) + value_cache = _construct_fp8_cache(value_cache, v_scale) + return ( + key_cache, + value_cache, + key_cache, + value_cache, + ) + keys = _construct_fp8_cache(keys.contiguous(), k_scale) + values = _construct_fp8_cache(values.contiguous(), v_scale) + return (keys, values, keys, values) + + +def _math_fp8_compute_op( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + nheads: int, + kvheads: int, + p_dropout: float, + scale_factor: float | None, + **attn_kwargs: Unpack[MathFP8AttentionKwargs], +) -> torch.Tensor: + """Implement computation of attention BMM, leveraging the custom scaled attention + BMM op that was pre-registered for torch.compile.""" + + orig_dtype = query.dtype + + q_scale = torch.tensor(1.0, dtype=torch.float32, device=query.device) + if attn_kwargs.get("do_scale_q", False): + q_scale.copy_(torch.abs(query).max() / Q_RANGE) + query = query / q_scale + + query = query.to(torch.float8_e4m3fn).transpose(2, 1) + + if isinstance(key_cache, ScaledTensor) and isinstance(value_cache, ScaledTensor): + k_scale = key_cache._scale + v_scale = value_cache._scale + key_cache = key_cache._data + value_cache = value_cache._data + else: + k_scale = (torch.abs(key_cache).max() / K_RANGE).to(dtype=torch.float32) + v_scale = (torch.abs(value_cache).max() / V_RANGE).to(dtype=torch.float32) + key_cache = (key_cache / k_scale).to(torch.float8_e4m3fn) + value_cache = (value_cache / v_scale).to(torch.float8_e4m3fn) + + # no longer transposing prior to store, so need to check this in case of no cache + # TODO: Refactor FMS to avoid edge cases where this fails; add use_cache param here + if key_cache.shape[1] != kvheads and key_cache.shape[2] == kvheads: + key_cache = key_cache.transpose(2, 1) + value_cache = value_cache.transpose(2, 1) + + mask = attn_kwargs.get("mask", None) + if mask is not None: + # Our expected mask format is bs x q_len x k_len, so to make it broadcastable + # we need to create the nheads dimension + while len(mask.size()) != 4: # expects bs (x nheads) x q_len x kv_len + mask = mask.unsqueeze(1) + + L, S = query.size(-2), key_cache.size(-2) + scale_factor = ( + 1 / math.sqrt(query.size(-1)) if scale_factor is None else scale_factor + ) + attn_bias = torch.zeros(L, S, dtype=orig_dtype, device=query.device) + if attn_kwargs.get("is_causal_mask", False): + assert mask is None + temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(torch.float32) + + if mask is not None: + if mask.dtype == torch.bool: + attn_bias.masked_fill_(mask.logical_not(), float("-inf")) + else: + attn_bias = mask + attn_bias + + expansion = nheads // kvheads + if expansion > 1: + key_cache = key_cache.repeat_interleave( + query.size(-3) // key_cache.size(-3), -3 + ) + value_cache = value_cache.repeat_interleave( + query.size(-3) // value_cache.size(-3), -3 + ) + + attn_weight = ( + torch.ops.sendnn.scaled_bmm( + query, + key_cache.transpose(-2, -1), + q_scale, + k_scale, + out_dtype=orig_dtype, + use_fast_accum=True, + ) + * scale_factor + ) + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + attn_weight = torch.dropout(attn_weight, p_dropout, train=True) + # Do matmul in orig_dtype + attn = attn_weight @ (value_cache.to(dtype=orig_dtype) * v_scale) + + attn = attn.to(orig_dtype).transpose(2, 1).contiguous() + return attn + + +register_attention_op( + "math_fp8", + _math_fp8_store_op, + _math_fp8_compute_op, + update_attn_kwargs_op=_sdpa_update_attn_kwargs, +) diff --git a/fms_mo/aiu_addons/fp8/fp8_linear.py b/fms_mo/aiu_addons/fp8/fp8_linear.py new file mode 100644 index 00000000..3e2246ba --- /dev/null +++ b/fms_mo/aiu_addons/fp8/fp8_linear.py @@ -0,0 +1,327 @@ +# Copyright The FMS Model Optimizer Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Implement FP8 linear module to be loaded via FMS.""" + +# Standard +from importlib.util import find_spec +from typing import Any, Mapping + +# Third Party +from fms.modules.linear import ( + LinearModuleShardingInfo, + LinearParameterShardingInfo, + register_linear_type_to_module_map, + register_linear_type_to_sharding_map, + shard_base_linear, +) +from fms.modules.tp import ShardType, TPModule +import torch + +# pylint: disable=not-callable +# torch.nn.functional.linear not recognized as callable +# open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482 + + +# Gated torchao imports for FP8 implementation +if find_spec("torchao"): + TORCHAO_INSTALLED = True + + # Third Party + from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + to_affine_quantized_floatx, + to_affine_quantized_floatx_static, + ) + from torchao.dtypes.floatx.float8_layout import ( + Float8AQTTensorImpl, + Float8Layout, + Float8MMConfig, + preprocess_data, + preprocess_scale, + ) + from torchao.dtypes.utils import get_out_shape + from torchao.float8.inference import ( + _is_rowwise_scaled, + addmm_float8_unwrapped_inference, + ) + from torchao.quantization.granularity import PerRow, PerTensor + from torchao.quantization.observer import get_block_size + from torchao.quantization.quant_primitives import ZeroPointDomain +else: + TORCHAO_INSTALLED = False + + +class FP8Linear(torch.nn.Module): + """Class handles FP8 weights loading and uses torchao for the matmuls.""" + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + linear_config: Mapping[str, Any], + ): + super().__init__() + + self.in_features = in_features + self.out_features = out_features + self.has_bias = bias + self.linear_config = linear_config + + assert ( + self.linear_config["weights"] is not None + ), "Weights must always be quantized for FP8Linear" + assert self.linear_config["weights"][ + "symmetric" + ], "We only support symmetric weights for now" + assert not self.linear_config["weights"][ + "dynamic" + ], "We only support pre-quantized weights for now" + + self.weight = torch.nn.Parameter( + torch.zeros(out_features, in_features, dtype=torch.float8_e4m3fn), + requires_grad=False, + ) + + weight_scale_shape = ( + (1,) + if self.linear_config["weights"]["strategy"] == "tensor" + else (out_features, 1) + ) + self.weight_scale = torch.nn.Parameter( + torch.ones(weight_scale_shape), requires_grad=False + ) + + self.has_bias = bias + if self.has_bias: + self.bias = torch.nn.Parameter(torch.zeros((out_features,))) + + if ( + self.linear_config["input_activations"] is not None + and not self.linear_config["input_activations"]["dynamic"] + ): + input_scale_shape = ( + (1,) + if self.linear_config["input_activations"]["strategy"] == "tensor" + else (out_features, 1) + ) + self.input_scale = torch.nn.Parameter( + torch.ones(input_scale_shape), requires_grad=False + ) + + def _input_activation_quant_func_fp8( + self, + x: torch.Tensor, + activation_granularity, + activation_dtype: torch.dtype, + scale: torch.Tensor | None = None, + ): + """Quantize the input activation tensor for an aqt_float variant. + If scale is not provided, it will be dynamically calculated, otherwise the + provided scale will be used. + """ + + block_size = get_block_size(x.shape, activation_granularity) + if scale is None: + activation = to_affine_quantized_floatx( + input_float=x, + block_size=block_size, + target_dtype=activation_dtype, + scale_dtype=torch.float32, + _layout=Float8Layout(mm_config=None), # Config is stored on weight + ) + else: + assert isinstance( + activation_granularity, PerTensor + ), "Static quantization only supports PerTensor granularity" + activation = to_affine_quantized_floatx_static( + input_float=x, + block_size=block_size, + scale=scale, + target_dtype=activation_dtype, + _layout=Float8Layout(mm_config=None), # Config is stored on weight + ) + return activation + + def _construct_qweight_structure(self) -> "AffineQuantizedTensor": + """Construct the torchao machinery for the fp8 matmul""" + + weight_granularity = ( + PerTensor() + if self.linear_config["weights"]["strategy"] == "tensor" + else PerRow() + ) + fp8_layout = Float8Layout(Float8MMConfig(use_fast_accum=True)) + return AffineQuantizedTensor( + Float8AQTTensorImpl.from_plain( + self.weight, + self.weight_scale.squeeze().to(torch.float32), + None, + fp8_layout, + ), + get_block_size(self.weight.shape, weight_granularity), + self.weight.shape, + zero_point_domain=ZeroPointDomain.NONE, + dtype=self.weight_scale.dtype, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """If input quantization is active, compute FP8xFP8 addmm leveraging torchao + functionalities. Otherwise compute non-quantized addmm.""" + + # fp8 weight tensor for torchao + qweight: AffineQuantizedTensor = self._construct_qweight_structure() + + if self.linear_config["input_activations"] is not None: + # activations are also fp8, quantize as required by model + act_granularity = ( + PerTensor() + if self.linear_config["input_activations"]["strategy"] == "tensor" + else PerRow() + ) + input_quant_kwargs = { + "activation_granularity": act_granularity, + "activation_dtype": torch.float8_e4m3fn, + } + if not self.linear_config["input_activations"]["dynamic"]: + input_quant_kwargs["scale"] = self.input_scale.squeeze().to( + torch.float32 + ) + qx = self._input_activation_quant_func_fp8(x, **input_quant_kwargs) + + # Copied from torchao _linear_fp8_act_fp8_weight_impl (with changes to support fp8 out) + scaled_mm_config = Float8MMConfig(use_fast_accum=True) + out_shape = get_out_shape(qx.shape, qweight.shape) + + # Weight tensor preprocessing + w_tensor_impl = qweight.tensor_impl + assert not w_tensor_impl.transposed, "Weight tensor must be contiguous" + w_data = w_tensor_impl.float8_data + w_scale = w_tensor_impl.scale + + # Input tensor preprocessing + inpt_data = qx.tensor_impl.float8_data + input_scale = qx.tensor_impl.scale + # Handle case where input tensor is more than 2D + inpt_data = inpt_data.reshape(-1, inpt_data.shape[-1]) + + # Handle rowwise case + if _is_rowwise_scaled(qweight): + assert _is_rowwise_scaled(qx), "Input tensor must be rowwise block size" + w_scale = w_scale.unsqueeze(-1).T + input_scale = preprocess_scale(input_scale, qx.shape) + + # Preprocess data + inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config) + + # Perform the computation + return addmm_float8_unwrapped_inference( + inpt_data, + input_scale, + w_data, + w_scale, + output_dtype=qx.dtype, + bias=getattr(self, "bias", None), + use_fast_accum=scaled_mm_config.use_fast_accum, + ).reshape(out_shape) + + # activations not quantized, dequant fp8 weight and do regular matmul + out = torch.nn.functional.linear( + x, qweight.dequantize(), self.bias if self.has_bias else None + ) + return out + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}" + f"(in={self.in_features}, out={self.out_features}, " + f"bias={self.has_bias}, fp8_config={self.linear_config})" + ) + + +def get_fp8_linear( + in_features: int, + out_features: int, + bias: bool, + linear_config: Mapping[str, Any], +) -> FP8Linear: + """Retrieve an FP8 Linear module""" + + if not TORCHAO_INSTALLED: + raise ModuleNotFoundError("You need to install torchao for FP8 support in FMS!") + + return FP8Linear(in_features, out_features, bias, linear_config) + + +def shard_fp8_linear( + tensor_values: dict[str, torch.Tensor], + tp_module: TPModule, + module_sharding_info: dict[str, LinearModuleShardingInfo], +) -> set | None: + """ + | GPU | + sharding | param | shard | dim | + ----------+----------------+-------+-----| + colwise | weight | Y | 0 | + | weight_scale | N | - | + | input_scale | N | - | + | bias | Y | 0 | + ----------+----------------+-------+-----| + rowwise | weight | Y | 1 | + | weight_scale | Y/N | 0/- | + | input_scale | Y/N | 0/- | + | bias | 0 | - | + """ + + param_sharding_info: dict[str, dict[str, LinearParameterShardingInfo]] = {} + for module_name, module_info in module_sharding_info.items(): + linear_mod: torch.nn.Module = module_info.linear_module + weight_strategy = getattr(linear_mod, "linear_config")["input_activations"][ + "strategy" + ] + # Scales are per-row or per-tensor + # Only sharding needed when row parallel and per-row + shard_scales = weight_strategy != "tensor" and module_info.sharding_dim == 1 + params: dict[str, LinearParameterShardingInfo] = { + "weight": LinearParameterShardingInfo( + module_info.sharding_dim, ShardType.SHARD + ), + "weight_scale": LinearParameterShardingInfo( + module_info.sharding_dim, + ShardType.SHARD if shard_scales else ShardType.CLONE, + ), + } + if hasattr(linear_mod, "input_scale"): + params["input_scale"] = LinearParameterShardingInfo( + module_info.sharding_dim, + ShardType.SHARD if shard_scales else ShardType.CLONE, + ) + if hasattr(linear_mod, "bias") and linear_mod.bias is not None: + params["bias"] = LinearParameterShardingInfo( + module_info.sharding_dim, + ShardType.SHARD if module_info.sharding_dim == 0 else ShardType.RANK0, + ) + param_sharding_info[module_name] = params + + unused_keys = shard_base_linear( + tensor_values, + tp_module, + module_sharding_info, + param_sharding_info, + ) + return unused_keys + + +register_linear_type_to_module_map("fp8", get_fp8_linear) +register_linear_type_to_sharding_map("fp8", shard_fp8_linear) diff --git a/fms_mo/aiu_addons/fp8/fp8_spyre_op.py b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py new file mode 100644 index 00000000..c6280d4a --- /dev/null +++ b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py @@ -0,0 +1,75 @@ +# Copyright The FMS Model Optimizer Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Torch registration of FP8xFP8 operation for attention BMMs.""" + +# Third Party +from torch import Tensor +import torch + +# pylint: disable=unused-argument +# abstract op must be registered with specific I/O, even if not in use by the op function + + +@torch.library.custom_op("spyre::scaled_bmm", mutates_args=()) +def sendnn_scaled_bmm( + mat1: Tensor, + mat2: Tensor, + scale1: Tensor, + scale2: Tensor, + out_dtype: torch.dtype | None = None, + use_fast_accum: bool = False, +) -> Tensor: + """Implement a custom scaled attention BMM op: a batched version of _scaled_mm. + The operations that are part of this function are not exposed to the computational + graph, but are invoked when running on non-Spyre devices. + """ + + assert ( + mat1.shape[:-2] == mat2.shape[:-2] + ), "batch dimensions must match for mat1 and mat2" + mat1 = mat1.view(-1, *mat1.shape[-2:]) + mat2 = mat2.view(-1, *mat2.shape[-2:]) + out = torch.empty( + (mat1.shape[0], mat1.shape[1], mat2.shape[2]), + dtype=out_dtype, + device=mat1.device, + ) + for b_idx in range(mat1.shape[0]): + out[b_idx] = torch._scaled_mm( + mat1[b_idx], + mat2[b_idx], + scale1, + scale2, + out_dtype=out_dtype, + use_fast_accum=use_fast_accum, + ) + return out.view(*mat1.shape[:-2], mat1.shape[1], mat2.shape[2]) + + +@sendnn_scaled_bmm.register_fake +def _( + mat1: Tensor, + mat2: Tensor, + scale1: Tensor, + scale2: Tensor, + out_dtype: torch.dtype | None = None, + use_fast_accum: bool = False, +) -> Tensor: + """Template for scaled attention BMM operation. I/O retain the expected size.""" + + return torch.empty( + (*mat1.shape[:-2], mat1.shape[-2], mat2.shape[-1]), + dtype=out_dtype, + device=mat1.device, + ) diff --git a/fms_mo/aiu_addons/fp8/fp8_utils.py b/fms_mo/aiu_addons/fp8/fp8_utils.py new file mode 100644 index 00000000..0f5fbab4 --- /dev/null +++ b/fms_mo/aiu_addons/fp8/fp8_utils.py @@ -0,0 +1,152 @@ +# Copyright The FMS Model Optimizer Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utility functions and components for FP8 addon implementation.""" + +# Standard +import functools + +# Third Party +import torch + +# pylint: disable=unused-argument +# unusued arguments are needed for templates + + +_HANDLED_FUNCTIONS = {} + + +def _implements(torch_function): + """Register a torch function override""" + + def decorator(func): + @functools.wraps(torch_function) + def wrapper(f, types, args, kwargs): + return func(f, types, args, kwargs) + + _HANDLED_FUNCTIONS[torch_function] = wrapper + return func + + return decorator + + +class ScaledTensor(torch.Tensor): + """Representation of a quantized tensor and its scale.""" + + def __new__( + cls, + data: torch.Tensor, + scale: torch.Tensor, + ): + return torch.Tensor._make_wrapper_subclass( + cls, + data.size(), + strides=data.stride(), + storage_offset=data.storage_offset(), + dtype=data.dtype, + layout=data.layout, + requires_grad=data.requires_grad, + device=data.device, + ) + + def __init__( # pylint: disable=super-init-not-called + self, + data: torch.Tensor, + scale: torch.Tensor, + ): + self._data = data + self._scale = scale + + def __tensor_flatten__(self): + ctx = {} + return ["_data", "_scale"], ctx + + @staticmethod + def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride): + assert len(inner_tensors) == 2 + return ScaledTensor( + inner_tensors["_data"], + inner_tensors["_scale"], + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + if func in _HANDLED_FUNCTIONS: + return _HANDLED_FUNCTIONS[func](func, types, args, kwargs) + + arg_types = tuple(type(arg) for arg in args) + kwarg_types = {k: type(arg) for k, arg in kwargs.items()} + raise NotImplementedError( + f"{cls.__name__} dispatch: attempting to run unimplemented " + f"operator/function: {func=}, {types=}, {arg_types=}, {kwarg_types=}" + ) + + def __repr__(self): + return f"{self._data.__repr__()}\n{self._scale.__repr__()}" + + +def _infer_quantization_config(quant_config: dict) -> dict | None: + """Construct linear_config dictionary carrying FP8 configuration for FMS. + + There's many quantization packages compatible with HF + We initially focus on llm-compressor as it is the one used in FMS-MO + + llm-compressor saves its checkpoints with quant_method = compressed-tensors + quantization_status tells us whether the model has already been quantized + We only support loading already quantized models (compressed status) + """ + + if ( + quant_config["quant_method"] == "compressed-tensors" + and quant_config["quantization_status"] == "compressed" + ): + # FP8 quantization will have FP8 weights + # We assume a single quantization group (group_0), to follow fms-mo checkpoints + # num_bits and type tells us "float" with "8" bits, aka FP8 + if ( + quant_config["config_groups"]["group_0"]["weights"]["type"] == "float" + and quant_config["config_groups"]["group_0"]["weights"]["num_bits"] == 8 + ): + # This is used by get_linear to decide whether a linear layer + # will be quantized or not inside the model + def fp8_linear_type(name: str) -> str: + # We need to translate HF names to FMS names + translations = { + "lm_head": "head", + } + for ignored_layer in quant_config["ignore"]: + assert isinstance(ignored_layer, str) + fms_ign_layer = translations.get(ignored_layer, ignored_layer) + if name in fms_ign_layer: + return "torch_linear" + for pattern in quant_config["config_groups"]["group_0"]["targets"]: + # Special case from llm-compressor that covers all linear layers + # not in the ignore pattern + assert isinstance(pattern, str) + if pattern == "Linear": + return "fp8" + if name in translations.get(pattern, pattern): + return "fp8" + return "torch_linear" + + return { + "linear_type": fp8_linear_type, + "input_activations": quant_config["config_groups"]["group_0"][ + "input_activations" + ], + "output_activations": quant_config["config_groups"]["group_0"][ + "output_activations" + ], + "weights": quant_config["config_groups"]["group_0"]["weights"], + } + return None