From 2357480b1abefb2b7177046133080917e1449324 Mon Sep 17 00:00:00 2001 From: rivos-shreeasish Date: Tue, 23 Sep 2025 09:14:22 -0700 Subject: [PATCH] [BugFix] Fix UB in per_token_group_quant.cu (#24913) Signed-off-by: Shreeasish Kumar --- csrc/quantization/fp8/per_token_group_quant.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/quantization/fp8/per_token_group_quant.cu b/csrc/quantization/fp8/per_token_group_quant.cu index f5b40e35b6e5a..91d489fdef862 100644 --- a/csrc/quantization/fp8/per_token_group_quant.cu +++ b/csrc/quantization/fp8/per_token_group_quant.cu @@ -12,8 +12,8 @@ #include "../vectorization_utils.cuh" #include "../../dispatch_utils.h" -__device__ __forceinline__ float GroupReduceMax(float val, const int tid) { - unsigned mask = 0xffff; +__device__ __forceinline__ float GroupReduceMax(float val) { + unsigned mask = threadIdx.x % 32 >= 16 ? 0xffff0000 : 0x0000ffff; val = fmaxf(val, __shfl_xor_sync(mask, val, 8)); val = fmaxf(val, __shfl_xor_sync(mask, val, 4)); @@ -86,7 +86,7 @@ __global__ void per_token_group_quant_8bit_kernel( threads_per_group, // stride in group scalar_op_cache); // scalar handler - local_absmax = GroupReduceMax(local_absmax, lane_id); + local_absmax = GroupReduceMax(local_absmax); float y_s = local_absmax / max_8bit; if constexpr (SCALE_UE8M0) {