mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 04:44:59 +08:00
[ROCm][CI] Fix test_max_len.py for Rocm (#29916)
Signed-off-by: charlifu <charlifu@amd.com> Signed-off-by: Charlie Fu <Charlie.Fu@amd.com>
This commit is contained in:
parent
ae0f69b16a
commit
6af70e11a0
@ -13,12 +13,15 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import LLM
|
from vllm import LLM
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.v1.engine.llm_engine import LLMEngine
|
from vllm.v1.engine.llm_engine import LLMEngine
|
||||||
|
|
||||||
from ..conftest import HfRunner, VllmRunner
|
from ..conftest import HfRunner, VllmRunner
|
||||||
from ..models.utils import check_outputs_equal
|
from ..models.utils import check_outputs_equal
|
||||||
from ..utils import multi_gpu_test
|
from ..utils import multi_gpu_test
|
||||||
|
|
||||||
|
ATTN_BACKEND = ["ROCM_ATTN"] if current_platform.is_rocm() else ["FLASH_ATTN"]
|
||||||
|
|
||||||
MODELS = [
|
MODELS = [
|
||||||
"hmellor/tiny-random-Gemma2ForCausalLM",
|
"hmellor/tiny-random-Gemma2ForCausalLM",
|
||||||
"meta-llama/Llama-3.2-1B-Instruct",
|
"meta-llama/Llama-3.2-1B-Instruct",
|
||||||
@ -57,7 +60,7 @@ def _fix_prompt_embed_outputs(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
|
@pytest.mark.parametrize("backend", ATTN_BACKEND)
|
||||||
@pytest.mark.parametrize("max_tokens", [5])
|
@pytest.mark.parametrize("max_tokens", [5])
|
||||||
@pytest.mark.parametrize("enforce_eager", [False])
|
@pytest.mark.parametrize("enforce_eager", [False])
|
||||||
@pytest.mark.parametrize("async_scheduling", [True, False])
|
@pytest.mark.parametrize("async_scheduling", [True, False])
|
||||||
|
|||||||
@ -1225,9 +1225,9 @@ def get_attn_backend_list_based_on_platform() -> list[str]:
|
|||||||
try:
|
try:
|
||||||
import aiter # noqa: F401
|
import aiter # noqa: F401
|
||||||
|
|
||||||
attn_backend_list.append("FLASH_ATTN")
|
attn_backend_list.append("ROCM_AITER_FA")
|
||||||
except Exception:
|
except Exception:
|
||||||
print("Skip FLASH_ATTN on ROCm as aiter is not installed")
|
print("Skip ROCM_AITER_FA on ROCm as aiter is not installed")
|
||||||
|
|
||||||
return attn_backend_list
|
return attn_backend_list
|
||||||
elif current_platform.is_xpu():
|
elif current_platform.is_xpu():
|
||||||
|
|||||||
@ -417,9 +417,9 @@ def test_eagle_correctness(
|
|||||||
"multi-token eagle spec decode on current platform"
|
"multi-token eagle spec decode on current platform"
|
||||||
)
|
)
|
||||||
|
|
||||||
if attn_backend == "FLASH_ATTN" and current_platform.is_rocm():
|
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
|
||||||
if "deepseek" in model_setup[1].lower():
|
if "deepseek" in model_setup[1].lower():
|
||||||
pytest.skip("FLASH_ATTN for deepseek not supported on ROCm platform")
|
pytest.skip("ROCM_AITER_FA for deepseek not supported on ROCm platform")
|
||||||
else:
|
else:
|
||||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||||
|
|
||||||
|
|||||||
@ -339,7 +339,7 @@ def test_load_model(
|
|||||||
"multi-token eagle spec decode on current platform"
|
"multi-token eagle spec decode on current platform"
|
||||||
)
|
)
|
||||||
|
|
||||||
if attn_backend == "FLASH_ATTN" and current_platform.is_rocm():
|
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
|
||||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||||
|
|
||||||
# Setup draft model mock
|
# Setup draft model mock
|
||||||
@ -434,7 +434,7 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
|
|||||||
"because it requires special input mocking."
|
"because it requires special input mocking."
|
||||||
)
|
)
|
||||||
|
|
||||||
if attn_backend == "FLASH_ATTN" and current_platform.is_rocm():
|
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
|
||||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||||
|
|
||||||
# Use GPU device
|
# Use GPU device
|
||||||
@ -541,6 +541,10 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
|
|||||||
attn_metadata_builder_cls, _ = try_get_attention_backend(
|
attn_metadata_builder_cls, _ = try_get_attention_backend(
|
||||||
AttentionBackendEnum.TREE_ATTN
|
AttentionBackendEnum.TREE_ATTN
|
||||||
)
|
)
|
||||||
|
elif attn_backend == "ROCM_AITER_FA":
|
||||||
|
attn_metadata_builder_cls, _ = try_get_attention_backend(
|
||||||
|
AttentionBackendEnum.ROCM_AITER_FA
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported attention backend: {attn_backend}")
|
raise ValueError(f"Unsupported attention backend: {attn_backend}")
|
||||||
|
|
||||||
|
|||||||
@ -47,7 +47,7 @@ def test_eagle_max_len(
|
|||||||
"multi-token eagle spec decode on current platform"
|
"multi-token eagle spec decode on current platform"
|
||||||
)
|
)
|
||||||
|
|
||||||
if attn_backend == "FLASH_ATTN" and current_platform.is_rocm():
|
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
|
||||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||||
|
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user