mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 09:25:29 +08:00
[Perf] Support topk softmax fused kernel for broader num_experts (#22211)
Signed-off-by: Shixian Cui <shixian@amazon.com> Co-authored-by: Shixian Cui <shixian@amazon.com>
This commit is contained in:
parent
77a6bf07ae
commit
4c558cf62e
@ -188,7 +188,9 @@ __launch_bounds__(TPB) __global__ void moeTopK(
|
|||||||
It fuses the softmax, max and argmax into a single kernel.
|
It fuses the softmax, max and argmax into a single kernel.
|
||||||
|
|
||||||
Limitations:
|
Limitations:
|
||||||
1) This implementation is intended for when the number of experts is a small power of 2.
|
1) This implementation is optimized for when the number of experts is a small power of 2.
|
||||||
|
Additionally it also supports when number of experts is multiple of 64 which is still
|
||||||
|
faster than the computing softmax and topK separately (only tested on CUDA yet).
|
||||||
2) This implementation assumes k is small, but will work for any k.
|
2) This implementation assumes k is small, but will work for any k.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
@ -198,8 +200,6 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
|
|||||||
int* source_rows, const int k, const int start_expert, const int end_expert)
|
int* source_rows, const int k, const int start_expert, const int end_expert)
|
||||||
{
|
{
|
||||||
// We begin by enforcing compile time assertions and setting up compile time constants.
|
// We begin by enforcing compile time assertions and setting up compile time constants.
|
||||||
static_assert(VPT == (VPT & -VPT), "VPT must be power of 2");
|
|
||||||
static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2");
|
|
||||||
static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2");
|
static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2");
|
||||||
static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");
|
static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");
|
||||||
|
|
||||||
@ -407,12 +407,10 @@ struct TopkConstants
|
|||||||
};
|
};
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
template <int EXPERTS, int WARPS_PER_TB, int WARP_SIZE_PARAM, typename IndType>
|
template <int EXPERTS, int WARPS_PER_TB, int WARP_SIZE_PARAM, int MAX_BYTES_PER_LDG, typename IndType>
|
||||||
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, IndType* indices,
|
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, IndType* indices,
|
||||||
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
|
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
|
||||||
{
|
{
|
||||||
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
|
|
||||||
|
|
||||||
static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
|
static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
|
||||||
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG, WARP_SIZE_PARAM>;
|
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG, WARP_SIZE_PARAM>;
|
||||||
static constexpr int VPT = Constants::VPT;
|
static constexpr int VPT = Constants::VPT;
|
||||||
@ -425,21 +423,12 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
|
|||||||
input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert);
|
input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert);
|
||||||
}
|
}
|
||||||
|
|
||||||
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \
|
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \
|
||||||
switch (warpSize) { \
|
static_assert(WARP_SIZE == 32 || WARP_SIZE == 64, \
|
||||||
case 32: \
|
"Unsupported warp size. Only 32 and 64 are supported."); \
|
||||||
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 32>( \
|
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, WARP_SIZE, MAX_BYTES>( \
|
||||||
gating_output, nullptr, topk_weights, topk_indices, \
|
gating_output, nullptr, topk_weights, topk_indices, \
|
||||||
token_expert_indices, num_tokens, topk, 0, num_experts, stream); \
|
token_expert_indices, num_tokens, topk, 0, num_experts, stream);
|
||||||
break; \
|
|
||||||
case 64: \
|
|
||||||
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 64>( \
|
|
||||||
gating_output, nullptr, topk_weights, topk_indices, \
|
|
||||||
token_expert_indices, num_tokens, topk, 0, num_experts, stream); \
|
|
||||||
break; \
|
|
||||||
default: \
|
|
||||||
TORCH_CHECK(false, "Unsupported warp size: ", warpSize); \
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename IndType>
|
template <typename IndType>
|
||||||
void topkGatingSoftmaxKernelLauncher(
|
void topkGatingSoftmaxKernelLauncher(
|
||||||
@ -453,38 +442,62 @@ void topkGatingSoftmaxKernelLauncher(
|
|||||||
const int topk,
|
const int topk,
|
||||||
cudaStream_t stream) {
|
cudaStream_t stream) {
|
||||||
static constexpr int WARPS_PER_TB = 4;
|
static constexpr int WARPS_PER_TB = 4;
|
||||||
auto warpSize = WARP_SIZE;
|
static constexpr int BYTES_PER_LDG_POWER_OF_2 = 16;
|
||||||
|
static constexpr int BYTES_PER_LDG_MULTIPLE_64 = 8;
|
||||||
switch (num_experts) {
|
switch (num_experts) {
|
||||||
case 1:
|
case 1:
|
||||||
LAUNCH_SOFTMAX(1, WARPS_PER_TB);
|
LAUNCH_SOFTMAX(1, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||||
break;
|
break;
|
||||||
case 2:
|
case 2:
|
||||||
LAUNCH_SOFTMAX(2, WARPS_PER_TB);
|
LAUNCH_SOFTMAX(2, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||||
break;
|
break;
|
||||||
case 4:
|
case 4:
|
||||||
LAUNCH_SOFTMAX(4, WARPS_PER_TB);
|
LAUNCH_SOFTMAX(4, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||||
break;
|
break;
|
||||||
case 8:
|
case 8:
|
||||||
LAUNCH_SOFTMAX(8, WARPS_PER_TB);
|
LAUNCH_SOFTMAX(8, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||||
break;
|
break;
|
||||||
case 16:
|
case 16:
|
||||||
LAUNCH_SOFTMAX(16, WARPS_PER_TB);
|
LAUNCH_SOFTMAX(16, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||||
break;
|
break;
|
||||||
case 32:
|
case 32:
|
||||||
LAUNCH_SOFTMAX(32, WARPS_PER_TB);
|
LAUNCH_SOFTMAX(32, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||||
break;
|
break;
|
||||||
case 64:
|
case 64:
|
||||||
LAUNCH_SOFTMAX(64, WARPS_PER_TB);
|
LAUNCH_SOFTMAX(64, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||||
break;
|
break;
|
||||||
case 128:
|
case 128:
|
||||||
LAUNCH_SOFTMAX(128, WARPS_PER_TB);
|
LAUNCH_SOFTMAX(128, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||||
break;
|
break;
|
||||||
case 256:
|
case 256:
|
||||||
LAUNCH_SOFTMAX(256, WARPS_PER_TB);
|
LAUNCH_SOFTMAX(256, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||||
break;
|
break;
|
||||||
|
case 512:
|
||||||
|
LAUNCH_SOFTMAX(512, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||||
|
break;
|
||||||
|
// (CUDA only) support multiples of 64 when num_experts is not power of 2.
|
||||||
|
// ROCm uses WARP_SIZE 64 so 8 bytes loading won't fit for some of num_experts,
|
||||||
|
// alternatively we can test 4 bytes loading and enable it in future.
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
case 192:
|
||||||
|
LAUNCH_SOFTMAX(192, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
|
||||||
|
break;
|
||||||
|
case 320:
|
||||||
|
LAUNCH_SOFTMAX(320, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
|
||||||
|
break;
|
||||||
|
case 384:
|
||||||
|
LAUNCH_SOFTMAX(384, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
|
||||||
|
break;
|
||||||
|
case 448:
|
||||||
|
LAUNCH_SOFTMAX(448, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
|
||||||
|
break;
|
||||||
|
case 576:
|
||||||
|
LAUNCH_SOFTMAX(576, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
|
||||||
|
break;
|
||||||
|
#endif
|
||||||
default: {
|
default: {
|
||||||
TORCH_CHECK(softmax_workspace != nullptr,
|
TORCH_CHECK(softmax_workspace != nullptr,
|
||||||
"softmax_workspace must be provided for num_experts that are not a power of 2.");
|
"softmax_workspace must be provided for num_experts that are not a power of 2 or multiple of 64.");
|
||||||
static constexpr int TPB = 256;
|
static constexpr int TPB = 256;
|
||||||
moeSoftmax<TPB><<<num_tokens, TPB, 0, stream>>>(
|
moeSoftmax<TPB><<<num_tokens, TPB, 0, stream>>>(
|
||||||
gating_output, nullptr, softmax_workspace, num_experts);
|
gating_output, nullptr, softmax_workspace, num_experts);
|
||||||
|
|||||||
@ -36,7 +36,7 @@ from vllm.model_executor.models.mixtral import MixtralMoE
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.scalar_type import ScalarType, scalar_types
|
from vllm.scalar_type import ScalarType, scalar_types
|
||||||
|
|
||||||
NUM_EXPERTS = [8, 64]
|
NUM_EXPERTS = [8, 64, 192]
|
||||||
EP_SIZE = [1, 4]
|
EP_SIZE = [1, 4]
|
||||||
TOP_KS = [2, 6]
|
TOP_KS = [2, 6]
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user