mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-03 23:02:19 +08:00
[Perf] Optimize group_topk kernel, 1.9% Throughput improvement, 2.1% TPOT improvemnt (#30159)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
d9417096d1
commit
0ee6416f67
@ -444,23 +444,27 @@ __device__ inline T apply_sigmoid(T val) {
|
|||||||
return cuda_cast<T, float>(sigmoid_accurate(f));
|
return cuda_cast<T, float>(sigmoid_accurate(f));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <ScoringFunc SF, typename T>
|
||||||
|
__device__ inline T apply_scoring(T val) {
|
||||||
|
if constexpr (SF == SCORING_SIGMOID) {
|
||||||
|
return apply_sigmoid(val);
|
||||||
|
} else {
|
||||||
|
return val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, ScoringFunc SF>
|
||||||
__device__ void topk_with_k2(T* output, T const* input, T const* bias,
|
__device__ void topk_with_k2(T* output, T const* input, T const* bias,
|
||||||
cg::thread_block_tile<32> const& tile,
|
cg::thread_block_tile<32> const& tile,
|
||||||
int32_t const lane_id,
|
int32_t const lane_id,
|
||||||
int const num_experts_per_group,
|
int const num_experts_per_group) {
|
||||||
int const scoring_func) {
|
|
||||||
// Get the top2 per thread
|
// Get the top2 per thread
|
||||||
T largest = neg_inf<T>();
|
T largest = neg_inf<T>();
|
||||||
T second_largest = neg_inf<T>();
|
T second_largest = neg_inf<T>();
|
||||||
|
|
||||||
if (num_experts_per_group > WARP_SIZE) {
|
if (num_experts_per_group > WARP_SIZE) {
|
||||||
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
|
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
|
||||||
T value = input[i];
|
T value = apply_scoring<SF>(input[i]);
|
||||||
// Apply scoring function if needed
|
|
||||||
if (scoring_func == SCORING_SIGMOID) {
|
|
||||||
value = apply_sigmoid(value);
|
|
||||||
}
|
|
||||||
value = value + bias[i];
|
value = value + bias[i];
|
||||||
|
|
||||||
if (value > largest) {
|
if (value > largest) {
|
||||||
@ -472,11 +476,7 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias,
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
|
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
|
||||||
T value = input[i];
|
T value = apply_scoring<SF>(input[i]);
|
||||||
// Apply scoring function if needed
|
|
||||||
if (scoring_func == SCORING_SIGMOID) {
|
|
||||||
value = apply_sigmoid(value);
|
|
||||||
}
|
|
||||||
value = value + bias[i];
|
value = value + bias[i];
|
||||||
largest = value;
|
largest = value;
|
||||||
}
|
}
|
||||||
@ -501,13 +501,12 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T, ScoringFunc SF>
|
||||||
__global__ void topk_with_k2_kernel(T* output, T* input, T const* bias,
|
__global__ void topk_with_k2_kernel(T* output, T* input, T const* bias,
|
||||||
int64_t const num_tokens,
|
int64_t const num_tokens,
|
||||||
int64_t const num_cases,
|
int64_t const num_cases,
|
||||||
int64_t const n_group,
|
int64_t const n_group,
|
||||||
int64_t const num_experts_per_group,
|
int64_t const num_experts_per_group) {
|
||||||
int const scoring_func) {
|
|
||||||
int32_t warp_id = threadIdx.x / WARP_SIZE;
|
int32_t warp_id = threadIdx.x / WARP_SIZE;
|
||||||
int32_t lane_id = threadIdx.x % WARP_SIZE;
|
int32_t lane_id = threadIdx.x % WARP_SIZE;
|
||||||
|
|
||||||
@ -525,21 +524,21 @@ __global__ void topk_with_k2_kernel(T* output, T* input, T const* bias,
|
|||||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||||
asm volatile("griddepcontrol.wait;");
|
asm volatile("griddepcontrol.wait;");
|
||||||
#endif
|
#endif
|
||||||
topk_with_k2(output, input, group_bias, tile, lane_id,
|
topk_with_k2<T, SF>(output, input, group_bias, tile, lane_id,
|
||||||
num_experts_per_group, scoring_func);
|
num_experts_per_group);
|
||||||
}
|
}
|
||||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||||
asm volatile("griddepcontrol.launch_dependents;");
|
asm volatile("griddepcontrol.launch_dependents;");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename IdxT>
|
template <typename T, typename IdxT, ScoringFunc SF, int NGroup = -1>
|
||||||
__global__ void group_idx_and_topk_idx_kernel(
|
__global__ void group_idx_and_topk_idx_kernel(
|
||||||
T* scores, T const* group_scores, float* topk_values, IdxT* topk_indices,
|
T* scores, T const* group_scores, float* topk_values, IdxT* topk_indices,
|
||||||
T const* bias, int64_t const num_tokens, int64_t const n_group,
|
T const* bias, int64_t const num_tokens, int64_t const n_group,
|
||||||
int64_t const topk_group, int64_t const topk, int64_t const num_experts,
|
int64_t const topk_group, int64_t const topk, int64_t const num_experts,
|
||||||
int64_t const num_experts_per_group, bool renormalize,
|
int64_t const num_experts_per_group, bool renormalize,
|
||||||
double routed_scaling_factor, int scoring_func) {
|
double routed_scaling_factor) {
|
||||||
int32_t warp_id = threadIdx.x / WARP_SIZE;
|
int32_t warp_id = threadIdx.x / WARP_SIZE;
|
||||||
int32_t lane_id = threadIdx.x % WARP_SIZE;
|
int32_t lane_id = threadIdx.x % WARP_SIZE;
|
||||||
int32_t case_id =
|
int32_t case_id =
|
||||||
@ -549,6 +548,11 @@ __global__ void group_idx_and_topk_idx_kernel(
|
|||||||
topk_values += case_id * topk;
|
topk_values += case_id * topk;
|
||||||
topk_indices += case_id * topk;
|
topk_indices += case_id * topk;
|
||||||
|
|
||||||
|
constexpr bool kUseStaticNGroup = (NGroup > 0);
|
||||||
|
// use int32 to avoid implicit conversion
|
||||||
|
int32_t const n_group_i32 =
|
||||||
|
kUseStaticNGroup ? NGroup : static_cast<int32_t>(n_group);
|
||||||
|
|
||||||
int32_t align_num_experts_per_group =
|
int32_t align_num_experts_per_group =
|
||||||
warp_topk::round_up_to_multiple_of<WARP_SIZE>(num_experts_per_group);
|
warp_topk::round_up_to_multiple_of<WARP_SIZE>(num_experts_per_group);
|
||||||
|
|
||||||
@ -574,13 +578,14 @@ __global__ void group_idx_and_topk_idx_kernel(
|
|||||||
|
|
||||||
if (case_id < num_tokens) {
|
if (case_id < num_tokens) {
|
||||||
// calculate group_idx
|
// calculate group_idx
|
||||||
int32_t target_num_min = WARP_SIZE - n_group + topk_group;
|
int32_t target_num_min =
|
||||||
|
WARP_SIZE - n_group_i32 + static_cast<int32_t>(topk_group);
|
||||||
// The check is necessary to avoid abnormal input
|
// The check is necessary to avoid abnormal input
|
||||||
if (lane_id < n_group && is_finite(group_scores[lane_id])) {
|
if (lane_id < n_group_i32 && is_finite(group_scores[lane_id])) {
|
||||||
value = group_scores[lane_id];
|
value = group_scores[lane_id];
|
||||||
}
|
}
|
||||||
|
|
||||||
int count_equal_to_top_value = WARP_SIZE - n_group;
|
int count_equal_to_top_value = WARP_SIZE - n_group_i32;
|
||||||
int pre_count_equal_to_top_value = 0;
|
int pre_count_equal_to_top_value = 0;
|
||||||
// Use loop to find the largset top_group
|
// Use loop to find the largset top_group
|
||||||
while (count_equal_to_top_value < target_num_min) {
|
while (count_equal_to_top_value < target_num_min) {
|
||||||
@ -604,7 +609,7 @@ __global__ void group_idx_and_topk_idx_kernel(
|
|||||||
int count_equalto_topkth_group = 0;
|
int count_equalto_topkth_group = 0;
|
||||||
bool if_proceed_next_topk = topk_group_value != neg_inf<T>();
|
bool if_proceed_next_topk = topk_group_value != neg_inf<T>();
|
||||||
if (case_id < num_tokens && if_proceed_next_topk) {
|
if (case_id < num_tokens && if_proceed_next_topk) {
|
||||||
for (int i_group = 0; i_group < n_group; i_group++) {
|
auto process_group = [&](int i_group) {
|
||||||
if ((group_scores[i_group] > topk_group_value) ||
|
if ((group_scores[i_group] > topk_group_value) ||
|
||||||
((group_scores[i_group] == topk_group_value) &&
|
((group_scores[i_group] == topk_group_value) &&
|
||||||
(count_equalto_topkth_group < num_equalto_topkth_group))) {
|
(count_equalto_topkth_group < num_equalto_topkth_group))) {
|
||||||
@ -613,11 +618,10 @@ __global__ void group_idx_and_topk_idx_kernel(
|
|||||||
i += WARP_SIZE) {
|
i += WARP_SIZE) {
|
||||||
T candidates = neg_inf<T>();
|
T candidates = neg_inf<T>();
|
||||||
if (i < num_experts_per_group) {
|
if (i < num_experts_per_group) {
|
||||||
// Apply scoring function (if any) and add bias
|
// apply scoring function (if any) and add bias
|
||||||
T input = scores[offset + i];
|
T input = scores[offset + i];
|
||||||
if (is_finite(input)) {
|
if (is_finite(input)) {
|
||||||
T score = (scoring_func == SCORING_SIGMOID) ? apply_sigmoid(input)
|
T score = apply_scoring<SF>(input);
|
||||||
: input;
|
|
||||||
candidates = score + bias[offset + i];
|
candidates = score + bias[offset + i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -627,6 +631,17 @@ __global__ void group_idx_and_topk_idx_kernel(
|
|||||||
count_equalto_topkth_group++;
|
count_equalto_topkth_group++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if constexpr (kUseStaticNGroup) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i_group = 0; i_group < NGroup; ++i_group) {
|
||||||
|
process_group(i_group);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i_group = 0; i_group < n_group_i32; ++i_group) {
|
||||||
|
process_group(i_group);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
queue.done();
|
queue.done();
|
||||||
__syncwarp();
|
__syncwarp();
|
||||||
@ -646,12 +661,13 @@ __global__ void group_idx_and_topk_idx_kernel(
|
|||||||
if (i < topk) {
|
if (i < topk) {
|
||||||
// Load the score value (without bias) for normalization
|
// Load the score value (without bias) for normalization
|
||||||
T input = scores[s_topk_idx[i]];
|
T input = scores[s_topk_idx[i]];
|
||||||
value =
|
value = apply_scoring<SF>(input);
|
||||||
(scoring_func == SCORING_SIGMOID) ? apply_sigmoid(input) : input;
|
|
||||||
s_topk_value[i] = value;
|
s_topk_value[i] = value;
|
||||||
}
|
}
|
||||||
topk_sum +=
|
if (renormalize) {
|
||||||
cg::reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
|
topk_sum +=
|
||||||
|
cg::reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -660,13 +676,9 @@ __global__ void group_idx_and_topk_idx_kernel(
|
|||||||
if (case_id < num_tokens) {
|
if (case_id < num_tokens) {
|
||||||
if (if_proceed_next_topk) {
|
if (if_proceed_next_topk) {
|
||||||
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
||||||
float value;
|
float base = cuda_cast<float, T>(s_topk_value[i]);
|
||||||
if (renormalize) {
|
float value = renormalize ? (base / topk_sum * routed_scaling_factor)
|
||||||
value = cuda_cast<float, T>(s_topk_value[i]) / topk_sum *
|
: (base * routed_scaling_factor);
|
||||||
routed_scaling_factor;
|
|
||||||
} else {
|
|
||||||
value = cuda_cast<float, T>(s_topk_value[i]) * routed_scaling_factor;
|
|
||||||
}
|
|
||||||
topk_indices[i] = s_topk_idx[i];
|
topk_indices[i] = s_topk_idx[i];
|
||||||
topk_values[i] = value;
|
topk_values[i] = value;
|
||||||
}
|
}
|
||||||
@ -684,6 +696,45 @@ __global__ void group_idx_and_topk_idx_kernel(
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, typename IdxT, ScoringFunc SF>
|
||||||
|
inline void launch_group_idx_and_topk_kernel(
|
||||||
|
cudaLaunchConfig_t const& config, T* scores, T* group_scores,
|
||||||
|
float* topk_values, IdxT* topk_indices, T const* bias,
|
||||||
|
int64_t const num_tokens, int64_t const n_group, int64_t const topk_group,
|
||||||
|
int64_t const topk, int64_t const num_experts,
|
||||||
|
int64_t const num_experts_per_group, bool const renormalize,
|
||||||
|
double const routed_scaling_factor) {
|
||||||
|
auto launch = [&](auto* kernel_instance2) {
|
||||||
|
cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores,
|
||||||
|
topk_values, topk_indices, bias, num_tokens, n_group,
|
||||||
|
topk_group, topk, num_experts, num_experts_per_group,
|
||||||
|
renormalize, routed_scaling_factor);
|
||||||
|
};
|
||||||
|
|
||||||
|
switch (n_group) {
|
||||||
|
case 4: {
|
||||||
|
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF, 4>);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 8: {
|
||||||
|
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF, 8>);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 16: {
|
||||||
|
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF, 16>);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case 32: {
|
||||||
|
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF, 32>);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default: {
|
||||||
|
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF>);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T, typename IdxT>
|
template <typename T, typename IdxT>
|
||||||
void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
|
void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
|
||||||
IdxT* topk_indices, T const* bias, int64_t const num_tokens,
|
IdxT* topk_indices, T const* bias, int64_t const num_tokens,
|
||||||
@ -694,7 +745,6 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
|
|||||||
cudaStream_t const stream = 0) {
|
cudaStream_t const stream = 0) {
|
||||||
int64_t num_cases = num_tokens * n_group;
|
int64_t num_cases = num_tokens * n_group;
|
||||||
int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
|
int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
|
||||||
auto* kernel_instance1 = &topk_with_k2_kernel<T>;
|
|
||||||
cudaLaunchConfig_t config;
|
cudaLaunchConfig_t config;
|
||||||
config.gridDim = topk_with_k2_num_blocks;
|
config.gridDim = topk_with_k2_num_blocks;
|
||||||
config.blockDim = BLOCK_SIZE;
|
config.blockDim = BLOCK_SIZE;
|
||||||
@ -705,16 +755,33 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
|
|||||||
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
|
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
|
||||||
config.numAttrs = 1;
|
config.numAttrs = 1;
|
||||||
config.attrs = attrs;
|
config.attrs = attrs;
|
||||||
cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores, bias,
|
auto const sf = static_cast<ScoringFunc>(scoring_func);
|
||||||
num_tokens, num_cases, n_group, num_experts / n_group,
|
int64_t const num_experts_per_group = num_experts / n_group;
|
||||||
scoring_func);
|
auto launch_topk_with_k2 = [&](auto* kernel_instance1) {
|
||||||
|
cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores, bias,
|
||||||
|
num_tokens, num_cases, n_group, num_experts_per_group);
|
||||||
|
};
|
||||||
|
switch (sf) {
|
||||||
|
case SCORING_NONE: {
|
||||||
|
auto* kernel_instance1 = &topk_with_k2_kernel<T, SCORING_NONE>;
|
||||||
|
launch_topk_with_k2(kernel_instance1);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case SCORING_SIGMOID: {
|
||||||
|
auto* kernel_instance1 = &topk_with_k2_kernel<T, SCORING_SIGMOID>;
|
||||||
|
launch_topk_with_k2(kernel_instance1);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// should be guarded by higher level checks.
|
||||||
|
TORCH_CHECK(false, "Unsupported scoring_func in invokeNoAuxTc");
|
||||||
|
}
|
||||||
|
|
||||||
int64_t topk_with_k_group_num_blocks =
|
int64_t topk_with_k_group_num_blocks =
|
||||||
(num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1;
|
(num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1;
|
||||||
size_t dynamic_smem_in_bytes =
|
size_t dynamic_smem_in_bytes =
|
||||||
warp_topk::calc_smem_size_for_block_wide<T, int32_t>(NUM_WARPS_PER_BLOCK,
|
warp_topk::calc_smem_size_for_block_wide<T, int32_t>(NUM_WARPS_PER_BLOCK,
|
||||||
topk);
|
topk);
|
||||||
auto* kernel_instance2 = &group_idx_and_topk_idx_kernel<T, IdxT>;
|
|
||||||
config.gridDim = topk_with_k_group_num_blocks;
|
config.gridDim = topk_with_k_group_num_blocks;
|
||||||
config.blockDim = BLOCK_SIZE;
|
config.blockDim = BLOCK_SIZE;
|
||||||
config.dynamicSmemBytes = dynamic_smem_in_bytes;
|
config.dynamicSmemBytes = dynamic_smem_in_bytes;
|
||||||
@ -723,10 +790,24 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
|
|||||||
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
|
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
|
||||||
config.numAttrs = 1;
|
config.numAttrs = 1;
|
||||||
config.attrs = attrs;
|
config.attrs = attrs;
|
||||||
cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores,
|
switch (sf) {
|
||||||
topk_values, topk_indices, bias, num_tokens, n_group,
|
case SCORING_NONE: {
|
||||||
topk_group, topk, num_experts, num_experts / n_group,
|
launch_group_idx_and_topk_kernel<T, IdxT, SCORING_NONE>(
|
||||||
renormalize, routed_scaling_factor, scoring_func);
|
config, scores, group_scores, topk_values, topk_indices, bias,
|
||||||
|
num_tokens, n_group, topk_group, topk, num_experts,
|
||||||
|
num_experts_per_group, renormalize, routed_scaling_factor);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case SCORING_SIGMOID: {
|
||||||
|
launch_group_idx_and_topk_kernel<T, IdxT, SCORING_SIGMOID>(
|
||||||
|
config, scores, group_scores, topk_values, topk_indices, bias,
|
||||||
|
num_tokens, n_group, topk_group, topk, num_experts,
|
||||||
|
num_experts_per_group, renormalize, routed_scaling_factor);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
TORCH_CHECK(false, "Unsupported scoring_func in invokeNoAuxTc");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define INSTANTIATE_NOAUX_TC(T, IdxT) \
|
#define INSTANTIATE_NOAUX_TC(T, IdxT) \
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user