Skip to content

Commit 0126707

Browse files
SunMarcMekkCyber
andauthored
small cleaning of quantization class (#42633)
* small cleaning * fix * Apply suggestions from code review Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> --------- Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
1 parent 626875b commit 0126707

23 files changed

+298
-1173
lines changed

docs/source/en/quantization/contribute.md

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,26 +46,30 @@ Some quantization methods may require "pre-quantizing" the model through data ca
4646

4747
## Create new HFQuantizer class
4848

49+
0. The best starting point would be to have a look at another quantization method such as Finegrained Fp8. You will have to update or create three files in total: the [config file](https://github.com/huggingface/transformers/blob/main/src/transformers/utils/quantization_config.py), the [integration file](https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/finegrained_fp8.py) and the [quantizer file](https://github.com/huggingface/transformers/blob/main/src/transformers/quantizers/quantizer_finegrained_fp8.py).
50+
4951
1. Create a new quantization config class inside [src/transformers/utils/quantization_config.py](https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/utils/quantization_config.py). Add the new quantization config to the [_import_structure](https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/__init__.py#L1088) inside Transformers' [src/transformers/__init__.py](https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/__init__.py) file.
5052

5153
2. Create a new file inside [src/transformers/quantizers/](https://github.com/huggingface/transformers/tree/abbffc4525566a48a9733639797c812301218b83/src/transformers/quantizers) named `quantizer_your_method.py`, and make it inherit from [`~quantizers.HfQuantizer]. Make sure to add the new quantizer and quantization config in the quantization auto-mapping in [src/transformers/quantizers/auto.py](https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/quantizers/auto.py).
5254
53-
3. Define the following class attributes and property methods for your quantization method.
55+
3. Define the following class attributes and property methods for your quantization method:
5456

5557
- `requires_calibration`: Whether the quantization method requires a data calibration process. If set to `True`, you can only support inference (with quantized weights) and not inference and quantization.
56-
- `required_packages`: A list of strings of the required packages to use the quantized weights. You might need to define some new utility methods such as `is_auto_awq_available` in [transformers/src/utils/import_utils.py](https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/utils/import_utils.py).
57-
- `requires_parameters_quantization`: Only required if your quantization method requires extra attention to the underlying [nn.Parameter](https://pytorch.org/docs/stable/generated/torch.nn.parameter.Parameter.html) object. For example, bitsandbytes uses [`~bitsandbytes.nn.Params4bit`] and [`~bitsandbytes.nn.Int8Params`], which requires some extra attention when quantizing the model. Most of the recent quantization method packs int2 and int4 weights inside [torch.uint8](https://pytorch.org/docs/stable/tensors.html) weights, so this flag should not be really required (set to `False` by default).
5858
- `is_serializable`: A property method to determine whether the method is serializable or not.
5959
- `is_trainable`: A property method to determine whether you can fine-tune models on top of the quantization method (with or without PEFT approaches).
6060

6161
4. Write the `validate_environment` and `update_dtype` methods. These methods are called before creating the quantized model to ensure users use the right configuration. Refer to other quantizers for an example of it is implemented.
6262

6363
5. Write the `_process_model_before_weight_loading` method. In Transformers, the quantized models are initialized first on the `"meta"` device before loading the weights. This means the `_process_model_before_weight_loading` method takes care of manipulating the model skeleton to replace some modules ([nn.Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html)) with the target modules (quantization modules).
6464

65-
You can define module replacement logic or any other utility method by creating a new file in [transformers/src/integrations/](https://github.com/huggingface/transformers/tree/abbffc4525566a48a9733639797c812301218b83/src/transformers/integrations) and exposing the relevant methods in that folder's `__init__.py` file. The best starting point would be to have a look at another quantization method such as [quantizer_awq.py](https://github.com/huggingface/transformers/blob/abbffc4525566a48a9733639797c812301218b83/src/transformers/quantizers/quantizer_awq.py).
65+
You can define module replacement logic or any other utility method by creating a new file in [transformers/src/integrations/](https://github.com/huggingface/transformers/tree/abbffc4525566a48a9733639797c812301218b83/src/transformers/integrations) and exposing the relevant methods in that folder's `__init__.py` file.
66+
67+
6. Add the `get_quantize_ops` method to the quantizer class if the quantization supports quantizing on the fly. In transformers, we materialize each tensor and apply a sequence of different operations on it. In our case, the quantization operation happens at the end. You need to create a `XXXQuantize`, a subclass of `ConversionOps`, and add a `convert` method. In the `convert` method, you need to quantize the weights and return a dictionary of quantized params.
68+
69+
7. Add the `get_weight_conversions` method to the quantizer class if the quantization supports loading pre-quantized weights. In transformers, we can collect multiple tensors and apply operations on them. This is particularly useful when we have tensors in the checkpoint that require to be regrouped to re-create the quantized tensors.
6670

67-
6. Write the `_process_model_after_weight_loading` method. This method enables implementing additional features that require manipulating the model after loading the weights.
71+
8. Write the `_process_model_after_weight_loading` method if needed. This method enables implementing additional features that require manipulating the model after loading the weights.
6872

69-
7. Document everything! Make sure your quantization method is documented by adding a new file under `docs/source/en/quantization`.
73+
9. Document everything! Make sure your quantization method is documented by adding a new file under `docs/source/en/quantization`.
7074

71-
8. You should add tests by adding the package in our nightly Dockerfile inside `docker/transformers-quantization-latest-gpu` and then adding a new test file in `tests/quantization/xxx`. Feel free to check out existing quantization methods to see how it is implemented.
75+
10. You should add tests by adding the package in our nightly Dockerfile inside `docker/transformers-quantization-latest-gpu` and then adding a new test file in `tests/quantization/xxx`. Feel free to check out existing quantization methods to see how it is implemented.

src/transformers/integrations/bitsandbytes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def convert(
4444
we need to store some parameters to create the quantized weight. For example, bnb requires 6 values that are stored in the checkpoint to recover the quantized weight. So we store them in a dict that it stored in hf_quantizer for now as we can't save it in the op since we create an op per tensor.
4545
"""
4646
value = list(input_dict.values())[0]
47-
value = value[0] if isinstance(value, list) else value
47+
value = value[0]
4848

4949
# update param name to get the weights instead of the quantized stats
5050
module, _ = get_module_from_name(model, full_layer_name)

src/transformers/quantizers/base.py

Lines changed: 7 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -75,26 +75,14 @@ class HfQuantizer(ABC):
7575
Attributes
7676
quantization_config (`transformers.utils.quantization_config.QuantizationConfigMixin`):
7777
The quantization config that defines the quantization parameters of your model that you want to quantize.
78-
modules_to_not_convert (`list[str]`, *optional*):
79-
The list of module names to not convert when quantizing the model.
80-
required_packages (`list[str]`, *optional*):
81-
The list of required pip packages to install prior to using the quantizer
8278
requires_calibration (`bool`):
8379
Whether the quantization method requires to calibrate the model before using it.
84-
requires_parameters_quantization (`bool`):
85-
Whether the quantization method requires to create a new Parameter. For example, for bitsandbytes, it is
86-
required to create a new xxxParameter in order to properly quantize the model.
8780
"""
8881

8982
requires_calibration = False
90-
required_packages = None
91-
requires_parameters_quantization = False
9283

9384
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
9485
self.quantization_config = quantization_config
95-
96-
# -- Handle extra kwargs below --
97-
self.modules_to_not_convert = kwargs.pop("modules_to_not_convert", [])
9886
self.pre_quantized = kwargs.pop("pre_quantized", True)
9987

10088
if not self.pre_quantized and self.requires_calibration:
@@ -157,53 +145,16 @@ def param_element_size(self, model: "PreTrainedModel", param_name: str, param: "
157145
return mapping[custom_dtype]
158146
return param.element_size()
159147

160-
def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
161-
"""
162-
Override this method if you want to adjust the `missing_keys`.
163-
164-
Args:
165-
missing_keys (`list[str]`, *optional*):
166-
The list of missing keys in the checkpoint compared to the state dict of the model
167-
"""
168-
return missing_keys
169-
170-
def update_expected_keys(self, model, expected_keys: list[str], loaded_keys: list[str]) -> list[str]:
171-
"""
172-
Override this method if you want to adjust the `update_expected_keys`.
173-
174-
Args:
175-
expected_keys (`list[str]`, *optional*):
176-
The list of the expected keys in the initialized model.
177-
loaded_keys (`list[str]`, *optional*):
178-
The list of the loaded keys in the checkpoint.
179-
"""
180-
return expected_keys
181-
182-
def update_unexpected_keys(self, model, unexpected_keys: list[str]) -> list[str]:
183-
return unexpected_keys
184-
185148
def adjust_max_memory(self, max_memory: dict[str, int | str]) -> dict[str, int | str]:
186149
"""adjust max_memory argument for infer_auto_device_map() if extra memory is needed for quantization"""
187150
return max_memory
188151

189152
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
190153
"""
191-
Check whether a given param needs quantization as defined by `create_quantized_param`.
154+
Check whether a given param needs to be quantized.
192155
"""
193156
return False
194157

195-
def create_quantized_param(self, *args, **kwargs):
196-
"""
197-
Take needed components from state_dict (those from which `param_needs_quantization` is True) and create
198-
quantized param.
199-
It usually also load the new param directly in the `model`.
200-
Note: only applicable if requires_parameters_quantization == True.
201-
"""
202-
if not self.requires_parameters_quantization:
203-
raise AttributeError(
204-
f"`.create_quantized_param()` method is not supported by quantizer class {self.__class__.__name__}."
205-
)
206-
207158
def validate_environment(self, *args, **kwargs):
208159
"""
209160
This method is used to potentially check for potential conflicts with arguments that are
@@ -263,6 +214,11 @@ def postprocess_model(self, model: "PreTrainedModel", **kwargs):
263214
kwargs (`dict`, *optional*):
264215
The keyword arguments that are passed along `_process_model_after_weight_loading`.
265216
"""
217+
model.config.quantization_config = self.quantization_config
218+
219+
if self.pre_quantized and getattr(self.quantization_config, "dequantize", False):
220+
self.remove_quantization_config(model)
221+
266222
return self._process_model_after_weight_loading(model, **kwargs)
267223

268224
def remove_quantization_config(self, model):
@@ -285,13 +241,7 @@ def dequantize(self, model):
285241
Note not all quantization schemes support this.
286242
"""
287243
model = self._dequantize(model)
288-
289-
# Delete quantizer and quantization config
290-
del model.hf_quantizer
291-
del model.config.quantization_config
292-
del model.config._pre_quantization_dtype
293-
del model.quantization_method
294-
model.is_quantized = False
244+
self.remove_quantization_config(model)
295245

296246
return model
297247

@@ -353,10 +303,6 @@ def get_state_dict_and_metadata(self, model, safe_serialization=False):
353303
"""Get state dict and metadata. Useful when we need to modify a bit the state dict due to quantization"""
354304
return None, {}
355305

356-
def update_state_dict_with_metadata(self, state_dict, metadata):
357-
"""Update state dict with metadata. Default behaviour returns state_dict"""
358-
return state_dict
359-
360306
@abstractmethod
361307
def is_serializable(self, safe_serialization=None): ...
362308

src/transformers/quantizers/quantizer_aqlm.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,9 @@ class AqlmHfQuantizer(HfQuantizer):
3939
"""
4040

4141
requires_calibration = True
42-
required_packages = ["aqlm"]
43-
optimum_quantizer = None
4442

4543
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
4644
super().__init__(quantization_config, **kwargs)
47-
self.quantization_config = quantization_config
4845

4946
def validate_environment(self, *args, **kwargs):
5047
if not is_accelerate_available():
@@ -77,7 +74,6 @@ def _process_model_before_weight_loading(
7774
quantization_config=self.quantization_config,
7875
linear_weights_not_to_quantize=self.quantization_config.linear_weights_not_to_quantize,
7976
)
80-
model.config.quantization_config = self.quantization_config
8177

8278
@property
8379
def is_trainable(self) -> bool:
@@ -90,5 +86,5 @@ def is_trainable(self) -> bool:
9086
)
9187
return False
9288

93-
def is_serializable(self, safe_serialization=None):
89+
def is_serializable(self, **kwargs):
9490
return True

src/transformers/quantizers/quantizer_auto_round.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ class AutoRoundQuantizer(HfQuantizer):
3636

3737
# AutoRound requires data calibration - we support only inference
3838
requires_calibration = True
39-
required_packages = ["auto_round"]
4039

4140
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
4241
super().__init__(quantization_config, **kwargs)

src/transformers/quantizers/quantizer_awq.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,6 @@ class AwqQuantizer(HfQuantizer):
4040
# AWQ requires data calibration - we support only inference
4141
requires_calibration = True
4242

43-
required_packages = ["awq", "accelerate"]
44-
4543
def __init__(self, quantization_config, **kwargs):
4644
super().__init__(quantization_config, **kwargs)
4745

src/transformers/quantizers/quantizer_bitnet.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,10 @@ class BitNetHfQuantizer(HfQuantizer):
3737
Check out the paper introducing this method: https://huggingface.co/papers/2402.17764
3838
"""
3939

40-
requires_parameters_quantization = False
4140
requires_calibration = True
4241

43-
required_packages = ["accelerate"]
44-
4542
def __init__(self, quantization_config, **kwargs):
4643
super().__init__(quantization_config, **kwargs)
47-
self.quantization_config = quantization_config
4844

4945
def validate_environment(self, *args, **kwargs):
5046
if not is_accelerate_available():
@@ -62,8 +58,8 @@ def validate_environment(self, *args, **kwargs):
6258
"You have loaded a BitNet model on CPU and have a CUDA device available, make sure to set "
6359
"your model on a GPU device in order to run your model."
6460
)
65-
elif device_map is not None:
66-
if isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()):
61+
elif isinstance(device_map, dict):
62+
if len(device_map) > 1 and "cpu" in device_map.values() or "disk" in device_map.values():
6763
raise ValueError(
6864
"You are attempting to load a BitNet model with a device_map that contains a CPU or disk device."
6965
"This is not supported. Please remove the CPU or disk device from the device_map."

0 commit comments

Comments
 (0)