Skip to content

Commit a6da068

Browse files
committed
Add convert_autoawq.py script.
Signed-off-by: Muti Chung <mtchung037@gmail.com>
1 parent 6cf8d29 commit a6da068

File tree

1 file changed

+288
-0
lines changed

1 file changed

+288
-0
lines changed
Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
import glob
2+
import os
3+
from dataclasses import dataclass, field
4+
from pathlib import Path
5+
from tempfile import TemporaryDirectory
6+
from typing import Any, Literal, cast
7+
8+
import torch
9+
import transformers
10+
from auto_round.export.export_to_awq.utils import (
11+
reverse_awq_order,
12+
unpack_awq,
13+
)
14+
from compressed_tensors import ModelCompressor
15+
from compressed_tensors.quantization import (
16+
QuantizationArgs,
17+
QuantizationConfig,
18+
QuantizationScheme,
19+
QuantizationStatus,
20+
QuantizationStrategy,
21+
QuantizationType,
22+
)
23+
from huggingface_hub import load_state_dict_from_file, snapshot_download
24+
25+
26+
def is_autoawq_model(model_path: Path) -> bool:
27+
config = transformers.AutoConfig.from_pretrained(model_path, trust_remote_code=True)
28+
if not hasattr(config, "quantization_config"):
29+
return False
30+
31+
quantization_config = cast(dict[str, Any], config.quantization_config)
32+
return quantization_config.get("quant_method") == "awq"
33+
34+
35+
def resolve_model_path(model_name_or_path: str) -> Path:
36+
"""Locate the model path.
37+
38+
If the input is a repository ID, download the model from the Hugging Face Hub and
39+
return the path to the local directory.
40+
"""
41+
if os.path.isdir(model_name_or_path):
42+
return Path(model_name_or_path)
43+
44+
return Path(snapshot_download(model_name_or_path))
45+
46+
47+
def load_state_dict_from_model_dir(model_path: Path) -> dict[str, torch.Tensor]:
48+
weight_files = glob.glob(str(model_path / "*.safetensors"))
49+
if not weight_files:
50+
weight_files = glob.glob(str(model_path / "*.bin"))
51+
52+
state_dict = {}
53+
for weight_file in weight_files:
54+
state_dict.update(
55+
load_state_dict_from_file(
56+
weight_file, map_location="cpu", weights_only=True
57+
)
58+
)
59+
return state_dict
60+
61+
62+
def dequantize_gemm(
63+
state_dict: dict[str, torch.Tensor], prefix: str, autoawq_config: dict[str, Any]
64+
) -> None:
65+
num_bits = cast(int, autoawq_config.get("bits"))
66+
group_size = cast(int, autoawq_config.get("group_size"))
67+
68+
qweight = state_dict.pop(f"{prefix}.qweight")
69+
scales = state_dict.pop(f"{prefix}.scales")
70+
qzeros = state_dict.pop(f"{prefix}.qzeros")
71+
72+
def dequantize_gemm_original(
73+
qweight: torch.Tensor,
74+
qzeros: torch.Tensor,
75+
scales: torch.Tensor,
76+
bits: int,
77+
group_size: int,
78+
) -> tuple[torch.Tensor, torch.Tensor]:
79+
"""Modified from auto_round.export.export_to_awq.utils.dequantize_gemm."""
80+
# Unpack the qweight and qzeros tensors
81+
iweight, izeros = unpack_awq(qweight, qzeros, bits)
82+
# Reverse the order of the iweight and izeros tensors
83+
iweight, izeros = reverse_awq_order(iweight, izeros, bits)
84+
85+
# overflow checks
86+
iweight = torch.bitwise_and(iweight, (2**bits) - 1)
87+
izeros = torch.bitwise_and(izeros, (2**bits) - 1)
88+
89+
# fp16 weights
90+
scales_interleaved = scales.repeat_interleave(group_size, dim=0)
91+
izeros_interleaved = izeros.repeat_interleave(group_size, dim=0)
92+
fweight = (iweight - izeros_interleaved) * scales_interleaved
93+
94+
return fweight, izeros
95+
96+
weight, zero_point = dequantize_gemm_original(
97+
qweight, qzeros, scales, num_bits, group_size
98+
)
99+
100+
# AutoAWQ uses [0, 2^bits - 1], e.g., [0, 15], for quantized weights, but
101+
# compressed-tensors uses [-2^(bits - 1), 2^(bits - 1) - 1], e.g., [-8, 7].
102+
# Therefore, we need to shift the zero point by 2^(bits - 1) to match the range
103+
# of compressed-tensors and to allow correct quant/dequantization.
104+
shifted_zero_point = zero_point - 2 ** (num_bits - 1)
105+
106+
state_dict.update(
107+
{
108+
f"{prefix}.weight": weight.T,
109+
f"{prefix}.weight_scale": scales.T,
110+
f"{prefix}.weight_zero_point": shifted_zero_point.T,
111+
}
112+
)
113+
114+
115+
def dequantize_autoawq_state_dict(
116+
state_dict: dict[str, torch.Tensor], autoawq_config: dict[str, Any]
117+
) -> dict[str, torch.Tensor]:
118+
version = cast(str, autoawq_config.get("version"))
119+
120+
# TODO: maybe add support for other versions?
121+
match version:
122+
case "gemm":
123+
dequantize_fn = dequantize_gemm
124+
case _:
125+
raise ValueError(f"Unsupported version: {version}")
126+
127+
keys = list(state_dict.keys())
128+
for key in filter(lambda k: k.endswith("qweight"), keys):
129+
prefix = key.removesuffix(".qweight")
130+
dequantize_fn(state_dict, prefix, autoawq_config)
131+
132+
return state_dict
133+
134+
135+
def convert_and_save(
136+
model_name_or_path: str,
137+
output_dir: str,
138+
quantization_format: str,
139+
overwrite: bool = False,
140+
trust_remote_code: bool = False,
141+
) -> None:
142+
"""Convert an AutoAWQ model to a compressed model and save it.
143+
144+
Steps:
145+
146+
1. Load the model weights directly.
147+
2. Dequantize the weights accordingly.
148+
3. Load the model with the dequantized weights.
149+
4. Add the quantization parameters to the model.
150+
5. Re-pack the weights using `ModelCompressor` with the correct configuration.
151+
6. Save the model to the output directory.
152+
153+
:param model_name_or_path: Model ID on huggingface hub or path to local model.
154+
:param output_dir: Path to save the converted model.
155+
:param quantization_format: Compression format to be saved.
156+
:param overwrite: Overwrite the existing output directory if it exists.
157+
:param trust_remote_code: Whether to trust remote code.
158+
"""
159+
if os.path.exists(output_dir) and not overwrite:
160+
raise FileExistsError(
161+
f"Output directory {output_dir} already exists. Set `overwrite=True` to"
162+
" overwrite the existing directory."
163+
)
164+
165+
model_path = resolve_model_path(model_name_or_path)
166+
if not is_autoawq_model(model_path):
167+
raise ValueError("Model is not an AutoAWQ model")
168+
169+
config = transformers.AutoConfig.from_pretrained(
170+
model_path, trust_remote_code=trust_remote_code
171+
)
172+
autoawq_config = cast(dict[str, Any], config.quantization_config)
173+
num_bits = cast(int, autoawq_config.get("bits"))
174+
is_symmetric = not autoawq_config.get("zero_point")
175+
group_size = cast(int, autoawq_config.get("group_size"))
176+
177+
# TODO: check syntax of modules_to_not_convert
178+
ignore = autoawq_config.get("modules_to_not_convert")
179+
if ignore is None:
180+
ignore = ["lm_head"]
181+
182+
# 1. Load the model weights directly.
183+
state_dict = load_state_dict_from_model_dir(model_path)
184+
185+
# 2. Dequantize the weights accordingly.
186+
state_dict = dequantize_autoawq_state_dict(state_dict, autoawq_config)
187+
188+
# 3. Load the model with the dequantized weights.
189+
del config.quantization_config # remove to avoid loading with AutoAWQ.
190+
with transformers.modeling_utils.no_init_weights():
191+
model = transformers.AutoModelForCausalLM.from_config(
192+
config, torch_dtype=torch.float16, trust_remote_code=trust_remote_code
193+
)
194+
195+
model.load_state_dict(state_dict, strict=False)
196+
197+
# 4. Add the quantization parameters to the model.
198+
quantization_scheme = QuantizationScheme(
199+
targets=["Linear"],
200+
weights=QuantizationArgs(
201+
num_bits=num_bits,
202+
type=QuantizationType.INT,
203+
symmetric=is_symmetric,
204+
group_size=group_size,
205+
strategy=QuantizationStrategy.GROUP,
206+
),
207+
)
208+
209+
for key in filter(lambda k: k.endswith("weight_zero_point"), state_dict.keys()):
210+
module_name = key.removesuffix(".weight_zero_point")
211+
setattr(
212+
model.get_submodule(module_name), "quantization_scheme", quantization_scheme
213+
)
214+
215+
quant_config = QuantizationConfig(
216+
config_groups={"group_0": quantization_scheme},
217+
quant_method="compressed-tensors",
218+
quantization_status=QuantizationStatus.COMPRESSED,
219+
format=quantization_format,
220+
ignore=ignore,
221+
)
222+
223+
# 5. Re-pack the weights using `ModelCompressor`.
224+
compressor = ModelCompressor(quantization_config=quant_config)
225+
compressed_state_dict = compressor.compress(model, state_dict, show_progress=True)
226+
227+
# 6. Save the model.
228+
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path)
229+
model.save_pretrained(output_dir, state_dict=compressed_state_dict)
230+
tokenizer.save_pretrained(output_dir)
231+
compressor.update_config(output_dir)
232+
233+
234+
def load_and_convert_from_autoawq(
235+
model_name_or_path: str,
236+
quantization_format: str = "pack-quantized",
237+
trust_remote_code: bool = False,
238+
) -> transformers.modeling_utils.PreTrainedModel:
239+
"""
240+
Load an AutoAWQ checkpoint and convert it to a compressed model.
241+
242+
:param model_name_or_path: Model ID on huggingface hub or path to local model.
243+
:param quantization_format: Compression format to be saved.
244+
:param trust_remote_code: Whether to trust remote code.
245+
:return: A compressed model.
246+
"""
247+
with TemporaryDirectory() as temp_dir:
248+
convert_and_save(model_name_or_path, temp_dir, quantization_format)
249+
return transformers.AutoModelForCausalLM.from_pretrained(
250+
temp_dir, torch_dtype=torch.float16, trust_remote_code=trust_remote_code
251+
)
252+
253+
254+
@dataclass
255+
class ConversionArgs:
256+
model_name_or_path: str = field(
257+
metadata={"help": "Model ID on huggingface hub or path to local model."},
258+
)
259+
output_dir: str = field(
260+
metadata={"help": "Path to save the converted model."},
261+
)
262+
quantization_format: Literal["naive-quantized", "packed-quantized"] = field(
263+
default="naive-quantized",
264+
metadata={"help": "Compression format to be saved."},
265+
) # TODO: switch default to packed-quantized once supported by llm-compressor.
266+
overwrite: bool = field(
267+
default=False,
268+
metadata={"help": "Overwrite the existing output directory if it exists."},
269+
)
270+
trust_remote_code: bool = field(
271+
default=False,
272+
metadata={"help": "Whether to trust remote code."},
273+
)
274+
275+
276+
__all__ = ["convert_and_save", "load_and_convert_from_autoawq"]
277+
278+
279+
if __name__ == "__main__":
280+
parser = transformers.HfArgumentParser(ConversionArgs)
281+
args = parser.parse_args_into_dataclasses()[0]
282+
convert_and_save(
283+
args.model_name_or_path,
284+
args.output_dir,
285+
args.quantization_format,
286+
args.overwrite,
287+
args.trust_remote_code,
288+
)

0 commit comments

Comments
 (0)