mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 12:25:01 +08:00
[Perf] Support topk softmax fused kernel for broader num_experts (#22211)
Signed-off-by: Shixian Cui <shixian@amazon.com> Co-authored-by: Shixian Cui <shixian@amazon.com>
This commit is contained in:
parent
77a6bf07ae
commit
4c558cf62e
@ -188,7 +188,9 @@ __launch_bounds__(TPB) __global__ void moeTopK(
|
||||
It fuses the softmax, max and argmax into a single kernel.
|
||||
|
||||
Limitations:
|
||||
1) This implementation is intended for when the number of experts is a small power of 2.
|
||||
1) This implementation is optimized for when the number of experts is a small power of 2.
|
||||
Additionally it also supports when number of experts is multiple of 64 which is still
|
||||
faster than the computing softmax and topK separately (only tested on CUDA yet).
|
||||
2) This implementation assumes k is small, but will work for any k.
|
||||
*/
|
||||
|
||||
@ -198,8 +200,6 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
|
||||
int* source_rows, const int k, const int start_expert, const int end_expert)
|
||||
{
|
||||
// We begin by enforcing compile time assertions and setting up compile time constants.
|
||||
static_assert(VPT == (VPT & -VPT), "VPT must be power of 2");
|
||||
static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2");
|
||||
static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2");
|
||||
static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");
|
||||
|
||||
@ -407,12 +407,10 @@ struct TopkConstants
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <int EXPERTS, int WARPS_PER_TB, int WARP_SIZE_PARAM, typename IndType>
|
||||
template <int EXPERTS, int WARPS_PER_TB, int WARP_SIZE_PARAM, int MAX_BYTES_PER_LDG, typename IndType>
|
||||
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, IndType* indices,
|
||||
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
|
||||
{
|
||||
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
|
||||
|
||||
static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
|
||||
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG, WARP_SIZE_PARAM>;
|
||||
static constexpr int VPT = Constants::VPT;
|
||||
@ -425,21 +423,12 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
|
||||
input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert);
|
||||
}
|
||||
|
||||
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \
|
||||
switch (warpSize) { \
|
||||
case 32: \
|
||||
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 32>( \
|
||||
gating_output, nullptr, topk_weights, topk_indices, \
|
||||
token_expert_indices, num_tokens, topk, 0, num_experts, stream); \
|
||||
break; \
|
||||
case 64: \
|
||||
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 64>( \
|
||||
gating_output, nullptr, topk_weights, topk_indices, \
|
||||
token_expert_indices, num_tokens, topk, 0, num_experts, stream); \
|
||||
break; \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported warp size: ", warpSize); \
|
||||
}
|
||||
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \
|
||||
static_assert(WARP_SIZE == 32 || WARP_SIZE == 64, \
|
||||
"Unsupported warp size. Only 32 and 64 are supported."); \
|
||||
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, WARP_SIZE, MAX_BYTES>( \
|
||||
gating_output, nullptr, topk_weights, topk_indices, \
|
||||
token_expert_indices, num_tokens, topk, 0, num_experts, stream);
|
||||
|
||||
template <typename IndType>
|
||||
void topkGatingSoftmaxKernelLauncher(
|
||||
@ -453,38 +442,62 @@ void topkGatingSoftmaxKernelLauncher(
|
||||
const int topk,
|
||||
cudaStream_t stream) {
|
||||
static constexpr int WARPS_PER_TB = 4;
|
||||
auto warpSize = WARP_SIZE;
|
||||
static constexpr int BYTES_PER_LDG_POWER_OF_2 = 16;
|
||||
static constexpr int BYTES_PER_LDG_MULTIPLE_64 = 8;
|
||||
switch (num_experts) {
|
||||
case 1:
|
||||
LAUNCH_SOFTMAX(1, WARPS_PER_TB);
|
||||
LAUNCH_SOFTMAX(1, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 2:
|
||||
LAUNCH_SOFTMAX(2, WARPS_PER_TB);
|
||||
LAUNCH_SOFTMAX(2, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 4:
|
||||
LAUNCH_SOFTMAX(4, WARPS_PER_TB);
|
||||
LAUNCH_SOFTMAX(4, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 8:
|
||||
LAUNCH_SOFTMAX(8, WARPS_PER_TB);
|
||||
LAUNCH_SOFTMAX(8, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 16:
|
||||
LAUNCH_SOFTMAX(16, WARPS_PER_TB);
|
||||
LAUNCH_SOFTMAX(16, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 32:
|
||||
LAUNCH_SOFTMAX(32, WARPS_PER_TB);
|
||||
LAUNCH_SOFTMAX(32, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 64:
|
||||
LAUNCH_SOFTMAX(64, WARPS_PER_TB);
|
||||
LAUNCH_SOFTMAX(64, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 128:
|
||||
LAUNCH_SOFTMAX(128, WARPS_PER_TB);
|
||||
LAUNCH_SOFTMAX(128, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 256:
|
||||
LAUNCH_SOFTMAX(256, WARPS_PER_TB);
|
||||
LAUNCH_SOFTMAX(256, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
case 512:
|
||||
LAUNCH_SOFTMAX(512, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
// (CUDA only) support multiples of 64 when num_experts is not power of 2.
|
||||
// ROCm uses WARP_SIZE 64 so 8 bytes loading won't fit for some of num_experts,
|
||||
// alternatively we can test 4 bytes loading and enable it in future.
|
||||
#ifndef USE_ROCM
|
||||
case 192:
|
||||
LAUNCH_SOFTMAX(192, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
|
||||
break;
|
||||
case 320:
|
||||
LAUNCH_SOFTMAX(320, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
|
||||
break;
|
||||
case 384:
|
||||
LAUNCH_SOFTMAX(384, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
|
||||
break;
|
||||
case 448:
|
||||
LAUNCH_SOFTMAX(448, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
|
||||
break;
|
||||
case 576:
|
||||
LAUNCH_SOFTMAX(576, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
|
||||
break;
|
||||
#endif
|
||||
default: {
|
||||
TORCH_CHECK(softmax_workspace != nullptr,
|
||||
"softmax_workspace must be provided for num_experts that are not a power of 2.");
|
||||
"softmax_workspace must be provided for num_experts that are not a power of 2 or multiple of 64.");
|
||||
static constexpr int TPB = 256;
|
||||
moeSoftmax<TPB><<<num_tokens, TPB, 0, stream>>>(
|
||||
gating_output, nullptr, softmax_workspace, num_experts);
|
||||
|
||||
@ -36,7 +36,7 @@ from vllm.model_executor.models.mixtral import MixtralMoE
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
NUM_EXPERTS = [8, 64]
|
||||
NUM_EXPERTS = [8, 64, 192]
|
||||
EP_SIZE = [1, 4]
|
||||
TOP_KS = [2, 6]
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user