diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu index 946c137db6366..99c52ef17d08b 100644 --- a/csrc/moe/topk_softmax_kernels.cu +++ b/csrc/moe/topk_softmax_kernels.cu @@ -423,12 +423,27 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert); } +#ifndef USE_ROCM #define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \ - static_assert(WARP_SIZE == 32 || WARP_SIZE == 64, \ - "Unsupported warp size. Only 32 and 64 are supported."); \ + static_assert(WARP_SIZE == 32, \ + "Unsupported warp size. Only 32 is supported for CUDA"); \ topkGatingSoftmaxLauncherHelper( \ gating_output, nullptr, topk_weights, topk_indices, \ token_expert_indices, num_tokens, topk, 0, num_experts, stream); +#else +#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \ + if (WARP_SIZE == 64) { \ + topkGatingSoftmaxLauncherHelper( \ + gating_output, nullptr, topk_weights, topk_indices, \ + token_expert_indices, num_tokens, topk, 0, num_experts, stream); \ + } else if (WARP_SIZE == 32) { \ + topkGatingSoftmaxLauncherHelper( \ + gating_output, nullptr, topk_weights, topk_indices, \ + token_expert_indices, num_tokens, topk, 0, num_experts, stream); \ + } else { \ + assert(false && "Unsupported warp size. Only 32 and 64 are supported for ROCm"); \ + } +#endif template void topkGatingSoftmaxKernelLauncher( @@ -443,7 +458,9 @@ void topkGatingSoftmaxKernelLauncher( cudaStream_t stream) { static constexpr int WARPS_PER_TB = 4; static constexpr int BYTES_PER_LDG_POWER_OF_2 = 16; +#ifndef USE_ROCM static constexpr int BYTES_PER_LDG_MULTIPLE_64 = 8; +#endif switch (num_experts) { case 1: LAUNCH_SOFTMAX(1, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);