mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 11:15:48 +08:00
[Refactor] Remove useless syncwarp (#30510)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
c817b14151
commit
61249b177d
@ -481,8 +481,6 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias,
|
|||||||
largest = value;
|
largest = value;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__syncwarp(); // Ensure all threads have valid data before reduction
|
|
||||||
// Get the top2 warpwise
|
// Get the top2 warpwise
|
||||||
T max1 = cg::reduce(tile, largest, cg::greater<T>());
|
T max1 = cg::reduce(tile, largest, cg::greater<T>());
|
||||||
|
|
||||||
@ -589,7 +587,6 @@ __global__ void group_idx_and_topk_idx_kernel(
|
|||||||
int pre_count_equal_to_top_value = 0;
|
int pre_count_equal_to_top_value = 0;
|
||||||
// Use loop to find the largset top_group
|
// Use loop to find the largset top_group
|
||||||
while (count_equal_to_top_value < target_num_min) {
|
while (count_equal_to_top_value < target_num_min) {
|
||||||
__syncwarp(); // Ensure all threads have valid data before reduction
|
|
||||||
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
|
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
|
||||||
if (value == topk_group_value) {
|
if (value == topk_group_value) {
|
||||||
value = neg_inf<T>();
|
value = neg_inf<T>();
|
||||||
@ -644,10 +641,8 @@ __global__ void group_idx_and_topk_idx_kernel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
queue.done();
|
queue.done();
|
||||||
__syncwarp();
|
|
||||||
// Get the topk_idx
|
// Get the topk_idx
|
||||||
queue.dumpIdx(s_topk_idx);
|
queue.dumpIdx(s_topk_idx);
|
||||||
__syncwarp();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load the valid score value
|
// Load the valid score value
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user