From e919d6f549f4da22fa60ea394f00aaf93ef23aa0 Mon Sep 17 00:00:00 2001 From: Qiming Zhang Date: Wed, 3 Sep 2025 21:37:37 -0700 Subject: [PATCH] [Kernel][Bugfix] Fix grouped topk cu (#24146) Signed-off-by: mayuyuace --- csrc/moe/grouped_topk_kernels.cu | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/csrc/moe/grouped_topk_kernels.cu b/csrc/moe/grouped_topk_kernels.cu index 78f7b3cc1aa25..accbb09858fac 100644 --- a/csrc/moe/grouped_topk_kernels.cu +++ b/csrc/moe/grouped_topk_kernels.cu @@ -28,6 +28,7 @@ namespace cg = cooperative_groups; namespace vllm { namespace moe { +constexpr float kNegInfinity = INFINITY * -1; constexpr unsigned FULL_WARP_MASK = 0xffffffff; constexpr int32_t WARP_SIZE = 32; constexpr int32_t BLOCK_SIZE = 512; @@ -512,8 +513,8 @@ __global__ void group_idx_and_topk_idx_kernel( warp_id * topk; s_topk_idx += warp_id * topk; - T value = cuda::std::numeric_limits::min(); - T topk_group_value = cuda::std::numeric_limits::min(); + T value = kNegInfinity; + T topk_group_value = kNegInfinity; int32_t num_equalto_topkth_group; #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) @@ -539,11 +540,11 @@ __global__ void group_idx_and_topk_idx_kernel( __syncwarp(); // Ensure all threads have valid data before reduction topk_group_value = cg::reduce(tile, value, cg::greater()); if (value == topk_group_value) { - value = cuda::std::numeric_limits::min(); + value = kNegInfinity; } pre_count_equal_to_top_value = count_equal_to_top_value; count_equal_to_top_value = __popc(__ballot_sync( - FULL_WARP_MASK, (value == cuda::std::numeric_limits::min()))); + FULL_WARP_MASK, (value == cuda_cast(kNegInfinity)))); } num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value; } @@ -555,7 +556,7 @@ __global__ void group_idx_and_topk_idx_kernel( int count_equalto_topkth_group = 0; bool if_proceed_next_topk = - (topk_group_value != cuda::std::numeric_limits::min()); + (topk_group_value != cuda_cast(kNegInfinity)); if (case_id < num_tokens && if_proceed_next_topk) { for (int i_group = 0; i_group < n_group; i_group++) { if ((group_scores[i_group] > topk_group_value) || @@ -568,7 +569,7 @@ __global__ void group_idx_and_topk_idx_kernel( (i < num_experts_per_group) && isfinite(cuda_cast( scores_with_bias[offset + i])) ? scores_with_bias[offset + i] - : cuda::std::numeric_limits::min(); + : cuda_cast(kNegInfinity); queue.add(candidates, offset + i); } if (group_scores[i_group] == topk_group_value) {