mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-05 09:27:03 +08:00
[NVIDIA] Add Cutlass MLA backend (#17625)
This commit is contained in:
parent
8d646c2e53
commit
41aa578428
@ -119,7 +119,7 @@ typename T::Fmha::Arguments args_from_options(
|
||||
{static_cast<ElementOut*>(out.data_ptr()), stride_O,
|
||||
static_cast<ElementAcc*>(nullptr), stride_LSE},
|
||||
hw_info,
|
||||
-1, // split_kv
|
||||
1, // split_kv
|
||||
nullptr, // is_var_split_kv
|
||||
};
|
||||
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
|
||||
|
||||
@ -76,7 +76,9 @@ def test_cutlass_mla_decode(dtype: torch.dtype, mean_seq_len: int, bs: int,
|
||||
pack_factor = 128 // block_size
|
||||
block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor
|
||||
|
||||
q = torch.randn(bs, h_q, d)
|
||||
# Amplify input values to ensure test coverage of edge cases where CUTLASS
|
||||
# kernel errors occur with split_k settings.
|
||||
q = torch.randn(bs, h_q, d) * 100
|
||||
block_table = torch.randint(0,
|
||||
bs * block_num, (bs, block_num),
|
||||
dtype=torch.int32)
|
||||
|
||||
@ -1395,6 +1395,7 @@ class EngineArgs:
|
||||
"PALLAS_VLLM_V1",
|
||||
"TRITON_ATTN_VLLM_V1",
|
||||
"TRITON_MLA",
|
||||
"CUTLASS_MLA_VLLM_V1",
|
||||
"FLASHMLA",
|
||||
"FLASHINFER",
|
||||
"FLASHINFER_VLLM_V1",
|
||||
|
||||
@ -183,6 +183,14 @@ class CudaPlatformBase(Platform):
|
||||
if use_mla:
|
||||
# TODO(lucas): refactor to be more concise
|
||||
# we should probably consider factoring out V1 here
|
||||
if selected_backend == _Backend.CUTLASS_MLA_VLLM_V1:
|
||||
if use_v1:
|
||||
logger.info_once("Using Cutlass MLA backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends.mla."
|
||||
"cutlass_mla.CutlassMLABackend")
|
||||
else:
|
||||
logger.warning(
|
||||
"Cutlass MLA backend is only supported on V1 engine")
|
||||
if selected_backend == _Backend.TRITON_MLA or block_size != 64:
|
||||
if use_v1:
|
||||
logger.info_once("Using Triton MLA backend on V1 engine.")
|
||||
|
||||
@ -51,6 +51,7 @@ class _Backend(enum.Enum):
|
||||
TRITON_MLA_VLLM_V1 = enum.auto()
|
||||
FLASHMLA_VLLM_V1 = enum.auto()
|
||||
FLASHMLA = enum.auto() # Supported by V1
|
||||
CUTLASS_MLA_VLLM_V1 = enum.auto()
|
||||
HPU_ATTN = enum.auto()
|
||||
PALLAS = enum.auto()
|
||||
PALLAS_VLLM_V1 = enum.auto()
|
||||
|
||||
@ -350,7 +350,7 @@ class MLACommonMetadataBuilder(Generic[M]):
|
||||
self.num_heads = model_config.get_num_attention_heads(
|
||||
runner.parallel_config)
|
||||
self.mla_dims = get_mla_dims(model_config)
|
||||
self.aot_schedule = is_vllm_fa and (get_flash_attn_version() == 3)
|
||||
self.aot_schedule = current_platform.is_cuda()
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
|
||||
# Dont try to access the runner on AMD
|
||||
|
||||
96
vllm/v1/attention/backends/mla/cutlass_mla.py
Normal file
96
vllm/v1/attention/backends/mla/cutlass_mla.py
Normal file
@ -0,0 +1,96 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CutlassMLABackend(MLACommonBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "CUTLASS_MLA_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["CutlassMLAImpl"]:
|
||||
return CutlassMLAImpl
|
||||
|
||||
|
||||
class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
# MLA Specific Arguments
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
**mla_args)
|
||||
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"CutlassMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"CutlassMLAImpl")
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"CutlassMLA V1 with FP8 KV cache not yet supported")
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError("FP8 Cutlass MLA not yet supported")
|
||||
|
||||
B = q_nope.shape[0]
|
||||
|
||||
o = torch.empty((B, self.num_heads, self.kv_lora_rank),
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
|
||||
# Run MLA
|
||||
# Clone q_nope and q_pe to make sure strides computation is correct.
|
||||
q_nope = q_nope.clone()
|
||||
q_pe = q_pe.clone()
|
||||
ops.cutlass_mla_decode(o, q_nope, q_pe, kv_c_and_k_pe_cache,
|
||||
attn_metadata.decode.seq_lens,
|
||||
attn_metadata.decode.block_table, self.scale)
|
||||
|
||||
return self._v_up_proj(o)
|
||||
Loading…
x
Reference in New Issue
Block a user