mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 20:05:37 +08:00
[ROCm][CI] Fix test_max_len.py for Rocm (#29916)
Signed-off-by: charlifu <charlifu@amd.com> Signed-off-by: Charlie Fu <Charlie.Fu@amd.com>
This commit is contained in:
parent
ae0f69b16a
commit
6af70e11a0
@ -13,12 +13,15 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.engine.llm_engine import LLMEngine
|
||||
|
||||
from ..conftest import HfRunner, VllmRunner
|
||||
from ..models.utils import check_outputs_equal
|
||||
from ..utils import multi_gpu_test
|
||||
|
||||
ATTN_BACKEND = ["ROCM_ATTN"] if current_platform.is_rocm() else ["FLASH_ATTN"]
|
||||
|
||||
MODELS = [
|
||||
"hmellor/tiny-random-Gemma2ForCausalLM",
|
||||
"meta-llama/Llama-3.2-1B-Instruct",
|
||||
@ -57,7 +60,7 @@ def _fix_prompt_embed_outputs(
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
|
||||
@pytest.mark.parametrize("backend", ATTN_BACKEND)
|
||||
@pytest.mark.parametrize("max_tokens", [5])
|
||||
@pytest.mark.parametrize("enforce_eager", [False])
|
||||
@pytest.mark.parametrize("async_scheduling", [True, False])
|
||||
|
||||
@ -1225,9 +1225,9 @@ def get_attn_backend_list_based_on_platform() -> list[str]:
|
||||
try:
|
||||
import aiter # noqa: F401
|
||||
|
||||
attn_backend_list.append("FLASH_ATTN")
|
||||
attn_backend_list.append("ROCM_AITER_FA")
|
||||
except Exception:
|
||||
print("Skip FLASH_ATTN on ROCm as aiter is not installed")
|
||||
print("Skip ROCM_AITER_FA on ROCm as aiter is not installed")
|
||||
|
||||
return attn_backend_list
|
||||
elif current_platform.is_xpu():
|
||||
|
||||
@ -417,9 +417,9 @@ def test_eagle_correctness(
|
||||
"multi-token eagle spec decode on current platform"
|
||||
)
|
||||
|
||||
if attn_backend == "FLASH_ATTN" and current_platform.is_rocm():
|
||||
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
|
||||
if "deepseek" in model_setup[1].lower():
|
||||
pytest.skip("FLASH_ATTN for deepseek not supported on ROCm platform")
|
||||
pytest.skip("ROCM_AITER_FA for deepseek not supported on ROCm platform")
|
||||
else:
|
||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
|
||||
@ -339,7 +339,7 @@ def test_load_model(
|
||||
"multi-token eagle spec decode on current platform"
|
||||
)
|
||||
|
||||
if attn_backend == "FLASH_ATTN" and current_platform.is_rocm():
|
||||
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
# Setup draft model mock
|
||||
@ -434,7 +434,7 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
|
||||
"because it requires special input mocking."
|
||||
)
|
||||
|
||||
if attn_backend == "FLASH_ATTN" and current_platform.is_rocm():
|
||||
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
# Use GPU device
|
||||
@ -541,6 +541,10 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
|
||||
attn_metadata_builder_cls, _ = try_get_attention_backend(
|
||||
AttentionBackendEnum.TREE_ATTN
|
||||
)
|
||||
elif attn_backend == "ROCM_AITER_FA":
|
||||
attn_metadata_builder_cls, _ = try_get_attention_backend(
|
||||
AttentionBackendEnum.ROCM_AITER_FA
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported attention backend: {attn_backend}")
|
||||
|
||||
|
||||
@ -47,7 +47,7 @@ def test_eagle_max_len(
|
||||
"multi-token eagle spec decode on current platform"
|
||||
)
|
||||
|
||||
if attn_backend == "FLASH_ATTN" and current_platform.is_rocm():
|
||||
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
|
||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
llm = LLM(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user