From 4c558cf62ed69fbd8c031809b0a7f8b12afa980b Mon Sep 17 00:00:00 2001 From: shixianc <49539556+shixianc@users.noreply.github.com> Date: Tue, 12 Aug 2025 21:34:47 -0700 Subject: [PATCH] [Perf] Support topk softmax fused kernel for broader num_experts (#22211) Signed-off-by: Shixian Cui Co-authored-by: Shixian Cui --- csrc/moe/topk_softmax_kernels.cu | 77 +++++++++++++++++++------------- tests/kernels/moe/test_moe.py | 2 +- 2 files changed, 46 insertions(+), 33 deletions(-) diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu index 7a7865b901de..946c137db636 100644 --- a/csrc/moe/topk_softmax_kernels.cu +++ b/csrc/moe/topk_softmax_kernels.cu @@ -188,7 +188,9 @@ __launch_bounds__(TPB) __global__ void moeTopK( It fuses the softmax, max and argmax into a single kernel. Limitations: - 1) This implementation is intended for when the number of experts is a small power of 2. + 1) This implementation is optimized for when the number of experts is a small power of 2. + Additionally it also supports when number of experts is multiple of 64 which is still + faster than the computing softmax and topK separately (only tested on CUDA yet). 2) This implementation assumes k is small, but will work for any k. */ @@ -198,8 +200,6 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__ int* source_rows, const int k, const int start_expert, const int end_expert) { // We begin by enforcing compile time assertions and setting up compile time constants. - static_assert(VPT == (VPT & -VPT), "VPT must be power of 2"); - static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2"); static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2"); static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16"); @@ -407,12 +407,10 @@ struct TopkConstants }; } // namespace detail -template +template void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, IndType* indices, int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream) { - static constexpr std::size_t MAX_BYTES_PER_LDG = 16; - static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS); using Constants = detail::TopkConstants; static constexpr int VPT = Constants::VPT; @@ -425,21 +423,12 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert); } -#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \ - switch (warpSize) { \ - case 32: \ - topkGatingSoftmaxLauncherHelper( \ - gating_output, nullptr, topk_weights, topk_indices, \ - token_expert_indices, num_tokens, topk, 0, num_experts, stream); \ - break; \ - case 64: \ - topkGatingSoftmaxLauncherHelper( \ - gating_output, nullptr, topk_weights, topk_indices, \ - token_expert_indices, num_tokens, topk, 0, num_experts, stream); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported warp size: ", warpSize); \ - } +#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \ + static_assert(WARP_SIZE == 32 || WARP_SIZE == 64, \ + "Unsupported warp size. Only 32 and 64 are supported."); \ + topkGatingSoftmaxLauncherHelper( \ + gating_output, nullptr, topk_weights, topk_indices, \ + token_expert_indices, num_tokens, topk, 0, num_experts, stream); template void topkGatingSoftmaxKernelLauncher( @@ -453,38 +442,62 @@ void topkGatingSoftmaxKernelLauncher( const int topk, cudaStream_t stream) { static constexpr int WARPS_PER_TB = 4; - auto warpSize = WARP_SIZE; + static constexpr int BYTES_PER_LDG_POWER_OF_2 = 16; + static constexpr int BYTES_PER_LDG_MULTIPLE_64 = 8; switch (num_experts) { case 1: - LAUNCH_SOFTMAX(1, WARPS_PER_TB); + LAUNCH_SOFTMAX(1, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); break; case 2: - LAUNCH_SOFTMAX(2, WARPS_PER_TB); + LAUNCH_SOFTMAX(2, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); break; case 4: - LAUNCH_SOFTMAX(4, WARPS_PER_TB); + LAUNCH_SOFTMAX(4, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); break; case 8: - LAUNCH_SOFTMAX(8, WARPS_PER_TB); + LAUNCH_SOFTMAX(8, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); break; case 16: - LAUNCH_SOFTMAX(16, WARPS_PER_TB); + LAUNCH_SOFTMAX(16, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); break; case 32: - LAUNCH_SOFTMAX(32, WARPS_PER_TB); + LAUNCH_SOFTMAX(32, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); break; case 64: - LAUNCH_SOFTMAX(64, WARPS_PER_TB); + LAUNCH_SOFTMAX(64, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); break; case 128: - LAUNCH_SOFTMAX(128, WARPS_PER_TB); + LAUNCH_SOFTMAX(128, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); break; case 256: - LAUNCH_SOFTMAX(256, WARPS_PER_TB); + LAUNCH_SOFTMAX(256, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); break; + case 512: + LAUNCH_SOFTMAX(512, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2); + break; + // (CUDA only) support multiples of 64 when num_experts is not power of 2. + // ROCm uses WARP_SIZE 64 so 8 bytes loading won't fit for some of num_experts, + // alternatively we can test 4 bytes loading and enable it in future. +#ifndef USE_ROCM + case 192: + LAUNCH_SOFTMAX(192, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64); + break; + case 320: + LAUNCH_SOFTMAX(320, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64); + break; + case 384: + LAUNCH_SOFTMAX(384, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64); + break; + case 448: + LAUNCH_SOFTMAX(448, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64); + break; + case 576: + LAUNCH_SOFTMAX(576, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64); + break; +#endif default: { TORCH_CHECK(softmax_workspace != nullptr, - "softmax_workspace must be provided for num_experts that are not a power of 2."); + "softmax_workspace must be provided for num_experts that are not a power of 2 or multiple of 64."); static constexpr int TPB = 256; moeSoftmax<<>>( gating_output, nullptr, softmax_workspace, num_experts); diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 0f1c78704642..49c097718e30 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -36,7 +36,7 @@ from vllm.model_executor.models.mixtral import MixtralMoE from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types -NUM_EXPERTS = [8, 64] +NUM_EXPERTS = [8, 64, 192] EP_SIZE = [1, 4] TOP_KS = [2, 6]