-
-
Notifications
You must be signed in to change notification settings - Fork 11.9k
[ROCm] Use aiter.topk_sigmoid in llama4 #30255
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This fuses a set of kernels that are bound by launch latency. Signed-off-by: Tres Popp <tres.popp@amd.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces an optimization for Llama-4 models on ROCm platforms by using aiter.topk_sigmoid to fuse the topk and sigmoid operations. This is a good optimization for latency-bound kernels. The changes are well-implemented, adding a new rocm_aiter_topk_sigmoid op and conditionally using it when aiter is available. The implementation is consistent with existing aiter ops in the codebase. The provided benchmarks demonstrate a performance improvement. The code is clean and correct, and I have no suggestions for improvement.
|
I attempted this as a separate optimization pass, which I believe is better in general, but did not have luck. I can gladly change it to that, but I need guidance then. I see topk+sigmoid in traces, but I do not see those operations in dumped pytorch graphs. I couldn't figure out how to check if the custom_routing is included in a pytorch graph. It is part of a hipGraph, but I don't know if those are per pytorch graph or can include more. |
|
@tpopp please include accuracy results like gsm8k. |
|
@tjtanaa I think I included what you are asking for in the 2 code blocks of the description, or are you saying to include accuracy numbers from a benchmark other than gsm8k? |
tjtanaa
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
| topk: int, | ||
| renormalize: bool, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| if rocm_aiter_ops.is_enabled(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NITS: Since this is related to MoE, let's change to use rocm_aiter_ops.is_fused_moe_enabled()
|
@houseroad If you are free could you take a quick look. Thank you. |
|
This pull request has merge conflicts that must be resolved before it can be |

This fuses a set of kernels that are bound by launch latency and is used whenever aiter is available.
Test Plan:
Server Launch:
vllm serve ${MODEL} --port ${PORT} --swap-space 32 --max-model-len ${MAX_MODEL_LEN} --tensor-parallel-size ${TP} --max-num-seqs ${MAX_NUM_SEQS} --gpu-memory-utilization 0.93 --kv-cache-dtype fp8 --max-num-batched-tokens ${MAX_NUM_BATCHED_TOKENS} --compilation-config "{\"custom_ops\": [\"-rms_norm\", \"-quant_fp8\", \"-silu_and_mul\"] }" --no-enable-prefix-caching --async-schedulingExample benchmark command:
vllm bench serve --model meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8 --host localhost --dataset-name random --random-input-len 1024 --random-output-len 1024 --max-concurrency 4 --num-prompts 48 --ignore-eosExample correctness check:
lm_eval --model local-completions --tasks gsm8k --model_args base_url=http://0.0.0.0:${PORT}/v1/completions,model=meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8,tokenized_requests=False,tokenizer_backend=None,num_concurrent=128,timeout=120,max_retries=5Test Result (using mi325x8):
Throughput increase by about 4% and accuracy comparable/better.
Before:
After: