From 7326fef568b14ecbdc13e24d3b2b974691de0187 Mon Sep 17 00:00:00 2001 From: Andreas Karatzas Date: Mon, 15 Dec 2025 20:25:30 +0000 Subject: [PATCH 1/2] fix output zeroing race condition in GPTQ GEMM kernels Signed-off-by: Andreas Karatzas --- csrc/quantization/gptq/q_gemm.cu | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/csrc/quantization/gptq/q_gemm.cu b/csrc/quantization/gptq/q_gemm.cu index 8869d7cd521b6..b4bd756d1fb55 100644 --- a/csrc/quantization/gptq/q_gemm.cu +++ b/csrc/quantization/gptq/q_gemm.cu @@ -233,10 +233,10 @@ __global__ void gemm_half_q_half_gptq_4bit_kernel( // Zero output if (n >= size_n) return; - if (blockIdx.z == 0) { - for (int m = 0; m < m_count; m++) - *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; - } + // if (blockIdx.z == 0) { + // for (int m = 0; m < m_count; m++) + // *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + // } __syncthreads(); @@ -1857,7 +1857,7 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, bool use_exllama, bool use_v2_format, int64_t bit) { const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); - at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options); + at::Tensor c = torch::zeros({a.size(0), b_q_weight.size(1)}, options); at::Tensor temp_dq = torch::empty( {b_q_weight.size(0) * 32 / bit, b_q_weight.size(1)}, options); From a36191a7745759a674036029b0556c37159713d3 Mon Sep 17 00:00:00 2001 From: Andreas Karatzas Date: Mon, 15 Dec 2025 21:01:15 +0000 Subject: [PATCH 2/2] fix output zeroing race condition in GPTQ GEMM kernels Signed-off-by: Andreas Karatzas --- csrc/quantization/gptq/q_gemm.cu | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/csrc/quantization/gptq/q_gemm.cu b/csrc/quantization/gptq/q_gemm.cu index b4bd756d1fb55..8a29ad5ab2dd8 100644 --- a/csrc/quantization/gptq/q_gemm.cu +++ b/csrc/quantization/gptq/q_gemm.cu @@ -233,11 +233,6 @@ __global__ void gemm_half_q_half_gptq_4bit_kernel( // Zero output if (n >= size_n) return; - // if (blockIdx.z == 0) { - // for (int m = 0; m < m_count; m++) - // *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; - // } - __syncthreads(); // Find initial group @@ -372,11 +367,6 @@ __global__ void gemm_half_q_half_gptq_2bit_kernel( // Zero output if (n >= size_n) return; - if (blockIdx.z == 0) { - for (int m = 0; m < m_count; m++) - *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; - } - __syncthreads(); // Find initial group @@ -494,11 +484,6 @@ __global__ void gemm_half_q_half_gptq_3bit_kernel( // Zero output if (n >= size_n) return; - if (blockIdx.z == 0) { - for (int m = 0; m < m_count; m++) - *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; - } - __syncthreads(); // Find initial group @@ -623,11 +608,6 @@ __global__ void gemm_half_q_half_gptq_8bit_kernel( // Zero output if (n >= size_n) return; - if (blockIdx.z == 0) { - for (int m = 0; m < m_count; m++) - *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; - } - __syncthreads(); // Find initial group @@ -1224,9 +1204,6 @@ __global__ void gemm_half_q_half_alt_4bit_kernel( __halves2half2(__int2half_rn(val & 0xF), __int2half_rn(val >> 4)); } - if (blockIdx.z == 0) { - for (int m = 0; m < b_end; m++) mul[(b + m) * width + w] = __int2half_rn(0); - } __syncthreads(); int i = width * h + w; @@ -1319,9 +1296,6 @@ __global__ void gemm_half_q_half_alt_8bit_kernel( } } - if (blockIdx.z == 0) { - for (int m = 0; m < b_end; m++) mul[(b + m) * width + w] = __int2half_rn(0); - } __syncthreads(); int i = width * h + w;