[Refactor] Small refactor for group topk (#30562)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
Wentao Ye 2025-12-16 14:50:59 -05:00 committed by GitHub
parent ca702a14dc
commit f21f5ea38c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 10 additions and 4 deletions

View File

@ -446,9 +446,13 @@ __device__ inline T apply_sigmoid(T val) {
template <ScoringFunc SF, typename T>
__device__ inline T apply_scoring(T val) {
if constexpr (SF == SCORING_SIGMOID) {
if constexpr (SF == SCORING_NONE) {
return val;
} else if constexpr (SF == SCORING_SIGMOID) {
return apply_sigmoid(val);
} else {
static_assert(SF == SCORING_NONE || SF == SCORING_SIGMOID,
"Unsupported ScoringFunc in apply_scoring");
return val;
}
}
@ -670,10 +674,13 @@ __global__ void group_idx_and_topk_idx_kernel(
if (case_id < num_tokens) {
if (if_proceed_next_topk) {
float scale = routed_scaling_factor;
if (renormalize) {
scale /= topk_sum;
}
for (int i = lane_id; i < topk; i += WARP_SIZE) {
float base = cuda_cast<float, T>(s_topk_value[i]);
float value = renormalize ? (base / topk_sum * routed_scaling_factor)
: (base * routed_scaling_factor);
float value = base * scale;
topk_indices[i] = s_topk_idx[i];
topk_values[i] = value;
}

View File

@ -188,7 +188,6 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
llm = LLM(
model=model_name,
tensor_parallel_size=tp_size,
# enable_prefix_caching=False,
max_num_seqs=32,
max_model_len=8192,
dtype="bfloat16", # not everything is supported