mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 09:06:02 +08:00
Add FLASHINFER_MLA to backend selector test (#24753)
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
This commit is contained in:
parent
7ba32aa60b
commit
5fe643fc26
@ -22,7 +22,10 @@ def clear_cache():
|
||||
|
||||
# Define MLA and non-MLA backends separately
|
||||
DEVICE_MLA_BACKENDS = {
|
||||
"cuda": ["TRITON_MLA", "FLASHMLA", "FLASH_ATTN_MLA", "CUTLASS_MLA"],
|
||||
"cuda": [
|
||||
"TRITON_MLA", "FLASHMLA", "FLASHINFER_MLA", "FLASH_ATTN_MLA",
|
||||
"CUTLASS_MLA"
|
||||
],
|
||||
"hip": ["TRITON_MLA", "ROCM_AITER_MLA"],
|
||||
"cpu": [],
|
||||
}
|
||||
@ -90,8 +93,8 @@ def test_env(
|
||||
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
CpuPlatform()):
|
||||
backend = get_attn_backend(16, torch.float16, torch.float16,
|
||||
block_size, False)
|
||||
backend = get_attn_backend(16, torch.float16, None, block_size,
|
||||
False)
|
||||
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
|
||||
|
||||
elif device == "hip":
|
||||
@ -109,7 +112,7 @@ def test_env(
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
get_attn_backend(16,
|
||||
torch.float16,
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
@ -120,7 +123,7 @@ def test_env(
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
get_attn_backend(16,
|
||||
torch.float16,
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
@ -130,7 +133,7 @@ def test_env(
|
||||
# Valid backend-block_size combination
|
||||
backend = get_attn_backend(16,
|
||||
torch.float16,
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
@ -139,7 +142,7 @@ def test_env(
|
||||
else:
|
||||
backend = get_attn_backend(16,
|
||||
torch.float16,
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
@ -153,6 +156,8 @@ def test_env(
|
||||
# CUDA MLA backend logic:
|
||||
# - CUTLASS_MLA: only supported with block_size == 128
|
||||
# and Blackwell GPUs (SM 10.0), V1 only
|
||||
# - FLASHINFER_MLA: only supported on Blackwell GPUs
|
||||
# (SM 10.0+), V1 only
|
||||
# - FLASHMLA: only supported with block_size == 64
|
||||
# - FLASH_ATTN_MLA: V1 only
|
||||
# - TRITON_MLA: fallback for other cases
|
||||
@ -169,12 +174,31 @@ def test_env(
|
||||
else:
|
||||
backend = get_attn_backend(16,
|
||||
torch.float16,
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
expected = "CUTLASS_MLA_VLLM_V1"
|
||||
assert backend.get_name() == expected
|
||||
elif name == "FLASHINFER_MLA":
|
||||
if not use_v1:
|
||||
# FlashInfer MLA only supported on V1 engine
|
||||
pytest.skip(
|
||||
"FlashInfer MLA only supported on V1 engine")
|
||||
elif block_size not in [32, 64]:
|
||||
# FlashInfer MLA only supports block_size 32 or 64
|
||||
pytest.skip(
|
||||
"FlashInfer MLA only supports block_size 32 "
|
||||
"or 64")
|
||||
else:
|
||||
backend = get_attn_backend(16,
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
expected = "FLASHINFER_MLA"
|
||||
assert backend.get_name() == expected
|
||||
elif name == "FLASHMLA":
|
||||
if block_size != 64:
|
||||
# FlashMLA only supports block_size == 64
|
||||
@ -189,7 +213,7 @@ def test_env(
|
||||
else:
|
||||
backend = get_attn_backend(16,
|
||||
torch.float16,
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
@ -204,7 +228,7 @@ def test_env(
|
||||
else:
|
||||
backend = get_attn_backend(16,
|
||||
torch.float16,
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
@ -214,7 +238,7 @@ def test_env(
|
||||
# TRITON_MLA or other fallback
|
||||
backend = get_attn_backend(16,
|
||||
torch.float16,
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
@ -224,7 +248,7 @@ def test_env(
|
||||
elif name == "FLASHINFER":
|
||||
backend = get_attn_backend(16,
|
||||
torch.float16,
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
@ -233,7 +257,7 @@ def test_env(
|
||||
else:
|
||||
backend = get_attn_backend(32,
|
||||
torch.float16,
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
@ -243,7 +267,7 @@ def test_env(
|
||||
if use_v1:
|
||||
backend = get_attn_backend(16,
|
||||
torch.float16,
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
@ -269,15 +293,13 @@ def test_fp32_fallback(
|
||||
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
CpuPlatform()):
|
||||
backend = get_attn_backend(16, torch.float32, torch.float32,
|
||||
16, False)
|
||||
backend = get_attn_backend(16, torch.float32, None, 16, False)
|
||||
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
|
||||
|
||||
elif device == "cuda":
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
CudaPlatform()):
|
||||
backend = get_attn_backend(16, torch.float32, torch.float32,
|
||||
16, False)
|
||||
backend = get_attn_backend(16, torch.float32, None, 16, False)
|
||||
assert (backend.get_name() == "FLEX_ATTENTION"
|
||||
if use_v1 else "XFORMERS")
|
||||
|
||||
@ -331,7 +353,7 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
|
||||
# Attention-free models should bypass env and use PlaceholderAttention
|
||||
backend = get_attn_backend(16, torch.float16, torch.float16, 16, True)
|
||||
backend = get_attn_backend(16, torch.float16, None, 16, True)
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
|
||||
|
||||
|
||||
@ -141,6 +141,8 @@ def get_attention_backend(backend_name: _Backend):
|
||||
"vllm.v1.attention.backends.mla.flashmla.FlashMLABackend",
|
||||
_Backend.FLASH_ATTN_MLA:
|
||||
"vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend",
|
||||
_Backend.FLASHINFER_MLA:
|
||||
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend",
|
||||
_Backend.TRITON_MLA_VLLM_V1:
|
||||
"vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend",
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user