[Perf] Add swap_ab to SM90 FP8 non-block CUTLASS moe grouped gemm (#20911)

Signed-off-by: Shixian Cui <shixian@amazon.com>
Co-authored-by: Shixian Cui <shixian@amazon.com>
This commit is contained in:
shixianc 2025-07-17 21:34:43 -07:00 committed by GitHub
parent c7d8724e78
commit 5780121c95
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 135 additions and 50 deletions

View File

@ -29,19 +29,36 @@ struct sm90_fp8_config_default {
template <typename InType, typename OutType, template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue> template <typename, typename, typename> typename Epilogue>
struct sm90_fp8_config_M16 { struct sm90_fp8_config_M4 {
// M in [1, 16] // M in [1, 4]
static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule = using KernelSchedule =
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
using EpilogueSchedule = using EpilogueSchedule =
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using TileShape = cute::Shape<cute::_64, cute::_64, cute::_128>; using TileShape = cute::Shape<cute::_128, cute::_16, cute::_128>;
using ClusterShape = cute::Shape<cute::_1, cute::_4, cute::_1>; using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using Cutlass3xGemm = using Cutlass3xGemm =
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape, cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>; KernelSchedule, EpilogueSchedule, true>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_fp8_config_M64 {
// M in (4, 64]
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
using EpilogueSchedule =
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using TileShape = cute::Shape<cute::_128, cute::_16, cute::_256>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using Cutlass3xGemm =
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule, true>;
}; };
template <typename InType, typename OutType, template <typename InType, typename OutType,
@ -102,7 +119,9 @@ void run_cutlass_moe_mm_sm90(
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
using Cutlass3xGemmK8192 = typename sm90_fp8_config_K8192< using Cutlass3xGemmK8192 = typename sm90_fp8_config_K8192<
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
using Cutlass3xGemmM16 = typename sm90_fp8_config_M16< using Cutlass3xGemmM4 = typename sm90_fp8_config_M4<
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
using Cutlass3xGemmM64 = typename sm90_fp8_config_M64<
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
using Cutlass3xGemmDefault = typename sm90_fp8_config_default< using Cutlass3xGemmDefault = typename sm90_fp8_config_default<
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
@ -111,7 +130,18 @@ void run_cutlass_moe_mm_sm90(
uint32_t const n = out_tensors.size(1); uint32_t const n = out_tensors.size(1);
uint32_t const k = a_tensors.size(1); uint32_t const k = a_tensors.size(1);
if (n >= 8192) { // Use swap_ab for M <= 64 by default to reduce padding
if (m <= 4) {
cutlass_group_gemm_caller<Cutlass3xGemmM4>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else if (m <= 64) {
cutlass_group_gemm_caller<Cutlass3xGemmM64>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else if (n >= 8192) {
cutlass_group_gemm_caller<Cutlass3xGemmN8192>( cutlass_group_gemm_caller<Cutlass3xGemmN8192>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token, problem_sizes, a_strides, b_strides, c_strides, per_act_token,
@ -121,11 +151,6 @@ void run_cutlass_moe_mm_sm90(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token, problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch); per_out_ch);
} else if (m <= 16) {
cutlass_group_gemm_caller<Cutlass3xGemmM16>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else { } else {
cutlass_group_gemm_caller<Cutlass3xGemmDefault>( cutlass_group_gemm_caller<Cutlass3xGemmDefault>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,

View File

@ -22,14 +22,23 @@ using ArchTag = cutlass::arch::Sm90;
using OperatorClass = cutlass::arch::OpClassTensorOp; using OperatorClass = cutlass::arch::OpClassTensorOp;
using LayoutA = cutlass::layout::RowMajor; using LayoutA = cutlass::layout::RowMajor;
using LayoutA_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutA>::type;
using LayoutB = cutlass::layout::ColumnMajor; using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor; using LayoutB_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutB>::type;
using LayoutD = cutlass::layout::RowMajor;
using LayoutD_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutD>::type;
using LayoutC = LayoutD;
using LayoutC_Transpose = LayoutD_Transpose;
template <typename ElementAB_, typename ElementC_, template <typename ElementAB_, typename ElementC_,
template <typename, typename, typename> typename Epilogue_, template <typename, typename, typename> typename Epilogue_,
typename TileShape, typename ClusterShape, typename KernelSchedule, typename TileShape, typename ClusterShape, typename KernelSchedule,
typename EpilogueSchedule> typename EpilogueSchedule, bool swap_ab_ = false>
struct cutlass_3x_group_gemm { struct cutlass_3x_group_gemm {
static constexpr bool swap_ab = swap_ab_;
using ElementAB = ElementAB_; using ElementAB = ElementAB_;
using ElementC = void; using ElementC = void;
using ElementD = ElementC_; using ElementD = ElementC_;
@ -37,9 +46,6 @@ struct cutlass_3x_group_gemm {
using Epilogue = Epilogue_<ElementAccumulator, ElementD, TileShape>; using Epilogue = Epilogue_<ElementAccumulator, ElementD, TileShape>;
using StrideC =
cute::remove_pointer_t<cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>>;
static constexpr int AlignmentAB = static constexpr int AlignmentAB =
128 / cutlass::sizeof_bits<ElementAB>::value; 128 / cutlass::sizeof_bits<ElementAB>::value;
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementD>::value; static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementD>::value;
@ -50,19 +56,26 @@ struct cutlass_3x_group_gemm {
typename cutlass::epilogue::collective::CollectiveBuilder< typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass, TileShape, ClusterShape, ArchTag, OperatorClass, TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementD, ElementAccumulator, ElementC,
LayoutC*, AlignmentC, EpilogueSchedule, EVTCompute>::CollectiveOp; conditional_t<swap_ab, LayoutC_Transpose*, LayoutC*>, AlignmentC,
ElementD, conditional_t<swap_ab, LayoutD_Transpose*, LayoutD*>,
AlignmentC, EpilogueSchedule, EVTCompute>::CollectiveOp;
static constexpr size_t CEStorageSize = static constexpr size_t CEStorageSize =
sizeof(typename CollectiveEpilogue::SharedStorage); sizeof(typename CollectiveEpilogue::SharedStorage);
using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(CEStorageSize)>; static_cast<int>(CEStorageSize)>;
using CollectiveMainloop = using CollectiveMainloop = conditional_t<
swap_ab,
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass, ElementAB, LayoutB_Transpose*, AlignmentAB,
ElementAB, LayoutA_Transpose*, AlignmentAB, ElementAccumulator,
TileShape, ClusterShape, Stages, KernelSchedule>::CollectiveOp,
typename cutlass::gemm::collective::CollectiveBuilder< typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass, ElementAB, LayoutA*, AlignmentAB, ElementAB, ArchTag, OperatorClass, ElementAB, LayoutA*, AlignmentAB, ElementAB,
LayoutB*, AlignmentAB, ElementAccumulator, TileShape, ClusterShape, LayoutB*, AlignmentAB, ElementAccumulator, TileShape, ClusterShape,
Stages, KernelSchedule>::CollectiveOp; Stages, KernelSchedule>::CollectiveOp>;
using KernelType = enable_sm90_only<cutlass::gemm::kernel::GemmUniversal< using KernelType = enable_sm90_only<cutlass::gemm::kernel::GemmUniversal<
ProblemShape, CollectiveMainloop, CollectiveEpilogue>>; ProblemShape, CollectiveMainloop, CollectiveEpilogue>>;
@ -78,12 +91,12 @@ void cutlass_group_gemm_caller(
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides, torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) { bool per_act_token, bool per_out_ch) {
static constexpr bool swap_ab = Gemm::swap_ab;
using ElementAB = typename Gemm::ElementAB; using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD; using ElementD = typename Gemm::ElementD;
int num_experts = static_cast<int>(expert_offsets.size(0)); int num_experts = static_cast<int>(expert_offsets.size(0));
int k_size = a_tensors.size(1);
int n_size = out_tensors.size(1);
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
@ -110,19 +123,35 @@ void cutlass_group_gemm_caller(
problem_sizes.data_ptr()); problem_sizes.data_ptr());
ProblemShape prob_shape{num_experts, problem_sizes_as_shapes, nullptr}; ProblemShape prob_shape{num_experts, problem_sizes_as_shapes, nullptr};
typename GemmKernel::MainloopArguments mainloop_args{ typename GemmKernel::MainloopArguments mainloop_args;
static_cast<const ElementAB**>(a_ptrs.data_ptr()), if constexpr (swap_ab) {
static_cast<StrideA*>(a_strides.data_ptr()), mainloop_args = typename GemmKernel::MainloopArguments{
static_cast<const ElementAB**>(b_ptrs.data_ptr()), static_cast<const ElementAB**>(b_ptrs.data_ptr()),
static_cast<StrideB*>(b_strides.data_ptr())}; static_cast<StrideB*>(b_strides.data_ptr()),
static_cast<const ElementAB**>(a_ptrs.data_ptr()),
static_cast<StrideA*>(a_strides.data_ptr())};
} else {
mainloop_args = typename GemmKernel::MainloopArguments{
static_cast<const ElementAB**>(a_ptrs.data_ptr()),
static_cast<StrideA*>(a_strides.data_ptr()),
static_cast<const ElementAB**>(b_ptrs.data_ptr()),
static_cast<StrideB*>(b_strides.data_ptr())};
}
// Currently, we are only able to do broadcast on either all or none a_scales // Currently, we are only able to do broadcast on either all or none a_scales
// and on either all or none b_scales // and on either all or none b_scales
typename GemmKernel::EpilogueArguments epilogue_args{ typename GemmKernel::EpilogueArguments epilogue_args{
Gemm::Epilogue::prepare_args( Gemm::Epilogue::prepare_args(
static_cast<const ElementAccumulator**>(a_scales_ptrs.data_ptr()), swap_ab ? static_cast<const ElementAccumulator**>(
static_cast<const ElementAccumulator**>(b_scales_ptrs.data_ptr()), b_scales_ptrs.data_ptr())
per_act_token, per_out_ch), : static_cast<const ElementAccumulator**>(
a_scales_ptrs.data_ptr()),
swap_ab ? static_cast<const ElementAccumulator**>(
a_scales_ptrs.data_ptr())
: static_cast<const ElementAccumulator**>(
b_scales_ptrs.data_ptr()),
swap_ab ? per_out_ch : per_act_token,
swap_ab ? per_act_token : per_out_ch),
nullptr, static_cast<StrideC*>(c_strides.data_ptr()), nullptr, static_cast<StrideC*>(c_strides.data_ptr()),
static_cast<ElementD**>(out_ptrs.data_ptr()), static_cast<ElementD**>(out_ptrs.data_ptr()),
static_cast<StrideC*>(c_strides.data_ptr())}; static_cast<StrideC*>(c_strides.data_ptr())};

View File

@ -6,7 +6,10 @@
#include <iostream> #include <iostream>
constexpr uint64_t THREADS_PER_EXPERT = 512; constexpr uint64_t THREADS_PER_EXPERT = 512;
// threshold must match the dispatch logic in run_cutlass_moe_mm_sm90()
constexpr int SWAP_AB_THRESHOLD = 64;
template <bool SWAP_AB>
__global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids, __global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids,
int32_t* problem_sizes1, int32_t* problem_sizes1,
int32_t* problem_sizes2, int32_t* problem_sizes2,
@ -24,40 +27,53 @@ __global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids,
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
int final_occurrences = atomic_buffer[expert_id]; int final_occurrences = atomic_buffer[expert_id];
problem_sizes1[expert_id * 3] = final_occurrences; if constexpr (!SWAP_AB) {
problem_sizes1[expert_id * 3 + 1] = 2 * n; problem_sizes1[expert_id * 3] = final_occurrences;
problem_sizes1[expert_id * 3 + 2] = k; problem_sizes1[expert_id * 3 + 1] = 2 * n;
problem_sizes2[expert_id * 3] = final_occurrences; problem_sizes1[expert_id * 3 + 2] = k;
problem_sizes2[expert_id * 3 + 1] = k; problem_sizes2[expert_id * 3] = final_occurrences;
problem_sizes2[expert_id * 3 + 2] = n; problem_sizes2[expert_id * 3 + 1] = k;
problem_sizes2[expert_id * 3 + 2] = n;
} else {
problem_sizes1[expert_id * 3] = 2 * n;
problem_sizes1[expert_id * 3 + 1] = final_occurrences;
problem_sizes1[expert_id * 3 + 2] = k;
problem_sizes2[expert_id * 3] = k;
problem_sizes2[expert_id * 3 + 1] = final_occurrences;
problem_sizes2[expert_id * 3 + 2] = n;
}
} }
} }
__global__ void compute_expert_offsets( __global__ void compute_expert_offsets(
const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets, const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets,
int32_t* atomic_buffer, const int num_experts) { int32_t* atomic_buffer, const int num_experts, const int topk_length) {
int32_t tot_offset = 0; int32_t tot_offset = 0;
expert_offsets[0] = 0; expert_offsets[0] = 0;
for (int i = 0; i < num_experts; ++i) { for (int i = 0; i < num_experts; ++i) {
atomic_buffer[i] = tot_offset; atomic_buffer[i] = tot_offset;
tot_offset += problem_sizes1[i * 3]; tot_offset += topk_length > SWAP_AB_THRESHOLD ? problem_sizes1[i * 3]
: problem_sizes1[i * 3 + 1];
expert_offsets[i + 1] = tot_offset; expert_offsets[i + 1] = tot_offset;
} }
} }
__global__ void compute_expert_blockscale_offsets( __global__ void compute_expert_blockscale_offsets(
const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets, const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets,
int32_t* blockscale_offsets, int32_t* atomic_buffer, int32_t* blockscale_offsets, int32_t* atomic_buffer, const int num_experts,
const int num_experts) { const int topk_length) {
int32_t tot_offset = 0; int32_t tot_offset = 0;
int32_t tot_offset_round = 0; int32_t tot_offset_round = 0;
expert_offsets[0] = 0; expert_offsets[0] = 0;
blockscale_offsets[0] = 0; blockscale_offsets[0] = 0;
for (int i = 0; i < num_experts; ++i) { for (int i = 0; i < num_experts; ++i) {
int32_t cur_offset = topk_length > SWAP_AB_THRESHOLD
? problem_sizes1[i * 3]
: problem_sizes1[i * 3 + 1];
atomic_buffer[i] = tot_offset; atomic_buffer[i] = tot_offset;
tot_offset += problem_sizes1[i * 3]; tot_offset += cur_offset;
expert_offsets[i + 1] = tot_offset; expert_offsets[i + 1] = tot_offset;
tot_offset_round += (problem_sizes1[i * 3] + (128 - 1)) / 128 * 128; tot_offset_round += (cur_offset + (128 - 1)) / 128 * 128;
blockscale_offsets[i + 1] = tot_offset_round; blockscale_offsets[i + 1] = tot_offset_round;
} }
} }
@ -102,22 +118,36 @@ void get_cutlass_moe_mm_data_caller(
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32); torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel()); int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
compute_problem_sizes<<<num_experts, num_threads, 0, stream>>>(
static_cast<const int32_t*>(topk_ids.data_ptr()), if (topk_ids.numel() > SWAP_AB_THRESHOLD) {
static_cast<int32_t*>(problem_sizes1.data_ptr()), compute_problem_sizes<false><<<num_experts, num_threads, 0, stream>>>(
static_cast<int32_t*>(problem_sizes2.data_ptr()), static_cast<const int32_t*>(topk_ids.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n, k); static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n,
k);
} else {
compute_problem_sizes<true><<<num_experts, num_threads, 0, stream>>>(
static_cast<const int32_t*>(topk_ids.data_ptr()),
static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n,
k);
}
if (blockscale_offsets.has_value()) { if (blockscale_offsets.has_value()) {
compute_expert_blockscale_offsets<<<1, 1, 0, stream>>>( compute_expert_blockscale_offsets<<<1, 1, 0, stream>>>(
static_cast<const int32_t*>(problem_sizes1.data_ptr()), static_cast<const int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(expert_offsets.data_ptr()), static_cast<int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(blockscale_offsets.value().data_ptr()), static_cast<int32_t*>(blockscale_offsets.value().data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts); static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts,
topk_ids.numel());
} else { } else {
compute_expert_offsets<<<1, 1, 0, stream>>>( compute_expert_offsets<<<1, 1, 0, stream>>>(
static_cast<const int32_t*>(problem_sizes1.data_ptr()), static_cast<const int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(expert_offsets.data_ptr()), static_cast<int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts); static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts,
topk_ids.numel());
} }
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>( compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
static_cast<const int32_t*>(topk_ids.data_ptr()), static_cast<const int32_t*>(topk_ids.data_ptr()),

View File

@ -25,6 +25,7 @@ MNK_FACTORS = [
(2, 1024, 1536), (2, 1024, 1536),
(2, 3072, 1024), (2, 3072, 1024),
(2, 3072, 1536), (2, 3072, 1536),
(7, 3072, 1536),
(64, 1024, 1024), (64, 1024, 1024),
(64, 1024, 1536), (64, 1024, 1536),
(64, 3072, 1024), (64, 3072, 1024),