From a59cd9d9f7fd89e19beeffb7e7f89437d413eafb Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Fri, 1 Aug 2025 09:10:30 -0400 Subject: [PATCH] [Refactor] Fix Compile Warning #1444-D (#21462) Signed-off-by: yewentao256 --- csrc/moe/topk_softmax_kernels.cu | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu index 0b505d2e04a21..7a7865b901de1 100644 --- a/csrc/moe/topk_softmax_kernels.cu +++ b/csrc/moe/topk_softmax_kernels.cu @@ -24,9 +24,12 @@ #ifndef USE_ROCM #include #include + #include + using AddOp = cuda::std::plus; #else #include #include + using AddOp = cub::Sum; #endif #define MAX(a, b) ((a) > (b) ? (a) : (b)) @@ -62,7 +65,6 @@ __launch_bounds__(TPB) __global__ const int thread_row_offset = blockIdx.x * num_cols; - cub::Sum sum; float threadData(-FLT_MAX); // Don't touch finished rows. @@ -92,7 +94,7 @@ __launch_bounds__(TPB) __global__ threadData += exp((static_cast(input[idx]) - float_max)); } - const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum); + const auto Z = BlockReduce(tmpStorage).Reduce(threadData, AddOp()); if (threadIdx.x == 0) {