diff --git a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu index 820bf81dd1a02..c60f1823b8a1d 100644 --- a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu +++ b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu @@ -36,6 +36,7 @@ limitations under the License. #if !defined(CUDA_VERSION) || CUDA_VERSION < 12040 void sm100_cutlass_mla_decode( torch::Tensor const& out, + torch::Tensor const& lse, torch::Tensor const& q_nope, torch::Tensor const& q_pe, torch::Tensor const& kv_c_and_k_pe_cache, @@ -99,6 +100,7 @@ struct MlaSm100 { template typename T::Fmha::Arguments args_from_options( at::Tensor const& out, + at::Tensor const& lse, at::Tensor const& q_nope, at::Tensor const& q_pe, at::Tensor const& kv_c_and_k_pe_cache, @@ -162,7 +164,10 @@ typename T::Fmha::Arguments args_from_options( stride_PT, page_count_total, page_size}, - {static_cast(out.data_ptr()), stride_O, static_cast(nullptr), stride_LSE}, + {static_cast(out.data_ptr()), + stride_O, + static_cast(lse.defined() ? lse.data_ptr() : nullptr), + stride_LSE}, hw_info, // TODO(trevor-m): Change split_kv back to -1 when // https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will @@ -181,6 +186,7 @@ typename T::Fmha::Arguments args_from_options( template void runMla( at::Tensor const& out, + at::Tensor const& lse, at::Tensor const& q_nope, at::Tensor const& q_pe, at::Tensor const& kv_c_and_k_pe_cache, @@ -192,7 +198,7 @@ void runMla( cudaStream_t stream) { using MlaSm100Type = MlaSm100; typename MlaSm100Type::Fmha fmha; - auto arguments = args_from_options(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, sm_scale, num_kv_splits); + auto arguments = args_from_options(out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, sm_scale, num_kv_splits); CUTLASS_CHECK(fmha.can_implement(arguments)); @@ -214,6 +220,7 @@ void runMla( void sm100_cutlass_mla_decode( torch::Tensor const& out, + torch::Tensor const& lse, torch::Tensor const& q_nope, torch::Tensor const& q_pe, torch::Tensor const& kv_c_and_k_pe_cache, @@ -234,13 +241,13 @@ void sm100_cutlass_mla_decode( DISPATCH_BOOL(num_kv_splits <= 1, NotManualSplitKV, [&] { if (in_dtype == at::ScalarType::Half) { runMla>( - out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); + out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); } else if (in_dtype == at::ScalarType::BFloat16) { runMla>( - out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); + out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); } else if (in_dtype == at::ScalarType::Float8_e4m3fn) { runMla>( - out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); + out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); } else { TORCH_CHECK(false, "Unsupported input data type of MLA"); } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 95fb5b197f534..d3f50d1076cb0 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -516,10 +516,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // SM100 CUTLASS MLA decode ops.def( - "sm100_cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe," - " Tensor kv_c_and_k_pe_cache, Tensor seq_lens," - " Tensor page_table, Tensor workspace, float " - "scale," + "sm100_cutlass_mla_decode(Tensor! out, Tensor! lse, Tensor q_nope," + " Tensor q_pe, Tensor kv_c_and_k_pe_cache," + " Tensor seq_lens, Tensor page_table," + " Tensor workspace, float scale," " int num_kv_splits) -> ()"); // conditionally compiled so impl in source file diff --git a/tests/kernels/test_cutlass_mla_decode.py b/tests/kernels/test_cutlass_mla_decode.py index 85984324b1967..820dac0e6cec9 100644 --- a/tests/kernels/test_cutlass_mla_decode.py +++ b/tests/kernels/test_cutlass_mla_decode.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math import random +from typing import Optional import pytest import torch @@ -14,14 +15,20 @@ from vllm.triton_utils import triton def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, - use_fp8: bool = False) -> None: + use_fp8: bool = False, + diff_threshold: Optional[float] = None) -> None: x, y = x.double(), y.double() cos_diff = 1 - 2 * (x * y).sum().item() / max( (x * x + y * y).sum().item(), 1e-12) - if (use_fp8): - assert cos_diff < 1e-4 + if diff_threshold is not None: + # directly compare the cos_diff with the threshold + assert cos_diff < diff_threshold else: - assert cos_diff < 1e-5 + # use the default threshold + if (use_fp8): + assert cos_diff < 1e-4 + else: + assert cos_diff < 1e-5 CUTLASS_MLA_UNSUPPORTED_REASON = \ @@ -118,11 +125,13 @@ def test_cutlass_mla_decode(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, dtype=torch.uint8) out_ans = torch.empty(b, MAX_HEADS, dv, dtype=init_dtype) - - ops.sm100_cutlass_mla_decode(out_ans, q_nope, q_pe, kv_cache_flat, - cache_seqlens, block_table, workspace, - scale, 1) - return out_ans[:, :h_q].contiguous() + output_lse = torch.empty((b, MAX_HEADS), + dtype=torch.float32, + device=q_nope.device) + ops.sm100_cutlass_mla_decode(out_ans, output_lse, q_nope, q_pe, + kv_cache_flat, cache_seqlens, block_table, + workspace, scale, 1) + return out_ans[:, :h_q].contiguous(), output_lse[:, :h_q].contiguous() def scaled_dot_product_attention(query, key, value, is_causal=False): query = query.float() @@ -165,11 +174,14 @@ def test_cutlass_mla_decode(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, lse[i] = lse_i return out, lse - out_cutlass = cutlass_mla() + out_cutlass, lse_cutlass = cutlass_mla() out_torch, lse_torch = ref_mla() # Extract the single token (s_q=1) slice to match cutlass output shape out_torch_slice = out_torch[:, 0, :, :] # [b, h_q, dv] + lse_torch_slice = lse_torch[:, 0, :] # [b, h_q] cal_diff(out_cutlass, out_torch_slice, "out", use_fp8) + # lse has larger numerical error, so use a larger threshold + cal_diff(lse_cutlass, lse_torch_slice, "lse", use_fp8, diff_threshold=1e-3) t = triton.testing.do_bench(cutlass_mla) FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 545f4cb48bf47..6e9a8df0a56a2 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1833,13 +1833,13 @@ def cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor, return out -def sm100_cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor, - q_pe: torch.Tensor, +def sm100_cutlass_mla_decode(out: torch.Tensor, lse: torch.Tensor, + q_nope: torch.Tensor, q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, seq_lens: torch.Tensor, page_table: torch.Tensor, workspace: torch.Tensor, scale: float, num_kv_splits: int) -> torch.Tensor: - torch.ops._C.sm100_cutlass_mla_decode(out, q_nope, q_pe, + torch.ops._C.sm100_cutlass_mla_decode(out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, scale, num_kv_splits) diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index 95dce8d8e2eef..6017445402eca 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -76,6 +76,7 @@ g_sm100_workspace = SM100Workspace(128 * 1024 * 1024) # 128MB class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): + can_return_lse_for_decode: bool = True def __init__( self, @@ -138,7 +139,7 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): workspace: torch.Tensor, sm_scale: float, num_kv_splits: int, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor]: assert (q_nope.ndim == 3 ), f"q_nope must be a 3D tensor, but got {q_nope.ndim}" assert ( @@ -193,9 +194,13 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): dtype = (torch.bfloat16 if is_quantized_kv_cache(self.kv_cache_dtype) else q_nope.dtype) out = q_nope.new_empty((B_q, MAX_HEADS, D_latent), dtype=dtype) + lse = (torch.empty( + (B_q, MAX_HEADS), dtype=torch.float32, device=q_nope.device) + if self.need_to_return_lse_for_decode else torch.Tensor()) ops.sm100_cutlass_mla_decode( out, + lse, q_nope, q_pe, kv_c_and_k_pe_cache, @@ -205,7 +210,9 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): sm_scale, num_kv_splits, ) - return out[:, :H].contiguous() + returned_lse = lse[:, :H].contiguous( + ) if self.need_to_return_lse_for_decode else lse + return out[:, :H].contiguous(), returned_lse def _sm100_forward_decode( self, @@ -213,7 +220,7 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None @@ -226,13 +233,18 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): q_nope = q_nope.clone() q_pe = q_pe.clone() - o = self._sm100_cutlass_mla_decode(q_nope, q_pe, kv_c_and_k_pe_cache, - attn_metadata.decode.seq_lens, - attn_metadata.decode.block_table, - self._workspace.get_buf(), - self.scale, self._num_kv_splits) + o, lse = self._sm100_cutlass_mla_decode( + q_nope, + q_pe, + kv_c_and_k_pe_cache, + attn_metadata.decode.seq_lens, + attn_metadata.decode.block_table, + self._workspace.get_buf(), + self.scale, + self._num_kv_splits, + ) - return o + return o, (lse if self.need_to_return_lse_for_decode else None) # TODO: Currently we leave it here only for backup in case something is # wrong with the new SM100 CUTLASS MLA kernel @@ -286,4 +298,4 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): attn_metadata), None return self._sm100_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache, - attn_metadata), None + attn_metadata)