From f21f5ea38c6fa0e824bc00d5762d17e049199cd3 Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Tue, 16 Dec 2025 14:50:59 -0500 Subject: [PATCH] [Refactor] Small refactor for group topk (#30562) Signed-off-by: yewentao256 Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> --- csrc/moe/grouped_topk_kernels.cu | 13 ++++++++++--- tests/v1/determinism/test_batch_invariance.py | 1 - 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/csrc/moe/grouped_topk_kernels.cu b/csrc/moe/grouped_topk_kernels.cu index 5fa367abd96f5..7229e420d3fe4 100644 --- a/csrc/moe/grouped_topk_kernels.cu +++ b/csrc/moe/grouped_topk_kernels.cu @@ -446,9 +446,13 @@ __device__ inline T apply_sigmoid(T val) { template __device__ inline T apply_scoring(T val) { - if constexpr (SF == SCORING_SIGMOID) { + if constexpr (SF == SCORING_NONE) { + return val; + } else if constexpr (SF == SCORING_SIGMOID) { return apply_sigmoid(val); } else { + static_assert(SF == SCORING_NONE || SF == SCORING_SIGMOID, + "Unsupported ScoringFunc in apply_scoring"); return val; } } @@ -670,10 +674,13 @@ __global__ void group_idx_and_topk_idx_kernel( if (case_id < num_tokens) { if (if_proceed_next_topk) { + float scale = routed_scaling_factor; + if (renormalize) { + scale /= topk_sum; + } for (int i = lane_id; i < topk; i += WARP_SIZE) { float base = cuda_cast(s_topk_value[i]); - float value = renormalize ? (base / topk_sum * routed_scaling_factor) - : (base * routed_scaling_factor); + float value = base * scale; topk_indices[i] = s_topk_idx[i]; topk_values[i] = value; } diff --git a/tests/v1/determinism/test_batch_invariance.py b/tests/v1/determinism/test_batch_invariance.py index 1c45e7fe366ff..7a58e1c9bad03 100644 --- a/tests/v1/determinism/test_batch_invariance.py +++ b/tests/v1/determinism/test_batch_invariance.py @@ -188,7 +188,6 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( llm = LLM( model=model_name, tensor_parallel_size=tp_size, - # enable_prefix_caching=False, max_num_seqs=32, max_model_len=8192, dtype="bfloat16", # not everything is supported