mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-30 06:57:04 +08:00
[Refactor] Small refactor for group topk (#30562)
Signed-off-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
parent
ca702a14dc
commit
f21f5ea38c
@ -446,9 +446,13 @@ __device__ inline T apply_sigmoid(T val) {
|
|||||||
|
|
||||||
template <ScoringFunc SF, typename T>
|
template <ScoringFunc SF, typename T>
|
||||||
__device__ inline T apply_scoring(T val) {
|
__device__ inline T apply_scoring(T val) {
|
||||||
if constexpr (SF == SCORING_SIGMOID) {
|
if constexpr (SF == SCORING_NONE) {
|
||||||
|
return val;
|
||||||
|
} else if constexpr (SF == SCORING_SIGMOID) {
|
||||||
return apply_sigmoid(val);
|
return apply_sigmoid(val);
|
||||||
} else {
|
} else {
|
||||||
|
static_assert(SF == SCORING_NONE || SF == SCORING_SIGMOID,
|
||||||
|
"Unsupported ScoringFunc in apply_scoring");
|
||||||
return val;
|
return val;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -670,10 +674,13 @@ __global__ void group_idx_and_topk_idx_kernel(
|
|||||||
|
|
||||||
if (case_id < num_tokens) {
|
if (case_id < num_tokens) {
|
||||||
if (if_proceed_next_topk) {
|
if (if_proceed_next_topk) {
|
||||||
|
float scale = routed_scaling_factor;
|
||||||
|
if (renormalize) {
|
||||||
|
scale /= topk_sum;
|
||||||
|
}
|
||||||
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
||||||
float base = cuda_cast<float, T>(s_topk_value[i]);
|
float base = cuda_cast<float, T>(s_topk_value[i]);
|
||||||
float value = renormalize ? (base / topk_sum * routed_scaling_factor)
|
float value = base * scale;
|
||||||
: (base * routed_scaling_factor);
|
|
||||||
topk_indices[i] = s_topk_idx[i];
|
topk_indices[i] = s_topk_idx[i];
|
||||||
topk_values[i] = value;
|
topk_values[i] = value;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -188,7 +188,6 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
|||||||
llm = LLM(
|
llm = LLM(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
tensor_parallel_size=tp_size,
|
tensor_parallel_size=tp_size,
|
||||||
# enable_prefix_caching=False,
|
|
||||||
max_num_seqs=32,
|
max_num_seqs=32,
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
dtype="bfloat16", # not everything is supported
|
dtype="bfloat16", # not everything is supported
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user