Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions fms_mo/aiu_addons/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
def _infer_quantization_config(quant_config: dict) -> dict | None:
"""Construct linear_config dictionary carrying FP8 configuration for FMS.

There's many quantization packages compatible with HF
We initially focus on llm-compressor as it is the one used in FMS-MO

llm-compressor saves its checkpoints with quant_method = compressed-tensors
quantization_status tells us whether the model has already been quantized
We only support loading already quantized models (compressed status)
"""

if (
quant_config["quant_method"] == "compressed-tensors"
and quant_config["quantization_status"] == "compressed"
):
# FP8 quantization will have FP8 weights
# We assume a single quantization group (group_0), to follow fms-mo checkpoints
# num_bits and type tells us "float" with "8" bits, aka FP8
if (
quant_config["config_groups"]["group_0"]["weights"]["type"] == "float"
and quant_config["config_groups"]["group_0"]["weights"]["num_bits"] == 8
):
# First, import required FP8 linear classes from fms-mo
# Local
import fms_mo.aiu_addons.fp8.fp8_adapter # pylint: disable=unused-import
import fms_mo.aiu_addons.fp8.fp8_linear # pylint: disable=unused-import

# This is used by get_linear to decide whether a linear layer
# will be quantized or not inside the model
def fp8_linear_type(name: str) -> str:
# We need to translate HF names to FMS names
translations = {
"lm_head": "head",
}
for ignored_layer in quant_config["ignore"]:
assert isinstance(ignored_layer, str)
fms_ign_layer = translations.get(ignored_layer, ignored_layer)
if name in fms_ign_layer:
return "torch_linear"
for pattern in quant_config["config_groups"]["group_0"]["targets"]:
# Special case from llm-compressor that covers all linear layers
# not in the ignore pattern
assert isinstance(pattern, str)
if pattern == "Linear":
return "fp8"
if name in translations.get(pattern, pattern):
return "fp8"
return "torch_linear"

return {
"linear_type": fp8_linear_type,
"input_activations": quant_config["config_groups"]["group_0"][
"input_activations"
],
"output_activations": quant_config["config_groups"]["group_0"][
"output_activations"
],
"weights": quant_config["config_groups"]["group_0"]["weights"],
}
return None
Empty file.
73 changes: 73 additions & 0 deletions fms_mo/aiu_addons/fp8/fp8_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright The FMS Model Optimizer Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implement and register FMS adapters for FP8 checkpoint loading."""

# Standard
from typing import Any, Mapping
import functools

# Local
from fms_mo.prep import available_packages

if available_packages["fms"]:
# Third Party
from fms.modules.linear import get_linear_type
from fms.utils import serialization
from fms.utils.config import ModelConfig

# pylint: disable=unused-argument
# Retaining kwargs input arguments for consistency with other adapter steps.
# TODO: may be shared with gptq llama
def _hf_fp8_check(
input_sd: Mapping[str, Any],
model_config: ModelConfig | None = None,
checkpoint_is_fused: bool = False,
**kwargs,
) -> Mapping[str, Any]:
"""Implementation of adapter step for FMS: ensure that when FP8 quantization
is in use, weights are fused like the model checkpoint.
"""

has_fused_weights = True
linear_type = "torch_linear"
if model_config:
if not model_config.fused_weights:
has_fused_weights = False
if model_config.linear_config:
linear_type = model_config.linear_config["linear_type"]
if callable(linear_type):
# Calling this function with "any" guarantees "fp8" to be returned
# when loading an HF fp8 checkpoint, and never in any other condition
linear_type = get_linear_type(model_config.linear_config, "any")

if "fp8" in linear_type and has_fused_weights != checkpoint_is_fused:
raise ValueError(
"FP8 HF llama checkpoints cannot be loaded into a model with fused weights"
)

return input_sd

serialization.register_adapter_step(
"llama",
"hf_fp8_check",
functools.partial(_hf_fp8_check, checkpoint_is_fused=False),
)
serialization.extend_adapter("llama", "hf", ["hf_fp8_check"])

serialization.register_adapter_step(
"granite",
"hf_fp8_check",
functools.partial(_hf_fp8_check, checkpoint_is_fused=False),
)
serialization.extend_adapter("granite", "hf", ["hf_fp8_check"])
Loading
Loading