[Kernel] Support running GPTQ 8-bit models in Marlin (#4533)

This commit is contained in:
alexm-nm 2024-05-02 12:56:22 -04:00 committed by GitHub
parent 2a85f93007
commit 7038e8b803
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 552 additions and 323 deletions

View File

@ -132,6 +132,7 @@ torch::Tensor gptq_marlin_gemm(
torch::Tensor &g_idx, torch::Tensor &g_idx,
torch::Tensor &perm, torch::Tensor &perm,
torch::Tensor &workspace, torch::Tensor &workspace,
int64_t num_bits,
int64_t size_m, int64_t size_m,
int64_t size_n, int64_t size_n,
int64_t size_k, int64_t size_k,
@ -141,7 +142,8 @@ torch::Tensor gptq_marlin_repack(
torch::Tensor &b_q_weight, torch::Tensor &b_q_weight,
torch::Tensor &perm, torch::Tensor &perm,
int64_t size_k, int64_t size_k,
int64_t size_n); int64_t size_n,
int64_t num_bits);
#endif #endif
void squeezellm_gemm( void squeezellm_gemm(

View File

@ -32,7 +32,8 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr,
int4 *__restrict__ out_int4_ptr, int size_m, int4 *__restrict__ out_int4_ptr, int size_m,
int size_k, int block_rows) {} int size_k, int block_rows) {}
template <const int threads, // number of threads in a threadblock template <const int num_bits, // number of bits used for weights
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the threadblock // dimension (batchsize) of the threadblock
const int thread_n_blocks, // same for n dimension (output) const int thread_n_blocks, // same for n dimension (output)
@ -62,8 +63,8 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
torch::Tensor &b_scales, torch::Tensor &g_idx, torch::Tensor &b_scales, torch::Tensor &g_idx,
torch::Tensor &perm, torch::Tensor &workspace, torch::Tensor &perm, torch::Tensor &workspace,
int64_t size_m, int64_t size_n, int64_t size_k, int64_t num_bits, int64_t size_m, int64_t size_n,
bool is_k_full) { int64_t size_k, bool is_k_full) {
TORCH_CHECK_NOT_IMPLEMENTED(false, TORCH_CHECK_NOT_IMPLEMENTED(false,
"marlin_gemm(..) requires CUDA_ARCH >= 8.0"); "marlin_gemm(..) requires CUDA_ARCH >= 8.0");
return torch::empty({1, 1}); return torch::empty({1, 1});
@ -114,11 +115,21 @@ template <int lut> __device__ inline int lop3(int a, int b, int c) {
return res; return res;
} }
// Constructs destination register by taking bytes from 2 sources (based on mask)
template <int start_byte, int mask>
__device__ inline uint32_t prmt(uint32_t a) {
uint32_t res;
asm volatile("prmt.b32 %0, %1, %2, %3;\n"
: "=r"(res)
: "r"(a), "n"(start_byte), "n"(mask));
return res;
}
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 // Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
// values. We mostly follow the strategy in the link below, with some small // values. We mostly follow the strategy in the link below, with some small
// changes: // changes:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
__device__ inline FragB dequant(int q) { __device__ inline FragB dequant_4bit(int q) {
const int LO = 0x000f000f; const int LO = 0x000f000f;
const int HI = 0x00f000f0; const int HI = 0x00f000f0;
const int EX = 0x64006400; const int EX = 0x64006400;
@ -139,6 +150,24 @@ __device__ inline FragB dequant(int q) {
return frag_b; return frag_b;
} }
__device__ inline FragB dequant_8bit(int q) {
static constexpr uint32_t mask_for_elt_01 = 0x5250;
static constexpr uint32_t mask_for_elt_23 = 0x5351;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
FragB frag_b;
frag_b[0] = __hsub2(*reinterpret_cast<half2 *>(&lo),
*reinterpret_cast<const half2 *>(&I8s_TO_F16s_MAGIC_NUM));
frag_b[1] = __hsub2(*reinterpret_cast<half2 *>(&hi),
*reinterpret_cast<const half2 *>(&I8s_TO_F16s_MAGIC_NUM));
return frag_b;
}
// Multiply dequantized values by the corresponding quantization scale; used // Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization. // only for grouped quantization.
__device__ inline void scale(FragB &frag_b, FragS &frag_s, int i) { __device__ inline void scale(FragB &frag_b, FragS &frag_s, int i) {
@ -162,6 +191,13 @@ __device__ inline void scale4(FragB &frag_b, FragS &frag_s_1, FragS &frag_s_2,
frag_b[1] = __hmul2(frag_b[1], s_val_3_4); frag_b[1] = __hmul2(frag_b[1], s_val_3_4);
} }
// Given 2 floats multiply by 2 scales (halves)
__device__ inline void scale_float(float *c, FragS &s) {
__half *s_ptr = reinterpret_cast<__half *>(&s);
c[0] = __fmul_rn(c[0], __half2float(s_ptr[0]));
c[1] = __fmul_rn(c[1], __half2float(s_ptr[1]));
}
// Wait until barrier reaches `count`, then lock for current threadblock. // Wait until barrier reaches `count`, then lock for current threadblock.
__device__ inline void barrier_acquire(int *lock, int count) { __device__ inline void barrier_acquire(int *lock, int count) {
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
@ -250,7 +286,8 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr,
} }
} }
template <const int threads, // number of threads in a threadblock template <const int num_bits, // number of bits used for weights
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the threadblock // dimension (batchsize) of the threadblock
const int thread_n_blocks, // same for n dimension (output) const int thread_n_blocks, // same for n dimension (output)
@ -286,6 +323,8 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// configurations, while requiring as few slow global cross-threadblock // configurations, while requiring as few slow global cross-threadblock
// reductions as possible. // reductions as possible.
constexpr int pack_factor = 32 / num_bits;
// For larger GEMMs we run multiple batchsize 64 versions in parallel for a // For larger GEMMs we run multiple batchsize 64 versions in parallel for a
// better partitioning with less reductions // better partitioning with less reductions
int parallel = 1; int parallel = 1;
@ -385,21 +424,25 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta);
// B sizes/strides // B sizes/strides
int b_gl_stride = 16 * prob_n / 32; int b_gl_stride = 16 * prob_n / (pack_factor * 4);
constexpr int b_sh_stride = 32 * thread_n_blocks / 4; constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4;
constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2;
constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs;
int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride); int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads);
constexpr int b_sh_wr_delta = threads; constexpr int b_sh_wr_delta = threads * b_thread_vecs;
constexpr int b_sh_rd_delta = threads; constexpr int b_sh_rd_delta = threads * b_thread_vecs;
constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
// Scale sizes/strides without act_order // Scale sizes/strides without act_order
int s_gl_stride = prob_n / 8; int s_gl_stride = prob_n / 8;
constexpr int s_sh_stride = 16 * thread_n_blocks / 8; constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
constexpr int s_tb_groups = !has_act_order && group_blocks < thread_k_blocks constexpr int s_tb_groups =
? thread_k_blocks / group_blocks !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
: 1; ? thread_k_blocks / group_blocks
: 1;
constexpr int s_sh_stage = s_tb_groups * s_sh_stride; constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
int s_gl_rd_delta = s_gl_stride; int s_gl_rd_delta = s_gl_stride;
@ -425,12 +468,12 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;
a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
int b_gl_rd = int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) +
b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); (threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
b_gl_rd += b_sh_stride * slice_col; b_gl_rd += b_sh_stride * slice_col;
b_gl_rd += b_gl_rd_delta_o * slice_row; b_gl_rd += b_gl_rd_delta_o * slice_row;
int b_sh_wr = threadIdx.x; int b_sh_wr = threadIdx.x * b_thread_vecs;
int b_sh_rd = threadIdx.x; int b_sh_rd = threadIdx.x * b_thread_vecs;
// For act_order // For act_order
constexpr int k_iter_size = tb_k / b_sh_wr_iters; constexpr int k_iter_size = tb_k / b_sh_wr_iters;
@ -442,8 +485,12 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// No act_order // No act_order
int s_gl_rd; int s_gl_rd;
if constexpr (!has_act_order) { if constexpr (!has_act_order) {
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + if constexpr (group_blocks == -1) {
s_sh_stride * slice_col + threadIdx.x; s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
} else {
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
s_sh_stride * slice_col + threadIdx.x;
}
} }
int s_sh_wr = threadIdx.x; int s_sh_wr = threadIdx.x;
bool s_sh_wr_pred = threadIdx.x < s_sh_stride; bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
@ -511,7 +558,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// Register storage for double buffer of shared memory reads. // Register storage for double buffer of shared memory reads.
FragA frag_a[2][thread_m_blocks]; FragA frag_a[2][thread_m_blocks];
I4 frag_b_quant[2]; I4 frag_b_quant[2][b_thread_vecs];
FragC frag_c[thread_m_blocks][4][2]; FragC frag_c[thread_m_blocks][4][2];
FragS frag_s[2][4]; // No act-order FragS frag_s[2][4]; // No act-order
FragS act_frag_s[2][4][4]; // For act-order FragS act_frag_s[2][4][4]; // For act-order
@ -575,7 +622,11 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
int4 *sh_b_stage = sh_b + b_sh_stage * pipe; int4 *sh_b_stage = sh_b + b_sh_stage * pipe;
#pragma unroll #pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) { for (int i = 0; i < b_sh_wr_iters; i++) {
cp_async4_stream(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); #pragma unroll
for (int j = 0; j < b_thread_vecs; j++) {
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j);
}
B_ptr[i] += b_gl_rd_delta_o; B_ptr[i] += b_gl_rd_delta_o;
} }
@ -602,15 +653,15 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// Only fetch scales if this tile starts a new group // Only fetch scales if this tile starts a new group
if (pipe % (group_blocks / thread_k_blocks) == 0) { if (pipe % (group_blocks / thread_k_blocks) == 0) {
if (s_sh_wr_pred) { if (s_sh_wr_pred) {
cp_async4_stream(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
} }
s_gl_rd += s_gl_rd_delta; s_gl_rd += s_gl_rd_delta;
} }
} else { } else {
for (int i = 0; i < s_tb_groups; i++) { for (int i = 0; i < s_tb_groups; i++) {
if (s_sh_wr_pred) { if (s_sh_wr_pred) {
cp_async4_stream(&sh_s_stage[i * s_sh_stride + s_sh_wr], cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr],
&scales_ptr[s_gl_rd]); &scales_ptr[s_gl_rd]);
} }
s_gl_rd += s_gl_rd_delta; s_gl_rd += s_gl_rd_delta;
} }
@ -641,14 +692,24 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
for (int i = 0; i < thread_m_blocks; i++) for (int i = 0; i < thread_m_blocks; i++)
ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
int4 *sh_b_stage = sh_b + b_sh_stage * pipe; int4 *sh_b_stage = sh_b + b_sh_stage * pipe;
frag_b_quant[k % 2] = *reinterpret_cast<I4 *>(
&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); #pragma unroll
for (int i = 0; i < b_thread_vecs; i++) {
frag_b_quant[k % 2][i] = *reinterpret_cast<I4 *>(
&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]);
}
}; };
bool is_same_group[stages]; bool is_same_group[stages];
int same_group_id[stages]; int same_group_id[stages];
auto init_same_group = [&](int pipe) { auto init_same_group = [&](int pipe) {
if constexpr (!has_act_order) {
is_same_group[pipe] = false;
same_group_id[pipe] = 0;
return;
}
int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
int *sh_g_idx_int_ptr = reinterpret_cast<int *>(sh_g_idx_stage); int *sh_g_idx_int_ptr = reinterpret_cast<int *>(sh_g_idx_stage);
@ -767,10 +828,23 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// dequantization and matmul operations. // dequantization and matmul operations.
#pragma unroll #pragma unroll
for (int j = 0; j < 4; j++) { for (int j = 0; j < 4; j++) {
int b_quant = frag_b_quant[k % 2][j]; FragB frag_b0;
int b_quant_shift = b_quant >> 8; FragB frag_b1;
if constexpr (num_bits == 4) {
int b_quant = frag_b_quant[k % 2][0][j];
int b_quant_shift = b_quant >> 8;
FragB frag_b0 = dequant(b_quant); frag_b0 = dequant_4bit(b_quant);
frag_b1 = dequant_4bit(b_quant_shift);
} else {
int *frag_b_quant_ptr = reinterpret_cast<int *>(frag_b_quant[k % 2]);
int b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
int b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
frag_b0 = dequant_8bit(b_quant_0);
frag_b1 = dequant_8bit(b_quant_1);
}
// Apply scale to frag_b0 // Apply scale to frag_b0
if constexpr (has_act_order) { if constexpr (has_act_order) {
@ -782,8 +856,6 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
} }
} }
FragB frag_b1 = dequant(b_quant_shift);
// Apply scale to frag_b1 // Apply scale to frag_b1
if constexpr (has_act_order) { if constexpr (has_act_order) {
scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
@ -808,13 +880,13 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// multiple warps that accumulate their partial sums of the same output // multiple warps that accumulate their partial sums of the same output
// location; which we have to reduce over in the end. We do in shared memory. // location; which we have to reduce over in the end. We do in shared memory.
auto thread_block_reduce = [&]() { auto thread_block_reduce = [&]() {
constexpr int red_off = threads / b_sh_stride / 2; constexpr int red_off = threads / b_sh_stride_threads / 2;
if (red_off >= 1) { if (red_off >= 1) {
int red_idx = threadIdx.x / b_sh_stride; int red_idx = threadIdx.x / b_sh_stride_threads;
constexpr int red_sh_stride = b_sh_stride * 4 * 2; constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
constexpr int red_sh_delta = b_sh_stride; constexpr int red_sh_delta = b_sh_stride_threads;
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
(threadIdx.x % b_sh_stride); (threadIdx.x % b_sh_stride_threads);
// Parallel logarithmic shared memory reduction. We make sure to avoid any // Parallel logarithmic shared memory reduction. We make sure to avoid any
// unnecessary read or write iterations, e.g., for two warps we write only // unnecessary read or write iterations, e.g., for two warps we write only
@ -861,7 +933,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
}; };
// Since multiple threadblocks may process parts of the same column slice, we // Since multiple threadblocks may process parts of the same column slice, we
// finally have to globally reduce over the results. As the striped portioning // finally have to globally reduce over the results. As the striped partitioning
// minimizes the number of such reductions and our outputs are usually rather // minimizes the number of such reductions and our outputs are usually rather
// small, we perform this reduction serially in L2 cache. // small, we perform this reduction serially in L2 cache.
auto global_reduce = [&](bool first = false, bool last = false) { auto global_reduce = [&](bool first = false, bool last = false) {
@ -951,13 +1023,15 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
auto write = [&](int idx, float c0, float c1, FragS &s) { auto write = [&](int idx, float c0, float c1, FragS &s) {
half2 res = __halves2half2(__float2half(c0), __float2half(c1)); half2 res = __halves2half2(__float2half(c0), __float2half(c1));
// For per-column quantization we finally apply the scale here // For per-column quantization we finally apply the scale here (only for
if constexpr (!has_act_order && group_blocks == -1) { // 4-bit)
if constexpr (!has_act_order && group_blocks == -1 && num_bits == 4) {
res = __hmul2(res, s[0]); res = __hmul2(res, s[0]);
} }
((half2 *)sh)[idx] = res; ((half2 *)sh)[idx] = res;
}; };
if (threadIdx.x / 32 < thread_n_blocks / 4) { if (threadIdx.x / 32 < thread_n_blocks / 4) {
#pragma unroll #pragma unroll
for (int i = 0; i < thread_m_blocks; i++) { for (int i = 0; i < thread_m_blocks; i++) {
@ -1023,6 +1097,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// ensure all shared memory accesses are static. Note that both pipelines // ensure all shared memory accesses are static. Note that both pipelines
// have even length meaning that the next iteration will always start at // have even length meaning that the next iteration will always start at
// index 0. // index 0.
#pragma unroll #pragma unroll
for (int pipe = 0; pipe < stages;) { for (int pipe = 0; pipe < stages;) {
#pragma unroll #pragma unroll
@ -1070,23 +1145,63 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// For per-column scales, we only fetch them here in the final step before // For per-column scales, we only fetch them here in the final step before
// write-out // write-out
if constexpr (!has_act_order && group_blocks == -1) { if constexpr (!has_act_order && group_blocks == -1) {
if (last) { if constexpr (num_bits == 8) {
if (s_sh_wr_pred) { if (s_sh_wr_pred) {
cp_async4_stream(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
} }
cp_async_fence(); cp_async_fence();
} else {
if (last) {
if (s_sh_wr_pred) {
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
}
cp_async_fence();
}
} }
} }
thread_block_reduce(); thread_block_reduce();
if constexpr (!has_act_order && group_blocks == -1) { if constexpr (!has_act_order && group_blocks == -1) {
if (last) { if constexpr (num_bits == 8) {
cp_async_wait<0>(); cp_async_wait<0>();
__syncthreads(); __syncthreads();
if (threadIdx.x / 32 < thread_n_blocks / 4) { if (threadIdx.x / 32 < thread_n_blocks / 4) {
reinterpret_cast<int4 *>(&frag_s)[0] = sh_s[s_sh_rd + 0]; reinterpret_cast<int4 *>(&frag_s)[0] = sh_s[s_sh_rd + 0];
reinterpret_cast<int4 *>(&frag_s)[1] = sh_s[s_sh_rd + 4]; reinterpret_cast<int4 *>(&frag_s)[1] = sh_s[s_sh_rd + 4];
} }
} else {
if (last) {
cp_async_wait<0>();
__syncthreads();
if (threadIdx.x / 32 < thread_n_blocks / 4) {
reinterpret_cast<int4 *>(&frag_s)[0] = sh_s[s_sh_rd + 0];
reinterpret_cast<int4 *>(&frag_s)[1] = sh_s[s_sh_rd + 4];
}
}
}
}
// For 8-bit channelwise, we apply the scale before the global reduction
// that converts the fp32 results to fp16 (so that we avoid possible
// overflow in fp16)
if constexpr (!has_act_order && group_blocks == -1 && num_bits == 8) {
if (threadIdx.x / 32 < thread_n_blocks / 4) {
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
#pragma unroll
for (int j = 0; j < 4; j++) {
scale_float(reinterpret_cast<float *>(&frag_c[i][j][0][0]),
frag_s[j / 2][2 * (j % 2) + 0]);
scale_float(reinterpret_cast<float *>(&frag_c[i][j][0][2]),
frag_s[j / 2][2 * (j % 2) + 0]);
scale_float(reinterpret_cast<float *>(&frag_c[i][j][1][0]),
frag_s[j / 2][2 * (j % 2) + 1]);
scale_float(reinterpret_cast<float *>(&frag_c[i][j][1][2]),
frag_s[j / 2][2 * (j % 2) + 1]);
}
}
} }
} }
@ -1125,28 +1240,25 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
s_gl_rd = s_sh_stride * slice_col + threadIdx.x; s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
} }
// if (blockIdx.x == 0 && threadIdx.x == 0) {
// printf("Move\n");
// }
start_pipes(); start_pipes();
} }
} }
} }
} }
#define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ #define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \ HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \
else if (thread_m_blocks == THREAD_M_BLOCKS && \ else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \ thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \ thread_k_blocks == THREAD_K_BLOCKS && \
has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \
num_threads == NUM_THREADS) { \ num_threads == NUM_THREADS) { \
cudaFuncSetAttribute( \ cudaFuncSetAttribute( \
Marlin<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ Marlin<NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS>, \ THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
Marlin<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ Marlin<NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS> \ THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS> \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \ <<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \ A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \
prob_k, locks); \ prob_k, locks); \
@ -1158,28 +1270,92 @@ typedef struct {
int num_threads; int num_threads;
} thread_config_t; } thread_config_t;
thread_config_t small_batch_thread_configs[] = { typedef struct {
int max_m_blocks;
thread_config_t tb_cfg;
} exec_config_t;
thread_config_t thread_configs[] = {
// Ordered by priority // Ordered by priority
// thread_k, thread_n, num_threads // thread_k, thread_n, num_threads
{128, 128, 256}, // Default {64, 256, 256}, // Default (max cache usage)
{128, 64, 128}, // Reduce N 2X, same K {64, 128, 128}, // Reduce N, reduce warps
{64, 256, 256}, // Reduce K 2X, increase N 2X {128, 64, 128}, // Reduce N more, but increase K
{64, 128, 128}, // Reduce K 2X, same N
}; };
thread_config_t large_batch_thread_configs[] = { int get_scales_cache_size(thread_config_t const &th_config, int prob_m,
// Ordered by priority int prob_n, int prob_k, int num_bits, int group_size,
bool has_act_order, bool is_k_full) {
bool cache_scales_chunk = has_act_order && !is_k_full;
// thread_k, thread_n, num_threads int tb_n = th_config.thread_n;
{64, 256, 256}, // Default int tb_k = th_config.thread_k;
{128, 64, 128}, // Reduce N 2X, same K
{64, 128, 128}, // Reduce N 2X, same K
// {128, 64, 128}, // Reduce N 4X, increase K 2X
};
bool is_valid_config(thread_config_t const &th_config, int prob_m, int prob_n, // Get max scale groups per thread-block
int prob_k) { int tb_groups;
if (group_size == -1) {
tb_groups = 1;
} else if (group_size == 0) {
tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size
} else {
tb_groups = div_ceil(tb_k, group_size);
}
if (cache_scales_chunk) {
int load_groups =
tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K
load_groups = max(load_groups, 32); // We load at least 32 scale groups
return load_groups * tb_n * 2;
} else {
int tb_scales = tb_groups * tb_n * 2;
return tb_scales * pipe_stages;
}
}
bool is_valid_cache_size(thread_config_t const &th_config, int max_m_blocks,
int prob_m, int prob_n, int prob_k, int num_bits,
int scales_cache_size, int max_shared_mem) {
int pack_factor = 32 / num_bits;
// Get B size
int tb_k = th_config.thread_k;
int tb_n = th_config.thread_n;
int b_size = (tb_k * tb_n / pack_factor) * 4;
// Get A size
int m_blocks = div_ceil(prob_m, 16);
int tb_max_m = 16;
while (true) {
if (m_blocks >= max_m_blocks) {
tb_max_m *= max_m_blocks;
break;
}
max_m_blocks--;
if (max_m_blocks == 0) {
TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks);
}
}
int a_size = (tb_max_m * tb_k) * 2;
float pipe_size = (a_size + b_size) * pipe_stages;
TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity
return pipe_size < 0.95f * (max_shared_mem - scales_cache_size);
}
bool is_valid_config(thread_config_t const &th_config, int max_m_blocks,
int prob_m, int prob_n, int prob_k, int num_bits,
int group_size, bool has_act_order, bool is_k_full,
int max_shared_mem) {
// Sanity // Sanity
if (th_config.thread_k == -1 || th_config.thread_n == -1 || if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
th_config.num_threads == -1) { th_config.num_threads == -1) {
@ -1201,62 +1377,79 @@ bool is_valid_config(thread_config_t const &th_config, int prob_m, int prob_n,
return false; return false;
} }
// Determine cache for scales
int scales_cache_size =
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
group_size, has_act_order, is_k_full);
// Check that pipeline fits into cache
if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k,
num_bits, scales_cache_size, max_shared_mem)) {
return false;
}
return true; return true;
} }
thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
int num_bits, int group_size,
// TODO: Enable if needed after some more testing bool has_act_order, bool is_k_full,
if (prob_m <= 0) { int max_shared_mem) {
for (auto th_config : small_batch_thread_configs) { int max_m_blocks = 4;
if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { while (max_m_blocks > 0) {
return th_config; for (auto th_config : thread_configs) {
if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full,
max_shared_mem)) {
return exec_config_t{max_m_blocks, th_config};
} }
} }
} else { printf("WARNING: Marlin kernel is reducing max_m_blocks due to small SM "
for (auto th_config : large_batch_thread_configs) { "GPU cache. This may "
if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { "hurt performance. Consider upgrading your GPU.\n");
return th_config;
} max_m_blocks--; // Process less M blocks per invocation to reduce cache
} // usage
} }
return thread_config_t{-1, -1, -1}; return exec_config_t{0, {-1, -1, -1}};
} }
#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ #define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
__CALL_IF(2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
__CALL_IF(3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
__CALL_IF(4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
\ \
__CALL_IF(1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF(1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF(1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
\ \
__CALL_IF(2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF(2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF(2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
\ \
__CALL_IF(3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF(3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF(3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
\ \
__CALL_IF(4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF(4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF(4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
void marlin_cuda(const void *A, const void *B, void *C, void *s, void *g_idx, void marlin_mm_f16i4(const void *A, const void *B, void *C, void *s,
void *perm, void *a_tmp, int prob_m, int prob_n, int prob_k, void *g_idx, void *perm, void *a_tmp, int prob_m,
void *workspace, bool has_act_order, bool is_k_full, int prob_n, int prob_k, void *workspace, int num_bits,
int num_groups, int group_size, int dev = 0, bool has_act_order, bool is_k_full, int num_groups,
cudaStream_t stream = 0, int thread_k = -1, int thread_n = -1, int group_size, int dev, cudaStream_t stream, int thread_k,
int sms = -1, int max_par = 16) { int thread_n, int sms, int max_par) {
TORCH_CHECK(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits);
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
", ", prob_n, ", ", prob_k, "]"); ", ", prob_n, ", ", prob_k, "]");
@ -1274,25 +1467,34 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, void *g_idx,
TORCH_CHECK(max_shared_mem > 0); TORCH_CHECK(max_shared_mem > 0);
// Set thread config // Set thread config
thread_config_t th_config; exec_config_t exec_cfg;
if (thread_k != -1 && thread_n != -1) { if (thread_k != -1 && thread_n != -1) {
// User-defined config // User-defined config
th_config = thread_config_t{thread_k, thread_n, default_threads}; exec_cfg =
exec_config_t{4, thread_config_t{thread_k, thread_n, default_threads}};
} else { } else {
// Auto config // Auto config
th_config = determine_thread_config(prob_m, prob_n, prob_k); exec_cfg =
determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size,
has_act_order, is_k_full, max_shared_mem);
} }
TORCH_CHECK(is_valid_config(th_config, prob_m, prob_n, prob_k), TORCH_CHECK(exec_cfg.max_m_blocks > 0 &&
"Invalid thread config: thread_k = " + str(th_config.thread_k) + is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks,
", thread_n = " + str(th_config.thread_n) + prob_m, prob_n, prob_k, num_bits, group_size,
", num_threads = " + str(th_config.num_threads) + has_act_order, is_k_full, max_shared_mem),
" for MKN = [" + str(prob_m) + ", " + str(prob_k) + ", " + "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks,
str(prob_n) + "]"); ", thread_k = ", exec_cfg.tb_cfg.thread_k,
", thread_n = ", exec_cfg.tb_cfg.thread_n,
", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [",
prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
", group_size = ", group_size,
", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
", max_shared_mem = ", max_shared_mem);
int num_threads = th_config.num_threads; int num_threads = exec_cfg.tb_cfg.num_threads;
thread_k = th_config.thread_k; thread_k = exec_cfg.tb_cfg.thread_k;
thread_n = th_config.thread_n; thread_n = exec_cfg.tb_cfg.thread_n;
int thread_k_blocks = thread_k / 16; int thread_k_blocks = thread_k / 16;
int thread_n_blocks = thread_n / 16; int thread_n_blocks = thread_n / 16;
@ -1352,28 +1554,32 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, void *g_idx,
} }
// Main loop // Main loop
for (int i = 0; i < tot_m_blocks; i += 4) { for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) {
int thread_m_blocks = tot_m_blocks - i; int thread_m_blocks = tot_m_blocks - i;
prob_m = tot_m - 16 * i; prob_m = tot_m - 16 * i;
int par = 1; int par = 1;
if (thread_m_blocks > 4) { if (thread_m_blocks > exec_cfg.max_m_blocks) {
// Note that parallel > 1 currently only works for inputs without any // Note that parallel > 1 currently only works for inputs without any
// padding // padding
par = (16 * thread_m_blocks - pad) / 64; par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks);
if (par > max_par) if (par > max_par)
par = max_par; par = max_par;
prob_m = 64 * par; prob_m = (16 * exec_cfg.max_m_blocks) * par;
i += 4 * (par - 1); i += exec_cfg.max_m_blocks * (par - 1);
thread_m_blocks = 4; thread_m_blocks = exec_cfg.max_m_blocks;
} }
// Define kernel configurations // Define kernel configurations
if (false) { if (false) {
} }
CALL_IF(16, 4, 256) CALL_IF(4, 32, 2, 256)
CALL_IF(8, 8, 256) CALL_IF(4, 16, 4, 256)
CALL_IF(8, 4, 128) CALL_IF(4, 8, 4, 128)
CALL_IF(4, 8, 128) CALL_IF(4, 4, 8, 128)
CALL_IF(8, 32, 2, 256)
CALL_IF(8, 16, 4, 256)
CALL_IF(8, 8, 4, 128)
CALL_IF(8, 4, 8, 128)
else { else {
TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " +
str(prob_n) + ", " + str(prob_k) + "]" + str(prob_n) + ", " + str(prob_k) + "]" +
@ -1395,33 +1601,32 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, void *g_idx,
torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
torch::Tensor &b_scales, torch::Tensor &g_idx, torch::Tensor &b_scales, torch::Tensor &g_idx,
torch::Tensor &perm, torch::Tensor &workspace, torch::Tensor &perm, torch::Tensor &workspace,
int64_t size_m, int64_t size_n, int64_t size_k, int64_t num_bits, int64_t size_m, int64_t size_n,
bool is_k_full) { int64_t size_k, bool is_k_full) {
// Verify num_bits
TORCH_CHECK(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits);
int pack_factor = 32 / num_bits;
// Verify A // Verify A
TORCH_CHECK(a.size(0) == size_m, TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0),
"Shape mismatch: a.size(0) = " + str(a.size(0)) + ", size_m = ", size_m);
", size_m = " + str(size_m)); TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1),
TORCH_CHECK(a.size(1) == size_k, ", size_k = ", size_k);
"Shape mismatch: a.size(1) = " + str(a.size(1)) +
", size_k = " + str(size_k));
// Verify B // Verify B
TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, "size_k = ", size_k,
"size_k = " + str(size_k) + " is not divisible by tile_size = " + " is not divisible by tile_size = ", gptq_marlin::tile_size);
str(gptq_marlin::tile_size));
TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0), TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0),
"Shape mismatch: b_q_weight.size(0) = " + "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
str(b_q_weight.size(0)) + ", size_k = " + str(size_k) + ", size_k = ", size_k, ", tile_size = ", gptq_marlin::tile_size);
", tile_size = " + str(gptq_marlin::tile_size)); TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0,
TORCH_CHECK( "b_q_weight.size(1) = ", b_q_weight.size(1),
b_q_weight.size(1) % gptq_marlin::tile_size == 0, " is not divisible by tile_size = ", gptq_marlin::tile_size);
"b_q_weight.size(1) = " + str(b_q_weight.size(1)) + int actual_size_n =
" is not divisible by tile_size = " + str(gptq_marlin::tile_size)); (b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor;
int actual_size_n = (b_q_weight.size(1) / gptq_marlin::tile_size) * TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
gptq_marlin::pack_factor_4bit; ", actual_size_n = ", actual_size_n);
TORCH_CHECK(size_n == actual_size_n,
"size_n = " + str(size_n) +
", actual_size_n = " + str(actual_size_n));
// Verify device and strides // Verify device and strides
TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
@ -1457,9 +1662,9 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
// Verify g_idx and perm // Verify g_idx and perm
TORCH_CHECK((g_idx.size(0) == 0 && perm.size(0) == 0) || TORCH_CHECK((g_idx.size(0) == 0 && perm.size(0) == 0) ||
(g_idx.size(0) == size_k && perm.size(0) == size_k), (g_idx.size(0) == size_k && perm.size(0) == size_k),
"Unexpected g_idx.size(0) = " + str(g_idx.size(0)) + "Unexpected g_idx.size(0) = ", g_idx.size(0),
" and perm.size(0) = " + str(perm.size(0)) + " and perm.size(0) = ", perm.size(0),
", where size_k = " + str(size_k)); ", where size_k = ", size_k);
// Detect groupsize and act_order // Detect groupsize and act_order
int num_groups = -1; int num_groups = -1;
@ -1475,9 +1680,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
if (has_act_order) { if (has_act_order) {
if (is_k_full) { if (is_k_full) {
TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1"); TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1");
TORCH_CHECK(size_k % num_groups == 0, TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k,
"size_k = " + str(size_k) + ", is not divisible by num_groups = ", num_groups);
", is not divisible by num_groups = " + str(num_groups));
group_size = size_k / num_groups; group_size = size_k / num_groups;
} else { } else {
group_size = 0; group_size = 0;
@ -1485,10 +1689,9 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
} else { } else {
if (num_groups > 1) { if (num_groups > 1) {
TORCH_CHECK(size_k % num_groups == 0, TORCH_CHECK(
"size_k = " + str(size_k) + size_k % num_groups == 0, "size_k = ", size_k,
", is not divisible by b_scales.size(0) = " + ", is not divisible by b_scales.size(0) = ", b_scales.size(0));
str(b_scales.size(0)));
group_size = size_k / num_groups; group_size = size_k / num_groups;
} else { } else {
group_size = -1; group_size = -1;
@ -1496,23 +1699,22 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
} }
// Verify workspace size // Verify workspace size
TORCH_CHECK(size_n % gptq_marlin::min_thread_n == 0, TORCH_CHECK(
"size_n = " + str(size_n) + size_n % gptq_marlin::min_thread_n == 0, "size_n = ", size_n,
", is not divisible by min_thread_n = " + ", is not divisible by min_thread_n = ", gptq_marlin::min_thread_n);
str(gptq_marlin::min_thread_n));
int min_workspace_size = int min_workspace_size =
(size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par; (size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par;
TORCH_CHECK(workspace.numel() >= min_workspace_size, TORCH_CHECK(workspace.numel() >= min_workspace_size,
"workspace.numel = " + str(workspace.numel()) + "workspace.numel = ", workspace.numel(),
" is below min_workspace_size = " + str(min_workspace_size)); " is below min_workspace_size = ", min_workspace_size);
int dev = a.get_device(); int dev = a.get_device();
gptq_marlin::marlin_cuda( gptq_marlin::marlin_mm_f16i4(
a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), b_scales.data_ptr(), a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), b_scales.data_ptr(),
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n,
size_k, workspace.data_ptr(), has_act_order, is_k_full, num_groups, size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full,
group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
sms, gptq_marlin::max_par); thread_k, thread_n, sms, gptq_marlin::max_par);
return c; return c;
} }

View File

@ -24,8 +24,6 @@ static constexpr int min_thread_k = 64;
static constexpr int tile_size = 16; static constexpr int tile_size = 16;
static constexpr int max_par = 16; static constexpr int max_par = 16;
static constexpr int pack_factor_4bit = 8; // We have 8 4-bit vals inside a 32 bit
template <typename T, int n> template <typename T, int n>
struct Vec { struct Vec {
T elems[n]; T elems[n];
@ -51,13 +49,11 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool
"r"(smem), "l"(glob_ptr), "n"(BYTES)); "r"(smem), "l"(glob_ptr), "n"(BYTES));
} }
__device__ inline void cp_async4_stream(void* smem_ptr, const void* glob_ptr) { __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
const int BYTES = 16; const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("{\n" asm volatile("{\n"
" .reg .b64 p;\n" " cp.async.cg.shared.global [%0], [%1], %2;\n"
" createpolicy.fractional.L2::evict_first.b64 p, 1.0;"
" cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n"
"}\n" ::"r"(smem), "}\n" ::"r"(smem),
"l"(glob_ptr), "n"(BYTES)); "l"(glob_ptr), "n"(BYTES));
} }

View File

@ -11,7 +11,7 @@ static constexpr int tile_n_size = tile_k_size * 4;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
template <int const num_threads, bool const has_perm> template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void __global__ void
marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
uint32_t const *__restrict__ perm_ptr, uint32_t const *__restrict__ perm_ptr,
@ -20,7 +20,8 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
} // namespace gptq_marlin } // namespace gptq_marlin
torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
int64_t size_k, int64_t size_n) { int64_t size_k, int64_t size_n,
int64_t num_bits) {
TORCH_CHECK_NOT_IMPLEMENTED( TORCH_CHECK_NOT_IMPLEMENTED(
false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0"); false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0");
return torch::empty({1, 1}); return torch::empty({1, 1});
@ -28,11 +29,13 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
#else #else
template <int const num_threads, bool const has_perm> template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void __global__ void
marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
uint32_t const *__restrict__ perm_ptr, uint32_t const *__restrict__ perm_ptr,
uint32_t *__restrict__ out_ptr, int size_k, int size_n) { uint32_t *__restrict__ out_ptr, int size_k, int size_n) {
constexpr int pack_factor = 32 / num_bits;
int k_tiles = size_k / tile_k_size; int k_tiles = size_k / tile_k_size;
int n_tiles = size_n / tile_n_size; int n_tiles = size_n / tile_n_size;
int block_k_tiles = div_ceil(k_tiles, gridDim.x); int block_k_tiles = div_ceil(k_tiles, gridDim.x);
@ -64,9 +67,10 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
sh_pipe_ptr += perm_size; sh_pipe_ptr += perm_size;
} }
constexpr int tile_ints = tile_k_size / pack_factor;
constexpr int stage_n_threads = tile_n_size / 4; constexpr int stage_n_threads = tile_n_size / 4;
constexpr int stage_k_threads = constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints;
has_perm ? tile_k_size : tile_k_size / pack_factor_4bit;
constexpr int stage_size = stage_k_threads * stage_n_threads; constexpr int stage_size = stage_k_threads * stage_n_threads;
auto load_perm_to_shared = [&](int k_tile_id) { auto load_perm_to_shared = [&](int k_tile_id) {
@ -99,9 +103,9 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
reinterpret_cast<uint32_t const *>(sh_perm_ptr); reinterpret_cast<uint32_t const *>(sh_perm_ptr);
int src_k = sh_perm_int_ptr[k_id]; int src_k = sh_perm_int_ptr[k_id];
int src_k_packed = src_k / pack_factor_4bit; int src_k_packed = src_k / pack_factor;
cp_async4_stream( cp_async4(
&sh_ptr[k_id * stage_n_threads + n_id], &sh_ptr[k_id * stage_n_threads + n_id],
reinterpret_cast<int4 const *>(&( reinterpret_cast<int4 const *>(&(
b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)]))); b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)])));
@ -113,12 +117,12 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
int n_id = threadIdx.x % stage_n_threads; int n_id = threadIdx.x % stage_n_threads;
int first_k = k_tile_id * tile_k_size; int first_k = k_tile_id * tile_k_size;
int first_k_packed = first_k / pack_factor_4bit; int first_k_packed = first_k / pack_factor;
cp_async4_stream(&sh_ptr[k_id * stage_n_threads + n_id], cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
reinterpret_cast<int4 const *>( reinterpret_cast<int4 const *>(
&(b_q_weight_ptr[(first_k_packed + k_id) * size_n + &(b_q_weight_ptr[(first_k_packed + k_id) * size_n +
first_n + (n_id * 4)]))); first_n + (n_id * 4)])));
} }
} }
@ -145,26 +149,27 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
int cur_n = warp_id * 16 + tc_col; int cur_n = warp_id * 16 + tc_col;
constexpr int sh_stride = 64; constexpr int sh_stride = 64;
constexpr uint32_t mask = (1 << num_bits) - 1;
int4 *sh_stage_ptr = sh_pipe_ptr + stage_size * pipe; int4 *sh_stage_ptr = sh_pipe_ptr + stage_size * pipe;
uint32_t *sh_stage_int_ptr = reinterpret_cast<uint32_t *>(sh_stage_ptr); uint32_t *sh_stage_int_ptr = reinterpret_cast<uint32_t *>(sh_stage_ptr);
uint32_t *sh_perm_int_ptr = reinterpret_cast<uint32_t *>(sh_perm_ptr); uint32_t *sh_perm_int_ptr = reinterpret_cast<uint32_t *>(sh_perm_ptr);
uint32_t vals[pack_factor_4bit]; uint32_t vals[8];
if constexpr (has_perm) { if constexpr (has_perm) {
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
int k_idx = tc_row + tc_offsets[i]; int k_idx = tc_row + tc_offsets[i];
uint32_t src_k = sh_perm_int_ptr[k_idx]; uint32_t src_k = sh_perm_int_ptr[k_idx];
uint32_t src_k_pos = src_k % pack_factor_4bit; uint32_t src_k_pos = src_k % pack_factor;
uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n]; uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n];
uint32_t b1_cur_val = (b1_val >> (src_k_pos * 4)) & 0xf; uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask;
uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8]; uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8];
uint32_t b2_cur_val = (b2_val >> (src_k_pos * 4)) & 0xf; uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask;
vals[i] = b1_cur_val; vals[i] = b1_cur_val;
vals[4 + i] = b2_cur_val; vals[4 + i] = b2_cur_val;
@ -172,41 +177,56 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
} else { } else {
uint32_t b1_val_1 = sh_stage_int_ptr[cur_n]; uint32_t b1_vals[tile_ints];
uint32_t b1_val_2 = sh_stage_int_ptr[sh_stride + cur_n]; uint32_t b2_vals[tile_ints];
uint32_t b2_val_1 = sh_stage_int_ptr[cur_n + 8];
uint32_t b2_val_2 = sh_stage_int_ptr[sh_stride + cur_n + 8];
#pragma unroll #pragma unroll
for (int i = 0; i < 2; i++) { for (int i = 0; i < tile_ints; i++) {
int cur_elem = tc_row + tc_offsets[i]; b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];
vals[i] = (b1_val_1 >> (cur_elem * 4)) & 0xf; b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];
vals[4 + i] = (b2_val_1 >> (cur_elem * 4)) & 0xf;
} }
#pragma unroll #pragma unroll
for (int i = 2; i < 4; i++) { for (int i = 0; i < 4; i++) {
int cur_elem = tc_row + tc_offsets[i] - 8; int cur_elem = tc_row + tc_offsets[i];
vals[i] = (b1_val_2 >> (cur_elem * 4)) & 0xf; int cur_int = cur_elem / pack_factor;
vals[4 + i] = (b2_val_2 >> (cur_elem * 4)) & 0xf; int cur_pos = cur_elem % pack_factor;
vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask;
vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask;
} }
} }
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
// Result of: // Result of:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
constexpr int pack_idx[pack_factor_4bit] = {0, 2, 4, 6, 1, 3, 5, 7}; if constexpr (num_bits == 4) {
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
uint32_t res = 0; uint32_t res = 0;
#pragma unroll #pragma unroll
for (int i = 0; i < pack_factor_4bit; i++) { for (int i = 0; i < 8; i++) {
res |= vals[pack_idx[i]] << (i * 4); res |= vals[pack_idx[i]] << (i * 4);
}
out_ptr[out_offset + th_id * 4 + warp_id] = res;
} else {
constexpr int pack_idx[4] = {0, 2, 1, 3};
uint32_t res1 = 0;
uint32_t res2 = 0;
#pragma unroll
for (int i = 0; i < 4; i++) {
res1 |= vals[pack_idx[i]] << (i * 8);
res2 |= vals[4 + pack_idx[i]] << (i * 8);
}
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;
} }
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor_4bit;
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
out_ptr[out_offset + th_id * 4 + warp_id] = res;
}; };
auto start_pipes = [&](int k_tile_id, int n_tile_id) { auto start_pipes = [&](int k_tile_id, int n_tile_id) {
@ -242,19 +262,35 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
} // namespace gptq_marlin } // namespace gptq_marlin
#define CALL_IF(NUM_BITS, HAS_PERM) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
cudaFuncSetAttribute( \
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, \
NUM_BITS, HAS_PERM>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, NUM_BITS, \
HAS_PERM> \
<<<blocks, gptq_marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
}
torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
int64_t size_k, int64_t size_n) { int64_t size_k, int64_t size_n,
int64_t num_bits) {
// Verify compatibility with marlin tile of 16x64 // Verify compatibility with marlin tile of 16x64
TORCH_CHECK(size_k % gptq_marlin::tile_k_size == 0, "size_k = ", size_k, TORCH_CHECK(size_k % gptq_marlin::tile_k_size == 0, "size_k = ", size_k,
" is not divisible by tile_k_size = ", gptq_marlin::tile_k_size); " is not divisible by tile_k_size = ", gptq_marlin::tile_k_size);
TORCH_CHECK(size_n % gptq_marlin::tile_n_size == 0, "size_n = ", size_n, TORCH_CHECK(size_n % gptq_marlin::tile_n_size == 0, "size_n = ", size_n,
" is not divisible by tile_n_size = ", gptq_marlin::tile_n_size); " is not divisible by tile_n_size = ", gptq_marlin::tile_n_size);
TORCH_CHECK(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits);
int const pack_factor = 32 / num_bits;
// Verify B // Verify B
TORCH_CHECK((size_k / gptq_marlin::pack_factor_4bit) == b_q_weight.size(0), TORCH_CHECK((size_k / pack_factor) == b_q_weight.size(0),
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
", size_k = ", size_k, ", size_k = ", size_k, ", pack_factor = ", pack_factor);
", pack_factor_4bit = ", gptq_marlin::pack_factor_4bit);
TORCH_CHECK(b_q_weight.size(1) == size_n, TORCH_CHECK(b_q_weight.size(1) == size_n,
"b_q_weight.size(1) = ", b_q_weight.size(1), "b_q_weight.size(1) = ", b_q_weight.size(1),
" is not size_n = ", size_n); " is not size_n = ", size_n);
@ -273,10 +309,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
auto options = torch::TensorOptions() auto options = torch::TensorOptions()
.dtype(b_q_weight.dtype()) .dtype(b_q_weight.dtype())
.device(b_q_weight.device()); .device(b_q_weight.device());
torch::Tensor out = torch::empty( torch::Tensor out =
{size_k / gptq_marlin::tile_size, torch::empty({size_k / gptq_marlin::tile_size,
size_n * gptq_marlin::tile_size / gptq_marlin::pack_factor_4bit}, size_n * gptq_marlin::tile_size / pack_factor},
options); options);
// Detect if there is act_order // Detect if there is act_order
bool has_perm = perm.size(0) != 0; bool has_perm = perm.size(0) != 0;
@ -299,23 +335,15 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
TORCH_CHECK(max_shared_mem > 0); TORCH_CHECK(max_shared_mem > 0);
if (has_perm) { if (false) {
cudaFuncSetAttribute( }
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, true>, CALL_IF(4, false)
cudaFuncAttributeMaxDynamicSharedMemorySize, CALL_IF(4, true)
max_shared_mem); CALL_IF(8, false)
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, true> CALL_IF(8, true)
<<<blocks, gptq_marlin::repack_threads, max_shared_mem, else {
stream>>>(b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits,
", has_perm = ", has_perm);
} else {
cudaFuncSetAttribute(
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, false>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
max_shared_mem);
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, false>
<<<blocks, gptq_marlin::repack_threads, max_shared_mem,
stream>>>(b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n);
} }
return out; return out;

View File

@ -39,6 +39,13 @@ MODELS = [
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-4bit-64g-actorder_True"), ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-4bit-64g-actorder_True"),
# act_order==True, group_size=32 # act_order==True, group_size=32
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-4bit-32g-actorder_True"), ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-4bit-32g-actorder_True"),
# 8-bit, act_order==True, group_size=channelwise
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-8bit--1g-actorder_True"),
# 8-bit, act_order==True, group_size=128
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-8bit-128g-actorder_True"),
# 8-bit, act_order==True, group_size=32
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-8bit-32g-actorder_True"),
] ]
@ -65,8 +72,7 @@ def test_models(
dtype=dtype, dtype=dtype,
quantization="marlin", quantization="marlin",
max_model_len=MAX_MODEL_LEN, max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=1, tensor_parallel_size=1)
disable_custom_all_reduce=True)
gptq_marlin_outputs = gptq_marlin_model.generate_greedy_logprobs( gptq_marlin_outputs = gptq_marlin_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs) example_prompts, max_tokens, num_logprobs)
@ -78,8 +84,7 @@ def test_models(
dtype=dtype, dtype=dtype,
quantization="gptq", quantization="gptq",
max_model_len=MAX_MODEL_LEN, max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=1, tensor_parallel_size=1)
disable_custom_all_reduce=True)
gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts, gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts,
max_tokens, max_tokens,
num_logprobs) num_logprobs)

View File

@ -169,18 +169,20 @@ def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
# gptq_marlin # gptq_marlin
def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
size_k: int, size_n: int) -> torch.Tensor: size_k: int, size_n: int,
return vllm_ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n) num_bits: int) -> torch.Tensor:
return vllm_ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n,
num_bits)
def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, g_idx: torch.Tensor, b_scales: torch.Tensor, g_idx: torch.Tensor,
perm: torch.Tensor, workspace: torch.Tensor, size_m: int, perm: torch.Tensor, workspace: torch.Tensor,
size_n: int, size_k: int, num_bits: int, size_m: int, size_n: int, size_k: int,
is_k_full: bool) -> torch.Tensor: is_k_full: bool) -> torch.Tensor:
return vllm_ops.gptq_marlin_gemm(a, b_q_weight, b_scales, g_idx, perm, return vllm_ops.gptq_marlin_gemm(a, b_q_weight, b_scales, g_idx, perm,
workspace, size_m, size_n, size_k, workspace, num_bits, size_m, size_n,
is_k_full) size_k, is_k_full)
# fp8 # fp8

View File

@ -2,7 +2,6 @@ import enum
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import numpy
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
@ -17,41 +16,13 @@ GPTQ_MARLIN_MIN_THREAD_N = 64
GPTQ_MARLIN_MIN_THREAD_K = 128 GPTQ_MARLIN_MIN_THREAD_K = 128
GPTQ_MARLIN_MAX_PARALLEL = 16 GPTQ_MARLIN_MAX_PARALLEL = 16
GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4] GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8]
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
GPTQ_MARLIN_SUPPORTED_SYM = [True] GPTQ_MARLIN_SUPPORTED_SYM = [True]
# Precompute permutations for Marlin weight and scale shuffling # Permutations for Marlin scale shuffling
# def get_scale_perms(num_bits):
# Marlin works on [16,64] tiles. The goal of the permutations
# is to reorder the weight data so that it is compatible
# with the tensor-core format that is described here:
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
#
# As a result of this reordering, the vector loads inside the
# kernel will get the data as it is needed for tensor-core
# (without the need to use ldmatrix instructions)
def _get_perms():
perm = []
for i in range(32):
perm1 = []
col = i // 4
for block in [0, 1]:
for row in [
2 * (i % 4),
2 * (i % 4) + 1,
2 * (i % 4 + 4),
2 * (i % 4 + 4) + 1,
]:
perm1.append(16 * row + col + 8 * block)
for j in range(4):
perm.extend([p + 256 * j for p in perm1])
perm = numpy.array(perm)
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
perm = perm.reshape((-1, 8))[:, interleave].ravel() # type: ignore
perm = torch.from_numpy(perm)
scale_perm = [] scale_perm = []
for i in range(8): for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)]) scale_perm.extend([i + 8 * j for j in range(8)])
@ -59,23 +30,21 @@ def _get_perms():
for i in range(4): for i in range(4):
scale_perm_single.extend( scale_perm_single.extend(
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return perm, scale_perm, scale_perm_single return scale_perm, scale_perm_single
_perm, _scale_perm, _scale_perm_single = _get_perms()
def get_pack_factor(num_bits): def get_pack_factor(num_bits):
assert num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS, ( assert (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS
f"Unsupported num_bits = {num_bits}") ), f"Unsupported num_bits = {num_bits}"
return 32 // num_bits return 32 // num_bits
def marlin_permute_scales(s, size_k, size_n, group_size): def marlin_permute_scales(s, size_k, size_n, group_size, num_bits):
scale_perm, scale_perm_single = get_scale_perms(num_bits)
if group_size < size_k and group_size != -1: if group_size < size_k and group_size != -1:
s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm] s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
else: else:
s = s.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single] s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
s = s.reshape((-1, size_n)).contiguous() s = s.reshape((-1, size_n)).contiguous()
return s return s
@ -279,13 +248,15 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
requires_grad=False, requires_grad=False,
) )
set_weight_attrs( set_weight_attrs(
qweight, { qweight,
{
**extra_weight_attrs, **extra_weight_attrs,
"input_dim": 0, "input_dim": 0,
"output_dim": 1, "output_dim": 1,
"packed_dim": 0, "packed_dim": 0,
"pack_factor": self.quant_config.pack_factor, "pack_factor": self.quant_config.pack_factor,
}) },
)
# Activation order # Activation order
g_idx = Parameter( g_idx = Parameter(
@ -296,10 +267,13 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
requires_grad=False, requires_grad=False,
) )
# Ignore warning from fused linear layers such as QKVParallelLinear. # Ignore warning from fused linear layers such as QKVParallelLinear.
set_weight_attrs(g_idx, { set_weight_attrs(
**extra_weight_attrs, "input_dim": 0, g_idx,
"ignore_warning": True {
}) **extra_weight_attrs, "input_dim": 0,
"ignore_warning": True
},
)
g_idx_sort_indices = Parameter( g_idx_sort_indices = Parameter(
torch.empty( torch.empty(
@ -320,29 +294,34 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
requires_grad=False, requires_grad=False,
) )
set_weight_attrs( set_weight_attrs(
scales, { scales,
{
**extra_weight_attrs, **extra_weight_attrs,
"input_dim": scales_and_zp_input_dim, "input_dim": scales_and_zp_input_dim,
"output_dim": 1, "output_dim": 1,
}) },
)
# Quantized zero-points # Quantized zero-points
qzeros = Parameter( qzeros = Parameter(
torch.empty(scales_and_zp_size, torch.empty(
output_size_per_partition // scales_and_zp_size,
self.quant_config.pack_factor, output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32, dtype=torch.int32,
device="meta"), device="meta",
),
requires_grad=False, requires_grad=False,
) )
set_weight_attrs( set_weight_attrs(
qzeros, { qzeros,
{
**extra_weight_attrs, **extra_weight_attrs,
"input_dim": scales_and_zp_input_dim, "input_dim": scales_and_zp_input_dim,
"output_dim": 1, "output_dim": 1,
"packed_dim": 1, "packed_dim": 1,
"pack_factor": self.quant_config.pack_factor, "pack_factor": self.quant_config.pack_factor,
}) },
)
# Allocate marlin workspace # Allocate marlin workspace
max_workspace_size = ( max_workspace_size = (
@ -405,13 +384,14 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
else: else:
# Reset g_idx related tensors # Reset g_idx related tensors
layer.g_idx = Parameter(torch.empty(0, layer.g_idx = Parameter(
dtype=torch.int, torch.empty(0, dtype=torch.int, device=cur_device),
device=cur_device), requires_grad=False,
requires_grad=False) )
layer.g_idx_sort_indices = Parameter(torch.empty( layer.g_idx_sort_indices = Parameter(
0, dtype=torch.int, device=cur_device), torch.empty(0, dtype=torch.int, device=cur_device),
requires_grad=False) requires_grad=False,
)
# Repack weights # Repack weights
marlin_qweight = ops.gptq_marlin_repack( marlin_qweight = ops.gptq_marlin_repack(
@ -419,6 +399,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
layer.g_idx_sort_indices, layer.g_idx_sort_indices,
part_size_k, part_size_k,
part_size_n, part_size_n,
self.quant_config.weight_bits,
) )
replace_tensor("qweight", marlin_qweight) replace_tensor("qweight", marlin_qweight)
@ -428,15 +409,28 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
if self.quant_config.desc_act: if self.quant_config.desc_act:
scales_size_k = full_size_k scales_size_k = full_size_k
marlin_scales = marlin_permute_scales(layer.scales, scales_size_k, marlin_scales = marlin_permute_scales(
scales_size_n, layer.scales,
self.quant_config.group_size) scales_size_k,
scales_size_n,
self.quant_config.group_size,
self.quant_config.weight_bits,
)
replace_tensor("scales", marlin_scales) replace_tensor("scales", marlin_scales)
output = ops.gptq_marlin_gemm(reshaped_x, layer.qweight, layer.scales, output = ops.gptq_marlin_gemm(
layer.g_idx, layer.g_idx_sort_indices, reshaped_x,
layer.workspace, size_m, part_size_n, layer.qweight,
part_size_k, layer.is_k_full) layer.scales,
layer.g_idx,
layer.g_idx_sort_indices,
layer.workspace,
self.quant_config.weight_bits,
size_m,
part_size_n,
part_size_k,
layer.is_k_full,
)
if bias is not None: if bias is not None:
output.add_(bias) # In-place add output.add_(bias) # In-place add