Skip to content

Commit abb99f1

Browse files
authored
Add DS/QWEN Examples (#2333)
* add qwen-ds example Signed-off-by: yiliu30 <yi4.liu@intel.com>
1 parent e77a309 commit abb99f1

File tree

15 files changed

+1110
-0
lines changed

15 files changed

+1110
-0
lines changed

examples/README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,18 @@ Intel® Neural Compressor validated examples with multiple compression technique
1515
</tr>
1616
</thead>
1717
<tbody>
18+
<tr>
19+
<td>deepseek-ai/DeepSeek-R1</td>
20+
<td>Natural Language Processing</td>
21+
<td>Quantization (MXFP8/MXFP4)</td>
22+
<td><a href="./pytorch/nlp/huggingface_models/language-modeling/quantization/auto_round/deepseek">link</a></td>
23+
</tr>
24+
<tr>
25+
<td>Qwen/Qwen3-235B-A22B</td>
26+
<td>Natural Language Processing</td>
27+
<td>Quantization (MXFP8/MXFP4)</td>
28+
<td><a href="./pytorch/nlp/huggingface_models/language-modeling/quantization/auto_round/qwen">link</a></td>
29+
</tr>
1830
<tr>
1931
<td>Framepack</td>
2032
<td>Image + Text to Video</td>
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
This example provides an end-to-end workflow to quantize DeepSeek models to MXFP4/MXFP8 and evaluate them using a custom vLLM fork.
2+
3+
## Requirement
4+
```bash
5+
pip install neural-compressor-pt==3.7
6+
# auto-round
7+
pip install auto-round==0.9.2
8+
# vLLM
9+
git clone -b fused-moe-ar --single-branch --quiet https://github.com/yiliu30/vllm-fork.git && cd vllm-fork
10+
VLLM_USE_PRECOMPILED=1 pip install --editable . -vvv
11+
# other requirements
12+
pip install -r requirements.txt
13+
pip uninstall flash_attn
14+
```
15+
16+
### Quantize Model
17+
- Export model path
18+
```bash
19+
export MODEL=deepseek-ai/DeepSeek-R1
20+
```
21+
22+
- MXFP8
23+
```bash
24+
bash run_quant.sh --model $MODEL -t mxfp8 --output_dir ./qmodels
25+
```
26+
27+
- MXFP4
28+
```bash
29+
bash run_quant.sh --model $MODEL -t mxfp4 --output_dir ./qmodels
30+
```
31+
32+
## Evaluation
33+
34+
### Prompt Tests
35+
36+
Usage:
37+
```bash
38+
bash ./run_generate.sh -s [mxfp4|mxfp8] -tp [tensor_parallel_size] -m [model_path]
39+
```
40+
41+
- MXFP8
42+
```bash
43+
bash ./run_generate.sh -s mxfp8 -tp 8 -m /path/to/ds_mxfp8
44+
```
45+
- MXFP4
46+
```bash
47+
bash ./run_generate.sh -s mxfp4 -tp 8 -m /path/to/ds_mxfp4
48+
```
49+
### Evaluation
50+
51+
52+
Usage:
53+
```bash
54+
bash run_evaluation.sh -m [model_path] -s [mxfp4|mxfp8] -t [task_name] -tp [tensor_parallel_size] -b [batch_size]
55+
```
56+
```bash
57+
bash run_evaluation.sh -s mxfp8 -t piqa,hellaswag,mmlu -tp 8 -b 512 -m /path/to/ds_mxfp8
58+
bash run_evaluation.sh -s mxfp8 -t gsm8k -tp 8 -b 256 -m /path/to/ds_mxfp8
59+
60+
```
61+
- MXFP4
62+
```bash
63+
bash run_evaluation.sh -s mxfp4 -t piqa,hellaswag,mmlu -tp 8 -b 512 -m /path/to/ds_mxfp4
64+
bash run_evaluation.sh -s mxfp4 -t gsm8k -tp 8 -b 256 -m /path/to/ds_mxfp4
65+
```
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
# Copied from https://github.com/vllm-project/vllm/
4+
5+
from vllm import LLM, EngineArgs
6+
from vllm.utils.argparse_utils import FlexibleArgumentParser
7+
8+
9+
10+
def create_parser():
11+
parser = FlexibleArgumentParser()
12+
# Add engine args
13+
EngineArgs.add_cli_args(parser)
14+
parser.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct")
15+
# Add sampling params
16+
sampling_group = parser.add_argument_group("Sampling parameters")
17+
sampling_group.add_argument("--max-tokens", type=int)
18+
sampling_group.add_argument("--temperature", type=float)
19+
sampling_group.add_argument("--top-p", type=float)
20+
sampling_group.add_argument("--top-k", type=int)
21+
22+
return parser
23+
24+
25+
def main(args: dict):
26+
# Pop arguments not used by LLM
27+
max_tokens = args.pop("max_tokens")
28+
temperature = args.pop("temperature")
29+
top_p = args.pop("top_p")
30+
top_k = args.pop("top_k")
31+
32+
# Create an LLM
33+
llm = LLM(**args)
34+
35+
# Create a sampling params object
36+
sampling_params = llm.get_default_sampling_params()
37+
if max_tokens is not None:
38+
sampling_params.max_tokens = max_tokens
39+
if temperature is not None:
40+
sampling_params.temperature = temperature
41+
if top_p is not None:
42+
sampling_params.top_p = top_p
43+
if top_k is not None:
44+
sampling_params.top_k = top_k
45+
46+
# Generate texts from the prompts. The output is a list of RequestOutput
47+
# objects that contain the prompt, generated text, and other information.
48+
prompts = [
49+
"Hello, my name is",
50+
"The president of the United States is",
51+
"The capital of France is",
52+
"The future of AI is",
53+
]
54+
outputs = llm.generate(prompts, sampling_params)
55+
# Print the outputs.
56+
print("-" * 50)
57+
for output in outputs:
58+
prompt = output.prompt
59+
generated_text = output.outputs[0].text
60+
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
61+
print("-" * 50)
62+
63+
64+
if __name__ == "__main__":
65+
parser = create_parser()
66+
args: dict = vars(parser.parse_args())
67+
main(args)
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import torch
17+
from transformers import AutoModelForCausalLM, AutoTokenizer
18+
import transformers
19+
import logging
20+
logging.basicConfig(level=logging.INFO)
21+
logger = logging.getLogger(__name__)
22+
23+
24+
topologies_config = {
25+
"mxfp8": {
26+
"scheme": "MXFP8",
27+
"fp_layers": "lm_head",
28+
"iters": 0,
29+
},
30+
"mxfp4": {
31+
"scheme": "MXFP4",
32+
"fp_layers": "lm_head,self_attn",
33+
"iters": 0,
34+
},
35+
}
36+
37+
38+
def get_model_and_tokenizer(model_name):
39+
# Load model and tokenizer
40+
fp32_model = AutoModelForCausalLM.from_pretrained(
41+
model_name,
42+
device_map="cpu",
43+
trust_remote_code=True,
44+
)
45+
tokenizer = AutoTokenizer.from_pretrained(
46+
model_name,
47+
trust_remote_code=True,
48+
)
49+
return fp32_model, tokenizer
50+
51+
52+
def quant_model(args):
53+
from neural_compressor.torch.quantization import (
54+
AutoRoundConfig,
55+
convert,
56+
prepare,
57+
)
58+
59+
config = topologies_config[args.t]
60+
export_format = "auto_round" if args.use_autoround_format else "llm_compressor"
61+
output_dir = f"{args.output_dir}/quantized_model_{args.t}"
62+
fp32_model, tokenizer = get_model_and_tokenizer(args.model)
63+
quant_config = AutoRoundConfig(
64+
tokenizer=tokenizer,
65+
scheme=config["scheme"],
66+
enable_torch_compile=args.enable_torch_compile,
67+
iters=config["iters"],
68+
fp_layers=config["fp_layers"],
69+
export_format=export_format,
70+
output_dir=output_dir,
71+
)
72+
73+
# quantizer execute
74+
model = prepare(model=fp32_model, quant_config=quant_config)
75+
inc_model = convert(model)
76+
logger.info(f"Quantized model saved to {output_dir}")
77+
78+
79+
if __name__ == "__main__":
80+
import argparse
81+
82+
# Parse command-line arguments
83+
parser = argparse.ArgumentParser(description="Select a quantization scheme.")
84+
parser.add_argument(
85+
"--model",
86+
type=str,
87+
help="Path to the pre-trained model or model identifier from Hugging Face Hub.",
88+
)
89+
parser.add_argument(
90+
"-t",
91+
type=str,
92+
choices=topologies_config.keys(),
93+
default="mxfp4",
94+
help="Quantization scheme to use. Available options: " + ", ".join(topologies_config.keys()),
95+
)
96+
97+
parser.add_argument(
98+
"--enable_torch_compile",
99+
action="store_true",
100+
help="Enable torch compile for the model.",
101+
)
102+
parser.add_argument(
103+
"--use_autoround_format",
104+
action="store_true",
105+
help="Use AutoRound format for saving the quantized model.",
106+
)
107+
108+
parser.add_argument(
109+
"--skip_attn",
110+
action="store_true",
111+
help="Skip quantize attention layers.",
112+
)
113+
parser.add_argument(
114+
"--iters",
115+
type=int,
116+
default=0,
117+
help="Number of iterations for quantization.",
118+
)
119+
parser.add_argument(
120+
"--output_dir",
121+
type=str,
122+
default="./",
123+
help="Directory to save the quantized model.",
124+
)
125+
126+
args = parser.parse_args()
127+
128+
quant_model(args)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
lm-eval==0.4.9.1
2+
loguru

0 commit comments

Comments
 (0)