mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 15:46:51 +08:00
[Kernel]Add streamK for block-quantized CUTLASS kernels (#12978)
This commit is contained in:
parent
34ad27fe83
commit
839b27c6cc
@ -30,12 +30,18 @@ static inline cute::Shape<int, int, int, int> get_problem_shape(
|
||||
}
|
||||
|
||||
template <typename GemmKernel>
|
||||
void cutlass_gemm_caller(torch::Device device,
|
||||
cute::Shape<int, int, int, int> prob_shape,
|
||||
typename GemmKernel::MainloopArguments mainloop_args,
|
||||
typename GemmKernel::EpilogueArguments epilogue_args) {
|
||||
void cutlass_gemm_caller(
|
||||
torch::Device device, cute::Shape<int, int, int, int> prob_shape,
|
||||
typename GemmKernel::MainloopArguments mainloop_args,
|
||||
typename GemmKernel::EpilogueArguments epilogue_args,
|
||||
typename GemmKernel::TileSchedulerArguments scheduler = {}) {
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
prob_shape, mainloop_args, epilogue_args};
|
||||
prob_shape,
|
||||
mainloop_args,
|
||||
epilogue_args,
|
||||
hw_info,
|
||||
scheduler};
|
||||
|
||||
// Launch the CUTLASS GEMM kernel.
|
||||
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
@ -22,8 +22,9 @@ namespace vllm {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename OutType, int GroupSizeM_, int GroupSizeN_, int GroupSizeK_,
|
||||
int TileSizeM_ = 128, class ClusterShape = Shape<_1, _2, _1>>
|
||||
template <typename SchedulerType, typename OutType, int GroupSizeM_,
|
||||
int GroupSizeN_, int GroupSizeK_, int TileSizeM_ = 128,
|
||||
class ClusterShape = Shape<_1, _2, _1>>
|
||||
struct cutlass_3x_gemm_fp8_blockwise {
|
||||
using GroupSizeM = Int<GroupSizeM_>;
|
||||
using GroupSizeN = Int<GroupSizeN_>;
|
||||
@ -84,7 +85,7 @@ struct cutlass_3x_gemm_fp8_blockwise {
|
||||
|
||||
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
|
||||
cutlass::gemm::PersistentScheduler>>;
|
||||
SchedulerType>>;
|
||||
|
||||
struct GemmKernel : public KernelType {};
|
||||
|
||||
@ -150,8 +151,24 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
{}, c_ptr, c_stride, c_ptr, c_stride};
|
||||
|
||||
typename GemmKernel::TileSchedulerArguments scheduler;
|
||||
|
||||
static constexpr bool UsesStreamKScheduler =
|
||||
cute::is_same_v<typename GemmKernel::TileSchedulerTag,
|
||||
cutlass::gemm::StreamKScheduler>;
|
||||
|
||||
if constexpr (UsesStreamKScheduler) {
|
||||
using DecompositionMode = typename cutlass::gemm::kernel::detail::
|
||||
PersistentTileSchedulerSm90StreamKParams::DecompositionMode;
|
||||
using ReductionMode = typename cutlass::gemm::kernel::detail::
|
||||
PersistentTileSchedulerSm90StreamKParams::ReductionMode;
|
||||
|
||||
scheduler.decomposition_mode = DecompositionMode::StreamK;
|
||||
scheduler.reduction_mode = ReductionMode::Nondeterministic;
|
||||
}
|
||||
|
||||
c3x::cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
|
||||
epilogue_args);
|
||||
epilogue_args, scheduler);
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
@ -160,9 +177,18 @@ void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
cutlass_gemm_caller_blockwise<
|
||||
cutlass_3x_gemm_fp8_blockwise<OutType, 1, 128, 128>>(out, a, b, a_scales,
|
||||
b_scales);
|
||||
auto k = a.size(1);
|
||||
auto n = b.size(1);
|
||||
|
||||
if (k > 3 * n) {
|
||||
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
||||
cutlass::gemm::StreamKScheduler, OutType, 1, 128, 128>>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
} else {
|
||||
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
||||
cutlass::gemm::PersistentScheduler, OutType, 1, 128, 128>>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
Loading…
x
Reference in New Issue
Block a user