mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 10:34:58 +08:00
[CI/Build][AMD] Use ROCM_ATTN instead of FLASH_ATTN test for test_register_kv_caches for ROCm and update test for TRITON_ATTN (#29985)
Signed-off-by: Randall Smith <ransmith@amd.com> Co-authored-by: Randall Smith <ransmith@amd.com> Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
This commit is contained in:
parent
40a046cd82
commit
b12f4a9830
@ -41,6 +41,7 @@ from vllm.distributed.kv_transfer.kv_transfer_state import (
|
||||
has_kv_transfer_group,
|
||||
)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import Platform
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
||||
@ -1111,7 +1112,26 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
|
||||
llm.llm_engine.engine_core.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("attn_backend", ["FLASH_ATTN", "TRITON_ATTN"])
|
||||
@pytest.mark.parametrize(
|
||||
"attn_backend",
|
||||
[
|
||||
pytest.param(
|
||||
"FLASH_ATTN",
|
||||
marks=pytest.mark.skipif(
|
||||
current_platform.is_rocm(),
|
||||
reason="Attention backend FLASH_ATTN is not supported on ROCm",
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
"ROCM_ATTN",
|
||||
marks=pytest.mark.skipif(
|
||||
not current_platform.is_rocm(),
|
||||
reason="Attention backend ROCM_ATTN is only supported on ROCm",
|
||||
),
|
||||
),
|
||||
"TRITON_ATTN",
|
||||
],
|
||||
)
|
||||
def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
|
||||
"""
|
||||
Test that register_kv_caches() properly calls nixl_wrapper methods with
|
||||
@ -1133,6 +1153,10 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
||||
|
||||
backend_cls = FlashAttentionBackend
|
||||
elif attn_backend == "ROCM_ATTN":
|
||||
from vllm.v1.attention.backends.rocm_attn import RocmAttentionBackend
|
||||
|
||||
backend_cls = RocmAttentionBackend
|
||||
else: # TRITON_ATTN
|
||||
from vllm.v1.attention.backends.triton_attn import TritonAttentionBackend
|
||||
|
||||
@ -1151,6 +1175,7 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
|
||||
}
|
||||
|
||||
# Store tensor info for validation
|
||||
|
||||
test_shape = backend_cls.get_kv_cache_shape(
|
||||
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
|
||||
)
|
||||
@ -1175,17 +1200,18 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
|
||||
]
|
||||
expected_num_entries = 4
|
||||
|
||||
nixl_module = "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector"
|
||||
with (
|
||||
patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
|
||||
) as mock_nixl_wrapper,
|
||||
patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event"
|
||||
),
|
||||
patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread"
|
||||
) as mock_thread,
|
||||
): # noqa: E501
|
||||
patch(f"{nixl_module}.NixlWrapper") as mock_nixl_wrapper,
|
||||
patch(f"{nixl_module}.threading.Event"),
|
||||
patch(f"{nixl_module}.threading.Thread") as mock_thread,
|
||||
patch(f"{nixl_module}.get_attn_backend") as mock_get_attn_backend,
|
||||
):
|
||||
# Ensure get_attn_backend returns the correct value due to
|
||||
# _cached_get_attn_backend returning the backend from previous
|
||||
# test run if not mocking.
|
||||
mock_get_attn_backend.return_value = backend_cls
|
||||
|
||||
# Create connector
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector.connector_worker = FakeNixlConnectorWorker(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user