mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-24 17:24:30 +08:00
[Kernel] Support decode context parallelism on Blackwell with CUTLASS MLA (#24385)
Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
795b6951cd
commit
86173ad593
@ -36,6 +36,7 @@ limitations under the License.
|
|||||||
#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040
|
#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040
|
||||||
void sm100_cutlass_mla_decode(
|
void sm100_cutlass_mla_decode(
|
||||||
torch::Tensor const& out,
|
torch::Tensor const& out,
|
||||||
|
torch::Tensor const& lse,
|
||||||
torch::Tensor const& q_nope,
|
torch::Tensor const& q_nope,
|
||||||
torch::Tensor const& q_pe,
|
torch::Tensor const& q_pe,
|
||||||
torch::Tensor const& kv_c_and_k_pe_cache,
|
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||||
@ -99,6 +100,7 @@ struct MlaSm100 {
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
typename T::Fmha::Arguments args_from_options(
|
typename T::Fmha::Arguments args_from_options(
|
||||||
at::Tensor const& out,
|
at::Tensor const& out,
|
||||||
|
at::Tensor const& lse,
|
||||||
at::Tensor const& q_nope,
|
at::Tensor const& q_nope,
|
||||||
at::Tensor const& q_pe,
|
at::Tensor const& q_pe,
|
||||||
at::Tensor const& kv_c_and_k_pe_cache,
|
at::Tensor const& kv_c_and_k_pe_cache,
|
||||||
@ -162,7 +164,10 @@ typename T::Fmha::Arguments args_from_options(
|
|||||||
stride_PT,
|
stride_PT,
|
||||||
page_count_total,
|
page_count_total,
|
||||||
page_size},
|
page_size},
|
||||||
{static_cast<ElementOut*>(out.data_ptr()), stride_O, static_cast<ElementAcc*>(nullptr), stride_LSE},
|
{static_cast<ElementOut*>(out.data_ptr()),
|
||||||
|
stride_O,
|
||||||
|
static_cast<ElementAcc*>(lse.defined() ? lse.data_ptr() : nullptr),
|
||||||
|
stride_LSE},
|
||||||
hw_info,
|
hw_info,
|
||||||
// TODO(trevor-m): Change split_kv back to -1 when
|
// TODO(trevor-m): Change split_kv back to -1 when
|
||||||
// https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will
|
// https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will
|
||||||
@ -181,6 +186,7 @@ typename T::Fmha::Arguments args_from_options(
|
|||||||
template <typename Element, typename ElementOut, bool IsPaged128, typename PersistenceOption>
|
template <typename Element, typename ElementOut, bool IsPaged128, typename PersistenceOption>
|
||||||
void runMla(
|
void runMla(
|
||||||
at::Tensor const& out,
|
at::Tensor const& out,
|
||||||
|
at::Tensor const& lse,
|
||||||
at::Tensor const& q_nope,
|
at::Tensor const& q_nope,
|
||||||
at::Tensor const& q_pe,
|
at::Tensor const& q_pe,
|
||||||
at::Tensor const& kv_c_and_k_pe_cache,
|
at::Tensor const& kv_c_and_k_pe_cache,
|
||||||
@ -192,7 +198,7 @@ void runMla(
|
|||||||
cudaStream_t stream) {
|
cudaStream_t stream) {
|
||||||
using MlaSm100Type = MlaSm100<Element, ElementOut, IsPaged128, PersistenceOption>;
|
using MlaSm100Type = MlaSm100<Element, ElementOut, IsPaged128, PersistenceOption>;
|
||||||
typename MlaSm100Type::Fmha fmha;
|
typename MlaSm100Type::Fmha fmha;
|
||||||
auto arguments = args_from_options<MlaSm100Type>(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, sm_scale, num_kv_splits);
|
auto arguments = args_from_options<MlaSm100Type>(out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, sm_scale, num_kv_splits);
|
||||||
|
|
||||||
CUTLASS_CHECK(fmha.can_implement(arguments));
|
CUTLASS_CHECK(fmha.can_implement(arguments));
|
||||||
|
|
||||||
@ -214,6 +220,7 @@ void runMla(
|
|||||||
|
|
||||||
void sm100_cutlass_mla_decode(
|
void sm100_cutlass_mla_decode(
|
||||||
torch::Tensor const& out,
|
torch::Tensor const& out,
|
||||||
|
torch::Tensor const& lse,
|
||||||
torch::Tensor const& q_nope,
|
torch::Tensor const& q_nope,
|
||||||
torch::Tensor const& q_pe,
|
torch::Tensor const& q_pe,
|
||||||
torch::Tensor const& kv_c_and_k_pe_cache,
|
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||||
@ -234,13 +241,13 @@ void sm100_cutlass_mla_decode(
|
|||||||
DISPATCH_BOOL(num_kv_splits <= 1, NotManualSplitKV, [&] {
|
DISPATCH_BOOL(num_kv_splits <= 1, NotManualSplitKV, [&] {
|
||||||
if (in_dtype == at::ScalarType::Half) {
|
if (in_dtype == at::ScalarType::Half) {
|
||||||
runMla<cutlass::half_t, cutlass::half_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
|
runMla<cutlass::half_t, cutlass::half_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
|
||||||
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
|
out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
|
||||||
} else if (in_dtype == at::ScalarType::BFloat16) {
|
} else if (in_dtype == at::ScalarType::BFloat16) {
|
||||||
runMla<cutlass::bfloat16_t, cutlass::bfloat16_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
|
runMla<cutlass::bfloat16_t, cutlass::bfloat16_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
|
||||||
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
|
out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
|
||||||
} else if (in_dtype == at::ScalarType::Float8_e4m3fn) {
|
} else if (in_dtype == at::ScalarType::Float8_e4m3fn) {
|
||||||
runMla<cutlass::float_e4m3_t, cutlass::bfloat16_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
|
runMla<cutlass::float_e4m3_t, cutlass::bfloat16_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
|
||||||
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
|
out, lse, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(false, "Unsupported input data type of MLA");
|
TORCH_CHECK(false, "Unsupported input data type of MLA");
|
||||||
}
|
}
|
||||||
|
|||||||
@ -516,10 +516,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
|
|
||||||
// SM100 CUTLASS MLA decode
|
// SM100 CUTLASS MLA decode
|
||||||
ops.def(
|
ops.def(
|
||||||
"sm100_cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe,"
|
"sm100_cutlass_mla_decode(Tensor! out, Tensor! lse, Tensor q_nope,"
|
||||||
" Tensor kv_c_and_k_pe_cache, Tensor seq_lens,"
|
" Tensor q_pe, Tensor kv_c_and_k_pe_cache,"
|
||||||
" Tensor page_table, Tensor workspace, float "
|
" Tensor seq_lens, Tensor page_table,"
|
||||||
"scale,"
|
" Tensor workspace, float scale,"
|
||||||
" int num_kv_splits) -> ()");
|
" int num_kv_splits) -> ()");
|
||||||
// conditionally compiled so impl in source file
|
// conditionally compiled so impl in source file
|
||||||
|
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@ -14,14 +15,20 @@ from vllm.triton_utils import triton
|
|||||||
def cal_diff(x: torch.Tensor,
|
def cal_diff(x: torch.Tensor,
|
||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
name: str,
|
name: str,
|
||||||
use_fp8: bool = False) -> None:
|
use_fp8: bool = False,
|
||||||
|
diff_threshold: Optional[float] = None) -> None:
|
||||||
x, y = x.double(), y.double()
|
x, y = x.double(), y.double()
|
||||||
cos_diff = 1 - 2 * (x * y).sum().item() / max(
|
cos_diff = 1 - 2 * (x * y).sum().item() / max(
|
||||||
(x * x + y * y).sum().item(), 1e-12)
|
(x * x + y * y).sum().item(), 1e-12)
|
||||||
if (use_fp8):
|
if diff_threshold is not None:
|
||||||
assert cos_diff < 1e-4
|
# directly compare the cos_diff with the threshold
|
||||||
|
assert cos_diff < diff_threshold
|
||||||
else:
|
else:
|
||||||
assert cos_diff < 1e-5
|
# use the default threshold
|
||||||
|
if (use_fp8):
|
||||||
|
assert cos_diff < 1e-4
|
||||||
|
else:
|
||||||
|
assert cos_diff < 1e-5
|
||||||
|
|
||||||
|
|
||||||
CUTLASS_MLA_UNSUPPORTED_REASON = \
|
CUTLASS_MLA_UNSUPPORTED_REASON = \
|
||||||
@ -118,11 +125,13 @@ def test_cutlass_mla_decode(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size,
|
|||||||
dtype=torch.uint8)
|
dtype=torch.uint8)
|
||||||
|
|
||||||
out_ans = torch.empty(b, MAX_HEADS, dv, dtype=init_dtype)
|
out_ans = torch.empty(b, MAX_HEADS, dv, dtype=init_dtype)
|
||||||
|
output_lse = torch.empty((b, MAX_HEADS),
|
||||||
ops.sm100_cutlass_mla_decode(out_ans, q_nope, q_pe, kv_cache_flat,
|
dtype=torch.float32,
|
||||||
cache_seqlens, block_table, workspace,
|
device=q_nope.device)
|
||||||
scale, 1)
|
ops.sm100_cutlass_mla_decode(out_ans, output_lse, q_nope, q_pe,
|
||||||
return out_ans[:, :h_q].contiguous()
|
kv_cache_flat, cache_seqlens, block_table,
|
||||||
|
workspace, scale, 1)
|
||||||
|
return out_ans[:, :h_q].contiguous(), output_lse[:, :h_q].contiguous()
|
||||||
|
|
||||||
def scaled_dot_product_attention(query, key, value, is_causal=False):
|
def scaled_dot_product_attention(query, key, value, is_causal=False):
|
||||||
query = query.float()
|
query = query.float()
|
||||||
@ -165,11 +174,14 @@ def test_cutlass_mla_decode(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size,
|
|||||||
lse[i] = lse_i
|
lse[i] = lse_i
|
||||||
return out, lse
|
return out, lse
|
||||||
|
|
||||||
out_cutlass = cutlass_mla()
|
out_cutlass, lse_cutlass = cutlass_mla()
|
||||||
out_torch, lse_torch = ref_mla()
|
out_torch, lse_torch = ref_mla()
|
||||||
# Extract the single token (s_q=1) slice to match cutlass output shape
|
# Extract the single token (s_q=1) slice to match cutlass output shape
|
||||||
out_torch_slice = out_torch[:, 0, :, :] # [b, h_q, dv]
|
out_torch_slice = out_torch[:, 0, :, :] # [b, h_q, dv]
|
||||||
|
lse_torch_slice = lse_torch[:, 0, :] # [b, h_q]
|
||||||
cal_diff(out_cutlass, out_torch_slice, "out", use_fp8)
|
cal_diff(out_cutlass, out_torch_slice, "out", use_fp8)
|
||||||
|
# lse has larger numerical error, so use a larger threshold
|
||||||
|
cal_diff(lse_cutlass, lse_torch_slice, "lse", use_fp8, diff_threshold=1e-3)
|
||||||
|
|
||||||
t = triton.testing.do_bench(cutlass_mla)
|
t = triton.testing.do_bench(cutlass_mla)
|
||||||
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
|
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
|
||||||
|
|||||||
@ -1833,13 +1833,13 @@ def cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor,
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def sm100_cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor,
|
def sm100_cutlass_mla_decode(out: torch.Tensor, lse: torch.Tensor,
|
||||||
q_pe: torch.Tensor,
|
q_nope: torch.Tensor, q_pe: torch.Tensor,
|
||||||
kv_c_and_k_pe_cache: torch.Tensor,
|
kv_c_and_k_pe_cache: torch.Tensor,
|
||||||
seq_lens: torch.Tensor, page_table: torch.Tensor,
|
seq_lens: torch.Tensor, page_table: torch.Tensor,
|
||||||
workspace: torch.Tensor, scale: float,
|
workspace: torch.Tensor, scale: float,
|
||||||
num_kv_splits: int) -> torch.Tensor:
|
num_kv_splits: int) -> torch.Tensor:
|
||||||
torch.ops._C.sm100_cutlass_mla_decode(out, q_nope, q_pe,
|
torch.ops._C.sm100_cutlass_mla_decode(out, lse, q_nope, q_pe,
|
||||||
kv_c_and_k_pe_cache, seq_lens,
|
kv_c_and_k_pe_cache, seq_lens,
|
||||||
page_table, workspace, scale,
|
page_table, workspace, scale,
|
||||||
num_kv_splits)
|
num_kv_splits)
|
||||||
|
|||||||
@ -76,6 +76,7 @@ g_sm100_workspace = SM100Workspace(128 * 1024 * 1024) # 128MB
|
|||||||
|
|
||||||
|
|
||||||
class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||||
|
can_return_lse_for_decode: bool = True
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -138,7 +139,7 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
workspace: torch.Tensor,
|
workspace: torch.Tensor,
|
||||||
sm_scale: float,
|
sm_scale: float,
|
||||||
num_kv_splits: int,
|
num_kv_splits: int,
|
||||||
) -> torch.Tensor:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
assert (q_nope.ndim == 3
|
assert (q_nope.ndim == 3
|
||||||
), f"q_nope must be a 3D tensor, but got {q_nope.ndim}"
|
), f"q_nope must be a 3D tensor, but got {q_nope.ndim}"
|
||||||
assert (
|
assert (
|
||||||
@ -193,9 +194,13 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
dtype = (torch.bfloat16 if is_quantized_kv_cache(self.kv_cache_dtype)
|
dtype = (torch.bfloat16 if is_quantized_kv_cache(self.kv_cache_dtype)
|
||||||
else q_nope.dtype)
|
else q_nope.dtype)
|
||||||
out = q_nope.new_empty((B_q, MAX_HEADS, D_latent), dtype=dtype)
|
out = q_nope.new_empty((B_q, MAX_HEADS, D_latent), dtype=dtype)
|
||||||
|
lse = (torch.empty(
|
||||||
|
(B_q, MAX_HEADS), dtype=torch.float32, device=q_nope.device)
|
||||||
|
if self.need_to_return_lse_for_decode else torch.Tensor())
|
||||||
|
|
||||||
ops.sm100_cutlass_mla_decode(
|
ops.sm100_cutlass_mla_decode(
|
||||||
out,
|
out,
|
||||||
|
lse,
|
||||||
q_nope,
|
q_nope,
|
||||||
q_pe,
|
q_pe,
|
||||||
kv_c_and_k_pe_cache,
|
kv_c_and_k_pe_cache,
|
||||||
@ -205,7 +210,9 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
sm_scale,
|
sm_scale,
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
)
|
)
|
||||||
return out[:, :H].contiguous()
|
returned_lse = lse[:, :H].contiguous(
|
||||||
|
) if self.need_to_return_lse_for_decode else lse
|
||||||
|
return out[:, :H].contiguous(), returned_lse
|
||||||
|
|
||||||
def _sm100_forward_decode(
|
def _sm100_forward_decode(
|
||||||
self,
|
self,
|
||||||
@ -213,7 +220,7 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
q_pe: torch.Tensor,
|
q_pe: torch.Tensor,
|
||||||
kv_c_and_k_pe_cache: torch.Tensor,
|
kv_c_and_k_pe_cache: torch.Tensor,
|
||||||
attn_metadata: MLACommonMetadata,
|
attn_metadata: MLACommonMetadata,
|
||||||
) -> torch.Tensor:
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
assert kv_c_and_k_pe_cache.numel() > 0
|
assert kv_c_and_k_pe_cache.numel() > 0
|
||||||
assert attn_metadata.decode is not None
|
assert attn_metadata.decode is not None
|
||||||
|
|
||||||
@ -226,13 +233,18 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
q_nope = q_nope.clone()
|
q_nope = q_nope.clone()
|
||||||
q_pe = q_pe.clone()
|
q_pe = q_pe.clone()
|
||||||
|
|
||||||
o = self._sm100_cutlass_mla_decode(q_nope, q_pe, kv_c_and_k_pe_cache,
|
o, lse = self._sm100_cutlass_mla_decode(
|
||||||
attn_metadata.decode.seq_lens,
|
q_nope,
|
||||||
attn_metadata.decode.block_table,
|
q_pe,
|
||||||
self._workspace.get_buf(),
|
kv_c_and_k_pe_cache,
|
||||||
self.scale, self._num_kv_splits)
|
attn_metadata.decode.seq_lens,
|
||||||
|
attn_metadata.decode.block_table,
|
||||||
|
self._workspace.get_buf(),
|
||||||
|
self.scale,
|
||||||
|
self._num_kv_splits,
|
||||||
|
)
|
||||||
|
|
||||||
return o
|
return o, (lse if self.need_to_return_lse_for_decode else None)
|
||||||
|
|
||||||
# TODO: Currently we leave it here only for backup in case something is
|
# TODO: Currently we leave it here only for backup in case something is
|
||||||
# wrong with the new SM100 CUTLASS MLA kernel
|
# wrong with the new SM100 CUTLASS MLA kernel
|
||||||
@ -286,4 +298,4 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
attn_metadata), None
|
attn_metadata), None
|
||||||
|
|
||||||
return self._sm100_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache,
|
return self._sm100_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache,
|
||||||
attn_metadata), None
|
attn_metadata)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user