From f80e7866c096c478021e04911fcdaedbd3d69930 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 7 Oct 2025 22:09:34 -0400 Subject: [PATCH] [Misc] Clean up cruft from previous FlashMLA sparse implementation (#26125) Signed-off-by: Lucas Wilkinson --- .../attention/test_attention_selector.py | 4 +- tests/kernels/attention/test_flashmla.py | 10 ++- .../kernels/attention/test_flashmla_sparse.py | 25 ++---- .../v1/attention/test_sparse_mla_backends.py | 83 ++++--------------- tests/v1/attention/utils.py | 4 + vllm/attention/ops/flashmla.py | 46 +++++++--- vllm/platforms/cuda.py | 8 +- vllm/v1/attention/backends/mla/flashmla.py | 4 +- .../attention/backends/mla/flashmla_sparse.py | 37 --------- 9 files changed, 80 insertions(+), 141 deletions(-) diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index 5c607f921536e..fa95c3b2d39ea 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -165,10 +165,10 @@ def test_env( pytest.skip("FlashMLA only supports block_size 64") else: from vllm.v1.attention.backends.mla.flashmla import ( - is_flashmla_supported, + is_flashmla_dense_supported, ) - is_supported, _ = is_flashmla_supported() + is_supported, _ = is_flashmla_dense_supported() if not is_supported: pytest.skip("FlashMLA not supported on this platform") else: diff --git a/tests/kernels/attention/test_flashmla.py b/tests/kernels/attention/test_flashmla.py index 2b6fd38e4f58f..2151933a610d8 100644 --- a/tests/kernels/attention/test_flashmla.py +++ b/tests/kernels/attention/test_flashmla.py @@ -10,7 +10,7 @@ import torch from vllm.attention.ops.flashmla import ( flash_mla_with_kvcache, get_mla_metadata, - is_flashmla_supported, + is_flashmla_dense_supported, ) from vllm.triton_utils import triton @@ -27,13 +27,15 @@ def cal_diff( FLASH_MLA_UNSUPPORTED_REASON = ( - is_flashmla_supported()[1] - if not is_flashmla_supported()[0] + is_flashmla_dense_supported()[1] + if not is_flashmla_dense_supported()[0] else "FlashMLA is supported" ) -@pytest.mark.skipif(not is_flashmla_supported()[0], reason=FLASH_MLA_UNSUPPORTED_REASON) +@pytest.mark.skipif( + not is_flashmla_dense_supported()[0], reason=FLASH_MLA_UNSUPPORTED_REASON +) @pytest.mark.parametrize("b", [128]) @pytest.mark.parametrize("s_q", [1, 2]) @pytest.mark.parametrize("mean_sk", [4096, 8192, 16384]) diff --git a/tests/kernels/attention/test_flashmla_sparse.py b/tests/kernels/attention/test_flashmla_sparse.py index 562ae3009e41d..7ee6f4b07b4a9 100644 --- a/tests/kernels/attention/test_flashmla_sparse.py +++ b/tests/kernels/attention/test_flashmla_sparse.py @@ -4,19 +4,12 @@ import pytest import torch -def _cuda_sm90_available() -> bool: - if not torch.cuda.is_available(): - return False - major, _ = torch.cuda.get_device_capability() - return major == 9 - - def test_sparse_flashmla_metadata_smoke(): import vllm.attention.ops.flashmla as fm - ok, reason = fm.is_flashmla_supported() - if not ok or not _cuda_sm90_available(): - pytest.skip(reason or "SM90 not available") + ok, reason = fm.is_flashmla_sparse_supported() + if not ok: + pytest.skip(reason) device = torch.device("cuda") batch_size = 1 @@ -43,9 +36,9 @@ def test_sparse_flashmla_metadata_smoke(): def test_sparse_flashmla_decode_smoke(): import vllm.attention.ops.flashmla as fm - ok, reason = fm.is_flashmla_supported() - if not ok or not _cuda_sm90_available(): - pytest.skip(reason or "SM90 not available") + ok, reason = fm.is_flashmla_sparse_supported() + if not ok: + pytest.skip(reason) device = torch.device("cuda") batch_size = 1 @@ -106,9 +99,9 @@ def test_sparse_flashmla_decode_smoke(): def test_sparse_flashmla_prefill_smoke(): import vllm.attention.ops.flashmla as fm - ok, reason = fm.is_flashmla_supported() - if not ok or not _cuda_sm90_available(): - pytest.skip(reason or "SM90 not available") + ok, reason = fm.is_flashmla_sparse_supported() + if not ok: + pytest.skip(reason) device = torch.device("cuda") s_q = 1 diff --git a/tests/v1/attention/test_sparse_mla_backends.py b/tests/v1/attention/test_sparse_mla_backends.py index f84951485310f..25de65a56b379 100644 --- a/tests/v1/attention/test_sparse_mla_backends.py +++ b/tests/v1/attention/test_sparse_mla_backends.py @@ -24,12 +24,7 @@ from vllm import _custom_ops as ops from vllm.attention.ops import flashmla from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.utils import cdiv -from vllm.v1.attention.backends.mla.flashmla_sparse import ( - FlashMLASparseBackend, - FlashMLASparseDecodeAndContextMetadata, - FlashMLASparseImpl, - FlashMLASparseMetadata, -) +from vllm.v1.attention.backends.mla.flashmla_sparse import FlashMLASparseBackend from vllm.v1.attention.backends.mla.indexer import split_prefill_chunks SPARSE_BACKEND_BATCH_SPECS = { @@ -116,59 +111,6 @@ def _quantize_dequantize_fp8_ds_mla( return dequant_kv_c, dequant_k_pe -def test_sparse_backend_metadata_registration(): - backend = FlashMLASparseBackend - - assert backend.get_name() == "FLASHMLA_SPARSE_VLLM_V1" - assert backend.get_metadata_cls() is FlashMLASparseMetadata - assert backend.get_impl_cls() is FlashMLASparseImpl - - dtype_list = backend.get_supported_dtypes() - assert torch.bfloat16 in dtype_list - - shape = backend.get_kv_cache_shape( - num_blocks=2, block_size=64, num_kv_heads=1, head_size=576 - ) - assert shape == (2, 64, 576) - - -def test_sparse_decode_metadata_filters_prefill_indices(): - prefill_context_lengths = torch.tensor([4, 2], dtype=torch.int32) - metadata = FlashMLASparseDecodeAndContextMetadata( - scheduler_metadata=torch.tensor([[0]], dtype=torch.int32), - num_splits=torch.tensor([1, 1], dtype=torch.int32), - cache_lens=torch.tensor([10, 12], dtype=torch.int32), - prefill_context_lengths=prefill_context_lengths, - ) - - indices = torch.tensor([[0, 3, 5], [1, 2, 4]], dtype=torch.int32) - - context_indices, new_token_indices = metadata.filter_prefill_indices(indices) - - expected_context = torch.tensor([[-1, -1, 5], [-1, -1, 4]], dtype=torch.int32) - expected_new_tokens = torch.tensor([[-1, -1, 1], [-1, 0, 2]], dtype=torch.int32) - - assert torch.equal(context_indices, expected_context) - assert torch.equal(new_token_indices, expected_new_tokens) - - -def test_sparse_impl_zero_fills_when_metadata_missing(): - impl = FlashMLASparseImpl.__new__(FlashMLASparseImpl) - dummy_layer = object() - q = torch.zeros((2, 1, 3)) - k_c = torch.zeros((2, 3)) - k_pe = torch.zeros((2, 1, 1)) - kv_cache = torch.zeros((1, 1, 1)) - output = torch.ones((2, 4)) - - result = FlashMLASparseImpl.forward( - impl, dummy_layer, q, k_c, k_pe, kv_cache, attn_metadata=None, output=output - ) - - assert result is output - assert torch.all(result == 0) - - @pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys())) @pytest.mark.parametrize("kv_cache_dtype", ["fp8_ds_mla", "auto"]) def test_sparse_backend_decode_correctness(dist_init, batch_name, kv_cache_dtype): @@ -198,11 +140,12 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name, kv_cache_dtype max_model_len=max_seqlen, num_gpu_blocks=max(2048, cdiv(total_cache_tokens, block_size) + 1), block_size=block_size, + hf_config_override={ + "index_topk": topk_tokens, + "attn_module_list_cfg": [{"topk_tokens": topk_tokens}], + }, ) model_config = vllm_config.model_config - model_config.hf_config = SimpleNamespace( - attn_module_list_cfg=[{"topk_tokens": topk_tokens}] - ) model_config.hf_text_config = SimpleNamespace( q_lora_rank=None, kv_lora_rank=kv_lora_rank, @@ -301,6 +244,7 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name, kv_cache_dtype sdpa_reference = torch.cat(reference_outputs, dim=0) vllm_config.cache_config.cache_dtype = kv_cache_dtype + vllm_config.model_config.hf_config.index_topk = topk_tokens common_attn_metadata = create_common_attn_metadata( batch_spec, @@ -352,7 +296,7 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name, kv_cache_dtype debug_indices = debug_indices.expand(metadata.num_actual_tokens, -1).clone() mock_indexer = SimpleNamespace(topk_indices_buffer=debug_indices) - ok, reason = flashmla.is_flashmla_supported() + ok, reason = flashmla.is_flashmla_sparse_supported() if not ok: pytest.skip(reason) @@ -397,9 +341,16 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name, kv_cache_dtype metadata.num_actual_tokens, num_heads * v_head_dim, dtype=dtype, device=device ) - backend_output = impl.forward( - layer, query_vllm, kv_c_vllm, k_pe_vllm, kv_cache, metadata, output=out_buffer - ) + with torch.inference_mode(): + backend_output = impl.forward( + layer, + query_vllm, + kv_c_vllm, + k_pe_vllm, + kv_cache, + metadata, + output=out_buffer, + ) assert backend_output.shape == sdpa_reference.shape assert backend_output.dtype == sdpa_reference.dtype diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index fcc0b6a5f7dea..feed66d33b586 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -181,6 +181,7 @@ def create_vllm_config( max_num_batched_tokens: int = 8192, enable_chunked_prefill: bool = True, add_mock_model_methods: bool = True, + hf_config_override: Optional[dict] = None, ) -> VllmConfig: """Create a VllmConfig for testing with reasonable defaults.""" @@ -235,6 +236,9 @@ def create_vllm_config( lambda self, i: 1.0 / model_config.get_head_size() ** 0.5, model_config ) + if hf_config_override: + model_config.hf_config.update(hf_config_override) + return VllmConfig( model_config=model_config, cache_config=cache_config, diff --git a/vllm/attention/ops/flashmla.py b/vllm/attention/ops/flashmla.py index 0fe01a51ec623..0bf354a95b1ca 100644 --- a/vllm/attention/ops/flashmla.py +++ b/vllm/attention/ops/flashmla.py @@ -31,21 +31,47 @@ else: _flashmla_extension_C_AVAILABLE = False -def is_flashmla_supported() -> tuple[bool, Optional[str]]: - """ - Return: is_supported_flag, unsupported_reason (optional). - """ - if not current_platform.is_cuda(): - return False, "FlashMLA is only supported on CUDA devices." - if current_platform.get_device_capability()[0] != 9: - return False, "FlashMLA is only supported on Hopper devices." +def _is_flashmla_available() -> tuple[bool, Optional[str]]: if not _flashmla_C_AVAILABLE: return ( False, "vllm._flashmla_C is not available, likely was not " "compiled due to insufficient nvcc version or a supported arch " - "(only sm90a currently) was not in the list of target arches to " - "compile for.", + "was not in the list of target arches to compile for.", + ) + if not _flashmla_extension_C_AVAILABLE: + return ( + False, + "vllm._flashmla_extension_C is not available, likely " + "was not compiled due to a build error.", + ) + + return True, None + + +def is_flashmla_dense_supported() -> tuple[bool, Optional[str]]: + """ + Return: is_supported_flag, unsupported_reason (optional). + """ + is_availble, maybe_reason = _is_flashmla_available() + if not is_availble: + return False, maybe_reason + if current_platform.get_device_capability()[0] != 9: + return False, "FlashMLA Dense is only supported on Hopper devices." + return True, None + + +def is_flashmla_sparse_supported() -> tuple[bool, Optional[str]]: + """ + Return: is_supported_flag, unsupported_reason (optional). + """ + is_availble, maybe_reason = _is_flashmla_available() + if not is_availble: + return False, maybe_reason + if current_platform.get_device_capability()[0] not in (9, 10): + return ( + False, + "FlashMLA Sparse is only supported on Hopper and Blackwell devices.", ) return True, None diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 20568e0d6c514..8a4565b4d1a03 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -146,11 +146,11 @@ class CudaPlatformBase(Platform): use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA" use_flashinfer_mla = envs.VLLM_ATTENTION_BACKEND == "FLASHINFER_MLA" - from vllm.attention.ops.flashmla import is_flashmla_supported + from vllm.attention.ops.flashmla import is_flashmla_dense_supported if ( use_flashmla - and is_flashmla_supported()[0] + and is_flashmla_dense_supported()[0] and cache_config.block_size != 64 ): cache_config.block_size = 64 @@ -256,7 +256,7 @@ class CudaPlatformBase(Platform): "Set VLLM_USE_V1=1 to enable them." ) - from vllm.attention.ops.flashmla import is_flashmla_supported + from vllm.attention.ops.flashmla import is_flashmla_dense_supported from vllm.attention.utils.fa_utils import flash_attn_supports_mla if use_sparse: @@ -277,7 +277,7 @@ class CudaPlatformBase(Platform): and block_size in [32, 64] ) use_flashmla = selected_backend == _Backend.FLASHMLA or ( - selected_backend is None and is_flashmla_supported()[0] + selected_backend is None and is_flashmla_dense_supported()[0] ) use_flashattn = selected_backend == _Backend.FLASH_ATTN_MLA or ( selected_backend is None and flash_attn_supports_mla() diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 56480832bcd1c..6ba2c682760cb 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -10,7 +10,7 @@ from vllm.attention.backends.abstract import AttentionLayer, AttentionType from vllm.attention.ops.flashmla import ( flash_mla_with_kvcache, get_mla_metadata, - is_flashmla_supported, + is_flashmla_dense_supported, ) from vllm.config import VllmConfig from vllm.logger import init_logger @@ -177,7 +177,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): **mla_args, ) - is_supported, reason = is_flashmla_supported() + is_supported, reason = is_flashmla_dense_supported() assert is_supported, reason unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 21d67f832b7ba..49c29de35da15 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import math from dataclasses import dataclass from typing import TYPE_CHECKING, ClassVar, Optional @@ -51,12 +50,6 @@ structured as: """ -def _lse2_to_lse(lse_base2: torch.Tensor) -> torch.Tensor: - # Convert base-2 LSE to natural-log LSE - # Keep FP32 for numerical stability during the merge. - return lse_base2.to(torch.float32) * math.log(2.0) - - class FlashMLASparseBackend(AttentionBackend): accept_output_buffer: bool = True @@ -100,36 +93,6 @@ class FlashMLASparseBackend(AttentionBackend): return [576] -@dataclass -class MLASparsePrefillMetadata: - # NOTE(Chen): not call it "FlashMLASparsePrefillMetadata" because - # the kernel is not from flashmla - block_table: torch.Tensor - has_context: bool = False - context_lens: Optional[torch.Tensor] = None - - -@dataclass -class FlashMLASparseDecodeAndContextMetadata: - scheduler_metadata: torch.Tensor = None - num_splits: torch.Tensor = None - cache_lens: torch.Tensor = None - prefill_context_lengths: Optional[torch.Tensor] = None - prefill_new_k_start_locs: Optional[torch.Tensor] = None - dummy_block_table: torch.Tensor = None - - def filter_prefill_indices( - self, indices: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: - assert self.prefill_context_lengths is not None - prefill_context_lengths = self.prefill_context_lengths.unsqueeze(-1) - context_indices = torch.where(indices < prefill_context_lengths, indices, -1) - new_token_indices = torch.where( - indices >= prefill_context_lengths, indices - prefill_context_lengths, -1 - ) - return context_indices, new_token_indices - - @dataclass class FlashMLASparseMetadata: num_reqs: int