[CI/Build][AMD] Use ROCM_ATTN instead of FLASH_ATTN test for test_register_kv_caches for ROCm and update test for TRITON_ATTN (#29985)

Signed-off-by: Randall Smith <ransmith@amd.com>
Co-authored-by: Randall Smith <ransmith@amd.com>
Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
This commit is contained in:
rasmith 2025-12-05 22:57:38 -06:00 committed by GitHub
parent 40a046cd82
commit b12f4a9830
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -41,6 +41,7 @@ from vllm.distributed.kv_transfer.kv_transfer_state import (
has_kv_transfer_group,
)
from vllm.forward_context import ForwardContext
from vllm.platforms import current_platform
from vllm.platforms.interface import Platform
from vllm.sampling_params import SamplingParams
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
@ -1111,7 +1112,26 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
llm.llm_engine.engine_core.shutdown()
@pytest.mark.parametrize("attn_backend", ["FLASH_ATTN", "TRITON_ATTN"])
@pytest.mark.parametrize(
"attn_backend",
[
pytest.param(
"FLASH_ATTN",
marks=pytest.mark.skipif(
current_platform.is_rocm(),
reason="Attention backend FLASH_ATTN is not supported on ROCm",
),
),
pytest.param(
"ROCM_ATTN",
marks=pytest.mark.skipif(
not current_platform.is_rocm(),
reason="Attention backend ROCM_ATTN is only supported on ROCm",
),
),
"TRITON_ATTN",
],
)
def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
"""
Test that register_kv_caches() properly calls nixl_wrapper methods with
@ -1133,6 +1153,10 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
backend_cls = FlashAttentionBackend
elif attn_backend == "ROCM_ATTN":
from vllm.v1.attention.backends.rocm_attn import RocmAttentionBackend
backend_cls = RocmAttentionBackend
else: # TRITON_ATTN
from vllm.v1.attention.backends.triton_attn import TritonAttentionBackend
@ -1151,6 +1175,7 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
}
# Store tensor info for validation
test_shape = backend_cls.get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
)
@ -1175,17 +1200,18 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
]
expected_num_entries = 4
nixl_module = "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector"
with (
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
) as mock_nixl_wrapper,
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event"
),
patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread"
) as mock_thread,
): # noqa: E501
patch(f"{nixl_module}.NixlWrapper") as mock_nixl_wrapper,
patch(f"{nixl_module}.threading.Event"),
patch(f"{nixl_module}.threading.Thread") as mock_thread,
patch(f"{nixl_module}.get_attn_backend") as mock_get_attn_backend,
):
# Ensure get_attn_backend returns the correct value due to
# _cached_get_attn_backend returning the backend from previous
# test run if not mocking.
mock_get_attn_backend.return_value = backend_cls
# Create connector
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(