[ROCm][Bugfix] Fix compilation error in topk softmax fused kernel (#22819)

Signed-off-by: kliuae <kuanfu.liu@embeddedllm.com>
This commit is contained in:
kliuae 2025-08-14 04:45:03 +08:00 committed by GitHub
parent df0e0f023e
commit c6cd5ca3d3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -423,12 +423,27 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert);
}
#ifndef USE_ROCM
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \
static_assert(WARP_SIZE == 32 || WARP_SIZE == 64, \
"Unsupported warp size. Only 32 and 64 are supported."); \
static_assert(WARP_SIZE == 32, \
"Unsupported warp size. Only 32 is supported for CUDA"); \
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, WARP_SIZE, MAX_BYTES>( \
gating_output, nullptr, topk_weights, topk_indices, \
token_expert_indices, num_tokens, topk, 0, num_experts, stream);
#else
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \
if (WARP_SIZE == 64) { \
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 64, MAX_BYTES>( \
gating_output, nullptr, topk_weights, topk_indices, \
token_expert_indices, num_tokens, topk, 0, num_experts, stream); \
} else if (WARP_SIZE == 32) { \
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 32, MAX_BYTES>( \
gating_output, nullptr, topk_weights, topk_indices, \
token_expert_indices, num_tokens, topk, 0, num_experts, stream); \
} else { \
assert(false && "Unsupported warp size. Only 32 and 64 are supported for ROCm"); \
}
#endif
template <typename IndType>
void topkGatingSoftmaxKernelLauncher(
@ -443,7 +458,9 @@ void topkGatingSoftmaxKernelLauncher(
cudaStream_t stream) {
static constexpr int WARPS_PER_TB = 4;
static constexpr int BYTES_PER_LDG_POWER_OF_2 = 16;
#ifndef USE_ROCM
static constexpr int BYTES_PER_LDG_MULTIPLE_64 = 8;
#endif
switch (num_experts) {
case 1:
LAUNCH_SOFTMAX(1, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);