diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu index c88e134ae406b..b024482208d37 100644 --- a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu @@ -29,19 +29,36 @@ struct sm90_fp8_config_default { template typename Epilogue> -struct sm90_fp8_config_M16 { - // M in [1, 16] +struct sm90_fp8_config_M4 { + // M in [1, 4] static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; - using TileShape = cute::Shape; - using ClusterShape = cute::Shape; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; using Cutlass3xGemm = cutlass_3x_group_gemm; + KernelSchedule, EpilogueSchedule, true>; +}; + +template typename Epilogue> +struct sm90_fp8_config_M64 { + // M in (4, 64] + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = + cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + + using Cutlass3xGemm = + cutlass_3x_group_gemm; }; template ::Cutlass3xGemm; using Cutlass3xGemmK8192 = typename sm90_fp8_config_K8192< InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; - using Cutlass3xGemmM16 = typename sm90_fp8_config_M16< + using Cutlass3xGemmM4 = typename sm90_fp8_config_M4< + InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; + using Cutlass3xGemmM64 = typename sm90_fp8_config_M64< InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; using Cutlass3xGemmDefault = typename sm90_fp8_config_default< InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; @@ -111,7 +130,18 @@ void run_cutlass_moe_mm_sm90( uint32_t const n = out_tensors.size(1); uint32_t const k = a_tensors.size(1); - if (n >= 8192) { + // Use swap_ab for M <= 64 by default to reduce padding + if (m <= 4) { + cutlass_group_gemm_caller( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides, per_act_token, + per_out_ch); + } else if (m <= 64) { + cutlass_group_gemm_caller( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides, per_act_token, + per_out_ch); + } else if (n >= 8192) { cutlass_group_gemm_caller( out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, problem_sizes, a_strides, b_strides, c_strides, per_act_token, @@ -121,11 +151,6 @@ void run_cutlass_moe_mm_sm90( out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, problem_sizes, a_strides, b_strides, c_strides, per_act_token, per_out_ch); - } else if (m <= 16) { - cutlass_group_gemm_caller( - out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, - problem_sizes, a_strides, b_strides, c_strides, per_act_token, - per_out_ch); } else { cutlass_group_gemm_caller( out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh index bbd82d72e95bd..3225378a6ca0a 100644 --- a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh +++ b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh @@ -22,14 +22,23 @@ using ArchTag = cutlass::arch::Sm90; using OperatorClass = cutlass::arch::OpClassTensorOp; using LayoutA = cutlass::layout::RowMajor; +using LayoutA_Transpose = + typename cutlass::layout::LayoutTranspose::type; using LayoutB = cutlass::layout::ColumnMajor; -using LayoutC = cutlass::layout::RowMajor; +using LayoutB_Transpose = + typename cutlass::layout::LayoutTranspose::type; +using LayoutD = cutlass::layout::RowMajor; +using LayoutD_Transpose = + typename cutlass::layout::LayoutTranspose::type; +using LayoutC = LayoutD; +using LayoutC_Transpose = LayoutD_Transpose; template typename Epilogue_, typename TileShape, typename ClusterShape, typename KernelSchedule, - typename EpilogueSchedule> + typename EpilogueSchedule, bool swap_ab_ = false> struct cutlass_3x_group_gemm { + static constexpr bool swap_ab = swap_ab_; using ElementAB = ElementAB_; using ElementC = void; using ElementD = ElementC_; @@ -37,9 +46,6 @@ struct cutlass_3x_group_gemm { using Epilogue = Epilogue_; - using StrideC = - cute::remove_pointer_t, cute::Int<0>>>; - static constexpr int AlignmentAB = 128 / cutlass::sizeof_bits::value; static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; @@ -50,19 +56,26 @@ struct cutlass_3x_group_gemm { typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, - ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementD, - LayoutC*, AlignmentC, EpilogueSchedule, EVTCompute>::CollectiveOp; + ElementAccumulator, ElementC, + conditional_t, AlignmentC, + ElementD, conditional_t, + AlignmentC, EpilogueSchedule, EVTCompute>::CollectiveOp; static constexpr size_t CEStorageSize = sizeof(typename CollectiveEpilogue::SharedStorage); using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< static_cast(CEStorageSize)>; - using CollectiveMainloop = + using CollectiveMainloop = conditional_t< + swap_ab, + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementAB, LayoutB_Transpose*, AlignmentAB, + ElementAB, LayoutA_Transpose*, AlignmentAB, ElementAccumulator, + TileShape, ClusterShape, Stages, KernelSchedule>::CollectiveOp, typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, OperatorClass, ElementAB, LayoutA*, AlignmentAB, ElementAB, LayoutB*, AlignmentAB, ElementAccumulator, TileShape, ClusterShape, - Stages, KernelSchedule>::CollectiveOp; + Stages, KernelSchedule>::CollectiveOp>; using KernelType = enable_sm90_only>; @@ -78,12 +91,12 @@ void cutlass_group_gemm_caller( torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& b_strides, torch::Tensor const& c_strides, bool per_act_token, bool per_out_ch) { + static constexpr bool swap_ab = Gemm::swap_ab; + using ElementAB = typename Gemm::ElementAB; using ElementD = typename Gemm::ElementD; int num_experts = static_cast(expert_offsets.size(0)); - int k_size = a_tensors.size(1); - int n_size = out_tensors.size(1); auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); @@ -110,19 +123,35 @@ void cutlass_group_gemm_caller( problem_sizes.data_ptr()); ProblemShape prob_shape{num_experts, problem_sizes_as_shapes, nullptr}; - typename GemmKernel::MainloopArguments mainloop_args{ - static_cast(a_ptrs.data_ptr()), - static_cast(a_strides.data_ptr()), - static_cast(b_ptrs.data_ptr()), - static_cast(b_strides.data_ptr())}; + typename GemmKernel::MainloopArguments mainloop_args; + if constexpr (swap_ab) { + mainloop_args = typename GemmKernel::MainloopArguments{ + static_cast(b_ptrs.data_ptr()), + static_cast(b_strides.data_ptr()), + static_cast(a_ptrs.data_ptr()), + static_cast(a_strides.data_ptr())}; + } else { + mainloop_args = typename GemmKernel::MainloopArguments{ + static_cast(a_ptrs.data_ptr()), + static_cast(a_strides.data_ptr()), + static_cast(b_ptrs.data_ptr()), + static_cast(b_strides.data_ptr())}; + } // Currently, we are only able to do broadcast on either all or none a_scales // and on either all or none b_scales typename GemmKernel::EpilogueArguments epilogue_args{ Gemm::Epilogue::prepare_args( - static_cast(a_scales_ptrs.data_ptr()), - static_cast(b_scales_ptrs.data_ptr()), - per_act_token, per_out_ch), + swap_ab ? static_cast( + b_scales_ptrs.data_ptr()) + : static_cast( + a_scales_ptrs.data_ptr()), + swap_ab ? static_cast( + a_scales_ptrs.data_ptr()) + : static_cast( + b_scales_ptrs.data_ptr()), + swap_ab ? per_out_ch : per_act_token, + swap_ab ? per_act_token : per_out_ch), nullptr, static_cast(c_strides.data_ptr()), static_cast(out_ptrs.data_ptr()), static_cast(c_strides.data_ptr())}; diff --git a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu index 80c6589ab1716..623c9a2f096bf 100644 --- a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu +++ b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu @@ -6,7 +6,10 @@ #include constexpr uint64_t THREADS_PER_EXPERT = 512; +// threshold must match the dispatch logic in run_cutlass_moe_mm_sm90() +constexpr int SWAP_AB_THRESHOLD = 64; +template __global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids, int32_t* problem_sizes1, int32_t* problem_sizes2, @@ -24,40 +27,53 @@ __global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids, if (threadIdx.x == 0) { int final_occurrences = atomic_buffer[expert_id]; - problem_sizes1[expert_id * 3] = final_occurrences; - problem_sizes1[expert_id * 3 + 1] = 2 * n; - problem_sizes1[expert_id * 3 + 2] = k; - problem_sizes2[expert_id * 3] = final_occurrences; - problem_sizes2[expert_id * 3 + 1] = k; - problem_sizes2[expert_id * 3 + 2] = n; + if constexpr (!SWAP_AB) { + problem_sizes1[expert_id * 3] = final_occurrences; + problem_sizes1[expert_id * 3 + 1] = 2 * n; + problem_sizes1[expert_id * 3 + 2] = k; + problem_sizes2[expert_id * 3] = final_occurrences; + problem_sizes2[expert_id * 3 + 1] = k; + problem_sizes2[expert_id * 3 + 2] = n; + } else { + problem_sizes1[expert_id * 3] = 2 * n; + problem_sizes1[expert_id * 3 + 1] = final_occurrences; + problem_sizes1[expert_id * 3 + 2] = k; + problem_sizes2[expert_id * 3] = k; + problem_sizes2[expert_id * 3 + 1] = final_occurrences; + problem_sizes2[expert_id * 3 + 2] = n; + } } } __global__ void compute_expert_offsets( const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets, - int32_t* atomic_buffer, const int num_experts) { + int32_t* atomic_buffer, const int num_experts, const int topk_length) { int32_t tot_offset = 0; expert_offsets[0] = 0; for (int i = 0; i < num_experts; ++i) { atomic_buffer[i] = tot_offset; - tot_offset += problem_sizes1[i * 3]; + tot_offset += topk_length > SWAP_AB_THRESHOLD ? problem_sizes1[i * 3] + : problem_sizes1[i * 3 + 1]; expert_offsets[i + 1] = tot_offset; } } __global__ void compute_expert_blockscale_offsets( const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets, - int32_t* blockscale_offsets, int32_t* atomic_buffer, - const int num_experts) { + int32_t* blockscale_offsets, int32_t* atomic_buffer, const int num_experts, + const int topk_length) { int32_t tot_offset = 0; int32_t tot_offset_round = 0; expert_offsets[0] = 0; blockscale_offsets[0] = 0; for (int i = 0; i < num_experts; ++i) { + int32_t cur_offset = topk_length > SWAP_AB_THRESHOLD + ? problem_sizes1[i * 3] + : problem_sizes1[i * 3 + 1]; atomic_buffer[i] = tot_offset; - tot_offset += problem_sizes1[i * 3]; + tot_offset += cur_offset; expert_offsets[i + 1] = tot_offset; - tot_offset_round += (problem_sizes1[i * 3] + (128 - 1)) / 128 * 128; + tot_offset_round += (cur_offset + (128 - 1)) / 128 * 128; blockscale_offsets[i + 1] = tot_offset_round; } } @@ -102,22 +118,36 @@ void get_cutlass_moe_mm_data_caller( torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32); int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel()); - compute_problem_sizes<<>>( - static_cast(topk_ids.data_ptr()), - static_cast(problem_sizes1.data_ptr()), - static_cast(problem_sizes2.data_ptr()), - static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), n, k); + + if (topk_ids.numel() > SWAP_AB_THRESHOLD) { + compute_problem_sizes<<>>( + static_cast(topk_ids.data_ptr()), + static_cast(problem_sizes1.data_ptr()), + static_cast(problem_sizes2.data_ptr()), + static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), n, + k); + } else { + compute_problem_sizes<<>>( + static_cast(topk_ids.data_ptr()), + static_cast(problem_sizes1.data_ptr()), + static_cast(problem_sizes2.data_ptr()), + static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), n, + k); + } + if (blockscale_offsets.has_value()) { compute_expert_blockscale_offsets<<<1, 1, 0, stream>>>( static_cast(problem_sizes1.data_ptr()), static_cast(expert_offsets.data_ptr()), static_cast(blockscale_offsets.value().data_ptr()), - static_cast(atomic_buffer.data_ptr()), num_experts); + static_cast(atomic_buffer.data_ptr()), num_experts, + topk_ids.numel()); } else { compute_expert_offsets<<<1, 1, 0, stream>>>( static_cast(problem_sizes1.data_ptr()), static_cast(expert_offsets.data_ptr()), - static_cast(atomic_buffer.data_ptr()), num_experts); + static_cast(atomic_buffer.data_ptr()), num_experts, + topk_ids.numel()); } compute_arg_sorts<<>>( static_cast(topk_ids.data_ptr()), diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index 5fb49c2da4fe0..37727b75b077b 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -25,6 +25,7 @@ MNK_FACTORS = [ (2, 1024, 1536), (2, 3072, 1024), (2, 3072, 1536), + (7, 3072, 1536), (64, 1024, 1024), (64, 1024, 1536), (64, 3072, 1024),