mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-02 10:07:06 +08:00
Remove old cutlass mla (#23961)
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
parent
47f670b03b
commit
8f3616f422
@ -308,7 +308,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu"
|
||||
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
|
||||
"csrc/cutlass_extensions/common.cpp"
|
||||
"csrc/attention/mla/cutlass_mla_entry.cu"
|
||||
"csrc/quantization/fp8/per_token_group_quant.cu")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
@ -595,7 +594,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
cuda_archs_loose_intersection(MLA_ARCHS "10.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND MLA_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/attention/mla/cutlass_mla_kernels.cu"
|
||||
"csrc/attention/mla/sm100_cutlass_mla_kernel.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
|
||||
@ -1,38 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#if defined ENABLE_CUTLASS_MLA && ENABLE_CUTLASS_MLA
|
||||
void cutlass_mla_decode_sm100a(torch::Tensor const& out,
|
||||
torch::Tensor const& q_nope,
|
||||
torch::Tensor const& q_pe,
|
||||
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||
torch::Tensor const& seq_lens,
|
||||
torch::Tensor const& page_table, double scale);
|
||||
#endif
|
||||
|
||||
void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
|
||||
torch::Tensor const& q_pe,
|
||||
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||
torch::Tensor const& seq_lens,
|
||||
torch::Tensor const& page_table, double scale) {
|
||||
#if defined ENABLE_CUTLASS_MLA && ENABLE_CUTLASS_MLA
|
||||
return cutlass_mla_decode_sm100a(out, q_nope, q_pe, kv_c_and_k_pe_cache,
|
||||
seq_lens, page_table, scale);
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled cutlass MLA");
|
||||
}
|
||||
@ -1,225 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/kernel_hardware_info.h"
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
|
||||
#include "device/sm100_mla.hpp"
|
||||
#include "kernel/sm100_mla_tile_scheduler.hpp"
|
||||
|
||||
using namespace cute;
|
||||
using namespace cutlass::fmha::kernel;
|
||||
|
||||
template <typename T, bool PersistenceOption = true>
|
||||
struct MlaSm100 {
|
||||
using Element = T;
|
||||
using ElementAcc = float;
|
||||
using ElementOut = T;
|
||||
|
||||
using TileShape = Shape<_128, _128, Shape<_512, _64>>;
|
||||
using TileShapeH = cute::tuple_element_t<0, TileShape>;
|
||||
using TileShapeD = cute::tuple_element_t<2, TileShape>;
|
||||
|
||||
// H K (D_latent D_rope) B
|
||||
using ProblemShape = cute::tuple<TileShapeH, int, TileShapeD, int>;
|
||||
|
||||
using StrideQ = cute::tuple<int64_t, _1, int64_t>; // H D B
|
||||
using StrideK = cute::tuple<int64_t, _1, int64_t>; // K D B
|
||||
using StrideO = StrideK; // H D B
|
||||
using StrideLSE = cute::tuple<_1, int>; // H B
|
||||
|
||||
using TileScheduler =
|
||||
std::conditional_t<PersistenceOption, Sm100MlaPersistentTileScheduler,
|
||||
Sm100MlaIndividualTileScheduler>;
|
||||
|
||||
using FmhaKernel =
|
||||
cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized<
|
||||
TileShape, Element, ElementAcc, ElementOut, ElementAcc, TileScheduler,
|
||||
/*kIsCpAsync=*/true>;
|
||||
using Fmha = cutlass::fmha::device::MLA<FmhaKernel>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
typename T::Fmha::Arguments args_from_options(
|
||||
at::Tensor const& out, at::Tensor const& q_nope, at::Tensor const& q_pe,
|
||||
at::Tensor const& kv_c_and_k_pe_cache, at::Tensor const& seq_lens,
|
||||
at::Tensor const& page_table, double scale) {
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = q_nope.device().index();
|
||||
hw_info.sm_count =
|
||||
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
|
||||
hw_info.device_id);
|
||||
|
||||
int batches = q_nope.sizes()[0];
|
||||
int page_count_per_seq = page_table.sizes()[1];
|
||||
int page_count_total = kv_c_and_k_pe_cache.sizes()[0];
|
||||
int page_size = kv_c_and_k_pe_cache.sizes()[1];
|
||||
int max_seq_len = page_size * page_count_per_seq;
|
||||
using TileShapeH = typename T::TileShapeH;
|
||||
using TileShapeD = typename T::TileShapeD;
|
||||
auto problem_shape =
|
||||
cute::make_tuple(TileShapeH{}, max_seq_len, TileShapeD{}, batches);
|
||||
|
||||
auto [H, K, D, B] = problem_shape;
|
||||
auto [D_latent, D_rope] = D;
|
||||
|
||||
using StrideQ = typename T::StrideQ;
|
||||
using StrideK = typename T::StrideK;
|
||||
using StrideO = typename T::StrideO;
|
||||
using StrideLSE = typename T::StrideLSE;
|
||||
|
||||
StrideQ stride_Q_latent = cute::make_tuple(
|
||||
static_cast<int64_t>(D_latent), _1{}, static_cast<int64_t>(H * D_latent));
|
||||
StrideQ stride_Q_rope = cute::make_tuple(static_cast<int64_t>(D_rope), _1{},
|
||||
static_cast<int64_t>(H * D_rope));
|
||||
StrideK stride_C =
|
||||
cute::make_tuple(static_cast<int64_t>(D_latent + D_rope), _1{},
|
||||
static_cast<int64_t>(page_size * (D_latent + D_rope)));
|
||||
StrideLSE stride_PT = cute::make_stride(_1{}, page_count_per_seq);
|
||||
StrideLSE stride_LSE = cute::make_tuple(_1{}, static_cast<int>(H));
|
||||
StrideO stride_O = cute::make_tuple(static_cast<int64_t>(D_latent), _1{},
|
||||
static_cast<int64_t>(H * D_latent));
|
||||
|
||||
using Element = typename T::Element;
|
||||
using ElementOut = typename T::ElementOut;
|
||||
using ElementAcc = typename T::ElementAcc;
|
||||
auto Q_latent_ptr = static_cast<Element*>(q_nope.data_ptr());
|
||||
auto Q_rope_ptr = static_cast<Element*>(q_pe.data_ptr());
|
||||
auto C_ptr = static_cast<Element*>(kv_c_and_k_pe_cache.data_ptr());
|
||||
auto scale_f = static_cast<float>(scale);
|
||||
typename T::Fmha::Arguments arguments{
|
||||
problem_shape,
|
||||
{scale_f, Q_latent_ptr, stride_Q_latent, Q_rope_ptr, stride_Q_rope, C_ptr,
|
||||
stride_C, C_ptr + D_latent, stride_C,
|
||||
static_cast<int*>(seq_lens.data_ptr()),
|
||||
static_cast<int*>(page_table.data_ptr()), stride_PT, page_count_total,
|
||||
page_size},
|
||||
{static_cast<ElementOut*>(out.data_ptr()), stride_O,
|
||||
static_cast<ElementAcc*>(nullptr), stride_LSE},
|
||||
hw_info,
|
||||
1, // split_kv
|
||||
nullptr, // is_var_split_kv
|
||||
};
|
||||
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
|
||||
// split_kv automatically based on batch size and sequence length to balance
|
||||
// workload across available SMs. Consider using var_split_kv for manual
|
||||
// control if needed.
|
||||
T::Fmha::set_split_kv(arguments);
|
||||
return arguments;
|
||||
}
|
||||
|
||||
template <typename Element>
|
||||
void runMla(at::Tensor const& out, at::Tensor const& q_nope,
|
||||
at::Tensor const& q_pe, at::Tensor const& kv_c_and_k_pe_cache,
|
||||
at::Tensor const& seq_lens, at::Tensor const& page_table,
|
||||
float scale, cudaStream_t stream) {
|
||||
using MlaSm100Type = MlaSm100<Element>;
|
||||
typename MlaSm100Type::Fmha fmha;
|
||||
auto arguments = args_from_options<MlaSm100Type>(
|
||||
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, scale);
|
||||
size_t workspace_size = MlaSm100Type::Fmha::get_workspace_size(arguments);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(q_nope.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
CUTLASS_CHECK(fmha.can_implement(arguments));
|
||||
|
||||
CUTLASS_CHECK(fmha.initialize(arguments, workspace.data_ptr(), stream));
|
||||
|
||||
CUTLASS_CHECK(fmha.run(arguments, workspace.data_ptr(), stream));
|
||||
}
|
||||
|
||||
void cutlass_mla_decode_sm100a(torch::Tensor const& out,
|
||||
torch::Tensor const& q_nope,
|
||||
torch::Tensor const& q_pe,
|
||||
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||
torch::Tensor const& seq_lens,
|
||||
torch::Tensor const& page_table, double scale) {
|
||||
TORCH_CHECK(q_nope.device().is_cuda(), "q_nope must be on CUDA");
|
||||
TORCH_CHECK(q_nope.dim() == 3, "q_nope must be a 3D tensor");
|
||||
TORCH_CHECK(q_pe.dim() == 3, "q_pe must be a 3D tensor");
|
||||
TORCH_CHECK(kv_c_and_k_pe_cache.dim() == 3,
|
||||
"kv_c_and_k_pe_cache must be a 3D tensor");
|
||||
TORCH_CHECK(seq_lens.dim() == 1, "seq_lens must be a 1D tensor");
|
||||
TORCH_CHECK(page_table.dim() == 2, "page_table must be a 2D tensor");
|
||||
TORCH_CHECK(out.dim() == 3, "out must be a 3D tensor");
|
||||
|
||||
auto B_q_nope = q_nope.size(0);
|
||||
auto H_q_nope = q_nope.size(1);
|
||||
auto D_q_nope = q_nope.size(2);
|
||||
auto B_q_pe = q_pe.size(0);
|
||||
auto H_q_pe = q_pe.size(1);
|
||||
auto D_q_pe = q_pe.size(2);
|
||||
auto B_pt = page_table.size(0);
|
||||
auto PAGE_NUM = page_table.size(1);
|
||||
auto PAGE_SIZE = kv_c_and_k_pe_cache.size(1);
|
||||
auto D_ckv = kv_c_and_k_pe_cache.size(2);
|
||||
auto B_o = out.size(0);
|
||||
auto H_o = out.size(1);
|
||||
auto D_o = out.size(2);
|
||||
|
||||
TORCH_CHECK(D_q_nope == 512, "D_q_nope must be equal to 512");
|
||||
TORCH_CHECK(D_q_pe == 64, "D_q_pe must be equal to 64");
|
||||
TORCH_CHECK(D_ckv == 576, "D_ckv must be equal to 576");
|
||||
TORCH_CHECK(H_q_nope == H_q_pe && H_q_nope == H_o && H_o == 128,
|
||||
"H_q_nope, H_q_pe, and H_o must be equal to 128");
|
||||
TORCH_CHECK(PAGE_SIZE > 0 && (PAGE_SIZE & (PAGE_SIZE - 1)) == 0,
|
||||
"PAGE_SIZE must be a power of 2");
|
||||
TORCH_CHECK(
|
||||
B_q_nope == B_q_pe && B_q_nope == B_pt && B_q_nope == B_o,
|
||||
"Batch dims must be same for page_table, q_nope and q_pe, and out");
|
||||
TORCH_CHECK(PAGE_NUM % (128 / PAGE_SIZE) == 0,
|
||||
"PAGE_NUM must be divisible by 128 / PAGE_SIZE");
|
||||
TORCH_CHECK(D_o == 512, "D_o must be equal to 512");
|
||||
|
||||
TORCH_CHECK(q_nope.dtype() == at::ScalarType::Half ||
|
||||
q_nope.dtype() == at::ScalarType::BFloat16 ||
|
||||
q_nope.dtype() == at::ScalarType::Float8_e4m3fn,
|
||||
"q_nope must be a half, bfloat16, or float8_e4m3fn tensor");
|
||||
TORCH_CHECK(kv_c_and_k_pe_cache.dtype() == q_nope.dtype() &&
|
||||
q_nope.dtype() == q_pe.dtype(),
|
||||
"kv_c_and_k_pe_cache, q_nope, and q_pe must be the same type");
|
||||
TORCH_CHECK(seq_lens.dtype() == torch::kInt32,
|
||||
"seq_lens must be a 32-bit integer tensor");
|
||||
TORCH_CHECK(page_table.dtype() == torch::kInt32,
|
||||
"page_table must be a 32-bit integer tensor");
|
||||
|
||||
auto in_dtype = q_nope.dtype();
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(q_nope));
|
||||
const cudaStream_t stream =
|
||||
at::cuda::getCurrentCUDAStream(q_nope.get_device());
|
||||
if (in_dtype == at::ScalarType::Half) {
|
||||
runMla<cutlass::half_t>(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens,
|
||||
page_table, scale, stream);
|
||||
} else if (in_dtype == at::ScalarType::BFloat16) {
|
||||
runMla<cutlass::bfloat16_t>(out, q_nope, q_pe, kv_c_and_k_pe_cache,
|
||||
seq_lens, page_table, scale, stream);
|
||||
} else if (in_dtype == at::ScalarType::Float8_e4m3fn) {
|
||||
runMla<cutlass::float_e4m3_t>(out, q_nope, q_pe, kv_c_and_k_pe_cache,
|
||||
seq_lens, page_table, scale, stream);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported input data type of MLA");
|
||||
}
|
||||
}
|
||||
@ -510,13 +510,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def("cutlass_sparse_compress(Tensor a) -> Tensor[]");
|
||||
ops.impl("cutlass_sparse_compress", &cutlass_sparse_compress);
|
||||
|
||||
// CUTLASS MLA decode
|
||||
ops.def(
|
||||
"cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe,"
|
||||
" Tensor kv_c_and_k_pe_cache, Tensor seq_lens,"
|
||||
" Tensor page_table, float scale) -> ()");
|
||||
ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
|
||||
|
||||
// SM100 CUTLASS MLA decode
|
||||
ops.def(
|
||||
"sm100_cutlass_mla_decode(Tensor! out, Tensor! lse, Tensor q_nope,"
|
||||
|
||||
@ -1823,15 +1823,6 @@ def flash_mla_with_kvcache(
|
||||
return out, softmax_lse
|
||||
|
||||
|
||||
def cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
|
||||
seq_lens: torch.Tensor, page_table: torch.Tensor,
|
||||
scale: float) -> torch.Tensor:
|
||||
torch.ops._C.cutlass_mla_decode(out, q_nope, q_pe, kv_c_and_k_pe_cache,
|
||||
seq_lens, page_table, scale)
|
||||
return out
|
||||
|
||||
|
||||
def sm100_cutlass_mla_decode(out: torch.Tensor, lse: torch.Tensor,
|
||||
q_nope: torch.Tensor, q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from typing import ClassVar, Optional
|
||||
from typing import ClassVar, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -109,12 +109,6 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
"are not implemented for "
|
||||
"CutlassMLAImpl")
|
||||
|
||||
self._use_old_cutlass_mla = False
|
||||
force_old_cutlass = os.environ.get("FORCE_OLD_CUTLASS_MLA", None)
|
||||
if force_old_cutlass:
|
||||
logger.warning_once("Forcing old cutlass mla kernel")
|
||||
self._use_old_cutlass_mla = True
|
||||
|
||||
# TODO: Currently, num_kv_splits is limited to 16 to avoid hanging
|
||||
# issues. In case the code hangs, use:
|
||||
# FORCE_NUM_KV_SPLITS=1
|
||||
@ -219,16 +213,22 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
|
||||
return out, returned_lse
|
||||
|
||||
def _sm100_forward_decode(
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if type(q) is tuple:
|
||||
q_nope, q_pe = q
|
||||
else:
|
||||
q_nope, q_pe = torch.split(
|
||||
q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
|
||||
# Adjust workspace size (if necessary)
|
||||
self._workspace.ensure_size(attn_metadata, self._num_kv_splits)
|
||||
|
||||
@ -245,57 +245,3 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
)
|
||||
|
||||
return o, (lse if self.need_to_return_lse_for_decode else None)
|
||||
|
||||
# TODO: Currently we leave it here only for backup in case something is
|
||||
# wrong with the new SM100 CUTLASS MLA kernel
|
||||
def _old_forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"FP8 Cutlass MLA not supported with FORCE_OLD_CUTLASS_MLA")
|
||||
|
||||
B = q_nope.shape[0]
|
||||
|
||||
o = torch.empty((B, self.num_heads, self.kv_lora_rank),
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
|
||||
# Run MLA
|
||||
# Clone q_nope and q_pe to make sure strides computation is correct.
|
||||
q_nope = q_nope.clone()
|
||||
q_pe = q_pe.clone()
|
||||
|
||||
ops.cutlass_mla_decode(o, q_nope, q_pe, kv_c_and_k_pe_cache,
|
||||
attn_metadata.decode.seq_lens,
|
||||
attn_metadata.decode.block_table, self.scale)
|
||||
|
||||
return o
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
if type(q) is tuple:
|
||||
q_nope, q_pe = q
|
||||
else:
|
||||
q_nope, q_pe = torch.split(
|
||||
q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
if self._use_old_cutlass_mla:
|
||||
# TODO: Remove the old cutlass MLA kernel after more extensive
|
||||
# testing
|
||||
return self._old_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache,
|
||||
attn_metadata), None
|
||||
|
||||
return self._sm100_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache,
|
||||
attn_metadata)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user