[Kernel] Increase precision of GPTQ/AWQ Marlin kernel (#6795)

This commit is contained in:
Alexander Matveev 2024-07-27 17:52:33 -04:00 committed by GitHub
parent fad5576c58
commit 75acdaa4b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 168 additions and 44 deletions

View File

@ -10,7 +10,7 @@ from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS)
MARLIN_SUPPORTED_GROUP_SIZES, MARLIN_SUPPORTED_NUM_BITS)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
MarlinWorkspace, marlin_quantize)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
@ -56,6 +56,8 @@ def bench_run(results: List[benchmark.Measurement], model: str,
(marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta,
marlin_24_s) = marlin_24_quantize(b, num_bits, group_size)
marlin_zp = torch.empty(0, dtype=torch.int, device=b.device)
# GPTQ quant
(w_ref, q_w, s, g_idx,
rand_perm) = quantize_weights(b, num_bits, group_size, act_order)
@ -87,6 +89,7 @@ def bench_run(results: List[benchmark.Measurement], model: str,
"marlin_w_ref": marlin_w_ref,
"marlin_q_w": marlin_q_w,
"marlin_s": marlin_s,
"marlin_zp": marlin_zp,
"marlin_g_idx": marlin_g_idx,
"marlin_sort_indices": marlin_sort_indices,
"marlin_rand_perm": marlin_rand_perm,
@ -125,11 +128,21 @@ def bench_run(results: List[benchmark.Measurement], model: str,
results.append(
benchmark.Timer(
stmt=
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full)", # noqa: E501
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full, False, False)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="gptq_marlin_gemm",
description="gptq_marlin_gemm_fp16",
).blocked_autorange(min_run_time=min_run_time))
results.append(
benchmark.Timer(
stmt=
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full, False, True)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="gptq_marlin_gemm_fp32",
).blocked_autorange(min_run_time=min_run_time))
if (num_bits in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
@ -183,12 +196,12 @@ def main(args):
) > 0 and is_k_full not in args.limit_k_full:
continue
for num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
for num_bits in MARLIN_SUPPORTED_NUM_BITS:
if len(args.limit_num_bits
) > 0 and num_bits not in args.limit_num_bits:
continue
for group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES:
for group_size in MARLIN_SUPPORTED_GROUP_SIZES:
if len(
args.limit_group_size
) > 0 and group_size not in args.limit_group_size:

View File

@ -93,7 +93,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& g_idx, torch::Tensor& perm,
torch::Tensor& workspace, int64_t num_bits,
int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full, bool has_zp);
bool is_k_full, bool has_zp,
bool use_fp32_reduce);
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,

View File

@ -59,14 +59,16 @@ __global__ void Marlin(
const int4* __restrict__ A, // fp16 input matrix of shape mxk
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
int4* __restrict__ C, // fp16 output buffer of shape mxn
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
// (k/groupsize)xn
const int* __restrict__ g_idx, // int32 group indices of shape k
int num_groups, // number of scale groups per output channel
int prob_m, // batch dimension m
int prob_n, // output dimension n
int prob_k, // reduction dimension k
int* locks // extra global storage for barrier synchronization
int num_groups, // number of scale groups per output channel
int prob_m, // batch dimension m
int prob_n, // output dimension n
int prob_k, // reduction dimension k
int* locks, // extra global storage for barrier synchronization
bool use_fp32_reduce // whether to use fp32 global reduce
) {}
} // namespace gptq_marlin
@ -532,16 +534,18 @@ __global__ void Marlin(
const int4* __restrict__ A, // fp16 input matrix of shape mxk
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
int4* __restrict__ C, // fp16 output buffer of shape mxn
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
// (k/groupsize)xn
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const int* __restrict__ g_idx, // int32 group indices of shape k
int num_groups, // number of scale groups per output channel
int prob_m, // batch dimension m
int prob_n, // output dimension n
int prob_k, // reduction dimension k
int* locks // extra global storage for barrier synchronization
int num_groups, // number of scale groups per output channel
int prob_m, // batch dimension m
int prob_n, // output dimension n
int prob_k, // reduction dimension k
int* locks, // extra global storage for barrier synchronization
bool use_fp32_reduce // whether to use fp32 global reduce
) {
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
// same size, which might involve multiple column "slices" (of width 16 *
@ -595,6 +599,8 @@ __global__ void Marlin(
int slice_idx; // index of threadblock in current slice; numbered bottom to
// top
int par_id = 0;
// We can easily implement parallel problem execution by just remapping
// indices and advancing global pointers
if (slice_col_par >= n_tiles) {
@ -602,6 +608,7 @@ __global__ void Marlin(
C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
locks += (slice_col_par / n_tiles) * n_tiles;
slice_col = slice_col_par % n_tiles;
par_id = slice_col_par / n_tiles;
}
// Compute all information about the current slice which is required for
@ -632,6 +639,7 @@ __global__ void Marlin(
C += 16 * thread_m_blocks * prob_n / 8;
locks += n_tiles;
slice_col = 0;
par_id++;
}
};
init_slice();
@ -1321,7 +1329,7 @@ __global__ void Marlin(
// finally have to globally reduce over the results. As the striped
// partitioning minimizes the number of such reductions and our outputs are
// usually rather small, we perform this reduction serially in L2 cache.
auto global_reduce = [&](bool first = false, bool last = false) {
auto global_reduce_fp16 = [&](bool first = false, bool last = false) {
// We are very careful here to reduce directly in the output buffer to
// maximize L2 cache utilization in this step. To do this, we write out
// results in FP16 (but still reduce with FP32 compute).
@ -1382,6 +1390,53 @@ __global__ void Marlin(
}
};
// Globally reduce over threadblocks that compute the same column block.
// We use a tmp C buffer to reduce in full fp32 precision.
auto global_reduce_fp32 = [&](bool first = false, bool last = false) {
constexpr int tb_m = thread_m_blocks * 16;
constexpr int tb_n = thread_n_blocks * 16;
constexpr int c_size = tb_m * tb_n * sizeof(float) / 16;
constexpr int active_threads = 32 * thread_n_blocks / 4;
bool is_th_active = threadIdx.x < active_threads;
int par_offset = c_size * n_tiles * par_id;
int slice_offset = c_size * slice_col;
constexpr int num_floats = thread_m_blocks * 4 * 2 * 4;
constexpr int th_size = num_floats * sizeof(float) / 16;
int c_cur_offset = par_offset + slice_offset;
if (!is_th_active) {
return;
}
if (!first) {
float* frag_c_ptr = reinterpret_cast<float*>(&frag_c);
#pragma unroll
for (int k = 0; k < th_size; k++) {
sh[threadIdx.x] =
C_tmp[c_cur_offset + active_threads * k + threadIdx.x];
float* sh_c_ptr = reinterpret_cast<float*>(&sh[threadIdx.x]);
#pragma unroll
for (int f = 0; f < 4; f++) {
frag_c_ptr[k * 4 + f] += sh_c_ptr[f];
}
}
}
if (!last) {
int4* frag_c_ptr = reinterpret_cast<int4*>(&frag_c);
#pragma unroll
for (int k = 0; k < th_size; k++) {
C_tmp[c_cur_offset + active_threads * k + threadIdx.x] = frag_c_ptr[k];
}
}
};
// Write out the reduce final result in the correct layout. We only actually
// reshuffle matrix fragments in this step, the reduction above is performed
// in fragment layout.
@ -1606,7 +1661,11 @@ __global__ void Marlin(
if (slice_count > 1) { // only globally reduce if there is more than one
// block in a slice
barrier_acquire(&locks[slice_col], slice_idx);
global_reduce(slice_idx == 0, last);
if (use_fp32_reduce) {
global_reduce_fp32(slice_idx == 0, last);
} else {
global_reduce_fp16(slice_idx == 0, last);
}
barrier_release(&locks[slice_col], last);
}
if (last) // only the last block in a slice actually writes the result
@ -1661,8 +1720,8 @@ __global__ void Marlin(
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
HAS_ZP, GROUP_BLOCKS> \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, s_ptr, zp_ptr, g_idx_ptr, num_groups, \
prob_m, prob_n, prob_k, locks); \
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \
num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \
}
typedef struct {
@ -1801,6 +1860,27 @@ bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
return true;
}
int determine_reduce_max_m(int prob_m, int max_par) {
constexpr int tile_m_size = 16;
if (prob_m <= tile_m_size) {
return tile_m_size;
} else if (prob_m <= tile_m_size * 2) {
return tile_m_size * 2;
} else if (prob_m <= tile_m_size * 3) {
return tile_m_size * 3;
} else if (prob_m <= tile_m_size * 4) {
return tile_m_size * 4;
} else {
int cur_par = min(div_ceil(prob_m, tile_m_size * 4), max_par);
return tile_m_size * 4 * cur_par;
}
}
exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
int num_bits, int group_size,
bool has_act_order, bool is_k_full,
@ -1880,13 +1960,13 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
template <typename scalar_t>
void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, void* zp,
void* g_idx, void* perm, void* a_tmp, int prob_m,
int prob_n, int prob_k, void* workspace, int num_bits,
bool has_act_order, bool is_k_full, bool has_zp,
int num_groups, int group_size, int dev,
void marlin_mm_f16i4(const void* A, const void* B, void* C, void* C_tmp,
void* s, void* zp, void* g_idx, void* perm, void* a_tmp,
int prob_m, int prob_n, int prob_k, void* workspace,
int num_bits, bool has_act_order, bool is_k_full,
bool has_zp, int num_groups, int group_size, int dev,
cudaStream_t stream, int thread_k, int thread_n, int sms,
int max_par) {
int max_par, bool use_fp32_reduce) {
TORCH_CHECK(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits);
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
@ -1970,6 +2050,7 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, void* zp,
const int4* A_ptr = (const int4*)A;
const int4* B_ptr = (const int4*)B;
int4* C_ptr = (int4*)C;
int4* C_tmp_ptr = (int4*)C_tmp;
const int4* s_ptr = (const int4*)s;
const int4* zp_ptr = (const int4*)zp;
const int* g_idx_ptr = (const int*)g_idx;
@ -2049,7 +2130,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& g_idx, torch::Tensor& perm,
torch::Tensor& workspace, int64_t num_bits,
int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full, bool has_zp) {
bool is_k_full, bool has_zp,
bool use_fp32_reduce) {
// Verify num_bits
TORCH_CHECK(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits);
@ -2099,6 +2181,17 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor c = torch::empty({size_m, size_n}, options);
torch::Tensor a_tmp = torch::empty({size_m, size_k}, options);
// Alloc C tmp buffer that is going to be used for the global reduce
int reduce_max_m = marlin::determine_reduce_max_m(size_m, marlin::max_par);
int reduce_n = size_n;
auto options_fp32 =
torch::TensorOptions().dtype(at::kFloat).device(a.device());
if (!use_fp32_reduce) {
reduce_max_m = 0;
reduce_n = 0;
}
torch::Tensor c_tmp = torch::empty({reduce_max_m, reduce_n}, options_fp32);
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int thread_k = -1;
@ -2171,20 +2264,21 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
if (a.scalar_type() == at::ScalarType::Half) {
marlin::marlin_mm_f16i4<half>(
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
b_scales.data_ptr<at::Half>(), b_zeros.data_ptr(), g_idx.data_ptr(),
perm.data_ptr(), a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
c_tmp.data_ptr<float>(), b_scales.data_ptr<at::Half>(),
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp,
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, marlin::max_par);
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
marlin::marlin_mm_f16i4<nv_bfloat16>(
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(),
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(),
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp,
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, marlin::max_par);
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
} else {
TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
}

View File

@ -27,6 +27,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True]
USE_FP32_REDUCE_OPTS = [False, True]
MARLIN_K_CHUNKS = [128]
MARLIN_N_CHUNKS = [64, 128, 256]
@ -175,6 +176,7 @@ def test_awq_marlin_repack(k_chunk, n_chunk, num_bits, group_size,
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
def test_gptq_marlin_gemm(
k_chunk,
n_chunk,
@ -183,6 +185,7 @@ def test_gptq_marlin_gemm(
mnk_factors,
act_order,
is_k_full,
use_fp32_reduce,
):
m_factor, n_factor, k_factor = mnk_factors
@ -222,8 +225,9 @@ def test_gptq_marlin_gemm(
a_input.shape[0],
b_weight.shape[1],
a_input.shape[1],
is_k_full,
is_k_full=is_k_full,
has_zp=False,
use_fp32_reduce=use_fp32_reduce,
)
output_ref = torch.matmul(a_input, w_ref)
@ -365,12 +369,14 @@ def test_fp8_marlin_gemm(
@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS)
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
def test_awq_marlin_gemm(
k_chunk,
n_chunk,
num_bits,
group_size,
mnk_factors,
use_fp32_reduce,
):
m_factor, n_factor, k_factor = mnk_factors
@ -407,8 +413,9 @@ def test_awq_marlin_gemm(
a_input.shape[0],
b_weight.shape[1],
a_input.shape[1],
is_k_full,
has_zp,
is_k_full=is_k_full,
has_zp=has_zp,
use_fp32_reduce=use_fp32_reduce,
)
output_ref = torch.matmul(a_input, w_ref)

View File

@ -286,12 +286,12 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, b_zeros: torch.Tensor,
g_idx: torch.Tensor, perm: torch.Tensor,
workspace: torch.Tensor, num_bits: int, size_m: int,
size_n: int, size_k: int, is_k_full: bool,
has_zp: bool) -> torch.Tensor:
size_n: int, size_k: int, is_k_full: bool, has_zp: bool,
use_fp32_reduce: bool) -> torch.Tensor:
return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
g_idx, perm, workspace, num_bits,
size_m, size_n, size_k, is_k_full,
has_zp)
has_zp, use_fp32_reduce)
# fp8 marlin

View File

@ -16,6 +16,11 @@ GPTQ_MARLIN_MAX_PARALLEL = 16
MARLIN_SUPPORTED_NUM_BITS = [4, 8]
MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
# In case there is a performance issue with Marlin, the variable below can be
# changed to False, which allows Marlin to perform global reductions in fp16
# precision (instead of fp32), and therefore, save on some memory movements.
USE_FP32_REDUCE_DEFAULT = True
def _check_marlin_supported(num_bits: int, group_size: int, is_sym: bool,
min_capability: Optional[int],
@ -244,7 +249,8 @@ def apply_gptq_marlin_linear(
output_size_per_partition: int,
input_size_per_partition: int,
is_k_full: bool,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
bias: Optional[torch.Tensor] = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition, )
@ -260,7 +266,8 @@ def apply_gptq_marlin_linear(
size_n=output_size_per_partition,
size_k=input_size_per_partition,
is_k_full=is_k_full,
has_zp=False)
has_zp=False,
use_fp32_reduce=use_fp32_reduce)
if bias is not None:
output.add_(bias) # In-place add
@ -279,7 +286,8 @@ def apply_awq_marlin_linear(
num_bits: int,
output_size_per_partition: int,
input_size_per_partition: int,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
bias: Optional[torch.Tensor] = None,
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition, )
@ -295,7 +303,8 @@ def apply_awq_marlin_linear(
size_n=output_size_per_partition,
size_k=input_size_per_partition,
is_k_full=True,
has_zp=True)
has_zp=True,
use_fp32_reduce=use_fp32_reduce)
if bias is not None:
output.add_(bias) # In-place add