-
Notifications
You must be signed in to change notification settings - Fork 17
feat: addons for FP8 attention bmm, paged attention, and linear in FMS #154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
b2c0d54
Addons for FP8 attention bmm in FMS
andrea-fasoli 6f289b0
Update FP8 bmm
andrea-fasoli 30f76a9
Add FP8 adapter step
andrea-fasoli 496bf44
Add FP8 linear to FMS addon
andrea-fasoli c931ad7
rename fp8 attention
andrea-fasoli f05beb5
Fix linting, add paged attention kernels
ani300 cf2082e
Make changes to work with fms and aftu
ani300 f4ec836
fix merge conflicts
ani300 b12dc58
Fixes from PR comments, unit tests
ani300 6a88117
Add test
ani300 bdf1cf2
Gate FMS imports
ani300 3a373ff
Add choice for scaled bmm
ani300 43372e4
Improve package checking to allow editable builds
ani300 c5a55fc
Add CPU fallback for scaled_mm
ani300 2b56e0f
Clean repr for fp8linear
ani300 42528b0
Add further skips to test
ani300 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,60 @@ | ||
| def _infer_quantization_config(quant_config: dict) -> dict | None: | ||
| """Construct linear_config dictionary carrying FP8 configuration for FMS. | ||
|
|
||
| There's many quantization packages compatible with HF | ||
| We initially focus on llm-compressor as it is the one used in FMS-MO | ||
|
|
||
| llm-compressor saves its checkpoints with quant_method = compressed-tensors | ||
| quantization_status tells us whether the model has already been quantized | ||
| We only support loading already quantized models (compressed status) | ||
| """ | ||
|
|
||
| if ( | ||
| quant_config["quant_method"] == "compressed-tensors" | ||
| and quant_config["quantization_status"] == "compressed" | ||
| ): | ||
| # FP8 quantization will have FP8 weights | ||
| # We assume a single quantization group (group_0), to follow fms-mo checkpoints | ||
| # num_bits and type tells us "float" with "8" bits, aka FP8 | ||
| if ( | ||
| quant_config["config_groups"]["group_0"]["weights"]["type"] == "float" | ||
| and quant_config["config_groups"]["group_0"]["weights"]["num_bits"] == 8 | ||
| ): | ||
| # First, import required FP8 linear classes from fms-mo | ||
| # Local | ||
| import fms_mo.aiu_addons.fp8.fp8_adapter # pylint: disable=unused-import | ||
| import fms_mo.aiu_addons.fp8.fp8_linear # pylint: disable=unused-import | ||
|
|
||
| # This is used by get_linear to decide whether a linear layer | ||
| # will be quantized or not inside the model | ||
| def fp8_linear_type(name: str) -> str: | ||
| # We need to translate HF names to FMS names | ||
| translations = { | ||
| "lm_head": "head", | ||
| } | ||
| for ignored_layer in quant_config["ignore"]: | ||
| assert isinstance(ignored_layer, str) | ||
| fms_ign_layer = translations.get(ignored_layer, ignored_layer) | ||
| if name in fms_ign_layer: | ||
| return "torch_linear" | ||
| for pattern in quant_config["config_groups"]["group_0"]["targets"]: | ||
| # Special case from llm-compressor that covers all linear layers | ||
| # not in the ignore pattern | ||
| assert isinstance(pattern, str) | ||
| if pattern == "Linear": | ||
| return "fp8" | ||
| if name in translations.get(pattern, pattern): | ||
| return "fp8" | ||
| return "torch_linear" | ||
|
|
||
| return { | ||
| "linear_type": fp8_linear_type, | ||
| "input_activations": quant_config["config_groups"]["group_0"][ | ||
| "input_activations" | ||
| ], | ||
| "output_activations": quant_config["config_groups"]["group_0"][ | ||
| "output_activations" | ||
| ], | ||
| "weights": quant_config["config_groups"]["group_0"]["weights"], | ||
| } | ||
| return None | ||
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,73 @@ | ||
| # Copyright The FMS Model Optimizer Authors | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """Implement and register FMS adapters for FP8 checkpoint loading.""" | ||
|
|
||
| # Standard | ||
| from typing import Any, Mapping | ||
| import functools | ||
|
|
||
| # Local | ||
| from fms_mo.prep import available_packages | ||
|
|
||
| if available_packages["fms"]: | ||
| # Third Party | ||
| from fms.modules.linear import get_linear_type | ||
| from fms.utils import serialization | ||
| from fms.utils.config import ModelConfig | ||
|
|
||
| # pylint: disable=unused-argument | ||
| # Retaining kwargs input arguments for consistency with other adapter steps. | ||
| # TODO: may be shared with gptq llama | ||
| def _hf_fp8_check( | ||
| input_sd: Mapping[str, Any], | ||
| model_config: ModelConfig | None = None, | ||
| checkpoint_is_fused: bool = False, | ||
| **kwargs, | ||
| ) -> Mapping[str, Any]: | ||
| """Implementation of adapter step for FMS: ensure that when FP8 quantization | ||
| is in use, weights are fused like the model checkpoint. | ||
| """ | ||
|
|
||
| has_fused_weights = True | ||
| linear_type = "torch_linear" | ||
| if model_config: | ||
| if not model_config.fused_weights: | ||
| has_fused_weights = False | ||
| if model_config.linear_config: | ||
| linear_type = model_config.linear_config["linear_type"] | ||
| if callable(linear_type): | ||
| # Calling this function with "any" guarantees "fp8" to be returned | ||
| # when loading an HF fp8 checkpoint, and never in any other condition | ||
| linear_type = get_linear_type(model_config.linear_config, "any") | ||
|
|
||
| if "fp8" in linear_type and has_fused_weights != checkpoint_is_fused: | ||
| raise ValueError( | ||
| "FP8 HF llama checkpoints cannot be loaded into a model with fused weights" | ||
| ) | ||
|
|
||
| return input_sd | ||
|
|
||
| serialization.register_adapter_step( | ||
| "llama", | ||
| "hf_fp8_check", | ||
| functools.partial(_hf_fp8_check, checkpoint_is_fused=False), | ||
| ) | ||
| serialization.extend_adapter("llama", "hf", ["hf_fp8_check"]) | ||
|
|
||
| serialization.register_adapter_step( | ||
| "granite", | ||
| "hf_fp8_check", | ||
| functools.partial(_hf_fp8_check, checkpoint_is_fused=False), | ||
| ) | ||
| serialization.extend_adapter("granite", "hf", ["hf_fp8_check"]) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.