[Kernel]Add streamK for block-quantized CUTLASS kernels (#12978)

This commit is contained in:
leoneo 2025-02-21 14:14:24 +08:00 committed by GitHub
parent 34ad27fe83
commit 839b27c6cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 44 additions and 12 deletions

View File

@ -30,12 +30,18 @@ static inline cute::Shape<int, int, int, int> get_problem_shape(
}
template <typename GemmKernel>
void cutlass_gemm_caller(torch::Device device,
cute::Shape<int, int, int, int> prob_shape,
typename GemmKernel::MainloopArguments mainloop_args,
typename GemmKernel::EpilogueArguments epilogue_args) {
void cutlass_gemm_caller(
torch::Device device, cute::Shape<int, int, int, int> prob_shape,
typename GemmKernel::MainloopArguments mainloop_args,
typename GemmKernel::EpilogueArguments epilogue_args,
typename GemmKernel::TileSchedulerArguments scheduler = {}) {
cutlass::KernelHardwareInfo hw_info;
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
prob_shape, mainloop_args, epilogue_args};
prob_shape,
mainloop_args,
epilogue_args,
hw_info,
scheduler};
// Launch the CUTLASS GEMM kernel.
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

View File

@ -22,8 +22,9 @@ namespace vllm {
using namespace cute;
template <typename OutType, int GroupSizeM_, int GroupSizeN_, int GroupSizeK_,
int TileSizeM_ = 128, class ClusterShape = Shape<_1, _2, _1>>
template <typename SchedulerType, typename OutType, int GroupSizeM_,
int GroupSizeN_, int GroupSizeK_, int TileSizeM_ = 128,
class ClusterShape = Shape<_1, _2, _1>>
struct cutlass_3x_gemm_fp8_blockwise {
using GroupSizeM = Int<GroupSizeM_>;
using GroupSizeN = Int<GroupSizeN_>;
@ -84,7 +85,7 @@ struct cutlass_3x_gemm_fp8_blockwise {
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
cutlass::gemm::PersistentScheduler>>;
SchedulerType>>;
struct GemmKernel : public KernelType {};
@ -150,8 +151,24 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
typename GemmKernel::EpilogueArguments epilogue_args{
{}, c_ptr, c_stride, c_ptr, c_stride};
typename GemmKernel::TileSchedulerArguments scheduler;
static constexpr bool UsesStreamKScheduler =
cute::is_same_v<typename GemmKernel::TileSchedulerTag,
cutlass::gemm::StreamKScheduler>;
if constexpr (UsesStreamKScheduler) {
using DecompositionMode = typename cutlass::gemm::kernel::detail::
PersistentTileSchedulerSm90StreamKParams::DecompositionMode;
using ReductionMode = typename cutlass::gemm::kernel::detail::
PersistentTileSchedulerSm90StreamKParams::ReductionMode;
scheduler.decomposition_mode = DecompositionMode::StreamK;
scheduler.reduction_mode = ReductionMode::Nondeterministic;
}
c3x::cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
epilogue_args);
epilogue_args, scheduler);
}
template <typename OutType>
@ -160,9 +177,18 @@ void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
cutlass_gemm_caller_blockwise<
cutlass_3x_gemm_fp8_blockwise<OutType, 1, 128, 128>>(out, a, b, a_scales,
b_scales);
auto k = a.size(1);
auto n = b.size(1);
if (k > 3 * n) {
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
cutlass::gemm::StreamKScheduler, OutType, 1, 128, 128>>(
out, a, b, a_scales, b_scales);
} else {
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
cutlass::gemm::PersistentScheduler, OutType, 1, 128, 128>>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm