diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 521d6c33dd39..9e1cc309edd1 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -13,12 +13,15 @@ import pytest import torch from vllm import LLM +from vllm.platforms import current_platform from vllm.v1.engine.llm_engine import LLMEngine from ..conftest import HfRunner, VllmRunner from ..models.utils import check_outputs_equal from ..utils import multi_gpu_test +ATTN_BACKEND = ["ROCM_ATTN"] if current_platform.is_rocm() else ["FLASH_ATTN"] + MODELS = [ "hmellor/tiny-random-Gemma2ForCausalLM", "meta-llama/Llama-3.2-1B-Instruct", @@ -57,7 +60,7 @@ def _fix_prompt_embed_outputs( @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("backend", ["FLASH_ATTN"]) +@pytest.mark.parametrize("backend", ATTN_BACKEND) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("enforce_eager", [False]) @pytest.mark.parametrize("async_scheduling", [True, False]) diff --git a/tests/utils.py b/tests/utils.py index 539f67c47ac1..ea3675b1461b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1225,9 +1225,9 @@ def get_attn_backend_list_based_on_platform() -> list[str]: try: import aiter # noqa: F401 - attn_backend_list.append("FLASH_ATTN") + attn_backend_list.append("ROCM_AITER_FA") except Exception: - print("Skip FLASH_ATTN on ROCm as aiter is not installed") + print("Skip ROCM_AITER_FA on ROCm as aiter is not installed") return attn_backend_list elif current_platform.is_xpu(): diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 575a6a151f57..416b582dfaa6 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -417,9 +417,9 @@ def test_eagle_correctness( "multi-token eagle spec decode on current platform" ) - if attn_backend == "FLASH_ATTN" and current_platform.is_rocm(): + if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm(): if "deepseek" in model_setup[1].lower(): - pytest.skip("FLASH_ATTN for deepseek not supported on ROCm platform") + pytest.skip("ROCM_AITER_FA for deepseek not supported on ROCm platform") else: m.setenv("VLLM_ROCM_USE_AITER", "1") diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 616e57de339e..55e9b4d0660f 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -339,7 +339,7 @@ def test_load_model( "multi-token eagle spec decode on current platform" ) - if attn_backend == "FLASH_ATTN" and current_platform.is_rocm(): + if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm(): monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") # Setup draft model mock @@ -434,7 +434,7 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch): "because it requires special input mocking." ) - if attn_backend == "FLASH_ATTN" and current_platform.is_rocm(): + if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm(): monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") # Use GPU device @@ -541,6 +541,10 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch): attn_metadata_builder_cls, _ = try_get_attention_backend( AttentionBackendEnum.TREE_ATTN ) + elif attn_backend == "ROCM_AITER_FA": + attn_metadata_builder_cls, _ = try_get_attention_backend( + AttentionBackendEnum.ROCM_AITER_FA + ) else: raise ValueError(f"Unsupported attention backend: {attn_backend}") diff --git a/tests/v1/spec_decode/test_max_len.py b/tests/v1/spec_decode/test_max_len.py index fa1d0437f7c7..81da8609aa6c 100644 --- a/tests/v1/spec_decode/test_max_len.py +++ b/tests/v1/spec_decode/test_max_len.py @@ -47,7 +47,7 @@ def test_eagle_max_len( "multi-token eagle spec decode on current platform" ) - if attn_backend == "FLASH_ATTN" and current_platform.is_rocm(): + if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm(): m.setenv("VLLM_ROCM_USE_AITER", "1") llm = LLM(