From 0ee6416f678de761f827f4bd387b36f0cbc43a0a Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Mon, 8 Dec 2025 19:44:01 -0500 Subject: [PATCH] [Perf] Optimize `group_topk` kernel, 1.9% Throughput improvement, 2.1% TPOT improvemnt (#30159) Signed-off-by: yewentao256 --- csrc/moe/grouped_topk_kernels.cu | 175 ++++++++++++++++++++++--------- 1 file changed, 128 insertions(+), 47 deletions(-) diff --git a/csrc/moe/grouped_topk_kernels.cu b/csrc/moe/grouped_topk_kernels.cu index 69b4c1fb11d1a..47ee5f021eb4a 100644 --- a/csrc/moe/grouped_topk_kernels.cu +++ b/csrc/moe/grouped_topk_kernels.cu @@ -444,23 +444,27 @@ __device__ inline T apply_sigmoid(T val) { return cuda_cast(sigmoid_accurate(f)); } -template +template +__device__ inline T apply_scoring(T val) { + if constexpr (SF == SCORING_SIGMOID) { + return apply_sigmoid(val); + } else { + return val; + } +} + +template __device__ void topk_with_k2(T* output, T const* input, T const* bias, cg::thread_block_tile<32> const& tile, int32_t const lane_id, - int const num_experts_per_group, - int const scoring_func) { + int const num_experts_per_group) { // Get the top2 per thread T largest = neg_inf(); T second_largest = neg_inf(); if (num_experts_per_group > WARP_SIZE) { for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { - T value = input[i]; - // Apply scoring function if needed - if (scoring_func == SCORING_SIGMOID) { - value = apply_sigmoid(value); - } + T value = apply_scoring(input[i]); value = value + bias[i]; if (value > largest) { @@ -472,11 +476,7 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias, } } else { for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { - T value = input[i]; - // Apply scoring function if needed - if (scoring_func == SCORING_SIGMOID) { - value = apply_sigmoid(value); - } + T value = apply_scoring(input[i]); value = value + bias[i]; largest = value; } @@ -501,13 +501,12 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias, } } -template +template __global__ void topk_with_k2_kernel(T* output, T* input, T const* bias, int64_t const num_tokens, int64_t const num_cases, int64_t const n_group, - int64_t const num_experts_per_group, - int const scoring_func) { + int64_t const num_experts_per_group) { int32_t warp_id = threadIdx.x / WARP_SIZE; int32_t lane_id = threadIdx.x % WARP_SIZE; @@ -525,21 +524,21 @@ __global__ void topk_with_k2_kernel(T* output, T* input, T const* bias, #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.wait;"); #endif - topk_with_k2(output, input, group_bias, tile, lane_id, - num_experts_per_group, scoring_func); + topk_with_k2(output, input, group_bias, tile, lane_id, + num_experts_per_group); } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.launch_dependents;"); #endif } -template +template __global__ void group_idx_and_topk_idx_kernel( T* scores, T const* group_scores, float* topk_values, IdxT* topk_indices, T const* bias, int64_t const num_tokens, int64_t const n_group, int64_t const topk_group, int64_t const topk, int64_t const num_experts, int64_t const num_experts_per_group, bool renormalize, - double routed_scaling_factor, int scoring_func) { + double routed_scaling_factor) { int32_t warp_id = threadIdx.x / WARP_SIZE; int32_t lane_id = threadIdx.x % WARP_SIZE; int32_t case_id = @@ -549,6 +548,11 @@ __global__ void group_idx_and_topk_idx_kernel( topk_values += case_id * topk; topk_indices += case_id * topk; + constexpr bool kUseStaticNGroup = (NGroup > 0); + // use int32 to avoid implicit conversion + int32_t const n_group_i32 = + kUseStaticNGroup ? NGroup : static_cast(n_group); + int32_t align_num_experts_per_group = warp_topk::round_up_to_multiple_of(num_experts_per_group); @@ -574,13 +578,14 @@ __global__ void group_idx_and_topk_idx_kernel( if (case_id < num_tokens) { // calculate group_idx - int32_t target_num_min = WARP_SIZE - n_group + topk_group; + int32_t target_num_min = + WARP_SIZE - n_group_i32 + static_cast(topk_group); // The check is necessary to avoid abnormal input - if (lane_id < n_group && is_finite(group_scores[lane_id])) { + if (lane_id < n_group_i32 && is_finite(group_scores[lane_id])) { value = group_scores[lane_id]; } - int count_equal_to_top_value = WARP_SIZE - n_group; + int count_equal_to_top_value = WARP_SIZE - n_group_i32; int pre_count_equal_to_top_value = 0; // Use loop to find the largset top_group while (count_equal_to_top_value < target_num_min) { @@ -604,7 +609,7 @@ __global__ void group_idx_and_topk_idx_kernel( int count_equalto_topkth_group = 0; bool if_proceed_next_topk = topk_group_value != neg_inf(); if (case_id < num_tokens && if_proceed_next_topk) { - for (int i_group = 0; i_group < n_group; i_group++) { + auto process_group = [&](int i_group) { if ((group_scores[i_group] > topk_group_value) || ((group_scores[i_group] == topk_group_value) && (count_equalto_topkth_group < num_equalto_topkth_group))) { @@ -613,11 +618,10 @@ __global__ void group_idx_and_topk_idx_kernel( i += WARP_SIZE) { T candidates = neg_inf(); if (i < num_experts_per_group) { - // Apply scoring function (if any) and add bias + // apply scoring function (if any) and add bias T input = scores[offset + i]; if (is_finite(input)) { - T score = (scoring_func == SCORING_SIGMOID) ? apply_sigmoid(input) - : input; + T score = apply_scoring(input); candidates = score + bias[offset + i]; } } @@ -627,6 +631,17 @@ __global__ void group_idx_and_topk_idx_kernel( count_equalto_topkth_group++; } } + }; + + if constexpr (kUseStaticNGroup) { +#pragma unroll + for (int i_group = 0; i_group < NGroup; ++i_group) { + process_group(i_group); + } + } else { + for (int i_group = 0; i_group < n_group_i32; ++i_group) { + process_group(i_group); + } } queue.done(); __syncwarp(); @@ -646,12 +661,13 @@ __global__ void group_idx_and_topk_idx_kernel( if (i < topk) { // Load the score value (without bias) for normalization T input = scores[s_topk_idx[i]]; - value = - (scoring_func == SCORING_SIGMOID) ? apply_sigmoid(input) : input; + value = apply_scoring(input); s_topk_value[i] = value; } - topk_sum += - cg::reduce(tile, cuda_cast(value), cg::plus()); + if (renormalize) { + topk_sum += + cg::reduce(tile, cuda_cast(value), cg::plus()); + } } } @@ -660,13 +676,9 @@ __global__ void group_idx_and_topk_idx_kernel( if (case_id < num_tokens) { if (if_proceed_next_topk) { for (int i = lane_id; i < topk; i += WARP_SIZE) { - float value; - if (renormalize) { - value = cuda_cast(s_topk_value[i]) / topk_sum * - routed_scaling_factor; - } else { - value = cuda_cast(s_topk_value[i]) * routed_scaling_factor; - } + float base = cuda_cast(s_topk_value[i]); + float value = renormalize ? (base / topk_sum * routed_scaling_factor) + : (base * routed_scaling_factor); topk_indices[i] = s_topk_idx[i]; topk_values[i] = value; } @@ -684,6 +696,45 @@ __global__ void group_idx_and_topk_idx_kernel( #endif } +template +inline void launch_group_idx_and_topk_kernel( + cudaLaunchConfig_t const& config, T* scores, T* group_scores, + float* topk_values, IdxT* topk_indices, T const* bias, + int64_t const num_tokens, int64_t const n_group, int64_t const topk_group, + int64_t const topk, int64_t const num_experts, + int64_t const num_experts_per_group, bool const renormalize, + double const routed_scaling_factor) { + auto launch = [&](auto* kernel_instance2) { + cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores, + topk_values, topk_indices, bias, num_tokens, n_group, + topk_group, topk, num_experts, num_experts_per_group, + renormalize, routed_scaling_factor); + }; + + switch (n_group) { + case 4: { + launch(&group_idx_and_topk_idx_kernel); + break; + } + case 8: { + launch(&group_idx_and_topk_idx_kernel); + break; + } + case 16: { + launch(&group_idx_and_topk_idx_kernel); + break; + } + case 32: { + launch(&group_idx_and_topk_idx_kernel); + break; + } + default: { + launch(&group_idx_and_topk_idx_kernel); + break; + } + } +} + template void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values, IdxT* topk_indices, T const* bias, int64_t const num_tokens, @@ -694,7 +745,6 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values, cudaStream_t const stream = 0) { int64_t num_cases = num_tokens * n_group; int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1; - auto* kernel_instance1 = &topk_with_k2_kernel; cudaLaunchConfig_t config; config.gridDim = topk_with_k2_num_blocks; config.blockDim = BLOCK_SIZE; @@ -705,16 +755,33 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values, attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; config.numAttrs = 1; config.attrs = attrs; - cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores, bias, - num_tokens, num_cases, n_group, num_experts / n_group, - scoring_func); + auto const sf = static_cast(scoring_func); + int64_t const num_experts_per_group = num_experts / n_group; + auto launch_topk_with_k2 = [&](auto* kernel_instance1) { + cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores, bias, + num_tokens, num_cases, n_group, num_experts_per_group); + }; + switch (sf) { + case SCORING_NONE: { + auto* kernel_instance1 = &topk_with_k2_kernel; + launch_topk_with_k2(kernel_instance1); + break; + } + case SCORING_SIGMOID: { + auto* kernel_instance1 = &topk_with_k2_kernel; + launch_topk_with_k2(kernel_instance1); + break; + } + default: + // should be guarded by higher level checks. + TORCH_CHECK(false, "Unsupported scoring_func in invokeNoAuxTc"); + } int64_t topk_with_k_group_num_blocks = (num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1; size_t dynamic_smem_in_bytes = warp_topk::calc_smem_size_for_block_wide(NUM_WARPS_PER_BLOCK, topk); - auto* kernel_instance2 = &group_idx_and_topk_idx_kernel; config.gridDim = topk_with_k_group_num_blocks; config.blockDim = BLOCK_SIZE; config.dynamicSmemBytes = dynamic_smem_in_bytes; @@ -723,10 +790,24 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values, attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; config.numAttrs = 1; config.attrs = attrs; - cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores, - topk_values, topk_indices, bias, num_tokens, n_group, - topk_group, topk, num_experts, num_experts / n_group, - renormalize, routed_scaling_factor, scoring_func); + switch (sf) { + case SCORING_NONE: { + launch_group_idx_and_topk_kernel( + config, scores, group_scores, topk_values, topk_indices, bias, + num_tokens, n_group, topk_group, topk, num_experts, + num_experts_per_group, renormalize, routed_scaling_factor); + break; + } + case SCORING_SIGMOID: { + launch_group_idx_and_topk_kernel( + config, scores, group_scores, topk_values, topk_indices, bias, + num_tokens, n_group, topk_group, topk, num_experts, + num_experts_per_group, renormalize, routed_scaling_factor); + break; + } + default: + TORCH_CHECK(false, "Unsupported scoring_func in invokeNoAuxTc"); + } } #define INSTANTIATE_NOAUX_TC(T, IdxT) \