mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-11 01:17:04 +08:00
[Attention][Spec Decode] FlashMLA spec decode support (#26541)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
parent
87efc681db
commit
82af928c41
@ -1,6 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for v1 MLA backends without GPUModelRunner dependency."""
|
||||
"""Tests for v1 MLA backends without GPUModelRunner dependency.
|
||||
|
||||
Known Issues:
|
||||
- FLASH_ATTN_MLA backend occasionally produces NaN values in
|
||||
test_backend_correctness[mixed_small] when run after
|
||||
test_backend_correctness[small_prefill], but passes when run alone.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -14,6 +20,8 @@ from tests.v1.attention.utils import (
|
||||
)
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.ops.flashmla import is_flashmla_dense_supported
|
||||
from vllm.config.vllm import set_current_vllm_config
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
@ -29,6 +37,10 @@ BACKENDS_TO_TEST = [
|
||||
if not torch.cuda.is_available() or torch.cuda.get_device_properties(0).major < 10:
|
||||
BACKENDS_TO_TEST.remove(_Backend.CUTLASS_MLA)
|
||||
|
||||
# Remove FLASHMLA from the list if not supported
|
||||
if not is_flashmla_dense_supported()[0]:
|
||||
BACKENDS_TO_TEST.remove(_Backend.FLASHMLA)
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
|
||||
@ -66,6 +78,12 @@ BATCH_SPECS = {
|
||||
"large_prefill": BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8),
|
||||
"single_decode": BatchSpec(seq_lens=[1024], query_lens=[1]),
|
||||
"single_prefill": BatchSpec(seq_lens=[1024], query_lens=[64]),
|
||||
"spec_decode_small": BatchSpec(
|
||||
seq_lens=[128, 256, 512, 1024], query_lens=[4, 4, 4, 4]
|
||||
),
|
||||
"spec_decode_medium": BatchSpec(
|
||||
seq_lens=[512, 1024, 2048, 512, 1024, 2048], query_lens=[8, 8, 8, 8, 8, 8]
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@ -239,61 +257,64 @@ def run_attention_backend(
|
||||
|
||||
builder_cls, impl_cls = try_get_attention_backend(backend)
|
||||
|
||||
# Build metadata
|
||||
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
|
||||
attn_metadata = builder.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
)
|
||||
# Set the current vllm config so that get_current_vllm_config() works
|
||||
# in the backend implementations
|
||||
with set_current_vllm_config(vllm_config):
|
||||
# Build metadata
|
||||
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
|
||||
attn_metadata = builder.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
)
|
||||
|
||||
# Instantiate MLA implementation
|
||||
num_heads = vllm_config.model_config.get_num_attention_heads(
|
||||
vllm_config.parallel_config
|
||||
)
|
||||
num_kv_heads = vllm_config.model_config.get_num_kv_heads(
|
||||
vllm_config.parallel_config
|
||||
)
|
||||
head_size = vllm_config.model_config.get_head_size()
|
||||
scale = 1.0 / (head_size**0.5)
|
||||
impl = impl_cls(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=num_kv_heads,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
kv_cache_dtype="auto",
|
||||
logits_soft_cap=None,
|
||||
attn_type="decoder",
|
||||
kv_sharing_target_layer_name=None,
|
||||
q_lora_rank=None,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
qk_nope_head_dim=qk_nope_head_dim,
|
||||
qk_rope_head_dim=qk_rope_head_dim,
|
||||
qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
|
||||
v_head_dim=v_head_dim,
|
||||
kv_b_proj=mock_kv_b_proj,
|
||||
)
|
||||
# Instantiate MLA implementation
|
||||
num_heads = vllm_config.model_config.get_num_attention_heads(
|
||||
vllm_config.parallel_config
|
||||
)
|
||||
num_kv_heads = vllm_config.model_config.get_num_kv_heads(
|
||||
vllm_config.parallel_config
|
||||
)
|
||||
head_size = vllm_config.model_config.get_head_size()
|
||||
scale = 1.0 / (head_size**0.5)
|
||||
impl = impl_cls(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=num_kv_heads,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
kv_cache_dtype="auto",
|
||||
logits_soft_cap=None,
|
||||
attn_type="decoder",
|
||||
kv_sharing_target_layer_name=None,
|
||||
q_lora_rank=None,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
qk_nope_head_dim=qk_nope_head_dim,
|
||||
qk_rope_head_dim=qk_rope_head_dim,
|
||||
qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
|
||||
v_head_dim=v_head_dim,
|
||||
kv_b_proj=mock_kv_b_proj,
|
||||
)
|
||||
|
||||
# Process weights to create W_UK_T and W_UV attributes needed by MLA
|
||||
act_dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
|
||||
impl.process_weights_after_loading(act_dtype)
|
||||
# Process weights to create W_UK_T and W_UV attributes needed by MLA
|
||||
act_dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
|
||||
impl.process_weights_after_loading(act_dtype)
|
||||
|
||||
# Create mock layer and output buffer
|
||||
mock_layer = MockAttentionLayer(device)
|
||||
num_tokens = query.shape[0]
|
||||
output = torch.empty(
|
||||
num_tokens, num_heads * v_head_dim, dtype=query.dtype, device=query.device
|
||||
)
|
||||
# Create mock layer and output buffer
|
||||
mock_layer = MockAttentionLayer(device)
|
||||
num_tokens = query.shape[0]
|
||||
output = torch.empty(
|
||||
num_tokens, num_heads * v_head_dim, dtype=query.dtype, device=query.device
|
||||
)
|
||||
|
||||
# Run forward pass
|
||||
# NOTE: The query, key, and value are already shaped correctly
|
||||
# in the calling test function.
|
||||
output = impl.forward(
|
||||
mock_layer, query, kv_c, k_pe, kv_cache, attn_metadata, output=output
|
||||
)
|
||||
# Run forward pass
|
||||
# NOTE: The query, key, and value are already shaped correctly
|
||||
# in the calling test function.
|
||||
output = impl.forward(
|
||||
mock_layer, query, kv_c, k_pe, kv_cache, attn_metadata, output=output
|
||||
)
|
||||
|
||||
return output
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -309,6 +330,8 @@ def run_attention_backend(
|
||||
"large_prefill",
|
||||
"single_decode",
|
||||
"single_prefill",
|
||||
"spec_decode_small",
|
||||
"spec_decode_medium",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("model", ["deepseek-ai/DeepSeek-V2-Lite-Chat"])
|
||||
@ -328,10 +351,39 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
simulated paged KV cache.
|
||||
5. Comparing the vLLM backend's output to the ground-truth SDPA output.
|
||||
"""
|
||||
from vllm.v1.attention.backends.mla.common import QueryLenSupport
|
||||
|
||||
batch_spec = BATCH_SPECS[batch_spec_name]
|
||||
vllm_config = create_vllm_config(
|
||||
model_name=model, max_model_len=max(batch_spec.seq_lens), num_gpu_blocks=2048
|
||||
is_spec_decode_test = batch_spec_name.startswith("spec_decode")
|
||||
spec_decode_backends = {_Backend.FLASH_ATTN_MLA, _Backend.FLASHMLA}
|
||||
|
||||
block_size = 16
|
||||
required_blocks = sum(
|
||||
(seq_len + block_size - 1) // block_size for seq_len in batch_spec.seq_lens
|
||||
)
|
||||
# Add 1 for null block at index 0, and some buffer
|
||||
num_gpu_blocks = required_blocks + 1 + 100
|
||||
|
||||
vllm_config = create_vllm_config(
|
||||
model_name=model,
|
||||
max_model_len=max(batch_spec.seq_lens),
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
# For spec decode tests, add a speculative_config to set the reorder_batch_threshold
|
||||
if is_spec_decode_test:
|
||||
from vllm.config import SpeculativeConfig
|
||||
|
||||
# Get the query length from the batch spec (they should all be uniform)
|
||||
query_len = batch_spec.query_lens[0]
|
||||
# Set num_speculative_tokens to query_len - 1
|
||||
# (since threshold is 1 + num_spec_tokens)
|
||||
# Use ngram method which doesn't require a draft model
|
||||
vllm_config.speculative_config = SpeculativeConfig(
|
||||
method="ngram", num_speculative_tokens=query_len - 1
|
||||
)
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
|
||||
@ -395,11 +447,37 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
# K_PE (rope component): [s_len, 1, qk_rope_head_dim]
|
||||
k_pe_full = torch.randn(s_len, 1, qk_rope_head_dim, dtype=dtype, device=device)
|
||||
|
||||
# Determine if this is decode or prefill
|
||||
# Determine if this sequence uses the decode pipeline or prefill
|
||||
# pipeline for each backend
|
||||
# NOTE: For spec decode tests with uniform query_len > 1, backends that
|
||||
# support spec decode (FLASH_ATTN_MLA with varlen support, FLASHMLA with
|
||||
# uniform support) will use the decode pipeline (MQA-style), while
|
||||
# backends that only support single-token queries will use the prefill
|
||||
# pipeline (MHA-style). This ensures the reference implementation
|
||||
# matches each backend's actual decode/prefill pipeline path.
|
||||
is_decode = []
|
||||
for i, backend in enumerate(BACKENDS_TO_TEST):
|
||||
for backend_idx, backend in enumerate(BACKENDS_TO_TEST):
|
||||
builder_cls, _ = try_get_attention_backend(backend)
|
||||
is_decode.append(q_len <= builder_cls.reorder_batch_threshold)
|
||||
if is_spec_decode_test:
|
||||
query_len_support = getattr(
|
||||
builder_cls, "query_len_support", QueryLenSupport.SINGLE_ONLY
|
||||
)
|
||||
supports_spec = query_len_support != QueryLenSupport.SINGLE_ONLY
|
||||
is_decode.append(supports_spec)
|
||||
else:
|
||||
threshold = getattr(builder_cls, "reorder_batch_threshold", None)
|
||||
query_len_support = getattr(
|
||||
builder_cls, "query_len_support", QueryLenSupport.SINGLE_ONLY
|
||||
)
|
||||
within_threshold = q_len <= threshold if threshold else False
|
||||
if (
|
||||
within_threshold
|
||||
and query_len_support == QueryLenSupport.UNIFORM
|
||||
and i > 0
|
||||
):
|
||||
first_q_len = query_lens[0]
|
||||
within_threshold = q_len == first_q_len
|
||||
is_decode.append(within_threshold)
|
||||
|
||||
# Split q into nope and rope components
|
||||
q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
|
||||
@ -478,11 +556,11 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
sdpa_out_i_prefill = sdpa_out_i_prefill.transpose(1, 2).squeeze(0)
|
||||
sdpa_out_i_prefill = sdpa_out_i_prefill.flatten(start_dim=-2)
|
||||
|
||||
for i, backend in enumerate(BACKENDS_TO_TEST):
|
||||
if is_decode[i]:
|
||||
all_sdpa_outputs[i].append(sdpa_out_i_decode)
|
||||
for backend_idx, backend in enumerate(BACKENDS_TO_TEST):
|
||||
if is_decode[backend_idx]:
|
||||
all_sdpa_outputs[backend_idx].append(sdpa_out_i_decode)
|
||||
else:
|
||||
all_sdpa_outputs[i].append(sdpa_out_i_prefill)
|
||||
all_sdpa_outputs[backend_idx].append(sdpa_out_i_prefill)
|
||||
|
||||
# Inputs for vLLM MLA backends are just the new tokens
|
||||
all_q_vllm.append(q_c)
|
||||
@ -497,9 +575,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
query_vllm = torch.cat(all_q_vllm, dim=0)
|
||||
kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0)
|
||||
k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0)
|
||||
sdpa_outputs = []
|
||||
for i, backend in enumerate(BACKENDS_TO_TEST):
|
||||
sdpa_outputs.append(torch.cat(all_sdpa_outputs[i], dim=0))
|
||||
sdpa_outputs = {}
|
||||
for backend_idx, backend in enumerate(BACKENDS_TO_TEST):
|
||||
sdpa_outputs[backend] = torch.cat(all_sdpa_outputs[backend_idx], dim=0)
|
||||
|
||||
# Create mock kv_b_proj using the same weights as reference implementation
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
@ -516,7 +594,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||
kv_lora_rank, num_q_heads * (qk_nope_head_dim + v_head_dim)
|
||||
)
|
||||
mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T)
|
||||
mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T, requires_grad=False)
|
||||
|
||||
# Create metadata using original batch spec
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
@ -537,7 +615,11 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
)
|
||||
|
||||
# 4. Run vLLM backends and compare
|
||||
for i, backend_name in enumerate(BACKENDS_TO_TEST):
|
||||
for backend_idx, backend_name in enumerate(BACKENDS_TO_TEST):
|
||||
# Skip backends that don't support spec decode for spec decode tests
|
||||
if is_spec_decode_test and backend_name not in spec_decode_backends:
|
||||
continue
|
||||
|
||||
backend_output = run_attention_backend(
|
||||
backend_name,
|
||||
kv_cache_spec,
|
||||
@ -556,14 +638,17 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
mock_kv_b_proj,
|
||||
)
|
||||
|
||||
# Use backend_idx to get the correct SDPA output for this backend
|
||||
expected_output = sdpa_outputs[backend_name]
|
||||
|
||||
# Check shape and dtype consistency
|
||||
assert backend_output.shape == sdpa_outputs[i].shape, (
|
||||
assert backend_output.shape == expected_output.shape, (
|
||||
f"[{backend_name}] shape {backend_output.shape} != "
|
||||
f"SDPA shape {sdpa_outputs[i].shape}"
|
||||
f"SDPA shape {expected_output.shape}"
|
||||
)
|
||||
assert backend_output.dtype == sdpa_outputs[i].dtype, (
|
||||
assert backend_output.dtype == expected_output.dtype, (
|
||||
f"[{backend_name}] dtype {backend_output.dtype} != "
|
||||
f"SDPA dtype {sdpa_outputs[i].dtype}"
|
||||
f"SDPA dtype {expected_output.dtype}"
|
||||
)
|
||||
|
||||
assert torch.isfinite(backend_output).all(), (
|
||||
@ -574,12 +659,12 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
rtol = 1e-2
|
||||
atol = 5e-1
|
||||
|
||||
max_diff = torch.max(torch.abs(backend_output - sdpa_outputs[i])).item()
|
||||
max_diff = torch.max(torch.abs(backend_output - expected_output)).item()
|
||||
max_rel_diff = torch.max(
|
||||
torch.abs(backend_output - sdpa_outputs[i]) / torch.abs(sdpa_outputs[i])
|
||||
torch.abs(backend_output - expected_output) / torch.abs(expected_output)
|
||||
).item()
|
||||
all_close = torch.allclose(
|
||||
backend_output, sdpa_outputs[i], rtol=rtol, atol=atol
|
||||
backend_output, expected_output, rtol=rtol, atol=atol
|
||||
)
|
||||
|
||||
assert all_close, (
|
||||
|
||||
@ -190,6 +190,7 @@ return curr_o @ W_O
|
||||
import functools
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import ClassVar, Generic, TypeVar
|
||||
|
||||
import torch
|
||||
@ -227,6 +228,24 @@ from vllm.v1.attention.backends.utils import (
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
|
||||
class QueryLenSupport(Enum):
|
||||
"""Defines the level of query length support for an attention backend's
|
||||
decode pipeline.
|
||||
|
||||
- SINGLE_ONLY: Decode pipeline only supports single-token queries
|
||||
(query_len=1)
|
||||
- UNIFORM: Decode pipeline supports uniform multi-token queries
|
||||
(all requests must have same query_len > 1)
|
||||
- VARLEN: Decode pipeline supports variable-length queries
|
||||
(mixed query lengths in same batch)
|
||||
"""
|
||||
|
||||
SINGLE_ONLY = "single_only"
|
||||
UNIFORM = "uniform"
|
||||
VARLEN = "varlen"
|
||||
|
||||
|
||||
try:
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
|
||||
@ -460,19 +479,18 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
understand this class
|
||||
"""
|
||||
|
||||
# Whether the backend supports reordering the batch such that
|
||||
# short sequences (i.e. verification for speculative decoding) are
|
||||
# classified as decode requests.
|
||||
# If True, this will increase `reorder_batch_threshold` (below) when
|
||||
# speculative decoding is enabled, and set `require_uniform=True` when
|
||||
# when reordering the batch. Non-uniform decode requests will
|
||||
# fall back to prefill in this case.
|
||||
supports_uniform_spec_as_decode: ClassVar[bool] = False
|
||||
# Defines the level of query length support for this backend.
|
||||
# - SINGLE_ONLY: Only single-token queries (no spec decode support)
|
||||
# - UNIFORM: Supports uniform multi-token queries (spec decode with uniform lengths)
|
||||
# - VARLEN: Supports variable-length queries (spec decode with mixed lengths)
|
||||
# If set to UNIFORM or VARLEN, this will increase `reorder_batch_threshold` when
|
||||
# speculative decoding is enabled.
|
||||
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.SINGLE_ONLY
|
||||
|
||||
# The threshold for reordering the batch into decode and prefill requests.
|
||||
# If > 1, the batch will be reordered such that requests with
|
||||
# query length <= threshold are classified as decode requests.
|
||||
# Use `supports_uniform_spec_as_decode` (above) to set this automatically
|
||||
# Use `query_len_support` (above) to set this automatically
|
||||
# when speculative decoding is enabled.
|
||||
reorder_batch_threshold: int = 1
|
||||
|
||||
@ -599,11 +617,18 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
device=device,
|
||||
)
|
||||
|
||||
supports_spec_as_decode = self.supports_uniform_spec_as_decode
|
||||
supports_spec_decode = self.query_len_support != QueryLenSupport.SINGLE_ONLY
|
||||
self._init_reorder_batch_threshold(
|
||||
self.reorder_batch_threshold, supports_spec_as_decode
|
||||
self.reorder_batch_threshold, supports_spec_decode
|
||||
)
|
||||
|
||||
# Validate consistency between query_len_support and reorder_batch_threshold
|
||||
if self.query_len_support == QueryLenSupport.SINGLE_ONLY:
|
||||
assert self.reorder_batch_threshold == 1, (
|
||||
f"reorder_batch_threshold must be 1 when query_len_support is "
|
||||
f"SINGLE_ONLY, got {self.reorder_batch_threshold}"
|
||||
)
|
||||
|
||||
def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata):
|
||||
qo_indptr = prefill.query_start_loc
|
||||
|
||||
@ -745,7 +770,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold,
|
||||
require_uniform=self.supports_uniform_spec_as_decode,
|
||||
require_uniform=(self.query_len_support != QueryLenSupport.VARLEN),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -24,6 +24,7 @@ from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder,
|
||||
QueryLenSupport,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
@ -66,8 +67,8 @@ class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]):
|
||||
|
||||
class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]):
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
reorder_batch_threshold: int = 512
|
||||
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.VARLEN
|
||||
reorder_batch_threshold: int = 512 # process small prefills with decode pathway
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@ -13,6 +13,7 @@ from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder,
|
||||
QueryLenSupport,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
|
||||
@ -22,11 +23,8 @@ FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
|
||||
|
||||
|
||||
class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
|
||||
# enable spec-as-decode optimization
|
||||
supports_uniform_spec_as_decode: ClassVar[bool] = True
|
||||
|
||||
# enable full CUDA Graph support for decode-only capture
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
|
||||
|
||||
|
||||
class FlashInferMLABackend(MLACommonBackend):
|
||||
|
||||
@ -20,8 +20,13 @@ from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder,
|
||||
QueryLenSupport,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
reshape_attn_output_for_spec_decode,
|
||||
reshape_query_for_spec_decode,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -62,6 +67,9 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
|
||||
|
||||
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
|
||||
reorder_batch_threshold: int = 512 # process small prefills with decode pathway
|
||||
# ^ TODO(matt): tune this
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -216,8 +224,12 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
q = torch.cat(q, dim=-1)
|
||||
|
||||
assert isinstance(q, torch.Tensor)
|
||||
|
||||
num_decodes = attn_metadata.num_decodes
|
||||
q = reshape_query_for_spec_decode(q, num_decodes)
|
||||
|
||||
o, lse = flash_mla_with_kvcache(
|
||||
q=q.unsqueeze(1), # Add seqlen dim of 1 (decode)
|
||||
q=q,
|
||||
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
||||
block_table=attn_metadata.decode.block_table,
|
||||
cache_seqlens=attn_metadata.decode.seq_lens,
|
||||
@ -230,4 +242,6 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
descale_k=layer._k_scale.reshape(1),
|
||||
)
|
||||
|
||||
o = reshape_attn_output_for_spec_decode(o)
|
||||
|
||||
return o, lse
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user