diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index 20061ad2f8bf7..48a42ce6ffab5 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -84,12 +84,12 @@ def test_env( m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0") if device == "cpu": - with patch("vllm.attention.selector.current_platform", CpuPlatform()): + with patch("vllm.platforms.current_platform", CpuPlatform()): backend = get_attn_backend(16, torch.float16, None, block_size) assert backend.get_name() == "TORCH_SDPA" elif device == "hip": - with patch("vllm.attention.selector.current_platform", RocmPlatform()): + with patch("vllm.platforms.current_platform", RocmPlatform()): if use_mla: # ROCm MLA backend logic: # - TRITON_MLA: supported when block_size != 1 @@ -126,7 +126,7 @@ def test_env( assert backend.get_name() == expected elif device == "cuda": - with patch("vllm.attention.selector.current_platform", CudaPlatform()): + with patch("vllm.platforms.current_platform", CudaPlatform()): if use_mla: # CUDA MLA backend logic: # - CUTLASS_MLA: only supported with block_size == 128 @@ -214,12 +214,12 @@ def test_env( def test_fp32_fallback(device: str): """Test attention backend selection with fp32.""" if device == "cpu": - with patch("vllm.attention.selector.current_platform", CpuPlatform()): + with patch("vllm.platforms.current_platform", CpuPlatform()): backend = get_attn_backend(16, torch.float32, None, 16) assert backend.get_name() == "TORCH_SDPA" elif device == "cuda": - with patch("vllm.attention.selector.current_platform", CudaPlatform()): + with patch("vllm.platforms.current_platform", CudaPlatform()): backend = get_attn_backend(16, torch.float32, None, 16) assert backend.get_name() == "FLEX_ATTENTION" @@ -277,7 +277,7 @@ def test_invalid_env(monkeypatch: pytest.MonkeyPatch): """Test that invalid attention backend names raise ValueError.""" with ( monkeypatch.context() as m, - patch("vllm.attention.selector.current_platform", CudaPlatform()), + patch("vllm.platforms.current_platform", CudaPlatform()), ): m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 53677372e0551..7dfe6ffda6a80 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -14,7 +14,6 @@ import vllm.envs as envs from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.registry import _Backend, backend_name_to_enum from vllm.logger import init_logger -from vllm.platforms import current_platform from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname logger = init_logger(__name__) @@ -192,6 +191,8 @@ def _cached_get_attn_backend( ) # get device-specific attn_backend + from vllm.platforms import current_platform + attention_cls = current_platform.get_attn_backend_cls( selected_backend, head_size,