diff --git a/src/llmcompressor/modifiers/awq/convert_autoawq.py b/src/llmcompressor/modifiers/awq/convert_autoawq.py new file mode 100644 index 000000000..0ae65f134 --- /dev/null +++ b/src/llmcompressor/modifiers/awq/convert_autoawq.py @@ -0,0 +1,344 @@ +""" +Convert AutoAWQ models to llmcompressor-compatible models. + +This module offers the functionality to convert models quantized with AutoAWQ into +compressed models in llmcompressor's format, which can then be served with vLLM. +This module can be used as a CLI tool or as a Python API. + +## CLI Usage + +```sh +python -m llmcompressor.modifiers.awq.convert_autoawq \ + --model-name-or-path /path/to/model \ + --output-dir /path/to/compressed/model \ + --quantization-format naive-quantized +``` + +For more information, run `python -m llmcompressor.modifiers.awq.convert_autoawq --help` +or refer to the `ConversionArgs` dataclass below. + +## Python API Usage + +```python +from llmcompressor.modifiers.awq.convert_autoawq import load_and_convert_from_autoawq + +awq_model_path = "/path/to/model" # can also be model_id on huggingface hub +model = load_and_convert_from_autoawq(awq_model_path) +model.generate(...) # the converted model is now ready to be used. +``` +""" + +import glob +import os +import re +from dataclasses import dataclass, field +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Any, Literal, cast + +import torch +import transformers +from auto_round.export.export_to_awq.utils import ( + reverse_awq_order, + unpack_awq, +) +from compressed_tensors import ModelCompressor +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationConfig, + QuantizationScheme, + QuantizationStatus, + QuantizationStrategy, + QuantizationType, +) +from huggingface_hub import load_state_dict_from_file, snapshot_download + + +def is_autoawq_model(model_path: Path, trust_remote_code: bool = False) -> bool: + config = transformers.AutoConfig.from_pretrained( + model_path, trust_remote_code=trust_remote_code + ) + if not hasattr(config, "quantization_config"): + return False + + quantization_config = cast(dict[str, Any], config.quantization_config) + return quantization_config.get("quant_method") == "awq" + + +def resolve_model_path(model_name_or_path: str) -> Path: + if os.path.isdir(model_name_or_path): + return Path(model_name_or_path) + else: + # If the input is a model ID, download the model from the Hugging Face Hub and + # return the path to the local directory. + return Path(snapshot_download(model_name_or_path)) + + +def load_state_dict_from_model_dir(model_path: Path) -> dict[str, torch.Tensor]: + weight_files = glob.glob(str(model_path / "*.safetensors")) + if not weight_files: + weight_files = glob.glob(str(model_path / "*.bin")) + + state_dict = {} + for weight_file in weight_files: + state_dict.update( + load_state_dict_from_file( + weight_file, map_location="cpu", weights_only=True + ) + ) + return state_dict + + +def dequantize_gemm( + state_dict: dict[str, torch.Tensor], prefix: str, autoawq_config: dict[str, Any] +) -> None: + num_bits = cast(int, autoawq_config.get("bits")) + group_size = cast(int, autoawq_config.get("group_size")) + + qweight = state_dict.pop(f"{prefix}.qweight") + scales = state_dict.pop(f"{prefix}.scales") + qzeros = state_dict.pop(f"{prefix}.qzeros") + + def dequantize_gemm_original( + qweight: torch.Tensor, + qzeros: torch.Tensor, + scales: torch.Tensor, + bits: int, + group_size: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Modified from auto_round.export.export_to_awq.utils.dequantize_gemm.""" + # Unpack the qweight and qzeros tensors + iweight, izeros = unpack_awq(qweight, qzeros, bits) + # Reverse the order of the iweight and izeros tensors + iweight, izeros = reverse_awq_order(iweight, izeros, bits) + + # overflow checks + iweight = torch.bitwise_and(iweight, (2**bits) - 1) + izeros = torch.bitwise_and(izeros, (2**bits) - 1) + + # fp16 weights + scales_interleaved = scales.repeat_interleave(group_size, dim=0) + izeros_interleaved = izeros.repeat_interleave(group_size, dim=0) + fweight = (iweight - izeros_interleaved) * scales_interleaved + + return fweight, izeros + + weight, zero_point = dequantize_gemm_original( + qweight, qzeros, scales, num_bits, group_size + ) + + # AutoAWQ uses [0, 2^bits - 1], e.g., [0, 15], for quantized weights, but + # compressed-tensors uses [-2^(bits - 1), 2^(bits - 1) - 1], e.g., [-8, 7]. + # Therefore, we need to shift the zero point by 2^(bits - 1) to match the range + # of compressed-tensors and to allow correct quant/dequantization. + shifted_zero_point = zero_point - 2 ** (num_bits - 1) + + state_dict.update( + { + f"{prefix}.weight": weight.T, + f"{prefix}.weight_scale": scales.T, + f"{prefix}.weight_zero_point": shifted_zero_point.T, + } + ) + + +def dequantize_autoawq_state_dict( + state_dict: dict[str, torch.Tensor], autoawq_config: dict[str, Any] +) -> dict[str, torch.Tensor]: + version = cast(str, autoawq_config.get("version")) + + # TODO: maybe add support for other versions? + match version: + case "gemm": + dequantize_fn = dequantize_gemm + case _: + raise ValueError(f"Unsupported version: {version}") + + keys = list(state_dict.keys()) + for key in filter(lambda k: k.endswith("qweight"), keys): + prefix = key.removesuffix(".qweight") + dequantize_fn(state_dict, prefix, autoawq_config) + + return state_dict + + +def convert_and_save( + model_name_or_path: str, + output_dir: str, + quantization_format: str, + overwrite: bool = False, + trust_remote_code: bool = False, +) -> None: + """Convert an AutoAWQ model to a compressed model and save it. + + Steps: + + 1. Load the model weights directly. + 2. Dequantize the weights accordingly. + 3. Load the model with the dequantized weights. + 4. Add the quantization parameters to the model. + 5. Re-pack the weights using `ModelCompressor` with the correct configuration. + 6. Save the model to the output directory. + + :param model_name_or_path: Model ID on huggingface hub or path to local model. + :param output_dir: Path to save the converted model. + :param quantization_format: Compression format to be saved. + :param overwrite: Overwrite the existing output directory if it exists. + :param trust_remote_code: Whether to trust remote code. + """ + output_exists = os.path.exists(output_dir) + is_directory = os.path.isdir(output_dir) if output_exists else False + is_empty_dir = False + if output_exists and is_directory: + is_empty_dir = not any(os.scandir(output_dir)) + + if not output_exists: + # Safe: output_dir does not exist + pass + elif not is_directory or (not is_empty_dir and not overwrite): + raise FileExistsError( + f"{output_dir=} already exists. Set `overwrite=True` to" + " overwrite the existing directory." + ) + + model_path = resolve_model_path(model_name_or_path) + if not is_autoawq_model(model_path, trust_remote_code): + raise ValueError("Model is not an AutoAWQ model") + + config = transformers.AutoConfig.from_pretrained( + model_path, trust_remote_code=trust_remote_code + ) + autoawq_config = cast(dict[str, Any], config.quantization_config) + num_bits = cast(int, autoawq_config.get("bits")) + is_symmetric = not autoawq_config.get("zero_point") + group_size = cast(int, autoawq_config.get("group_size")) + + # Convert AutoAWQ's substring-based ignore list to llm-compressor's regex format + # Usage in AutoAWQ: + # ```python + # if any(key in name for key in modules_to_not_convert): ... + # ``` + # See https://github.com/casper-hansen/AutoAWQ/blob/88e4c76b20755db275574e6a03c83c84ba3bece5/awq/utils/module.py#L62 + modules_to_not_convert = autoawq_config.get("modules_to_not_convert", None) + ignore = [] + if modules_to_not_convert is not None: + # Convert each substring pattern to a regex pattern that matches it anywhere + for module in modules_to_not_convert: + ignore.append(f"re:.*{re.escape(module)}.*") + + ignore.append("lm_head") # AutoAWQ ignores lm_head by default + + # 1. Load the model weights directly. + state_dict = load_state_dict_from_model_dir(model_path) + + # 2. Dequantize the weights accordingly. + state_dict = dequantize_autoawq_state_dict(state_dict, autoawq_config) + + # 3. Load the model with the dequantized weights. + del config.quantization_config # remove to avoid loading with AutoAWQ. + with transformers.modeling_utils.no_init_weights(): + model = transformers.AutoModelForCausalLM.from_config( + config, torch_dtype=torch.float16, trust_remote_code=trust_remote_code + ) + + model.load_state_dict(state_dict, strict=False) + + # 4. Add the quantization parameters to the model. + quantization_scheme = QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs( + num_bits=num_bits, + type=QuantizationType.INT, + symmetric=is_symmetric, + group_size=group_size, + strategy=QuantizationStrategy.GROUP, + ), + ) + + for key in filter(lambda k: k.endswith("weight_zero_point"), state_dict.keys()): + module_name = key.removesuffix(".weight_zero_point") + setattr( + model.get_submodule(module_name), "quantization_scheme", quantization_scheme + ) + + quant_config = QuantizationConfig( + config_groups={"group_0": quantization_scheme}, + quant_method="compressed-tensors", + quantization_status=QuantizationStatus.COMPRESSED, + format=quantization_format, + ignore=ignore, + ) + + # 5. Re-pack the weights using `ModelCompressor`. + compressor = ModelCompressor(quantization_config=quant_config) + compressed_state_dict = compressor.compress(model, state_dict, show_progress=True) + + # 6. Save the model. + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_name_or_path, trust_remote_code=trust_remote_code + ) + model.save_pretrained(output_dir, state_dict=compressed_state_dict) + tokenizer.save_pretrained(output_dir) + compressor.update_config(output_dir) + + +def load_and_convert_from_autoawq( + model_name_or_path: str, + quantization_format: str = "naive-quantized", + trust_remote_code: bool = False, +) -> transformers.modeling_utils.PreTrainedModel: + """ + Load an AutoAWQ checkpoint and convert it to a compressed model. + + :param model_name_or_path: Model ID on huggingface hub or path to local model. + :param quantization_format: Compression format to be saved. + :param trust_remote_code: Whether to trust remote code. + :return: A compressed model. + """ + with TemporaryDirectory() as temp_dir: + convert_and_save( + model_name_or_path, + temp_dir, + quantization_format, + trust_remote_code=trust_remote_code, + ) + return transformers.AutoModelForCausalLM.from_pretrained( + temp_dir, torch_dtype=torch.float16, trust_remote_code=trust_remote_code + ) + + +@dataclass +class ConversionArgs: + model_name_or_path: str = field( + metadata={"help": "Model ID on huggingface hub or path to local model."}, + ) + output_dir: str = field( + metadata={"help": "Path to save the converted model."}, + ) + quantization_format: Literal["naive-quantized", "pack-quantized"] = field( + default="naive-quantized", + metadata={"help": "Compression format to be saved."}, + ) # TODO: switch default to packed-quantized once supported by llm-compressor. + overwrite: bool = field( + default=False, + metadata={"help": "Overwrite the existing output directory if it exists."}, + ) + trust_remote_code: bool = field( + default=False, + metadata={"help": "Whether to trust remote code."}, + ) + + +__all__ = ["convert_and_save", "load_and_convert_from_autoawq", "ConversionArgs"] + + +if __name__ == "__main__": + parser = transformers.HfArgumentParser(ConversionArgs) + args = parser.parse_args_into_dataclasses()[0] + convert_and_save( + args.model_name_or_path, + args.output_dir, + args.quantization_format, + args.overwrite, + args.trust_remote_code, + ) diff --git a/tests/llmcompressor/modifiers/awq/test_convert_autoawq.py b/tests/llmcompressor/modifiers/awq/test_convert_autoawq.py new file mode 100644 index 000000000..7a9723918 --- /dev/null +++ b/tests/llmcompressor/modifiers/awq/test_convert_autoawq.py @@ -0,0 +1,56 @@ +from tempfile import TemporaryDirectory + +from lm_eval.evaluator import simple_evaluate + +from llmcompressor.modifiers.awq.convert_autoawq import convert_and_save +from tests.testing_utils import requires_gpu + + +def run_lm_eval(model_name_or_path: str): + results = simple_evaluate( + model="hf", + model_args=f"pretrained={model_name_or_path},dtype=float16", + tasks=["arc_challenge", "arc_easy"], + num_fewshot=5, + batch_size=16, + ) + + return results + + +def compare_models(model_name_or_path: str): + autoawq_result = run_lm_eval(model_name_or_path) + with TemporaryDirectory() as converted_model_dir: + convert_and_save(model_name_or_path, converted_model_dir, "naive-quantized") + converted_result = run_lm_eval(converted_model_dir) + + arc_c_autoawq = autoawq_result["results"]["arc_challenge"]["acc_norm,none"] + arc_c_converted = converted_result["results"]["arc_challenge"]["acc_norm,none"] + arc_e_autoawq = autoawq_result["results"]["arc_easy"]["acc_norm,none"] + arc_e_converted = converted_result["results"]["arc_easy"]["acc_norm,none"] + + assert abs(arc_e_autoawq - arc_e_converted) < 1e-2, ( + f"Arc Easy: autoawq={arc_e_autoawq} != converted={arc_e_converted}." + ) + assert abs(arc_c_autoawq - arc_c_converted) < 1e-2, ( + f"Arc Challenge: autoawq={arc_c_autoawq} != converted={arc_c_converted}." + ) + + +@requires_gpu +def test_mistral(): + compare_models( + "fbaldassarri/mistralai_Mistral-7B-Instruct-v0.3-autoawq-int4-gs128-asym" + ) + + +@requires_gpu +def test_qwen(): + compare_models( + "ruikangliu/DeepSeek-R1-Distill-Qwen-1.5B-quantized.awq-autoawq-w4g128" + ) + + +@requires_gpu +def test_llama(): + compare_models("AMead10/Llama-3.2-3B-Instruct-AWQ")