From c3fab5f7691c55e9fd0de5ed373f4dd5fb2152cf Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Wed, 16 Oct 2024 19:46:06 -0400 Subject: [PATCH] [Bugfix][Kernel] Prevent integer overflow in fp8 dynamic per-token quantize kernel (#9425) --- csrc/quantization/fp8/common.cu | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index 7e23f9225776..f2c609c1b68c 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -204,8 +204,10 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel( int const tid = threadIdx.x; int const token_idx = blockIdx.x; - scalar_t const* __restrict__ token_input = &input[token_idx * hidden_size]; - FP8_TYPE* __restrict__ token_output = &out[token_idx * hidden_size]; + // Use int64 to avoid overflowing an int32 when calculating this offset + int64_t offset = static_cast(token_idx) * hidden_size; + scalar_t const* __restrict__ token_input = &input[offset]; + FP8_TYPE* __restrict__ token_output = &out[offset]; // For vectorization, token_input and token_output pointers need to be // aligned at 8-byte and 4-byte addresses respectively.