[Kernel] Support running GPTQ 8-bit models in Marlin (#4533)

This commit is contained in:
alexm-nm 2024-05-02 12:56:22 -04:00 committed by GitHub
parent 2a85f93007
commit 7038e8b803
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 552 additions and 323 deletions

View File

@ -132,6 +132,7 @@ torch::Tensor gptq_marlin_gemm(
torch::Tensor &g_idx,
torch::Tensor &perm,
torch::Tensor &workspace,
int64_t num_bits,
int64_t size_m,
int64_t size_n,
int64_t size_k,
@ -141,7 +142,8 @@ torch::Tensor gptq_marlin_repack(
torch::Tensor &b_q_weight,
torch::Tensor &perm,
int64_t size_k,
int64_t size_n);
int64_t size_n,
int64_t num_bits);
#endif
void squeezellm_gemm(

View File

@ -32,7 +32,8 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr,
int4 *__restrict__ out_int4_ptr, int size_m,
int size_k, int block_rows) {}
template <const int threads, // number of threads in a threadblock
template <const int num_bits, // number of bits used for weights
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the threadblock
const int thread_n_blocks, // same for n dimension (output)
@ -62,8 +63,8 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
torch::Tensor &b_scales, torch::Tensor &g_idx,
torch::Tensor &perm, torch::Tensor &workspace,
int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full) {
int64_t num_bits, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full) {
TORCH_CHECK_NOT_IMPLEMENTED(false,
"marlin_gemm(..) requires CUDA_ARCH >= 8.0");
return torch::empty({1, 1});
@ -114,11 +115,21 @@ template <int lut> __device__ inline int lop3(int a, int b, int c) {
return res;
}
// Constructs destination register by taking bytes from 2 sources (based on mask)
template <int start_byte, int mask>
__device__ inline uint32_t prmt(uint32_t a) {
uint32_t res;
asm volatile("prmt.b32 %0, %1, %2, %3;\n"
: "=r"(res)
: "r"(a), "n"(start_byte), "n"(mask));
return res;
}
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
// values. We mostly follow the strategy in the link below, with some small
// changes:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
__device__ inline FragB dequant(int q) {
__device__ inline FragB dequant_4bit(int q) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
@ -139,6 +150,24 @@ __device__ inline FragB dequant(int q) {
return frag_b;
}
__device__ inline FragB dequant_8bit(int q) {
static constexpr uint32_t mask_for_elt_01 = 0x5250;
static constexpr uint32_t mask_for_elt_23 = 0x5351;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
FragB frag_b;
frag_b[0] = __hsub2(*reinterpret_cast<half2 *>(&lo),
*reinterpret_cast<const half2 *>(&I8s_TO_F16s_MAGIC_NUM));
frag_b[1] = __hsub2(*reinterpret_cast<half2 *>(&hi),
*reinterpret_cast<const half2 *>(&I8s_TO_F16s_MAGIC_NUM));
return frag_b;
}
// Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization.
__device__ inline void scale(FragB &frag_b, FragS &frag_s, int i) {
@ -162,6 +191,13 @@ __device__ inline void scale4(FragB &frag_b, FragS &frag_s_1, FragS &frag_s_2,
frag_b[1] = __hmul2(frag_b[1], s_val_3_4);
}
// Given 2 floats multiply by 2 scales (halves)
__device__ inline void scale_float(float *c, FragS &s) {
__half *s_ptr = reinterpret_cast<__half *>(&s);
c[0] = __fmul_rn(c[0], __half2float(s_ptr[0]));
c[1] = __fmul_rn(c[1], __half2float(s_ptr[1]));
}
// Wait until barrier reaches `count`, then lock for current threadblock.
__device__ inline void barrier_acquire(int *lock, int count) {
if (threadIdx.x == 0) {
@ -250,7 +286,8 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr,
}
}
template <const int threads, // number of threads in a threadblock
template <const int num_bits, // number of bits used for weights
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the threadblock
const int thread_n_blocks, // same for n dimension (output)
@ -286,6 +323,8 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// configurations, while requiring as few slow global cross-threadblock
// reductions as possible.
constexpr int pack_factor = 32 / num_bits;
// For larger GEMMs we run multiple batchsize 64 versions in parallel for a
// better partitioning with less reductions
int parallel = 1;
@ -385,21 +424,25 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta);
// B sizes/strides
int b_gl_stride = 16 * prob_n / 32;
constexpr int b_sh_stride = 32 * thread_n_blocks / 4;
int b_gl_stride = 16 * prob_n / (pack_factor * 4);
constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4;
constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2;
constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs;
int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride);
constexpr int b_sh_wr_delta = threads;
constexpr int b_sh_rd_delta = threads;
int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads);
constexpr int b_sh_wr_delta = threads * b_thread_vecs;
constexpr int b_sh_rd_delta = threads * b_thread_vecs;
constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
// Scale sizes/strides without act_order
int s_gl_stride = prob_n / 8;
constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
constexpr int s_tb_groups = !has_act_order && group_blocks < thread_k_blocks
? thread_k_blocks / group_blocks
: 1;
constexpr int s_tb_groups =
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks
? thread_k_blocks / group_blocks
: 1;
constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
int s_gl_rd_delta = s_gl_stride;
@ -425,12 +468,12 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;
a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
int b_gl_rd =
b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) +
(threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
b_gl_rd += b_sh_stride * slice_col;
b_gl_rd += b_gl_rd_delta_o * slice_row;
int b_sh_wr = threadIdx.x;
int b_sh_rd = threadIdx.x;
int b_sh_wr = threadIdx.x * b_thread_vecs;
int b_sh_rd = threadIdx.x * b_thread_vecs;
// For act_order
constexpr int k_iter_size = tb_k / b_sh_wr_iters;
@ -442,8 +485,12 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// No act_order
int s_gl_rd;
if constexpr (!has_act_order) {
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
s_sh_stride * slice_col + threadIdx.x;
if constexpr (group_blocks == -1) {
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
} else {
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
s_sh_stride * slice_col + threadIdx.x;
}
}
int s_sh_wr = threadIdx.x;
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
@ -511,7 +558,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// Register storage for double buffer of shared memory reads.
FragA frag_a[2][thread_m_blocks];
I4 frag_b_quant[2];
I4 frag_b_quant[2][b_thread_vecs];
FragC frag_c[thread_m_blocks][4][2];
FragS frag_s[2][4]; // No act-order
FragS act_frag_s[2][4][4]; // For act-order
@ -575,7 +622,11 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
int4 *sh_b_stage = sh_b + b_sh_stage * pipe;
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) {
cp_async4_stream(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]);
#pragma unroll
for (int j = 0; j < b_thread_vecs; j++) {
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j);
}
B_ptr[i] += b_gl_rd_delta_o;
}
@ -602,15 +653,15 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// Only fetch scales if this tile starts a new group
if (pipe % (group_blocks / thread_k_blocks) == 0) {
if (s_sh_wr_pred) {
cp_async4_stream(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
}
s_gl_rd += s_gl_rd_delta;
}
} else {
for (int i = 0; i < s_tb_groups; i++) {
if (s_sh_wr_pred) {
cp_async4_stream(&sh_s_stage[i * s_sh_stride + s_sh_wr],
&scales_ptr[s_gl_rd]);
cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr],
&scales_ptr[s_gl_rd]);
}
s_gl_rd += s_gl_rd_delta;
}
@ -641,14 +692,24 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
for (int i = 0; i < thread_m_blocks; i++)
ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
int4 *sh_b_stage = sh_b + b_sh_stage * pipe;
frag_b_quant[k % 2] = *reinterpret_cast<I4 *>(
&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]);
#pragma unroll
for (int i = 0; i < b_thread_vecs; i++) {
frag_b_quant[k % 2][i] = *reinterpret_cast<I4 *>(
&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]);
}
};
bool is_same_group[stages];
int same_group_id[stages];
auto init_same_group = [&](int pipe) {
if constexpr (!has_act_order) {
is_same_group[pipe] = false;
same_group_id[pipe] = 0;
return;
}
int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
int *sh_g_idx_int_ptr = reinterpret_cast<int *>(sh_g_idx_stage);
@ -767,10 +828,23 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// dequantization and matmul operations.
#pragma unroll
for (int j = 0; j < 4; j++) {
int b_quant = frag_b_quant[k % 2][j];
int b_quant_shift = b_quant >> 8;
FragB frag_b0;
FragB frag_b1;
if constexpr (num_bits == 4) {
int b_quant = frag_b_quant[k % 2][0][j];
int b_quant_shift = b_quant >> 8;
FragB frag_b0 = dequant(b_quant);
frag_b0 = dequant_4bit(b_quant);
frag_b1 = dequant_4bit(b_quant_shift);
} else {
int *frag_b_quant_ptr = reinterpret_cast<int *>(frag_b_quant[k % 2]);
int b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
int b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
frag_b0 = dequant_8bit(b_quant_0);
frag_b1 = dequant_8bit(b_quant_1);
}
// Apply scale to frag_b0
if constexpr (has_act_order) {
@ -782,8 +856,6 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
}
}
FragB frag_b1 = dequant(b_quant_shift);
// Apply scale to frag_b1
if constexpr (has_act_order) {
scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
@ -808,13 +880,13 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// multiple warps that accumulate their partial sums of the same output
// location; which we have to reduce over in the end. We do in shared memory.
auto thread_block_reduce = [&]() {
constexpr int red_off = threads / b_sh_stride / 2;
constexpr int red_off = threads / b_sh_stride_threads / 2;
if (red_off >= 1) {
int red_idx = threadIdx.x / b_sh_stride;
constexpr int red_sh_stride = b_sh_stride * 4 * 2;
constexpr int red_sh_delta = b_sh_stride;
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) +
(threadIdx.x % b_sh_stride);
int red_idx = threadIdx.x / b_sh_stride_threads;
constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
constexpr int red_sh_delta = b_sh_stride_threads;
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
(threadIdx.x % b_sh_stride_threads);
// Parallel logarithmic shared memory reduction. We make sure to avoid any
// unnecessary read or write iterations, e.g., for two warps we write only
@ -861,7 +933,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
};
// Since multiple threadblocks may process parts of the same column slice, we
// finally have to globally reduce over the results. As the striped portioning
// finally have to globally reduce over the results. As the striped partitioning
// minimizes the number of such reductions and our outputs are usually rather
// small, we perform this reduction serially in L2 cache.
auto global_reduce = [&](bool first = false, bool last = false) {
@ -951,13 +1023,15 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
auto write = [&](int idx, float c0, float c1, FragS &s) {
half2 res = __halves2half2(__float2half(c0), __float2half(c1));
// For per-column quantization we finally apply the scale here
if constexpr (!has_act_order && group_blocks == -1) {
// For per-column quantization we finally apply the scale here (only for
// 4-bit)
if constexpr (!has_act_order && group_blocks == -1 && num_bits == 4) {
res = __hmul2(res, s[0]);
}
((half2 *)sh)[idx] = res;
};
if (threadIdx.x / 32 < thread_n_blocks / 4) {
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
@ -1023,6 +1097,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// ensure all shared memory accesses are static. Note that both pipelines
// have even length meaning that the next iteration will always start at
// index 0.
#pragma unroll
for (int pipe = 0; pipe < stages;) {
#pragma unroll
@ -1070,23 +1145,63 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// For per-column scales, we only fetch them here in the final step before
// write-out
if constexpr (!has_act_order && group_blocks == -1) {
if (last) {
if constexpr (num_bits == 8) {
if (s_sh_wr_pred) {
cp_async4_stream(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
}
cp_async_fence();
} else {
if (last) {
if (s_sh_wr_pred) {
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
}
cp_async_fence();
}
}
}
thread_block_reduce();
if constexpr (!has_act_order && group_blocks == -1) {
if (last) {
if constexpr (num_bits == 8) {
cp_async_wait<0>();
__syncthreads();
if (threadIdx.x / 32 < thread_n_blocks / 4) {
reinterpret_cast<int4 *>(&frag_s)[0] = sh_s[s_sh_rd + 0];
reinterpret_cast<int4 *>(&frag_s)[1] = sh_s[s_sh_rd + 4];
}
} else {
if (last) {
cp_async_wait<0>();
__syncthreads();
if (threadIdx.x / 32 < thread_n_blocks / 4) {
reinterpret_cast<int4 *>(&frag_s)[0] = sh_s[s_sh_rd + 0];
reinterpret_cast<int4 *>(&frag_s)[1] = sh_s[s_sh_rd + 4];
}
}
}
}
// For 8-bit channelwise, we apply the scale before the global reduction
// that converts the fp32 results to fp16 (so that we avoid possible
// overflow in fp16)
if constexpr (!has_act_order && group_blocks == -1 && num_bits == 8) {
if (threadIdx.x / 32 < thread_n_blocks / 4) {
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
#pragma unroll
for (int j = 0; j < 4; j++) {
scale_float(reinterpret_cast<float *>(&frag_c[i][j][0][0]),
frag_s[j / 2][2 * (j % 2) + 0]);
scale_float(reinterpret_cast<float *>(&frag_c[i][j][0][2]),
frag_s[j / 2][2 * (j % 2) + 0]);
scale_float(reinterpret_cast<float *>(&frag_c[i][j][1][0]),
frag_s[j / 2][2 * (j % 2) + 1]);
scale_float(reinterpret_cast<float *>(&frag_c[i][j][1][2]),
frag_s[j / 2][2 * (j % 2) + 1]);
}
}
}
}
@ -1125,28 +1240,25 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
}
// if (blockIdx.x == 0 && threadIdx.x == 0) {
// printf("Move\n");
// }
start_pipes();
}
}
}
}
#define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \
else if (thread_m_blocks == THREAD_M_BLOCKS && \
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \
num_threads == NUM_THREADS) { \
cudaFuncSetAttribute( \
Marlin<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS>, \
Marlin<NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
Marlin<NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS> \
Marlin<NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS> \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \
prob_k, locks); \
@ -1158,28 +1270,92 @@ typedef struct {
int num_threads;
} thread_config_t;
thread_config_t small_batch_thread_configs[] = {
typedef struct {
int max_m_blocks;
thread_config_t tb_cfg;
} exec_config_t;
thread_config_t thread_configs[] = {
// Ordered by priority
// thread_k, thread_n, num_threads
{128, 128, 256}, // Default
{128, 64, 128}, // Reduce N 2X, same K
{64, 256, 256}, // Reduce K 2X, increase N 2X
{64, 128, 128}, // Reduce K 2X, same N
{64, 256, 256}, // Default (max cache usage)
{64, 128, 128}, // Reduce N, reduce warps
{128, 64, 128}, // Reduce N more, but increase K
};
thread_config_t large_batch_thread_configs[] = {
// Ordered by priority
int get_scales_cache_size(thread_config_t const &th_config, int prob_m,
int prob_n, int prob_k, int num_bits, int group_size,
bool has_act_order, bool is_k_full) {
bool cache_scales_chunk = has_act_order && !is_k_full;
// thread_k, thread_n, num_threads
{64, 256, 256}, // Default
{128, 64, 128}, // Reduce N 2X, same K
{64, 128, 128}, // Reduce N 2X, same K
// {128, 64, 128}, // Reduce N 4X, increase K 2X
};
int tb_n = th_config.thread_n;
int tb_k = th_config.thread_k;
bool is_valid_config(thread_config_t const &th_config, int prob_m, int prob_n,
int prob_k) {
// Get max scale groups per thread-block
int tb_groups;
if (group_size == -1) {
tb_groups = 1;
} else if (group_size == 0) {
tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size
} else {
tb_groups = div_ceil(tb_k, group_size);
}
if (cache_scales_chunk) {
int load_groups =
tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K
load_groups = max(load_groups, 32); // We load at least 32 scale groups
return load_groups * tb_n * 2;
} else {
int tb_scales = tb_groups * tb_n * 2;
return tb_scales * pipe_stages;
}
}
bool is_valid_cache_size(thread_config_t const &th_config, int max_m_blocks,
int prob_m, int prob_n, int prob_k, int num_bits,
int scales_cache_size, int max_shared_mem) {
int pack_factor = 32 / num_bits;
// Get B size
int tb_k = th_config.thread_k;
int tb_n = th_config.thread_n;
int b_size = (tb_k * tb_n / pack_factor) * 4;
// Get A size
int m_blocks = div_ceil(prob_m, 16);
int tb_max_m = 16;
while (true) {
if (m_blocks >= max_m_blocks) {
tb_max_m *= max_m_blocks;
break;
}
max_m_blocks--;
if (max_m_blocks == 0) {
TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks);
}
}
int a_size = (tb_max_m * tb_k) * 2;
float pipe_size = (a_size + b_size) * pipe_stages;
TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity
return pipe_size < 0.95f * (max_shared_mem - scales_cache_size);
}
bool is_valid_config(thread_config_t const &th_config, int max_m_blocks,
int prob_m, int prob_n, int prob_k, int num_bits,
int group_size, bool has_act_order, bool is_k_full,
int max_shared_mem) {
// Sanity
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
th_config.num_threads == -1) {
@ -1201,62 +1377,79 @@ bool is_valid_config(thread_config_t const &th_config, int prob_m, int prob_n,
return false;
}
// Determine cache for scales
int scales_cache_size =
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
group_size, has_act_order, is_k_full);
// Check that pipeline fits into cache
if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k,
num_bits, scales_cache_size, max_shared_mem)) {
return false;
}
return true;
}
thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) {
// TODO: Enable if needed after some more testing
if (prob_m <= 0) {
for (auto th_config : small_batch_thread_configs) {
if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
return th_config;
exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
int num_bits, int group_size,
bool has_act_order, bool is_k_full,
int max_shared_mem) {
int max_m_blocks = 4;
while (max_m_blocks > 0) {
for (auto th_config : thread_configs) {
if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full,
max_shared_mem)) {
return exec_config_t{max_m_blocks, th_config};
}
}
} else {
for (auto th_config : large_batch_thread_configs) {
if (is_valid_config(th_config, prob_m, prob_n, prob_k)) {
return th_config;
}
}
printf("WARNING: Marlin kernel is reducing max_m_blocks due to small SM "
"GPU cache. This may "
"hurt performance. Consider upgrading your GPU.\n");
max_m_blocks--; // Process less M blocks per invocation to reduce cache
// usage
}
return thread_config_t{-1, -1, -1};
return exec_config_t{0, {-1, -1, -1}};
}
#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
__CALL_IF(2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
__CALL_IF(3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
__CALL_IF(4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
#define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
\
__CALL_IF(1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF(1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF(1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
\
__CALL_IF(2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF(2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF(2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
\
__CALL_IF(3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF(3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF(3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
\
__CALL_IF(4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF(4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF(4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
void marlin_cuda(const void *A, const void *B, void *C, void *s, void *g_idx,
void *perm, void *a_tmp, int prob_m, int prob_n, int prob_k,
void *workspace, bool has_act_order, bool is_k_full,
int num_groups, int group_size, int dev = 0,
cudaStream_t stream = 0, int thread_k = -1, int thread_n = -1,
int sms = -1, int max_par = 16) {
void marlin_mm_f16i4(const void *A, const void *B, void *C, void *s,
void *g_idx, void *perm, void *a_tmp, int prob_m,
int prob_n, int prob_k, void *workspace, int num_bits,
bool has_act_order, bool is_k_full, int num_groups,
int group_size, int dev, cudaStream_t stream, int thread_k,
int thread_n, int sms, int max_par) {
TORCH_CHECK(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits);
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
", ", prob_n, ", ", prob_k, "]");
@ -1274,25 +1467,34 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, void *g_idx,
TORCH_CHECK(max_shared_mem > 0);
// Set thread config
thread_config_t th_config;
exec_config_t exec_cfg;
if (thread_k != -1 && thread_n != -1) {
// User-defined config
th_config = thread_config_t{thread_k, thread_n, default_threads};
exec_cfg =
exec_config_t{4, thread_config_t{thread_k, thread_n, default_threads}};
} else {
// Auto config
th_config = determine_thread_config(prob_m, prob_n, prob_k);
exec_cfg =
determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size,
has_act_order, is_k_full, max_shared_mem);
}
TORCH_CHECK(is_valid_config(th_config, prob_m, prob_n, prob_k),
"Invalid thread config: thread_k = " + str(th_config.thread_k) +
", thread_n = " + str(th_config.thread_n) +
", num_threads = " + str(th_config.num_threads) +
" for MKN = [" + str(prob_m) + ", " + str(prob_k) + ", " +
str(prob_n) + "]");
TORCH_CHECK(exec_cfg.max_m_blocks > 0 &&
is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks,
prob_m, prob_n, prob_k, num_bits, group_size,
has_act_order, is_k_full, max_shared_mem),
"Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks,
", thread_k = ", exec_cfg.tb_cfg.thread_k,
", thread_n = ", exec_cfg.tb_cfg.thread_n,
", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [",
prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
", group_size = ", group_size,
", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
", max_shared_mem = ", max_shared_mem);
int num_threads = th_config.num_threads;
thread_k = th_config.thread_k;
thread_n = th_config.thread_n;
int num_threads = exec_cfg.tb_cfg.num_threads;
thread_k = exec_cfg.tb_cfg.thread_k;
thread_n = exec_cfg.tb_cfg.thread_n;
int thread_k_blocks = thread_k / 16;
int thread_n_blocks = thread_n / 16;
@ -1352,28 +1554,32 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, void *g_idx,
}
// Main loop
for (int i = 0; i < tot_m_blocks; i += 4) {
for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) {
int thread_m_blocks = tot_m_blocks - i;
prob_m = tot_m - 16 * i;
int par = 1;
if (thread_m_blocks > 4) {
if (thread_m_blocks > exec_cfg.max_m_blocks) {
// Note that parallel > 1 currently only works for inputs without any
// padding
par = (16 * thread_m_blocks - pad) / 64;
par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks);
if (par > max_par)
par = max_par;
prob_m = 64 * par;
i += 4 * (par - 1);
thread_m_blocks = 4;
prob_m = (16 * exec_cfg.max_m_blocks) * par;
i += exec_cfg.max_m_blocks * (par - 1);
thread_m_blocks = exec_cfg.max_m_blocks;
}
// Define kernel configurations
if (false) {
}
CALL_IF(16, 4, 256)
CALL_IF(8, 8, 256)
CALL_IF(8, 4, 128)
CALL_IF(4, 8, 128)
CALL_IF(4, 32, 2, 256)
CALL_IF(4, 16, 4, 256)
CALL_IF(4, 8, 4, 128)
CALL_IF(4, 4, 8, 128)
CALL_IF(8, 32, 2, 256)
CALL_IF(8, 16, 4, 256)
CALL_IF(8, 8, 4, 128)
CALL_IF(8, 4, 8, 128)
else {
TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " +
str(prob_n) + ", " + str(prob_k) + "]" +
@ -1395,33 +1601,32 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, void *g_idx,
torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
torch::Tensor &b_scales, torch::Tensor &g_idx,
torch::Tensor &perm, torch::Tensor &workspace,
int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full) {
int64_t num_bits, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full) {
// Verify num_bits
TORCH_CHECK(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits);
int pack_factor = 32 / num_bits;
// Verify A
TORCH_CHECK(a.size(0) == size_m,
"Shape mismatch: a.size(0) = " + str(a.size(0)) +
", size_m = " + str(size_m));
TORCH_CHECK(a.size(1) == size_k,
"Shape mismatch: a.size(1) = " + str(a.size(1)) +
", size_k = " + str(size_k));
TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0),
", size_m = ", size_m);
TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1),
", size_k = ", size_k);
// Verify B
TORCH_CHECK(size_k % gptq_marlin::tile_size == 0,
"size_k = " + str(size_k) + " is not divisible by tile_size = " +
str(gptq_marlin::tile_size));
TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, "size_k = ", size_k,
" is not divisible by tile_size = ", gptq_marlin::tile_size);
TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0),
"Shape mismatch: b_q_weight.size(0) = " +
str(b_q_weight.size(0)) + ", size_k = " + str(size_k) +
", tile_size = " + str(gptq_marlin::tile_size));
TORCH_CHECK(
b_q_weight.size(1) % gptq_marlin::tile_size == 0,
"b_q_weight.size(1) = " + str(b_q_weight.size(1)) +
" is not divisible by tile_size = " + str(gptq_marlin::tile_size));
int actual_size_n = (b_q_weight.size(1) / gptq_marlin::tile_size) *
gptq_marlin::pack_factor_4bit;
TORCH_CHECK(size_n == actual_size_n,
"size_n = " + str(size_n) +
", actual_size_n = " + str(actual_size_n));
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
", size_k = ", size_k, ", tile_size = ", gptq_marlin::tile_size);
TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0,
"b_q_weight.size(1) = ", b_q_weight.size(1),
" is not divisible by tile_size = ", gptq_marlin::tile_size);
int actual_size_n =
(b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor;
TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
", actual_size_n = ", actual_size_n);
// Verify device and strides
TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
@ -1457,9 +1662,9 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
// Verify g_idx and perm
TORCH_CHECK((g_idx.size(0) == 0 && perm.size(0) == 0) ||
(g_idx.size(0) == size_k && perm.size(0) == size_k),
"Unexpected g_idx.size(0) = " + str(g_idx.size(0)) +
" and perm.size(0) = " + str(perm.size(0)) +
", where size_k = " + str(size_k));
"Unexpected g_idx.size(0) = ", g_idx.size(0),
" and perm.size(0) = ", perm.size(0),
", where size_k = ", size_k);
// Detect groupsize and act_order
int num_groups = -1;
@ -1475,9 +1680,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
if (has_act_order) {
if (is_k_full) {
TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1");
TORCH_CHECK(size_k % num_groups == 0,
"size_k = " + str(size_k) +
", is not divisible by num_groups = " + str(num_groups));
TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k,
", is not divisible by num_groups = ", num_groups);
group_size = size_k / num_groups;
} else {
group_size = 0;
@ -1485,10 +1689,9 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
} else {
if (num_groups > 1) {
TORCH_CHECK(size_k % num_groups == 0,
"size_k = " + str(size_k) +
", is not divisible by b_scales.size(0) = " +
str(b_scales.size(0)));
TORCH_CHECK(
size_k % num_groups == 0, "size_k = ", size_k,
", is not divisible by b_scales.size(0) = ", b_scales.size(0));
group_size = size_k / num_groups;
} else {
group_size = -1;
@ -1496,23 +1699,22 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
}
// Verify workspace size
TORCH_CHECK(size_n % gptq_marlin::min_thread_n == 0,
"size_n = " + str(size_n) +
", is not divisible by min_thread_n = " +
str(gptq_marlin::min_thread_n));
TORCH_CHECK(
size_n % gptq_marlin::min_thread_n == 0, "size_n = ", size_n,
", is not divisible by min_thread_n = ", gptq_marlin::min_thread_n);
int min_workspace_size =
(size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par;
TORCH_CHECK(workspace.numel() >= min_workspace_size,
"workspace.numel = " + str(workspace.numel()) +
" is below min_workspace_size = " + str(min_workspace_size));
"workspace.numel = ", workspace.numel(),
" is below min_workspace_size = ", min_workspace_size);
int dev = a.get_device();
gptq_marlin::marlin_cuda(
gptq_marlin::marlin_mm_f16i4(
a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), b_scales.data_ptr(),
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n,
size_k, workspace.data_ptr(), has_act_order, is_k_full, num_groups,
group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n,
sms, gptq_marlin::max_par);
size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full,
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, gptq_marlin::max_par);
return c;
}

View File

@ -24,8 +24,6 @@ static constexpr int min_thread_k = 64;
static constexpr int tile_size = 16;
static constexpr int max_par = 16;
static constexpr int pack_factor_4bit = 8; // We have 8 4-bit vals inside a 32 bit
template <typename T, int n>
struct Vec {
T elems[n];
@ -51,13 +49,11 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool
"r"(smem), "l"(glob_ptr), "n"(BYTES));
}
__device__ inline void cp_async4_stream(void* smem_ptr, const void* glob_ptr) {
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("{\n"
" .reg .b64 p;\n"
" createpolicy.fractional.L2::evict_first.b64 p, 1.0;"
" cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n"
" cp.async.cg.shared.global [%0], [%1], %2;\n"
"}\n" ::"r"(smem),
"l"(glob_ptr), "n"(BYTES));
}

View File

@ -11,7 +11,7 @@ static constexpr int tile_n_size = tile_k_size * 4;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
template <int const num_threads, bool const has_perm>
template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void
marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
uint32_t const *__restrict__ perm_ptr,
@ -20,7 +20,8 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
} // namespace gptq_marlin
torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
int64_t size_k, int64_t size_n) {
int64_t size_k, int64_t size_n,
int64_t num_bits) {
TORCH_CHECK_NOT_IMPLEMENTED(
false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0");
return torch::empty({1, 1});
@ -28,11 +29,13 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
#else
template <int const num_threads, bool const has_perm>
template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void
marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
uint32_t const *__restrict__ perm_ptr,
uint32_t *__restrict__ out_ptr, int size_k, int size_n) {
constexpr int pack_factor = 32 / num_bits;
int k_tiles = size_k / tile_k_size;
int n_tiles = size_n / tile_n_size;
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
@ -64,9 +67,10 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
sh_pipe_ptr += perm_size;
}
constexpr int tile_ints = tile_k_size / pack_factor;
constexpr int stage_n_threads = tile_n_size / 4;
constexpr int stage_k_threads =
has_perm ? tile_k_size : tile_k_size / pack_factor_4bit;
constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints;
constexpr int stage_size = stage_k_threads * stage_n_threads;
auto load_perm_to_shared = [&](int k_tile_id) {
@ -99,9 +103,9 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
reinterpret_cast<uint32_t const *>(sh_perm_ptr);
int src_k = sh_perm_int_ptr[k_id];
int src_k_packed = src_k / pack_factor_4bit;
int src_k_packed = src_k / pack_factor;
cp_async4_stream(
cp_async4(
&sh_ptr[k_id * stage_n_threads + n_id],
reinterpret_cast<int4 const *>(&(
b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)])));
@ -113,12 +117,12 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
int n_id = threadIdx.x % stage_n_threads;
int first_k = k_tile_id * tile_k_size;
int first_k_packed = first_k / pack_factor_4bit;
int first_k_packed = first_k / pack_factor;
cp_async4_stream(&sh_ptr[k_id * stage_n_threads + n_id],
reinterpret_cast<int4 const *>(
&(b_q_weight_ptr[(first_k_packed + k_id) * size_n +
first_n + (n_id * 4)])));
cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
reinterpret_cast<int4 const *>(
&(b_q_weight_ptr[(first_k_packed + k_id) * size_n +
first_n + (n_id * 4)])));
}
}
@ -145,26 +149,27 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
int cur_n = warp_id * 16 + tc_col;
constexpr int sh_stride = 64;
constexpr uint32_t mask = (1 << num_bits) - 1;
int4 *sh_stage_ptr = sh_pipe_ptr + stage_size * pipe;
uint32_t *sh_stage_int_ptr = reinterpret_cast<uint32_t *>(sh_stage_ptr);
uint32_t *sh_perm_int_ptr = reinterpret_cast<uint32_t *>(sh_perm_ptr);
uint32_t vals[pack_factor_4bit];
uint32_t vals[8];
if constexpr (has_perm) {
for (int i = 0; i < 4; i++) {
int k_idx = tc_row + tc_offsets[i];
uint32_t src_k = sh_perm_int_ptr[k_idx];
uint32_t src_k_pos = src_k % pack_factor_4bit;
uint32_t src_k_pos = src_k % pack_factor;
uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n];
uint32_t b1_cur_val = (b1_val >> (src_k_pos * 4)) & 0xf;
uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask;
uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8];
uint32_t b2_cur_val = (b2_val >> (src_k_pos * 4)) & 0xf;
uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask;
vals[i] = b1_cur_val;
vals[4 + i] = b2_cur_val;
@ -172,41 +177,56 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
} else {
uint32_t b1_val_1 = sh_stage_int_ptr[cur_n];
uint32_t b1_val_2 = sh_stage_int_ptr[sh_stride + cur_n];
uint32_t b2_val_1 = sh_stage_int_ptr[cur_n + 8];
uint32_t b2_val_2 = sh_stage_int_ptr[sh_stride + cur_n + 8];
uint32_t b1_vals[tile_ints];
uint32_t b2_vals[tile_ints];
#pragma unroll
for (int i = 0; i < 2; i++) {
int cur_elem = tc_row + tc_offsets[i];
vals[i] = (b1_val_1 >> (cur_elem * 4)) & 0xf;
vals[4 + i] = (b2_val_1 >> (cur_elem * 4)) & 0xf;
for (int i = 0; i < tile_ints; i++) {
b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];
b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];
}
#pragma unroll
for (int i = 2; i < 4; i++) {
int cur_elem = tc_row + tc_offsets[i] - 8;
vals[i] = (b1_val_2 >> (cur_elem * 4)) & 0xf;
vals[4 + i] = (b2_val_2 >> (cur_elem * 4)) & 0xf;
for (int i = 0; i < 4; i++) {
int cur_elem = tc_row + tc_offsets[i];
int cur_int = cur_elem / pack_factor;
int cur_pos = cur_elem % pack_factor;
vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask;
vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask;
}
}
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
// Result of:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
constexpr int pack_idx[pack_factor_4bit] = {0, 2, 4, 6, 1, 3, 5, 7};
if constexpr (num_bits == 4) {
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
uint32_t res = 0;
uint32_t res = 0;
#pragma unroll
for (int i = 0; i < pack_factor_4bit; i++) {
res |= vals[pack_idx[i]] << (i * 4);
for (int i = 0; i < 8; i++) {
res |= vals[pack_idx[i]] << (i * 4);
}
out_ptr[out_offset + th_id * 4 + warp_id] = res;
} else {
constexpr int pack_idx[4] = {0, 2, 1, 3};
uint32_t res1 = 0;
uint32_t res2 = 0;
#pragma unroll
for (int i = 0; i < 4; i++) {
res1 |= vals[pack_idx[i]] << (i * 8);
res2 |= vals[4 + pack_idx[i]] << (i * 8);
}
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;
}
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor_4bit;
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
out_ptr[out_offset + th_id * 4 + warp_id] = res;
};
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
@ -242,19 +262,35 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr,
} // namespace gptq_marlin
#define CALL_IF(NUM_BITS, HAS_PERM) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
cudaFuncSetAttribute( \
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, \
NUM_BITS, HAS_PERM>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, NUM_BITS, \
HAS_PERM> \
<<<blocks, gptq_marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
}
torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
int64_t size_k, int64_t size_n) {
int64_t size_k, int64_t size_n,
int64_t num_bits) {
// Verify compatibility with marlin tile of 16x64
TORCH_CHECK(size_k % gptq_marlin::tile_k_size == 0, "size_k = ", size_k,
" is not divisible by tile_k_size = ", gptq_marlin::tile_k_size);
TORCH_CHECK(size_n % gptq_marlin::tile_n_size == 0, "size_n = ", size_n,
" is not divisible by tile_n_size = ", gptq_marlin::tile_n_size);
TORCH_CHECK(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits);
int const pack_factor = 32 / num_bits;
// Verify B
TORCH_CHECK((size_k / gptq_marlin::pack_factor_4bit) == b_q_weight.size(0),
TORCH_CHECK((size_k / pack_factor) == b_q_weight.size(0),
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
", size_k = ", size_k,
", pack_factor_4bit = ", gptq_marlin::pack_factor_4bit);
", size_k = ", size_k, ", pack_factor = ", pack_factor);
TORCH_CHECK(b_q_weight.size(1) == size_n,
"b_q_weight.size(1) = ", b_q_weight.size(1),
" is not size_n = ", size_n);
@ -273,10 +309,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
auto options = torch::TensorOptions()
.dtype(b_q_weight.dtype())
.device(b_q_weight.device());
torch::Tensor out = torch::empty(
{size_k / gptq_marlin::tile_size,
size_n * gptq_marlin::tile_size / gptq_marlin::pack_factor_4bit},
options);
torch::Tensor out =
torch::empty({size_k / gptq_marlin::tile_size,
size_n * gptq_marlin::tile_size / pack_factor},
options);
// Detect if there is act_order
bool has_perm = perm.size(0) != 0;
@ -299,23 +335,15 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
TORCH_CHECK(max_shared_mem > 0);
if (has_perm) {
cudaFuncSetAttribute(
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, true>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
max_shared_mem);
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, true>
<<<blocks, gptq_marlin::repack_threads, max_shared_mem,
stream>>>(b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n);
} else {
cudaFuncSetAttribute(
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, false>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
max_shared_mem);
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, false>
<<<blocks, gptq_marlin::repack_threads, max_shared_mem,
stream>>>(b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n);
if (false) {
}
CALL_IF(4, false)
CALL_IF(4, true)
CALL_IF(8, false)
CALL_IF(8, true)
else {
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits,
", has_perm = ", has_perm);
}
return out;

View File

@ -39,6 +39,13 @@ MODELS = [
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-4bit-64g-actorder_True"),
# act_order==True, group_size=32
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-4bit-32g-actorder_True"),
# 8-bit, act_order==True, group_size=channelwise
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-8bit--1g-actorder_True"),
# 8-bit, act_order==True, group_size=128
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-8bit-128g-actorder_True"),
# 8-bit, act_order==True, group_size=32
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-8bit-32g-actorder_True"),
]
@ -65,8 +72,7 @@ def test_models(
dtype=dtype,
quantization="marlin",
max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=1,
disable_custom_all_reduce=True)
tensor_parallel_size=1)
gptq_marlin_outputs = gptq_marlin_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
@ -78,8 +84,7 @@ def test_models(
dtype=dtype,
quantization="gptq",
max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=1,
disable_custom_all_reduce=True)
tensor_parallel_size=1)
gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts,
max_tokens,
num_logprobs)

View File

@ -169,18 +169,20 @@ def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
# gptq_marlin
def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
size_k: int, size_n: int) -> torch.Tensor:
return vllm_ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n)
size_k: int, size_n: int,
num_bits: int) -> torch.Tensor:
return vllm_ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n,
num_bits)
def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, g_idx: torch.Tensor,
perm: torch.Tensor, workspace: torch.Tensor, size_m: int,
size_n: int, size_k: int,
perm: torch.Tensor, workspace: torch.Tensor,
num_bits: int, size_m: int, size_n: int, size_k: int,
is_k_full: bool) -> torch.Tensor:
return vllm_ops.gptq_marlin_gemm(a, b_q_weight, b_scales, g_idx, perm,
workspace, size_m, size_n, size_k,
is_k_full)
workspace, num_bits, size_m, size_n,
size_k, is_k_full)
# fp8

View File

@ -2,7 +2,6 @@ import enum
from enum import Enum
from typing import Any, Dict, List, Optional
import numpy
import torch
from torch.nn.parameter import Parameter
@ -17,41 +16,13 @@ GPTQ_MARLIN_MIN_THREAD_N = 64
GPTQ_MARLIN_MIN_THREAD_K = 128
GPTQ_MARLIN_MAX_PARALLEL = 16
GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4]
GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8]
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
GPTQ_MARLIN_SUPPORTED_SYM = [True]
# Precompute permutations for Marlin weight and scale shuffling
#
# Marlin works on [16,64] tiles. The goal of the permutations
# is to reorder the weight data so that it is compatible
# with the tensor-core format that is described here:
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
#
# As a result of this reordering, the vector loads inside the
# kernel will get the data as it is needed for tensor-core
# (without the need to use ldmatrix instructions)
def _get_perms():
perm = []
for i in range(32):
perm1 = []
col = i // 4
for block in [0, 1]:
for row in [
2 * (i % 4),
2 * (i % 4) + 1,
2 * (i % 4 + 4),
2 * (i % 4 + 4) + 1,
]:
perm1.append(16 * row + col + 8 * block)
for j in range(4):
perm.extend([p + 256 * j for p in perm1])
perm = numpy.array(perm)
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
perm = perm.reshape((-1, 8))[:, interleave].ravel() # type: ignore
perm = torch.from_numpy(perm)
# Permutations for Marlin scale shuffling
def get_scale_perms(num_bits):
scale_perm = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
@ -59,23 +30,21 @@ def _get_perms():
for i in range(4):
scale_perm_single.extend(
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return perm, scale_perm, scale_perm_single
_perm, _scale_perm, _scale_perm_single = _get_perms()
return scale_perm, scale_perm_single
def get_pack_factor(num_bits):
assert num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS, (
f"Unsupported num_bits = {num_bits}")
assert (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS
), f"Unsupported num_bits = {num_bits}"
return 32 // num_bits
def marlin_permute_scales(s, size_k, size_n, group_size):
def marlin_permute_scales(s, size_k, size_n, group_size, num_bits):
scale_perm, scale_perm_single = get_scale_perms(num_bits)
if group_size < size_k and group_size != -1:
s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm]
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
else:
s = s.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single]
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
s = s.reshape((-1, size_n)).contiguous()
return s
@ -279,13 +248,15 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
requires_grad=False,
)
set_weight_attrs(
qweight, {
qweight,
{
**extra_weight_attrs,
"input_dim": 0,
"output_dim": 1,
"packed_dim": 0,
"pack_factor": self.quant_config.pack_factor,
})
},
)
# Activation order
g_idx = Parameter(
@ -296,10 +267,13 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
requires_grad=False,
)
# Ignore warning from fused linear layers such as QKVParallelLinear.
set_weight_attrs(g_idx, {
**extra_weight_attrs, "input_dim": 0,
"ignore_warning": True
})
set_weight_attrs(
g_idx,
{
**extra_weight_attrs, "input_dim": 0,
"ignore_warning": True
},
)
g_idx_sort_indices = Parameter(
torch.empty(
@ -320,29 +294,34 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
requires_grad=False,
)
set_weight_attrs(
scales, {
scales,
{
**extra_weight_attrs,
"input_dim": scales_and_zp_input_dim,
"output_dim": 1,
})
},
)
# Quantized zero-points
qzeros = Parameter(
torch.empty(scales_and_zp_size,
output_size_per_partition //
self.quant_config.pack_factor,
dtype=torch.int32,
device="meta"),
torch.empty(
scales_and_zp_size,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
device="meta",
),
requires_grad=False,
)
set_weight_attrs(
qzeros, {
qzeros,
{
**extra_weight_attrs,
"input_dim": scales_and_zp_input_dim,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
})
},
)
# Allocate marlin workspace
max_workspace_size = (
@ -405,13 +384,14 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
else:
# Reset g_idx related tensors
layer.g_idx = Parameter(torch.empty(0,
dtype=torch.int,
device=cur_device),
requires_grad=False)
layer.g_idx_sort_indices = Parameter(torch.empty(
0, dtype=torch.int, device=cur_device),
requires_grad=False)
layer.g_idx = Parameter(
torch.empty(0, dtype=torch.int, device=cur_device),
requires_grad=False,
)
layer.g_idx_sort_indices = Parameter(
torch.empty(0, dtype=torch.int, device=cur_device),
requires_grad=False,
)
# Repack weights
marlin_qweight = ops.gptq_marlin_repack(
@ -419,6 +399,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
layer.g_idx_sort_indices,
part_size_k,
part_size_n,
self.quant_config.weight_bits,
)
replace_tensor("qweight", marlin_qweight)
@ -428,15 +409,28 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
if self.quant_config.desc_act:
scales_size_k = full_size_k
marlin_scales = marlin_permute_scales(layer.scales, scales_size_k,
scales_size_n,
self.quant_config.group_size)
marlin_scales = marlin_permute_scales(
layer.scales,
scales_size_k,
scales_size_n,
self.quant_config.group_size,
self.quant_config.weight_bits,
)
replace_tensor("scales", marlin_scales)
output = ops.gptq_marlin_gemm(reshaped_x, layer.qweight, layer.scales,
layer.g_idx, layer.g_idx_sort_indices,
layer.workspace, size_m, part_size_n,
part_size_k, layer.is_k_full)
output = ops.gptq_marlin_gemm(
reshaped_x,
layer.qweight,
layer.scales,
layer.g_idx,
layer.g_idx_sort_indices,
layer.workspace,
self.quant_config.weight_bits,
size_m,
part_size_n,
part_size_k,
layer.is_k_full,
)
if bias is not None:
output.add_(bias) # In-place add