From 7618dc973dd1e56a46162bc7bd6e7625143bead0 Mon Sep 17 00:00:00 2001 From: rasmith Date: Tue, 9 Dec 2025 14:18:17 -0600 Subject: [PATCH] [CI/Build] Make test_mha_attn.py run on correct platform only and check for flash_attn_varlen_func in layer.py (#29145) --- tests/kernels/attention/test_mha_attn.py | 11 +++++++++-- vllm/attention/layer.py | 5 ++++- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/tests/kernels/attention/test_mha_attn.py b/tests/kernels/attention/test_mha_attn.py index ae3c63cc62d6b..639abdf6f0487 100644 --- a/tests/kernels/attention/test_mha_attn.py +++ b/tests/kernels/attention/test_mha_attn.py @@ -26,7 +26,14 @@ def clear_cache(): _cached_get_attn_backend.cache_clear() -@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"]) +devices = ["cpu"] +if current_platform.is_cuda(): + devices.append("cuda") +if current_platform.is_rocm(): + devices.append("hip") + + +@pytest.mark.parametrize("device", devices) def test_mha_attn_platform(device: str): """ Test the attention selector between different platform and device. @@ -46,7 +53,7 @@ def test_mha_attn_platform(device: str): patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()), ): attn = MultiHeadAttention(16, 64, scale=1) - assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA + assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN else: # Test CUDA with head_size=64 (divisible by 32) # - should use vLLM's FlashAttention diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 7e5adfe0742d3..c77fc0fad0038 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -89,7 +89,10 @@ def maybe_get_vit_flash_attn_backend( if attn_backend == AttentionBackendEnum.ROCM_AITER_FA: from aiter import flash_attn_varlen_func else: - from vllm.attention.utils.fa_utils import flash_attn_varlen_func + try: + from vllm.attention.utils.fa_utils import flash_attn_varlen_func + except ImportError: + flash_attn_varlen_func = None else: flash_attn_varlen_func = None