mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 20:04:31 +08:00
[Kernel][Bugfix] Fix grouped topk cu (#24146)
Signed-off-by: mayuyuace <qiming1.zhang@intel.com>
This commit is contained in:
parent
a38f8bd54c
commit
e919d6f549
@ -28,6 +28,7 @@ namespace cg = cooperative_groups;
|
||||
namespace vllm {
|
||||
namespace moe {
|
||||
|
||||
constexpr float kNegInfinity = INFINITY * -1;
|
||||
constexpr unsigned FULL_WARP_MASK = 0xffffffff;
|
||||
constexpr int32_t WARP_SIZE = 32;
|
||||
constexpr int32_t BLOCK_SIZE = 512;
|
||||
@ -512,8 +513,8 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
warp_id * topk;
|
||||
s_topk_idx += warp_id * topk;
|
||||
|
||||
T value = cuda::std::numeric_limits<T>::min();
|
||||
T topk_group_value = cuda::std::numeric_limits<T>::min();
|
||||
T value = kNegInfinity;
|
||||
T topk_group_value = kNegInfinity;
|
||||
int32_t num_equalto_topkth_group;
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
@ -539,11 +540,11 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
__syncwarp(); // Ensure all threads have valid data before reduction
|
||||
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
|
||||
if (value == topk_group_value) {
|
||||
value = cuda::std::numeric_limits<T>::min();
|
||||
value = kNegInfinity;
|
||||
}
|
||||
pre_count_equal_to_top_value = count_equal_to_top_value;
|
||||
count_equal_to_top_value = __popc(__ballot_sync(
|
||||
FULL_WARP_MASK, (value == cuda::std::numeric_limits<T>::min())));
|
||||
FULL_WARP_MASK, (value == cuda_cast<T, float>(kNegInfinity))));
|
||||
}
|
||||
num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value;
|
||||
}
|
||||
@ -555,7 +556,7 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
|
||||
int count_equalto_topkth_group = 0;
|
||||
bool if_proceed_next_topk =
|
||||
(topk_group_value != cuda::std::numeric_limits<T>::min());
|
||||
(topk_group_value != cuda_cast<T, float>(kNegInfinity));
|
||||
if (case_id < num_tokens && if_proceed_next_topk) {
|
||||
for (int i_group = 0; i_group < n_group; i_group++) {
|
||||
if ((group_scores[i_group] > topk_group_value) ||
|
||||
@ -568,7 +569,7 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
(i < num_experts_per_group) && isfinite(cuda_cast<float, T>(
|
||||
scores_with_bias[offset + i]))
|
||||
? scores_with_bias[offset + i]
|
||||
: cuda::std::numeric_limits<T>::min();
|
||||
: cuda_cast<T, float>(kNegInfinity);
|
||||
queue.add(candidates, offset + i);
|
||||
}
|
||||
if (group_scores[i_group] == topk_group_value) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user