mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 20:54:59 +08:00
[Kernel] Marlin Expansion: Support AutoGPTQ Models with Marlin (#3922)
Co-authored-by: alexm <alexm@neuralmagic.com> Co-authored-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
parent
df29793dc7
commit
73c8d677e5
@ -177,6 +177,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
"csrc/quantization/aqlm/gemm_kernels.cu"
|
"csrc/quantization/aqlm/gemm_kernels.cu"
|
||||||
"csrc/quantization/awq/gemm_kernels.cu"
|
"csrc/quantization/awq/gemm_kernels.cu"
|
||||||
"csrc/quantization/marlin/marlin_cuda_kernel.cu"
|
"csrc/quantization/marlin/marlin_cuda_kernel.cu"
|
||||||
|
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
||||||
|
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
|
||||||
"csrc/custom_all_reduce.cu")
|
"csrc/custom_all_reduce.cu")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|||||||
18
csrc/ops.h
18
csrc/ops.h
@ -124,6 +124,24 @@ torch::Tensor marlin_gemm(
|
|||||||
int64_t size_m,
|
int64_t size_m,
|
||||||
int64_t size_n,
|
int64_t size_n,
|
||||||
int64_t size_k);
|
int64_t size_k);
|
||||||
|
|
||||||
|
torch::Tensor gptq_marlin_gemm(
|
||||||
|
torch::Tensor &a,
|
||||||
|
torch::Tensor &b_q_weight,
|
||||||
|
torch::Tensor &b_scales,
|
||||||
|
torch::Tensor &g_idx,
|
||||||
|
torch::Tensor &perm,
|
||||||
|
torch::Tensor &workspace,
|
||||||
|
int64_t size_m,
|
||||||
|
int64_t size_n,
|
||||||
|
int64_t size_k,
|
||||||
|
bool is_k_full);
|
||||||
|
|
||||||
|
torch::Tensor gptq_marlin_repack(
|
||||||
|
torch::Tensor &b_q_weight,
|
||||||
|
torch::Tensor &perm,
|
||||||
|
int64_t size_k,
|
||||||
|
int64_t size_n);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void squeezellm_gemm(
|
void squeezellm_gemm(
|
||||||
|
|||||||
@ -67,6 +67,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||||||
ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM");
|
ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM");
|
||||||
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
|
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
|
||||||
ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ");
|
ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ");
|
||||||
|
ops.def("gptq_marlin_gemm", &gptq_marlin_gemm, "gptq_marlin Optimized Quantized GEMM for GPTQ");
|
||||||
|
ops.def("gptq_marlin_repack", &gptq_marlin_repack, "gptq_marlin repack from GPTQ");
|
||||||
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
|
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|||||||
1520
csrc/quantization/gptq_marlin/gptq_marlin.cu
Normal file
1520
csrc/quantization/gptq_marlin/gptq_marlin.cu
Normal file
File diff suppressed because it is too large
Load Diff
74
csrc/quantization/gptq_marlin/gptq_marlin.cuh
Normal file
74
csrc/quantization/gptq_marlin/gptq_marlin.cuh
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
namespace gptq_marlin {
|
||||||
|
|
||||||
|
// 8 warps are a good choice since every SM has 4 schedulers and having more than 1 warp per
|
||||||
|
// schedule allows some more latency hiding. At the same time, we want relatively few warps to have
|
||||||
|
// many registers per warp and small tiles.
|
||||||
|
static constexpr int default_threads = 256;
|
||||||
|
|
||||||
|
static constexpr int pipe_stages = 4; // 4 pipeline stages fit into shared memory
|
||||||
|
|
||||||
|
static constexpr int min_thread_n = 64;
|
||||||
|
static constexpr int min_thread_k = 64;
|
||||||
|
|
||||||
|
static constexpr int tile_size = 16;
|
||||||
|
static constexpr int max_par = 16;
|
||||||
|
|
||||||
|
static constexpr int pack_factor_4bit = 8; // We have 8 4-bit vals inside a 32 bit
|
||||||
|
|
||||||
|
template <typename T, int n>
|
||||||
|
struct Vec {
|
||||||
|
T elems[n];
|
||||||
|
__device__ T& operator[](int i) { return elems[i]; }
|
||||||
|
};
|
||||||
|
|
||||||
|
using I4 = Vec<int, 4>;
|
||||||
|
|
||||||
|
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
|
||||||
|
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
|
// No support for async
|
||||||
|
#else
|
||||||
|
|
||||||
|
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) {
|
||||||
|
const int BYTES = 16;
|
||||||
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||||
|
asm volatile("{\n"
|
||||||
|
" .reg .pred p;\n"
|
||||||
|
" setp.ne.b32 p, %0, 0;\n"
|
||||||
|
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
||||||
|
"}\n" ::"r"((int)pred),
|
||||||
|
"r"(smem), "l"(glob_ptr), "n"(BYTES));
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ inline void cp_async4_stream(void* smem_ptr, const void* glob_ptr) {
|
||||||
|
const int BYTES = 16;
|
||||||
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||||
|
asm volatile("{\n"
|
||||||
|
" .reg .b64 p;\n"
|
||||||
|
" createpolicy.fractional.L2::evict_first.b64 p, 1.0;"
|
||||||
|
" cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n"
|
||||||
|
"}\n" ::"r"(smem),
|
||||||
|
"l"(glob_ptr), "n"(BYTES));
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ inline void cp_async_fence() { asm volatile("cp.async.commit_group;\n" ::); }
|
||||||
|
|
||||||
|
template <int n>
|
||||||
|
__device__ inline void cp_async_wait() {
|
||||||
|
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
} // namespace gptq_marlin
|
||||||
324
csrc/quantization/gptq_marlin/gptq_marlin_repack.cu
Normal file
324
csrc/quantization/gptq_marlin/gptq_marlin_repack.cu
Normal file
@ -0,0 +1,324 @@
|
|||||||
|
#include "gptq_marlin.cuh"
|
||||||
|
|
||||||
|
namespace gptq_marlin {
|
||||||
|
|
||||||
|
static constexpr int repack_stages = 8;
|
||||||
|
|
||||||
|
static constexpr int repack_threads = 256;
|
||||||
|
|
||||||
|
static constexpr int tile_k_size = tile_size;
|
||||||
|
static constexpr int tile_n_size = tile_k_size * 4;
|
||||||
|
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
|
|
||||||
|
template <int const num_threads, bool const has_perm>
|
||||||
|
__global__ void
|
||||||
|
marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
|
||||||
|
uint32_t const *__restrict__ perm_ptr,
|
||||||
|
uint32_t *__restrict__ out_ptr, int size_k, int size_n) {}
|
||||||
|
|
||||||
|
} // namespace gptq_marlin
|
||||||
|
|
||||||
|
torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
|
||||||
|
int64_t size_k, int64_t size_n) {
|
||||||
|
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||||
|
false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0");
|
||||||
|
return torch::empty({1, 1});
|
||||||
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
template <int const num_threads, bool const has_perm>
|
||||||
|
__global__ void
|
||||||
|
marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
|
||||||
|
uint32_t const *__restrict__ perm_ptr,
|
||||||
|
uint32_t *__restrict__ out_ptr, int size_k, int size_n) {
|
||||||
|
int k_tiles = size_k / tile_k_size;
|
||||||
|
int n_tiles = size_n / tile_n_size;
|
||||||
|
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
|
||||||
|
|
||||||
|
int start_k_tile = blockIdx.x * block_k_tiles;
|
||||||
|
if (start_k_tile >= k_tiles) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);
|
||||||
|
|
||||||
|
// Wait until the next thread tile has been loaded to shared memory.
|
||||||
|
auto wait_for_stage = [&]() {
|
||||||
|
// We only have `stages - 2` active fetches since we are double buffering
|
||||||
|
// and can only issue the next fetch when it is guaranteed that the previous
|
||||||
|
// shared memory load is fully complete (as it may otherwise be
|
||||||
|
// overwritten).
|
||||||
|
cp_async_wait<repack_stages - 2>();
|
||||||
|
__syncthreads();
|
||||||
|
};
|
||||||
|
|
||||||
|
extern __shared__ int4 sh[];
|
||||||
|
|
||||||
|
constexpr int perm_size = tile_k_size / 4;
|
||||||
|
|
||||||
|
int4 *sh_perm_ptr = sh;
|
||||||
|
int4 *sh_pipe_ptr = sh_perm_ptr;
|
||||||
|
if constexpr (has_perm) {
|
||||||
|
sh_pipe_ptr += perm_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr int stage_n_threads = tile_n_size / 4;
|
||||||
|
constexpr int stage_k_threads =
|
||||||
|
has_perm ? tile_k_size : tile_k_size / pack_factor_4bit;
|
||||||
|
constexpr int stage_size = stage_k_threads * stage_n_threads;
|
||||||
|
|
||||||
|
auto load_perm_to_shared = [&](int k_tile_id) {
|
||||||
|
int first_k_int4 = (k_tile_id * tile_k_size) / 4;
|
||||||
|
|
||||||
|
int4 const *perm_int4_ptr = reinterpret_cast<int4 const *>(perm_ptr);
|
||||||
|
|
||||||
|
if (threadIdx.x < perm_size) {
|
||||||
|
sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
};
|
||||||
|
|
||||||
|
auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
|
||||||
|
if (n_tile_id >= n_tiles) {
|
||||||
|
cp_async_fence();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int first_n = n_tile_id * tile_n_size;
|
||||||
|
|
||||||
|
int4 *sh_ptr = sh_pipe_ptr + stage_size * pipe;
|
||||||
|
|
||||||
|
if constexpr (has_perm) {
|
||||||
|
if (threadIdx.x < stage_size) {
|
||||||
|
int k_id = threadIdx.x / stage_n_threads;
|
||||||
|
int n_id = threadIdx.x % stage_n_threads;
|
||||||
|
|
||||||
|
uint32_t const *sh_perm_int_ptr =
|
||||||
|
reinterpret_cast<uint32_t const *>(sh_perm_ptr);
|
||||||
|
|
||||||
|
int src_k = sh_perm_int_ptr[k_id];
|
||||||
|
int src_k_packed = src_k / pack_factor_4bit;
|
||||||
|
|
||||||
|
cp_async4_stream(
|
||||||
|
&sh_ptr[k_id * stage_n_threads + n_id],
|
||||||
|
reinterpret_cast<int4 const *>(&(
|
||||||
|
b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)])));
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
if (threadIdx.x < stage_size) {
|
||||||
|
int k_id = threadIdx.x / stage_n_threads;
|
||||||
|
int n_id = threadIdx.x % stage_n_threads;
|
||||||
|
|
||||||
|
int first_k = k_tile_id * tile_k_size;
|
||||||
|
int first_k_packed = first_k / pack_factor_4bit;
|
||||||
|
|
||||||
|
cp_async4_stream(&sh_ptr[k_id * stage_n_threads + n_id],
|
||||||
|
reinterpret_cast<int4 const *>(
|
||||||
|
&(b_q_weight_ptr[(first_k_packed + k_id) * size_n +
|
||||||
|
first_n + (n_id * 4)])));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cp_async_fence();
|
||||||
|
};
|
||||||
|
|
||||||
|
auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {
|
||||||
|
if (n_tile_id >= n_tiles) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int warp_id = threadIdx.x / 32;
|
||||||
|
int th_id = threadIdx.x % 32;
|
||||||
|
|
||||||
|
if (warp_id >= 4) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int tc_col = th_id / 4;
|
||||||
|
int tc_row = (th_id % 4) * 2;
|
||||||
|
|
||||||
|
constexpr int tc_offsets[4] = {0, 1, 8, 9};
|
||||||
|
|
||||||
|
int cur_n = warp_id * 16 + tc_col;
|
||||||
|
|
||||||
|
constexpr int sh_stride = 64;
|
||||||
|
|
||||||
|
int4 *sh_stage_ptr = sh_pipe_ptr + stage_size * pipe;
|
||||||
|
uint32_t *sh_stage_int_ptr = reinterpret_cast<uint32_t *>(sh_stage_ptr);
|
||||||
|
|
||||||
|
uint32_t *sh_perm_int_ptr = reinterpret_cast<uint32_t *>(sh_perm_ptr);
|
||||||
|
|
||||||
|
uint32_t vals[pack_factor_4bit];
|
||||||
|
|
||||||
|
if constexpr (has_perm) {
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
int k_idx = tc_row + tc_offsets[i];
|
||||||
|
|
||||||
|
uint32_t src_k = sh_perm_int_ptr[k_idx];
|
||||||
|
uint32_t src_k_pos = src_k % pack_factor_4bit;
|
||||||
|
|
||||||
|
uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n];
|
||||||
|
uint32_t b1_cur_val = (b1_val >> (src_k_pos * 4)) & 0xf;
|
||||||
|
|
||||||
|
uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8];
|
||||||
|
uint32_t b2_cur_val = (b2_val >> (src_k_pos * 4)) & 0xf;
|
||||||
|
|
||||||
|
vals[i] = b1_cur_val;
|
||||||
|
vals[4 + i] = b2_cur_val;
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
|
||||||
|
uint32_t b1_val_1 = sh_stage_int_ptr[cur_n];
|
||||||
|
uint32_t b1_val_2 = sh_stage_int_ptr[sh_stride + cur_n];
|
||||||
|
|
||||||
|
uint32_t b2_val_1 = sh_stage_int_ptr[cur_n + 8];
|
||||||
|
uint32_t b2_val_2 = sh_stage_int_ptr[sh_stride + cur_n + 8];
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 2; i++) {
|
||||||
|
int cur_elem = tc_row + tc_offsets[i];
|
||||||
|
vals[i] = (b1_val_1 >> (cur_elem * 4)) & 0xf;
|
||||||
|
vals[4 + i] = (b2_val_1 >> (cur_elem * 4)) & 0xf;
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 2; i < 4; i++) {
|
||||||
|
int cur_elem = tc_row + tc_offsets[i] - 8;
|
||||||
|
vals[i] = (b1_val_2 >> (cur_elem * 4)) & 0xf;
|
||||||
|
vals[4 + i] = (b2_val_2 >> (cur_elem * 4)) & 0xf;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Result of:
|
||||||
|
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||||
|
constexpr int pack_idx[pack_factor_4bit] = {0, 2, 4, 6, 1, 3, 5, 7};
|
||||||
|
|
||||||
|
uint32_t res = 0;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < pack_factor_4bit; i++) {
|
||||||
|
res |= vals[pack_idx[i]] << (i * 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor_4bit;
|
||||||
|
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
|
||||||
|
|
||||||
|
out_ptr[out_offset + th_id * 4 + warp_id] = res;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
|
||||||
|
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
|
||||||
|
}
|
||||||
|
|
||||||
|
wait_for_stage();
|
||||||
|
};
|
||||||
|
#pragma unroll
|
||||||
|
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
|
||||||
|
int n_tile_id = 0;
|
||||||
|
|
||||||
|
if constexpr (has_perm) {
|
||||||
|
load_perm_to_shared(k_tile_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
start_pipes(k_tile_id, n_tile_id);
|
||||||
|
|
||||||
|
while (n_tile_id < n_tiles) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int pipe = 0; pipe < repack_stages; pipe++) {
|
||||||
|
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
|
||||||
|
n_tile_id + pipe + repack_stages - 1);
|
||||||
|
repack_tile(pipe, k_tile_id, n_tile_id + pipe);
|
||||||
|
wait_for_stage();
|
||||||
|
}
|
||||||
|
n_tile_id += repack_stages;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gptq_marlin
|
||||||
|
|
||||||
|
torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
|
||||||
|
int64_t size_k, int64_t size_n) {
|
||||||
|
// Verify compatibility with marlin tile of 16x64
|
||||||
|
TORCH_CHECK(size_k % gptq_marlin::tile_k_size == 0, "size_k = ", size_k,
|
||||||
|
" is not divisible by tile_k_size = ", gptq_marlin::tile_k_size);
|
||||||
|
TORCH_CHECK(size_n % gptq_marlin::tile_n_size == 0, "size_n = ", size_n,
|
||||||
|
" is not divisible by tile_n_size = ", gptq_marlin::tile_n_size);
|
||||||
|
|
||||||
|
// Verify B
|
||||||
|
TORCH_CHECK((size_k / gptq_marlin::pack_factor_4bit) == b_q_weight.size(0),
|
||||||
|
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
|
||||||
|
", size_k = ", size_k,
|
||||||
|
", pack_factor_4bit = ", gptq_marlin::pack_factor_4bit);
|
||||||
|
TORCH_CHECK(b_q_weight.size(1) == size_n,
|
||||||
|
"b_q_weight.size(1) = ", b_q_weight.size(1),
|
||||||
|
" is not size_n = ", size_n);
|
||||||
|
|
||||||
|
// Verify device and strides
|
||||||
|
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
|
||||||
|
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
|
||||||
|
TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt");
|
||||||
|
|
||||||
|
TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU");
|
||||||
|
TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous");
|
||||||
|
TORCH_CHECK(perm.dtype() == at::kInt, "perm type is not at::kInt");
|
||||||
|
|
||||||
|
// Alloc buffers
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
|
||||||
|
auto options = torch::TensorOptions()
|
||||||
|
.dtype(b_q_weight.dtype())
|
||||||
|
.device(b_q_weight.device());
|
||||||
|
torch::Tensor out = torch::empty(
|
||||||
|
{size_k / gptq_marlin::tile_size,
|
||||||
|
size_n * gptq_marlin::tile_size / gptq_marlin::pack_factor_4bit},
|
||||||
|
options);
|
||||||
|
|
||||||
|
// Detect if there is act_order
|
||||||
|
bool has_perm = perm.size(0) != 0;
|
||||||
|
|
||||||
|
// Get ptrs
|
||||||
|
uint32_t const *b_q_weight_ptr =
|
||||||
|
reinterpret_cast<uint32_t const *>(b_q_weight.data_ptr());
|
||||||
|
uint32_t const *perm_ptr =
|
||||||
|
reinterpret_cast<uint32_t const *>(perm.data_ptr());
|
||||||
|
uint32_t *out_ptr = reinterpret_cast<uint32_t *>(out.data_ptr());
|
||||||
|
|
||||||
|
// Get dev info
|
||||||
|
int dev = b_q_weight.get_device();
|
||||||
|
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
|
||||||
|
int blocks;
|
||||||
|
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
||||||
|
|
||||||
|
int max_shared_mem = 0;
|
||||||
|
cudaDeviceGetAttribute(&max_shared_mem,
|
||||||
|
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
||||||
|
TORCH_CHECK(max_shared_mem > 0);
|
||||||
|
|
||||||
|
if (has_perm) {
|
||||||
|
cudaFuncSetAttribute(
|
||||||
|
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, true>,
|
||||||
|
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||||
|
max_shared_mem);
|
||||||
|
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, true>
|
||||||
|
<<<blocks, gptq_marlin::repack_threads, max_shared_mem,
|
||||||
|
stream>>>(b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
cudaFuncSetAttribute(
|
||||||
|
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, false>,
|
||||||
|
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||||
|
max_shared_mem);
|
||||||
|
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, false>
|
||||||
|
<<<blocks, gptq_marlin::repack_threads, max_shared_mem,
|
||||||
|
stream>>>(b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n);
|
||||||
|
}
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
93
tests/models/test_gptq_marlin.py
Normal file
93
tests/models/test_gptq_marlin.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
"""Compares the outputs of gptq vs gptq_marlin
|
||||||
|
Note: GPTQ and Marlin do not have bitwise correctness.
|
||||||
|
As a result, in this test, we just confirm that the top selected tokens of the
|
||||||
|
Marlin/GPTQ models are in the top 3 selections of each other.
|
||||||
|
Note: Marlin internally uses locks to synchronize the threads. This can
|
||||||
|
result in very slight nondeterminism for Marlin. As a result, we re-run the test
|
||||||
|
up to 3 times to see if we pass.
|
||||||
|
Note: This test currently fails running with --forked with the following:
|
||||||
|
RuntimeError: Cannot re-initialize CUDA in forked subprocess.
|
||||||
|
To use CUDA with multiprocessing, you must use the 'spawn' start method
|
||||||
|
Run `pytest tests/models/test_gptq_marlin.py`.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tests.models.utils import check_logprobs_close
|
||||||
|
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||||
|
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||||
|
|
||||||
|
MAX_MODEL_LEN = 1024
|
||||||
|
|
||||||
|
capability = torch.cuda.get_device_capability()
|
||||||
|
capability = capability[0] * 10 + capability[1]
|
||||||
|
gptq_marlin_not_supported = (
|
||||||
|
capability < QUANTIZATION_METHODS["gptq_marlin"].get_min_capability())
|
||||||
|
|
||||||
|
MODELS = [
|
||||||
|
# act_order==False, group_size=channelwise
|
||||||
|
("robertgshaw2/zephyr-7b-beta-channelwise-gptq", "main"),
|
||||||
|
# act_order==False, group_size=128
|
||||||
|
("TheBloke/Llama-2-7B-GPTQ", "main"),
|
||||||
|
|
||||||
|
# act_order==True, group_size=128
|
||||||
|
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "main"),
|
||||||
|
# act_order==True, group_size=64
|
||||||
|
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-4bit-64g-actorder_True"),
|
||||||
|
# act_order==True, group_size=32
|
||||||
|
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-4bit-32g-actorder_True"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.flaky(reruns=2)
|
||||||
|
@pytest.mark.skipif(gptq_marlin_not_supported,
|
||||||
|
reason="gptq_marlin is not supported on this GPU type.")
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [32])
|
||||||
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
|
def test_models(
|
||||||
|
vllm_runner,
|
||||||
|
example_prompts,
|
||||||
|
model,
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
num_logprobs: int,
|
||||||
|
) -> None:
|
||||||
|
model_name, revision = model
|
||||||
|
|
||||||
|
# Run marlin.
|
||||||
|
gptq_marlin_model = vllm_runner(model_name=model_name,
|
||||||
|
revision=revision,
|
||||||
|
dtype=dtype,
|
||||||
|
quantization="marlin",
|
||||||
|
max_model_len=MAX_MODEL_LEN,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
disable_custom_all_reduce=True)
|
||||||
|
|
||||||
|
gptq_marlin_outputs = gptq_marlin_model.generate_greedy_logprobs(
|
||||||
|
example_prompts, max_tokens, num_logprobs)
|
||||||
|
del gptq_marlin_model
|
||||||
|
|
||||||
|
# Run gptq.
|
||||||
|
gptq_model = vllm_runner(model_name=model_name,
|
||||||
|
revision=revision,
|
||||||
|
dtype=dtype,
|
||||||
|
quantization="gptq",
|
||||||
|
max_model_len=MAX_MODEL_LEN,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
disable_custom_all_reduce=True)
|
||||||
|
gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts,
|
||||||
|
max_tokens,
|
||||||
|
num_logprobs)
|
||||||
|
del gptq_model
|
||||||
|
|
||||||
|
check_logprobs_close(
|
||||||
|
outputs_0_lst=gptq_outputs,
|
||||||
|
outputs_1_lst=gptq_marlin_outputs,
|
||||||
|
name_0="gptq",
|
||||||
|
name_1="gptq_marlin",
|
||||||
|
)
|
||||||
@ -10,12 +10,12 @@ up to 3 times to see if we pass.
|
|||||||
|
|
||||||
Run `pytest tests/models/test_marlin.py`.
|
Run `pytest tests/models/test_marlin.py`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from tests.models.utils import check_logprobs_close
|
||||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||||
|
|
||||||
capability = torch.cuda.get_device_capability()
|
capability = torch.cuda.get_device_capability()
|
||||||
@ -55,43 +55,24 @@ def test_models(
|
|||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
num_logprobs: int,
|
num_logprobs: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
marlin_model = vllm_runner(model_pair.model_marlin, dtype=dtype)
|
marlin_model = vllm_runner(model_pair.model_marlin,
|
||||||
|
dtype=dtype,
|
||||||
|
quantization="marlin")
|
||||||
marlin_outputs = marlin_model.generate_greedy_logprobs(
|
marlin_outputs = marlin_model.generate_greedy_logprobs(
|
||||||
example_prompts, max_tokens, num_logprobs)
|
example_prompts, max_tokens, num_logprobs)
|
||||||
|
|
||||||
# Note: not sure why, but deleting just the model on Ada Lovelace
|
|
||||||
# does not free the GPU memory. On Ampere, deleting the just model
|
|
||||||
# frees the memory.
|
|
||||||
del marlin_model
|
del marlin_model
|
||||||
|
|
||||||
gptq_model = vllm_runner(model_pair.model_gptq, dtype=dtype)
|
gptq_model = vllm_runner(model_pair.model_gptq,
|
||||||
|
dtype=dtype,
|
||||||
|
quantization="gptq")
|
||||||
gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts,
|
gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts,
|
||||||
max_tokens,
|
max_tokens,
|
||||||
num_logprobs)
|
num_logprobs)
|
||||||
|
|
||||||
# Note: not sure why, but deleting just the model on Ada Lovelace
|
|
||||||
# does not free the GPU memory. On Ampere, deleting the just model
|
|
||||||
# frees the memory.
|
|
||||||
del gptq_model
|
del gptq_model
|
||||||
|
|
||||||
# loop through the prompts
|
check_logprobs_close(
|
||||||
for prompt_idx in range(len(example_prompts)):
|
outputs_0_lst=gptq_outputs,
|
||||||
gptq_output_ids, gptq_output_str, gptq_logprobs = gptq_outputs[
|
outputs_1_lst=marlin_outputs,
|
||||||
prompt_idx]
|
name_0="gptq",
|
||||||
marlin_output_ids, marlin_output_str, marlin_logprobs = marlin_outputs[
|
name_1="marlin",
|
||||||
prompt_idx]
|
)
|
||||||
|
|
||||||
for idx, (gptq_output_id, marlin_output_id) in enumerate(
|
|
||||||
zip(gptq_output_ids, marlin_output_ids)):
|
|
||||||
# If sequence is not an exact match,
|
|
||||||
if marlin_output_id != gptq_output_id:
|
|
||||||
# Each predicted token must be in top 5 of the other's
|
|
||||||
assert gptq_output_id in marlin_logprobs[idx], (
|
|
||||||
f"Test{prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\n"
|
|
||||||
f"Marlin:\t{marlin_output_str!r}")
|
|
||||||
assert marlin_output_id in gptq_logprobs[idx], (
|
|
||||||
f"Test{prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\n"
|
|
||||||
f"Marlin:\t{marlin_output_str!r}")
|
|
||||||
|
|
||||||
# Break out since sequences will now diverge.
|
|
||||||
break
|
|
||||||
|
|||||||
29
tests/models/utils.py
Normal file
29
tests/models/utils.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
def check_logprobs_close(outputs_0_lst, outputs_1_lst, name_0, name_1):
|
||||||
|
"""Compare the logprobs of two sequences generated by different models,
|
||||||
|
which should be similar but not necessarily equal.
|
||||||
|
"""
|
||||||
|
# Loop through responses to each prompt.
|
||||||
|
for prompt_idx, (outputs_0,
|
||||||
|
outputs_1) in enumerate(zip(outputs_0_lst,
|
||||||
|
outputs_1_lst)):
|
||||||
|
output_ids_0, output_str_0, logprobs_0 = outputs_0
|
||||||
|
output_ids_1, output_str_1, logprobs_1 = outputs_1
|
||||||
|
|
||||||
|
# Loop through generated tokens.
|
||||||
|
for idx, (output_id_0,
|
||||||
|
output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):
|
||||||
|
|
||||||
|
# If generated tokens don't match, then
|
||||||
|
if output_id_0 != output_id_1:
|
||||||
|
# Each predicted token must be in top N logprobs of the other
|
||||||
|
assert output_id_0 in logprobs_1[idx], (
|
||||||
|
f"Test{prompt_idx}:"
|
||||||
|
f"\n{name_0}:\t{output_str_0!r}"
|
||||||
|
f"\n{name_1}:\t{output_str_1!r}")
|
||||||
|
assert output_id_1 in logprobs_0[idx], (
|
||||||
|
f"Test{prompt_idx}:"
|
||||||
|
f"\n{name_0}:\t{output_str_0!r}"
|
||||||
|
f"\n{name_1}:\t{output_str_1!r}")
|
||||||
|
|
||||||
|
# Break out since sequences will now diverge.
|
||||||
|
break
|
||||||
@ -1,64 +0,0 @@
|
|||||||
"""Tests whether Marlin models can be loaded from the autogptq config.
|
|
||||||
|
|
||||||
Run `pytest tests/quantization/test_autogptq_marlin_configs.py --forked`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ModelPair:
|
|
||||||
model_marlin: str
|
|
||||||
model_gptq: str
|
|
||||||
|
|
||||||
|
|
||||||
# Model Id // Expected Kernel
|
|
||||||
MODELS_QUANT_TYPE = [
|
|
||||||
# compat: autogptq <=0.7.1 is_marlin_format: bool
|
|
||||||
("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "marlin"),
|
|
||||||
("TheBloke/Llama-2-7B-Chat-GPTQ", "gptq"),
|
|
||||||
# compat: autogptq >=0.8.0 use checkpoint_format: str
|
|
||||||
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "marlin"),
|
|
||||||
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "gptq")
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_quant_type", MODELS_QUANT_TYPE)
|
|
||||||
def test_auto_gptq(model_quant_type: str, ) -> None:
|
|
||||||
model_path, quant_type = model_quant_type
|
|
||||||
|
|
||||||
model_config_no_quant_arg = ModelConfig(
|
|
||||||
model_path,
|
|
||||||
model_path,
|
|
||||||
tokenizer_mode="auto",
|
|
||||||
trust_remote_code=False,
|
|
||||||
seed=0,
|
|
||||||
dtype="float16",
|
|
||||||
revision=None,
|
|
||||||
quantization=None # case 1
|
|
||||||
)
|
|
||||||
|
|
||||||
model_config_quant_arg = ModelConfig(
|
|
||||||
model_path,
|
|
||||||
model_path,
|
|
||||||
tokenizer_mode="auto",
|
|
||||||
trust_remote_code=False,
|
|
||||||
seed=0,
|
|
||||||
dtype="float16",
|
|
||||||
revision=None,
|
|
||||||
quantization="gptq" # case 2
|
|
||||||
)
|
|
||||||
|
|
||||||
assert model_config_no_quant_arg.quantization == quant_type, (
|
|
||||||
f"Expected quant_type == {quant_type} for {model_path}, "
|
|
||||||
f"but found {model_config_no_quant_arg.quantization} "
|
|
||||||
"for no --quantization None case")
|
|
||||||
|
|
||||||
assert model_config_quant_arg.quantization == quant_type, (
|
|
||||||
f"Expected quant_type == {quant_type} for {model_path}, "
|
|
||||||
f"but found {model_config_quant_arg.quantization} "
|
|
||||||
"for --quantization gptq case")
|
|
||||||
73
tests/quantization/test_configs.py
Normal file
73
tests/quantization/test_configs.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
"""Tests whether Marlin models can be loaded from the autogptq config.
|
||||||
|
|
||||||
|
Run `pytest tests/quantization/test_configs.py --forked`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelPair:
|
||||||
|
model_marlin: str
|
||||||
|
model_gptq: str
|
||||||
|
|
||||||
|
|
||||||
|
# Model Id // Quantization Arg // Expected Type
|
||||||
|
MODEL_ARG_EXPTYPES = [
|
||||||
|
# AUTOGPTQ
|
||||||
|
# compat: autogptq <=0.7.1 is_marlin_format: bool
|
||||||
|
# Model Serialized in Marlin Format should always use Marlin kernel.
|
||||||
|
("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", None, "marlin"),
|
||||||
|
("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "marlin", "marlin"),
|
||||||
|
("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "gptq", "marlin"),
|
||||||
|
("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "awq", "ERROR"),
|
||||||
|
# Model Serialized in Exllama Format.
|
||||||
|
("TheBloke/Llama-2-7B-Chat-GPTQ", None, "gptq_marlin"),
|
||||||
|
("TheBloke/Llama-2-7B-Chat-GPTQ", "marlin", "gptq_marlin"),
|
||||||
|
("TheBloke/Llama-2-7B-Chat-GPTQ", "gptq", "gptq"),
|
||||||
|
("TheBloke/Llama-2-7B-Chat-GPTQ", "awq", "ERROR"),
|
||||||
|
# compat: autogptq >=0.8.0 use checkpoint_format: str
|
||||||
|
# Model Serialized in Marlin Format should always use Marlin kernel.
|
||||||
|
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", None, "marlin"),
|
||||||
|
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "marlin", "marlin"),
|
||||||
|
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "gptq", "marlin"),
|
||||||
|
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "awq", "ERROR"),
|
||||||
|
# Model Serialized in Exllama Format.
|
||||||
|
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", None, "gptq_marlin"),
|
||||||
|
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "marlin", "gptq_marlin"),
|
||||||
|
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "gptq", "gptq"),
|
||||||
|
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "awq", "ERROR"),
|
||||||
|
|
||||||
|
# AUTOAWQ
|
||||||
|
("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", None, "awq"),
|
||||||
|
("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "awq", "awq"),
|
||||||
|
("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "marlin", "ERROR"),
|
||||||
|
("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "gptq", "ERROR"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_arg_exptype", MODEL_ARG_EXPTYPES)
|
||||||
|
def test_auto_gptq(model_arg_exptype: str) -> None:
|
||||||
|
model_path, quantization_arg, expected_type = model_arg_exptype
|
||||||
|
|
||||||
|
try:
|
||||||
|
model_config = ModelConfig(model_path,
|
||||||
|
model_path,
|
||||||
|
tokenizer_mode="auto",
|
||||||
|
trust_remote_code=False,
|
||||||
|
seed=0,
|
||||||
|
dtype="float16",
|
||||||
|
revision=None,
|
||||||
|
quantization=quantization_arg)
|
||||||
|
found_quantization_type = model_config.quantization
|
||||||
|
except ValueError:
|
||||||
|
found_quantization_type = "ERROR"
|
||||||
|
|
||||||
|
assert found_quantization_type == expected_type, (
|
||||||
|
f"Expected quant_type == {expected_type} for {model_path}, "
|
||||||
|
f"but found {found_quantization_type} "
|
||||||
|
f"for no --quantization {quantization_arg} case")
|
||||||
@ -9,11 +9,14 @@ from packaging.version import Version
|
|||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
|
||||||
|
get_quantization_config)
|
||||||
from vllm.transformers_utils.config import get_config, get_hf_text_config
|
from vllm.transformers_utils.config import get_config, get_hf_text_config
|
||||||
from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_cpu, is_hip,
|
from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_cpu, is_hip,
|
||||||
is_neuron)
|
is_neuron)
|
||||||
|
|
||||||
|
GPTQMarlinConfig = get_quantization_config("gptq_marlin")
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ray.util.placement_group import PlacementGroup
|
from ray.util.placement_group import PlacementGroup
|
||||||
|
|
||||||
@ -138,14 +141,34 @@ class ModelConfig:
|
|||||||
is_format_marlin = (quant_cfg.get("checkpoint_format") == "marlin"
|
is_format_marlin = (quant_cfg.get("checkpoint_format") == "marlin"
|
||||||
or quant_cfg.get("is_marlin_format", False))
|
or quant_cfg.get("is_marlin_format", False))
|
||||||
|
|
||||||
# Use marlin if the GPTQ model is serialized in marlin format.
|
# Check which LinearMethod the GPTQ model should use.
|
||||||
if quant_method == "gptq" and is_format_marlin:
|
if quant_method == "gptq":
|
||||||
logger.info("The model is serialized in Marlin format. "
|
# If serialized in Marlin format, use MarlinLinearMethod.
|
||||||
"Using Marlin kernel.")
|
# TODO (@robertgshaw): migrate under GPTQMarlinLinearMethod.
|
||||||
quant_method = "marlin"
|
if is_format_marlin:
|
||||||
if self.quantization == "gptq":
|
logger.info("The model is serialized in Marlin format. "
|
||||||
self.quantization = quant_method
|
"Using Marlin kernel.")
|
||||||
|
quant_method = "marlin"
|
||||||
|
if self.quantization == "gptq":
|
||||||
|
self.quantization = quant_method
|
||||||
|
|
||||||
|
# If convertible to Marlin format, use GPTQMarlinLinearMethod
|
||||||
|
# unless the user explicitly specified GPTQLinearMethod.
|
||||||
|
elif GPTQMarlinConfig.is_marlin_compatible(quant_cfg):
|
||||||
|
if self.quantization == "gptq":
|
||||||
|
logger.warning(
|
||||||
|
"The model is convertible to Marlin format, but "
|
||||||
|
"you specified quantization=gptq. Use "
|
||||||
|
"quantization=marlin for faster inference.")
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"The model is convertible to Marlin format. "
|
||||||
|
"Using Marlin kernel.")
|
||||||
|
quant_method = "gptq_marlin"
|
||||||
|
if self.quantization == "marlin":
|
||||||
|
self.quantization = quant_method
|
||||||
|
|
||||||
|
# Verify quantization configurations.
|
||||||
if self.quantization is None:
|
if self.quantization is None:
|
||||||
self.quantization = quant_method
|
self.quantization = quant_method
|
||||||
elif self.quantization != quant_method:
|
elif self.quantization != quant_method:
|
||||||
@ -165,7 +188,7 @@ class ModelConfig:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"{self.quantization} quantization is currently not "
|
f"{self.quantization} quantization is currently not "
|
||||||
f"supported in ROCm.")
|
f"supported in ROCm.")
|
||||||
if self.quantization != "marlin":
|
if (self.quantization not in ["marlin", "gptq_marlin"]):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"%s quantization is not fully "
|
"%s quantization is not fully "
|
||||||
"optimized yet. The speed can be slower than "
|
"optimized yet. The speed can be slower than "
|
||||||
|
|||||||
@ -6,6 +6,8 @@ from vllm.model_executor.layers.quantization.base_config import (
|
|||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||||
|
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||||
|
GPTQMarlinConfig)
|
||||||
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
||||||
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
|
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
|
||||||
|
|
||||||
@ -15,6 +17,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|||||||
"fp8": Fp8Config,
|
"fp8": Fp8Config,
|
||||||
"gptq": GPTQConfig,
|
"gptq": GPTQConfig,
|
||||||
"squeezellm": SqueezeLLMConfig,
|
"squeezellm": SqueezeLLMConfig,
|
||||||
|
"gptq_marlin": GPTQMarlinConfig,
|
||||||
"marlin": MarlinConfig,
|
"marlin": MarlinConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
444
vllm/model_executor/layers/quantization/gptq_marlin.py
Normal file
444
vllm/model_executor/layers/quantization/gptq_marlin.py
Normal file
@ -0,0 +1,444 @@
|
|||||||
|
import enum
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import numpy
|
||||||
|
import torch
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
from vllm._C import ops
|
||||||
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||||
|
set_weight_attrs)
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig)
|
||||||
|
|
||||||
|
GPTQ_MARLIN_TILE = 16
|
||||||
|
GPTQ_MARLIN_MIN_THREAD_N = 64
|
||||||
|
GPTQ_MARLIN_MIN_THREAD_K = 128
|
||||||
|
GPTQ_MARLIN_MAX_PARALLEL = 16
|
||||||
|
|
||||||
|
GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4]
|
||||||
|
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||||
|
GPTQ_MARLIN_SUPPORTED_SYM = [True]
|
||||||
|
|
||||||
|
|
||||||
|
# Precompute permutations for Marlin weight and scale shuffling
|
||||||
|
#
|
||||||
|
# Marlin works on [16,64] tiles. The goal of the permutations
|
||||||
|
# is to reorder the weight data so that it is compatible
|
||||||
|
# with the tensor-core format that is described here:
|
||||||
|
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
|
||||||
|
#
|
||||||
|
# As a result of this reordering, the vector loads inside the
|
||||||
|
# kernel will get the data as it is needed for tensor-core
|
||||||
|
# (without the need to use ldmatrix instructions)
|
||||||
|
def _get_perms():
|
||||||
|
perm = []
|
||||||
|
for i in range(32):
|
||||||
|
perm1 = []
|
||||||
|
col = i // 4
|
||||||
|
for block in [0, 1]:
|
||||||
|
for row in [
|
||||||
|
2 * (i % 4),
|
||||||
|
2 * (i % 4) + 1,
|
||||||
|
2 * (i % 4 + 4),
|
||||||
|
2 * (i % 4 + 4) + 1,
|
||||||
|
]:
|
||||||
|
perm1.append(16 * row + col + 8 * block)
|
||||||
|
for j in range(4):
|
||||||
|
perm.extend([p + 256 * j for p in perm1])
|
||||||
|
|
||||||
|
perm = numpy.array(perm)
|
||||||
|
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
||||||
|
perm = perm.reshape((-1, 8))[:, interleave].ravel() # type: ignore
|
||||||
|
perm = torch.from_numpy(perm)
|
||||||
|
scale_perm = []
|
||||||
|
for i in range(8):
|
||||||
|
scale_perm.extend([i + 8 * j for j in range(8)])
|
||||||
|
scale_perm_single = []
|
||||||
|
for i in range(4):
|
||||||
|
scale_perm_single.extend(
|
||||||
|
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
|
||||||
|
return perm, scale_perm, scale_perm_single
|
||||||
|
|
||||||
|
|
||||||
|
_perm, _scale_perm, _scale_perm_single = _get_perms()
|
||||||
|
|
||||||
|
|
||||||
|
def get_pack_factor(num_bits):
|
||||||
|
assert num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS, (
|
||||||
|
f"Unsupported num_bits = {num_bits}")
|
||||||
|
return 32 // num_bits
|
||||||
|
|
||||||
|
|
||||||
|
def marlin_permute_scales(s, size_k, size_n, group_size):
|
||||||
|
if group_size < size_k and group_size != -1:
|
||||||
|
s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm]
|
||||||
|
else:
|
||||||
|
s = s.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single]
|
||||||
|
s = s.reshape((-1, size_n)).contiguous()
|
||||||
|
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
class GPTQMarlinConfig(QuantizationConfig):
|
||||||
|
"""Config class for GPTQ Marlin"""
|
||||||
|
|
||||||
|
def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
|
||||||
|
is_sym: bool) -> None:
|
||||||
|
if desc_act and group_size == -1:
|
||||||
|
# In this case, act_order == True is the same as act_order == False
|
||||||
|
# (since we have only one group per output channel)
|
||||||
|
desc_act = False
|
||||||
|
|
||||||
|
self.weight_bits = weight_bits
|
||||||
|
self.group_size = group_size
|
||||||
|
self.desc_act = desc_act
|
||||||
|
self.is_sym = is_sym
|
||||||
|
|
||||||
|
# Verify
|
||||||
|
if self.weight_bits not in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Marlin does not support weight_bits = {self.weight_bits}. "
|
||||||
|
f"Only weight_bits = {GPTQ_MARLIN_SUPPORTED_NUM_BITS} "
|
||||||
|
"are supported.")
|
||||||
|
if self.group_size not in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES:
|
||||||
|
raise ValueError(
|
||||||
|
f"Marlin does not support group_size = {self.group_size}. "
|
||||||
|
f"Only group_sizes = {GPTQ_MARLIN_SUPPORTED_GROUP_SIZES} "
|
||||||
|
"are supported.")
|
||||||
|
if self.is_sym not in GPTQ_MARLIN_SUPPORTED_SYM:
|
||||||
|
raise ValueError(
|
||||||
|
f"Marlin does not support is_sym = {self.is_sym}. "
|
||||||
|
f"Only sym = {GPTQ_MARLIN_SUPPORTED_SYM} are supported.")
|
||||||
|
|
||||||
|
# Init
|
||||||
|
self.pack_factor = get_pack_factor(weight_bits)
|
||||||
|
self.tile_size = GPTQ_MARLIN_TILE
|
||||||
|
self.min_thread_n = GPTQ_MARLIN_MIN_THREAD_N
|
||||||
|
self.min_thread_k = GPTQ_MARLIN_MIN_THREAD_K
|
||||||
|
self.max_parallel = GPTQ_MARLIN_MAX_PARALLEL
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, "
|
||||||
|
f"group_size={self.group_size}, "
|
||||||
|
f"desc_act={self.desc_act})")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_name(cls) -> str:
|
||||||
|
return "gptq_marlin"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||||
|
return [torch.half]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
return 80
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config_filenames(cls) -> List[str]:
|
||||||
|
return ["quantize_config.json"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
|
||||||
|
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||||
|
group_size = cls.get_from_keys(config, ["group_size"])
|
||||||
|
desc_act = cls.get_from_keys(config, ["desc_act"])
|
||||||
|
is_sym = cls.get_from_keys(config, ["sym"])
|
||||||
|
return cls(weight_bits, group_size, desc_act, is_sym)
|
||||||
|
|
||||||
|
def get_quant_method(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module) -> Optional["GPTQMarlinLinearMethod"]:
|
||||||
|
if isinstance(layer, LinearBase):
|
||||||
|
return GPTQMarlinLinearMethod(self)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_scaled_act_names(self) -> List[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_marlin_compatible(cls, quant_config: Dict[str, Any]):
|
||||||
|
# Extract data from quant config.
|
||||||
|
num_bits = quant_config.get("bits", None)
|
||||||
|
group_size = quant_config.get("group_size", None)
|
||||||
|
sym = quant_config.get("sym", None)
|
||||||
|
desc_act = quant_config.get("desc_act", None)
|
||||||
|
|
||||||
|
# If we cannot find the info needed in the config, cannot convert.
|
||||||
|
if (num_bits is None or group_size is None or sym is None
|
||||||
|
or desc_act is None):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# If the capability of the device is too low, cannot convert.
|
||||||
|
major, minor = torch.cuda.get_device_capability()
|
||||||
|
device_capability = major * 10 + minor
|
||||||
|
if device_capability < cls.get_min_capability():
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Otherwise, can convert if model satisfies marlin constraints.
|
||||||
|
return (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS
|
||||||
|
and group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
|
||||||
|
and sym in GPTQ_MARLIN_SUPPORTED_SYM)
|
||||||
|
|
||||||
|
|
||||||
|
class GPTQMarlinState(Enum):
|
||||||
|
REPACK = enum.auto()
|
||||||
|
READY = enum.auto()
|
||||||
|
|
||||||
|
|
||||||
|
class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||||
|
"""Linear method for GPTQ Marlin.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
quant_config: The GPTQ Marlin quantization config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
input_size_per_partition: int,
|
||||||
|
output_partition_sizes: List[int],
|
||||||
|
input_size: int,
|
||||||
|
output_size: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
) -> None:
|
||||||
|
del output_size
|
||||||
|
|
||||||
|
# Normalize group_size
|
||||||
|
if self.quant_config.group_size != -1:
|
||||||
|
group_size = self.quant_config.group_size
|
||||||
|
else:
|
||||||
|
group_size = input_size
|
||||||
|
|
||||||
|
# Validate dtype
|
||||||
|
if params_dtype != torch.float16:
|
||||||
|
raise ValueError(
|
||||||
|
f"The params dtype must be float16, but got {params_dtype}")
|
||||||
|
|
||||||
|
# Validate output_size_per_partition
|
||||||
|
output_size_per_partition = sum(output_partition_sizes)
|
||||||
|
if output_size_per_partition % self.quant_config.min_thread_n != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Weight output_size_per_partition = "
|
||||||
|
f"{output_size_per_partition} is not divisible by "
|
||||||
|
f" min_thread_n = {self.quant_config.min_thread_n}.")
|
||||||
|
|
||||||
|
# Validate input_size_per_partition
|
||||||
|
if input_size_per_partition % self.quant_config.min_thread_k != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Weight input_size_per_partition = "
|
||||||
|
f"{input_size_per_partition} is not divisible "
|
||||||
|
f"by min_thread_k = {self.quant_config.min_thread_k}.")
|
||||||
|
|
||||||
|
if (group_size < input_size
|
||||||
|
and input_size_per_partition % group_size != 0):
|
||||||
|
raise ValueError(
|
||||||
|
f"Weight input_size_per_partition = {input_size_per_partition}"
|
||||||
|
f" is not divisible by group_size = {group_size}.")
|
||||||
|
|
||||||
|
# Detect sharding of scales/zp
|
||||||
|
|
||||||
|
# By default, no sharding over "input dim"
|
||||||
|
scales_and_zp_size = input_size // group_size
|
||||||
|
scales_and_zp_input_dim = None
|
||||||
|
|
||||||
|
if self.quant_config.desc_act:
|
||||||
|
# Act-order case
|
||||||
|
assert self.quant_config.group_size != -1
|
||||||
|
|
||||||
|
is_k_full = input_size_per_partition == input_size
|
||||||
|
|
||||||
|
else:
|
||||||
|
# No act-order case
|
||||||
|
|
||||||
|
# K is always full due to full alignment with
|
||||||
|
# group-size and shard of scales/zp
|
||||||
|
is_k_full = True
|
||||||
|
|
||||||
|
# If this is a row-parallel case, then shard scales/zp
|
||||||
|
if (input_size != input_size_per_partition
|
||||||
|
and self.quant_config.group_size != -1):
|
||||||
|
scales_and_zp_size = input_size_per_partition // group_size
|
||||||
|
scales_and_zp_input_dim = 0
|
||||||
|
|
||||||
|
# Init buffers
|
||||||
|
|
||||||
|
# Quantized weights
|
||||||
|
qweight = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
input_size_per_partition // self.quant_config.pack_factor,
|
||||||
|
output_size_per_partition,
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
set_weight_attrs(
|
||||||
|
qweight, {
|
||||||
|
**extra_weight_attrs,
|
||||||
|
"input_dim": 0,
|
||||||
|
"output_dim": 1,
|
||||||
|
"packed_dim": 0,
|
||||||
|
"pack_factor": self.quant_config.pack_factor,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Activation order
|
||||||
|
g_idx = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
input_size_per_partition,
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
# Ignore warning from fused linear layers such as QKVParallelLinear.
|
||||||
|
set_weight_attrs(g_idx, {
|
||||||
|
**extra_weight_attrs, "input_dim": 0,
|
||||||
|
"ignore_warning": True
|
||||||
|
})
|
||||||
|
|
||||||
|
g_idx_sort_indices = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
g_idx.shape,
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
set_weight_attrs(g_idx_sort_indices, extra_weight_attrs)
|
||||||
|
|
||||||
|
# Scales
|
||||||
|
scales = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
scales_and_zp_size,
|
||||||
|
output_size_per_partition,
|
||||||
|
dtype=params_dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
set_weight_attrs(
|
||||||
|
scales, {
|
||||||
|
**extra_weight_attrs,
|
||||||
|
"input_dim": scales_and_zp_input_dim,
|
||||||
|
"output_dim": 1,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Quantized zero-points
|
||||||
|
qzeros = Parameter(
|
||||||
|
torch.empty(scales_and_zp_size,
|
||||||
|
output_size_per_partition //
|
||||||
|
self.quant_config.pack_factor,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="meta"),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
set_weight_attrs(
|
||||||
|
qzeros, {
|
||||||
|
**extra_weight_attrs,
|
||||||
|
"input_dim": scales_and_zp_input_dim,
|
||||||
|
"output_dim": 1,
|
||||||
|
"packed_dim": 1,
|
||||||
|
"pack_factor": self.quant_config.pack_factor,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Allocate marlin workspace
|
||||||
|
max_workspace_size = (
|
||||||
|
output_size_per_partition //
|
||||||
|
self.quant_config.min_thread_n) * self.quant_config.max_parallel
|
||||||
|
workspace = torch.zeros(max_workspace_size,
|
||||||
|
dtype=torch.int,
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
layer.register_parameter("qweight", qweight)
|
||||||
|
layer.register_parameter("g_idx", g_idx)
|
||||||
|
layer.register_parameter("g_idx_sort_indices", g_idx_sort_indices)
|
||||||
|
layer.register_parameter("scales", scales)
|
||||||
|
layer.register_parameter("qzeros", qzeros)
|
||||||
|
layer.workspace = workspace
|
||||||
|
layer.input_size_per_partition = input_size_per_partition
|
||||||
|
layer.output_size_per_partition = output_size_per_partition
|
||||||
|
layer.input_size = input_size
|
||||||
|
layer.is_k_full = is_k_full
|
||||||
|
layer.marlin_state = GPTQMarlinState.REPACK
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||||
|
|
||||||
|
size_m = reshaped_x.shape[0]
|
||||||
|
part_size_n = layer.output_size_per_partition
|
||||||
|
part_size_k = layer.input_size_per_partition
|
||||||
|
full_size_k = layer.input_size
|
||||||
|
|
||||||
|
out_shape = x.shape[:-1] + (part_size_n, )
|
||||||
|
|
||||||
|
if layer.marlin_state == GPTQMarlinState.REPACK:
|
||||||
|
layer.marlin_state = GPTQMarlinState.READY
|
||||||
|
|
||||||
|
# Newly generated tensors need to replace existing tensors that are
|
||||||
|
# already registered as parameters by vLLM (and won't be freed)
|
||||||
|
def replace_tensor(name, new_t):
|
||||||
|
# It is important to use resize_() here since it ensures
|
||||||
|
# the same buffer is reused
|
||||||
|
getattr(layer, name).resize_(new_t.shape)
|
||||||
|
getattr(layer, name).copy_(new_t)
|
||||||
|
del new_t
|
||||||
|
|
||||||
|
cur_device = layer.qweight.device
|
||||||
|
|
||||||
|
# Process act_order
|
||||||
|
if self.quant_config.desc_act:
|
||||||
|
# Get sorting based on g_idx
|
||||||
|
g_idx_sort_indices = torch.argsort(layer.g_idx).to(torch.int)
|
||||||
|
|
||||||
|
sorted_g_idx = layer.g_idx[g_idx_sort_indices]
|
||||||
|
|
||||||
|
replace_tensor("g_idx", sorted_g_idx)
|
||||||
|
replace_tensor("g_idx_sort_indices", g_idx_sort_indices)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Reset g_idx related tensors
|
||||||
|
layer.g_idx = Parameter(torch.empty(0,
|
||||||
|
dtype=torch.int,
|
||||||
|
device=cur_device),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.g_idx_sort_indices = Parameter(torch.empty(
|
||||||
|
0, dtype=torch.int, device=cur_device),
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
# Repack weights
|
||||||
|
marlin_qweight = ops.gptq_marlin_repack(
|
||||||
|
layer.qweight,
|
||||||
|
layer.g_idx_sort_indices,
|
||||||
|
part_size_k,
|
||||||
|
part_size_n,
|
||||||
|
)
|
||||||
|
replace_tensor("qweight", marlin_qweight)
|
||||||
|
|
||||||
|
# Permute scales
|
||||||
|
scales_size_k = part_size_k
|
||||||
|
scales_size_n = part_size_n
|
||||||
|
if self.quant_config.desc_act:
|
||||||
|
scales_size_k = full_size_k
|
||||||
|
|
||||||
|
marlin_scales = marlin_permute_scales(layer.scales, scales_size_k,
|
||||||
|
scales_size_n,
|
||||||
|
self.quant_config.group_size)
|
||||||
|
replace_tensor("scales", marlin_scales)
|
||||||
|
|
||||||
|
output = ops.gptq_marlin_gemm(reshaped_x, layer.qweight, layer.scales,
|
||||||
|
layer.g_idx, layer.g_idx_sort_indices,
|
||||||
|
layer.workspace, size_m, part_size_n,
|
||||||
|
part_size_k, layer.is_k_full)
|
||||||
|
|
||||||
|
if bias is not None:
|
||||||
|
output.add_(bias) # In-place add
|
||||||
|
|
||||||
|
return output.reshape(out_shape)
|
||||||
Loading…
x
Reference in New Issue
Block a user