mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 15:29:09 +08:00
Merge ac23d0ba18407d09501fe470573940b63129964f into 254f6b986720c92ddf97fbb1a6a6465da8e87e29
This commit is contained in:
commit
e28a0c8fb2
@ -1,3 +1,4 @@
|
|||||||
|
#include "cuda_compat.h"
|
||||||
#include "dispatch_utils.h"
|
#include "dispatch_utils.h"
|
||||||
|
|
||||||
#include <torch/cuda.h>
|
#include <torch/cuda.h>
|
||||||
@ -97,7 +98,9 @@ static inline __device__ bool isPartialMatch(float x, uint32_t pattern) {
|
|||||||
template <typename T, typename idxT, typename Func>
|
template <typename T, typename idxT, typename Func>
|
||||||
__device__ void vectorized_process(size_t thread_rank, size_t num_threads,
|
__device__ void vectorized_process(size_t thread_rank, size_t num_threads,
|
||||||
const T* in, idxT len, Func f) {
|
const T* in, idxT len, Func f) {
|
||||||
constexpr int WARP_SIZE = 32;
|
// Use dynamic WARP_SIZE from cuda_compat.h to support both
|
||||||
|
// Wave64 (MI300X/gfx942) and Wave32 (Strix Halo/gfx1151) architectures
|
||||||
|
constexpr int kWarpSize = WARP_SIZE;
|
||||||
using WideT = float4;
|
using WideT = float4;
|
||||||
if constexpr (sizeof(T) >= sizeof(WideT)) {
|
if constexpr (sizeof(T) >= sizeof(WideT)) {
|
||||||
for (idxT i = thread_rank; i < len; i += num_threads) {
|
for (idxT i = thread_rank; i < len; i += num_threads) {
|
||||||
@ -132,8 +135,8 @@ __device__ void vectorized_process(size_t thread_rank, size_t num_threads,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static_assert(WARP_SIZE >= items_per_scalar);
|
static_assert(kWarpSize >= items_per_scalar);
|
||||||
// and because items_per_scalar > skip_cnt, WARP_SIZE > skip_cnt
|
// and because items_per_scalar > skip_cnt, kWarpSize > skip_cnt
|
||||||
// no need to use loop
|
// no need to use loop
|
||||||
if (thread_rank < skip_cnt) {
|
if (thread_rank < skip_cnt) {
|
||||||
f(in[thread_rank], thread_rank);
|
f(in[thread_rank], thread_rank);
|
||||||
@ -142,7 +145,7 @@ __device__ void vectorized_process(size_t thread_rank, size_t num_threads,
|
|||||||
// len_cast * items_per_scalar + items_per_scalar > len - skip_cnt;
|
// len_cast * items_per_scalar + items_per_scalar > len - skip_cnt;
|
||||||
// and so
|
// and so
|
||||||
// len - (skip_cnt + len_cast * items_per_scalar) < items_per_scalar <=
|
// len - (skip_cnt + len_cast * items_per_scalar) < items_per_scalar <=
|
||||||
// WARP_SIZE no need to use loop
|
// kWarpSize no need to use loop
|
||||||
const idxT remain_i = skip_cnt + len_cast * items_per_scalar + thread_rank;
|
const idxT remain_i = skip_cnt + len_cast * items_per_scalar + thread_rank;
|
||||||
if (remain_i < len) {
|
if (remain_i < len) {
|
||||||
f(in[remain_i], remain_i);
|
f(in[remain_i], remain_i);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user