From 5cc6bddb6ef5e8e5c10de8122a43fd6e8c1e3b4b Mon Sep 17 00:00:00 2001 From: Xiangyu Li Date: Fri, 24 Oct 2025 11:26:13 +0800 Subject: [PATCH] [Kernel] Add GPTQv2 format support for low-bit or asymmetric quantization, by adapting gptq_gemm (#26092) --- csrc/ops.h | 2 +- csrc/quantization/gptq/q_gemm.cu | 223 ++++++++++-------- csrc/torch_bindings.cpp | 3 +- tests/kernels/quantization/test_gptq.py | 8 +- tests/quantization/test_gptq_v2.py | 109 +++++++++ vllm/_custom_ops.py | 11 +- .../layers/quantization/gptq.py | 30 ++- .../kernels/mixed_precision/exllama.py | 7 +- 8 files changed, 295 insertions(+), 98 deletions(-) create mode 100644 tests/quantization/test_gptq_v2.py diff --git a/csrc/ops.h b/csrc/ops.h index eb3d60b77e60a..0bed7492f6616 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -307,7 +307,7 @@ void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, torch::Tensor b_gptq_qzeros, torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, - bool use_exllama, int64_t bit); + bool use_exllama, bool use_v2_format, int64_t bit); void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit); diff --git a/csrc/quantization/gptq/q_gemm.cu b/csrc/quantization/gptq/q_gemm.cu index 43b245530e950..8869d7cd521b6 100644 --- a/csrc/quantization/gptq/q_gemm.cu +++ b/csrc/quantization/gptq/q_gemm.cu @@ -185,7 +185,7 @@ typedef void (*fp_gemm_half_q_half_gptq_kernel)(const half*, const uint32_t*, const uint32_t*, const half*, half*, const int, const int, const int, const int, - const int*); + const bool, const int*); template __global__ void gemm_half_q_half_gptq_4bit_kernel( @@ -193,12 +193,15 @@ __global__ void gemm_half_q_half_gptq_4bit_kernel( const uint32_t* __restrict__ b_gptq_qzeros, const half* __restrict__ b_gptq_scales, half* __restrict__ c, const int size_m, const int size_n, const int size_k, const int groups, - const int* __restrict__ b_q_perm) { + const bool use_v2_format, const int* __restrict__ b_q_perm) { MatrixView_half a_(a, size_m, size_k); MatrixView_half_rw c_(c, size_m, size_n); MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + // GPTQv2 and GPTQv1 handles zero points differently + int zero_offset = use_v2_format ? 0 : 1; + auto t = threadIdx.x; // Block @@ -256,10 +259,10 @@ __global__ void gemm_half_q_half_gptq_4bit_kernel( half2 y1y16[4][2]; b_gptq_qzeros_.item4(zeros, group, n); b_gptq_scales_.item4_f(scales, group, n); - dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + dequant_4bit_8_prep_zero(zeros[0] + zero_offset, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + zero_offset, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + zero_offset, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + zero_offset, z1z16[3], y1y16[3]); // Column result float block_c[m_count][4] = {}; @@ -272,10 +275,10 @@ __global__ void gemm_half_q_half_gptq_4bit_kernel( nextgroup += groupsize; b_gptq_qzeros_.item4(zeros, group, n); b_gptq_scales_.item4_f(scales, group, n); - dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + dequant_4bit_8_prep_zero(zeros[0] + zero_offset, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + zero_offset, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + zero_offset, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + zero_offset, z1z16[3], y1y16[3]); } #pragma unroll @@ -329,12 +332,15 @@ __global__ void gemm_half_q_half_gptq_2bit_kernel( const uint32_t* __restrict__ b_gptq_qzeros, const half* __restrict__ b_gptq_scales, half* __restrict__ c, const int size_m, const int size_n, const int size_k, const int groups, - const int* __restrict__ b_q_perm) { + const bool use_v2_format, const int* __restrict__ b_q_perm) { MatrixView_half a_(a, size_m, size_k); MatrixView_half_rw c_(c, size_m, size_n); MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + // GPTQv2 and GPTQv1 handles zero points differently + int zero_offset = use_v2_format ? 0 : 1; + auto t = threadIdx.x; // Block @@ -409,10 +415,10 @@ __global__ void gemm_half_q_half_gptq_2bit_kernel( int4 load_int4 = *b_ptr4; half2 dq[4][8]; - dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1); - dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1); - dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1); - dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1); + dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + zero_offset); + dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + zero_offset); + dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + zero_offset); + dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + zero_offset); #pragma unroll for (int m = 0; m < m_count; m++) { @@ -448,12 +454,15 @@ __global__ void gemm_half_q_half_gptq_3bit_kernel( const uint32_t* __restrict__ b_gptq_qzeros, const half* __restrict__ b_gptq_scales, half* __restrict__ c, const int size_m, const int size_n, const int size_k, const int groups, - const int* __restrict__ b_q_perm) { + const bool use_v2_format, const int* __restrict__ b_q_perm) { MatrixView_half a_(a, size_m, size_k); MatrixView_half_rw c_(c, size_m, size_n); MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + // GPTQv2 and GPTQv1 handles zero points differently + int zero_offset = use_v2_format ? 0 : 1; + auto t = threadIdx.x; // Block @@ -534,13 +543,13 @@ __global__ void gemm_half_q_half_gptq_3bit_kernel( half2 dq[4][16]; dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], - size_n, zeros[0] + 1); + size_n, zeros[0] + zero_offset); dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], - size_n, zeros[1] + 1); + size_n, zeros[1] + zero_offset); dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], - size_n, zeros[2] + 1); + size_n, zeros[2] + zero_offset); dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], - size_n, zeros[3] + 1); + size_n, zeros[3] + zero_offset); #pragma unroll for (int m = 0; m < m_count; m++) { @@ -574,12 +583,15 @@ __global__ void gemm_half_q_half_gptq_8bit_kernel( const uint32_t* __restrict__ b_gptq_qzeros, const half* __restrict__ b_gptq_scales, half* __restrict__ c, const int size_m, const int size_n, const int size_k, const int groups, - const int* __restrict__ b_q_perm) { + const bool use_v2_format, const int* __restrict__ b_q_perm) { MatrixView_half a_(a, size_m, size_k); MatrixView_half_rw c_(c, size_m, size_n); MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + // GPTQv2 and GPTQv1 handles zero points differently + int zero_offset = use_v2_format ? 0 : 1; + auto t = threadIdx.x; // Block @@ -658,13 +670,13 @@ __global__ void gemm_half_q_half_gptq_8bit_kernel( half2 dq[4][4]; dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n, - zeros[0] + 1); + zeros[0] + zero_offset); dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n, - zeros[1] + 1); + zeros[1] + zero_offset); dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n, - zeros[2] + 1); + zeros[2] + zero_offset); dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n, - zeros[3] + 1); + zeros[3] + zero_offset); for (int m = 0; m < m_count; m++) { block_c[m][0] = @@ -730,7 +742,8 @@ void gemm_half_q_half_cuda_part(const half* a, const uint32_t* b_q_weight, const uint32_t* b_gptq_qzeros, const half* b_gptq_scales, const int* b_q_perm, half* c, int size_m, int size_n, int size_k, - int m_count, int groups, int bit) { + int m_count, int groups, bool use_v2_format, + int bit) { dim3 blockDim, gridDim; blockDim.x = BLOCK_KN_SIZE; blockDim.y = 1; @@ -743,20 +756,23 @@ void gemm_half_q_half_cuda_part(const half* a, const uint32_t* b_q_weight, pick_gemm_half_q_half_gptq_kernel(true, m_count, bit); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - kernel<<>>(a, b_q_weight, b_gptq_qzeros, - b_gptq_scales, c, size_m, size_n, - size_k, groups, b_q_perm); + kernel<<>>( + a, b_q_weight, b_gptq_qzeros, b_gptq_scales, c, size_m, size_n, size_k, + groups, use_v2_format, b_q_perm); } __global__ void reconstruct_exllama_8bit_kernel( const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm, const uint32_t* __restrict__ b_gptq_qzeros, const half* __restrict__ b_gptq_scales, const int size_k, const int size_n, - const int groups, half* __restrict__ b) { + const int groups, const bool use_v2_format, half* __restrict__ b) { MatrixView_half_rw b_(b, size_k, size_n); MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + // GPTQv2 and GPTQv1 handles zero points differently + int zero_offset = use_v2_format ? 0 : 1; + auto offset_k = BLOCK_KN_SIZE * blockIdx.y; auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; @@ -812,13 +828,13 @@ __global__ void reconstruct_exllama_8bit_kernel( half2 dq[4][4]; dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n, - zeros[0] + 1); + zeros[0] + zero_offset); dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n, - zeros[1] + 1); + zeros[1] + zero_offset); dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n, - zeros[2] + 1); + zeros[2] + zero_offset); dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n, - zeros[3] + 1); + zeros[3] + zero_offset); // half* dqh = (half*)dq; if (b_q_perm) { @@ -849,11 +865,14 @@ __global__ void reconstruct_exllama_4bit_kernel( const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm, const uint32_t* __restrict__ b_gptq_qzeros, const half* __restrict__ b_gptq_scales, const int size_k, const int size_n, - const int groups, half* __restrict__ b) { + const int groups, const bool use_v2_format, half* __restrict__ b) { MatrixView_half_rw b_(b, size_k, size_n); MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + // GPTQv2 and GPTQv1 handles zero points differently + int zero_offset = use_v2_format ? 0 : 1; + auto offset_k = BLOCK_KN_SIZE * blockIdx.y; auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; @@ -888,10 +907,10 @@ __global__ void reconstruct_exllama_4bit_kernel( half2 y1y16[4][2]; b_gptq_qzeros_.item4(zeros, group, n); b_gptq_scales_.item4_h2(scales, group, n); - dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + dequant_4bit_8_prep_zero(zeros[0] + zero_offset, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + zero_offset, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + zero_offset, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + zero_offset, z1z16[3], y1y16[3]); __syncthreads(); @@ -904,10 +923,10 @@ __global__ void reconstruct_exllama_4bit_kernel( nextgroup += groupsize; b_gptq_qzeros_.item4(zeros, group, n); b_gptq_scales_.item4_h2(scales, group, n); - dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + dequant_4bit_8_prep_zero(zeros[0] + zero_offset, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + zero_offset, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + zero_offset, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + zero_offset, z1z16[3], y1y16[3]); } for (int p = 0; p < 4; p++) { @@ -954,11 +973,14 @@ __global__ void reconstruct_exllama_3bit_kernel( const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm, const uint32_t* __restrict__ b_gptq_qzeros, const half* __restrict__ b_gptq_scales, const int size_k, const int size_n, - const int groups, half* __restrict__ b) { + const int groups, const bool use_v2_format, half* __restrict__ b) { MatrixView_half_rw b_(b, size_k, size_n); MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + // GPTQv2 and GPTQv1 handles zero points differently + int zero_offset = use_v2_format ? 0 : 1; + auto offset_k = BLOCK_KN_SIZE * blockIdx.y; auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; @@ -1016,13 +1038,13 @@ __global__ void reconstruct_exllama_3bit_kernel( half2 dq[4][16]; dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], - size_n, zeros[0] + 1); + size_n, zeros[0] + zero_offset); dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], - size_n, zeros[1] + 1); + size_n, zeros[1] + zero_offset); dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], - size_n, zeros[2] + 1); + size_n, zeros[2] + zero_offset); dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], - size_n, zeros[3] + 1); + size_n, zeros[3] + zero_offset); if (b_q_perm) { for (int j = 0; j < 16; j++) { @@ -1052,11 +1074,14 @@ __global__ void reconstruct_exllama_2bit_kernel( const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm, const uint32_t* __restrict__ b_gptq_qzeros, const half* __restrict__ b_gptq_scales, const int size_k, const int size_n, - const int groups, half* __restrict__ b) { + const int groups, const bool use_v2_format, half* __restrict__ b) { MatrixView_half_rw b_(b, size_k, size_n); MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + // GPTQv2 and GPTQv1 handles zero points differently + int zero_offset = use_v2_format ? 0 : 1; + auto offset_k = BLOCK_KN_SIZE * blockIdx.y; auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; @@ -1108,10 +1133,10 @@ __global__ void reconstruct_exllama_2bit_kernel( int4 load_int4 = *b_ptr4; half2 dq[4][8]; - dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1); - dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1); - dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1); - dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1); + dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + zero_offset); + dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + zero_offset); + dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + zero_offset); + dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + zero_offset); b_ptr += size_n; // half* dqh = (half*)dq; @@ -1143,7 +1168,7 @@ void reconstruct_exllama(const uint32_t* b_q_weight, const uint32_t* b_gptq_qzeros, const half* b_gptq_scales, const int* b_q_perm, half* out, int height, int width, int groups, - int bit) { + bool use_v2_format, int bit) { dim3 blockDim, gridDim; blockDim.x = BLOCK_KN_SIZE; blockDim.y = 1; @@ -1162,14 +1187,14 @@ void reconstruct_exllama(const uint32_t* b_q_weight, const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); reconstruct_exllama_kernel<<>>( b_q_weight, b_q_perm, b_gptq_qzeros, b_gptq_scales, height, width, groups, - out); + use_v2_format, out); } __global__ void gemm_half_q_half_alt_4bit_kernel( const half2* __restrict__ vec, const uint32_t* __restrict__ mat, half* __restrict__ mul, const half* __restrict__ scales, const uint32_t* __restrict__ zeros, const int* __restrict__ g_idx, - int batch, int height, int width) { + int batch, int height, int width, bool use_v2_format) { int zero_width = width / 8; int vec_height = height * 4; const int blockwidth2 = BLOCK_KN_SIZE / 2; @@ -1179,6 +1204,9 @@ __global__ void gemm_half_q_half_alt_4bit_kernel( int h_end = min(BLOCK_KN_SIZE / 8, height - h) * 4; auto w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + // GPTQv2 and GPTQv1 handles zero points differently + int zero_offset = use_v2_format ? 0 : 1; + __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2]; if (threadIdx.x < h_end) { for (int m = 0; m < b_end; ++m) { @@ -1223,10 +1251,11 @@ __global__ void gemm_half_q_half_alt_4bit_kernel( half2 zero = __halves2half2( __hmul(scale_f, __int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xF) - - 1)), - __hmul(scale_f2, - __int2half_rn( - -((zeros[g2 * zero_width + z_w] >> z_mod) & 0xF) - 1))); + zero_offset)), + __hmul( + scale_f2, + __int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xF) - + zero_offset))); scales_tmp[tmp_k] = scale; zeros_tmp[tmp_k] = zero; } @@ -1268,7 +1297,7 @@ __global__ void gemm_half_q_half_alt_8bit_kernel( const half2* __restrict__ vec, const uint32_t* __restrict__ mat, half* __restrict__ mul, const half* __restrict__ scales, const uint32_t* __restrict__ zeros, const int* __restrict__ g_idx, - int batch, int height, int width) { + int batch, int height, int width, bool use_v2_format) { int zero_width = width / 4; int vec_height = height * 2; const int blockwidth2 = BLOCK_KN_SIZE / 2; @@ -1278,6 +1307,9 @@ __global__ void gemm_half_q_half_alt_8bit_kernel( int h_end = min(BLOCK_KN_SIZE / 4, height - h) * 2; auto w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + // GPTQv2 and GPTQv1 handles zero points differently + int zero_offset = use_v2_format ? 0 : 1; + __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2]; if (threadIdx.x < h_end) { for (int m = 0; m < b_end; ++m) { @@ -1312,12 +1344,13 @@ __global__ void gemm_half_q_half_alt_8bit_kernel( half scale_f2 = scales[g2 * width + w]; half2 scale = __halves2half2(scale_f, scale_f2); half2 zero = __halves2half2( - __hmul(scale_f, - __int2half_rn( - -((zeros[g * zero_width + z_w] >> z_mod) & 0xff) - 1)), - __hmul(scale_f2, - __int2half_rn( - -((zeros[g2 * zero_width + z_w] >> z_mod) & 0xff) - 1))); + __hmul(scale_f, __int2half_rn( + -((zeros[g * zero_width + z_w] >> z_mod) & 0xff) - + zero_offset)), + __hmul( + scale_f2, + __int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xff) - + zero_offset))); scales_tmp[tmp_k] = scale; zeros_tmp[tmp_k] = zero; } @@ -1355,7 +1388,7 @@ void gemm_half_q_half_alt(const half* a, const uint32_t* b_q_weight, const uint32_t* b_gptq_qzeros, const half* b_gptq_scales, const int* b_g_idx, half* c, int size_m, int size_n, int size_k, - int bit) { + bool use_v2_format, int bit) { dim3 blockDim, gridDim; blockDim.x = BLOCK_KN_SIZE; blockDim.y = 1; @@ -1372,17 +1405,15 @@ void gemm_half_q_half_alt(const half* a, const uint32_t* b_q_weight, const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); kernel<<>>( (const half2*)a, b_q_weight, c, b_gptq_scales, b_gptq_qzeros, b_g_idx, - size_m, size_k / 32 * bit, size_n); + size_m, size_k / 32 * bit, size_n, use_v2_format); } template -__global__ void reconstruct_gptq_kernel(const uint32_t* __restrict__ w, - const half* __restrict__ w_scales, - const uint32_t* __restrict__ w_zeros, - const int* __restrict__ g_idx, - const int height, const int width, - const int group, - half* __restrict__ out) { +__global__ void reconstruct_gptq_kernel( + const uint32_t* __restrict__ w, const half* __restrict__ w_scales, + const uint32_t* __restrict__ w_zeros, const int* __restrict__ g_idx, + const int height, const int width, const int group, + const bool use_v2_format, half* __restrict__ out) { // Start of block auto column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; @@ -1395,6 +1426,9 @@ __global__ void reconstruct_gptq_kernel(const uint32_t* __restrict__ w, MatrixView_half w_scales_(w_scales, group, width); T w_zeros_(w_zeros, group, width); + // GPTQv2 and GPTQv1 handles zero points differently + int zero_offset = use_v2_format ? 0 : 1; + uint32_t w_read = w[blockIdx.y * width + column]; half* out_ptr = out_.item_ptr(row, column); @@ -1402,7 +1436,7 @@ __global__ void reconstruct_gptq_kernel(const uint32_t* __restrict__ w, for (int s = 0; s < 32; s += bit) { int group = g_idx[row + s / bit]; half w_scale = w_scales_.item(group, column); - uint32_t w_zero = w_zeros_.item(group, column) + 1; + uint32_t w_zero = w_zeros_.item(group, column) + zero_offset; half w_item = __hmul(__int2half_rn((int)((w_read >> s) & ((1 << bit) - 1)) - w_zero), w_scale); @@ -1415,7 +1449,7 @@ __global__ void reconstruct_gptq_3bit_kernel( const uint32_t* __restrict__ w, const half* __restrict__ w_scales, const uint32_t* __restrict__ w_zeros, const int* __restrict__ g_idx, const int height, const int width, const int group, - half* __restrict__ out) { + const bool use_v2_format, half* __restrict__ out) { // Start of block auto column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; auto row = blockIdx.y * 32; @@ -1427,6 +1461,9 @@ __global__ void reconstruct_gptq_3bit_kernel( MatrixView_half w_scales_(w_scales, group, width); MatrixView_q3_row w_zeros_(w_zeros, group, width); + // GPTQv2 and GPTQv1 handles zero points differently + int zero_offset = use_v2_format ? 0 : 1; + uint32_t w1 = w[(blockIdx.y * 3) * width + column]; uint32_t w2 = w[(blockIdx.y * 3 + 1) * width + column]; uint32_t w3 = w[(blockIdx.y * 3 + 2) * width + column]; @@ -1436,7 +1473,7 @@ __global__ void reconstruct_gptq_3bit_kernel( for (int i = 0; i < 32; i += 1) { int group = g_idx[row + i]; half w_scale = w_scales_.item(group, column); - uint32_t w_zero = w_zeros_.item(group, column) + 1; + uint32_t w_zero = w_zeros_.item(group, column) + zero_offset; int w_item; if (i == 10) { w_item = (w1 >> 30) | ((w2 << 2) & 0x4); @@ -1456,7 +1493,8 @@ __global__ void reconstruct_gptq_3bit_kernel( void reconstruct_gptq(const uint32_t* b_q_weight, const uint32_t* b_gptq_qzeros, const half* b_gptq_scales, const int* b_g_idx, half* out, - int height, int width, int groups, int bit) { + int height, int width, int groups, bool use_v2_format, + int bit) { dim3 blockDim, gridDim; blockDim.x = BLOCK_KN_SIZE; blockDim.y = 1; @@ -1476,7 +1514,7 @@ void reconstruct_gptq(const uint32_t* b_q_weight, const uint32_t* b_gptq_qzeros, const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); kernel<<>>(b_q_weight, b_gptq_scales, b_gptq_qzeros, b_g_idx, height, - width, groups, out); + width, groups, use_v2_format, out); } void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a, @@ -1484,7 +1522,8 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a, const uint32_t* b_gptq_qzeros, const half* b_gptq_scales, const int* b_g_idx, half* c, half* temp_dq, int size_m, int size_n, - int size_k, int groups, bool use_exllama, int bit) { + int size_k, int groups, bool use_exllama, + bool use_v2_format, int bit) { bool use_reconstruct; if (use_exllama) { use_reconstruct = ((bit == 8 && size_m > MAX_Q_GEMM_ROWS_8BIT) || @@ -1498,10 +1537,10 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a, // Reconstruct FP16 matrix, then cuBLAS if (use_exllama) { reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, - temp_dq, size_k, size_n, groups, bit); + temp_dq, size_k, size_n, groups, use_v2_format, bit); } else { reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, - temp_dq, size_k, size_n, groups, bit); + temp_dq, size_k, size_n, groups, use_v2_format, bit); } const half alpha = __float2half(1.0f); @@ -1517,18 +1556,18 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a, if (max_chunks) { gemm_half_q_half_cuda_part(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, c, last_chunk, size_n, size_k, - BLOCK_M_SIZE_MAX, groups, bit); + BLOCK_M_SIZE_MAX, groups, use_v2_format, bit); } if (last_chunk_size) { - gemm_half_q_half_cuda_part(a + last_chunk * size_k, b_q_weight, - b_gptq_qzeros, b_gptq_scales, b_g_idx, - c + last_chunk * size_n, last_chunk_size, - size_n, size_k, last_chunk_size, groups, bit); + gemm_half_q_half_cuda_part( + a + last_chunk * size_k, b_q_weight, b_gptq_qzeros, b_gptq_scales, + b_g_idx, c + last_chunk * size_n, last_chunk_size, size_n, size_k, + last_chunk_size, groups, use_v2_format, bit); } } else { gemm_half_q_half_alt(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, - c, size_m, size_n, size_k, bit); + c, size_m, size_n, size_k, use_v2_format, bit); } } @@ -1815,7 +1854,7 @@ void shuffle_exllama_weight(uint32_t* q_weight, int* q_perm, int height, torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, torch::Tensor b_gptq_qzeros, torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, - bool use_exllama, int64_t bit) { + bool use_exllama, bool use_v2_format, int64_t bit) { const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options); @@ -1833,7 +1872,7 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, c.size(1), // n a.size(1), // k b_gptq_qzeros.size(0), // group number - use_exllama, bit); + use_exllama, use_v2_format, bit); return c; } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 7e8660349dad5..8f091a429fbef 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -557,7 +557,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // to prevent the meta function registry. ops.def( "gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, " - "Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, int bit) " + "Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, bool " + "use_v2_format, int bit) " "-> Tensor", {stride_tag}); ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm); diff --git a/tests/kernels/quantization/test_gptq.py b/tests/kernels/quantization/test_gptq.py index 72e4194c13276..7bc7f97ce75b8 100644 --- a/tests/kernels/quantization/test_gptq.py +++ b/tests/kernels/quantization/test_gptq.py @@ -26,4 +26,10 @@ def test_gptq_gemm_opcheck(): idx = torch.empty((0,), device="cuda", dtype=torch.int32) use_exllama = True bit = 4 - opcheck(torch.ops._C.gptq_gemm, (a, weight, zeros, scales, idx, use_exllama, bit)) + # Test both GPTQv1 and GPTQv2 format + opcheck( + torch.ops._C.gptq_gemm, (a, weight, zeros, scales, idx, use_exllama, True, bit) + ) + opcheck( + torch.ops._C.gptq_gemm, (a, weight, zeros, scales, idx, use_exllama, False, bit) + ) diff --git a/tests/quantization/test_gptq_v2.py b/tests/quantization/test_gptq_v2.py new file mode 100644 index 0000000000000..dbafa2e8e7d1f --- /dev/null +++ b/tests/quantization/test_gptq_v2.py @@ -0,0 +1,109 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests whether vllm correctly load and run gptq_v2 format checkpoints. + +Run `pytest tests/quantization/test_gptq_v2.py --forked`. +""" + +import pytest +import torch +from transformers import AutoTokenizer + +from vllm import SamplingParams +from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod + +# A dummy small model quantized by GPTQModel, stored in GPTQ v2 format +MODELS = ["XXXXyu/Qwen3-1.7B-w2g64-gptq_v2"] + +# Generate multiple sequences for testing, because an 1.7B 2-bit model +# cannot always generate normal texts. +N_SEQ = 5 + + +@pytest.mark.parametrize("model_id", MODELS) +def test_model_load(vllm_runner, model_id, monkeypatch): + # `LLM.apply_model` requires pickling a function. + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") + + # Only check the default GPTQ linear method (used for 2/3-bit models). + # 4/8-bit linear methods like Marlin already support gptq_v2. + linear_method_cls = GPTQLinearMethod + + with vllm_runner(model_id, dtype=torch.float16, max_model_len=512) as llm: + + def check_model(model_id): + for name, submodule in model_id.named_modules(): + # Could check more modules if necessary + if name == "model_id.layers.0.self_attn.qkv_proj": + assert isinstance(submodule.quant_method, linear_method_cls) + + config = submodule.quant_method.quant_config + assert config.checkpoint_format == "gptq_v2" + assert submodule.quant_method.use_v2_format + + # Just break since currently we only check 1 module + break + + # Check if gptq_v2 format is correctly loaded + llm.apply_model(check_model) + + +@pytest.mark.parametrize("model_id", MODELS) +def test_model_inference(vllm_runner, model_id): + # Prepare prompt to test the model's generation result. + prompt = "What is the meaning of life?" + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ] + tokenizer = AutoTokenizer.from_pretrained(model_id) + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, # If thinking model, set it to false + ) + sampling_params = SamplingParams( + n=N_SEQ, + max_tokens=128, + temperature=0.7, + top_p=0.8, + top_k=20, + min_p=0, + presence_penalty=2, + ) + + with vllm_runner(model_id, dtype=torch.float16, max_model_len=512) as llm: + # Generate a response to verify inference correctness + output = llm.generate(text, sampling_params) + + # Make sure the output exists + assert output + assert output[0][1] + assert len(output[0][1]) == N_SEQ + + def has_normal_char_distribution(texts, min_len): + for text in texts: + # Response too short + if len(text) < min_len: + return False + + # Basic ratio checks + letters = sum(c.isalpha() for c in text) + spaces = sum(c.isspace() for c in text) + total = len(text) + + letter_ratio = letters / total + space_ratio = spaces / total + + # At least 1 normal text should exist within output sequences + # Normal text should be mostly letters with reasonable spacing + # Some magic numbers, could be adjusted + if 0.5 <= letter_ratio <= 0.9 and 0.01 <= space_ratio <= 0.3: + return True + # No sequence contains normal text, output might be broken + return False + + # Apply some simple checks for giberish output + # Print the output sequences if failed + assert has_normal_char_distribution(output[0][1], 5), output[0][1] diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index eccb9a1ef26fc..9110b0573fc92 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -451,10 +451,18 @@ def gptq_gemm( b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor, use_exllama: bool, + use_v2_format: bool, bit: int, ) -> torch.Tensor: return torch.ops._C.gptq_gemm( - a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_exllama, bit + a, + b_q_weight, + b_gptq_qzeros, + b_gptq_scales, + b_g_idx, + use_exllama, + use_v2_format, + bit, ) @@ -468,6 +476,7 @@ if hasattr(torch.ops._C, "gptq_gemm"): b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor, use_exllama: bool, + use_v2_format: bool, bit: int, ) -> torch.Tensor: return torch.empty( diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index b7bc3abeb724c..2ad28048cdce4 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -11,6 +11,7 @@ from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from torch.nn.parameter import Parameter from vllm import _custom_ops as ops +from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.layer import FusedMoE from vllm.model_executor.layers.linear import LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( @@ -36,6 +37,8 @@ if TYPE_CHECKING: else: QuantizationMethods = str +logger = init_logger(__name__) + class GPTQConfig(QuantizationConfig): """Config class for GPTQ. @@ -52,6 +55,7 @@ class GPTQConfig(QuantizationConfig): dynamic: dict[str, dict[str, int | bool]], autoround_version: str = "", modules_in_block_to_quantize: list[str] | None = None, + checkpoint_format: str = "", ) -> None: # GPTQModel use `dynamic` config property to allow per module # quantization config so each module can be individually optimized. @@ -89,12 +93,24 @@ class GPTQConfig(QuantizationConfig): "Currently, only 2/3/4/8-bit weight quantization is " f"supported for GPTQ, but got {self.weight_bits} bits." ) + # Somehow gptq_gemm 4-bit is buggy, maybe fix it in the future. + # For now, show a warning, since gptq_marlin will be used by default. + if self.weight_bits == 4: + logger.warning_once( + "Currently, the 4-bit gptq_gemm kernel for GPTQ is buggy. " + "Please switch to gptq_marlin or gptq_bitblas." + ) self.modules_in_block_to_quantize = modules_in_block_to_quantize or [] # used to identify GPTQ model quantized by autoround self.autoround_version = autoround_version + # GPTQ v1 and v2 format deals with zero points differently. + # Currently GPTQModel stores v1 format checkpoints by default, + # but provides the option to set `format="gptq_v2"` in `QuantizeConfig`. + self.checkpoint_format = checkpoint_format + def __repr__(self) -> str: return ( f"GPTQConfig(weight_bits={self.weight_bits}, " @@ -102,7 +118,8 @@ class GPTQConfig(QuantizationConfig): f"desc_act={self.desc_act}), " f"lm_head_quantized={self.lm_head_quantized}, " f"dynamic={self.dynamic}, " - f"modules_in_block_to_quantize={self.modules_in_block_to_quantize})" + f"modules_in_block_to_quantize={self.modules_in_block_to_quantize}), " + f"checkpoint_format={self.checkpoint_format})" ) @classmethod @@ -137,6 +154,9 @@ class GPTQConfig(QuantizationConfig): modules_in_block_to_quantize = cls.get_from_keys_or( config, ["modules_in_block_to_quantize"], default=None ) + checkpoint_format = cls.get_from_keys_or( + config, ["checkpoint_format"], default="" + ) return cls( weight_bits, group_size, @@ -145,6 +165,7 @@ class GPTQConfig(QuantizationConfig): dynamic, autoround_version, modules_in_block_to_quantize, + checkpoint_format, ) def get_quant_method( @@ -154,6 +175,7 @@ class GPTQConfig(QuantizationConfig): # GPTQ MoE support: fall back to MoeWNA16 for broad compatibility from .moe_wna16 import MoeWNA16Config + # TODO: maybe update this for GPTQv2 format checkpoints config = { "quant_method": "gptq", "bits": self.weight_bits, @@ -210,6 +232,9 @@ class GPTQLinearMethod(LinearMethodBase): def __init__(self, quant_config: GPTQConfig): self.quant_config = quant_config + # GPTQ v1 and v2 format deals with zero points differently + self.use_v2_format = quant_config.checkpoint_format == "gptq_v2" + def create_weights( self, layer: torch.nn.Module, @@ -351,6 +376,8 @@ class GPTQLinearMethod(LinearMethodBase): out_shape = x.shape[:-1] + (layer.qweight.shape[-1],) reshaped_x = x.reshape(-1, x.shape[-1]) + # GPTQ v1 and v2 format checkpoints deals with zero points differently, + # and require different gemm kernels. output = ops.gptq_gemm( reshaped_x, layer.qweight, @@ -358,6 +385,7 @@ class GPTQLinearMethod(LinearMethodBase): layer.scales, layer.g_idx, layer.exllama_state == ExllamaState.READY, + self.use_v2_format, self.quant_config.weight_bits, ) if bias is not None: diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/exllama.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/exllama.py index 27d8344f6b488..9fba4aafb05a7 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/exllama.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/exllama.py @@ -145,10 +145,15 @@ class ExllamaLinearKernel(MPLinearKernel): w_q, w_s, w_zp, w_g_idx = self._get_weight_params(layer) + # gptq_gemm supports GPTQv2 format by passing use_v2_format=True. + # However, the MPLinearLayerConfig doesn't contain format info. + # So hardcode GPTQv1 format here, to keep its behavior unchanged. + use_v2_format = False + assert w_zp is not None, "Zero points are required by Exllama" assert w_g_idx is not None, "Group index is required by Exllama" output = ops.gptq_gemm( - x_2d, w_q, w_zp, w_s, w_g_idx, True, c.weight_type.size_bits + x_2d, w_q, w_zp, w_s, w_g_idx, True, use_v2_format, c.weight_type.size_bits ) if bias is not None: