mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-06 12:22:16 +08:00
[CI/Build] Make test_mha_attn.py run on correct platform only and check for flash_attn_varlen_func in layer.py (#29145)
This commit is contained in:
parent
f8dacc66b6
commit
7618dc973d
@ -26,7 +26,14 @@ def clear_cache():
|
|||||||
_cached_get_attn_backend.cache_clear()
|
_cached_get_attn_backend.cache_clear()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
|
devices = ["cpu"]
|
||||||
|
if current_platform.is_cuda():
|
||||||
|
devices.append("cuda")
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
devices.append("hip")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("device", devices)
|
||||||
def test_mha_attn_platform(device: str):
|
def test_mha_attn_platform(device: str):
|
||||||
"""
|
"""
|
||||||
Test the attention selector between different platform and device.
|
Test the attention selector between different platform and device.
|
||||||
@ -46,7 +53,7 @@ def test_mha_attn_platform(device: str):
|
|||||||
patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()),
|
patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()),
|
||||||
):
|
):
|
||||||
attn = MultiHeadAttention(16, 64, scale=1)
|
attn = MultiHeadAttention(16, 64, scale=1)
|
||||||
assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA
|
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||||
else:
|
else:
|
||||||
# Test CUDA with head_size=64 (divisible by 32)
|
# Test CUDA with head_size=64 (divisible by 32)
|
||||||
# - should use vLLM's FlashAttention
|
# - should use vLLM's FlashAttention
|
||||||
|
|||||||
@ -89,7 +89,10 @@ def maybe_get_vit_flash_attn_backend(
|
|||||||
if attn_backend == AttentionBackendEnum.ROCM_AITER_FA:
|
if attn_backend == AttentionBackendEnum.ROCM_AITER_FA:
|
||||||
from aiter import flash_attn_varlen_func
|
from aiter import flash_attn_varlen_func
|
||||||
else:
|
else:
|
||||||
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
|
try:
|
||||||
|
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
|
||||||
|
except ImportError:
|
||||||
|
flash_attn_varlen_func = None
|
||||||
else:
|
else:
|
||||||
flash_attn_varlen_func = None
|
flash_attn_varlen_func = None
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user