Skip to content

Commit ef45976

Browse files
kylesayrsdsikka
andauthored
[Example] Attention and R3 Examples (#2132)
# Attention Quantization in LLM Compressor # LLM Compressor supports applying static attention quantization to models. Please note that attention quantization support in vLLM is still ongoing and is not fully supported as of this writing. ## FP8 Attention Example ## For an example applying attention quantization, see [llama3_attention.py](/experimental/attention/llama3_attention.py). ```python recipe = QuantizationModifier( config_groups={ "attention": QuantizationScheme( targets=["LlamaAttention"], input_activations=QuantizationArgs( num_bits=8, type="float", strategy="attn_head" ), ) } ) ``` Note that attention quantization also implicitly applies kv cache quantization with the same quantization arguments. ## NVFP4 Attention + R3 Example ## Attention quantization can be improved using the R3 transform, as described by [SpinQuant](https://arxiv.org/abs/2405.16406). This transform reduces the presence of outliers in the attention activation distribution, thereby improving accurcy recovery. ```python recipe = [ SpinQuantModifier(rotations=["R3"]), QuantizationModifier( config_groups={ "attention": QuantizationScheme( targets=["LlamaAttention"], input_activations=NVFP4["input_activations"], ) } ), ] ``` ### Evaluations ### Utilizing the R3 transform has been shown to improve accuracy recovery for the `meta-llama/Llama-3.2-1B-Instruct` model when using NVFP4 attention quantization. Without R3 Transform ``` ../llm-compressor/Llama-3.2-1B-Instruct-attention-nvfp4/ | Tasks |Version| Filter |n-shot| Metric | |Value | |Stderr| |--------------|------:|----------------|-----:|-----------|---|-----:|---|-----:| |gsm8k_platinum| 3|flexible-extract| 5|exact_match|↑ |0.2680|± |0.0127| | | |strict-match | 5|exact_match|↑ |0.1836|± |0.0111| ``` With R3 Transform ``` ../llm-compressor/Llama-3.2-1B-Instruct-r3-attention-nvfp4/ | Tasks |Version| Filter |n-shot| Metric | |Value | |Stderr| |--------------|------:|----------------|-----:|-----------|---|-----:|---|-----:| |gsm8k_platinum| 3|flexible-extract| 5|exact_match|↑ |0.2961|± |0.0131| | | |strict-match | 5|exact_match|↑ |0.2283|± |0.0121| ``` --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Co-authored-by: Dipika Sikka <dipikasikka1@gmail.com>
1 parent a105dee commit ef45976

File tree

4 files changed

+151
-0
lines changed

4 files changed

+151
-0
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ Applying quantization with `llmcompressor`:
8181
* [Weight only quantization to `int4` using GPTQ](examples/quantization_w4a16/README.md)
8282
* [Weight only quantization to `int4` using AWQ](examples/awq/README.md)
8383
* [Weight only quantization to `int4` using AutoRound](examples/autoround/README.md)
84+
* [KV Cache quantization to `fp8`](examples/quantization_kv_cache/README.md)
85+
* [Attention quantization to `fp8` (experimental)](experimental/attention/README.md)
86+
* [Attention quantization to `nvfp4` with SpinQuant (experimental)](experimental/attention/README.md)
8487
* [Quantizing MoE LLMs](examples/quantizing_moe/README.md)
8588
* [Quantizing Vision-Language Models](examples/multimodal_vision/README.md)
8689
* [Quantizing Audio-Language Models](examples/multimodal_audio/README.md)

experimental/attention/README.md

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Attention Quantization in LLM Compressor #
2+
LLM Compressor supports applying static attention quantization to models. Please note that attention quantization support in vLLM is still ongoing and is not fully supported as of this writing.
3+
4+
## FP8 Attention Example ##
5+
For an example applying attention quantization, see [llama3_attention.py](/experimental/attention/llama3_attention.py).
6+
7+
```python
8+
recipe = QuantizationModifier(
9+
config_groups={
10+
"attention": QuantizationScheme(
11+
targets=["LlamaAttention"],
12+
input_activations=QuantizationArgs(
13+
num_bits=8, type="float", strategy="attn_head"
14+
),
15+
)
16+
}
17+
)
18+
```
19+
20+
Note that attention quantization also implicitly applies kv cache quantization with the same quantization arguments.
21+
22+
## NVFP4 Attention + R3 Example ##
23+
Attention quantization can be improved using the R3 transform, as described by [SpinQuant](https://arxiv.org/abs/2405.16406). This transform reduces the presence of outliers in the attention activation distribution, thereby improving accurcy recovery.
24+
25+
```python
26+
recipe = [
27+
SpinQuantModifier(rotations=["R3"]),
28+
QuantizationModifier(
29+
config_groups={
30+
"attention": QuantizationScheme(
31+
targets=["LlamaAttention"],
32+
input_activations=NVFP4["input_activations"],
33+
)
34+
}
35+
),
36+
]
37+
```
38+
39+
### Evaluations ###
40+
Utilizing the R3 transform has been shown to improve accuracy recovery for the `meta-llama/Llama-3.2-1B-Instruct` model when using NVFP4 attention quantization.
41+
42+
Without R3 Transform
43+
```
44+
../llm-compressor/Llama-3.2-1B-Instruct-attention-nvfp4/
45+
| Tasks |Version| Filter |n-shot| Metric | |Value | |Stderr|
46+
|--------------|------:|----------------|-----:|-----------|---|-----:|---|-----:|
47+
|gsm8k_platinum| 3|flexible-extract| 5|exact_match|↑ |0.2680|± |0.0127|
48+
| | |strict-match | 5|exact_match|↑ |0.1836|± |0.0111|
49+
```
50+
51+
With R3 Transform
52+
```
53+
../llm-compressor/Llama-3.2-1B-Instruct-r3-attention-nvfp4/
54+
| Tasks |Version| Filter |n-shot| Metric | |Value | |Stderr|
55+
|--------------|------:|----------------|-----:|-----------|---|-----:|---|-----:|
56+
|gsm8k_platinum| 3|flexible-extract| 5|exact_match|↑ |0.2961|± |0.0131|
57+
| | |strict-match | 5|exact_match|↑ |0.2283|± |0.0121|
58+
```
File renamed without changes.
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from datasets import load_dataset
2+
from transformers import AutoModelForCausalLM, AutoTokenizer
3+
4+
from llmcompressor import oneshot
5+
from llmcompressor.modifiers.quantization import QuantizationModifier
6+
from llmcompressor.modifiers.transform import SpinQuantModifier
7+
from llmcompressor.utils import dispatch_for_generation
8+
from compressed_tensors.quantization import QuantizationScheme
9+
from compressed_tensors.quantization.quant_scheme import NVFP4
10+
11+
# Select model and load it.
12+
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
13+
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
14+
tokenizer = AutoTokenizer.from_pretrained(model_id)
15+
16+
# Select calibration dataset.
17+
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
18+
DATASET_SPLIT = "train_sft"
19+
20+
# Select number of samples. 512 samples is a good place to start.
21+
# Increasing the number of samples can improve accuracy.
22+
NUM_CALIBRATION_SAMPLES = 512
23+
MAX_SEQUENCE_LENGTH = 2048
24+
25+
# Load dataset and preprocess.
26+
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
27+
ds = ds.shuffle(seed=42)
28+
29+
30+
def preprocess(example):
31+
return {
32+
"text": tokenizer.apply_chat_template(
33+
example["messages"],
34+
tokenize=False,
35+
)
36+
}
37+
38+
39+
ds = ds.map(preprocess)
40+
41+
42+
# Tokenize inputs.
43+
def tokenize(sample):
44+
return tokenizer(
45+
sample["text"],
46+
padding=False,
47+
max_length=MAX_SEQUENCE_LENGTH,
48+
truncation=True,
49+
add_special_tokens=False,
50+
)
51+
52+
53+
ds = ds.map(tokenize, remove_columns=ds.column_names)
54+
55+
# Configure the quantization algorithm to run.
56+
recipe = [
57+
SpinQuantModifier(rotations=["R3"]),
58+
QuantizationModifier(
59+
config_groups={
60+
"attention": QuantizationScheme(
61+
targets=["LlamaAttention"],
62+
input_activations=NVFP4["input_activations"],
63+
)
64+
}
65+
),
66+
]
67+
68+
# Apply algorithms.
69+
oneshot(
70+
model=model,
71+
dataset=ds,
72+
recipe=recipe,
73+
max_seq_length=MAX_SEQUENCE_LENGTH,
74+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
75+
)
76+
77+
# Confirm generations of the quantized model look sane.
78+
print("\n\n")
79+
print("========== SAMPLE GENERATION ==============")
80+
dispatch_for_generation(model)
81+
sample = tokenizer("Hello my name is", return_tensors="pt")
82+
sample = {key: value.to(model.device) for key, value in sample.items()}
83+
output = model.generate(**sample, max_new_tokens=100)
84+
print(tokenizer.decode(output[0]))
85+
print("==========================================\n\n")
86+
87+
# Save to disk compressed.
88+
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-r3-attention-nvfp4"
89+
model.save_pretrained(SAVE_DIR, save_compressed=True)
90+
tokenizer.save_pretrained(SAVE_DIR)

0 commit comments

Comments
 (0)