[Kernel] Add GPTQv2 format support for low-bit or asymmetric quantization, by adapting gptq_gemm (#26092)

This commit is contained in:
Xiangyu Li 2025-10-24 11:26:13 +08:00 committed by GitHub
parent 1f9460c4c1
commit 5cc6bddb6e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 295 additions and 98 deletions

View File

@ -307,7 +307,7 @@ void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros, torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
bool use_exllama, int64_t bit); bool use_exllama, bool use_v2_format, int64_t bit);
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit); void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);

View File

@ -185,7 +185,7 @@ typedef void (*fp_gemm_half_q_half_gptq_kernel)(const half*, const uint32_t*,
const uint32_t*, const half*, const uint32_t*, const half*,
half*, const int, const int, half*, const int, const int,
const int, const int, const int, const int,
const int*); const bool, const int*);
template <bool first_block, int m_count> template <bool first_block, int m_count>
__global__ void gemm_half_q_half_gptq_4bit_kernel( __global__ void gemm_half_q_half_gptq_4bit_kernel(
@ -193,12 +193,15 @@ __global__ void gemm_half_q_half_gptq_4bit_kernel(
const uint32_t* __restrict__ b_gptq_qzeros, const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales, half* __restrict__ c, const half* __restrict__ b_gptq_scales, half* __restrict__ c,
const int size_m, const int size_n, const int size_k, const int groups, const int size_m, const int size_n, const int size_k, const int groups,
const int* __restrict__ b_q_perm) { const bool use_v2_format, const int* __restrict__ b_q_perm) {
MatrixView_half a_(a, size_m, size_k); MatrixView_half a_(a, size_m, size_k);
MatrixView_half_rw c_(c, size_m, size_n); MatrixView_half_rw c_(c, size_m, size_n);
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
auto t = threadIdx.x; auto t = threadIdx.x;
// Block // Block
@ -256,10 +259,10 @@ __global__ void gemm_half_q_half_gptq_4bit_kernel(
half2 y1y16[4][2]; half2 y1y16[4][2];
b_gptq_qzeros_.item4(zeros, group, n); b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_f(scales, group, n); b_gptq_scales_.item4_f(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); dequant_4bit_8_prep_zero(zeros[0] + zero_offset, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); dequant_4bit_8_prep_zero(zeros[1] + zero_offset, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); dequant_4bit_8_prep_zero(zeros[2] + zero_offset, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); dequant_4bit_8_prep_zero(zeros[3] + zero_offset, z1z16[3], y1y16[3]);
// Column result // Column result
float block_c[m_count][4] = {}; float block_c[m_count][4] = {};
@ -272,10 +275,10 @@ __global__ void gemm_half_q_half_gptq_4bit_kernel(
nextgroup += groupsize; nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n); b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_f(scales, group, n); b_gptq_scales_.item4_f(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); dequant_4bit_8_prep_zero(zeros[0] + zero_offset, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); dequant_4bit_8_prep_zero(zeros[1] + zero_offset, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); dequant_4bit_8_prep_zero(zeros[2] + zero_offset, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); dequant_4bit_8_prep_zero(zeros[3] + zero_offset, z1z16[3], y1y16[3]);
} }
#pragma unroll #pragma unroll
@ -329,12 +332,15 @@ __global__ void gemm_half_q_half_gptq_2bit_kernel(
const uint32_t* __restrict__ b_gptq_qzeros, const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales, half* __restrict__ c, const half* __restrict__ b_gptq_scales, half* __restrict__ c,
const int size_m, const int size_n, const int size_k, const int groups, const int size_m, const int size_n, const int size_k, const int groups,
const int* __restrict__ b_q_perm) { const bool use_v2_format, const int* __restrict__ b_q_perm) {
MatrixView_half a_(a, size_m, size_k); MatrixView_half a_(a, size_m, size_k);
MatrixView_half_rw c_(c, size_m, size_n); MatrixView_half_rw c_(c, size_m, size_n);
MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
auto t = threadIdx.x; auto t = threadIdx.x;
// Block // Block
@ -409,10 +415,10 @@ __global__ void gemm_half_q_half_gptq_2bit_kernel(
int4 load_int4 = *b_ptr4; int4 load_int4 = *b_ptr4;
half2 dq[4][8]; half2 dq[4][8];
dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1); dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + zero_offset);
dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1); dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + zero_offset);
dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1); dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + zero_offset);
dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1); dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + zero_offset);
#pragma unroll #pragma unroll
for (int m = 0; m < m_count; m++) { for (int m = 0; m < m_count; m++) {
@ -448,12 +454,15 @@ __global__ void gemm_half_q_half_gptq_3bit_kernel(
const uint32_t* __restrict__ b_gptq_qzeros, const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales, half* __restrict__ c, const half* __restrict__ b_gptq_scales, half* __restrict__ c,
const int size_m, const int size_n, const int size_k, const int groups, const int size_m, const int size_n, const int size_k, const int groups,
const int* __restrict__ b_q_perm) { const bool use_v2_format, const int* __restrict__ b_q_perm) {
MatrixView_half a_(a, size_m, size_k); MatrixView_half a_(a, size_m, size_k);
MatrixView_half_rw c_(c, size_m, size_n); MatrixView_half_rw c_(c, size_m, size_n);
MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
auto t = threadIdx.x; auto t = threadIdx.x;
// Block // Block
@ -534,13 +543,13 @@ __global__ void gemm_half_q_half_gptq_3bit_kernel(
half2 dq[4][16]; half2 dq[4][16];
dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0],
size_n, zeros[0] + 1); size_n, zeros[0] + zero_offset);
dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1],
size_n, zeros[1] + 1); size_n, zeros[1] + zero_offset);
dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2],
size_n, zeros[2] + 1); size_n, zeros[2] + zero_offset);
dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3],
size_n, zeros[3] + 1); size_n, zeros[3] + zero_offset);
#pragma unroll #pragma unroll
for (int m = 0; m < m_count; m++) { for (int m = 0; m < m_count; m++) {
@ -574,12 +583,15 @@ __global__ void gemm_half_q_half_gptq_8bit_kernel(
const uint32_t* __restrict__ b_gptq_qzeros, const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales, half* __restrict__ c, const half* __restrict__ b_gptq_scales, half* __restrict__ c,
const int size_m, const int size_n, const int size_k, const int groups, const int size_m, const int size_n, const int size_k, const int groups,
const int* __restrict__ b_q_perm) { const bool use_v2_format, const int* __restrict__ b_q_perm) {
MatrixView_half a_(a, size_m, size_k); MatrixView_half a_(a, size_m, size_k);
MatrixView_half_rw c_(c, size_m, size_n); MatrixView_half_rw c_(c, size_m, size_n);
MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
auto t = threadIdx.x; auto t = threadIdx.x;
// Block // Block
@ -658,13 +670,13 @@ __global__ void gemm_half_q_half_gptq_8bit_kernel(
half2 dq[4][4]; half2 dq[4][4];
dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n, dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n,
zeros[0] + 1); zeros[0] + zero_offset);
dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n, dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n,
zeros[1] + 1); zeros[1] + zero_offset);
dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n, dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n,
zeros[2] + 1); zeros[2] + zero_offset);
dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n, dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n,
zeros[3] + 1); zeros[3] + zero_offset);
for (int m = 0; m < m_count; m++) { for (int m = 0; m < m_count; m++) {
block_c[m][0] = block_c[m][0] =
@ -730,7 +742,8 @@ void gemm_half_q_half_cuda_part(const half* a, const uint32_t* b_q_weight,
const uint32_t* b_gptq_qzeros, const uint32_t* b_gptq_qzeros,
const half* b_gptq_scales, const int* b_q_perm, const half* b_gptq_scales, const int* b_q_perm,
half* c, int size_m, int size_n, int size_k, half* c, int size_m, int size_n, int size_k,
int m_count, int groups, int bit) { int m_count, int groups, bool use_v2_format,
int bit) {
dim3 blockDim, gridDim; dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE; blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1; blockDim.y = 1;
@ -743,20 +756,23 @@ void gemm_half_q_half_cuda_part(const half* a, const uint32_t* b_q_weight,
pick_gemm_half_q_half_gptq_kernel(true, m_count, bit); pick_gemm_half_q_half_gptq_kernel(true, m_count, bit);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
kernel<<<gridDim, blockDim, 0, stream>>>(a, b_q_weight, b_gptq_qzeros, kernel<<<gridDim, blockDim, 0, stream>>>(
b_gptq_scales, c, size_m, size_n, a, b_q_weight, b_gptq_qzeros, b_gptq_scales, c, size_m, size_n, size_k,
size_k, groups, b_q_perm); groups, use_v2_format, b_q_perm);
} }
__global__ void reconstruct_exllama_8bit_kernel( __global__ void reconstruct_exllama_8bit_kernel(
const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm, const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_gptq_qzeros, const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales, const int size_k, const int size_n, const half* __restrict__ b_gptq_scales, const int size_k, const int size_n,
const int groups, half* __restrict__ b) { const int groups, const bool use_v2_format, half* __restrict__ b) {
MatrixView_half_rw b_(b, size_k, size_n); MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
auto offset_k = BLOCK_KN_SIZE * blockIdx.y; auto offset_k = BLOCK_KN_SIZE * blockIdx.y;
auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
@ -812,13 +828,13 @@ __global__ void reconstruct_exllama_8bit_kernel(
half2 dq[4][4]; half2 dq[4][4];
dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n, dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n,
zeros[0] + 1); zeros[0] + zero_offset);
dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n, dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n,
zeros[1] + 1); zeros[1] + zero_offset);
dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n, dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n,
zeros[2] + 1); zeros[2] + zero_offset);
dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n, dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n,
zeros[3] + 1); zeros[3] + zero_offset);
// half* dqh = (half*)dq; // half* dqh = (half*)dq;
if (b_q_perm) { if (b_q_perm) {
@ -849,11 +865,14 @@ __global__ void reconstruct_exllama_4bit_kernel(
const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm, const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_gptq_qzeros, const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales, const int size_k, const int size_n, const half* __restrict__ b_gptq_scales, const int size_k, const int size_n,
const int groups, half* __restrict__ b) { const int groups, const bool use_v2_format, half* __restrict__ b) {
MatrixView_half_rw b_(b, size_k, size_n); MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
auto offset_k = BLOCK_KN_SIZE * blockIdx.y; auto offset_k = BLOCK_KN_SIZE * blockIdx.y;
auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
@ -888,10 +907,10 @@ __global__ void reconstruct_exllama_4bit_kernel(
half2 y1y16[4][2]; half2 y1y16[4][2];
b_gptq_qzeros_.item4(zeros, group, n); b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n); b_gptq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); dequant_4bit_8_prep_zero(zeros[0] + zero_offset, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); dequant_4bit_8_prep_zero(zeros[1] + zero_offset, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); dequant_4bit_8_prep_zero(zeros[2] + zero_offset, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); dequant_4bit_8_prep_zero(zeros[3] + zero_offset, z1z16[3], y1y16[3]);
__syncthreads(); __syncthreads();
@ -904,10 +923,10 @@ __global__ void reconstruct_exllama_4bit_kernel(
nextgroup += groupsize; nextgroup += groupsize;
b_gptq_qzeros_.item4(zeros, group, n); b_gptq_qzeros_.item4(zeros, group, n);
b_gptq_scales_.item4_h2(scales, group, n); b_gptq_scales_.item4_h2(scales, group, n);
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); dequant_4bit_8_prep_zero(zeros[0] + zero_offset, z1z16[0], y1y16[0]);
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); dequant_4bit_8_prep_zero(zeros[1] + zero_offset, z1z16[1], y1y16[1]);
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); dequant_4bit_8_prep_zero(zeros[2] + zero_offset, z1z16[2], y1y16[2]);
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); dequant_4bit_8_prep_zero(zeros[3] + zero_offset, z1z16[3], y1y16[3]);
} }
for (int p = 0; p < 4; p++) { for (int p = 0; p < 4; p++) {
@ -954,11 +973,14 @@ __global__ void reconstruct_exllama_3bit_kernel(
const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm, const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_gptq_qzeros, const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales, const int size_k, const int size_n, const half* __restrict__ b_gptq_scales, const int size_k, const int size_n,
const int groups, half* __restrict__ b) { const int groups, const bool use_v2_format, half* __restrict__ b) {
MatrixView_half_rw b_(b, size_k, size_n); MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
auto offset_k = BLOCK_KN_SIZE * blockIdx.y; auto offset_k = BLOCK_KN_SIZE * blockIdx.y;
auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
@ -1016,13 +1038,13 @@ __global__ void reconstruct_exllama_3bit_kernel(
half2 dq[4][16]; half2 dq[4][16];
dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0],
size_n, zeros[0] + 1); size_n, zeros[0] + zero_offset);
dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1],
size_n, zeros[1] + 1); size_n, zeros[1] + zero_offset);
dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2],
size_n, zeros[2] + 1); size_n, zeros[2] + zero_offset);
dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3],
size_n, zeros[3] + 1); size_n, zeros[3] + zero_offset);
if (b_q_perm) { if (b_q_perm) {
for (int j = 0; j < 16; j++) { for (int j = 0; j < 16; j++) {
@ -1052,11 +1074,14 @@ __global__ void reconstruct_exllama_2bit_kernel(
const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm, const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm,
const uint32_t* __restrict__ b_gptq_qzeros, const uint32_t* __restrict__ b_gptq_qzeros,
const half* __restrict__ b_gptq_scales, const int size_k, const int size_n, const half* __restrict__ b_gptq_scales, const int size_k, const int size_n,
const int groups, half* __restrict__ b) { const int groups, const bool use_v2_format, half* __restrict__ b) {
MatrixView_half_rw b_(b, size_k, size_n); MatrixView_half_rw b_(b, size_k, size_n);
MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
auto offset_k = BLOCK_KN_SIZE * blockIdx.y; auto offset_k = BLOCK_KN_SIZE * blockIdx.y;
auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
@ -1108,10 +1133,10 @@ __global__ void reconstruct_exllama_2bit_kernel(
int4 load_int4 = *b_ptr4; int4 load_int4 = *b_ptr4;
half2 dq[4][8]; half2 dq[4][8];
dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + 1); dequant_2bit_16(load_int4.x, dq[0], size_n, zeros[0] + zero_offset);
dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + 1); dequant_2bit_16(load_int4.y, dq[1], size_n, zeros[1] + zero_offset);
dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + 1); dequant_2bit_16(load_int4.z, dq[2], size_n, zeros[2] + zero_offset);
dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + 1); dequant_2bit_16(load_int4.w, dq[3], size_n, zeros[3] + zero_offset);
b_ptr += size_n; b_ptr += size_n;
// half* dqh = (half*)dq; // half* dqh = (half*)dq;
@ -1143,7 +1168,7 @@ void reconstruct_exllama(const uint32_t* b_q_weight,
const uint32_t* b_gptq_qzeros, const uint32_t* b_gptq_qzeros,
const half* b_gptq_scales, const int* b_q_perm, const half* b_gptq_scales, const int* b_q_perm,
half* out, int height, int width, int groups, half* out, int height, int width, int groups,
int bit) { bool use_v2_format, int bit) {
dim3 blockDim, gridDim; dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE; blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1; blockDim.y = 1;
@ -1162,14 +1187,14 @@ void reconstruct_exllama(const uint32_t* b_q_weight,
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
reconstruct_exllama_kernel<<<gridDim, blockDim, 0, stream>>>( reconstruct_exllama_kernel<<<gridDim, blockDim, 0, stream>>>(
b_q_weight, b_q_perm, b_gptq_qzeros, b_gptq_scales, height, width, groups, b_q_weight, b_q_perm, b_gptq_qzeros, b_gptq_scales, height, width, groups,
out); use_v2_format, out);
} }
__global__ void gemm_half_q_half_alt_4bit_kernel( __global__ void gemm_half_q_half_alt_4bit_kernel(
const half2* __restrict__ vec, const uint32_t* __restrict__ mat, const half2* __restrict__ vec, const uint32_t* __restrict__ mat,
half* __restrict__ mul, const half* __restrict__ scales, half* __restrict__ mul, const half* __restrict__ scales,
const uint32_t* __restrict__ zeros, const int* __restrict__ g_idx, const uint32_t* __restrict__ zeros, const int* __restrict__ g_idx,
int batch, int height, int width) { int batch, int height, int width, bool use_v2_format) {
int zero_width = width / 8; int zero_width = width / 8;
int vec_height = height * 4; int vec_height = height * 4;
const int blockwidth2 = BLOCK_KN_SIZE / 2; const int blockwidth2 = BLOCK_KN_SIZE / 2;
@ -1179,6 +1204,9 @@ __global__ void gemm_half_q_half_alt_4bit_kernel(
int h_end = min(BLOCK_KN_SIZE / 8, height - h) * 4; int h_end = min(BLOCK_KN_SIZE / 8, height - h) * 4;
auto w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; auto w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
__shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2]; __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2];
if (threadIdx.x < h_end) { if (threadIdx.x < h_end) {
for (int m = 0; m < b_end; ++m) { for (int m = 0; m < b_end; ++m) {
@ -1223,10 +1251,11 @@ __global__ void gemm_half_q_half_alt_4bit_kernel(
half2 zero = __halves2half2( half2 zero = __halves2half2(
__hmul(scale_f, __hmul(scale_f,
__int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xF) - __int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xF) -
1)), zero_offset)),
__hmul(scale_f2, __hmul(
__int2half_rn( scale_f2,
-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xF) - 1))); __int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xF) -
zero_offset)));
scales_tmp[tmp_k] = scale; scales_tmp[tmp_k] = scale;
zeros_tmp[tmp_k] = zero; zeros_tmp[tmp_k] = zero;
} }
@ -1268,7 +1297,7 @@ __global__ void gemm_half_q_half_alt_8bit_kernel(
const half2* __restrict__ vec, const uint32_t* __restrict__ mat, const half2* __restrict__ vec, const uint32_t* __restrict__ mat,
half* __restrict__ mul, const half* __restrict__ scales, half* __restrict__ mul, const half* __restrict__ scales,
const uint32_t* __restrict__ zeros, const int* __restrict__ g_idx, const uint32_t* __restrict__ zeros, const int* __restrict__ g_idx,
int batch, int height, int width) { int batch, int height, int width, bool use_v2_format) {
int zero_width = width / 4; int zero_width = width / 4;
int vec_height = height * 2; int vec_height = height * 2;
const int blockwidth2 = BLOCK_KN_SIZE / 2; const int blockwidth2 = BLOCK_KN_SIZE / 2;
@ -1278,6 +1307,9 @@ __global__ void gemm_half_q_half_alt_8bit_kernel(
int h_end = min(BLOCK_KN_SIZE / 4, height - h) * 2; int h_end = min(BLOCK_KN_SIZE / 4, height - h) * 2;
auto w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; auto w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
__shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2]; __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2];
if (threadIdx.x < h_end) { if (threadIdx.x < h_end) {
for (int m = 0; m < b_end; ++m) { for (int m = 0; m < b_end; ++m) {
@ -1312,12 +1344,13 @@ __global__ void gemm_half_q_half_alt_8bit_kernel(
half scale_f2 = scales[g2 * width + w]; half scale_f2 = scales[g2 * width + w];
half2 scale = __halves2half2(scale_f, scale_f2); half2 scale = __halves2half2(scale_f, scale_f2);
half2 zero = __halves2half2( half2 zero = __halves2half2(
__hmul(scale_f, __hmul(scale_f, __int2half_rn(
__int2half_rn( -((zeros[g * zero_width + z_w] >> z_mod) & 0xff) -
-((zeros[g * zero_width + z_w] >> z_mod) & 0xff) - 1)), zero_offset)),
__hmul(scale_f2, __hmul(
__int2half_rn( scale_f2,
-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xff) - 1))); __int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xff) -
zero_offset)));
scales_tmp[tmp_k] = scale; scales_tmp[tmp_k] = scale;
zeros_tmp[tmp_k] = zero; zeros_tmp[tmp_k] = zero;
} }
@ -1355,7 +1388,7 @@ void gemm_half_q_half_alt(const half* a, const uint32_t* b_q_weight,
const uint32_t* b_gptq_qzeros, const uint32_t* b_gptq_qzeros,
const half* b_gptq_scales, const int* b_g_idx, const half* b_gptq_scales, const int* b_g_idx,
half* c, int size_m, int size_n, int size_k, half* c, int size_m, int size_n, int size_k,
int bit) { bool use_v2_format, int bit) {
dim3 blockDim, gridDim; dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE; blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1; blockDim.y = 1;
@ -1372,17 +1405,15 @@ void gemm_half_q_half_alt(const half* a, const uint32_t* b_q_weight,
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
kernel<<<gridDim, blockDim, 0, stream>>>( kernel<<<gridDim, blockDim, 0, stream>>>(
(const half2*)a, b_q_weight, c, b_gptq_scales, b_gptq_qzeros, b_g_idx, (const half2*)a, b_q_weight, c, b_gptq_scales, b_gptq_qzeros, b_g_idx,
size_m, size_k / 32 * bit, size_n); size_m, size_k / 32 * bit, size_n, use_v2_format);
} }
template <class T, int bit> template <class T, int bit>
__global__ void reconstruct_gptq_kernel(const uint32_t* __restrict__ w, __global__ void reconstruct_gptq_kernel(
const half* __restrict__ w_scales, const uint32_t* __restrict__ w, const half* __restrict__ w_scales,
const uint32_t* __restrict__ w_zeros, const uint32_t* __restrict__ w_zeros, const int* __restrict__ g_idx,
const int* __restrict__ g_idx, const int height, const int width, const int group,
const int height, const int width, const bool use_v2_format, half* __restrict__ out) {
const int group,
half* __restrict__ out) {
// Start of block // Start of block
auto column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; auto column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
@ -1395,6 +1426,9 @@ __global__ void reconstruct_gptq_kernel(const uint32_t* __restrict__ w,
MatrixView_half w_scales_(w_scales, group, width); MatrixView_half w_scales_(w_scales, group, width);
T w_zeros_(w_zeros, group, width); T w_zeros_(w_zeros, group, width);
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
uint32_t w_read = w[blockIdx.y * width + column]; uint32_t w_read = w[blockIdx.y * width + column];
half* out_ptr = out_.item_ptr(row, column); half* out_ptr = out_.item_ptr(row, column);
@ -1402,7 +1436,7 @@ __global__ void reconstruct_gptq_kernel(const uint32_t* __restrict__ w,
for (int s = 0; s < 32; s += bit) { for (int s = 0; s < 32; s += bit) {
int group = g_idx[row + s / bit]; int group = g_idx[row + s / bit];
half w_scale = w_scales_.item(group, column); half w_scale = w_scales_.item(group, column);
uint32_t w_zero = w_zeros_.item(group, column) + 1; uint32_t w_zero = w_zeros_.item(group, column) + zero_offset;
half w_item = half w_item =
__hmul(__int2half_rn((int)((w_read >> s) & ((1 << bit) - 1)) - w_zero), __hmul(__int2half_rn((int)((w_read >> s) & ((1 << bit) - 1)) - w_zero),
w_scale); w_scale);
@ -1415,7 +1449,7 @@ __global__ void reconstruct_gptq_3bit_kernel(
const uint32_t* __restrict__ w, const half* __restrict__ w_scales, const uint32_t* __restrict__ w, const half* __restrict__ w_scales,
const uint32_t* __restrict__ w_zeros, const int* __restrict__ g_idx, const uint32_t* __restrict__ w_zeros, const int* __restrict__ g_idx,
const int height, const int width, const int group, const int height, const int width, const int group,
half* __restrict__ out) { const bool use_v2_format, half* __restrict__ out) {
// Start of block // Start of block
auto column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; auto column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
auto row = blockIdx.y * 32; auto row = blockIdx.y * 32;
@ -1427,6 +1461,9 @@ __global__ void reconstruct_gptq_3bit_kernel(
MatrixView_half w_scales_(w_scales, group, width); MatrixView_half w_scales_(w_scales, group, width);
MatrixView_q3_row w_zeros_(w_zeros, group, width); MatrixView_q3_row w_zeros_(w_zeros, group, width);
// GPTQv2 and GPTQv1 handles zero points differently
int zero_offset = use_v2_format ? 0 : 1;
uint32_t w1 = w[(blockIdx.y * 3) * width + column]; uint32_t w1 = w[(blockIdx.y * 3) * width + column];
uint32_t w2 = w[(blockIdx.y * 3 + 1) * width + column]; uint32_t w2 = w[(blockIdx.y * 3 + 1) * width + column];
uint32_t w3 = w[(blockIdx.y * 3 + 2) * width + column]; uint32_t w3 = w[(blockIdx.y * 3 + 2) * width + column];
@ -1436,7 +1473,7 @@ __global__ void reconstruct_gptq_3bit_kernel(
for (int i = 0; i < 32; i += 1) { for (int i = 0; i < 32; i += 1) {
int group = g_idx[row + i]; int group = g_idx[row + i];
half w_scale = w_scales_.item(group, column); half w_scale = w_scales_.item(group, column);
uint32_t w_zero = w_zeros_.item(group, column) + 1; uint32_t w_zero = w_zeros_.item(group, column) + zero_offset;
int w_item; int w_item;
if (i == 10) { if (i == 10) {
w_item = (w1 >> 30) | ((w2 << 2) & 0x4); w_item = (w1 >> 30) | ((w2 << 2) & 0x4);
@ -1456,7 +1493,8 @@ __global__ void reconstruct_gptq_3bit_kernel(
void reconstruct_gptq(const uint32_t* b_q_weight, const uint32_t* b_gptq_qzeros, void reconstruct_gptq(const uint32_t* b_q_weight, const uint32_t* b_gptq_qzeros,
const half* b_gptq_scales, const int* b_g_idx, half* out, const half* b_gptq_scales, const int* b_g_idx, half* out,
int height, int width, int groups, int bit) { int height, int width, int groups, bool use_v2_format,
int bit) {
dim3 blockDim, gridDim; dim3 blockDim, gridDim;
blockDim.x = BLOCK_KN_SIZE; blockDim.x = BLOCK_KN_SIZE;
blockDim.y = 1; blockDim.y = 1;
@ -1476,7 +1514,7 @@ void reconstruct_gptq(const uint32_t* b_q_weight, const uint32_t* b_gptq_qzeros,
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
kernel<<<gridDim, blockDim, 0, stream>>>(b_q_weight, b_gptq_scales, kernel<<<gridDim, blockDim, 0, stream>>>(b_q_weight, b_gptq_scales,
b_gptq_qzeros, b_g_idx, height, b_gptq_qzeros, b_g_idx, height,
width, groups, out); width, groups, use_v2_format, out);
} }
void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a, void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a,
@ -1484,7 +1522,8 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a,
const uint32_t* b_gptq_qzeros, const uint32_t* b_gptq_qzeros,
const half* b_gptq_scales, const int* b_g_idx, const half* b_gptq_scales, const int* b_g_idx,
half* c, half* temp_dq, int size_m, int size_n, half* c, half* temp_dq, int size_m, int size_n,
int size_k, int groups, bool use_exllama, int bit) { int size_k, int groups, bool use_exllama,
bool use_v2_format, int bit) {
bool use_reconstruct; bool use_reconstruct;
if (use_exllama) { if (use_exllama) {
use_reconstruct = ((bit == 8 && size_m > MAX_Q_GEMM_ROWS_8BIT) || use_reconstruct = ((bit == 8 && size_m > MAX_Q_GEMM_ROWS_8BIT) ||
@ -1498,10 +1537,10 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a,
// Reconstruct FP16 matrix, then cuBLAS // Reconstruct FP16 matrix, then cuBLAS
if (use_exllama) { if (use_exllama) {
reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
temp_dq, size_k, size_n, groups, bit); temp_dq, size_k, size_n, groups, use_v2_format, bit);
} else { } else {
reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
temp_dq, size_k, size_n, groups, bit); temp_dq, size_k, size_n, groups, use_v2_format, bit);
} }
const half alpha = __float2half(1.0f); const half alpha = __float2half(1.0f);
@ -1517,18 +1556,18 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a,
if (max_chunks) { if (max_chunks) {
gemm_half_q_half_cuda_part(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, gemm_half_q_half_cuda_part(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
b_g_idx, c, last_chunk, size_n, size_k, b_g_idx, c, last_chunk, size_n, size_k,
BLOCK_M_SIZE_MAX, groups, bit); BLOCK_M_SIZE_MAX, groups, use_v2_format, bit);
} }
if (last_chunk_size) { if (last_chunk_size) {
gemm_half_q_half_cuda_part(a + last_chunk * size_k, b_q_weight, gemm_half_q_half_cuda_part(
b_gptq_qzeros, b_gptq_scales, b_g_idx, a + last_chunk * size_k, b_q_weight, b_gptq_qzeros, b_gptq_scales,
c + last_chunk * size_n, last_chunk_size, b_g_idx, c + last_chunk * size_n, last_chunk_size, size_n, size_k,
size_n, size_k, last_chunk_size, groups, bit); last_chunk_size, groups, use_v2_format, bit);
} }
} else { } else {
gemm_half_q_half_alt(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, gemm_half_q_half_alt(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
c, size_m, size_n, size_k, bit); c, size_m, size_n, size_k, use_v2_format, bit);
} }
} }
@ -1815,7 +1854,7 @@ void shuffle_exllama_weight(uint32_t* q_weight, int* q_perm, int height,
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros, torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
bool use_exllama, int64_t bit) { bool use_exllama, bool use_v2_format, int64_t bit) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options); at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options);
@ -1833,7 +1872,7 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
c.size(1), // n c.size(1), // n
a.size(1), // k a.size(1), // k
b_gptq_qzeros.size(0), // group number b_gptq_qzeros.size(0), // group number
use_exllama, bit); use_exllama, use_v2_format, bit);
return c; return c;
} }

View File

@ -557,7 +557,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// to prevent the meta function registry. // to prevent the meta function registry.
ops.def( ops.def(
"gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, " "gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, "
"Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, int bit) " "Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, bool "
"use_v2_format, int bit) "
"-> Tensor", "-> Tensor",
{stride_tag}); {stride_tag});
ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm); ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);

View File

@ -26,4 +26,10 @@ def test_gptq_gemm_opcheck():
idx = torch.empty((0,), device="cuda", dtype=torch.int32) idx = torch.empty((0,), device="cuda", dtype=torch.int32)
use_exllama = True use_exllama = True
bit = 4 bit = 4
opcheck(torch.ops._C.gptq_gemm, (a, weight, zeros, scales, idx, use_exllama, bit)) # Test both GPTQv1 and GPTQv2 format
opcheck(
torch.ops._C.gptq_gemm, (a, weight, zeros, scales, idx, use_exllama, True, bit)
)
opcheck(
torch.ops._C.gptq_gemm, (a, weight, zeros, scales, idx, use_exllama, False, bit)
)

View File

@ -0,0 +1,109 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests whether vllm correctly load and run gptq_v2 format checkpoints.
Run `pytest tests/quantization/test_gptq_v2.py --forked`.
"""
import pytest
import torch
from transformers import AutoTokenizer
from vllm import SamplingParams
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
# A dummy small model quantized by GPTQModel, stored in GPTQ v2 format
MODELS = ["XXXXyu/Qwen3-1.7B-w2g64-gptq_v2"]
# Generate multiple sequences for testing, because an 1.7B 2-bit model
# cannot always generate normal texts.
N_SEQ = 5
@pytest.mark.parametrize("model_id", MODELS)
def test_model_load(vllm_runner, model_id, monkeypatch):
# `LLM.apply_model` requires pickling a function.
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
# Only check the default GPTQ linear method (used for 2/3-bit models).
# 4/8-bit linear methods like Marlin already support gptq_v2.
linear_method_cls = GPTQLinearMethod
with vllm_runner(model_id, dtype=torch.float16, max_model_len=512) as llm:
def check_model(model_id):
for name, submodule in model_id.named_modules():
# Could check more modules if necessary
if name == "model_id.layers.0.self_attn.qkv_proj":
assert isinstance(submodule.quant_method, linear_method_cls)
config = submodule.quant_method.quant_config
assert config.checkpoint_format == "gptq_v2"
assert submodule.quant_method.use_v2_format
# Just break since currently we only check 1 module
break
# Check if gptq_v2 format is correctly loaded
llm.apply_model(check_model)
@pytest.mark.parametrize("model_id", MODELS)
def test_model_inference(vllm_runner, model_id):
# Prepare prompt to test the model's generation result.
prompt = "What is the meaning of life?"
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
]
tokenizer = AutoTokenizer.from_pretrained(model_id)
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False, # If thinking model, set it to false
)
sampling_params = SamplingParams(
n=N_SEQ,
max_tokens=128,
temperature=0.7,
top_p=0.8,
top_k=20,
min_p=0,
presence_penalty=2,
)
with vllm_runner(model_id, dtype=torch.float16, max_model_len=512) as llm:
# Generate a response to verify inference correctness
output = llm.generate(text, sampling_params)
# Make sure the output exists
assert output
assert output[0][1]
assert len(output[0][1]) == N_SEQ
def has_normal_char_distribution(texts, min_len):
for text in texts:
# Response too short
if len(text) < min_len:
return False
# Basic ratio checks
letters = sum(c.isalpha() for c in text)
spaces = sum(c.isspace() for c in text)
total = len(text)
letter_ratio = letters / total
space_ratio = spaces / total
# At least 1 normal text should exist within output sequences
# Normal text should be mostly letters with reasonable spacing
# Some magic numbers, could be adjusted
if 0.5 <= letter_ratio <= 0.9 and 0.01 <= space_ratio <= 0.3:
return True
# No sequence contains normal text, output might be broken
return False
# Apply some simple checks for giberish output
# Print the output sequences if failed
assert has_normal_char_distribution(output[0][1], 5), output[0][1]

View File

@ -451,10 +451,18 @@ def gptq_gemm(
b_gptq_scales: torch.Tensor, b_gptq_scales: torch.Tensor,
b_g_idx: torch.Tensor, b_g_idx: torch.Tensor,
use_exllama: bool, use_exllama: bool,
use_v2_format: bool,
bit: int, bit: int,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.ops._C.gptq_gemm( return torch.ops._C.gptq_gemm(
a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_exllama, bit a,
b_q_weight,
b_gptq_qzeros,
b_gptq_scales,
b_g_idx,
use_exllama,
use_v2_format,
bit,
) )
@ -468,6 +476,7 @@ if hasattr(torch.ops._C, "gptq_gemm"):
b_gptq_scales: torch.Tensor, b_gptq_scales: torch.Tensor,
b_g_idx: torch.Tensor, b_g_idx: torch.Tensor,
use_exllama: bool, use_exllama: bool,
use_v2_format: bool,
bit: int, bit: int,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.empty( return torch.empty(

View File

@ -11,6 +11,7 @@ from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import FusedMoE from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.linear import LinearMethodBase from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
@ -36,6 +37,8 @@ if TYPE_CHECKING:
else: else:
QuantizationMethods = str QuantizationMethods = str
logger = init_logger(__name__)
class GPTQConfig(QuantizationConfig): class GPTQConfig(QuantizationConfig):
"""Config class for GPTQ. """Config class for GPTQ.
@ -52,6 +55,7 @@ class GPTQConfig(QuantizationConfig):
dynamic: dict[str, dict[str, int | bool]], dynamic: dict[str, dict[str, int | bool]],
autoround_version: str = "", autoround_version: str = "",
modules_in_block_to_quantize: list[str] | None = None, modules_in_block_to_quantize: list[str] | None = None,
checkpoint_format: str = "",
) -> None: ) -> None:
# GPTQModel use `dynamic` config property to allow per module # GPTQModel use `dynamic` config property to allow per module
# quantization config so each module can be individually optimized. # quantization config so each module can be individually optimized.
@ -89,12 +93,24 @@ class GPTQConfig(QuantizationConfig):
"Currently, only 2/3/4/8-bit weight quantization is " "Currently, only 2/3/4/8-bit weight quantization is "
f"supported for GPTQ, but got {self.weight_bits} bits." f"supported for GPTQ, but got {self.weight_bits} bits."
) )
# Somehow gptq_gemm 4-bit is buggy, maybe fix it in the future.
# For now, show a warning, since gptq_marlin will be used by default.
if self.weight_bits == 4:
logger.warning_once(
"Currently, the 4-bit gptq_gemm kernel for GPTQ is buggy. "
"Please switch to gptq_marlin or gptq_bitblas."
)
self.modules_in_block_to_quantize = modules_in_block_to_quantize or [] self.modules_in_block_to_quantize = modules_in_block_to_quantize or []
# used to identify GPTQ model quantized by autoround # used to identify GPTQ model quantized by autoround
self.autoround_version = autoround_version self.autoround_version = autoround_version
# GPTQ v1 and v2 format deals with zero points differently.
# Currently GPTQModel stores v1 format checkpoints by default,
# but provides the option to set `format="gptq_v2"` in `QuantizeConfig`.
self.checkpoint_format = checkpoint_format
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f"GPTQConfig(weight_bits={self.weight_bits}, " f"GPTQConfig(weight_bits={self.weight_bits}, "
@ -102,7 +118,8 @@ class GPTQConfig(QuantizationConfig):
f"desc_act={self.desc_act}), " f"desc_act={self.desc_act}), "
f"lm_head_quantized={self.lm_head_quantized}, " f"lm_head_quantized={self.lm_head_quantized}, "
f"dynamic={self.dynamic}, " f"dynamic={self.dynamic}, "
f"modules_in_block_to_quantize={self.modules_in_block_to_quantize})" f"modules_in_block_to_quantize={self.modules_in_block_to_quantize}), "
f"checkpoint_format={self.checkpoint_format})"
) )
@classmethod @classmethod
@ -137,6 +154,9 @@ class GPTQConfig(QuantizationConfig):
modules_in_block_to_quantize = cls.get_from_keys_or( modules_in_block_to_quantize = cls.get_from_keys_or(
config, ["modules_in_block_to_quantize"], default=None config, ["modules_in_block_to_quantize"], default=None
) )
checkpoint_format = cls.get_from_keys_or(
config, ["checkpoint_format"], default=""
)
return cls( return cls(
weight_bits, weight_bits,
group_size, group_size,
@ -145,6 +165,7 @@ class GPTQConfig(QuantizationConfig):
dynamic, dynamic,
autoround_version, autoround_version,
modules_in_block_to_quantize, modules_in_block_to_quantize,
checkpoint_format,
) )
def get_quant_method( def get_quant_method(
@ -154,6 +175,7 @@ class GPTQConfig(QuantizationConfig):
# GPTQ MoE support: fall back to MoeWNA16 for broad compatibility # GPTQ MoE support: fall back to MoeWNA16 for broad compatibility
from .moe_wna16 import MoeWNA16Config from .moe_wna16 import MoeWNA16Config
# TODO: maybe update this for GPTQv2 format checkpoints
config = { config = {
"quant_method": "gptq", "quant_method": "gptq",
"bits": self.weight_bits, "bits": self.weight_bits,
@ -210,6 +232,9 @@ class GPTQLinearMethod(LinearMethodBase):
def __init__(self, quant_config: GPTQConfig): def __init__(self, quant_config: GPTQConfig):
self.quant_config = quant_config self.quant_config = quant_config
# GPTQ v1 and v2 format deals with zero points differently
self.use_v2_format = quant_config.checkpoint_format == "gptq_v2"
def create_weights( def create_weights(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
@ -351,6 +376,8 @@ class GPTQLinearMethod(LinearMethodBase):
out_shape = x.shape[:-1] + (layer.qweight.shape[-1],) out_shape = x.shape[:-1] + (layer.qweight.shape[-1],)
reshaped_x = x.reshape(-1, x.shape[-1]) reshaped_x = x.reshape(-1, x.shape[-1])
# GPTQ v1 and v2 format checkpoints deals with zero points differently,
# and require different gemm kernels.
output = ops.gptq_gemm( output = ops.gptq_gemm(
reshaped_x, reshaped_x,
layer.qweight, layer.qweight,
@ -358,6 +385,7 @@ class GPTQLinearMethod(LinearMethodBase):
layer.scales, layer.scales,
layer.g_idx, layer.g_idx,
layer.exllama_state == ExllamaState.READY, layer.exllama_state == ExllamaState.READY,
self.use_v2_format,
self.quant_config.weight_bits, self.quant_config.weight_bits,
) )
if bias is not None: if bias is not None:

View File

@ -145,10 +145,15 @@ class ExllamaLinearKernel(MPLinearKernel):
w_q, w_s, w_zp, w_g_idx = self._get_weight_params(layer) w_q, w_s, w_zp, w_g_idx = self._get_weight_params(layer)
# gptq_gemm supports GPTQv2 format by passing use_v2_format=True.
# However, the MPLinearLayerConfig doesn't contain format info.
# So hardcode GPTQv1 format here, to keep its behavior unchanged.
use_v2_format = False
assert w_zp is not None, "Zero points are required by Exllama" assert w_zp is not None, "Zero points are required by Exllama"
assert w_g_idx is not None, "Group index is required by Exllama" assert w_g_idx is not None, "Group index is required by Exllama"
output = ops.gptq_gemm( output = ops.gptq_gemm(
x_2d, w_q, w_zp, w_s, w_g_idx, True, c.weight_type.size_bits x_2d, w_q, w_zp, w_s, w_g_idx, True, use_v2_format, c.weight_type.size_bits
) )
if bias is not None: if bias is not None: