mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 21:55:38 +08:00
[ROCM] Fix blockReduceSum to use correct warp counts for ROCm and CUDA (#3262)
This commit is contained in:
parent
0bba88df03
commit
e4a28e5316
@ -15,9 +15,6 @@
|
|||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
#ifdef USE_ROCM
|
|
||||||
#include <hip/hip_runtime.h>
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
@ -31,11 +28,6 @@
|
|||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
#define WARP_SIZE 32
|
|
||||||
#else
|
|
||||||
#define WARP_SIZE warpSize
|
|
||||||
#endif
|
|
||||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||||
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
|
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
|
||||||
|
|||||||
@ -1,5 +1,15 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#ifdef USE_ROCM
|
||||||
|
#include <hip/hip_runtime.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
#define WARP_SIZE 32
|
||||||
|
#else
|
||||||
|
#define WARP_SIZE warpSize
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
#define VLLM_LDG(arg) __ldg(arg)
|
#define VLLM_LDG(arg) __ldg(arg)
|
||||||
#else
|
#else
|
||||||
|
|||||||
@ -24,7 +24,7 @@ namespace vllm {
|
|||||||
template<typename T>
|
template<typename T>
|
||||||
__inline__ __device__ T warpReduceSum(T val) {
|
__inline__ __device__ T warpReduceSum(T val) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = 16; mask > 0; mask >>= 1)
|
for (int mask = WARP_SIZE/2; mask > 0; mask >>= 1)
|
||||||
val += VLLM_SHFL_XOR_SYNC(val, mask);
|
val += VLLM_SHFL_XOR_SYNC(val, mask);
|
||||||
return val;
|
return val;
|
||||||
}
|
}
|
||||||
@ -32,7 +32,7 @@ __inline__ __device__ T warpReduceSum(T val) {
|
|||||||
/* Calculate the sum of all elements in a block */
|
/* Calculate the sum of all elements in a block */
|
||||||
template<typename T>
|
template<typename T>
|
||||||
__inline__ __device__ T blockReduceSum(T val) {
|
__inline__ __device__ T blockReduceSum(T val) {
|
||||||
static __shared__ T shared[32];
|
static __shared__ T shared[WARP_SIZE];
|
||||||
int lane = threadIdx.x & 0x1f;
|
int lane = threadIdx.x & 0x1f;
|
||||||
int wid = threadIdx.x >> 5;
|
int wid = threadIdx.x >> 5;
|
||||||
|
|
||||||
@ -45,7 +45,7 @@ __inline__ __device__ T blockReduceSum(T val) {
|
|||||||
|
|
||||||
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
|
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
|
||||||
// blockDim.x is not divided by 32
|
// blockDim.x is not divided by 32
|
||||||
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
|
val = (threadIdx.x < (blockDim.x / (WARP_SIZE * 1.0f))) ? shared[lane] : (T)(0.0f);
|
||||||
val = warpReduceSum<T>(val);
|
val = warpReduceSum<T>(val);
|
||||||
return val;
|
return val;
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user