mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-01 00:30:05 +08:00
[Kernel][Bugfix] Fix grouped topk cu (#24146)
Signed-off-by: mayuyuace <qiming1.zhang@intel.com>
This commit is contained in:
parent
a38f8bd54c
commit
e919d6f549
@ -28,6 +28,7 @@ namespace cg = cooperative_groups;
|
|||||||
namespace vllm {
|
namespace vllm {
|
||||||
namespace moe {
|
namespace moe {
|
||||||
|
|
||||||
|
constexpr float kNegInfinity = INFINITY * -1;
|
||||||
constexpr unsigned FULL_WARP_MASK = 0xffffffff;
|
constexpr unsigned FULL_WARP_MASK = 0xffffffff;
|
||||||
constexpr int32_t WARP_SIZE = 32;
|
constexpr int32_t WARP_SIZE = 32;
|
||||||
constexpr int32_t BLOCK_SIZE = 512;
|
constexpr int32_t BLOCK_SIZE = 512;
|
||||||
@ -512,8 +513,8 @@ __global__ void group_idx_and_topk_idx_kernel(
|
|||||||
warp_id * topk;
|
warp_id * topk;
|
||||||
s_topk_idx += warp_id * topk;
|
s_topk_idx += warp_id * topk;
|
||||||
|
|
||||||
T value = cuda::std::numeric_limits<T>::min();
|
T value = kNegInfinity;
|
||||||
T topk_group_value = cuda::std::numeric_limits<T>::min();
|
T topk_group_value = kNegInfinity;
|
||||||
int32_t num_equalto_topkth_group;
|
int32_t num_equalto_topkth_group;
|
||||||
|
|
||||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||||
@ -539,11 +540,11 @@ __global__ void group_idx_and_topk_idx_kernel(
|
|||||||
__syncwarp(); // Ensure all threads have valid data before reduction
|
__syncwarp(); // Ensure all threads have valid data before reduction
|
||||||
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
|
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
|
||||||
if (value == topk_group_value) {
|
if (value == topk_group_value) {
|
||||||
value = cuda::std::numeric_limits<T>::min();
|
value = kNegInfinity;
|
||||||
}
|
}
|
||||||
pre_count_equal_to_top_value = count_equal_to_top_value;
|
pre_count_equal_to_top_value = count_equal_to_top_value;
|
||||||
count_equal_to_top_value = __popc(__ballot_sync(
|
count_equal_to_top_value = __popc(__ballot_sync(
|
||||||
FULL_WARP_MASK, (value == cuda::std::numeric_limits<T>::min())));
|
FULL_WARP_MASK, (value == cuda_cast<T, float>(kNegInfinity))));
|
||||||
}
|
}
|
||||||
num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value;
|
num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value;
|
||||||
}
|
}
|
||||||
@ -555,7 +556,7 @@ __global__ void group_idx_and_topk_idx_kernel(
|
|||||||
|
|
||||||
int count_equalto_topkth_group = 0;
|
int count_equalto_topkth_group = 0;
|
||||||
bool if_proceed_next_topk =
|
bool if_proceed_next_topk =
|
||||||
(topk_group_value != cuda::std::numeric_limits<T>::min());
|
(topk_group_value != cuda_cast<T, float>(kNegInfinity));
|
||||||
if (case_id < num_tokens && if_proceed_next_topk) {
|
if (case_id < num_tokens && if_proceed_next_topk) {
|
||||||
for (int i_group = 0; i_group < n_group; i_group++) {
|
for (int i_group = 0; i_group < n_group; i_group++) {
|
||||||
if ((group_scores[i_group] > topk_group_value) ||
|
if ((group_scores[i_group] > topk_group_value) ||
|
||||||
@ -568,7 +569,7 @@ __global__ void group_idx_and_topk_idx_kernel(
|
|||||||
(i < num_experts_per_group) && isfinite(cuda_cast<float, T>(
|
(i < num_experts_per_group) && isfinite(cuda_cast<float, T>(
|
||||||
scores_with_bias[offset + i]))
|
scores_with_bias[offset + i]))
|
||||||
? scores_with_bias[offset + i]
|
? scores_with_bias[offset + i]
|
||||||
: cuda::std::numeric_limits<T>::min();
|
: cuda_cast<T, float>(kNegInfinity);
|
||||||
queue.add(candidates, offset + i);
|
queue.add(candidates, offset + i);
|
||||||
}
|
}
|
||||||
if (group_scores[i_group] == topk_group_value) {
|
if (group_scores[i_group] == topk_group_value) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user