[Refactor] Remove useless syncwarp (#30510)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye 2025-12-11 17:43:41 -05:00 committed by GitHub
parent c817b14151
commit 61249b177d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -481,8 +481,6 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias,
largest = value;
}
}
__syncwarp(); // Ensure all threads have valid data before reduction
// Get the top2 warpwise
T max1 = cg::reduce(tile, largest, cg::greater<T>());
@ -589,7 +587,6 @@ __global__ void group_idx_and_topk_idx_kernel(
int pre_count_equal_to_top_value = 0;
// Use loop to find the largset top_group
while (count_equal_to_top_value < target_num_min) {
__syncwarp(); // Ensure all threads have valid data before reduction
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
if (value == topk_group_value) {
value = neg_inf<T>();
@ -644,10 +641,8 @@ __global__ void group_idx_and_topk_idx_kernel(
}
}
queue.done();
__syncwarp();
// Get the topk_idx
queue.dumpIdx(s_topk_idx);
__syncwarp();
}
// Load the valid score value