mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 04:54:56 +08:00
[Kernel] Flashinfer MLA (trtllm-gen) decode kernel integration (#21078)
Signed-off-by: hjjq <hanjieq@nvidia.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
parent
fba7856581
commit
dcb28a332b
@ -729,7 +729,8 @@ steps:
|
||||
# num_heads2 broken by https://github.com/flashinfer-ai/flashinfer/issues/1353
|
||||
- pytest -v -s tests/kernels/attention/test_flashinfer.py -k 'not num_heads2'
|
||||
- pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_attention.py
|
||||
- pytest -v -s tests/kernels/test_cutlass_mla_decode.py
|
||||
- pytest -v -s tests/kernels/attention/test_cutlass_mla_decode.py
|
||||
- pytest -v -s tests/kernels/attention/test_flashinfer_mla_decode.py
|
||||
# Quantization
|
||||
- pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8'
|
||||
- pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py
|
||||
|
||||
123
tests/kernels/attention/test_flashinfer_mla_decode.py
Normal file
123
tests/kernels/attention/test_flashinfer_mla_decode.py
Normal file
@ -0,0 +1,123 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
|
||||
from torch import Tensor
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
|
||||
|
||||
if not current_platform.has_device_capability(100):
|
||||
pytest.skip(
|
||||
reason="FlashInfer MLA Requires compute capability of 10 or above.",
|
||||
allow_module_level=True)
|
||||
|
||||
|
||||
def ref_mla(
|
||||
out: Tensor, # (bs, num_heads, v_head_dim)
|
||||
query: Tensor, # (bs, num_heads, head_dim)
|
||||
kv_cache: Tensor, # (num_blocks, block_size, head_dim)
|
||||
scale: float,
|
||||
block_tables: Tensor, # (bs, max_num_blocks)
|
||||
seq_lens: Tensor, # (bs,)
|
||||
):
|
||||
bs, num_heads, v_head_dim = out.shape
|
||||
head_dim = query.shape[2]
|
||||
|
||||
for i in range(bs):
|
||||
# gather and flatten KV-cache
|
||||
kv = kv_cache[
|
||||
block_tables[i]] # (max_num_blocks, block_size, head_dim)
|
||||
kv = kv.view(1, -1,
|
||||
head_dim)[:, :seq_lens[i]] # (1, seq_len, head_dim)
|
||||
v = kv[:, :, :v_head_dim]
|
||||
|
||||
q = query[i].view(num_heads, 1, head_dim)
|
||||
o = F.scaled_dot_product_attention(q,
|
||||
kv,
|
||||
v,
|
||||
scale=scale,
|
||||
enable_gqa=True)
|
||||
out[i] = o.view(num_heads, v_head_dim)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("bs", [1, 2, 4, 16])
|
||||
@pytest.mark.parametrize("block_size", [32, 64])
|
||||
def test_flashinfer_mla_decode(dtype: torch.dtype, bs: int, block_size: int):
|
||||
torch.set_default_device('cuda')
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Deepseek R1 config
|
||||
num_heads = 128
|
||||
kv_lora_rank = 512
|
||||
qk_nope_head_dim = 128
|
||||
qk_rope_head_dim = 64
|
||||
qk_head_dim = kv_lora_rank + qk_rope_head_dim
|
||||
scale = (qk_nope_head_dim + qk_rope_head_dim)**-0.5
|
||||
|
||||
MAX_SEQ_LEN = 1024
|
||||
|
||||
seq_lens = [torch.randint(2, MAX_SEQ_LEN, (1, )).item() for _ in range(bs)]
|
||||
seq_lens[-1] = MAX_SEQ_LEN
|
||||
max_seq_len = max(seq_lens)
|
||||
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int32)
|
||||
|
||||
# Generate block tables with random but unique block IDs
|
||||
# From https://github.com/flashinfer-ai/flashinfer/pull/1222
|
||||
blocks_per_seq = (seq_lens_tensor + block_size - 1) // block_size
|
||||
max_num_blocks_per_seq = max(blocks_per_seq.max().item(), 4)
|
||||
total_blocks_needed = sum(blocks_per_seq)
|
||||
# Get random unique IDs for all blocks
|
||||
all_block_ids = torch.randperm(total_blocks_needed)
|
||||
|
||||
block_id = 0
|
||||
block_tables = torch.zeros(
|
||||
(bs, max_num_blocks_per_seq),
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
# Populate block tables and track block assignments
|
||||
block_id = 0
|
||||
for i in range(bs):
|
||||
num_blocks_needed = blocks_per_seq[i]
|
||||
block_tables[i, :num_blocks_needed] = all_block_ids[block_id:block_id +
|
||||
num_blocks_needed]
|
||||
block_id += num_blocks_needed
|
||||
|
||||
kv_cache = torch.randn(block_tables.numel(), block_size,
|
||||
qk_head_dim).to(dtype)
|
||||
q = torch.randn(bs, num_heads, qk_head_dim).to(dtype)
|
||||
|
||||
out_ref = q.new_zeros(bs, num_heads, kv_lora_rank)
|
||||
ref_mla(out_ref, q, kv_cache, scale, block_tables, seq_lens_tensor)
|
||||
|
||||
workspace_buffer = torch.zeros(
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE,
|
||||
dtype=torch.uint8,
|
||||
device=q.device,
|
||||
)
|
||||
# Flashinfer MLA expects the query to be of shape
|
||||
# (bs, q_len_per_request, num_heads, qk_head_dim),
|
||||
# where q_len_per_request is the MTP query length (=1 without MTP)
|
||||
q = q.unsqueeze(1)
|
||||
|
||||
out_ans = trtllm_batch_decode_with_kv_cache_mla(
|
||||
query=q,
|
||||
kv_cache=kv_cache.unsqueeze(1),
|
||||
workspace_buffer=workspace_buffer,
|
||||
qk_nope_head_dim=qk_nope_head_dim,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
qk_rope_head_dim=qk_rope_head_dim,
|
||||
block_tables=block_tables,
|
||||
seq_lens=seq_lens_tensor,
|
||||
max_seq_len=max_seq_len,
|
||||
bmm1_scale=scale,
|
||||
)
|
||||
out_ans = out_ans.squeeze(1)
|
||||
torch.testing.assert_close(out_ans, out_ref, atol=1e-2, rtol=1e-2)
|
||||
@ -1504,6 +1504,7 @@ class EngineArgs:
|
||||
"FLASH_ATTN_MLA",
|
||||
"FLASHINFER",
|
||||
"FLASHINFER_VLLM_V1",
|
||||
"FLASHINFER_MLA",
|
||||
"ROCM_AITER_MLA",
|
||||
"TORCH_SDPA_VLLM_V1",
|
||||
"FLEX_ATTENTION",
|
||||
|
||||
@ -228,6 +228,8 @@ class CudaPlatformBase(Platform):
|
||||
use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or (
|
||||
selected_backend is None and cls.is_device_capability(100)
|
||||
and block_size == 128)
|
||||
use_flashinfermla = (selected_backend == _Backend.FLASHINFER_MLA
|
||||
and cls.has_device_capability(100))
|
||||
use_flashmla = selected_backend in [
|
||||
_Backend.FLASHMLA, _Backend.FLASHMLA_VLLM_V1
|
||||
] or (selected_backend is None and is_flashmla_supported()[0])
|
||||
@ -252,6 +254,19 @@ class CudaPlatformBase(Platform):
|
||||
else:
|
||||
logger.warning(
|
||||
"Cutlass MLA backend is only supported on V1 engine")
|
||||
if use_flashinfermla:
|
||||
if use_v1:
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
set_kv_cache_layout)
|
||||
set_kv_cache_layout("HND")
|
||||
logger.info_once(
|
||||
"Using FlashInfer MLA backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends.mla."
|
||||
"flashinfer_mla.FlashInferMLABackend")
|
||||
else:
|
||||
logger.warning(
|
||||
"FlashInfer MLA backend is only supported on V1 engine"
|
||||
)
|
||||
if use_flashmla:
|
||||
if block_size != 64:
|
||||
logger.warning(
|
||||
|
||||
@ -51,6 +51,7 @@ class _Backend(enum.Enum):
|
||||
TORCH_SDPA_VLLM_V1 = enum.auto()
|
||||
FLASHINFER = enum.auto()
|
||||
FLASHINFER_VLLM_V1 = enum.auto()
|
||||
FLASHINFER_MLA = enum.auto()
|
||||
TRITON_MLA = enum.auto() # Supported by V1
|
||||
TRITON_MLA_VLLM_V1 = enum.auto()
|
||||
CUTLASS_MLA = enum.auto()
|
||||
|
||||
@ -381,6 +381,7 @@ class MLACommonMetadata(Generic[D]):
|
||||
|
||||
num_reqs: int
|
||||
max_query_len: int
|
||||
max_seq_len: int
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
query_start_loc: torch.Tensor
|
||||
@ -644,6 +645,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
max_seq_len = common_attn_metadata.max_seq_len
|
||||
|
||||
# Note(simon): be careful about the CPU <> GPU memory movement in this
|
||||
# function. We should avoid GPU -> CPU sync as much as possible because
|
||||
@ -830,6 +832,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
attn_metadata = self.metadata_cls(
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
max_query_len=common_attn_metadata.max_query_len,
|
||||
max_seq_len=max_seq_len,
|
||||
num_actual_tokens=num_tokens,
|
||||
query_start_loc=query_start_loc,
|
||||
slot_mapping=slot_mapping,
|
||||
|
||||
110
vllm/v1/attention/backends/mla/flashinfer_mla.py
Normal file
110
vllm/v1/attention/backends/mla/flashinfer_mla.py
Normal file
@ -0,0 +1,110 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionLayer, AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
|
||||
|
||||
|
||||
class FlashInferMLABackend(MLACommonBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHINFER_MLA"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlashInferMLAImpl"]:
|
||||
return FlashInferMLAImpl
|
||||
|
||||
|
||||
g_fi_workspace = torch.zeros(
|
||||
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE,
|
||||
dtype=torch.uint8,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
|
||||
class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
# MLA Specific Arguments
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"FlashInferMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashInferMLAImpl")
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"FlashInferMLA V1 with FP8 KV cache not yet supported")
|
||||
|
||||
self._workspace_buffer = g_fi_workspace
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if isinstance(q, tuple):
|
||||
q_nope, q_pe = q
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
|
||||
# trtllm API requires extra dimension q_len_per_request for MTP
|
||||
q = q.unsqueeze(1)
|
||||
|
||||
o = trtllm_batch_decode_with_kv_cache_mla(
|
||||
query=q,
|
||||
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
|
||||
workspace_buffer=self._workspace_buffer,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
block_tables=attn_metadata.decode.block_table,
|
||||
seq_lens=attn_metadata.decode.seq_lens,
|
||||
max_seq_len=attn_metadata.max_seq_len,
|
||||
bmm1_scale=self.scale,
|
||||
)
|
||||
|
||||
# TODO: Return LSE pending support from Flashinfer API:
|
||||
# https://github.com/flashinfer-ai/flashinfer/pull/1566
|
||||
return o, None
|
||||
Loading…
x
Reference in New Issue
Block a user