Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .ci/scripts/export_model_artifact.sh
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,11 @@ pip list
if [ "$MODEL_NAME" = "parakeet" ]; then
pip install -r examples/models/parakeet/install_requirements.txt

python examples/models/parakeet/export_parakeet_tdt.py \
python -m executorch.examples.models.parakeet.export_parakeet_tdt \
--backend "$DEVICE" \
--output-dir "${OUTPUT_DIR}" \
--dtype bf16
--dtype bf16 \
${EXTRA_ARGS}

test -f "${OUTPUT_DIR}/model.pte"
# CUDA saves named data to separate .ptd file, Metal embeds in .pte
Expand Down
18 changes: 0 additions & 18 deletions .github/workflows/cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,6 @@ jobs:
repo: "google"
name: "gemma-3-4b-it"
quant: "quantized-int4-weight-only"
# Parakeet only supports non-quantized
- model:
repo: "nvidia"
name: "parakeet-tdt"
quant: "quantized-int4-tile-packed"
- model:
repo: "nvidia"
name: "parakeet-tdt"
quant: "quantized-int4-weight-only"
with:
timeout: 90
secrets-env: EXECUTORCH_HF_TOKEN
Expand Down Expand Up @@ -219,15 +210,6 @@ jobs:
repo: "google"
name: "gemma-3-4b-it"
quant: "quantized-int4-weight-only"
# Parakeet only supports non-quantized
- model:
repo: "nvidia"
name: "parakeet-tdt"
quant: "quantized-int4-tile-packed"
- model:
repo: "nvidia"
name: "parakeet-tdt"
quant: "quantized-int4-weight-only"
with:
timeout: 90
runner: linux.g5.4xlarge.nvidia.gpu
Expand Down
43 changes: 43 additions & 0 deletions examples/models/parakeet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,49 @@ python export_parakeet_tdt.py --audio /path/to/audio.wav

**Note:** The preprocessor is always lowered with the portable backend regardless of the `--backend` setting.

### Quantization

The export script supports quantizing encoder and decoder linear layers using [torchao](https://github.com/pytorch/ao).

#### Quantization Arguments

| Argument | Description |
|----------|-------------|
| `--qlinear_encoder` | Quantization config for encoder linear layers: `4w`, `8w`, `8da4w`, `8da8w` |
| `--qlinear_encoder_group_size` | Group size for encoder linear quantization (default: 32) |
| `--qlinear_encoder_packing_format` | Packing format for encoder: `tile_packed_to_4d` |
| `--qlinear` | Quantization config for decoder linear layers: `4w`, `8w`, `8da4w`, `8da8w` |
| `--qlinear_group_size` | Group size for decoder linear quantization (default: 32) |
| `--qlinear_packing_format` | Packing format for decoder: `tile_packed_to_4d` |
| `--qembedding` | Quantization config for decoder embedding layer: `4w`, `8w` |
| `--qembedding_group_size` | Group size for embedding quantization (default: 0 = per-axis) |

#### Quantization Configs

| Config | Description |
|--------|-------------|
| `4w` | 4-bit weight only quantization |
| `8w` | 8-bit weight only quantization |
| `8da4w` | 8-bit dynamic activation, 4-bit weight |
| `8da8w` | 8-bit dynamic activation, 8-bit weight |

#### Example: 4-bit Weight Quantization with Tile Packing

```bash
python export_parakeet_tdt.py \
--backend cuda \
--qlinear_encoder 4w \
--qlinear_encoder_group_size 32 \
--qlinear_encoder_packing_format tile_packed_to_4d \
--qlinear 4w \
--qlinear_group_size 32 \
--qlinear_packing_format tile_packed_to_4d \
--qembedding 8w \
--output-dir ./parakeet_quantized
```

**Note:** The `tile_packed_to_4d` packing format is optimized for CUDA.

### Metal Export (macOS)

```bash
Expand Down
143 changes: 134 additions & 9 deletions examples/models/parakeet/export_parakeet_tdt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
import shutil
import tarfile
import tempfile
from typing import Optional

import torch
import torchaudio

from executorch.examples.models.parakeet.quantize import quantize_model_
from executorch.exir import (
EdgeCompileConfig,
ExecutorchBackendConfig,
Expand Down Expand Up @@ -295,7 +298,22 @@ def forward(
return mel, mel_len


def export_all(model, dtype=torch.float):
def export_all(
model,
dtype=torch.float,
backend: Optional[str] = None,
# Encoder quantization args
qlinear_encoder: Optional[str] = None,
qlinear_encoder_group_size: int = 32,
qlinear_encoder_packing_format: Optional[str] = None,
# Decoder quantization args
qlinear: Optional[str] = None,
qlinear_group_size: int = 32,
qlinear_packing_format: Optional[str] = None,
# Embedding quantization args (decoder has the embedding layer)
qembedding: Optional[str] = None,
qembedding_group_size: int = 0,
):
"""Export all model components.

The maximum audio duration is determined by the model's internal
Expand All @@ -304,9 +322,21 @@ def export_all(model, dtype=torch.float):
Args:
model: The NeMo ASR model to export.
dtype: Data type for floating-point tensors (default: torch.float).
backend: Target backend ("cuda", "xnnpack", etc.).
qlinear_encoder: Quantization config for encoder linear layers.
qlinear_encoder_group_size: Group size for encoder linear quantization.
qlinear_encoder_packing_format: Packing format for encoder linear layers.
qlinear: Quantization config for decoder linear layers.
qlinear_group_size: Group size for decoder linear quantization.
qlinear_packing_format: Packing format for decoder linear layers.
qembedding: Quantization config for embedding layers ("4w", "8w").
qembedding_group_size: Group size for embedding quantization (default: 0 = per-axis).
"""
programs = {}

# Determine device based on backend (preprocessor always stays on CPU)
device = torch.device("cuda" if backend == "cuda" else "cpu")

# Get audio parameters from model config
sample_rate = model.preprocessor._cfg.sample_rate
window_stride = float(model.preprocessor._cfg.window_stride)
Expand Down Expand Up @@ -339,14 +369,27 @@ def export_all(model, dtype=torch.float):
)
torch.cuda.is_available = old_cuda_is_available

# Move model to CUDA after preprocessor export (preprocessor must stay on CPU)
if backend == "cuda":
model.cuda()

feat_in = getattr(model.encoder, "_feat_in", 128)
# Use max_mel_frames as example to ensure Dim.AUTO infers the full range.
# Smaller examples cause Dim.AUTO to infer narrow bounds.
audio_signal = torch.randn(1, feat_in, max_mel_frames, dtype=dtype)
length = torch.tensor([max_mel_frames], dtype=torch.int64)
audio_signal = torch.randn(1, feat_in, max_mel_frames, dtype=dtype, device=device)
length = torch.tensor([max_mel_frames], dtype=torch.int64, device=device)
encoder_with_proj = EncoderWithProjection(model.encoder, model.joint)
encoder_with_proj.eval()

if qlinear_encoder:
print("Quantizing encoder...")
quantize_model_(
encoder_with_proj,
qlinear_config=qlinear_encoder,
qlinear_group_size=qlinear_encoder_group_size,
qlinear_packing_format=qlinear_encoder_packing_format,
)

programs["encoder"] = export(
encoder_with_proj,
(),
Expand All @@ -363,9 +406,21 @@ def export_all(model, dtype=torch.float):
pred_hidden = model.decoder.pred_hidden
decoder_step = DecoderStep(model.decoder, model.joint)
decoder_step.eval()
token = torch.tensor([[0]], dtype=torch.long)
h = torch.zeros(num_layers, 1, pred_hidden, dtype=dtype)
c = torch.zeros(num_layers, 1, pred_hidden, dtype=dtype)

if qlinear or qembedding:
print("Quantizing decoder...")
quantize_model_(
decoder_step,
qlinear_config=qlinear,
qlinear_group_size=qlinear_group_size,
qlinear_packing_format=qlinear_packing_format,
qembedding_config=qembedding,
qembedding_group_size=qembedding_group_size,
)

token = torch.tensor([[0]], dtype=torch.long, device=device)
h = torch.zeros(num_layers, 1, pred_hidden, dtype=dtype, device=device)
c = torch.zeros(num_layers, 1, pred_hidden, dtype=dtype, device=device)
programs["decoder_step"] = export(
decoder_step,
(token, h, c),
Expand All @@ -376,8 +431,8 @@ def export_all(model, dtype=torch.float):
joint_hidden = model.joint.joint_hidden
num_token_classes = model.tokenizer.vocab_size + 1 # +1 for blank

f_proj = torch.randn(1, 1, joint_hidden, dtype=dtype)
g_proj = torch.randn(1, 1, joint_hidden, dtype=dtype)
f_proj = torch.randn(1, 1, joint_hidden, dtype=dtype, device=device)
g_proj = torch.randn(1, 1, joint_hidden, dtype=dtype, device=device)
programs["joint"] = export(
JointWithArgmax(model.joint, num_token_classes),
(f_proj, g_proj),
Expand Down Expand Up @@ -554,6 +609,61 @@ def main():
choices=["fp32", "fp16", "bf16"],
help="Model dtype for Metal/CUDA backends (default: fp32)",
)

# Decoder quantization arguments
parser.add_argument(
"--qlinear",
type=str,
choices=["4w", "8w", "8da4w", "8da8w"],
help="Quantization config for decoder linear layers",
)
parser.add_argument(
"--qlinear_group_size",
type=int,
default=32,
help="Group size for decoder linear quantization (default: 32)",
)
parser.add_argument(
"--qlinear_packing_format",
type=str,
choices=["tile_packed_to_4d"],
help="Packing format for decoder linear layers",
)

# Encoder quantization arguments
parser.add_argument(
"--qlinear_encoder",
type=str,
choices=["4w", "8w", "8da4w", "8da8w"],
help="Quantization config for encoder linear layers",
)
parser.add_argument(
"--qlinear_encoder_group_size",
type=int,
default=32,
help="Group size for encoder linear quantization (default: 32)",
)
parser.add_argument(
"--qlinear_encoder_packing_format",
type=str,
choices=["tile_packed_to_4d"],
help="Packing format for encoder linear layers",
)

# Embedding quantization arguments (decoder has the embedding layer)
parser.add_argument(
"--qembedding",
type=str,
choices=["4w", "8w"],
help="Quantization config for decoder embedding layer",
)
parser.add_argument(
"--qembedding_group_size",
type=int,
default=0,
help="Group size for embedding quantization (default: 0 = per-axis)",
)

args = parser.parse_args()

# Validate dtype
Expand All @@ -578,7 +688,22 @@ def main():

print("\nExporting components...")
export_dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float
programs, metadata = export_all(model, dtype=export_dtype)
programs, metadata = export_all(
model,
dtype=export_dtype,
backend=args.backend,
# Encoder quantization
qlinear_encoder=args.qlinear_encoder,
qlinear_encoder_group_size=args.qlinear_encoder_group_size,
qlinear_encoder_packing_format=args.qlinear_encoder_packing_format,
# Decoder quantization
qlinear=args.qlinear,
qlinear_group_size=args.qlinear_group_size,
qlinear_packing_format=args.qlinear_packing_format,
# Embedding quantization
qembedding=args.qembedding,
qembedding_group_size=args.qembedding_group_size,
)

et = lower_to_executorch(programs, metadata=metadata, backend=args.backend)

Expand Down
Loading
Loading