mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-29 18:27:14 +08:00
[Refactor] Fix Compile Warning #1444-D (#21462)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
5c54d9759d
commit
a59cd9d9f7
@ -24,9 +24,12 @@
|
|||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
#include <cub/util_type.cuh>
|
#include <cub/util_type.cuh>
|
||||||
#include <cub/cub.cuh>
|
#include <cub/cub.cuh>
|
||||||
|
#include <cuda/std/functional>
|
||||||
|
using AddOp = cuda::std::plus<float>;
|
||||||
#else
|
#else
|
||||||
#include <hipcub/util_type.hpp>
|
#include <hipcub/util_type.hpp>
|
||||||
#include <hipcub/hipcub.hpp>
|
#include <hipcub/hipcub.hpp>
|
||||||
|
using AddOp = cub::Sum;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||||
@ -62,7 +65,6 @@ __launch_bounds__(TPB) __global__
|
|||||||
|
|
||||||
const int thread_row_offset = blockIdx.x * num_cols;
|
const int thread_row_offset = blockIdx.x * num_cols;
|
||||||
|
|
||||||
cub::Sum sum;
|
|
||||||
float threadData(-FLT_MAX);
|
float threadData(-FLT_MAX);
|
||||||
|
|
||||||
// Don't touch finished rows.
|
// Don't touch finished rows.
|
||||||
@ -92,7 +94,7 @@ __launch_bounds__(TPB) __global__
|
|||||||
threadData += exp((static_cast<float>(input[idx]) - float_max));
|
threadData += exp((static_cast<float>(input[idx]) - float_max));
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
|
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, AddOp());
|
||||||
|
|
||||||
if (threadIdx.x == 0)
|
if (threadIdx.x == 0)
|
||||||
{
|
{
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user