diff --git a/csrc/quantization/w8a8/fp8/per_token_group_quant.cu b/csrc/quantization/w8a8/fp8/per_token_group_quant.cu index 52326687f6197..e3ab0676b254e 100644 --- a/csrc/quantization/w8a8/fp8/per_token_group_quant.cu +++ b/csrc/quantization/w8a8/fp8/per_token_group_quant.cu @@ -12,8 +12,8 @@ #include "quantization/vectorization_utils.cuh" #include "dispatch_utils.h" -__device__ __forceinline__ float GroupReduceMax(float val, const int tid) { - unsigned mask = 0xffff; +__device__ __forceinline__ float GroupReduceMax(float val) { + unsigned mask = threadIdx.x % 32 >= 16 ? 0xffff0000 : 0x0000ffff; val = fmaxf(val, __shfl_xor_sync(mask, val, 8)); val = fmaxf(val, __shfl_xor_sync(mask, val, 4)); @@ -86,7 +86,7 @@ __global__ void per_token_group_quant_8bit_kernel( threads_per_group, // stride in group scalar_op_cache); // scalar handler - local_absmax = GroupReduceMax(local_absmax, lane_id); + local_absmax = GroupReduceMax(local_absmax); float y_s = local_absmax / max_8bit; if constexpr (SCALE_UE8M0) {