[Refactor] Fix Compile Warning #1444-D (#21462)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye 2025-08-01 09:10:30 -04:00 committed by GitHub
parent 5c54d9759d
commit a59cd9d9f7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -24,9 +24,12 @@
#ifndef USE_ROCM #ifndef USE_ROCM
#include <cub/util_type.cuh> #include <cub/util_type.cuh>
#include <cub/cub.cuh> #include <cub/cub.cuh>
#include <cuda/std/functional>
using AddOp = cuda::std::plus<float>;
#else #else
#include <hipcub/util_type.hpp> #include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp> #include <hipcub/hipcub.hpp>
using AddOp = cub::Sum;
#endif #endif
#define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b))
@ -62,7 +65,6 @@ __launch_bounds__(TPB) __global__
const int thread_row_offset = blockIdx.x * num_cols; const int thread_row_offset = blockIdx.x * num_cols;
cub::Sum sum;
float threadData(-FLT_MAX); float threadData(-FLT_MAX);
// Don't touch finished rows. // Don't touch finished rows.
@ -92,7 +94,7 @@ __launch_bounds__(TPB) __global__
threadData += exp((static_cast<float>(input[idx]) - float_max)); threadData += exp((static_cast<float>(input[idx]) - float_max));
} }
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum); const auto Z = BlockReduce(tmpStorage).Reduce(threadData, AddOp());
if (threadIdx.x == 0) if (threadIdx.x == 0)
{ {