mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-02 07:40:53 +08:00
[CI/Build] Make test_mha_attn.py run on correct platform only and check for flash_attn_varlen_func in layer.py (#29145)
This commit is contained in:
parent
f8dacc66b6
commit
7618dc973d
@ -26,7 +26,14 @@ def clear_cache():
|
||||
_cached_get_attn_backend.cache_clear()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
|
||||
devices = ["cpu"]
|
||||
if current_platform.is_cuda():
|
||||
devices.append("cuda")
|
||||
if current_platform.is_rocm():
|
||||
devices.append("hip")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", devices)
|
||||
def test_mha_attn_platform(device: str):
|
||||
"""
|
||||
Test the attention selector between different platform and device.
|
||||
@ -46,7 +53,7 @@ def test_mha_attn_platform(device: str):
|
||||
patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()),
|
||||
):
|
||||
attn = MultiHeadAttention(16, 64, scale=1)
|
||||
assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA
|
||||
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||
else:
|
||||
# Test CUDA with head_size=64 (divisible by 32)
|
||||
# - should use vLLM's FlashAttention
|
||||
|
||||
@ -89,7 +89,10 @@ def maybe_get_vit_flash_attn_backend(
|
||||
if attn_backend == AttentionBackendEnum.ROCM_AITER_FA:
|
||||
from aiter import flash_attn_varlen_func
|
||||
else:
|
||||
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
|
||||
try:
|
||||
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
|
||||
except ImportError:
|
||||
flash_attn_varlen_func = None
|
||||
else:
|
||||
flash_attn_varlen_func = None
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user