[CI/Build] Make test_mha_attn.py run on correct platform only and check for flash_attn_varlen_func in layer.py (#29145)

This commit is contained in:
rasmith 2025-12-09 14:18:17 -06:00 committed by GitHub
parent f8dacc66b6
commit 7618dc973d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 13 additions and 3 deletions

View File

@ -26,7 +26,14 @@ def clear_cache():
_cached_get_attn_backend.cache_clear()
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
devices = ["cpu"]
if current_platform.is_cuda():
devices.append("cuda")
if current_platform.is_rocm():
devices.append("hip")
@pytest.mark.parametrize("device", devices)
def test_mha_attn_platform(device: str):
"""
Test the attention selector between different platform and device.
@ -46,7 +53,7 @@ def test_mha_attn_platform(device: str):
patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()),
):
attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
else:
# Test CUDA with head_size=64 (divisible by 32)
# - should use vLLM's FlashAttention

View File

@ -89,7 +89,10 @@ def maybe_get_vit_flash_attn_backend(
if attn_backend == AttentionBackendEnum.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
try:
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
except ImportError:
flash_attn_varlen_func = None
else:
flash_attn_varlen_func = None