Skip to content

Conversation

@tpopp
Copy link

@tpopp tpopp commented Dec 8, 2025

This fuses a set of kernels that are bound by launch latency and is used whenever aiter is available.

Test Plan:
Server Launch:
vllm serve ${MODEL} --port ${PORT} --swap-space 32 --max-model-len ${MAX_MODEL_LEN} --tensor-parallel-size ${TP} --max-num-seqs ${MAX_NUM_SEQS} --gpu-memory-utilization 0.93 --kv-cache-dtype fp8 --max-num-batched-tokens ${MAX_NUM_BATCHED_TOKENS} --compilation-config "{\"custom_ops\": [\"-rms_norm\", \"-quant_fp8\", \"-silu_and_mul\"] }" --no-enable-prefix-caching --async-scheduling

Example benchmark command:
vllm bench serve --model meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8 --host localhost --dataset-name random --random-input-len 1024 --random-output-len 1024 --max-concurrency 4 --num-prompts 48 --ignore-eos

Example correctness check:
lm_eval --model local-completions --tasks gsm8k --model_args base_url=http://0.0.0.0:${PORT}/v1/completions,model=meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8,tokenized_requests=False,tokenizer_backend=None,num_concurrent=128,timeout=120,max_retries=5

Test Result (using mi325x8):

Throughput increase by about 4% and accuracy comparable/better.

Before:

============ Serving Benchmark Result ============
Successful requests:                     48
Failed requests:                         0
Maximum request concurrency:             4
Benchmark duration (s):                  119.76
Total input tokens:                      49104
Total generated tokens:                  49152
Request throughput (req/s):              0.40
Output token throughput (tok/s):         410.41
Peak output token throughput (tok/s):    428.00
Peak concurrent requests:                8.00
Total Token throughput (tok/s):          820.42
---------------Time to First Token----------------
Mean TTFT (ms):                          88.79
Median TTFT (ms):                        99.22
P99 TTFT (ms):                           111.84
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          9.67
Median TPOT (ms):                        9.65
P99 TPOT (ms):                           9.80
---------------Inter-token Latency----------------
Mean ITL (ms):                           9.67
Median ITL (ms):                         9.46
P99 ITL (ms):                            14.17
==================================================

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9333|±  |0.0069|
|     |       |strict-match    |     5|exact_match|↑  |0.9356|±  |0.0068|

After:

============ Serving Benchmark Result ============
Successful requests:                     48
Failed requests:                         0
Maximum request concurrency:             4
Benchmark duration (s):                  115.14
Total input tokens:                      49104
Total generated tokens:                  49152
Request throughput (req/s):              0.42
Output token throughput (tok/s):         426.89
Peak output token throughput (tok/s):    444.00
Peak concurrent requests:                8.00
Total Token throughput (tok/s):          853.36
---------------Time to First Token----------------
Mean TTFT (ms):                          87.66
Median TTFT (ms):                        96.94
P99 TTFT (ms):                           118.98
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          9.29
Median TPOT (ms):                        9.29
P99 TPOT (ms):                           9.37
---------------Inter-token Latency----------------
Mean ITL (ms):                           9.29
Median ITL (ms):                         9.09
P99 ITL (ms):                            13.76
==================================================

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9287|±  |0.0071|
|     |       |strict-match    |     5|exact_match|↑  |0.9310|±  |0.0070|

This fuses a set of kernels that are bound by launch latency.

Signed-off-by: Tres Popp <tres.popp@amd.com>
@mergify mergify bot added llama Related to Llama models rocm Related to AMD ROCm labels Dec 8, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an optimization for Llama-4 models on ROCm platforms by using aiter.topk_sigmoid to fuse the topk and sigmoid operations. This is a good optimization for latency-bound kernels. The changes are well-implemented, adding a new rocm_aiter_topk_sigmoid op and conditionally using it when aiter is available. The implementation is consistent with existing aiter ops in the codebase. The provided benchmarks demonstrate a performance improvement. The code is clean and correct, and I have no suggestions for improvement.

@tpopp
Copy link
Author

tpopp commented Dec 8, 2025

I attempted this as a separate optimization pass, which I believe is better in general, but did not have luck. I can gladly change it to that, but I need guidance then. I see topk+sigmoid in traces, but I do not see those operations in dumped pytorch graphs. I couldn't figure out how to check if the custom_routing is included in a pytorch graph. It is part of a hipGraph, but I don't know if those are per pytorch graph or can include more.

12521ae

@tpopp
Copy link
Author

tpopp commented Dec 8, 2025

Proof of the kernel being used on this branch for the PR:

Screenshot 2025-12-08 113300

@tpopp tpopp marked this pull request as ready for review December 8, 2025 11:37
@tpopp tpopp requested a review from tjtanaa as a code owner December 8, 2025 11:37
@tjtanaa
Copy link
Collaborator

tjtanaa commented Dec 8, 2025

@tpopp please include accuracy results like gsm8k.

@tpopp
Copy link
Author

tpopp commented Dec 8, 2025

@tjtanaa I think I included what you are asking for in the 2 code blocks of the description, or are you saying to include accuracy numbers from a benchmark other than gsm8k?

Copy link
Collaborator

@tjtanaa tjtanaa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

topk: int,
renormalize: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
if rocm_aiter_ops.is_enabled():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NITS: Since this is related to MoE, let's change to use rocm_aiter_ops.is_fused_moe_enabled()

@tjtanaa
Copy link
Collaborator

tjtanaa commented Dec 9, 2025

@houseroad If you are free could you take a quick look. Thank you.

@mergify
Copy link

mergify bot commented Dec 10, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @tpopp.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Dec 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

llama Related to Llama models needs-rebase rocm Related to AMD ROCm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants