From 980dd4a2c4ca965e4b10483258a006b812a1991f Mon Sep 17 00:00:00 2001 From: CHU Tianxiang Date: Wed, 11 Oct 2023 15:19:53 +0800 Subject: [PATCH] Fix overflow in awq kernel (#1295) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 楚天翔 --- csrc/quantization/awq/gemm_kernels.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/quantization/awq/gemm_kernels.cu b/csrc/quantization/awq/gemm_kernels.cu index 3c5d08a18e0a..63542990b80f 100644 --- a/csrc/quantization/awq/gemm_kernels.cu +++ b/csrc/quantization/awq/gemm_kernels.cu @@ -90,7 +90,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i + (((int)threadIdx.x) % (128 / 8)) * 8; half* C_ptr = C - + blockIdx_z * M * OC // blockIdz.x -> split_k dim + + static_cast(blockIdx_z) * M * OC // blockIdz.x -> split_k dim + (((int)blockIdx_y) % j_factors1) * 128 + ((int)threadIdx.y) * 64 + (((int)threadIdx.x) % 4) * 2; @@ -323,7 +323,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in + (((int)threadIdx.x) % (64 / 8)) * 8; half* C_ptr = C - + blockIdx_z * M * OC // blockIdz.x -> split_k dim + + static_cast(blockIdx_z) * M * OC // blockIdz.x -> split_k dim + (((int)blockIdx_y) % j_factors1) * 64 + ((int)threadIdx.y) * 32 + (((int)threadIdx.x) % 4) * 2;