From 41aa5784287f00b026f3ba225ac18ab3caccc622 Mon Sep 17 00:00:00 2001 From: Kaixi Hou Date: Wed, 4 Jun 2025 12:40:26 +0800 Subject: [PATCH] [NVIDIA] Add Cutlass MLA backend (#17625) --- csrc/attention/mla/cutlass_mla_kernels.cu | 2 +- tests/kernels/test_cutlass_mla_decode.py | 4 +- vllm/engine/arg_utils.py | 1 + vllm/platforms/cuda.py | 8 ++ vllm/platforms/interface.py | 1 + vllm/v1/attention/backends/mla/common.py | 2 +- vllm/v1/attention/backends/mla/cutlass_mla.py | 96 +++++++++++++++++++ 7 files changed, 111 insertions(+), 3 deletions(-) create mode 100644 vllm/v1/attention/backends/mla/cutlass_mla.py diff --git a/csrc/attention/mla/cutlass_mla_kernels.cu b/csrc/attention/mla/cutlass_mla_kernels.cu index 6743af0cf2dba..f4b6b19f4b232 100644 --- a/csrc/attention/mla/cutlass_mla_kernels.cu +++ b/csrc/attention/mla/cutlass_mla_kernels.cu @@ -119,7 +119,7 @@ typename T::Fmha::Arguments args_from_options( {static_cast(out.data_ptr()), stride_O, static_cast(nullptr), stride_LSE}, hw_info, - -1, // split_kv + 1, // split_kv nullptr, // is_var_split_kv }; // TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute diff --git a/tests/kernels/test_cutlass_mla_decode.py b/tests/kernels/test_cutlass_mla_decode.py index c56024b757e14..2b745b84dae6c 100644 --- a/tests/kernels/test_cutlass_mla_decode.py +++ b/tests/kernels/test_cutlass_mla_decode.py @@ -76,7 +76,9 @@ def test_cutlass_mla_decode(dtype: torch.dtype, mean_seq_len: int, bs: int, pack_factor = 128 // block_size block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor - q = torch.randn(bs, h_q, d) + # Amplify input values to ensure test coverage of edge cases where CUTLASS + # kernel errors occur with split_k settings. + q = torch.randn(bs, h_q, d) * 100 block_table = torch.randint(0, bs * block_num, (bs, block_num), dtype=torch.int32) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index b1c4b27a0ca4e..90134683180a7 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1395,6 +1395,7 @@ class EngineArgs: "PALLAS_VLLM_V1", "TRITON_ATTN_VLLM_V1", "TRITON_MLA", + "CUTLASS_MLA_VLLM_V1", "FLASHMLA", "FLASHINFER", "FLASHINFER_VLLM_V1", diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 07ae470fabfb8..bde606f0c1ef7 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -183,6 +183,14 @@ class CudaPlatformBase(Platform): if use_mla: # TODO(lucas): refactor to be more concise # we should probably consider factoring out V1 here + if selected_backend == _Backend.CUTLASS_MLA_VLLM_V1: + if use_v1: + logger.info_once("Using Cutlass MLA backend on V1 engine.") + return ("vllm.v1.attention.backends.mla." + "cutlass_mla.CutlassMLABackend") + else: + logger.warning( + "Cutlass MLA backend is only supported on V1 engine") if selected_backend == _Backend.TRITON_MLA or block_size != 64: if use_v1: logger.info_once("Using Triton MLA backend on V1 engine.") diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 1ec9c78a361af..7fef697d8f014 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -51,6 +51,7 @@ class _Backend(enum.Enum): TRITON_MLA_VLLM_V1 = enum.auto() FLASHMLA_VLLM_V1 = enum.auto() FLASHMLA = enum.auto() # Supported by V1 + CUTLASS_MLA_VLLM_V1 = enum.auto() HPU_ATTN = enum.auto() PALLAS = enum.auto() PALLAS_VLLM_V1 = enum.auto() diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 06acbb909a4f6..e6b4f6404632c 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -350,7 +350,7 @@ class MLACommonMetadataBuilder(Generic[M]): self.num_heads = model_config.get_num_attention_heads( runner.parallel_config) self.mla_dims = get_mla_dims(model_config) - self.aot_schedule = is_vllm_fa and (get_flash_attn_version() == 3) + self.aot_schedule = current_platform.is_cuda() self.kv_cache_spec = kv_cache_spec # Dont try to access the runner on AMD diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py new file mode 100644 index 0000000000000..70aee058e2963 --- /dev/null +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -0,0 +1,96 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Optional + +import torch + +import vllm._custom_ops as ops +from vllm.attention.backends.abstract import (AttentionType, + is_quantized_kv_cache) +from vllm.logger import init_logger +from vllm.v1.attention.backends.mla.common import (MLACommonBackend, + MLACommonImpl, + MLACommonMetadata) + +logger = init_logger(__name__) + + +class CutlassMLABackend(MLACommonBackend): + + @staticmethod + def get_name() -> str: + return "CUTLASS_MLA_VLLM_V1" + + @staticmethod + def get_impl_cls() -> type["CutlassMLAImpl"]: + return CutlassMLAImpl + + +class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap, attn_type, + **mla_args) + + unsupported_features = [ + alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap + ] + if any(unsupported_features): + raise NotImplementedError( + "CutlassMLAImpl does not support one of the following: " + "alibi_slopes, sliding_window, blocksparse_params, " + "logits_soft_cap") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "CutlassMLAImpl") + + if is_quantized_kv_cache(self.kv_cache_dtype): + raise NotImplementedError( + "CutlassMLA V1 with FP8 KV cache not yet supported") + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: MLACommonMetadata, + ) -> torch.Tensor: + assert kv_c_and_k_pe_cache.numel() > 0 + assert attn_metadata.decode is not None + + if self.kv_cache_dtype.startswith("fp8"): + raise NotImplementedError("FP8 Cutlass MLA not yet supported") + + B = q_nope.shape[0] + + o = torch.empty((B, self.num_heads, self.kv_lora_rank), + dtype=q_nope.dtype, + device=q_nope.device) + + # Run MLA + # Clone q_nope and q_pe to make sure strides computation is correct. + q_nope = q_nope.clone() + q_pe = q_pe.clone() + ops.cutlass_mla_decode(o, q_nope, q_pe, kv_c_and_k_pe_cache, + attn_metadata.decode.seq_lens, + attn_metadata.decode.block_table, self.scale) + + return self._v_up_proj(o)