[Kernel] Marlin Expansion: Support AutoGPTQ Models with Marlin (#3922)

Co-authored-by: alexm <alexm@neuralmagic.com>
Co-authored-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
Robert Shaw 2024-04-29 12:35:34 -04:00 committed by GitHub
parent df29793dc7
commit 73c8d677e5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 2627 additions and 105 deletions

View File

@ -177,6 +177,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/aqlm/gemm_kernels.cu"
"csrc/quantization/awq/gemm_kernels.cu"
"csrc/quantization/marlin/marlin_cuda_kernel.cu"
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
"csrc/custom_all_reduce.cu")
endif()

View File

@ -124,6 +124,24 @@ torch::Tensor marlin_gemm(
int64_t size_m,
int64_t size_n,
int64_t size_k);
torch::Tensor gptq_marlin_gemm(
torch::Tensor &a,
torch::Tensor &b_q_weight,
torch::Tensor &b_scales,
torch::Tensor &g_idx,
torch::Tensor &perm,
torch::Tensor &workspace,
int64_t size_m,
int64_t size_n,
int64_t size_k,
bool is_k_full);
torch::Tensor gptq_marlin_repack(
torch::Tensor &b_q_weight,
torch::Tensor &perm,
int64_t size_k,
int64_t size_n);
#endif
void squeezellm_gemm(

View File

@ -67,6 +67,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM");
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ");
ops.def("gptq_marlin_gemm", &gptq_marlin_gemm, "gptq_marlin Optimized Quantized GEMM for GPTQ");
ops.def("gptq_marlin_repack", &gptq_marlin_repack, "gptq_marlin repack from GPTQ");
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
#endif

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,74 @@
#pragma once
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>
namespace gptq_marlin {
// 8 warps are a good choice since every SM has 4 schedulers and having more than 1 warp per
// schedule allows some more latency hiding. At the same time, we want relatively few warps to have
// many registers per warp and small tiles.
static constexpr int default_threads = 256;
static constexpr int pipe_stages = 4; // 4 pipeline stages fit into shared memory
static constexpr int min_thread_n = 64;
static constexpr int min_thread_k = 64;
static constexpr int tile_size = 16;
static constexpr int max_par = 16;
static constexpr int pack_factor_4bit = 8; // We have 8 4-bit vals inside a 32 bit
template <typename T, int n>
struct Vec {
T elems[n];
__device__ T& operator[](int i) { return elems[i]; }
};
using I4 = Vec<int, 4>;
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// No support for async
#else
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred),
"r"(smem), "l"(glob_ptr), "n"(BYTES));
}
__device__ inline void cp_async4_stream(void* smem_ptr, const void* glob_ptr) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("{\n"
" .reg .b64 p;\n"
" createpolicy.fractional.L2::evict_first.b64 p, 1.0;"
" cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n"
"}\n" ::"r"(smem),
"l"(glob_ptr), "n"(BYTES));
}
__device__ inline void cp_async_fence() { asm volatile("cp.async.commit_group;\n" ::); }
template <int n>
__device__ inline void cp_async_wait() {
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
}
#endif
} // namespace gptq_marlin

View File

@ -0,0 +1,324 @@
#include "gptq_marlin.cuh"
namespace gptq_marlin {
static constexpr int repack_stages = 8;
static constexpr int repack_threads = 256;
static constexpr int tile_k_size = tile_size;
static constexpr int tile_n_size = tile_k_size * 4;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
template <int const num_threads, bool const has_perm>
__global__ void
marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
uint32_t const *__restrict__ perm_ptr,
uint32_t *__restrict__ out_ptr, int size_k, int size_n) {}
} // namespace gptq_marlin
torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
int64_t size_k, int64_t size_n) {
TORCH_CHECK_NOT_IMPLEMENTED(
false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0");
return torch::empty({1, 1});
}
#else
template <int const num_threads, bool const has_perm>
__global__ void
marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
uint32_t const *__restrict__ perm_ptr,
uint32_t *__restrict__ out_ptr, int size_k, int size_n) {
int k_tiles = size_k / tile_k_size;
int n_tiles = size_n / tile_n_size;
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
int start_k_tile = blockIdx.x * block_k_tiles;
if (start_k_tile >= k_tiles) {
return;
}
int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);
// Wait until the next thread tile has been loaded to shared memory.
auto wait_for_stage = [&]() {
// We only have `stages - 2` active fetches since we are double buffering
// and can only issue the next fetch when it is guaranteed that the previous
// shared memory load is fully complete (as it may otherwise be
// overwritten).
cp_async_wait<repack_stages - 2>();
__syncthreads();
};
extern __shared__ int4 sh[];
constexpr int perm_size = tile_k_size / 4;
int4 *sh_perm_ptr = sh;
int4 *sh_pipe_ptr = sh_perm_ptr;
if constexpr (has_perm) {
sh_pipe_ptr += perm_size;
}
constexpr int stage_n_threads = tile_n_size / 4;
constexpr int stage_k_threads =
has_perm ? tile_k_size : tile_k_size / pack_factor_4bit;
constexpr int stage_size = stage_k_threads * stage_n_threads;
auto load_perm_to_shared = [&](int k_tile_id) {
int first_k_int4 = (k_tile_id * tile_k_size) / 4;
int4 const *perm_int4_ptr = reinterpret_cast<int4 const *>(perm_ptr);
if (threadIdx.x < perm_size) {
sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x];
}
__syncthreads();
};
auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
if (n_tile_id >= n_tiles) {
cp_async_fence();
return;
}
int first_n = n_tile_id * tile_n_size;
int4 *sh_ptr = sh_pipe_ptr + stage_size * pipe;
if constexpr (has_perm) {
if (threadIdx.x < stage_size) {
int k_id = threadIdx.x / stage_n_threads;
int n_id = threadIdx.x % stage_n_threads;
uint32_t const *sh_perm_int_ptr =
reinterpret_cast<uint32_t const *>(sh_perm_ptr);
int src_k = sh_perm_int_ptr[k_id];
int src_k_packed = src_k / pack_factor_4bit;
cp_async4_stream(
&sh_ptr[k_id * stage_n_threads + n_id],
reinterpret_cast<int4 const *>(&(
b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)])));
}
} else {
if (threadIdx.x < stage_size) {
int k_id = threadIdx.x / stage_n_threads;
int n_id = threadIdx.x % stage_n_threads;
int first_k = k_tile_id * tile_k_size;
int first_k_packed = first_k / pack_factor_4bit;
cp_async4_stream(&sh_ptr[k_id * stage_n_threads + n_id],
reinterpret_cast<int4 const *>(
&(b_q_weight_ptr[(first_k_packed + k_id) * size_n +
first_n + (n_id * 4)])));
}
}
cp_async_fence();
};
auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {
if (n_tile_id >= n_tiles) {
return;
}
int warp_id = threadIdx.x / 32;
int th_id = threadIdx.x % 32;
if (warp_id >= 4) {
return;
}
int tc_col = th_id / 4;
int tc_row = (th_id % 4) * 2;
constexpr int tc_offsets[4] = {0, 1, 8, 9};
int cur_n = warp_id * 16 + tc_col;
constexpr int sh_stride = 64;
int4 *sh_stage_ptr = sh_pipe_ptr + stage_size * pipe;
uint32_t *sh_stage_int_ptr = reinterpret_cast<uint32_t *>(sh_stage_ptr);
uint32_t *sh_perm_int_ptr = reinterpret_cast<uint32_t *>(sh_perm_ptr);
uint32_t vals[pack_factor_4bit];
if constexpr (has_perm) {
for (int i = 0; i < 4; i++) {
int k_idx = tc_row + tc_offsets[i];
uint32_t src_k = sh_perm_int_ptr[k_idx];
uint32_t src_k_pos = src_k % pack_factor_4bit;
uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n];
uint32_t b1_cur_val = (b1_val >> (src_k_pos * 4)) & 0xf;
uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8];
uint32_t b2_cur_val = (b2_val >> (src_k_pos * 4)) & 0xf;
vals[i] = b1_cur_val;
vals[4 + i] = b2_cur_val;
}
} else {
uint32_t b1_val_1 = sh_stage_int_ptr[cur_n];
uint32_t b1_val_2 = sh_stage_int_ptr[sh_stride + cur_n];
uint32_t b2_val_1 = sh_stage_int_ptr[cur_n + 8];
uint32_t b2_val_2 = sh_stage_int_ptr[sh_stride + cur_n + 8];
#pragma unroll
for (int i = 0; i < 2; i++) {
int cur_elem = tc_row + tc_offsets[i];
vals[i] = (b1_val_1 >> (cur_elem * 4)) & 0xf;
vals[4 + i] = (b2_val_1 >> (cur_elem * 4)) & 0xf;
}
#pragma unroll
for (int i = 2; i < 4; i++) {
int cur_elem = tc_row + tc_offsets[i] - 8;
vals[i] = (b1_val_2 >> (cur_elem * 4)) & 0xf;
vals[4 + i] = (b2_val_2 >> (cur_elem * 4)) & 0xf;
}
}
// Result of:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
constexpr int pack_idx[pack_factor_4bit] = {0, 2, 4, 6, 1, 3, 5, 7};
uint32_t res = 0;
#pragma unroll
for (int i = 0; i < pack_factor_4bit; i++) {
res |= vals[pack_idx[i]] << (i * 4);
}
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor_4bit;
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
out_ptr[out_offset + th_id * 4 + warp_id] = res;
};
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
#pragma unroll
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
}
wait_for_stage();
};
#pragma unroll
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
int n_tile_id = 0;
if constexpr (has_perm) {
load_perm_to_shared(k_tile_id);
}
start_pipes(k_tile_id, n_tile_id);
while (n_tile_id < n_tiles) {
#pragma unroll
for (int pipe = 0; pipe < repack_stages; pipe++) {
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
n_tile_id + pipe + repack_stages - 1);
repack_tile(pipe, k_tile_id, n_tile_id + pipe);
wait_for_stage();
}
n_tile_id += repack_stages;
}
}
}
} // namespace gptq_marlin
torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
int64_t size_k, int64_t size_n) {
// Verify compatibility with marlin tile of 16x64
TORCH_CHECK(size_k % gptq_marlin::tile_k_size == 0, "size_k = ", size_k,
" is not divisible by tile_k_size = ", gptq_marlin::tile_k_size);
TORCH_CHECK(size_n % gptq_marlin::tile_n_size == 0, "size_n = ", size_n,
" is not divisible by tile_n_size = ", gptq_marlin::tile_n_size);
// Verify B
TORCH_CHECK((size_k / gptq_marlin::pack_factor_4bit) == b_q_weight.size(0),
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
", size_k = ", size_k,
", pack_factor_4bit = ", gptq_marlin::pack_factor_4bit);
TORCH_CHECK(b_q_weight.size(1) == size_n,
"b_q_weight.size(1) = ", b_q_weight.size(1),
" is not size_n = ", size_n);
// Verify device and strides
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt");
TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU");
TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous");
TORCH_CHECK(perm.dtype() == at::kInt, "perm type is not at::kInt");
// Alloc buffers
const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
auto options = torch::TensorOptions()
.dtype(b_q_weight.dtype())
.device(b_q_weight.device());
torch::Tensor out = torch::empty(
{size_k / gptq_marlin::tile_size,
size_n * gptq_marlin::tile_size / gptq_marlin::pack_factor_4bit},
options);
// Detect if there is act_order
bool has_perm = perm.size(0) != 0;
// Get ptrs
uint32_t const *b_q_weight_ptr =
reinterpret_cast<uint32_t const *>(b_q_weight.data_ptr());
uint32_t const *perm_ptr =
reinterpret_cast<uint32_t const *>(perm.data_ptr());
uint32_t *out_ptr = reinterpret_cast<uint32_t *>(out.data_ptr());
// Get dev info
int dev = b_q_weight.get_device();
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
int blocks;
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
int max_shared_mem = 0;
cudaDeviceGetAttribute(&max_shared_mem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
TORCH_CHECK(max_shared_mem > 0);
if (has_perm) {
cudaFuncSetAttribute(
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, true>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
max_shared_mem);
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, true>
<<<blocks, gptq_marlin::repack_threads, max_shared_mem,
stream>>>(b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n);
} else {
cudaFuncSetAttribute(
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, false>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
max_shared_mem);
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, false>
<<<blocks, gptq_marlin::repack_threads, max_shared_mem,
stream>>>(b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n);
}
return out;
}
#endif

View File

@ -0,0 +1,93 @@
"""Compares the outputs of gptq vs gptq_marlin
Note: GPTQ and Marlin do not have bitwise correctness.
As a result, in this test, we just confirm that the top selected tokens of the
Marlin/GPTQ models are in the top 3 selections of each other.
Note: Marlin internally uses locks to synchronize the threads. This can
result in very slight nondeterminism for Marlin. As a result, we re-run the test
up to 3 times to see if we pass.
Note: This test currently fails running with --forked with the following:
RuntimeError: Cannot re-initialize CUDA in forked subprocess.
To use CUDA with multiprocessing, you must use the 'spawn' start method
Run `pytest tests/models/test_gptq_marlin.py`.
"""
import os
import pytest
import torch
from tests.models.utils import check_logprobs_close
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
os.environ["TOKENIZERS_PARALLELISM"] = "true"
MAX_MODEL_LEN = 1024
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
gptq_marlin_not_supported = (
capability < QUANTIZATION_METHODS["gptq_marlin"].get_min_capability())
MODELS = [
# act_order==False, group_size=channelwise
("robertgshaw2/zephyr-7b-beta-channelwise-gptq", "main"),
# act_order==False, group_size=128
("TheBloke/Llama-2-7B-GPTQ", "main"),
# act_order==True, group_size=128
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "main"),
# act_order==True, group_size=64
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-4bit-64g-actorder_True"),
# act_order==True, group_size=32
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-4bit-32g-actorder_True"),
]
@pytest.mark.flaky(reruns=2)
@pytest.mark.skipif(gptq_marlin_not_supported,
reason="gptq_marlin is not supported on this GPU type.")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(
vllm_runner,
example_prompts,
model,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
model_name, revision = model
# Run marlin.
gptq_marlin_model = vllm_runner(model_name=model_name,
revision=revision,
dtype=dtype,
quantization="marlin",
max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=1,
disable_custom_all_reduce=True)
gptq_marlin_outputs = gptq_marlin_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
del gptq_marlin_model
# Run gptq.
gptq_model = vllm_runner(model_name=model_name,
revision=revision,
dtype=dtype,
quantization="gptq",
max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=1,
disable_custom_all_reduce=True)
gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts,
max_tokens,
num_logprobs)
del gptq_model
check_logprobs_close(
outputs_0_lst=gptq_outputs,
outputs_1_lst=gptq_marlin_outputs,
name_0="gptq",
name_1="gptq_marlin",
)

View File

@ -10,12 +10,12 @@ up to 3 times to see if we pass.
Run `pytest tests/models/test_marlin.py`.
"""
from dataclasses import dataclass
import pytest
import torch
from tests.models.utils import check_logprobs_close
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
capability = torch.cuda.get_device_capability()
@ -55,43 +55,24 @@ def test_models(
max_tokens: int,
num_logprobs: int,
) -> None:
marlin_model = vllm_runner(model_pair.model_marlin, dtype=dtype)
marlin_model = vllm_runner(model_pair.model_marlin,
dtype=dtype,
quantization="marlin")
marlin_outputs = marlin_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
# Note: not sure why, but deleting just the model on Ada Lovelace
# does not free the GPU memory. On Ampere, deleting the just model
# frees the memory.
del marlin_model
gptq_model = vllm_runner(model_pair.model_gptq, dtype=dtype)
gptq_model = vllm_runner(model_pair.model_gptq,
dtype=dtype,
quantization="gptq")
gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts,
max_tokens,
num_logprobs)
# Note: not sure why, but deleting just the model on Ada Lovelace
# does not free the GPU memory. On Ampere, deleting the just model
# frees the memory.
del gptq_model
# loop through the prompts
for prompt_idx in range(len(example_prompts)):
gptq_output_ids, gptq_output_str, gptq_logprobs = gptq_outputs[
prompt_idx]
marlin_output_ids, marlin_output_str, marlin_logprobs = marlin_outputs[
prompt_idx]
for idx, (gptq_output_id, marlin_output_id) in enumerate(
zip(gptq_output_ids, marlin_output_ids)):
# If sequence is not an exact match,
if marlin_output_id != gptq_output_id:
# Each predicted token must be in top 5 of the other's
assert gptq_output_id in marlin_logprobs[idx], (
f"Test{prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\n"
f"Marlin:\t{marlin_output_str!r}")
assert marlin_output_id in gptq_logprobs[idx], (
f"Test{prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\n"
f"Marlin:\t{marlin_output_str!r}")
# Break out since sequences will now diverge.
break
check_logprobs_close(
outputs_0_lst=gptq_outputs,
outputs_1_lst=marlin_outputs,
name_0="gptq",
name_1="marlin",
)

29
tests/models/utils.py Normal file
View File

@ -0,0 +1,29 @@
def check_logprobs_close(outputs_0_lst, outputs_1_lst, name_0, name_1):
"""Compare the logprobs of two sequences generated by different models,
which should be similar but not necessarily equal.
"""
# Loop through responses to each prompt.
for prompt_idx, (outputs_0,
outputs_1) in enumerate(zip(outputs_0_lst,
outputs_1_lst)):
output_ids_0, output_str_0, logprobs_0 = outputs_0
output_ids_1, output_str_1, logprobs_1 = outputs_1
# Loop through generated tokens.
for idx, (output_id_0,
output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):
# If generated tokens don't match, then
if output_id_0 != output_id_1:
# Each predicted token must be in top N logprobs of the other
assert output_id_0 in logprobs_1[idx], (
f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}")
assert output_id_1 in logprobs_0[idx], (
f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}")
# Break out since sequences will now diverge.
break

View File

@ -1,64 +0,0 @@
"""Tests whether Marlin models can be loaded from the autogptq config.
Run `pytest tests/quantization/test_autogptq_marlin_configs.py --forked`.
"""
from dataclasses import dataclass
import pytest
from vllm.config import ModelConfig
@dataclass
class ModelPair:
model_marlin: str
model_gptq: str
# Model Id // Expected Kernel
MODELS_QUANT_TYPE = [
# compat: autogptq <=0.7.1 is_marlin_format: bool
("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "marlin"),
("TheBloke/Llama-2-7B-Chat-GPTQ", "gptq"),
# compat: autogptq >=0.8.0 use checkpoint_format: str
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "marlin"),
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "gptq")
]
@pytest.mark.parametrize("model_quant_type", MODELS_QUANT_TYPE)
def test_auto_gptq(model_quant_type: str, ) -> None:
model_path, quant_type = model_quant_type
model_config_no_quant_arg = ModelConfig(
model_path,
model_path,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
revision=None,
quantization=None # case 1
)
model_config_quant_arg = ModelConfig(
model_path,
model_path,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
revision=None,
quantization="gptq" # case 2
)
assert model_config_no_quant_arg.quantization == quant_type, (
f"Expected quant_type == {quant_type} for {model_path}, "
f"but found {model_config_no_quant_arg.quantization} "
"for no --quantization None case")
assert model_config_quant_arg.quantization == quant_type, (
f"Expected quant_type == {quant_type} for {model_path}, "
f"but found {model_config_quant_arg.quantization} "
"for --quantization gptq case")

View File

@ -0,0 +1,73 @@
"""Tests whether Marlin models can be loaded from the autogptq config.
Run `pytest tests/quantization/test_configs.py --forked`.
"""
from dataclasses import dataclass
import pytest
from vllm.config import ModelConfig
@dataclass
class ModelPair:
model_marlin: str
model_gptq: str
# Model Id // Quantization Arg // Expected Type
MODEL_ARG_EXPTYPES = [
# AUTOGPTQ
# compat: autogptq <=0.7.1 is_marlin_format: bool
# Model Serialized in Marlin Format should always use Marlin kernel.
("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", None, "marlin"),
("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "marlin", "marlin"),
("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "gptq", "marlin"),
("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "awq", "ERROR"),
# Model Serialized in Exllama Format.
("TheBloke/Llama-2-7B-Chat-GPTQ", None, "gptq_marlin"),
("TheBloke/Llama-2-7B-Chat-GPTQ", "marlin", "gptq_marlin"),
("TheBloke/Llama-2-7B-Chat-GPTQ", "gptq", "gptq"),
("TheBloke/Llama-2-7B-Chat-GPTQ", "awq", "ERROR"),
# compat: autogptq >=0.8.0 use checkpoint_format: str
# Model Serialized in Marlin Format should always use Marlin kernel.
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", None, "marlin"),
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "marlin", "marlin"),
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "gptq", "marlin"),
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "awq", "ERROR"),
# Model Serialized in Exllama Format.
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", None, "gptq_marlin"),
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "marlin", "gptq_marlin"),
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "gptq", "gptq"),
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "awq", "ERROR"),
# AUTOAWQ
("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", None, "awq"),
("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "awq", "awq"),
("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "marlin", "ERROR"),
("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "gptq", "ERROR"),
]
@pytest.mark.parametrize("model_arg_exptype", MODEL_ARG_EXPTYPES)
def test_auto_gptq(model_arg_exptype: str) -> None:
model_path, quantization_arg, expected_type = model_arg_exptype
try:
model_config = ModelConfig(model_path,
model_path,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
revision=None,
quantization=quantization_arg)
found_quantization_type = model_config.quantization
except ValueError:
found_quantization_type = "ERROR"
assert found_quantization_type == expected_type, (
f"Expected quant_type == {expected_type} for {model_path}, "
f"but found {found_quantization_type} "
f"for no --quantization {quantization_arg} case")

View File

@ -9,11 +9,14 @@ from packaging.version import Version
from transformers import PretrainedConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
get_quantization_config)
from vllm.transformers_utils.config import get_config, get_hf_text_config
from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_cpu, is_hip,
is_neuron)
GPTQMarlinConfig = get_quantization_config("gptq_marlin")
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
@ -138,14 +141,34 @@ class ModelConfig:
is_format_marlin = (quant_cfg.get("checkpoint_format") == "marlin"
or quant_cfg.get("is_marlin_format", False))
# Use marlin if the GPTQ model is serialized in marlin format.
if quant_method == "gptq" and is_format_marlin:
logger.info("The model is serialized in Marlin format. "
"Using Marlin kernel.")
quant_method = "marlin"
if self.quantization == "gptq":
self.quantization = quant_method
# Check which LinearMethod the GPTQ model should use.
if quant_method == "gptq":
# If serialized in Marlin format, use MarlinLinearMethod.
# TODO (@robertgshaw): migrate under GPTQMarlinLinearMethod.
if is_format_marlin:
logger.info("The model is serialized in Marlin format. "
"Using Marlin kernel.")
quant_method = "marlin"
if self.quantization == "gptq":
self.quantization = quant_method
# If convertible to Marlin format, use GPTQMarlinLinearMethod
# unless the user explicitly specified GPTQLinearMethod.
elif GPTQMarlinConfig.is_marlin_compatible(quant_cfg):
if self.quantization == "gptq":
logger.warning(
"The model is convertible to Marlin format, but "
"you specified quantization=gptq. Use "
"quantization=marlin for faster inference.")
else:
logger.info(
"The model is convertible to Marlin format. "
"Using Marlin kernel.")
quant_method = "gptq_marlin"
if self.quantization == "marlin":
self.quantization = quant_method
# Verify quantization configurations.
if self.quantization is None:
self.quantization = quant_method
elif self.quantization != quant_method:
@ -165,7 +188,7 @@ class ModelConfig:
raise ValueError(
f"{self.quantization} quantization is currently not "
f"supported in ROCm.")
if self.quantization != "marlin":
if (self.quantization not in ["marlin", "gptq_marlin"]):
logger.warning(
"%s quantization is not fully "
"optimized yet. The speed can be slower than "

View File

@ -6,6 +6,8 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig)
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
@ -15,6 +17,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"fp8": Fp8Config,
"gptq": GPTQConfig,
"squeezellm": SqueezeLLMConfig,
"gptq_marlin": GPTQMarlinConfig,
"marlin": MarlinConfig,
}

View File

@ -0,0 +1,444 @@
import enum
from enum import Enum
from typing import Any, Dict, List, Optional
import numpy
import torch
from torch.nn.parameter import Parameter
from vllm._C import ops
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
GPTQ_MARLIN_TILE = 16
GPTQ_MARLIN_MIN_THREAD_N = 64
GPTQ_MARLIN_MIN_THREAD_K = 128
GPTQ_MARLIN_MAX_PARALLEL = 16
GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4]
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
GPTQ_MARLIN_SUPPORTED_SYM = [True]
# Precompute permutations for Marlin weight and scale shuffling
#
# Marlin works on [16,64] tiles. The goal of the permutations
# is to reorder the weight data so that it is compatible
# with the tensor-core format that is described here:
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
#
# As a result of this reordering, the vector loads inside the
# kernel will get the data as it is needed for tensor-core
# (without the need to use ldmatrix instructions)
def _get_perms():
perm = []
for i in range(32):
perm1 = []
col = i // 4
for block in [0, 1]:
for row in [
2 * (i % 4),
2 * (i % 4) + 1,
2 * (i % 4 + 4),
2 * (i % 4 + 4) + 1,
]:
perm1.append(16 * row + col + 8 * block)
for j in range(4):
perm.extend([p + 256 * j for p in perm1])
perm = numpy.array(perm)
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
perm = perm.reshape((-1, 8))[:, interleave].ravel() # type: ignore
perm = torch.from_numpy(perm)
scale_perm = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single = []
for i in range(4):
scale_perm_single.extend(
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return perm, scale_perm, scale_perm_single
_perm, _scale_perm, _scale_perm_single = _get_perms()
def get_pack_factor(num_bits):
assert num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS, (
f"Unsupported num_bits = {num_bits}")
return 32 // num_bits
def marlin_permute_scales(s, size_k, size_n, group_size):
if group_size < size_k and group_size != -1:
s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm]
else:
s = s.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single]
s = s.reshape((-1, size_n)).contiguous()
return s
class GPTQMarlinConfig(QuantizationConfig):
"""Config class for GPTQ Marlin"""
def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
is_sym: bool) -> None:
if desc_act and group_size == -1:
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
desc_act = False
self.weight_bits = weight_bits
self.group_size = group_size
self.desc_act = desc_act
self.is_sym = is_sym
# Verify
if self.weight_bits not in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
raise ValueError(
f"Marlin does not support weight_bits = {self.weight_bits}. "
f"Only weight_bits = {GPTQ_MARLIN_SUPPORTED_NUM_BITS} "
"are supported.")
if self.group_size not in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES:
raise ValueError(
f"Marlin does not support group_size = {self.group_size}. "
f"Only group_sizes = {GPTQ_MARLIN_SUPPORTED_GROUP_SIZES} "
"are supported.")
if self.is_sym not in GPTQ_MARLIN_SUPPORTED_SYM:
raise ValueError(
f"Marlin does not support is_sym = {self.is_sym}. "
f"Only sym = {GPTQ_MARLIN_SUPPORTED_SYM} are supported.")
# Init
self.pack_factor = get_pack_factor(weight_bits)
self.tile_size = GPTQ_MARLIN_TILE
self.min_thread_n = GPTQ_MARLIN_MIN_THREAD_N
self.min_thread_k = GPTQ_MARLIN_MIN_THREAD_K
self.max_parallel = GPTQ_MARLIN_MAX_PARALLEL
def __repr__(self) -> str:
return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act})")
@classmethod
def get_name(cls) -> str:
return "gptq_marlin"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half]
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
desc_act = cls.get_from_keys(config, ["desc_act"])
is_sym = cls.get_from_keys(config, ["sym"])
return cls(weight_bits, group_size, desc_act, is_sym)
def get_quant_method(
self,
layer: torch.nn.Module) -> Optional["GPTQMarlinLinearMethod"]:
if isinstance(layer, LinearBase):
return GPTQMarlinLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
@classmethod
def is_marlin_compatible(cls, quant_config: Dict[str, Any]):
# Extract data from quant config.
num_bits = quant_config.get("bits", None)
group_size = quant_config.get("group_size", None)
sym = quant_config.get("sym", None)
desc_act = quant_config.get("desc_act", None)
# If we cannot find the info needed in the config, cannot convert.
if (num_bits is None or group_size is None or sym is None
or desc_act is None):
return False
# If the capability of the device is too low, cannot convert.
major, minor = torch.cuda.get_device_capability()
device_capability = major * 10 + minor
if device_capability < cls.get_min_capability():
return False
# Otherwise, can convert if model satisfies marlin constraints.
return (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS
and group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
and sym in GPTQ_MARLIN_SUPPORTED_SYM)
class GPTQMarlinState(Enum):
REPACK = enum.auto()
READY = enum.auto()
class GPTQMarlinLinearMethod(LinearMethodBase):
"""Linear method for GPTQ Marlin.
Args:
quant_config: The GPTQ Marlin quantization config.
"""
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
del output_size
# Normalize group_size
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size
# Validate dtype
if params_dtype != torch.float16:
raise ValueError(
f"The params dtype must be float16, but got {params_dtype}")
# Validate output_size_per_partition
output_size_per_partition = sum(output_partition_sizes)
if output_size_per_partition % self.quant_config.min_thread_n != 0:
raise ValueError(
f"Weight output_size_per_partition = "
f"{output_size_per_partition} is not divisible by "
f" min_thread_n = {self.quant_config.min_thread_n}.")
# Validate input_size_per_partition
if input_size_per_partition % self.quant_config.min_thread_k != 0:
raise ValueError(
f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible "
f"by min_thread_k = {self.quant_config.min_thread_k}.")
if (group_size < input_size
and input_size_per_partition % group_size != 0):
raise ValueError(
f"Weight input_size_per_partition = {input_size_per_partition}"
f" is not divisible by group_size = {group_size}.")
# Detect sharding of scales/zp
# By default, no sharding over "input dim"
scales_and_zp_size = input_size // group_size
scales_and_zp_input_dim = None
if self.quant_config.desc_act:
# Act-order case
assert self.quant_config.group_size != -1
is_k_full = input_size_per_partition == input_size
else:
# No act-order case
# K is always full due to full alignment with
# group-size and shard of scales/zp
is_k_full = True
# If this is a row-parallel case, then shard scales/zp
if (input_size != input_size_per_partition
and self.quant_config.group_size != -1):
scales_and_zp_size = input_size_per_partition // group_size
scales_and_zp_input_dim = 0
# Init buffers
# Quantized weights
qweight = Parameter(
torch.empty(
input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qweight, {
**extra_weight_attrs,
"input_dim": 0,
"output_dim": 1,
"packed_dim": 0,
"pack_factor": self.quant_config.pack_factor,
})
# Activation order
g_idx = Parameter(
torch.empty(
input_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
# Ignore warning from fused linear layers such as QKVParallelLinear.
set_weight_attrs(g_idx, {
**extra_weight_attrs, "input_dim": 0,
"ignore_warning": True
})
g_idx_sort_indices = Parameter(
torch.empty(
g_idx.shape,
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(g_idx_sort_indices, extra_weight_attrs)
# Scales
scales = Parameter(
torch.empty(
scales_and_zp_size,
output_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(
scales, {
**extra_weight_attrs,
"input_dim": scales_and_zp_input_dim,
"output_dim": 1,
})
# Quantized zero-points
qzeros = Parameter(
torch.empty(scales_and_zp_size,
output_size_per_partition //
self.quant_config.pack_factor,
dtype=torch.int32,
device="meta"),
requires_grad=False,
)
set_weight_attrs(
qzeros, {
**extra_weight_attrs,
"input_dim": scales_and_zp_input_dim,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
})
# Allocate marlin workspace
max_workspace_size = (
output_size_per_partition //
self.quant_config.min_thread_n) * self.quant_config.max_parallel
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
requires_grad=False)
layer.register_parameter("qweight", qweight)
layer.register_parameter("g_idx", g_idx)
layer.register_parameter("g_idx_sort_indices", g_idx_sort_indices)
layer.register_parameter("scales", scales)
layer.register_parameter("qzeros", qzeros)
layer.workspace = workspace
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.input_size = input_size
layer.is_k_full = is_k_full
layer.marlin_state = GPTQMarlinState.REPACK
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
reshaped_x = x.reshape(-1, x.shape[-1])
size_m = reshaped_x.shape[0]
part_size_n = layer.output_size_per_partition
part_size_k = layer.input_size_per_partition
full_size_k = layer.input_size
out_shape = x.shape[:-1] + (part_size_n, )
if layer.marlin_state == GPTQMarlinState.REPACK:
layer.marlin_state = GPTQMarlinState.READY
# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
def replace_tensor(name, new_t):
# It is important to use resize_() here since it ensures
# the same buffer is reused
getattr(layer, name).resize_(new_t.shape)
getattr(layer, name).copy_(new_t)
del new_t
cur_device = layer.qweight.device
# Process act_order
if self.quant_config.desc_act:
# Get sorting based on g_idx
g_idx_sort_indices = torch.argsort(layer.g_idx).to(torch.int)
sorted_g_idx = layer.g_idx[g_idx_sort_indices]
replace_tensor("g_idx", sorted_g_idx)
replace_tensor("g_idx_sort_indices", g_idx_sort_indices)
else:
# Reset g_idx related tensors
layer.g_idx = Parameter(torch.empty(0,
dtype=torch.int,
device=cur_device),
requires_grad=False)
layer.g_idx_sort_indices = Parameter(torch.empty(
0, dtype=torch.int, device=cur_device),
requires_grad=False)
# Repack weights
marlin_qweight = ops.gptq_marlin_repack(
layer.qweight,
layer.g_idx_sort_indices,
part_size_k,
part_size_n,
)
replace_tensor("qweight", marlin_qweight)
# Permute scales
scales_size_k = part_size_k
scales_size_n = part_size_n
if self.quant_config.desc_act:
scales_size_k = full_size_k
marlin_scales = marlin_permute_scales(layer.scales, scales_size_k,
scales_size_n,
self.quant_config.group_size)
replace_tensor("scales", marlin_scales)
output = ops.gptq_marlin_gemm(reshaped_x, layer.qweight, layer.scales,
layer.g_idx, layer.g_idx_sort_indices,
layer.workspace, size_m, part_size_n,
part_size_k, layer.is_k_full)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)