From ac23d0ba18407d09501fe470573940b63129964f Mon Sep 17 00:00:00 2001 From: c0de128 Date: Wed, 24 Dec 2025 09:02:06 -0600 Subject: [PATCH] [Bugfix][Hardware][AMD] Use dynamic WARP_SIZE in sampler vectorized_process Replace hardcoded WARP_SIZE=32 with the dynamic WARP_SIZE macro from cuda_compat.h to correctly support both Wave64 (MI300X/gfx942) and Wave32 (Strix Halo/gfx1151) architectures. The previous hardcoded value was incorrect for AMD CDNA GPUs which use 64-wide wavefronts. While the current static_assert (kWarpSize >= 4) passes for both 32 and 64, having inconsistent WARP_SIZE definitions across the codebase is a maintenance issue and potential latent bug. Changes: - Add cuda_compat.h include for WARP_SIZE macro - Replace local WARP_SIZE constant with kWarpSize from cuda_compat.h - Update static_assert and comments to use kWarpSize Signed-off-by: c0de128 --- csrc/sampler.cu | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/csrc/sampler.cu b/csrc/sampler.cu index d458f8e4c1d02..f7c091f1d4ee4 100644 --- a/csrc/sampler.cu +++ b/csrc/sampler.cu @@ -1,3 +1,4 @@ +#include "cuda_compat.h" #include "dispatch_utils.h" #include @@ -97,7 +98,9 @@ static inline __device__ bool isPartialMatch(float x, uint32_t pattern) { template __device__ void vectorized_process(size_t thread_rank, size_t num_threads, const T* in, idxT len, Func f) { - constexpr int WARP_SIZE = 32; + // Use dynamic WARP_SIZE from cuda_compat.h to support both + // Wave64 (MI300X/gfx942) and Wave32 (Strix Halo/gfx1151) architectures + constexpr int kWarpSize = WARP_SIZE; using WideT = float4; if constexpr (sizeof(T) >= sizeof(WideT)) { for (idxT i = thread_rank; i < len; i += num_threads) { @@ -132,8 +135,8 @@ __device__ void vectorized_process(size_t thread_rank, size_t num_threads, } } - static_assert(WARP_SIZE >= items_per_scalar); - // and because items_per_scalar > skip_cnt, WARP_SIZE > skip_cnt + static_assert(kWarpSize >= items_per_scalar); + // and because items_per_scalar > skip_cnt, kWarpSize > skip_cnt // no need to use loop if (thread_rank < skip_cnt) { f(in[thread_rank], thread_rank); @@ -142,7 +145,7 @@ __device__ void vectorized_process(size_t thread_rank, size_t num_threads, // len_cast * items_per_scalar + items_per_scalar > len - skip_cnt; // and so // len - (skip_cnt + len_cast * items_per_scalar) < items_per_scalar <= - // WARP_SIZE no need to use loop + // kWarpSize no need to use loop const idxT remain_i = skip_cnt + len_cast * items_per_scalar + thread_rank; if (remain_i < len) { f(in[remain_i], remain_i);