From c6cd5ca3d3294f62bf5fad25ada8192ba39249b9 Mon Sep 17 00:00:00 2001 From: kliuae <17350011+kliuae@users.noreply.github.com> Date: Thu, 14 Aug 2025 04:45:03 +0800 Subject: [PATCH] [ROCm][Bugfix] Fix compilation error in topk softmax fused kernel (#22819) Signed-off-by: kliuae --- csrc/moe/topk_softmax_kernels.cu | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu index 946c137db6366..99c52ef17d08b 100644 --- a/csrc/moe/topk_softmax_kernels.cu +++ b/csrc/moe/topk_softmax_kernels.cu @@ -423,12 +423,27 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert); } +#ifndef USE_ROCM #define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \ - static_assert(WARP_SIZE == 32 || WARP_SIZE == 64, \ - "Unsupported warp size. Only 32 and 64 are supported."); \ + static_assert(WARP_SIZE == 32, \ + "Unsupported warp size. Only 32 is supported for CUDA"); \ topkGatingSoftmaxLauncherHelper( \ gating_output, nullptr, topk_weights, topk_indices, \ token_expert_indices, num_tokens, topk, 0, num_experts, stream); +#else +#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \ + if (WARP_SIZE == 64) { \ + topkGatingSoftmaxLauncherHelper( \ + gating_output, nullptr, topk_weights, topk_indices, \ + token_expert_indices, num_tokens, topk, 0, num_experts, stream); \ + } else if (WARP_SIZE == 32) { \ + topkGatingSoftmaxLauncherHelper( \ + gating_output, nullptr, topk_weights, topk_indices, \ + token_expert_indices, num_tokens, topk, 0, num_experts, stream); \ + } else { \ + assert(false && "Unsupported warp size. Only 32 and 64 are supported for ROCm"); \ + } +#endif template void topkGatingSoftmaxKernelLauncher( @@ -443,7 +458,9 @@ void topkGatingSoftmaxKernelLauncher( cudaStream_t stream) { static constexpr int WARPS_PER_TB = 4; static constexpr int BYTES_PER_LDG_POWER_OF_2 = 16; +#ifndef USE_ROCM static constexpr int BYTES_PER_LDG_MULTIPLE_64 = 8; +#endif switch (num_experts) { case 1: LAUNCH_SOFTMAX(1, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);