mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 08:55:46 +08:00
[Kernel] Marlin Expansion: Support AutoGPTQ Models with Marlin (#3922)
Co-authored-by: alexm <alexm@neuralmagic.com> Co-authored-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
parent
df29793dc7
commit
73c8d677e5
@ -177,6 +177,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"csrc/quantization/aqlm/gemm_kernels.cu"
|
||||
"csrc/quantization/awq/gemm_kernels.cu"
|
||||
"csrc/quantization/marlin/marlin_cuda_kernel.cu"
|
||||
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
||||
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
|
||||
"csrc/custom_all_reduce.cu")
|
||||
endif()
|
||||
|
||||
|
||||
18
csrc/ops.h
18
csrc/ops.h
@ -124,6 +124,24 @@ torch::Tensor marlin_gemm(
|
||||
int64_t size_m,
|
||||
int64_t size_n,
|
||||
int64_t size_k);
|
||||
|
||||
torch::Tensor gptq_marlin_gemm(
|
||||
torch::Tensor &a,
|
||||
torch::Tensor &b_q_weight,
|
||||
torch::Tensor &b_scales,
|
||||
torch::Tensor &g_idx,
|
||||
torch::Tensor &perm,
|
||||
torch::Tensor &workspace,
|
||||
int64_t size_m,
|
||||
int64_t size_n,
|
||||
int64_t size_k,
|
||||
bool is_k_full);
|
||||
|
||||
torch::Tensor gptq_marlin_repack(
|
||||
torch::Tensor &b_q_weight,
|
||||
torch::Tensor &perm,
|
||||
int64_t size_k,
|
||||
int64_t size_n);
|
||||
#endif
|
||||
|
||||
void squeezellm_gemm(
|
||||
|
||||
@ -67,6 +67,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM");
|
||||
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
|
||||
ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ");
|
||||
ops.def("gptq_marlin_gemm", &gptq_marlin_gemm, "gptq_marlin Optimized Quantized GEMM for GPTQ");
|
||||
ops.def("gptq_marlin_repack", &gptq_marlin_repack, "gptq_marlin repack from GPTQ");
|
||||
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
|
||||
#endif
|
||||
|
||||
|
||||
1520
csrc/quantization/gptq_marlin/gptq_marlin.cu
Normal file
1520
csrc/quantization/gptq_marlin/gptq_marlin.cu
Normal file
File diff suppressed because it is too large
Load Diff
74
csrc/quantization/gptq_marlin/gptq_marlin.cuh
Normal file
74
csrc/quantization/gptq_marlin/gptq_marlin.cuh
Normal file
@ -0,0 +1,74 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace gptq_marlin {
|
||||
|
||||
// 8 warps are a good choice since every SM has 4 schedulers and having more than 1 warp per
|
||||
// schedule allows some more latency hiding. At the same time, we want relatively few warps to have
|
||||
// many registers per warp and small tiles.
|
||||
static constexpr int default_threads = 256;
|
||||
|
||||
static constexpr int pipe_stages = 4; // 4 pipeline stages fit into shared memory
|
||||
|
||||
static constexpr int min_thread_n = 64;
|
||||
static constexpr int min_thread_k = 64;
|
||||
|
||||
static constexpr int tile_size = 16;
|
||||
static constexpr int max_par = 16;
|
||||
|
||||
static constexpr int pack_factor_4bit = 8; // We have 8 4-bit vals inside a 32 bit
|
||||
|
||||
template <typename T, int n>
|
||||
struct Vec {
|
||||
T elems[n];
|
||||
__device__ T& operator[](int i) { return elems[i]; }
|
||||
};
|
||||
|
||||
using I4 = Vec<int, 4>;
|
||||
|
||||
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
// No support for async
|
||||
#else
|
||||
|
||||
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) {
|
||||
const int BYTES = 16;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile("{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %0, 0;\n"
|
||||
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
||||
"}\n" ::"r"((int)pred),
|
||||
"r"(smem), "l"(glob_ptr), "n"(BYTES));
|
||||
}
|
||||
|
||||
__device__ inline void cp_async4_stream(void* smem_ptr, const void* glob_ptr) {
|
||||
const int BYTES = 16;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile("{\n"
|
||||
" .reg .b64 p;\n"
|
||||
" createpolicy.fractional.L2::evict_first.b64 p, 1.0;"
|
||||
" cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n"
|
||||
"}\n" ::"r"(smem),
|
||||
"l"(glob_ptr), "n"(BYTES));
|
||||
}
|
||||
|
||||
__device__ inline void cp_async_fence() { asm volatile("cp.async.commit_group;\n" ::); }
|
||||
|
||||
template <int n>
|
||||
__device__ inline void cp_async_wait() {
|
||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace gptq_marlin
|
||||
324
csrc/quantization/gptq_marlin/gptq_marlin_repack.cu
Normal file
324
csrc/quantization/gptq_marlin/gptq_marlin_repack.cu
Normal file
@ -0,0 +1,324 @@
|
||||
#include "gptq_marlin.cuh"
|
||||
|
||||
namespace gptq_marlin {
|
||||
|
||||
static constexpr int repack_stages = 8;
|
||||
|
||||
static constexpr int repack_threads = 256;
|
||||
|
||||
static constexpr int tile_k_size = tile_size;
|
||||
static constexpr int tile_n_size = tile_k_size * 4;
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
|
||||
template <int const num_threads, bool const has_perm>
|
||||
__global__ void
|
||||
marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
|
||||
uint32_t const *__restrict__ perm_ptr,
|
||||
uint32_t *__restrict__ out_ptr, int size_k, int size_n) {}
|
||||
|
||||
} // namespace gptq_marlin
|
||||
|
||||
torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
|
||||
int64_t size_k, int64_t size_n) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0");
|
||||
return torch::empty({1, 1});
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
template <int const num_threads, bool const has_perm>
|
||||
__global__ void
|
||||
marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
|
||||
uint32_t const *__restrict__ perm_ptr,
|
||||
uint32_t *__restrict__ out_ptr, int size_k, int size_n) {
|
||||
int k_tiles = size_k / tile_k_size;
|
||||
int n_tiles = size_n / tile_n_size;
|
||||
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
|
||||
|
||||
int start_k_tile = blockIdx.x * block_k_tiles;
|
||||
if (start_k_tile >= k_tiles) {
|
||||
return;
|
||||
}
|
||||
|
||||
int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);
|
||||
|
||||
// Wait until the next thread tile has been loaded to shared memory.
|
||||
auto wait_for_stage = [&]() {
|
||||
// We only have `stages - 2` active fetches since we are double buffering
|
||||
// and can only issue the next fetch when it is guaranteed that the previous
|
||||
// shared memory load is fully complete (as it may otherwise be
|
||||
// overwritten).
|
||||
cp_async_wait<repack_stages - 2>();
|
||||
__syncthreads();
|
||||
};
|
||||
|
||||
extern __shared__ int4 sh[];
|
||||
|
||||
constexpr int perm_size = tile_k_size / 4;
|
||||
|
||||
int4 *sh_perm_ptr = sh;
|
||||
int4 *sh_pipe_ptr = sh_perm_ptr;
|
||||
if constexpr (has_perm) {
|
||||
sh_pipe_ptr += perm_size;
|
||||
}
|
||||
|
||||
constexpr int stage_n_threads = tile_n_size / 4;
|
||||
constexpr int stage_k_threads =
|
||||
has_perm ? tile_k_size : tile_k_size / pack_factor_4bit;
|
||||
constexpr int stage_size = stage_k_threads * stage_n_threads;
|
||||
|
||||
auto load_perm_to_shared = [&](int k_tile_id) {
|
||||
int first_k_int4 = (k_tile_id * tile_k_size) / 4;
|
||||
|
||||
int4 const *perm_int4_ptr = reinterpret_cast<int4 const *>(perm_ptr);
|
||||
|
||||
if (threadIdx.x < perm_size) {
|
||||
sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x];
|
||||
}
|
||||
__syncthreads();
|
||||
};
|
||||
|
||||
auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
|
||||
if (n_tile_id >= n_tiles) {
|
||||
cp_async_fence();
|
||||
return;
|
||||
}
|
||||
|
||||
int first_n = n_tile_id * tile_n_size;
|
||||
|
||||
int4 *sh_ptr = sh_pipe_ptr + stage_size * pipe;
|
||||
|
||||
if constexpr (has_perm) {
|
||||
if (threadIdx.x < stage_size) {
|
||||
int k_id = threadIdx.x / stage_n_threads;
|
||||
int n_id = threadIdx.x % stage_n_threads;
|
||||
|
||||
uint32_t const *sh_perm_int_ptr =
|
||||
reinterpret_cast<uint32_t const *>(sh_perm_ptr);
|
||||
|
||||
int src_k = sh_perm_int_ptr[k_id];
|
||||
int src_k_packed = src_k / pack_factor_4bit;
|
||||
|
||||
cp_async4_stream(
|
||||
&sh_ptr[k_id * stage_n_threads + n_id],
|
||||
reinterpret_cast<int4 const *>(&(
|
||||
b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)])));
|
||||
}
|
||||
|
||||
} else {
|
||||
if (threadIdx.x < stage_size) {
|
||||
int k_id = threadIdx.x / stage_n_threads;
|
||||
int n_id = threadIdx.x % stage_n_threads;
|
||||
|
||||
int first_k = k_tile_id * tile_k_size;
|
||||
int first_k_packed = first_k / pack_factor_4bit;
|
||||
|
||||
cp_async4_stream(&sh_ptr[k_id * stage_n_threads + n_id],
|
||||
reinterpret_cast<int4 const *>(
|
||||
&(b_q_weight_ptr[(first_k_packed + k_id) * size_n +
|
||||
first_n + (n_id * 4)])));
|
||||
}
|
||||
}
|
||||
|
||||
cp_async_fence();
|
||||
};
|
||||
|
||||
auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {
|
||||
if (n_tile_id >= n_tiles) {
|
||||
return;
|
||||
}
|
||||
|
||||
int warp_id = threadIdx.x / 32;
|
||||
int th_id = threadIdx.x % 32;
|
||||
|
||||
if (warp_id >= 4) {
|
||||
return;
|
||||
}
|
||||
|
||||
int tc_col = th_id / 4;
|
||||
int tc_row = (th_id % 4) * 2;
|
||||
|
||||
constexpr int tc_offsets[4] = {0, 1, 8, 9};
|
||||
|
||||
int cur_n = warp_id * 16 + tc_col;
|
||||
|
||||
constexpr int sh_stride = 64;
|
||||
|
||||
int4 *sh_stage_ptr = sh_pipe_ptr + stage_size * pipe;
|
||||
uint32_t *sh_stage_int_ptr = reinterpret_cast<uint32_t *>(sh_stage_ptr);
|
||||
|
||||
uint32_t *sh_perm_int_ptr = reinterpret_cast<uint32_t *>(sh_perm_ptr);
|
||||
|
||||
uint32_t vals[pack_factor_4bit];
|
||||
|
||||
if constexpr (has_perm) {
|
||||
for (int i = 0; i < 4; i++) {
|
||||
int k_idx = tc_row + tc_offsets[i];
|
||||
|
||||
uint32_t src_k = sh_perm_int_ptr[k_idx];
|
||||
uint32_t src_k_pos = src_k % pack_factor_4bit;
|
||||
|
||||
uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n];
|
||||
uint32_t b1_cur_val = (b1_val >> (src_k_pos * 4)) & 0xf;
|
||||
|
||||
uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8];
|
||||
uint32_t b2_cur_val = (b2_val >> (src_k_pos * 4)) & 0xf;
|
||||
|
||||
vals[i] = b1_cur_val;
|
||||
vals[4 + i] = b2_cur_val;
|
||||
}
|
||||
|
||||
} else {
|
||||
|
||||
uint32_t b1_val_1 = sh_stage_int_ptr[cur_n];
|
||||
uint32_t b1_val_2 = sh_stage_int_ptr[sh_stride + cur_n];
|
||||
|
||||
uint32_t b2_val_1 = sh_stage_int_ptr[cur_n + 8];
|
||||
uint32_t b2_val_2 = sh_stage_int_ptr[sh_stride + cur_n + 8];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2; i++) {
|
||||
int cur_elem = tc_row + tc_offsets[i];
|
||||
vals[i] = (b1_val_1 >> (cur_elem * 4)) & 0xf;
|
||||
vals[4 + i] = (b2_val_1 >> (cur_elem * 4)) & 0xf;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 2; i < 4; i++) {
|
||||
int cur_elem = tc_row + tc_offsets[i] - 8;
|
||||
vals[i] = (b1_val_2 >> (cur_elem * 4)) & 0xf;
|
||||
vals[4 + i] = (b2_val_2 >> (cur_elem * 4)) & 0xf;
|
||||
}
|
||||
}
|
||||
|
||||
// Result of:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||
constexpr int pack_idx[pack_factor_4bit] = {0, 2, 4, 6, 1, 3, 5, 7};
|
||||
|
||||
uint32_t res = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < pack_factor_4bit; i++) {
|
||||
res |= vals[pack_idx[i]] << (i * 4);
|
||||
}
|
||||
|
||||
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor_4bit;
|
||||
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
|
||||
|
||||
out_ptr[out_offset + th_id * 4 + warp_id] = res;
|
||||
};
|
||||
|
||||
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
|
||||
#pragma unroll
|
||||
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
|
||||
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
|
||||
}
|
||||
|
||||
wait_for_stage();
|
||||
};
|
||||
#pragma unroll
|
||||
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
|
||||
int n_tile_id = 0;
|
||||
|
||||
if constexpr (has_perm) {
|
||||
load_perm_to_shared(k_tile_id);
|
||||
}
|
||||
|
||||
start_pipes(k_tile_id, n_tile_id);
|
||||
|
||||
while (n_tile_id < n_tiles) {
|
||||
#pragma unroll
|
||||
for (int pipe = 0; pipe < repack_stages; pipe++) {
|
||||
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
|
||||
n_tile_id + pipe + repack_stages - 1);
|
||||
repack_tile(pipe, k_tile_id, n_tile_id + pipe);
|
||||
wait_for_stage();
|
||||
}
|
||||
n_tile_id += repack_stages;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gptq_marlin
|
||||
|
||||
torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
|
||||
int64_t size_k, int64_t size_n) {
|
||||
// Verify compatibility with marlin tile of 16x64
|
||||
TORCH_CHECK(size_k % gptq_marlin::tile_k_size == 0, "size_k = ", size_k,
|
||||
" is not divisible by tile_k_size = ", gptq_marlin::tile_k_size);
|
||||
TORCH_CHECK(size_n % gptq_marlin::tile_n_size == 0, "size_n = ", size_n,
|
||||
" is not divisible by tile_n_size = ", gptq_marlin::tile_n_size);
|
||||
|
||||
// Verify B
|
||||
TORCH_CHECK((size_k / gptq_marlin::pack_factor_4bit) == b_q_weight.size(0),
|
||||
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
|
||||
", size_k = ", size_k,
|
||||
", pack_factor_4bit = ", gptq_marlin::pack_factor_4bit);
|
||||
TORCH_CHECK(b_q_weight.size(1) == size_n,
|
||||
"b_q_weight.size(1) = ", b_q_weight.size(1),
|
||||
" is not size_n = ", size_n);
|
||||
|
||||
// Verify device and strides
|
||||
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
|
||||
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
|
||||
TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt");
|
||||
|
||||
TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU");
|
||||
TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous");
|
||||
TORCH_CHECK(perm.dtype() == at::kInt, "perm type is not at::kInt");
|
||||
|
||||
// Alloc buffers
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
|
||||
auto options = torch::TensorOptions()
|
||||
.dtype(b_q_weight.dtype())
|
||||
.device(b_q_weight.device());
|
||||
torch::Tensor out = torch::empty(
|
||||
{size_k / gptq_marlin::tile_size,
|
||||
size_n * gptq_marlin::tile_size / gptq_marlin::pack_factor_4bit},
|
||||
options);
|
||||
|
||||
// Detect if there is act_order
|
||||
bool has_perm = perm.size(0) != 0;
|
||||
|
||||
// Get ptrs
|
||||
uint32_t const *b_q_weight_ptr =
|
||||
reinterpret_cast<uint32_t const *>(b_q_weight.data_ptr());
|
||||
uint32_t const *perm_ptr =
|
||||
reinterpret_cast<uint32_t const *>(perm.data_ptr());
|
||||
uint32_t *out_ptr = reinterpret_cast<uint32_t *>(out.data_ptr());
|
||||
|
||||
// Get dev info
|
||||
int dev = b_q_weight.get_device();
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
|
||||
int blocks;
|
||||
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
||||
|
||||
int max_shared_mem = 0;
|
||||
cudaDeviceGetAttribute(&max_shared_mem,
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
||||
TORCH_CHECK(max_shared_mem > 0);
|
||||
|
||||
if (has_perm) {
|
||||
cudaFuncSetAttribute(
|
||||
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, true>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
max_shared_mem);
|
||||
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, true>
|
||||
<<<blocks, gptq_marlin::repack_threads, max_shared_mem,
|
||||
stream>>>(b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n);
|
||||
|
||||
} else {
|
||||
cudaFuncSetAttribute(
|
||||
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, false>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
max_shared_mem);
|
||||
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, false>
|
||||
<<<blocks, gptq_marlin::repack_threads, max_shared_mem,
|
||||
stream>>>(b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n);
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
#endif
|
||||
93
tests/models/test_gptq_marlin.py
Normal file
93
tests/models/test_gptq_marlin.py
Normal file
@ -0,0 +1,93 @@
|
||||
"""Compares the outputs of gptq vs gptq_marlin
|
||||
Note: GPTQ and Marlin do not have bitwise correctness.
|
||||
As a result, in this test, we just confirm that the top selected tokens of the
|
||||
Marlin/GPTQ models are in the top 3 selections of each other.
|
||||
Note: Marlin internally uses locks to synchronize the threads. This can
|
||||
result in very slight nondeterminism for Marlin. As a result, we re-run the test
|
||||
up to 3 times to see if we pass.
|
||||
Note: This test currently fails running with --forked with the following:
|
||||
RuntimeError: Cannot re-initialize CUDA in forked subprocess.
|
||||
To use CUDA with multiprocessing, you must use the 'spawn' start method
|
||||
Run `pytest tests/models/test_gptq_marlin.py`.
|
||||
"""
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.models.utils import check_logprobs_close
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
|
||||
MAX_MODEL_LEN = 1024
|
||||
|
||||
capability = torch.cuda.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
gptq_marlin_not_supported = (
|
||||
capability < QUANTIZATION_METHODS["gptq_marlin"].get_min_capability())
|
||||
|
||||
MODELS = [
|
||||
# act_order==False, group_size=channelwise
|
||||
("robertgshaw2/zephyr-7b-beta-channelwise-gptq", "main"),
|
||||
# act_order==False, group_size=128
|
||||
("TheBloke/Llama-2-7B-GPTQ", "main"),
|
||||
|
||||
# act_order==True, group_size=128
|
||||
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "main"),
|
||||
# act_order==True, group_size=64
|
||||
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-4bit-64g-actorder_True"),
|
||||
# act_order==True, group_size=32
|
||||
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-4bit-32g-actorder_True"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.flaky(reruns=2)
|
||||
@pytest.mark.skipif(gptq_marlin_not_supported,
|
||||
reason="gptq_marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_models(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
model_name, revision = model
|
||||
|
||||
# Run marlin.
|
||||
gptq_marlin_model = vllm_runner(model_name=model_name,
|
||||
revision=revision,
|
||||
dtype=dtype,
|
||||
quantization="marlin",
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
tensor_parallel_size=1,
|
||||
disable_custom_all_reduce=True)
|
||||
|
||||
gptq_marlin_outputs = gptq_marlin_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
del gptq_marlin_model
|
||||
|
||||
# Run gptq.
|
||||
gptq_model = vllm_runner(model_name=model_name,
|
||||
revision=revision,
|
||||
dtype=dtype,
|
||||
quantization="gptq",
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
tensor_parallel_size=1,
|
||||
disable_custom_all_reduce=True)
|
||||
gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts,
|
||||
max_tokens,
|
||||
num_logprobs)
|
||||
del gptq_model
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=gptq_outputs,
|
||||
outputs_1_lst=gptq_marlin_outputs,
|
||||
name_0="gptq",
|
||||
name_1="gptq_marlin",
|
||||
)
|
||||
@ -10,12 +10,12 @@ up to 3 times to see if we pass.
|
||||
|
||||
Run `pytest tests/models/test_marlin.py`.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.models.utils import check_logprobs_close
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
|
||||
capability = torch.cuda.get_device_capability()
|
||||
@ -55,43 +55,24 @@ def test_models(
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
marlin_model = vllm_runner(model_pair.model_marlin, dtype=dtype)
|
||||
marlin_model = vllm_runner(model_pair.model_marlin,
|
||||
dtype=dtype,
|
||||
quantization="marlin")
|
||||
marlin_outputs = marlin_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
# Note: not sure why, but deleting just the model on Ada Lovelace
|
||||
# does not free the GPU memory. On Ampere, deleting the just model
|
||||
# frees the memory.
|
||||
del marlin_model
|
||||
|
||||
gptq_model = vllm_runner(model_pair.model_gptq, dtype=dtype)
|
||||
gptq_model = vllm_runner(model_pair.model_gptq,
|
||||
dtype=dtype,
|
||||
quantization="gptq")
|
||||
gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts,
|
||||
max_tokens,
|
||||
num_logprobs)
|
||||
|
||||
# Note: not sure why, but deleting just the model on Ada Lovelace
|
||||
# does not free the GPU memory. On Ampere, deleting the just model
|
||||
# frees the memory.
|
||||
del gptq_model
|
||||
|
||||
# loop through the prompts
|
||||
for prompt_idx in range(len(example_prompts)):
|
||||
gptq_output_ids, gptq_output_str, gptq_logprobs = gptq_outputs[
|
||||
prompt_idx]
|
||||
marlin_output_ids, marlin_output_str, marlin_logprobs = marlin_outputs[
|
||||
prompt_idx]
|
||||
|
||||
for idx, (gptq_output_id, marlin_output_id) in enumerate(
|
||||
zip(gptq_output_ids, marlin_output_ids)):
|
||||
# If sequence is not an exact match,
|
||||
if marlin_output_id != gptq_output_id:
|
||||
# Each predicted token must be in top 5 of the other's
|
||||
assert gptq_output_id in marlin_logprobs[idx], (
|
||||
f"Test{prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\n"
|
||||
f"Marlin:\t{marlin_output_str!r}")
|
||||
assert marlin_output_id in gptq_logprobs[idx], (
|
||||
f"Test{prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\n"
|
||||
f"Marlin:\t{marlin_output_str!r}")
|
||||
|
||||
# Break out since sequences will now diverge.
|
||||
break
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=gptq_outputs,
|
||||
outputs_1_lst=marlin_outputs,
|
||||
name_0="gptq",
|
||||
name_1="marlin",
|
||||
)
|
||||
|
||||
29
tests/models/utils.py
Normal file
29
tests/models/utils.py
Normal file
@ -0,0 +1,29 @@
|
||||
def check_logprobs_close(outputs_0_lst, outputs_1_lst, name_0, name_1):
|
||||
"""Compare the logprobs of two sequences generated by different models,
|
||||
which should be similar but not necessarily equal.
|
||||
"""
|
||||
# Loop through responses to each prompt.
|
||||
for prompt_idx, (outputs_0,
|
||||
outputs_1) in enumerate(zip(outputs_0_lst,
|
||||
outputs_1_lst)):
|
||||
output_ids_0, output_str_0, logprobs_0 = outputs_0
|
||||
output_ids_1, output_str_1, logprobs_1 = outputs_1
|
||||
|
||||
# Loop through generated tokens.
|
||||
for idx, (output_id_0,
|
||||
output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):
|
||||
|
||||
# If generated tokens don't match, then
|
||||
if output_id_0 != output_id_1:
|
||||
# Each predicted token must be in top N logprobs of the other
|
||||
assert output_id_0 in logprobs_1[idx], (
|
||||
f"Test{prompt_idx}:"
|
||||
f"\n{name_0}:\t{output_str_0!r}"
|
||||
f"\n{name_1}:\t{output_str_1!r}")
|
||||
assert output_id_1 in logprobs_0[idx], (
|
||||
f"Test{prompt_idx}:"
|
||||
f"\n{name_0}:\t{output_str_0!r}"
|
||||
f"\n{name_1}:\t{output_str_1!r}")
|
||||
|
||||
# Break out since sequences will now diverge.
|
||||
break
|
||||
@ -1,64 +0,0 @@
|
||||
"""Tests whether Marlin models can be loaded from the autogptq config.
|
||||
|
||||
Run `pytest tests/quantization/test_autogptq_marlin_configs.py --forked`.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelPair:
|
||||
model_marlin: str
|
||||
model_gptq: str
|
||||
|
||||
|
||||
# Model Id // Expected Kernel
|
||||
MODELS_QUANT_TYPE = [
|
||||
# compat: autogptq <=0.7.1 is_marlin_format: bool
|
||||
("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "marlin"),
|
||||
("TheBloke/Llama-2-7B-Chat-GPTQ", "gptq"),
|
||||
# compat: autogptq >=0.8.0 use checkpoint_format: str
|
||||
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "marlin"),
|
||||
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "gptq")
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_quant_type", MODELS_QUANT_TYPE)
|
||||
def test_auto_gptq(model_quant_type: str, ) -> None:
|
||||
model_path, quant_type = model_quant_type
|
||||
|
||||
model_config_no_quant_arg = ModelConfig(
|
||||
model_path,
|
||||
model_path,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
revision=None,
|
||||
quantization=None # case 1
|
||||
)
|
||||
|
||||
model_config_quant_arg = ModelConfig(
|
||||
model_path,
|
||||
model_path,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
revision=None,
|
||||
quantization="gptq" # case 2
|
||||
)
|
||||
|
||||
assert model_config_no_quant_arg.quantization == quant_type, (
|
||||
f"Expected quant_type == {quant_type} for {model_path}, "
|
||||
f"but found {model_config_no_quant_arg.quantization} "
|
||||
"for no --quantization None case")
|
||||
|
||||
assert model_config_quant_arg.quantization == quant_type, (
|
||||
f"Expected quant_type == {quant_type} for {model_path}, "
|
||||
f"but found {model_config_quant_arg.quantization} "
|
||||
"for --quantization gptq case")
|
||||
73
tests/quantization/test_configs.py
Normal file
73
tests/quantization/test_configs.py
Normal file
@ -0,0 +1,73 @@
|
||||
"""Tests whether Marlin models can be loaded from the autogptq config.
|
||||
|
||||
Run `pytest tests/quantization/test_configs.py --forked`.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelPair:
|
||||
model_marlin: str
|
||||
model_gptq: str
|
||||
|
||||
|
||||
# Model Id // Quantization Arg // Expected Type
|
||||
MODEL_ARG_EXPTYPES = [
|
||||
# AUTOGPTQ
|
||||
# compat: autogptq <=0.7.1 is_marlin_format: bool
|
||||
# Model Serialized in Marlin Format should always use Marlin kernel.
|
||||
("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", None, "marlin"),
|
||||
("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "marlin", "marlin"),
|
||||
("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "gptq", "marlin"),
|
||||
("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "awq", "ERROR"),
|
||||
# Model Serialized in Exllama Format.
|
||||
("TheBloke/Llama-2-7B-Chat-GPTQ", None, "gptq_marlin"),
|
||||
("TheBloke/Llama-2-7B-Chat-GPTQ", "marlin", "gptq_marlin"),
|
||||
("TheBloke/Llama-2-7B-Chat-GPTQ", "gptq", "gptq"),
|
||||
("TheBloke/Llama-2-7B-Chat-GPTQ", "awq", "ERROR"),
|
||||
# compat: autogptq >=0.8.0 use checkpoint_format: str
|
||||
# Model Serialized in Marlin Format should always use Marlin kernel.
|
||||
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", None, "marlin"),
|
||||
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "marlin", "marlin"),
|
||||
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "gptq", "marlin"),
|
||||
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "awq", "ERROR"),
|
||||
# Model Serialized in Exllama Format.
|
||||
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", None, "gptq_marlin"),
|
||||
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "marlin", "gptq_marlin"),
|
||||
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "gptq", "gptq"),
|
||||
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "awq", "ERROR"),
|
||||
|
||||
# AUTOAWQ
|
||||
("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", None, "awq"),
|
||||
("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "awq", "awq"),
|
||||
("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "marlin", "ERROR"),
|
||||
("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "gptq", "ERROR"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_arg_exptype", MODEL_ARG_EXPTYPES)
|
||||
def test_auto_gptq(model_arg_exptype: str) -> None:
|
||||
model_path, quantization_arg, expected_type = model_arg_exptype
|
||||
|
||||
try:
|
||||
model_config = ModelConfig(model_path,
|
||||
model_path,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
revision=None,
|
||||
quantization=quantization_arg)
|
||||
found_quantization_type = model_config.quantization
|
||||
except ValueError:
|
||||
found_quantization_type = "ERROR"
|
||||
|
||||
assert found_quantization_type == expected_type, (
|
||||
f"Expected quant_type == {expected_type} for {model_path}, "
|
||||
f"but found {found_quantization_type} "
|
||||
f"for no --quantization {quantization_arg} case")
|
||||
@ -9,11 +9,14 @@ from packaging.version import Version
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
|
||||
get_quantization_config)
|
||||
from vllm.transformers_utils.config import get_config, get_hf_text_config
|
||||
from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_cpu, is_hip,
|
||||
is_neuron)
|
||||
|
||||
GPTQMarlinConfig = get_quantization_config("gptq_marlin")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
|
||||
@ -138,14 +141,34 @@ class ModelConfig:
|
||||
is_format_marlin = (quant_cfg.get("checkpoint_format") == "marlin"
|
||||
or quant_cfg.get("is_marlin_format", False))
|
||||
|
||||
# Use marlin if the GPTQ model is serialized in marlin format.
|
||||
if quant_method == "gptq" and is_format_marlin:
|
||||
logger.info("The model is serialized in Marlin format. "
|
||||
"Using Marlin kernel.")
|
||||
quant_method = "marlin"
|
||||
if self.quantization == "gptq":
|
||||
self.quantization = quant_method
|
||||
# Check which LinearMethod the GPTQ model should use.
|
||||
if quant_method == "gptq":
|
||||
# If serialized in Marlin format, use MarlinLinearMethod.
|
||||
# TODO (@robertgshaw): migrate under GPTQMarlinLinearMethod.
|
||||
if is_format_marlin:
|
||||
logger.info("The model is serialized in Marlin format. "
|
||||
"Using Marlin kernel.")
|
||||
quant_method = "marlin"
|
||||
if self.quantization == "gptq":
|
||||
self.quantization = quant_method
|
||||
|
||||
# If convertible to Marlin format, use GPTQMarlinLinearMethod
|
||||
# unless the user explicitly specified GPTQLinearMethod.
|
||||
elif GPTQMarlinConfig.is_marlin_compatible(quant_cfg):
|
||||
if self.quantization == "gptq":
|
||||
logger.warning(
|
||||
"The model is convertible to Marlin format, but "
|
||||
"you specified quantization=gptq. Use "
|
||||
"quantization=marlin for faster inference.")
|
||||
else:
|
||||
logger.info(
|
||||
"The model is convertible to Marlin format. "
|
||||
"Using Marlin kernel.")
|
||||
quant_method = "gptq_marlin"
|
||||
if self.quantization == "marlin":
|
||||
self.quantization = quant_method
|
||||
|
||||
# Verify quantization configurations.
|
||||
if self.quantization is None:
|
||||
self.quantization = quant_method
|
||||
elif self.quantization != quant_method:
|
||||
@ -165,7 +188,7 @@ class ModelConfig:
|
||||
raise ValueError(
|
||||
f"{self.quantization} quantization is currently not "
|
||||
f"supported in ROCm.")
|
||||
if self.quantization != "marlin":
|
||||
if (self.quantization not in ["marlin", "gptq_marlin"]):
|
||||
logger.warning(
|
||||
"%s quantization is not fully "
|
||||
"optimized yet. The speed can be slower than "
|
||||
|
||||
@ -6,6 +6,8 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinConfig)
|
||||
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
||||
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
|
||||
|
||||
@ -15,6 +17,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
||||
"fp8": Fp8Config,
|
||||
"gptq": GPTQConfig,
|
||||
"squeezellm": SqueezeLLMConfig,
|
||||
"gptq_marlin": GPTQMarlinConfig,
|
||||
"marlin": MarlinConfig,
|
||||
}
|
||||
|
||||
|
||||
444
vllm/model_executor/layers/quantization/gptq_marlin.py
Normal file
444
vllm/model_executor/layers/quantization/gptq_marlin.py
Normal file
@ -0,0 +1,444 @@
|
||||
import enum
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm._C import ops
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
|
||||
GPTQ_MARLIN_TILE = 16
|
||||
GPTQ_MARLIN_MIN_THREAD_N = 64
|
||||
GPTQ_MARLIN_MIN_THREAD_K = 128
|
||||
GPTQ_MARLIN_MAX_PARALLEL = 16
|
||||
|
||||
GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4]
|
||||
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
GPTQ_MARLIN_SUPPORTED_SYM = [True]
|
||||
|
||||
|
||||
# Precompute permutations for Marlin weight and scale shuffling
|
||||
#
|
||||
# Marlin works on [16,64] tiles. The goal of the permutations
|
||||
# is to reorder the weight data so that it is compatible
|
||||
# with the tensor-core format that is described here:
|
||||
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
|
||||
#
|
||||
# As a result of this reordering, the vector loads inside the
|
||||
# kernel will get the data as it is needed for tensor-core
|
||||
# (without the need to use ldmatrix instructions)
|
||||
def _get_perms():
|
||||
perm = []
|
||||
for i in range(32):
|
||||
perm1 = []
|
||||
col = i // 4
|
||||
for block in [0, 1]:
|
||||
for row in [
|
||||
2 * (i % 4),
|
||||
2 * (i % 4) + 1,
|
||||
2 * (i % 4 + 4),
|
||||
2 * (i % 4 + 4) + 1,
|
||||
]:
|
||||
perm1.append(16 * row + col + 8 * block)
|
||||
for j in range(4):
|
||||
perm.extend([p + 256 * j for p in perm1])
|
||||
|
||||
perm = numpy.array(perm)
|
||||
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
||||
perm = perm.reshape((-1, 8))[:, interleave].ravel() # type: ignore
|
||||
perm = torch.from_numpy(perm)
|
||||
scale_perm = []
|
||||
for i in range(8):
|
||||
scale_perm.extend([i + 8 * j for j in range(8)])
|
||||
scale_perm_single = []
|
||||
for i in range(4):
|
||||
scale_perm_single.extend(
|
||||
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
|
||||
return perm, scale_perm, scale_perm_single
|
||||
|
||||
|
||||
_perm, _scale_perm, _scale_perm_single = _get_perms()
|
||||
|
||||
|
||||
def get_pack_factor(num_bits):
|
||||
assert num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS, (
|
||||
f"Unsupported num_bits = {num_bits}")
|
||||
return 32 // num_bits
|
||||
|
||||
|
||||
def marlin_permute_scales(s, size_k, size_n, group_size):
|
||||
if group_size < size_k and group_size != -1:
|
||||
s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm]
|
||||
else:
|
||||
s = s.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single]
|
||||
s = s.reshape((-1, size_n)).contiguous()
|
||||
|
||||
return s
|
||||
|
||||
|
||||
class GPTQMarlinConfig(QuantizationConfig):
|
||||
"""Config class for GPTQ Marlin"""
|
||||
|
||||
def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
|
||||
is_sym: bool) -> None:
|
||||
if desc_act and group_size == -1:
|
||||
# In this case, act_order == True is the same as act_order == False
|
||||
# (since we have only one group per output channel)
|
||||
desc_act = False
|
||||
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.desc_act = desc_act
|
||||
self.is_sym = is_sym
|
||||
|
||||
# Verify
|
||||
if self.weight_bits not in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
|
||||
raise ValueError(
|
||||
f"Marlin does not support weight_bits = {self.weight_bits}. "
|
||||
f"Only weight_bits = {GPTQ_MARLIN_SUPPORTED_NUM_BITS} "
|
||||
"are supported.")
|
||||
if self.group_size not in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES:
|
||||
raise ValueError(
|
||||
f"Marlin does not support group_size = {self.group_size}. "
|
||||
f"Only group_sizes = {GPTQ_MARLIN_SUPPORTED_GROUP_SIZES} "
|
||||
"are supported.")
|
||||
if self.is_sym not in GPTQ_MARLIN_SUPPORTED_SYM:
|
||||
raise ValueError(
|
||||
f"Marlin does not support is_sym = {self.is_sym}. "
|
||||
f"Only sym = {GPTQ_MARLIN_SUPPORTED_SYM} are supported.")
|
||||
|
||||
# Init
|
||||
self.pack_factor = get_pack_factor(weight_bits)
|
||||
self.tile_size = GPTQ_MARLIN_TILE
|
||||
self.min_thread_n = GPTQ_MARLIN_MIN_THREAD_N
|
||||
self.min_thread_k = GPTQ_MARLIN_MIN_THREAD_K
|
||||
self.max_parallel = GPTQ_MARLIN_MAX_PARALLEL
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"desc_act={self.desc_act})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "gptq_marlin"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
desc_act = cls.get_from_keys(config, ["desc_act"])
|
||||
is_sym = cls.get_from_keys(config, ["sym"])
|
||||
return cls(weight_bits, group_size, desc_act, is_sym)
|
||||
|
||||
def get_quant_method(
|
||||
self,
|
||||
layer: torch.nn.Module) -> Optional["GPTQMarlinLinearMethod"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
return GPTQMarlinLinearMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def is_marlin_compatible(cls, quant_config: Dict[str, Any]):
|
||||
# Extract data from quant config.
|
||||
num_bits = quant_config.get("bits", None)
|
||||
group_size = quant_config.get("group_size", None)
|
||||
sym = quant_config.get("sym", None)
|
||||
desc_act = quant_config.get("desc_act", None)
|
||||
|
||||
# If we cannot find the info needed in the config, cannot convert.
|
||||
if (num_bits is None or group_size is None or sym is None
|
||||
or desc_act is None):
|
||||
return False
|
||||
|
||||
# If the capability of the device is too low, cannot convert.
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
device_capability = major * 10 + minor
|
||||
if device_capability < cls.get_min_capability():
|
||||
return False
|
||||
|
||||
# Otherwise, can convert if model satisfies marlin constraints.
|
||||
return (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS
|
||||
and group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
|
||||
and sym in GPTQ_MARLIN_SUPPORTED_SYM)
|
||||
|
||||
|
||||
class GPTQMarlinState(Enum):
|
||||
REPACK = enum.auto()
|
||||
READY = enum.auto()
|
||||
|
||||
|
||||
class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
"""Linear method for GPTQ Marlin.
|
||||
|
||||
Args:
|
||||
quant_config: The GPTQ Marlin quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
) -> None:
|
||||
del output_size
|
||||
|
||||
# Normalize group_size
|
||||
if self.quant_config.group_size != -1:
|
||||
group_size = self.quant_config.group_size
|
||||
else:
|
||||
group_size = input_size
|
||||
|
||||
# Validate dtype
|
||||
if params_dtype != torch.float16:
|
||||
raise ValueError(
|
||||
f"The params dtype must be float16, but got {params_dtype}")
|
||||
|
||||
# Validate output_size_per_partition
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
if output_size_per_partition % self.quant_config.min_thread_n != 0:
|
||||
raise ValueError(
|
||||
f"Weight output_size_per_partition = "
|
||||
f"{output_size_per_partition} is not divisible by "
|
||||
f" min_thread_n = {self.quant_config.min_thread_n}.")
|
||||
|
||||
# Validate input_size_per_partition
|
||||
if input_size_per_partition % self.quant_config.min_thread_k != 0:
|
||||
raise ValueError(
|
||||
f"Weight input_size_per_partition = "
|
||||
f"{input_size_per_partition} is not divisible "
|
||||
f"by min_thread_k = {self.quant_config.min_thread_k}.")
|
||||
|
||||
if (group_size < input_size
|
||||
and input_size_per_partition % group_size != 0):
|
||||
raise ValueError(
|
||||
f"Weight input_size_per_partition = {input_size_per_partition}"
|
||||
f" is not divisible by group_size = {group_size}.")
|
||||
|
||||
# Detect sharding of scales/zp
|
||||
|
||||
# By default, no sharding over "input dim"
|
||||
scales_and_zp_size = input_size // group_size
|
||||
scales_and_zp_input_dim = None
|
||||
|
||||
if self.quant_config.desc_act:
|
||||
# Act-order case
|
||||
assert self.quant_config.group_size != -1
|
||||
|
||||
is_k_full = input_size_per_partition == input_size
|
||||
|
||||
else:
|
||||
# No act-order case
|
||||
|
||||
# K is always full due to full alignment with
|
||||
# group-size and shard of scales/zp
|
||||
is_k_full = True
|
||||
|
||||
# If this is a row-parallel case, then shard scales/zp
|
||||
if (input_size != input_size_per_partition
|
||||
and self.quant_config.group_size != -1):
|
||||
scales_and_zp_size = input_size_per_partition // group_size
|
||||
scales_and_zp_input_dim = 0
|
||||
|
||||
# Init buffers
|
||||
|
||||
# Quantized weights
|
||||
qweight = Parameter(
|
||||
torch.empty(
|
||||
input_size_per_partition // self.quant_config.pack_factor,
|
||||
output_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
qweight, {
|
||||
**extra_weight_attrs,
|
||||
"input_dim": 0,
|
||||
"output_dim": 1,
|
||||
"packed_dim": 0,
|
||||
"pack_factor": self.quant_config.pack_factor,
|
||||
})
|
||||
|
||||
# Activation order
|
||||
g_idx = Parameter(
|
||||
torch.empty(
|
||||
input_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
# Ignore warning from fused linear layers such as QKVParallelLinear.
|
||||
set_weight_attrs(g_idx, {
|
||||
**extra_weight_attrs, "input_dim": 0,
|
||||
"ignore_warning": True
|
||||
})
|
||||
|
||||
g_idx_sort_indices = Parameter(
|
||||
torch.empty(
|
||||
g_idx.shape,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(g_idx_sort_indices, extra_weight_attrs)
|
||||
|
||||
# Scales
|
||||
scales = Parameter(
|
||||
torch.empty(
|
||||
scales_and_zp_size,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
scales, {
|
||||
**extra_weight_attrs,
|
||||
"input_dim": scales_and_zp_input_dim,
|
||||
"output_dim": 1,
|
||||
})
|
||||
|
||||
# Quantized zero-points
|
||||
qzeros = Parameter(
|
||||
torch.empty(scales_and_zp_size,
|
||||
output_size_per_partition //
|
||||
self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
device="meta"),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
qzeros, {
|
||||
**extra_weight_attrs,
|
||||
"input_dim": scales_and_zp_input_dim,
|
||||
"output_dim": 1,
|
||||
"packed_dim": 1,
|
||||
"pack_factor": self.quant_config.pack_factor,
|
||||
})
|
||||
|
||||
# Allocate marlin workspace
|
||||
max_workspace_size = (
|
||||
output_size_per_partition //
|
||||
self.quant_config.min_thread_n) * self.quant_config.max_parallel
|
||||
workspace = torch.zeros(max_workspace_size,
|
||||
dtype=torch.int,
|
||||
requires_grad=False)
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("g_idx", g_idx)
|
||||
layer.register_parameter("g_idx_sort_indices", g_idx_sort_indices)
|
||||
layer.register_parameter("scales", scales)
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
layer.workspace = workspace
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
layer.input_size = input_size
|
||||
layer.is_k_full = is_k_full
|
||||
layer.marlin_state = GPTQMarlinState.REPACK
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
size_m = reshaped_x.shape[0]
|
||||
part_size_n = layer.output_size_per_partition
|
||||
part_size_k = layer.input_size_per_partition
|
||||
full_size_k = layer.input_size
|
||||
|
||||
out_shape = x.shape[:-1] + (part_size_n, )
|
||||
|
||||
if layer.marlin_state == GPTQMarlinState.REPACK:
|
||||
layer.marlin_state = GPTQMarlinState.READY
|
||||
|
||||
# Newly generated tensors need to replace existing tensors that are
|
||||
# already registered as parameters by vLLM (and won't be freed)
|
||||
def replace_tensor(name, new_t):
|
||||
# It is important to use resize_() here since it ensures
|
||||
# the same buffer is reused
|
||||
getattr(layer, name).resize_(new_t.shape)
|
||||
getattr(layer, name).copy_(new_t)
|
||||
del new_t
|
||||
|
||||
cur_device = layer.qweight.device
|
||||
|
||||
# Process act_order
|
||||
if self.quant_config.desc_act:
|
||||
# Get sorting based on g_idx
|
||||
g_idx_sort_indices = torch.argsort(layer.g_idx).to(torch.int)
|
||||
|
||||
sorted_g_idx = layer.g_idx[g_idx_sort_indices]
|
||||
|
||||
replace_tensor("g_idx", sorted_g_idx)
|
||||
replace_tensor("g_idx_sort_indices", g_idx_sort_indices)
|
||||
|
||||
else:
|
||||
# Reset g_idx related tensors
|
||||
layer.g_idx = Parameter(torch.empty(0,
|
||||
dtype=torch.int,
|
||||
device=cur_device),
|
||||
requires_grad=False)
|
||||
layer.g_idx_sort_indices = Parameter(torch.empty(
|
||||
0, dtype=torch.int, device=cur_device),
|
||||
requires_grad=False)
|
||||
|
||||
# Repack weights
|
||||
marlin_qweight = ops.gptq_marlin_repack(
|
||||
layer.qweight,
|
||||
layer.g_idx_sort_indices,
|
||||
part_size_k,
|
||||
part_size_n,
|
||||
)
|
||||
replace_tensor("qweight", marlin_qweight)
|
||||
|
||||
# Permute scales
|
||||
scales_size_k = part_size_k
|
||||
scales_size_n = part_size_n
|
||||
if self.quant_config.desc_act:
|
||||
scales_size_k = full_size_k
|
||||
|
||||
marlin_scales = marlin_permute_scales(layer.scales, scales_size_k,
|
||||
scales_size_n,
|
||||
self.quant_config.group_size)
|
||||
replace_tensor("scales", marlin_scales)
|
||||
|
||||
output = ops.gptq_marlin_gemm(reshaped_x, layer.qweight, layer.scales,
|
||||
layer.g_idx, layer.g_idx_sort_indices,
|
||||
layer.workspace, size_m, part_size_n,
|
||||
part_size_k, layer.is_k_full)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
||||
Loading…
x
Reference in New Issue
Block a user