[BugFix] Fix UB in per_token_group_quant.cu (#24913)

Signed-off-by: Shreeasish Kumar <shreeasish@rivosinc.com>
This commit is contained in:
rivos-shreeasish 2025-09-23 09:14:22 -07:00 committed by GitHub
parent f11e3c516b
commit 2357480b1a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -12,8 +12,8 @@
#include "../vectorization_utils.cuh"
#include "../../dispatch_utils.h"
__device__ __forceinline__ float GroupReduceMax(float val, const int tid) {
unsigned mask = 0xffff;
__device__ __forceinline__ float GroupReduceMax(float val) {
unsigned mask = threadIdx.x % 32 >= 16 ? 0xffff0000 : 0x0000ffff;
val = fmaxf(val, __shfl_xor_sync(mask, val, 8));
val = fmaxf(val, __shfl_xor_sync(mask, val, 4));
@ -86,7 +86,7 @@ __global__ void per_token_group_quant_8bit_kernel(
threads_per_group, // stride in group
scalar_op_cache); // scalar handler
local_absmax = GroupReduceMax(local_absmax, lane_id);
local_absmax = GroupReduceMax(local_absmax);
float y_s = local_absmax / max_8bit;
if constexpr (SCALE_UE8M0) {