diff --git a/csrc/attention/attention_kernels.cuh b/csrc/attention/attention_kernels.cuh index 79a546554fa1e..8f24be89578b8 100644 --- a/csrc/attention/attention_kernels.cuh +++ b/csrc/attention/attention_kernels.cuh @@ -24,6 +24,7 @@ #include "attention_dtypes.h" #include "attention_utils.cuh" +#include "cuda_compat.h" #ifdef USE_ROCM #include @@ -33,12 +34,6 @@ typedef __hip_bfloat16 __nv_bfloat16; #include "../quantization/fp8/nvidia/quant_utils.cuh" #endif -#ifndef USE_ROCM - #define WARP_SIZE 32 -#else - #define WARP_SIZE warpSize -#endif - #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) @@ -670,7 +665,6 @@ __global__ void paged_attention_v2_reduce_kernel( } // namespace vllm -#undef WARP_SIZE #undef MAX #undef MIN #undef DIVIDE_ROUND_UP diff --git a/csrc/attention/paged_attention_v1.cu b/csrc/attention/paged_attention_v1.cu index 46108a32d719b..7a5ef10f8ef3b 100644 --- a/csrc/attention/paged_attention_v1.cu +++ b/csrc/attention/paged_attention_v1.cu @@ -18,12 +18,7 @@ */ #include "attention_kernels.cuh" - -#ifndef USE_ROCM - #define WARP_SIZE 32 -#else - #define WARP_SIZE warpSize -#endif +#include "cuda_compat.h" #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) @@ -187,7 +182,6 @@ void paged_attention_v1( CALL_V1_LAUNCHER_BLOCK_SIZE) } -#undef WARP_SIZE #undef MAX #undef MIN #undef DIVIDE_ROUND_UP diff --git a/csrc/attention/paged_attention_v2.cu b/csrc/attention/paged_attention_v2.cu index 9358c0d9f6a2a..b45b28dad05ea 100644 --- a/csrc/attention/paged_attention_v2.cu +++ b/csrc/attention/paged_attention_v2.cu @@ -18,12 +18,7 @@ */ #include "attention_kernels.cuh" - -#ifndef USE_ROCM - #define WARP_SIZE 32 -#else - #define WARP_SIZE warpSize -#endif +#include "cuda_compat.h" #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) @@ -197,7 +192,6 @@ void paged_attention_v2( CALL_V2_LAUNCHER_BLOCK_SIZE) } -#undef WARP_SIZE #undef MAX #undef MIN #undef DIVIDE_ROUND_UP diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h index 82e55613d915a..affa051c75951 100644 --- a/csrc/cuda_compat.h +++ b/csrc/cuda_compat.h @@ -4,10 +4,10 @@ #include #endif -#ifndef USE_ROCM - #define WARP_SIZE 32 +#if defined(USE_ROCM) && defined(__GFX9__) + #define WARP_SIZE 64 #else - #define WARP_SIZE warpSize + #define WARP_SIZE 32 #endif #ifndef USE_ROCM