Add FLASHINFER_MLA to backend selector test (#24753)

Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
This commit is contained in:
Matthew Bonanni 2025-09-12 18:30:07 -04:00 committed by GitHub
parent 7ba32aa60b
commit 5fe643fc26
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 43 additions and 19 deletions

View File

@ -22,7 +22,10 @@ def clear_cache():
# Define MLA and non-MLA backends separately
DEVICE_MLA_BACKENDS = {
"cuda": ["TRITON_MLA", "FLASHMLA", "FLASH_ATTN_MLA", "CUTLASS_MLA"],
"cuda": [
"TRITON_MLA", "FLASHMLA", "FLASHINFER_MLA", "FLASH_ATTN_MLA",
"CUTLASS_MLA"
],
"hip": ["TRITON_MLA", "ROCM_AITER_MLA"],
"cpu": [],
}
@ -90,8 +93,8 @@ def test_env(
with patch("vllm.attention.selector.current_platform",
CpuPlatform()):
backend = get_attn_backend(16, torch.float16, torch.float16,
block_size, False)
backend = get_attn_backend(16, torch.float16, None, block_size,
False)
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
elif device == "hip":
@ -109,7 +112,7 @@ def test_env(
with pytest.raises(ValueError) as exc_info:
get_attn_backend(16,
torch.float16,
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
@ -120,7 +123,7 @@ def test_env(
with pytest.raises(ValueError) as exc_info:
get_attn_backend(16,
torch.float16,
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
@ -130,7 +133,7 @@ def test_env(
# Valid backend-block_size combination
backend = get_attn_backend(16,
torch.float16,
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
@ -139,7 +142,7 @@ def test_env(
else:
backend = get_attn_backend(16,
torch.float16,
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
@ -153,6 +156,8 @@ def test_env(
# CUDA MLA backend logic:
# - CUTLASS_MLA: only supported with block_size == 128
# and Blackwell GPUs (SM 10.0), V1 only
# - FLASHINFER_MLA: only supported on Blackwell GPUs
# (SM 10.0+), V1 only
# - FLASHMLA: only supported with block_size == 64
# - FLASH_ATTN_MLA: V1 only
# - TRITON_MLA: fallback for other cases
@ -169,12 +174,31 @@ def test_env(
else:
backend = get_attn_backend(16,
torch.float16,
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
expected = "CUTLASS_MLA_VLLM_V1"
assert backend.get_name() == expected
elif name == "FLASHINFER_MLA":
if not use_v1:
# FlashInfer MLA only supported on V1 engine
pytest.skip(
"FlashInfer MLA only supported on V1 engine")
elif block_size not in [32, 64]:
# FlashInfer MLA only supports block_size 32 or 64
pytest.skip(
"FlashInfer MLA only supports block_size 32 "
"or 64")
else:
backend = get_attn_backend(16,
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
expected = "FLASHINFER_MLA"
assert backend.get_name() == expected
elif name == "FLASHMLA":
if block_size != 64:
# FlashMLA only supports block_size == 64
@ -189,7 +213,7 @@ def test_env(
else:
backend = get_attn_backend(16,
torch.float16,
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
@ -204,7 +228,7 @@ def test_env(
else:
backend = get_attn_backend(16,
torch.float16,
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
@ -214,7 +238,7 @@ def test_env(
# TRITON_MLA or other fallback
backend = get_attn_backend(16,
torch.float16,
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
@ -224,7 +248,7 @@ def test_env(
elif name == "FLASHINFER":
backend = get_attn_backend(16,
torch.float16,
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
@ -233,7 +257,7 @@ def test_env(
else:
backend = get_attn_backend(32,
torch.float16,
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
@ -243,7 +267,7 @@ def test_env(
if use_v1:
backend = get_attn_backend(16,
torch.float16,
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
@ -269,15 +293,13 @@ def test_fp32_fallback(
with patch("vllm.attention.selector.current_platform",
CpuPlatform()):
backend = get_attn_backend(16, torch.float32, torch.float32,
16, False)
backend = get_attn_backend(16, torch.float32, None, 16, False)
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
elif device == "cuda":
with patch("vllm.attention.selector.current_platform",
CudaPlatform()):
backend = get_attn_backend(16, torch.float32, torch.float32,
16, False)
backend = get_attn_backend(16, torch.float32, None, 16, False)
assert (backend.get_name() == "FLEX_ATTENTION"
if use_v1 else "XFORMERS")
@ -331,7 +353,7 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
assert backend.get_name() != STR_FLASH_ATTN_VAL
# Attention-free models should bypass env and use PlaceholderAttention
backend = get_attn_backend(16, torch.float16, torch.float16, 16, True)
backend = get_attn_backend(16, torch.float16, None, 16, True)
assert backend.get_name() != STR_FLASH_ATTN_VAL

View File

@ -141,6 +141,8 @@ def get_attention_backend(backend_name: _Backend):
"vllm.v1.attention.backends.mla.flashmla.FlashMLABackend",
_Backend.FLASH_ATTN_MLA:
"vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend",
_Backend.FLASHINFER_MLA:
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend",
_Backend.TRITON_MLA_VLLM_V1:
"vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend",
}