From e4a28e53165902ffc5daf20977c70885d0c05768 Mon Sep 17 00:00:00 2001 From: Douglas Lehr <91553416+dllehr-amd@users.noreply.github.com> Date: Sun, 10 Mar 2024 17:27:45 -0500 Subject: [PATCH] [ROCM] Fix blockReduceSum to use correct warp counts for ROCm and CUDA (#3262) --- csrc/attention/attention_kernels.cu | 8 -------- csrc/cuda_compat.h | 10 ++++++++++ csrc/reduction_utils.cuh | 6 +++--- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index b5be3befa07e..5e61668d5cc1 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -15,9 +15,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifdef USE_ROCM -#include -#endif #include #include @@ -31,11 +28,6 @@ #include -#ifndef USE_ROCM -#define WARP_SIZE 32 -#else -#define WARP_SIZE warpSize -#endif #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h index aa58dd73c148..c711d8d1b24b 100644 --- a/csrc/cuda_compat.h +++ b/csrc/cuda_compat.h @@ -1,5 +1,15 @@ #pragma once +#ifdef USE_ROCM +#include +#endif + +#ifndef USE_ROCM + #define WARP_SIZE 32 +#else + #define WARP_SIZE warpSize +#endif + #ifndef USE_ROCM #define VLLM_LDG(arg) __ldg(arg) #else diff --git a/csrc/reduction_utils.cuh b/csrc/reduction_utils.cuh index b95ccef16207..210bf0b023ab 100644 --- a/csrc/reduction_utils.cuh +++ b/csrc/reduction_utils.cuh @@ -24,7 +24,7 @@ namespace vllm { template __inline__ __device__ T warpReduceSum(T val) { #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) + for (int mask = WARP_SIZE/2; mask > 0; mask >>= 1) val += VLLM_SHFL_XOR_SYNC(val, mask); return val; } @@ -32,7 +32,7 @@ __inline__ __device__ T warpReduceSum(T val) { /* Calculate the sum of all elements in a block */ template __inline__ __device__ T blockReduceSum(T val) { - static __shared__ T shared[32]; + static __shared__ T shared[WARP_SIZE]; int lane = threadIdx.x & 0x1f; int wid = threadIdx.x >> 5; @@ -45,7 +45,7 @@ __inline__ __device__ T blockReduceSum(T val) { // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent // blockDim.x is not divided by 32 - val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f); + val = (threadIdx.x < (blockDim.x / (WARP_SIZE * 1.0f))) ? shared[lane] : (T)(0.0f); val = warpReduceSum(val); return val; }