Skip to content

Commit 6af70e1

Browse files
authored
[ROCm][CI] Fix test_max_len.py for Rocm (#29916)
Signed-off-by: charlifu <charlifu@amd.com> Signed-off-by: Charlie Fu <Charlie.Fu@amd.com>
1 parent ae0f69b commit 6af70e1

File tree

5 files changed

+15
-8
lines changed

5 files changed

+15
-8
lines changed

tests/basic_correctness/test_basic_correctness.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,15 @@
1313
import torch
1414

1515
from vllm import LLM
16+
from vllm.platforms import current_platform
1617
from vllm.v1.engine.llm_engine import LLMEngine
1718

1819
from ..conftest import HfRunner, VllmRunner
1920
from ..models.utils import check_outputs_equal
2021
from ..utils import multi_gpu_test
2122

23+
ATTN_BACKEND = ["ROCM_ATTN"] if current_platform.is_rocm() else ["FLASH_ATTN"]
24+
2225
MODELS = [
2326
"hmellor/tiny-random-Gemma2ForCausalLM",
2427
"meta-llama/Llama-3.2-1B-Instruct",
@@ -57,7 +60,7 @@ def _fix_prompt_embed_outputs(
5760

5861

5962
@pytest.mark.parametrize("model", MODELS)
60-
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
63+
@pytest.mark.parametrize("backend", ATTN_BACKEND)
6164
@pytest.mark.parametrize("max_tokens", [5])
6265
@pytest.mark.parametrize("enforce_eager", [False])
6366
@pytest.mark.parametrize("async_scheduling", [True, False])

tests/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1225,9 +1225,9 @@ def get_attn_backend_list_based_on_platform() -> list[str]:
12251225
try:
12261226
import aiter # noqa: F401
12271227

1228-
attn_backend_list.append("FLASH_ATTN")
1228+
attn_backend_list.append("ROCM_AITER_FA")
12291229
except Exception:
1230-
print("Skip FLASH_ATTN on ROCm as aiter is not installed")
1230+
print("Skip ROCM_AITER_FA on ROCm as aiter is not installed")
12311231

12321232
return attn_backend_list
12331233
elif current_platform.is_xpu():

tests/v1/e2e/test_spec_decode.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -417,9 +417,9 @@ def test_eagle_correctness(
417417
"multi-token eagle spec decode on current platform"
418418
)
419419

420-
if attn_backend == "FLASH_ATTN" and current_platform.is_rocm():
420+
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
421421
if "deepseek" in model_setup[1].lower():
422-
pytest.skip("FLASH_ATTN for deepseek not supported on ROCm platform")
422+
pytest.skip("ROCM_AITER_FA for deepseek not supported on ROCm platform")
423423
else:
424424
m.setenv("VLLM_ROCM_USE_AITER", "1")
425425

tests/v1/spec_decode/test_eagle.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ def test_load_model(
339339
"multi-token eagle spec decode on current platform"
340340
)
341341

342-
if attn_backend == "FLASH_ATTN" and current_platform.is_rocm():
342+
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
343343
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
344344

345345
# Setup draft model mock
@@ -434,7 +434,7 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
434434
"because it requires special input mocking."
435435
)
436436

437-
if attn_backend == "FLASH_ATTN" and current_platform.is_rocm():
437+
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
438438
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
439439

440440
# Use GPU device
@@ -541,6 +541,10 @@ def create_deterministic_logits(token_ids):
541541
attn_metadata_builder_cls, _ = try_get_attention_backend(
542542
AttentionBackendEnum.TREE_ATTN
543543
)
544+
elif attn_backend == "ROCM_AITER_FA":
545+
attn_metadata_builder_cls, _ = try_get_attention_backend(
546+
AttentionBackendEnum.ROCM_AITER_FA
547+
)
544548
else:
545549
raise ValueError(f"Unsupported attention backend: {attn_backend}")
546550

tests/v1/spec_decode/test_max_len.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def test_eagle_max_len(
4747
"multi-token eagle spec decode on current platform"
4848
)
4949

50-
if attn_backend == "FLASH_ATTN" and current_platform.is_rocm():
50+
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
5151
m.setenv("VLLM_ROCM_USE_AITER", "1")
5252

5353
llm = LLM(

0 commit comments

Comments
 (0)