diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index 66db7509cc47..1615c23a4f71 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -49,7 +49,7 @@ def test_env( RocmPlatform()): backend = get_attn_backend(16, torch.float16, torch.float16, 16, False) - EXPECTED = "ROCM_ATTN_VLLM_V1" if use_v1 else "ROCM_FLASH" + EXPECTED = "TRITON_ATTN_VLLM_V1" if use_v1 else "ROCM_FLASH" assert backend.get_name() == EXPECTED elif device == "openvino": with patch("vllm.attention.selector.current_platform", diff --git a/tests/kernels/test_rocm_attention_selector.py b/tests/kernels/test_rocm_attention_selector.py index 724f0af283f7..90b483b4a41a 100644 --- a/tests/kernels/test_rocm_attention_selector.py +++ b/tests/kernels/test_rocm_attention_selector.py @@ -26,7 +26,7 @@ def test_selector(monkeypatch: pytest.MonkeyPatch): # Test standard ROCm attention backend = get_attn_backend(16, torch.float16, torch.float16, 16, False) assert (backend.get_name() == "ROCM_FLASH" - or backend.get_name() == "ROCM_ATTN_VLLM_V1") + or backend.get_name() == "TRITON_ATTN_VLLM_V1") # mla test for deepseek related backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False,