diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index 3c2aaabacae8c..4d969cf992d23 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -22,7 +22,10 @@ def clear_cache(): # Define MLA and non-MLA backends separately DEVICE_MLA_BACKENDS = { - "cuda": ["TRITON_MLA", "FLASHMLA", "FLASH_ATTN_MLA", "CUTLASS_MLA"], + "cuda": [ + "TRITON_MLA", "FLASHMLA", "FLASHINFER_MLA", "FLASH_ATTN_MLA", + "CUTLASS_MLA" + ], "hip": ["TRITON_MLA", "ROCM_AITER_MLA"], "cpu": [], } @@ -90,8 +93,8 @@ def test_env( with patch("vllm.attention.selector.current_platform", CpuPlatform()): - backend = get_attn_backend(16, torch.float16, torch.float16, - block_size, False) + backend = get_attn_backend(16, torch.float16, None, block_size, + False) assert backend.get_name() == "TORCH_SDPA_VLLM_V1" elif device == "hip": @@ -109,7 +112,7 @@ def test_env( with pytest.raises(ValueError) as exc_info: get_attn_backend(16, torch.float16, - torch.float16, + None, block_size, False, use_mla=use_mla) @@ -120,7 +123,7 @@ def test_env( with pytest.raises(ValueError) as exc_info: get_attn_backend(16, torch.float16, - torch.float16, + None, block_size, False, use_mla=use_mla) @@ -130,7 +133,7 @@ def test_env( # Valid backend-block_size combination backend = get_attn_backend(16, torch.float16, - torch.float16, + None, block_size, False, use_mla=use_mla) @@ -139,7 +142,7 @@ def test_env( else: backend = get_attn_backend(16, torch.float16, - torch.float16, + None, block_size, False, use_mla=use_mla) @@ -153,6 +156,8 @@ def test_env( # CUDA MLA backend logic: # - CUTLASS_MLA: only supported with block_size == 128 # and Blackwell GPUs (SM 10.0), V1 only + # - FLASHINFER_MLA: only supported on Blackwell GPUs + # (SM 10.0+), V1 only # - FLASHMLA: only supported with block_size == 64 # - FLASH_ATTN_MLA: V1 only # - TRITON_MLA: fallback for other cases @@ -169,12 +174,31 @@ def test_env( else: backend = get_attn_backend(16, torch.float16, - torch.float16, + None, block_size, False, use_mla=use_mla) expected = "CUTLASS_MLA_VLLM_V1" assert backend.get_name() == expected + elif name == "FLASHINFER_MLA": + if not use_v1: + # FlashInfer MLA only supported on V1 engine + pytest.skip( + "FlashInfer MLA only supported on V1 engine") + elif block_size not in [32, 64]: + # FlashInfer MLA only supports block_size 32 or 64 + pytest.skip( + "FlashInfer MLA only supports block_size 32 " + "or 64") + else: + backend = get_attn_backend(16, + torch.float16, + None, + block_size, + False, + use_mla=use_mla) + expected = "FLASHINFER_MLA" + assert backend.get_name() == expected elif name == "FLASHMLA": if block_size != 64: # FlashMLA only supports block_size == 64 @@ -189,7 +213,7 @@ def test_env( else: backend = get_attn_backend(16, torch.float16, - torch.float16, + None, block_size, False, use_mla=use_mla) @@ -204,7 +228,7 @@ def test_env( else: backend = get_attn_backend(16, torch.float16, - torch.float16, + None, block_size, False, use_mla=use_mla) @@ -214,7 +238,7 @@ def test_env( # TRITON_MLA or other fallback backend = get_attn_backend(16, torch.float16, - torch.float16, + None, block_size, False, use_mla=use_mla) @@ -224,7 +248,7 @@ def test_env( elif name == "FLASHINFER": backend = get_attn_backend(16, torch.float16, - torch.float16, + None, block_size, False, use_mla=use_mla) @@ -233,7 +257,7 @@ def test_env( else: backend = get_attn_backend(32, torch.float16, - torch.float16, + None, block_size, False, use_mla=use_mla) @@ -243,7 +267,7 @@ def test_env( if use_v1: backend = get_attn_backend(16, torch.float16, - torch.float16, + None, block_size, False, use_mla=use_mla) @@ -269,15 +293,13 @@ def test_fp32_fallback( with patch("vllm.attention.selector.current_platform", CpuPlatform()): - backend = get_attn_backend(16, torch.float32, torch.float32, - 16, False) + backend = get_attn_backend(16, torch.float32, None, 16, False) assert backend.get_name() == "TORCH_SDPA_VLLM_V1" elif device == "cuda": with patch("vllm.attention.selector.current_platform", CudaPlatform()): - backend = get_attn_backend(16, torch.float32, torch.float32, - 16, False) + backend = get_attn_backend(16, torch.float32, None, 16, False) assert (backend.get_name() == "FLEX_ATTENTION" if use_v1 else "XFORMERS") @@ -331,7 +353,7 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch): assert backend.get_name() != STR_FLASH_ATTN_VAL # Attention-free models should bypass env and use PlaceholderAttention - backend = get_attn_backend(16, torch.float16, torch.float16, 16, True) + backend = get_attn_backend(16, torch.float16, None, 16, True) assert backend.get_name() != STR_FLASH_ATTN_VAL diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index 5c49566240df4..f07c6eb0ea4da 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -141,6 +141,8 @@ def get_attention_backend(backend_name: _Backend): "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend", _Backend.FLASH_ATTN_MLA: "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend", + _Backend.FLASHINFER_MLA: + "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend", _Backend.TRITON_MLA_VLLM_V1: "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend", }