|
| 1 | +import glob |
| 2 | +import os |
| 3 | +from dataclasses import dataclass, field |
| 4 | +from pathlib import Path |
| 5 | +from tempfile import TemporaryDirectory |
| 6 | +from typing import Any, Literal, cast |
| 7 | + |
| 8 | +import torch |
| 9 | +import transformers |
| 10 | +from auto_round.export.export_to_awq.utils import ( |
| 11 | + reverse_awq_order, |
| 12 | + unpack_awq, |
| 13 | +) |
| 14 | +from compressed_tensors import ModelCompressor |
| 15 | +from compressed_tensors.quantization import ( |
| 16 | + QuantizationArgs, |
| 17 | + QuantizationConfig, |
| 18 | + QuantizationScheme, |
| 19 | + QuantizationStatus, |
| 20 | + QuantizationStrategy, |
| 21 | + QuantizationType, |
| 22 | +) |
| 23 | +from huggingface_hub import load_state_dict_from_file, snapshot_download |
| 24 | + |
| 25 | + |
| 26 | +def is_autoawq_model(model_path: Path) -> bool: |
| 27 | + config = transformers.AutoConfig.from_pretrained(model_path, trust_remote_code=True) |
| 28 | + if not hasattr(config, "quantization_config"): |
| 29 | + return False |
| 30 | + |
| 31 | + quantization_config = cast(dict[str, Any], config.quantization_config) |
| 32 | + return quantization_config.get("quant_method") == "awq" |
| 33 | + |
| 34 | + |
| 35 | +def resolve_model_path(model_name_or_path: str) -> Path: |
| 36 | + """Locate the model path. |
| 37 | +
|
| 38 | + If the input is a repository ID, download the model from the Hugging Face Hub and |
| 39 | + return the path to the local directory. |
| 40 | + """ |
| 41 | + if os.path.isdir(model_name_or_path): |
| 42 | + return Path(model_name_or_path) |
| 43 | + |
| 44 | + return Path(snapshot_download(model_name_or_path)) |
| 45 | + |
| 46 | + |
| 47 | +def load_state_dict_from_model_dir(model_path: Path) -> dict[str, torch.Tensor]: |
| 48 | + weight_files = glob.glob(str(model_path / "*.safetensors")) |
| 49 | + if not weight_files: |
| 50 | + weight_files = glob.glob(str(model_path / "*.bin")) |
| 51 | + |
| 52 | + state_dict = {} |
| 53 | + for weight_file in weight_files: |
| 54 | + state_dict.update( |
| 55 | + load_state_dict_from_file( |
| 56 | + weight_file, map_location="cpu", weights_only=True |
| 57 | + ) |
| 58 | + ) |
| 59 | + return state_dict |
| 60 | + |
| 61 | + |
| 62 | +def dequantize_gemm( |
| 63 | + state_dict: dict[str, torch.Tensor], prefix: str, autoawq_config: dict[str, Any] |
| 64 | +) -> None: |
| 65 | + num_bits = cast(int, autoawq_config.get("bits")) |
| 66 | + group_size = cast(int, autoawq_config.get("group_size")) |
| 67 | + |
| 68 | + qweight = state_dict.pop(f"{prefix}.qweight") |
| 69 | + scales = state_dict.pop(f"{prefix}.scales") |
| 70 | + qzeros = state_dict.pop(f"{prefix}.qzeros") |
| 71 | + |
| 72 | + def dequantize_gemm_original( |
| 73 | + qweight: torch.Tensor, |
| 74 | + qzeros: torch.Tensor, |
| 75 | + scales: torch.Tensor, |
| 76 | + bits: int, |
| 77 | + group_size: int, |
| 78 | + ) -> tuple[torch.Tensor, torch.Tensor]: |
| 79 | + """Modified from auto_round.export.export_to_awq.utils.dequantize_gemm.""" |
| 80 | + # Unpack the qweight and qzeros tensors |
| 81 | + iweight, izeros = unpack_awq(qweight, qzeros, bits) |
| 82 | + # Reverse the order of the iweight and izeros tensors |
| 83 | + iweight, izeros = reverse_awq_order(iweight, izeros, bits) |
| 84 | + |
| 85 | + # overflow checks |
| 86 | + iweight = torch.bitwise_and(iweight, (2**bits) - 1) |
| 87 | + izeros = torch.bitwise_and(izeros, (2**bits) - 1) |
| 88 | + |
| 89 | + # fp16 weights |
| 90 | + scales_interleaved = scales.repeat_interleave(group_size, dim=0) |
| 91 | + izeros_interleaved = izeros.repeat_interleave(group_size, dim=0) |
| 92 | + fweight = (iweight - izeros_interleaved) * scales_interleaved |
| 93 | + |
| 94 | + return fweight, izeros |
| 95 | + |
| 96 | + weight, zero_point = dequantize_gemm_original( |
| 97 | + qweight, qzeros, scales, num_bits, group_size |
| 98 | + ) |
| 99 | + |
| 100 | + # AutoAWQ uses [0, 2^bits - 1], e.g., [0, 15], for quantized weights, but |
| 101 | + # compressed-tensors uses [-2^(bits - 1), 2^(bits - 1) - 1], e.g., [-8, 7]. |
| 102 | + # Therefore, we need to shift the zero point by 2^(bits - 1) to match the range |
| 103 | + # of compressed-tensors and to allow correct quant/dequantization. |
| 104 | + shifted_zero_point = zero_point - 2 ** (num_bits - 1) |
| 105 | + |
| 106 | + state_dict.update( |
| 107 | + { |
| 108 | + f"{prefix}.weight": weight.T, |
| 109 | + f"{prefix}.weight_scale": scales.T, |
| 110 | + f"{prefix}.weight_zero_point": shifted_zero_point.T, |
| 111 | + } |
| 112 | + ) |
| 113 | + |
| 114 | + |
| 115 | +def dequantize_autoawq_state_dict( |
| 116 | + state_dict: dict[str, torch.Tensor], autoawq_config: dict[str, Any] |
| 117 | +) -> dict[str, torch.Tensor]: |
| 118 | + version = cast(str, autoawq_config.get("version")) |
| 119 | + |
| 120 | + # TODO: maybe add support for other versions? |
| 121 | + match version: |
| 122 | + case "gemm": |
| 123 | + dequantize_fn = dequantize_gemm |
| 124 | + case _: |
| 125 | + raise ValueError(f"Unsupported version: {version}") |
| 126 | + |
| 127 | + keys = list(state_dict.keys()) |
| 128 | + for key in filter(lambda k: k.endswith("qweight"), keys): |
| 129 | + prefix = key.removesuffix(".qweight") |
| 130 | + dequantize_fn(state_dict, prefix, autoawq_config) |
| 131 | + |
| 132 | + return state_dict |
| 133 | + |
| 134 | + |
| 135 | +def convert_and_save( |
| 136 | + model_name_or_path: str, |
| 137 | + output_dir: str, |
| 138 | + quantization_format: str, |
| 139 | + overwrite: bool = False, |
| 140 | + trust_remote_code: bool = False, |
| 141 | +) -> None: |
| 142 | + """Convert an AutoAWQ model to a compressed model and save it. |
| 143 | +
|
| 144 | + Steps: |
| 145 | +
|
| 146 | + 1. Load the model weights directly. |
| 147 | + 2. Dequantize the weights accordingly. |
| 148 | + 3. Load the model with the dequantized weights. |
| 149 | + 4. Add the quantization parameters to the model. |
| 150 | + 5. Re-pack the weights using `ModelCompressor` with the correct configuration. |
| 151 | + 6. Save the model to the output directory. |
| 152 | +
|
| 153 | + :param model_name_or_path: Model ID on huggingface hub or path to local model. |
| 154 | + :param output_dir: Path to save the converted model. |
| 155 | + :param quantization_format: Compression format to be saved. |
| 156 | + :param overwrite: Overwrite the existing output directory if it exists. |
| 157 | + :param trust_remote_code: Whether to trust remote code. |
| 158 | + """ |
| 159 | + if os.path.exists(output_dir) and not overwrite: |
| 160 | + raise FileExistsError( |
| 161 | + f"Output directory {output_dir} already exists. Set `overwrite=True` to" |
| 162 | + " overwrite the existing directory." |
| 163 | + ) |
| 164 | + |
| 165 | + model_path = resolve_model_path(model_name_or_path) |
| 166 | + if not is_autoawq_model(model_path): |
| 167 | + raise ValueError("Model is not an AutoAWQ model") |
| 168 | + |
| 169 | + config = transformers.AutoConfig.from_pretrained( |
| 170 | + model_path, trust_remote_code=trust_remote_code |
| 171 | + ) |
| 172 | + autoawq_config = cast(dict[str, Any], config.quantization_config) |
| 173 | + num_bits = cast(int, autoawq_config.get("bits")) |
| 174 | + is_symmetric = not autoawq_config.get("zero_point") |
| 175 | + group_size = cast(int, autoawq_config.get("group_size")) |
| 176 | + |
| 177 | + # TODO: check syntax of modules_to_not_convert |
| 178 | + ignore = autoawq_config.get("modules_to_not_convert") |
| 179 | + if ignore is None: |
| 180 | + ignore = ["lm_head"] |
| 181 | + |
| 182 | + # 1. Load the model weights directly. |
| 183 | + state_dict = load_state_dict_from_model_dir(model_path) |
| 184 | + |
| 185 | + # 2. Dequantize the weights accordingly. |
| 186 | + state_dict = dequantize_autoawq_state_dict(state_dict, autoawq_config) |
| 187 | + |
| 188 | + # 3. Load the model with the dequantized weights. |
| 189 | + del config.quantization_config # remove to avoid loading with AutoAWQ. |
| 190 | + with transformers.modeling_utils.no_init_weights(): |
| 191 | + model = transformers.AutoModelForCausalLM.from_config( |
| 192 | + config, torch_dtype=torch.float16, trust_remote_code=trust_remote_code |
| 193 | + ) |
| 194 | + |
| 195 | + model.load_state_dict(state_dict, strict=False) |
| 196 | + |
| 197 | + # 4. Add the quantization parameters to the model. |
| 198 | + quantization_scheme = QuantizationScheme( |
| 199 | + targets=["Linear"], |
| 200 | + weights=QuantizationArgs( |
| 201 | + num_bits=num_bits, |
| 202 | + type=QuantizationType.INT, |
| 203 | + symmetric=is_symmetric, |
| 204 | + group_size=group_size, |
| 205 | + strategy=QuantizationStrategy.GROUP, |
| 206 | + ), |
| 207 | + ) |
| 208 | + |
| 209 | + for key in filter(lambda k: k.endswith("weight_zero_point"), state_dict.keys()): |
| 210 | + module_name = key.removesuffix(".weight_zero_point") |
| 211 | + setattr( |
| 212 | + model.get_submodule(module_name), "quantization_scheme", quantization_scheme |
| 213 | + ) |
| 214 | + |
| 215 | + quant_config = QuantizationConfig( |
| 216 | + config_groups={"group_0": quantization_scheme}, |
| 217 | + quant_method="compressed-tensors", |
| 218 | + quantization_status=QuantizationStatus.COMPRESSED, |
| 219 | + format=quantization_format, |
| 220 | + ignore=ignore, |
| 221 | + ) |
| 222 | + |
| 223 | + # 5. Re-pack the weights using `ModelCompressor`. |
| 224 | + compressor = ModelCompressor(quantization_config=quant_config) |
| 225 | + compressed_state_dict = compressor.compress(model, state_dict, show_progress=True) |
| 226 | + |
| 227 | + # 6. Save the model. |
| 228 | + tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path) |
| 229 | + model.save_pretrained(output_dir, state_dict=compressed_state_dict) |
| 230 | + tokenizer.save_pretrained(output_dir) |
| 231 | + compressor.update_config(output_dir) |
| 232 | + |
| 233 | + |
| 234 | +def load_and_convert_from_autoawq( |
| 235 | + model_name_or_path: str, |
| 236 | + quantization_format: str = "pack-quantized", |
| 237 | + trust_remote_code: bool = False, |
| 238 | +) -> transformers.modeling_utils.PreTrainedModel: |
| 239 | + """ |
| 240 | + Load an AutoAWQ checkpoint and convert it to a compressed model. |
| 241 | +
|
| 242 | + :param model_name_or_path: Model ID on huggingface hub or path to local model. |
| 243 | + :param quantization_format: Compression format to be saved. |
| 244 | + :param trust_remote_code: Whether to trust remote code. |
| 245 | + :return: A compressed model. |
| 246 | + """ |
| 247 | + with TemporaryDirectory() as temp_dir: |
| 248 | + convert_and_save(model_name_or_path, temp_dir, quantization_format) |
| 249 | + return transformers.AutoModelForCausalLM.from_pretrained( |
| 250 | + temp_dir, torch_dtype=torch.float16, trust_remote_code=trust_remote_code |
| 251 | + ) |
| 252 | + |
| 253 | + |
| 254 | +@dataclass |
| 255 | +class ConversionArgs: |
| 256 | + model_name_or_path: str = field( |
| 257 | + metadata={"help": "Model ID on huggingface hub or path to local model."}, |
| 258 | + ) |
| 259 | + output_dir: str = field( |
| 260 | + metadata={"help": "Path to save the converted model."}, |
| 261 | + ) |
| 262 | + quantization_format: Literal["naive-quantized", "packed-quantized"] = field( |
| 263 | + default="naive-quantized", |
| 264 | + metadata={"help": "Compression format to be saved."}, |
| 265 | + ) # TODO: switch default to packed-quantized once supported by llm-compressor. |
| 266 | + overwrite: bool = field( |
| 267 | + default=False, |
| 268 | + metadata={"help": "Overwrite the existing output directory if it exists."}, |
| 269 | + ) |
| 270 | + trust_remote_code: bool = field( |
| 271 | + default=False, |
| 272 | + metadata={"help": "Whether to trust remote code."}, |
| 273 | + ) |
| 274 | + |
| 275 | + |
| 276 | +__all__ = ["convert_and_save", "load_and_convert_from_autoawq"] |
| 277 | + |
| 278 | + |
| 279 | +if __name__ == "__main__": |
| 280 | + parser = transformers.HfArgumentParser(ConversionArgs) |
| 281 | + args = parser.parse_args_into_dataclasses()[0] |
| 282 | + convert_and_save( |
| 283 | + args.model_name_or_path, |
| 284 | + args.output_dir, |
| 285 | + args.quantization_format, |
| 286 | + args.overwrite, |
| 287 | + args.trust_remote_code, |
| 288 | + ) |
0 commit comments