[ROCm][CI] Fix test_max_len.py for Rocm (#29916)

Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: Charlie Fu <Charlie.Fu@amd.com>
This commit is contained in:
Charlie Fu 2025-12-08 15:58:30 -06:00 committed by GitHub
parent ae0f69b16a
commit 6af70e11a0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 15 additions and 8 deletions

View File

@ -13,12 +13,15 @@ import pytest
import torch import torch
from vllm import LLM from vllm import LLM
from vllm.platforms import current_platform
from vllm.v1.engine.llm_engine import LLMEngine from vllm.v1.engine.llm_engine import LLMEngine
from ..conftest import HfRunner, VllmRunner from ..conftest import HfRunner, VllmRunner
from ..models.utils import check_outputs_equal from ..models.utils import check_outputs_equal
from ..utils import multi_gpu_test from ..utils import multi_gpu_test
ATTN_BACKEND = ["ROCM_ATTN"] if current_platform.is_rocm() else ["FLASH_ATTN"]
MODELS = [ MODELS = [
"hmellor/tiny-random-Gemma2ForCausalLM", "hmellor/tiny-random-Gemma2ForCausalLM",
"meta-llama/Llama-3.2-1B-Instruct", "meta-llama/Llama-3.2-1B-Instruct",
@ -57,7 +60,7 @@ def _fix_prompt_embed_outputs(
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("backend", ["FLASH_ATTN"]) @pytest.mark.parametrize("backend", ATTN_BACKEND)
@pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [False]) @pytest.mark.parametrize("enforce_eager", [False])
@pytest.mark.parametrize("async_scheduling", [True, False]) @pytest.mark.parametrize("async_scheduling", [True, False])

View File

@ -1225,9 +1225,9 @@ def get_attn_backend_list_based_on_platform() -> list[str]:
try: try:
import aiter # noqa: F401 import aiter # noqa: F401
attn_backend_list.append("FLASH_ATTN") attn_backend_list.append("ROCM_AITER_FA")
except Exception: except Exception:
print("Skip FLASH_ATTN on ROCm as aiter is not installed") print("Skip ROCM_AITER_FA on ROCm as aiter is not installed")
return attn_backend_list return attn_backend_list
elif current_platform.is_xpu(): elif current_platform.is_xpu():

View File

@ -417,9 +417,9 @@ def test_eagle_correctness(
"multi-token eagle spec decode on current platform" "multi-token eagle spec decode on current platform"
) )
if attn_backend == "FLASH_ATTN" and current_platform.is_rocm(): if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
if "deepseek" in model_setup[1].lower(): if "deepseek" in model_setup[1].lower():
pytest.skip("FLASH_ATTN for deepseek not supported on ROCm platform") pytest.skip("ROCM_AITER_FA for deepseek not supported on ROCm platform")
else: else:
m.setenv("VLLM_ROCM_USE_AITER", "1") m.setenv("VLLM_ROCM_USE_AITER", "1")

View File

@ -339,7 +339,7 @@ def test_load_model(
"multi-token eagle spec decode on current platform" "multi-token eagle spec decode on current platform"
) )
if attn_backend == "FLASH_ATTN" and current_platform.is_rocm(): if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
# Setup draft model mock # Setup draft model mock
@ -434,7 +434,7 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
"because it requires special input mocking." "because it requires special input mocking."
) )
if attn_backend == "FLASH_ATTN" and current_platform.is_rocm(): if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
# Use GPU device # Use GPU device
@ -541,6 +541,10 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
attn_metadata_builder_cls, _ = try_get_attention_backend( attn_metadata_builder_cls, _ = try_get_attention_backend(
AttentionBackendEnum.TREE_ATTN AttentionBackendEnum.TREE_ATTN
) )
elif attn_backend == "ROCM_AITER_FA":
attn_metadata_builder_cls, _ = try_get_attention_backend(
AttentionBackendEnum.ROCM_AITER_FA
)
else: else:
raise ValueError(f"Unsupported attention backend: {attn_backend}") raise ValueError(f"Unsupported attention backend: {attn_backend}")

View File

@ -47,7 +47,7 @@ def test_eagle_max_len(
"multi-token eagle spec decode on current platform" "multi-token eagle spec decode on current platform"
) )
if attn_backend == "FLASH_ATTN" and current_platform.is_rocm(): if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
m.setenv("VLLM_ROCM_USE_AITER", "1") m.setenv("VLLM_ROCM_USE_AITER", "1")
llm = LLM( llm = LLM(