mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-29 06:27:14 +08:00
[Misc] Clean up cruft from previous FlashMLA sparse implementation (#26125)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
31a4b3e6c4
commit
f80e7866c0
@ -165,10 +165,10 @@ def test_env(
|
||||
pytest.skip("FlashMLA only supports block_size 64")
|
||||
else:
|
||||
from vllm.v1.attention.backends.mla.flashmla import (
|
||||
is_flashmla_supported,
|
||||
is_flashmla_dense_supported,
|
||||
)
|
||||
|
||||
is_supported, _ = is_flashmla_supported()
|
||||
is_supported, _ = is_flashmla_dense_supported()
|
||||
if not is_supported:
|
||||
pytest.skip("FlashMLA not supported on this platform")
|
||||
else:
|
||||
|
||||
@ -10,7 +10,7 @@ import torch
|
||||
from vllm.attention.ops.flashmla import (
|
||||
flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
is_flashmla_supported,
|
||||
is_flashmla_dense_supported,
|
||||
)
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
@ -27,13 +27,15 @@ def cal_diff(
|
||||
|
||||
|
||||
FLASH_MLA_UNSUPPORTED_REASON = (
|
||||
is_flashmla_supported()[1]
|
||||
if not is_flashmla_supported()[0]
|
||||
is_flashmla_dense_supported()[1]
|
||||
if not is_flashmla_dense_supported()[0]
|
||||
else "FlashMLA is supported"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_flashmla_supported()[0], reason=FLASH_MLA_UNSUPPORTED_REASON)
|
||||
@pytest.mark.skipif(
|
||||
not is_flashmla_dense_supported()[0], reason=FLASH_MLA_UNSUPPORTED_REASON
|
||||
)
|
||||
@pytest.mark.parametrize("b", [128])
|
||||
@pytest.mark.parametrize("s_q", [1, 2])
|
||||
@pytest.mark.parametrize("mean_sk", [4096, 8192, 16384])
|
||||
|
||||
@ -4,19 +4,12 @@ import pytest
|
||||
import torch
|
||||
|
||||
|
||||
def _cuda_sm90_available() -> bool:
|
||||
if not torch.cuda.is_available():
|
||||
return False
|
||||
major, _ = torch.cuda.get_device_capability()
|
||||
return major == 9
|
||||
|
||||
|
||||
def test_sparse_flashmla_metadata_smoke():
|
||||
import vllm.attention.ops.flashmla as fm
|
||||
|
||||
ok, reason = fm.is_flashmla_supported()
|
||||
if not ok or not _cuda_sm90_available():
|
||||
pytest.skip(reason or "SM90 not available")
|
||||
ok, reason = fm.is_flashmla_sparse_supported()
|
||||
if not ok:
|
||||
pytest.skip(reason)
|
||||
|
||||
device = torch.device("cuda")
|
||||
batch_size = 1
|
||||
@ -43,9 +36,9 @@ def test_sparse_flashmla_metadata_smoke():
|
||||
def test_sparse_flashmla_decode_smoke():
|
||||
import vllm.attention.ops.flashmla as fm
|
||||
|
||||
ok, reason = fm.is_flashmla_supported()
|
||||
if not ok or not _cuda_sm90_available():
|
||||
pytest.skip(reason or "SM90 not available")
|
||||
ok, reason = fm.is_flashmla_sparse_supported()
|
||||
if not ok:
|
||||
pytest.skip(reason)
|
||||
|
||||
device = torch.device("cuda")
|
||||
batch_size = 1
|
||||
@ -106,9 +99,9 @@ def test_sparse_flashmla_decode_smoke():
|
||||
def test_sparse_flashmla_prefill_smoke():
|
||||
import vllm.attention.ops.flashmla as fm
|
||||
|
||||
ok, reason = fm.is_flashmla_supported()
|
||||
if not ok or not _cuda_sm90_available():
|
||||
pytest.skip(reason or "SM90 not available")
|
||||
ok, reason = fm.is_flashmla_sparse_supported()
|
||||
if not ok:
|
||||
pytest.skip(reason)
|
||||
|
||||
device = torch.device("cuda")
|
||||
s_q = 1
|
||||
|
||||
@ -24,12 +24,7 @@ from vllm import _custom_ops as ops
|
||||
from vllm.attention.ops import flashmla
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.attention.backends.mla.flashmla_sparse import (
|
||||
FlashMLASparseBackend,
|
||||
FlashMLASparseDecodeAndContextMetadata,
|
||||
FlashMLASparseImpl,
|
||||
FlashMLASparseMetadata,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.flashmla_sparse import FlashMLASparseBackend
|
||||
from vllm.v1.attention.backends.mla.indexer import split_prefill_chunks
|
||||
|
||||
SPARSE_BACKEND_BATCH_SPECS = {
|
||||
@ -116,59 +111,6 @@ def _quantize_dequantize_fp8_ds_mla(
|
||||
return dequant_kv_c, dequant_k_pe
|
||||
|
||||
|
||||
def test_sparse_backend_metadata_registration():
|
||||
backend = FlashMLASparseBackend
|
||||
|
||||
assert backend.get_name() == "FLASHMLA_SPARSE_VLLM_V1"
|
||||
assert backend.get_metadata_cls() is FlashMLASparseMetadata
|
||||
assert backend.get_impl_cls() is FlashMLASparseImpl
|
||||
|
||||
dtype_list = backend.get_supported_dtypes()
|
||||
assert torch.bfloat16 in dtype_list
|
||||
|
||||
shape = backend.get_kv_cache_shape(
|
||||
num_blocks=2, block_size=64, num_kv_heads=1, head_size=576
|
||||
)
|
||||
assert shape == (2, 64, 576)
|
||||
|
||||
|
||||
def test_sparse_decode_metadata_filters_prefill_indices():
|
||||
prefill_context_lengths = torch.tensor([4, 2], dtype=torch.int32)
|
||||
metadata = FlashMLASparseDecodeAndContextMetadata(
|
||||
scheduler_metadata=torch.tensor([[0]], dtype=torch.int32),
|
||||
num_splits=torch.tensor([1, 1], dtype=torch.int32),
|
||||
cache_lens=torch.tensor([10, 12], dtype=torch.int32),
|
||||
prefill_context_lengths=prefill_context_lengths,
|
||||
)
|
||||
|
||||
indices = torch.tensor([[0, 3, 5], [1, 2, 4]], dtype=torch.int32)
|
||||
|
||||
context_indices, new_token_indices = metadata.filter_prefill_indices(indices)
|
||||
|
||||
expected_context = torch.tensor([[-1, -1, 5], [-1, -1, 4]], dtype=torch.int32)
|
||||
expected_new_tokens = torch.tensor([[-1, -1, 1], [-1, 0, 2]], dtype=torch.int32)
|
||||
|
||||
assert torch.equal(context_indices, expected_context)
|
||||
assert torch.equal(new_token_indices, expected_new_tokens)
|
||||
|
||||
|
||||
def test_sparse_impl_zero_fills_when_metadata_missing():
|
||||
impl = FlashMLASparseImpl.__new__(FlashMLASparseImpl)
|
||||
dummy_layer = object()
|
||||
q = torch.zeros((2, 1, 3))
|
||||
k_c = torch.zeros((2, 3))
|
||||
k_pe = torch.zeros((2, 1, 1))
|
||||
kv_cache = torch.zeros((1, 1, 1))
|
||||
output = torch.ones((2, 4))
|
||||
|
||||
result = FlashMLASparseImpl.forward(
|
||||
impl, dummy_layer, q, k_c, k_pe, kv_cache, attn_metadata=None, output=output
|
||||
)
|
||||
|
||||
assert result is output
|
||||
assert torch.all(result == 0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys()))
|
||||
@pytest.mark.parametrize("kv_cache_dtype", ["fp8_ds_mla", "auto"])
|
||||
def test_sparse_backend_decode_correctness(dist_init, batch_name, kv_cache_dtype):
|
||||
@ -198,11 +140,12 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name, kv_cache_dtype
|
||||
max_model_len=max_seqlen,
|
||||
num_gpu_blocks=max(2048, cdiv(total_cache_tokens, block_size) + 1),
|
||||
block_size=block_size,
|
||||
hf_config_override={
|
||||
"index_topk": topk_tokens,
|
||||
"attn_module_list_cfg": [{"topk_tokens": topk_tokens}],
|
||||
},
|
||||
)
|
||||
model_config = vllm_config.model_config
|
||||
model_config.hf_config = SimpleNamespace(
|
||||
attn_module_list_cfg=[{"topk_tokens": topk_tokens}]
|
||||
)
|
||||
model_config.hf_text_config = SimpleNamespace(
|
||||
q_lora_rank=None,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
@ -301,6 +244,7 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name, kv_cache_dtype
|
||||
sdpa_reference = torch.cat(reference_outputs, dim=0)
|
||||
|
||||
vllm_config.cache_config.cache_dtype = kv_cache_dtype
|
||||
vllm_config.model_config.hf_config.index_topk = topk_tokens
|
||||
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec,
|
||||
@ -352,7 +296,7 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name, kv_cache_dtype
|
||||
debug_indices = debug_indices.expand(metadata.num_actual_tokens, -1).clone()
|
||||
mock_indexer = SimpleNamespace(topk_indices_buffer=debug_indices)
|
||||
|
||||
ok, reason = flashmla.is_flashmla_supported()
|
||||
ok, reason = flashmla.is_flashmla_sparse_supported()
|
||||
if not ok:
|
||||
pytest.skip(reason)
|
||||
|
||||
@ -397,9 +341,16 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name, kv_cache_dtype
|
||||
metadata.num_actual_tokens, num_heads * v_head_dim, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
backend_output = impl.forward(
|
||||
layer, query_vllm, kv_c_vllm, k_pe_vllm, kv_cache, metadata, output=out_buffer
|
||||
)
|
||||
with torch.inference_mode():
|
||||
backend_output = impl.forward(
|
||||
layer,
|
||||
query_vllm,
|
||||
kv_c_vllm,
|
||||
k_pe_vllm,
|
||||
kv_cache,
|
||||
metadata,
|
||||
output=out_buffer,
|
||||
)
|
||||
|
||||
assert backend_output.shape == sdpa_reference.shape
|
||||
assert backend_output.dtype == sdpa_reference.dtype
|
||||
|
||||
@ -181,6 +181,7 @@ def create_vllm_config(
|
||||
max_num_batched_tokens: int = 8192,
|
||||
enable_chunked_prefill: bool = True,
|
||||
add_mock_model_methods: bool = True,
|
||||
hf_config_override: Optional[dict] = None,
|
||||
) -> VllmConfig:
|
||||
"""Create a VllmConfig for testing with reasonable defaults."""
|
||||
|
||||
@ -235,6 +236,9 @@ def create_vllm_config(
|
||||
lambda self, i: 1.0 / model_config.get_head_size() ** 0.5, model_config
|
||||
)
|
||||
|
||||
if hf_config_override:
|
||||
model_config.hf_config.update(hf_config_override)
|
||||
|
||||
return VllmConfig(
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
|
||||
@ -31,21 +31,47 @@ else:
|
||||
_flashmla_extension_C_AVAILABLE = False
|
||||
|
||||
|
||||
def is_flashmla_supported() -> tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Return: is_supported_flag, unsupported_reason (optional).
|
||||
"""
|
||||
if not current_platform.is_cuda():
|
||||
return False, "FlashMLA is only supported on CUDA devices."
|
||||
if current_platform.get_device_capability()[0] != 9:
|
||||
return False, "FlashMLA is only supported on Hopper devices."
|
||||
def _is_flashmla_available() -> tuple[bool, Optional[str]]:
|
||||
if not _flashmla_C_AVAILABLE:
|
||||
return (
|
||||
False,
|
||||
"vllm._flashmla_C is not available, likely was not "
|
||||
"compiled due to insufficient nvcc version or a supported arch "
|
||||
"(only sm90a currently) was not in the list of target arches to "
|
||||
"compile for.",
|
||||
"was not in the list of target arches to compile for.",
|
||||
)
|
||||
if not _flashmla_extension_C_AVAILABLE:
|
||||
return (
|
||||
False,
|
||||
"vllm._flashmla_extension_C is not available, likely "
|
||||
"was not compiled due to a build error.",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
def is_flashmla_dense_supported() -> tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Return: is_supported_flag, unsupported_reason (optional).
|
||||
"""
|
||||
is_availble, maybe_reason = _is_flashmla_available()
|
||||
if not is_availble:
|
||||
return False, maybe_reason
|
||||
if current_platform.get_device_capability()[0] != 9:
|
||||
return False, "FlashMLA Dense is only supported on Hopper devices."
|
||||
return True, None
|
||||
|
||||
|
||||
def is_flashmla_sparse_supported() -> tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Return: is_supported_flag, unsupported_reason (optional).
|
||||
"""
|
||||
is_availble, maybe_reason = _is_flashmla_available()
|
||||
if not is_availble:
|
||||
return False, maybe_reason
|
||||
if current_platform.get_device_capability()[0] not in (9, 10):
|
||||
return (
|
||||
False,
|
||||
"FlashMLA Sparse is only supported on Hopper and Blackwell devices.",
|
||||
)
|
||||
return True, None
|
||||
|
||||
|
||||
@ -146,11 +146,11 @@ class CudaPlatformBase(Platform):
|
||||
use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA"
|
||||
use_flashinfer_mla = envs.VLLM_ATTENTION_BACKEND == "FLASHINFER_MLA"
|
||||
|
||||
from vllm.attention.ops.flashmla import is_flashmla_supported
|
||||
from vllm.attention.ops.flashmla import is_flashmla_dense_supported
|
||||
|
||||
if (
|
||||
use_flashmla
|
||||
and is_flashmla_supported()[0]
|
||||
and is_flashmla_dense_supported()[0]
|
||||
and cache_config.block_size != 64
|
||||
):
|
||||
cache_config.block_size = 64
|
||||
@ -256,7 +256,7 @@ class CudaPlatformBase(Platform):
|
||||
"Set VLLM_USE_V1=1 to enable them."
|
||||
)
|
||||
|
||||
from vllm.attention.ops.flashmla import is_flashmla_supported
|
||||
from vllm.attention.ops.flashmla import is_flashmla_dense_supported
|
||||
from vllm.attention.utils.fa_utils import flash_attn_supports_mla
|
||||
|
||||
if use_sparse:
|
||||
@ -277,7 +277,7 @@ class CudaPlatformBase(Platform):
|
||||
and block_size in [32, 64]
|
||||
)
|
||||
use_flashmla = selected_backend == _Backend.FLASHMLA or (
|
||||
selected_backend is None and is_flashmla_supported()[0]
|
||||
selected_backend is None and is_flashmla_dense_supported()[0]
|
||||
)
|
||||
use_flashattn = selected_backend == _Backend.FLASH_ATTN_MLA or (
|
||||
selected_backend is None and flash_attn_supports_mla()
|
||||
|
||||
@ -10,7 +10,7 @@ from vllm.attention.backends.abstract import AttentionLayer, AttentionType
|
||||
from vllm.attention.ops.flashmla import (
|
||||
flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
is_flashmla_supported,
|
||||
is_flashmla_dense_supported,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
@ -177,7 +177,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
**mla_args,
|
||||
)
|
||||
|
||||
is_supported, reason = is_flashmla_supported()
|
||||
is_supported, reason = is_flashmla_dense_supported()
|
||||
assert is_supported, reason
|
||||
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, ClassVar, Optional
|
||||
|
||||
@ -51,12 +50,6 @@ structured as:
|
||||
"""
|
||||
|
||||
|
||||
def _lse2_to_lse(lse_base2: torch.Tensor) -> torch.Tensor:
|
||||
# Convert base-2 LSE to natural-log LSE
|
||||
# Keep FP32 for numerical stability during the merge.
|
||||
return lse_base2.to(torch.float32) * math.log(2.0)
|
||||
|
||||
|
||||
class FlashMLASparseBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@ -100,36 +93,6 @@ class FlashMLASparseBackend(AttentionBackend):
|
||||
return [576]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLASparsePrefillMetadata:
|
||||
# NOTE(Chen): not call it "FlashMLASparsePrefillMetadata" because
|
||||
# the kernel is not from flashmla
|
||||
block_table: torch.Tensor
|
||||
has_context: bool = False
|
||||
context_lens: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLASparseDecodeAndContextMetadata:
|
||||
scheduler_metadata: torch.Tensor = None
|
||||
num_splits: torch.Tensor = None
|
||||
cache_lens: torch.Tensor = None
|
||||
prefill_context_lengths: Optional[torch.Tensor] = None
|
||||
prefill_new_k_start_locs: Optional[torch.Tensor] = None
|
||||
dummy_block_table: torch.Tensor = None
|
||||
|
||||
def filter_prefill_indices(
|
||||
self, indices: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.prefill_context_lengths is not None
|
||||
prefill_context_lengths = self.prefill_context_lengths.unsqueeze(-1)
|
||||
context_indices = torch.where(indices < prefill_context_lengths, indices, -1)
|
||||
new_token_indices = torch.where(
|
||||
indices >= prefill_context_lengths, indices - prefill_context_lengths, -1
|
||||
)
|
||||
return context_indices, new_token_indices
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLASparseMetadata:
|
||||
num_reqs: int
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user