[Misc] Clean up cruft from previous FlashMLA sparse implementation (#26125)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson 2025-10-07 22:09:34 -04:00 committed by GitHub
parent 31a4b3e6c4
commit f80e7866c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 80 additions and 141 deletions

View File

@ -165,10 +165,10 @@ def test_env(
pytest.skip("FlashMLA only supports block_size 64")
else:
from vllm.v1.attention.backends.mla.flashmla import (
is_flashmla_supported,
is_flashmla_dense_supported,
)
is_supported, _ = is_flashmla_supported()
is_supported, _ = is_flashmla_dense_supported()
if not is_supported:
pytest.skip("FlashMLA not supported on this platform")
else:

View File

@ -10,7 +10,7 @@ import torch
from vllm.attention.ops.flashmla import (
flash_mla_with_kvcache,
get_mla_metadata,
is_flashmla_supported,
is_flashmla_dense_supported,
)
from vllm.triton_utils import triton
@ -27,13 +27,15 @@ def cal_diff(
FLASH_MLA_UNSUPPORTED_REASON = (
is_flashmla_supported()[1]
if not is_flashmla_supported()[0]
is_flashmla_dense_supported()[1]
if not is_flashmla_dense_supported()[0]
else "FlashMLA is supported"
)
@pytest.mark.skipif(not is_flashmla_supported()[0], reason=FLASH_MLA_UNSUPPORTED_REASON)
@pytest.mark.skipif(
not is_flashmla_dense_supported()[0], reason=FLASH_MLA_UNSUPPORTED_REASON
)
@pytest.mark.parametrize("b", [128])
@pytest.mark.parametrize("s_q", [1, 2])
@pytest.mark.parametrize("mean_sk", [4096, 8192, 16384])

View File

@ -4,19 +4,12 @@ import pytest
import torch
def _cuda_sm90_available() -> bool:
if not torch.cuda.is_available():
return False
major, _ = torch.cuda.get_device_capability()
return major == 9
def test_sparse_flashmla_metadata_smoke():
import vllm.attention.ops.flashmla as fm
ok, reason = fm.is_flashmla_supported()
if not ok or not _cuda_sm90_available():
pytest.skip(reason or "SM90 not available")
ok, reason = fm.is_flashmla_sparse_supported()
if not ok:
pytest.skip(reason)
device = torch.device("cuda")
batch_size = 1
@ -43,9 +36,9 @@ def test_sparse_flashmla_metadata_smoke():
def test_sparse_flashmla_decode_smoke():
import vllm.attention.ops.flashmla as fm
ok, reason = fm.is_flashmla_supported()
if not ok or not _cuda_sm90_available():
pytest.skip(reason or "SM90 not available")
ok, reason = fm.is_flashmla_sparse_supported()
if not ok:
pytest.skip(reason)
device = torch.device("cuda")
batch_size = 1
@ -106,9 +99,9 @@ def test_sparse_flashmla_decode_smoke():
def test_sparse_flashmla_prefill_smoke():
import vllm.attention.ops.flashmla as fm
ok, reason = fm.is_flashmla_supported()
if not ok or not _cuda_sm90_available():
pytest.skip(reason or "SM90 not available")
ok, reason = fm.is_flashmla_sparse_supported()
if not ok:
pytest.skip(reason)
device = torch.device("cuda")
s_q = 1

View File

@ -24,12 +24,7 @@ from vllm import _custom_ops as ops
from vllm.attention.ops import flashmla
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.utils import cdiv
from vllm.v1.attention.backends.mla.flashmla_sparse import (
FlashMLASparseBackend,
FlashMLASparseDecodeAndContextMetadata,
FlashMLASparseImpl,
FlashMLASparseMetadata,
)
from vllm.v1.attention.backends.mla.flashmla_sparse import FlashMLASparseBackend
from vllm.v1.attention.backends.mla.indexer import split_prefill_chunks
SPARSE_BACKEND_BATCH_SPECS = {
@ -116,59 +111,6 @@ def _quantize_dequantize_fp8_ds_mla(
return dequant_kv_c, dequant_k_pe
def test_sparse_backend_metadata_registration():
backend = FlashMLASparseBackend
assert backend.get_name() == "FLASHMLA_SPARSE_VLLM_V1"
assert backend.get_metadata_cls() is FlashMLASparseMetadata
assert backend.get_impl_cls() is FlashMLASparseImpl
dtype_list = backend.get_supported_dtypes()
assert torch.bfloat16 in dtype_list
shape = backend.get_kv_cache_shape(
num_blocks=2, block_size=64, num_kv_heads=1, head_size=576
)
assert shape == (2, 64, 576)
def test_sparse_decode_metadata_filters_prefill_indices():
prefill_context_lengths = torch.tensor([4, 2], dtype=torch.int32)
metadata = FlashMLASparseDecodeAndContextMetadata(
scheduler_metadata=torch.tensor([[0]], dtype=torch.int32),
num_splits=torch.tensor([1, 1], dtype=torch.int32),
cache_lens=torch.tensor([10, 12], dtype=torch.int32),
prefill_context_lengths=prefill_context_lengths,
)
indices = torch.tensor([[0, 3, 5], [1, 2, 4]], dtype=torch.int32)
context_indices, new_token_indices = metadata.filter_prefill_indices(indices)
expected_context = torch.tensor([[-1, -1, 5], [-1, -1, 4]], dtype=torch.int32)
expected_new_tokens = torch.tensor([[-1, -1, 1], [-1, 0, 2]], dtype=torch.int32)
assert torch.equal(context_indices, expected_context)
assert torch.equal(new_token_indices, expected_new_tokens)
def test_sparse_impl_zero_fills_when_metadata_missing():
impl = FlashMLASparseImpl.__new__(FlashMLASparseImpl)
dummy_layer = object()
q = torch.zeros((2, 1, 3))
k_c = torch.zeros((2, 3))
k_pe = torch.zeros((2, 1, 1))
kv_cache = torch.zeros((1, 1, 1))
output = torch.ones((2, 4))
result = FlashMLASparseImpl.forward(
impl, dummy_layer, q, k_c, k_pe, kv_cache, attn_metadata=None, output=output
)
assert result is output
assert torch.all(result == 0)
@pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys()))
@pytest.mark.parametrize("kv_cache_dtype", ["fp8_ds_mla", "auto"])
def test_sparse_backend_decode_correctness(dist_init, batch_name, kv_cache_dtype):
@ -198,11 +140,12 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name, kv_cache_dtype
max_model_len=max_seqlen,
num_gpu_blocks=max(2048, cdiv(total_cache_tokens, block_size) + 1),
block_size=block_size,
hf_config_override={
"index_topk": topk_tokens,
"attn_module_list_cfg": [{"topk_tokens": topk_tokens}],
},
)
model_config = vllm_config.model_config
model_config.hf_config = SimpleNamespace(
attn_module_list_cfg=[{"topk_tokens": topk_tokens}]
)
model_config.hf_text_config = SimpleNamespace(
q_lora_rank=None,
kv_lora_rank=kv_lora_rank,
@ -301,6 +244,7 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name, kv_cache_dtype
sdpa_reference = torch.cat(reference_outputs, dim=0)
vllm_config.cache_config.cache_dtype = kv_cache_dtype
vllm_config.model_config.hf_config.index_topk = topk_tokens
common_attn_metadata = create_common_attn_metadata(
batch_spec,
@ -352,7 +296,7 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name, kv_cache_dtype
debug_indices = debug_indices.expand(metadata.num_actual_tokens, -1).clone()
mock_indexer = SimpleNamespace(topk_indices_buffer=debug_indices)
ok, reason = flashmla.is_flashmla_supported()
ok, reason = flashmla.is_flashmla_sparse_supported()
if not ok:
pytest.skip(reason)
@ -397,9 +341,16 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name, kv_cache_dtype
metadata.num_actual_tokens, num_heads * v_head_dim, dtype=dtype, device=device
)
backend_output = impl.forward(
layer, query_vllm, kv_c_vllm, k_pe_vllm, kv_cache, metadata, output=out_buffer
)
with torch.inference_mode():
backend_output = impl.forward(
layer,
query_vllm,
kv_c_vllm,
k_pe_vllm,
kv_cache,
metadata,
output=out_buffer,
)
assert backend_output.shape == sdpa_reference.shape
assert backend_output.dtype == sdpa_reference.dtype

View File

@ -181,6 +181,7 @@ def create_vllm_config(
max_num_batched_tokens: int = 8192,
enable_chunked_prefill: bool = True,
add_mock_model_methods: bool = True,
hf_config_override: Optional[dict] = None,
) -> VllmConfig:
"""Create a VllmConfig for testing with reasonable defaults."""
@ -235,6 +236,9 @@ def create_vllm_config(
lambda self, i: 1.0 / model_config.get_head_size() ** 0.5, model_config
)
if hf_config_override:
model_config.hf_config.update(hf_config_override)
return VllmConfig(
model_config=model_config,
cache_config=cache_config,

View File

@ -31,21 +31,47 @@ else:
_flashmla_extension_C_AVAILABLE = False
def is_flashmla_supported() -> tuple[bool, Optional[str]]:
"""
Return: is_supported_flag, unsupported_reason (optional).
"""
if not current_platform.is_cuda():
return False, "FlashMLA is only supported on CUDA devices."
if current_platform.get_device_capability()[0] != 9:
return False, "FlashMLA is only supported on Hopper devices."
def _is_flashmla_available() -> tuple[bool, Optional[str]]:
if not _flashmla_C_AVAILABLE:
return (
False,
"vllm._flashmla_C is not available, likely was not "
"compiled due to insufficient nvcc version or a supported arch "
"(only sm90a currently) was not in the list of target arches to "
"compile for.",
"was not in the list of target arches to compile for.",
)
if not _flashmla_extension_C_AVAILABLE:
return (
False,
"vllm._flashmla_extension_C is not available, likely "
"was not compiled due to a build error.",
)
return True, None
def is_flashmla_dense_supported() -> tuple[bool, Optional[str]]:
"""
Return: is_supported_flag, unsupported_reason (optional).
"""
is_availble, maybe_reason = _is_flashmla_available()
if not is_availble:
return False, maybe_reason
if current_platform.get_device_capability()[0] != 9:
return False, "FlashMLA Dense is only supported on Hopper devices."
return True, None
def is_flashmla_sparse_supported() -> tuple[bool, Optional[str]]:
"""
Return: is_supported_flag, unsupported_reason (optional).
"""
is_availble, maybe_reason = _is_flashmla_available()
if not is_availble:
return False, maybe_reason
if current_platform.get_device_capability()[0] not in (9, 10):
return (
False,
"FlashMLA Sparse is only supported on Hopper and Blackwell devices.",
)
return True, None

View File

@ -146,11 +146,11 @@ class CudaPlatformBase(Platform):
use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA"
use_flashinfer_mla = envs.VLLM_ATTENTION_BACKEND == "FLASHINFER_MLA"
from vllm.attention.ops.flashmla import is_flashmla_supported
from vllm.attention.ops.flashmla import is_flashmla_dense_supported
if (
use_flashmla
and is_flashmla_supported()[0]
and is_flashmla_dense_supported()[0]
and cache_config.block_size != 64
):
cache_config.block_size = 64
@ -256,7 +256,7 @@ class CudaPlatformBase(Platform):
"Set VLLM_USE_V1=1 to enable them."
)
from vllm.attention.ops.flashmla import is_flashmla_supported
from vllm.attention.ops.flashmla import is_flashmla_dense_supported
from vllm.attention.utils.fa_utils import flash_attn_supports_mla
if use_sparse:
@ -277,7 +277,7 @@ class CudaPlatformBase(Platform):
and block_size in [32, 64]
)
use_flashmla = selected_backend == _Backend.FLASHMLA or (
selected_backend is None and is_flashmla_supported()[0]
selected_backend is None and is_flashmla_dense_supported()[0]
)
use_flashattn = selected_backend == _Backend.FLASH_ATTN_MLA or (
selected_backend is None and flash_attn_supports_mla()

View File

@ -10,7 +10,7 @@ from vllm.attention.backends.abstract import AttentionLayer, AttentionType
from vllm.attention.ops.flashmla import (
flash_mla_with_kvcache,
get_mla_metadata,
is_flashmla_supported,
is_flashmla_dense_supported,
)
from vllm.config import VllmConfig
from vllm.logger import init_logger
@ -177,7 +177,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
**mla_args,
)
is_supported, reason = is_flashmla_supported()
is_supported, reason = is_flashmla_dense_supported()
assert is_supported, reason
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]

View File

@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Optional
@ -51,12 +50,6 @@ structured as:
"""
def _lse2_to_lse(lse_base2: torch.Tensor) -> torch.Tensor:
# Convert base-2 LSE to natural-log LSE
# Keep FP32 for numerical stability during the merge.
return lse_base2.to(torch.float32) * math.log(2.0)
class FlashMLASparseBackend(AttentionBackend):
accept_output_buffer: bool = True
@ -100,36 +93,6 @@ class FlashMLASparseBackend(AttentionBackend):
return [576]
@dataclass
class MLASparsePrefillMetadata:
# NOTE(Chen): not call it "FlashMLASparsePrefillMetadata" because
# the kernel is not from flashmla
block_table: torch.Tensor
has_context: bool = False
context_lens: Optional[torch.Tensor] = None
@dataclass
class FlashMLASparseDecodeAndContextMetadata:
scheduler_metadata: torch.Tensor = None
num_splits: torch.Tensor = None
cache_lens: torch.Tensor = None
prefill_context_lengths: Optional[torch.Tensor] = None
prefill_new_k_start_locs: Optional[torch.Tensor] = None
dummy_block_table: torch.Tensor = None
def filter_prefill_indices(
self, indices: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
assert self.prefill_context_lengths is not None
prefill_context_lengths = self.prefill_context_lengths.unsqueeze(-1)
context_indices = torch.where(indices < prefill_context_lengths, indices, -1)
new_token_indices = torch.where(
indices >= prefill_context_lengths, indices - prefill_context_lengths, -1
)
return context_indices, new_token_indices
@dataclass
class FlashMLASparseMetadata:
num_reqs: int