From 998856b373e7beb7be349f6ce5bd4ab920d891c8 Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Tue, 9 Dec 2025 07:44:00 +0000 Subject: [PATCH 1/4] Support KIMI K2 Thinking int4 checkpoint PTQ Signed-off-by: Chenjie Luo --- examples/llm_ptq/example_utils.py | 7 ++++ examples/llm_ptq/hf_ptq.py | 38 ++++++++++++------- .../llm_ptq/scripts/huggingface_example.sh | 4 +- .../torch/quantization/plugins/huggingface.py | 28 ++++++++++++++ 4 files changed, 61 insertions(+), 16 deletions(-) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index ce3fb0853..c13b9f897 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -327,6 +327,13 @@ def get_model( device_map=device_map, **model_kwargs, ) + elif hf_config.quantization_config.get("format", None) == "pack-quantized": + model = AutoModelForCausalLM.from_pretrained( + ckpt_path, + torch_dtype=torch.float16, + device_map="auto", + trust_remote_code=trust_remote_code, + ) else: architecture = hf_config.architectures[0] diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 57f0b5a89..9697adf08 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -14,6 +14,7 @@ # limitations under the License. import argparse +import gc import random import time import warnings @@ -510,19 +511,26 @@ def main(args): "input_features" if model_type == "whisper" else "input_ids" ][0:1] - # Generate preview before quantization - if is_nemotron_vl_model and tokenizer is not None: - generated_ids_before_ptq = run_nemotron_vl_preview( - full_model, - tokenizer, - input_ids, - args.pyt_ckpt_path, - "before quantization", - allow_fallback=True, - ) - else: - # Standard generation for non-Nemotron VL models - generated_ids_before_ptq = full_model.generate(input_ids, max_new_tokens=100) + try: + # Generate preview before quantization + if is_nemotron_vl_model and tokenizer is not None: + generated_ids_before_ptq = run_nemotron_vl_preview( + full_model, + tokenizer, + input_ids, + args.pyt_ckpt_path, + "before quantization", + allow_fallback=True, + ) + else: + # Standard generation for non-Nemotron VL models + generated_ids_before_ptq = full_model.generate(input_ids, max_new_tokens=100) + except torch.OutOfMemoryError: + print("Out of memory. Skipping preview generation.") + generated_ids_before_ptq = None + gc.collect() + torch.cuda.empty_cache() + if model_type == "gptoss" and args.qformat == "nvfp4_mlp_only": print("Applying nvfp4 quantization (MoE only) for gpt-oss") @@ -542,7 +550,9 @@ def main(args): # Run some samples torch.cuda.empty_cache() generated_ids_after_ptq = None - if model_type != "llama4" and not is_nemotron_vl_model: + if generated_ids_before_ptq is None: + pass + elif model_type != "llama4" and not is_nemotron_vl_model: # Our fake quantizer may not be fully compatible with torch.compile. generated_ids_after_ptq = full_model.generate(input_ids, max_new_tokens=100) elif is_nemotron_vl_model and tokenizer is not None: diff --git a/examples/llm_ptq/scripts/huggingface_example.sh b/examples/llm_ptq/scripts/huggingface_example.sh index 043b690e5..3ea85de9e 100755 --- a/examples/llm_ptq/scripts/huggingface_example.sh +++ b/examples/llm_ptq/scripts/huggingface_example.sh @@ -53,9 +53,9 @@ esac IFS="," for qformat in $QFORMAT; do case $qformat in - fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8) ;; + fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8 | nvfp4_mlp_only) ;; *) - echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8]" >&2 + echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, nvfp4_mlp_only]" >&2 exit 1 ;; esac diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 31ac2bbbd..458c72bce 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -22,6 +22,8 @@ from typing import TYPE_CHECKING import torch +from torch import Tensor +from torch.nn.functional import linear try: from torch.distributed.tensor import Shard @@ -501,6 +503,22 @@ def top_k(self, value): self.router.moe_top_k = value +class _QuantCompressedLinear(QuantModule): + def _setup(self): + self.input_quantizer = TensorQuantizer() + self.weight_quantizer = TensorQuantizer() + + def forward(self, input: Tensor) -> Tensor: + from compressed_tensors.quantization import QuantizationStatus + + if self.quantization_status == QuantizationStatus.COMPRESSED: + weight_data = self.compressor.decompress_module(self) + else: + weight_data = self.weight + + return linear(self.input_quantizer(input), self.weight_quantizer(weight_data), self.bias) + + try: from transformers.models.llama4.modeling_llama4 import Llama4TextExperts, Llama4TextMoe @@ -576,6 +594,16 @@ def top_k(self, value): except ImportError: pass +try: + from compressed_tensors.linear.compressed_linear import CompressedLinear + + if CompressedLinear not in QuantModuleRegistry: + QuantModuleRegistry.register({CompressedLinear: "hf.CompressedLinear"})( + _QuantCompressedLinear + ) +except ImportError: + pass + class _QuantGptOssExperts(_QuantFunctionalMixin): """Quantized wrapper for `transformers.GptOssExperts`. From 3aebac77c72f83f29640d05e8fc70f92474f7f6f Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Tue, 9 Dec 2025 17:27:45 +0000 Subject: [PATCH 2/4] Fix Signed-off-by: Chenjie Luo --- examples/llm_ptq/hf_ptq.py | 36 +++++++++++++++++------------------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 9697adf08..d34f9fdbb 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -14,7 +14,6 @@ # limitations under the License. import argparse -import gc import random import time import warnings @@ -511,25 +510,24 @@ def main(args): "input_features" if model_type == "whisper" else "input_ids" ][0:1] - try: - # Generate preview before quantization - if is_nemotron_vl_model and tokenizer is not None: - generated_ids_before_ptq = run_nemotron_vl_preview( - full_model, - tokenizer, - input_ids, - args.pyt_ckpt_path, - "before quantization", - allow_fallback=True, - ) - else: - # Standard generation for non-Nemotron VL models - generated_ids_before_ptq = full_model.generate(input_ids, max_new_tokens=100) - except torch.OutOfMemoryError: - print("Out of memory. Skipping preview generation.") + # Generate preview before quantization + if model_type == "deepseek": + print( + "Deepseek model may hit OOM during preview generation. Skipping preview generation." + ) generated_ids_before_ptq = None - gc.collect() - torch.cuda.empty_cache() + elif is_nemotron_vl_model and tokenizer is not None: + generated_ids_before_ptq = run_nemotron_vl_preview( + full_model, + tokenizer, + input_ids, + args.pyt_ckpt_path, + "before quantization", + allow_fallback=True, + ) + else: + # Standard generation for non-Nemotron VL models + generated_ids_before_ptq = full_model.generate(input_ids, max_new_tokens=100) if model_type == "gptoss" and args.qformat == "nvfp4_mlp_only": print("Applying nvfp4 quantization (MoE only) for gpt-oss") From 95ee2752e4d2156f248c5898b92f8bee4bce740f Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Wed, 10 Dec 2025 00:12:04 +0000 Subject: [PATCH 3/4] Fix Signed-off-by: Chenjie Luo --- examples/llm_ptq/example_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index c13b9f897..f31d11b8f 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -330,7 +330,6 @@ def get_model( elif hf_config.quantization_config.get("format", None) == "pack-quantized": model = AutoModelForCausalLM.from_pretrained( ckpt_path, - torch_dtype=torch.float16, device_map="auto", trust_remote_code=trust_remote_code, ) @@ -353,9 +352,9 @@ def get_model( from_config = auto_model_module._from_config with init_empty_weights(): - # When computing the device_map, assuming half precision by default, + # When computing the device_map, assuming bfloat16 precision by default, # unless specified by the hf_config. - torch_dtype = getattr(hf_config, "torch_dtype", torch.float16) + torch_dtype = getattr(hf_config, "torch_dtype", torch.bfloat16) model_kwargs2 = model_kwargs.copy() if auto_model_module != AutoModelForCausalLM: model_kwargs2.pop("trust_remote_code", None) From bf367f10b0a75ef5f7baa35dddd5a2e1f360dca9 Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Tue, 20 Jan 2026 07:30:28 +0000 Subject: [PATCH 4/4] Test int4 quant --- examples/deepseek/int4_kernel.py | 51 ++++++++++++++++++++++++++++++++ examples/deepseek/test.py | 17 +++++++++++ 2 files changed, 68 insertions(+) create mode 100644 examples/deepseek/int4_kernel.py create mode 100644 examples/deepseek/test.py diff --git a/examples/deepseek/int4_kernel.py b/examples/deepseek/int4_kernel.py new file mode 100644 index 000000000..14303e4a0 --- /dev/null +++ b/examples/deepseek/int4_kernel.py @@ -0,0 +1,51 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def int4_dequant_kernel( + x_ptr, # pointer to int32 input [M, N] + s_ptr, # pointer to bf16 scale [M, 8*N//BLOCK_SIZE] + y_ptr, # pointer to bf16 output [M, 8N] + M: tl.constexpr, + N: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + m = tl.program_id(axis=0) + n_block = tl.program_id(axis=1) + + # Load int32 values, unroll to int4. + NUM_INT32_PER_BLOCK = BLOCK_SIZE // 8 + int32_vals = tl.load( + x_ptr + m * N + n_block * NUM_INT32_PER_BLOCK + tl.arange(0, BLOCK_SIZE) // 8 + ) + + offset = (tl.arange(0, BLOCK_SIZE) % 8) * 4 + vals = ((int32_vals >> offset) & 0xF) - 8 + + # # Compute scale per block + # # Each scale covers block_size contiguous y + scale = tl.load(s_ptr + m * 8 * N // BLOCK_SIZE + n_block) + + vals = vals.to(tl.float32) * scale.to(tl.float32) + tl.store(y_ptr + m * N * 8 + n_block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE), vals) + + +def int4_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 32) -> torch.Tensor: + """ + Dequantizes a packed int4 tensor to bf16. + + Args: + x: int32 tensor of shape [M, N] with packed int4. + s: bf16 tensor of shape [M, 8 * N // BLOCK_SIZE]. + block_size: number of output columns per block. + + Returns: bf16 tensor of shape [M, 8 * N] + """ + m, n = x.shape + y = torch.empty((m, 8 * n), dtype=torch.get_default_dtype(), device=x.device) + + grid = (m, 8 * n // block_size) + int4_dequant_kernel[grid](x, s, y, m, n, BLOCK_SIZE=block_size) + return y diff --git a/examples/deepseek/test.py b/examples/deepseek/test.py new file mode 100644 index 000000000..33d55f840 --- /dev/null +++ b/examples/deepseek/test.py @@ -0,0 +1,17 @@ +import pdb + +import safetensors +import torch +from int4_kernel import int4_dequant + +tensors = safetensors.safe_open("model-00001-of-00527.safetensors", framework="pt", device="cuda") + +bf16 = tensors.get_tensor("model.layers.1.mlp.experts.0.down_proj.weight") +int32 = tensors.get_tensor("model.layers.1.mlp.experts.0.down_proj.weight_packed") +ws = tensors.get_tensor("model.layers.1.mlp.experts.0.down_proj.weight_scale") +torch.set_default_dtype(torch.bfloat16) +bf16_2 = int4_dequant(int32, ws, block_size=32) + + +if not torch.allclose(bf16_2, bf16, rtol=1e-4, atol=1e-4): + pdb.set_trace()