diff --git a/csrc/moe/grouped_topk_kernels.cu b/csrc/moe/grouped_topk_kernels.cu index 47ee5f021eb4a..5fa367abd96f5 100644 --- a/csrc/moe/grouped_topk_kernels.cu +++ b/csrc/moe/grouped_topk_kernels.cu @@ -481,8 +481,6 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias, largest = value; } } - - __syncwarp(); // Ensure all threads have valid data before reduction // Get the top2 warpwise T max1 = cg::reduce(tile, largest, cg::greater()); @@ -589,7 +587,6 @@ __global__ void group_idx_and_topk_idx_kernel( int pre_count_equal_to_top_value = 0; // Use loop to find the largset top_group while (count_equal_to_top_value < target_num_min) { - __syncwarp(); // Ensure all threads have valid data before reduction topk_group_value = cg::reduce(tile, value, cg::greater()); if (value == topk_group_value) { value = neg_inf(); @@ -644,10 +641,8 @@ __global__ void group_idx_and_topk_idx_kernel( } } queue.done(); - __syncwarp(); // Get the topk_idx queue.dumpIdx(s_topk_idx); - __syncwarp(); } // Load the valid score value