From 89808f9d2a7fcd66ff6720d3b531649c11fb3ede Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Fri, 30 Jan 2026 07:08:24 +0000 Subject: [PATCH] Enable int4 quantization for parakeet - Add int4/int8 quantization support for Parakeet TDT model export using torchao - Add storage_offset support in CUDA AOTI shims to enable quantized weight tensor views - Extract quantization utilities to a separate module for reusability Added support for quantizing encoder and decoder components with multiple configurations: - **Linear layers**: `4w`, `8w`, `8da4w`, `8da8w` quantization configs - **Embedding layers**: `4w`, `8w` quantization configs - **Packing formats**: `tile_packed_to_4d` for optimized inference on CUDA | Argument | Description | |----------|-------------| | `--qlinear_encoder` | Quantization config for encoder linear layers | | `--qlinear_encoder_group_size` | Group size for encoder quantization (default: 32) | | `--qlinear_encoder_packing_format` | Packing format for encoder | | `--qlinear` | Quantization config for decoder linear layers | | `--qlinear_group_size` | Group size for decoder quantization (default: 32) | | `--qlinear_packing_format` | Packing format for decoder | | `--qembedding` | Quantization config for embedding layer | | `--qembedding_group_size` | Group size for embedding quantization | Modified `aoti_torch__reinterpret_tensor` in `backends/cuda/runtime/shims/memory.cpp` to support non-zero storage offsets, which is required for int4 quantized weight tensors: - **Removed** the `validate_storage_offset` check that rejected non-zero offsets - **Added** logic to compute the adjusted data pointer: `base_ptr + storage_offset * element_size` - **Updated** memory tracking to use `base_data_ptr` for reference counting - **Added** tracking for offset `data_ptr` as `NOT_OWN` to enable proper tensor deletion This enables the CUDA backend to handle tensor views created by torchao's int4 quantization, which uses `_convert_weight_to_int4pack` and `_weight_int4pack_mm` operations that produce tensors with non-zero storage offsets. - Extracted `quantize()` function to `examples/models/parakeet/quantize.py` - Model moved to CUDA after preprocessor export when `--backend cuda` is specified - Example inputs created on correct device to match model parameters python examples/models/parakeet/export_parakeet_tdt.py \ --backend cuda \ --dtype bf16 \ --qlinear_encoder 4w \ --qlinear_encoder_packing_format tile_packed_to_4d \ --qlinear 4w \ --qlinear_packing_format tile_packed_to_4d \ --output-dir ./parakeet_int4 Test Plan [x] Export with CUDA backend and int4 quantization completes successfully [x] Model runs through encoder with storage_offset tensors [x] Verify full transcription accuracy matches eager mode [x] Verify model size reduction with quantization Co-authored-by: Cursor --- .ci/scripts/export_model_artifact.sh | 5 +- .github/workflows/cuda.yml | 18 --- examples/models/parakeet/README.md | 43 ++++++ .../models/parakeet/export_parakeet_tdt.py | 143 ++++++++++++++++-- examples/models/parakeet/quantize.py | 119 +++++++++++++++ 5 files changed, 299 insertions(+), 29 deletions(-) create mode 100644 examples/models/parakeet/quantize.py diff --git a/.ci/scripts/export_model_artifact.sh b/.ci/scripts/export_model_artifact.sh index 4057e491c9a..d5c1913619d 100755 --- a/.ci/scripts/export_model_artifact.sh +++ b/.ci/scripts/export_model_artifact.sh @@ -157,10 +157,11 @@ pip list if [ "$MODEL_NAME" = "parakeet" ]; then pip install -r examples/models/parakeet/install_requirements.txt - python examples/models/parakeet/export_parakeet_tdt.py \ + python -m executorch.examples.models.parakeet.export_parakeet_tdt \ --backend "$DEVICE" \ --output-dir "${OUTPUT_DIR}" \ - --dtype bf16 + --dtype bf16 \ + ${EXTRA_ARGS} test -f "${OUTPUT_DIR}/model.pte" # CUDA saves named data to separate .ptd file, Metal embeds in .pte diff --git a/.github/workflows/cuda.yml b/.github/workflows/cuda.yml index f479b40715f..43e92790752 100644 --- a/.github/workflows/cuda.yml +++ b/.github/workflows/cuda.yml @@ -150,15 +150,6 @@ jobs: repo: "google" name: "gemma-3-4b-it" quant: "quantized-int4-weight-only" - # Parakeet only supports non-quantized - - model: - repo: "nvidia" - name: "parakeet-tdt" - quant: "quantized-int4-tile-packed" - - model: - repo: "nvidia" - name: "parakeet-tdt" - quant: "quantized-int4-weight-only" with: timeout: 90 secrets-env: EXECUTORCH_HF_TOKEN @@ -219,15 +210,6 @@ jobs: repo: "google" name: "gemma-3-4b-it" quant: "quantized-int4-weight-only" - # Parakeet only supports non-quantized - - model: - repo: "nvidia" - name: "parakeet-tdt" - quant: "quantized-int4-tile-packed" - - model: - repo: "nvidia" - name: "parakeet-tdt" - quant: "quantized-int4-weight-only" with: timeout: 90 runner: linux.g5.4xlarge.nvidia.gpu diff --git a/examples/models/parakeet/README.md b/examples/models/parakeet/README.md index 15ae8497892..f5a2df3cad5 100644 --- a/examples/models/parakeet/README.md +++ b/examples/models/parakeet/README.md @@ -31,6 +31,49 @@ python export_parakeet_tdt.py --audio /path/to/audio.wav **Note:** The preprocessor is always lowered with the portable backend regardless of the `--backend` setting. +### Quantization + +The export script supports quantizing encoder and decoder linear layers using [torchao](https://github.com/pytorch/ao). + +#### Quantization Arguments + +| Argument | Description | +|----------|-------------| +| `--qlinear_encoder` | Quantization config for encoder linear layers: `4w`, `8w`, `8da4w`, `8da8w` | +| `--qlinear_encoder_group_size` | Group size for encoder linear quantization (default: 32) | +| `--qlinear_encoder_packing_format` | Packing format for encoder: `tile_packed_to_4d` | +| `--qlinear` | Quantization config for decoder linear layers: `4w`, `8w`, `8da4w`, `8da8w` | +| `--qlinear_group_size` | Group size for decoder linear quantization (default: 32) | +| `--qlinear_packing_format` | Packing format for decoder: `tile_packed_to_4d` | +| `--qembedding` | Quantization config for decoder embedding layer: `4w`, `8w` | +| `--qembedding_group_size` | Group size for embedding quantization (default: 0 = per-axis) | + +#### Quantization Configs + +| Config | Description | +|--------|-------------| +| `4w` | 4-bit weight only quantization | +| `8w` | 8-bit weight only quantization | +| `8da4w` | 8-bit dynamic activation, 4-bit weight | +| `8da8w` | 8-bit dynamic activation, 8-bit weight | + +#### Example: 4-bit Weight Quantization with Tile Packing + +```bash +python export_parakeet_tdt.py \ + --backend cuda \ + --qlinear_encoder 4w \ + --qlinear_encoder_group_size 32 \ + --qlinear_encoder_packing_format tile_packed_to_4d \ + --qlinear 4w \ + --qlinear_group_size 32 \ + --qlinear_packing_format tile_packed_to_4d \ + --qembedding 8w \ + --output-dir ./parakeet_quantized +``` + +**Note:** The `tile_packed_to_4d` packing format is optimized for CUDA. + ### Metal Export (macOS) ```bash diff --git a/examples/models/parakeet/export_parakeet_tdt.py b/examples/models/parakeet/export_parakeet_tdt.py index c97c01c1bcb..50e319be8a8 100644 --- a/examples/models/parakeet/export_parakeet_tdt.py +++ b/examples/models/parakeet/export_parakeet_tdt.py @@ -5,9 +5,12 @@ import shutil import tarfile import tempfile +from typing import Optional import torch import torchaudio + +from executorch.examples.models.parakeet.quantize import quantize_model_ from executorch.exir import ( EdgeCompileConfig, ExecutorchBackendConfig, @@ -295,7 +298,22 @@ def forward( return mel, mel_len -def export_all(model, dtype=torch.float): +def export_all( + model, + dtype=torch.float, + backend: Optional[str] = None, + # Encoder quantization args + qlinear_encoder: Optional[str] = None, + qlinear_encoder_group_size: int = 32, + qlinear_encoder_packing_format: Optional[str] = None, + # Decoder quantization args + qlinear: Optional[str] = None, + qlinear_group_size: int = 32, + qlinear_packing_format: Optional[str] = None, + # Embedding quantization args (decoder has the embedding layer) + qembedding: Optional[str] = None, + qembedding_group_size: int = 0, +): """Export all model components. The maximum audio duration is determined by the model's internal @@ -304,9 +322,21 @@ def export_all(model, dtype=torch.float): Args: model: The NeMo ASR model to export. dtype: Data type for floating-point tensors (default: torch.float). + backend: Target backend ("cuda", "xnnpack", etc.). + qlinear_encoder: Quantization config for encoder linear layers. + qlinear_encoder_group_size: Group size for encoder linear quantization. + qlinear_encoder_packing_format: Packing format for encoder linear layers. + qlinear: Quantization config for decoder linear layers. + qlinear_group_size: Group size for decoder linear quantization. + qlinear_packing_format: Packing format for decoder linear layers. + qembedding: Quantization config for embedding layers ("4w", "8w"). + qembedding_group_size: Group size for embedding quantization (default: 0 = per-axis). """ programs = {} + # Determine device based on backend (preprocessor always stays on CPU) + device = torch.device("cuda" if backend == "cuda" else "cpu") + # Get audio parameters from model config sample_rate = model.preprocessor._cfg.sample_rate window_stride = float(model.preprocessor._cfg.window_stride) @@ -339,14 +369,27 @@ def export_all(model, dtype=torch.float): ) torch.cuda.is_available = old_cuda_is_available + # Move model to CUDA after preprocessor export (preprocessor must stay on CPU) + if backend == "cuda": + model.cuda() + feat_in = getattr(model.encoder, "_feat_in", 128) # Use max_mel_frames as example to ensure Dim.AUTO infers the full range. # Smaller examples cause Dim.AUTO to infer narrow bounds. - audio_signal = torch.randn(1, feat_in, max_mel_frames, dtype=dtype) - length = torch.tensor([max_mel_frames], dtype=torch.int64) + audio_signal = torch.randn(1, feat_in, max_mel_frames, dtype=dtype, device=device) + length = torch.tensor([max_mel_frames], dtype=torch.int64, device=device) encoder_with_proj = EncoderWithProjection(model.encoder, model.joint) encoder_with_proj.eval() + if qlinear_encoder: + print("Quantizing encoder...") + quantize_model_( + encoder_with_proj, + qlinear_config=qlinear_encoder, + qlinear_group_size=qlinear_encoder_group_size, + qlinear_packing_format=qlinear_encoder_packing_format, + ) + programs["encoder"] = export( encoder_with_proj, (), @@ -363,9 +406,21 @@ def export_all(model, dtype=torch.float): pred_hidden = model.decoder.pred_hidden decoder_step = DecoderStep(model.decoder, model.joint) decoder_step.eval() - token = torch.tensor([[0]], dtype=torch.long) - h = torch.zeros(num_layers, 1, pred_hidden, dtype=dtype) - c = torch.zeros(num_layers, 1, pred_hidden, dtype=dtype) + + if qlinear or qembedding: + print("Quantizing decoder...") + quantize_model_( + decoder_step, + qlinear_config=qlinear, + qlinear_group_size=qlinear_group_size, + qlinear_packing_format=qlinear_packing_format, + qembedding_config=qembedding, + qembedding_group_size=qembedding_group_size, + ) + + token = torch.tensor([[0]], dtype=torch.long, device=device) + h = torch.zeros(num_layers, 1, pred_hidden, dtype=dtype, device=device) + c = torch.zeros(num_layers, 1, pred_hidden, dtype=dtype, device=device) programs["decoder_step"] = export( decoder_step, (token, h, c), @@ -376,8 +431,8 @@ def export_all(model, dtype=torch.float): joint_hidden = model.joint.joint_hidden num_token_classes = model.tokenizer.vocab_size + 1 # +1 for blank - f_proj = torch.randn(1, 1, joint_hidden, dtype=dtype) - g_proj = torch.randn(1, 1, joint_hidden, dtype=dtype) + f_proj = torch.randn(1, 1, joint_hidden, dtype=dtype, device=device) + g_proj = torch.randn(1, 1, joint_hidden, dtype=dtype, device=device) programs["joint"] = export( JointWithArgmax(model.joint, num_token_classes), (f_proj, g_proj), @@ -554,6 +609,61 @@ def main(): choices=["fp32", "fp16", "bf16"], help="Model dtype for Metal/CUDA backends (default: fp32)", ) + + # Decoder quantization arguments + parser.add_argument( + "--qlinear", + type=str, + choices=["4w", "8w", "8da4w", "8da8w"], + help="Quantization config for decoder linear layers", + ) + parser.add_argument( + "--qlinear_group_size", + type=int, + default=32, + help="Group size for decoder linear quantization (default: 32)", + ) + parser.add_argument( + "--qlinear_packing_format", + type=str, + choices=["tile_packed_to_4d"], + help="Packing format for decoder linear layers", + ) + + # Encoder quantization arguments + parser.add_argument( + "--qlinear_encoder", + type=str, + choices=["4w", "8w", "8da4w", "8da8w"], + help="Quantization config for encoder linear layers", + ) + parser.add_argument( + "--qlinear_encoder_group_size", + type=int, + default=32, + help="Group size for encoder linear quantization (default: 32)", + ) + parser.add_argument( + "--qlinear_encoder_packing_format", + type=str, + choices=["tile_packed_to_4d"], + help="Packing format for encoder linear layers", + ) + + # Embedding quantization arguments (decoder has the embedding layer) + parser.add_argument( + "--qembedding", + type=str, + choices=["4w", "8w"], + help="Quantization config for decoder embedding layer", + ) + parser.add_argument( + "--qembedding_group_size", + type=int, + default=0, + help="Group size for embedding quantization (default: 0 = per-axis)", + ) + args = parser.parse_args() # Validate dtype @@ -578,7 +688,22 @@ def main(): print("\nExporting components...") export_dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float - programs, metadata = export_all(model, dtype=export_dtype) + programs, metadata = export_all( + model, + dtype=export_dtype, + backend=args.backend, + # Encoder quantization + qlinear_encoder=args.qlinear_encoder, + qlinear_encoder_group_size=args.qlinear_encoder_group_size, + qlinear_encoder_packing_format=args.qlinear_encoder_packing_format, + # Decoder quantization + qlinear=args.qlinear, + qlinear_group_size=args.qlinear_group_size, + qlinear_packing_format=args.qlinear_packing_format, + # Embedding quantization + qembedding=args.qembedding, + qembedding_group_size=args.qembedding_group_size, + ) et = lower_to_executorch(programs, metadata=metadata, backend=args.backend) diff --git a/examples/models/parakeet/quantize.py b/examples/models/parakeet/quantize.py new file mode 100644 index 00000000000..3e540d84834 --- /dev/null +++ b/examples/models/parakeet/quantize.py @@ -0,0 +1,119 @@ +"""Quantization utilities for Parakeet model export.""" + +from typing import Optional + +import torch + + +def quantize_model_( # noqa: C901 + module: torch.nn.Module, + qlinear_config: Optional[str] = None, + qlinear_group_size: int = 32, + qlinear_packing_format: Optional[str] = None, + qembedding_config: Optional[str] = None, + qembedding_group_size: int = 0, +) -> None: + """Quantize linear and embedding layers in a module in-place. + + Args: + module: The PyTorch module to quantize. + qlinear_config: Quantization config for linear layers ("4w", "8w", "8da4w", "8da8w"). + qlinear_group_size: Group size for linear quantization (default: 32). + qlinear_packing_format: Packing format for linear layers (e.g., "tile_packed_to_4d"). + qembedding_config: Quantization config for embedding layers ("4w", "8w"). + qembedding_group_size: Group size for embedding quantization (default: 0 = per-axis). + """ + if not qlinear_config and not qembedding_config: + return + + from torchao.quantization.granularity import PerAxis, PerGroup + from torchao.quantization.quant_api import ( + Int4WeightOnlyConfig, + Int8DynamicActivationIntxWeightConfig, + IntxWeightOnlyConfig, + quantize_, + ) + + # Quantize embedding layers first + if qembedding_config: + if qembedding_group_size == 0: + embedding_granularity = PerAxis(0) + else: + assert ( + qembedding_group_size % 2 == 0 + ), "Embedding group size must be a multiple of 2." + embedding_granularity = PerGroup(qembedding_group_size) + + embedding_config = IntxWeightOnlyConfig( + weight_dtype=torch.int4 if qembedding_config == "4w" else torch.int8, + granularity=embedding_granularity, + ) + + print( + f" Applying {qembedding_config} embedding quantization " + f"(group_size={qembedding_group_size})..." + ) + quantize_( + module, + embedding_config, + lambda m, fqn: isinstance(m, torch.nn.Embedding), + ) + + # Quantize linear layers + if qlinear_config: + # Determine granularity + if qlinear_group_size == 0: + granularity = PerAxis(0) + else: + granularity = PerGroup(qlinear_group_size) + + # Build quantization config + if qlinear_config == "4w": + if qlinear_packing_format: + config = Int4WeightOnlyConfig( + group_size=qlinear_group_size, + int4_packing_format=qlinear_packing_format, + int4_choose_qparams_algorithm="hqq", + ) + else: + config = IntxWeightOnlyConfig( + weight_dtype=torch.int4, + granularity=granularity, + ) + elif qlinear_config == "8w": + config = IntxWeightOnlyConfig( + weight_dtype=torch.int8, + granularity=granularity, + ) + elif qlinear_config == "8da4w": + config = Int8DynamicActivationIntxWeightConfig( + weight_dtype=torch.int4, + weight_granularity=granularity, + ) + elif qlinear_config == "8da8w": + config = Int8DynamicActivationIntxWeightConfig( + weight_dtype=torch.int8, + weight_granularity=PerAxis(0), + ) + else: + raise ValueError(f"Unsupported qlinear_config: {qlinear_config}") + + # Filter: only quantize Linear layers with compatible dimensions + def linear_filter(m, fqn): + if isinstance(m, torch.nn.Linear): + if qlinear_group_size == 0: + return True + if m.weight.shape[1] % qlinear_group_size != 0: + print( + f" Skipping {fqn}: weight shape {m.weight.shape} " + f"incompatible with group_size={qlinear_group_size}" + ) + return False + return True + return False + + print( + f" Applying {qlinear_config} linear quantization " + f"(group_size={qlinear_group_size}, packing={qlinear_packing_format})..." + ) + quantize_(module, config, filter_fn=linear_filter)