mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-26 02:23:06 +08:00
Add FLASHINFER_MLA to backend selector test (#24753)
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
This commit is contained in:
parent
7ba32aa60b
commit
5fe643fc26
@ -22,7 +22,10 @@ def clear_cache():
|
|||||||
|
|
||||||
# Define MLA and non-MLA backends separately
|
# Define MLA and non-MLA backends separately
|
||||||
DEVICE_MLA_BACKENDS = {
|
DEVICE_MLA_BACKENDS = {
|
||||||
"cuda": ["TRITON_MLA", "FLASHMLA", "FLASH_ATTN_MLA", "CUTLASS_MLA"],
|
"cuda": [
|
||||||
|
"TRITON_MLA", "FLASHMLA", "FLASHINFER_MLA", "FLASH_ATTN_MLA",
|
||||||
|
"CUTLASS_MLA"
|
||||||
|
],
|
||||||
"hip": ["TRITON_MLA", "ROCM_AITER_MLA"],
|
"hip": ["TRITON_MLA", "ROCM_AITER_MLA"],
|
||||||
"cpu": [],
|
"cpu": [],
|
||||||
}
|
}
|
||||||
@ -90,8 +93,8 @@ def test_env(
|
|||||||
|
|
||||||
with patch("vllm.attention.selector.current_platform",
|
with patch("vllm.attention.selector.current_platform",
|
||||||
CpuPlatform()):
|
CpuPlatform()):
|
||||||
backend = get_attn_backend(16, torch.float16, torch.float16,
|
backend = get_attn_backend(16, torch.float16, None, block_size,
|
||||||
block_size, False)
|
False)
|
||||||
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
|
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
|
||||||
|
|
||||||
elif device == "hip":
|
elif device == "hip":
|
||||||
@ -109,7 +112,7 @@ def test_env(
|
|||||||
with pytest.raises(ValueError) as exc_info:
|
with pytest.raises(ValueError) as exc_info:
|
||||||
get_attn_backend(16,
|
get_attn_backend(16,
|
||||||
torch.float16,
|
torch.float16,
|
||||||
torch.float16,
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
False,
|
False,
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
@ -120,7 +123,7 @@ def test_env(
|
|||||||
with pytest.raises(ValueError) as exc_info:
|
with pytest.raises(ValueError) as exc_info:
|
||||||
get_attn_backend(16,
|
get_attn_backend(16,
|
||||||
torch.float16,
|
torch.float16,
|
||||||
torch.float16,
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
False,
|
False,
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
@ -130,7 +133,7 @@ def test_env(
|
|||||||
# Valid backend-block_size combination
|
# Valid backend-block_size combination
|
||||||
backend = get_attn_backend(16,
|
backend = get_attn_backend(16,
|
||||||
torch.float16,
|
torch.float16,
|
||||||
torch.float16,
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
False,
|
False,
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
@ -139,7 +142,7 @@ def test_env(
|
|||||||
else:
|
else:
|
||||||
backend = get_attn_backend(16,
|
backend = get_attn_backend(16,
|
||||||
torch.float16,
|
torch.float16,
|
||||||
torch.float16,
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
False,
|
False,
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
@ -153,6 +156,8 @@ def test_env(
|
|||||||
# CUDA MLA backend logic:
|
# CUDA MLA backend logic:
|
||||||
# - CUTLASS_MLA: only supported with block_size == 128
|
# - CUTLASS_MLA: only supported with block_size == 128
|
||||||
# and Blackwell GPUs (SM 10.0), V1 only
|
# and Blackwell GPUs (SM 10.0), V1 only
|
||||||
|
# - FLASHINFER_MLA: only supported on Blackwell GPUs
|
||||||
|
# (SM 10.0+), V1 only
|
||||||
# - FLASHMLA: only supported with block_size == 64
|
# - FLASHMLA: only supported with block_size == 64
|
||||||
# - FLASH_ATTN_MLA: V1 only
|
# - FLASH_ATTN_MLA: V1 only
|
||||||
# - TRITON_MLA: fallback for other cases
|
# - TRITON_MLA: fallback for other cases
|
||||||
@ -169,12 +174,31 @@ def test_env(
|
|||||||
else:
|
else:
|
||||||
backend = get_attn_backend(16,
|
backend = get_attn_backend(16,
|
||||||
torch.float16,
|
torch.float16,
|
||||||
torch.float16,
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
False,
|
False,
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
expected = "CUTLASS_MLA_VLLM_V1"
|
expected = "CUTLASS_MLA_VLLM_V1"
|
||||||
assert backend.get_name() == expected
|
assert backend.get_name() == expected
|
||||||
|
elif name == "FLASHINFER_MLA":
|
||||||
|
if not use_v1:
|
||||||
|
# FlashInfer MLA only supported on V1 engine
|
||||||
|
pytest.skip(
|
||||||
|
"FlashInfer MLA only supported on V1 engine")
|
||||||
|
elif block_size not in [32, 64]:
|
||||||
|
# FlashInfer MLA only supports block_size 32 or 64
|
||||||
|
pytest.skip(
|
||||||
|
"FlashInfer MLA only supports block_size 32 "
|
||||||
|
"or 64")
|
||||||
|
else:
|
||||||
|
backend = get_attn_backend(16,
|
||||||
|
torch.float16,
|
||||||
|
None,
|
||||||
|
block_size,
|
||||||
|
False,
|
||||||
|
use_mla=use_mla)
|
||||||
|
expected = "FLASHINFER_MLA"
|
||||||
|
assert backend.get_name() == expected
|
||||||
elif name == "FLASHMLA":
|
elif name == "FLASHMLA":
|
||||||
if block_size != 64:
|
if block_size != 64:
|
||||||
# FlashMLA only supports block_size == 64
|
# FlashMLA only supports block_size == 64
|
||||||
@ -189,7 +213,7 @@ def test_env(
|
|||||||
else:
|
else:
|
||||||
backend = get_attn_backend(16,
|
backend = get_attn_backend(16,
|
||||||
torch.float16,
|
torch.float16,
|
||||||
torch.float16,
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
False,
|
False,
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
@ -204,7 +228,7 @@ def test_env(
|
|||||||
else:
|
else:
|
||||||
backend = get_attn_backend(16,
|
backend = get_attn_backend(16,
|
||||||
torch.float16,
|
torch.float16,
|
||||||
torch.float16,
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
False,
|
False,
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
@ -214,7 +238,7 @@ def test_env(
|
|||||||
# TRITON_MLA or other fallback
|
# TRITON_MLA or other fallback
|
||||||
backend = get_attn_backend(16,
|
backend = get_attn_backend(16,
|
||||||
torch.float16,
|
torch.float16,
|
||||||
torch.float16,
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
False,
|
False,
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
@ -224,7 +248,7 @@ def test_env(
|
|||||||
elif name == "FLASHINFER":
|
elif name == "FLASHINFER":
|
||||||
backend = get_attn_backend(16,
|
backend = get_attn_backend(16,
|
||||||
torch.float16,
|
torch.float16,
|
||||||
torch.float16,
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
False,
|
False,
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
@ -233,7 +257,7 @@ def test_env(
|
|||||||
else:
|
else:
|
||||||
backend = get_attn_backend(32,
|
backend = get_attn_backend(32,
|
||||||
torch.float16,
|
torch.float16,
|
||||||
torch.float16,
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
False,
|
False,
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
@ -243,7 +267,7 @@ def test_env(
|
|||||||
if use_v1:
|
if use_v1:
|
||||||
backend = get_attn_backend(16,
|
backend = get_attn_backend(16,
|
||||||
torch.float16,
|
torch.float16,
|
||||||
torch.float16,
|
None,
|
||||||
block_size,
|
block_size,
|
||||||
False,
|
False,
|
||||||
use_mla=use_mla)
|
use_mla=use_mla)
|
||||||
@ -269,15 +293,13 @@ def test_fp32_fallback(
|
|||||||
|
|
||||||
with patch("vllm.attention.selector.current_platform",
|
with patch("vllm.attention.selector.current_platform",
|
||||||
CpuPlatform()):
|
CpuPlatform()):
|
||||||
backend = get_attn_backend(16, torch.float32, torch.float32,
|
backend = get_attn_backend(16, torch.float32, None, 16, False)
|
||||||
16, False)
|
|
||||||
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
|
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
|
||||||
|
|
||||||
elif device == "cuda":
|
elif device == "cuda":
|
||||||
with patch("vllm.attention.selector.current_platform",
|
with patch("vllm.attention.selector.current_platform",
|
||||||
CudaPlatform()):
|
CudaPlatform()):
|
||||||
backend = get_attn_backend(16, torch.float32, torch.float32,
|
backend = get_attn_backend(16, torch.float32, None, 16, False)
|
||||||
16, False)
|
|
||||||
assert (backend.get_name() == "FLEX_ATTENTION"
|
assert (backend.get_name() == "FLEX_ATTENTION"
|
||||||
if use_v1 else "XFORMERS")
|
if use_v1 else "XFORMERS")
|
||||||
|
|
||||||
@ -331,7 +353,7 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
|
|||||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||||
|
|
||||||
# Attention-free models should bypass env and use PlaceholderAttention
|
# Attention-free models should bypass env and use PlaceholderAttention
|
||||||
backend = get_attn_backend(16, torch.float16, torch.float16, 16, True)
|
backend = get_attn_backend(16, torch.float16, None, 16, True)
|
||||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -141,6 +141,8 @@ def get_attention_backend(backend_name: _Backend):
|
|||||||
"vllm.v1.attention.backends.mla.flashmla.FlashMLABackend",
|
"vllm.v1.attention.backends.mla.flashmla.FlashMLABackend",
|
||||||
_Backend.FLASH_ATTN_MLA:
|
_Backend.FLASH_ATTN_MLA:
|
||||||
"vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend",
|
"vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend",
|
||||||
|
_Backend.FLASHINFER_MLA:
|
||||||
|
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend",
|
||||||
_Backend.TRITON_MLA_VLLM_V1:
|
_Backend.TRITON_MLA_VLLM_V1:
|
||||||
"vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend",
|
"vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend",
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user