mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:44:57 +08:00
[Kernel] Increase precision of GPTQ/AWQ Marlin kernel (#6795)
This commit is contained in:
parent
fad5576c58
commit
75acdaa4b6
@ -10,7 +10,7 @@ from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS)
|
||||
MARLIN_SUPPORTED_GROUP_SIZES, MARLIN_SUPPORTED_NUM_BITS)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
MarlinWorkspace, marlin_quantize)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
|
||||
@ -56,6 +56,8 @@ def bench_run(results: List[benchmark.Measurement], model: str,
|
||||
(marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta,
|
||||
marlin_24_s) = marlin_24_quantize(b, num_bits, group_size)
|
||||
|
||||
marlin_zp = torch.empty(0, dtype=torch.int, device=b.device)
|
||||
|
||||
# GPTQ quant
|
||||
(w_ref, q_w, s, g_idx,
|
||||
rand_perm) = quantize_weights(b, num_bits, group_size, act_order)
|
||||
@ -87,6 +89,7 @@ def bench_run(results: List[benchmark.Measurement], model: str,
|
||||
"marlin_w_ref": marlin_w_ref,
|
||||
"marlin_q_w": marlin_q_w,
|
||||
"marlin_s": marlin_s,
|
||||
"marlin_zp": marlin_zp,
|
||||
"marlin_g_idx": marlin_g_idx,
|
||||
"marlin_sort_indices": marlin_sort_indices,
|
||||
"marlin_rand_perm": marlin_rand_perm,
|
||||
@ -125,11 +128,21 @@ def bench_run(results: List[benchmark.Measurement], model: str,
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt=
|
||||
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full)", # noqa: E501
|
||||
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full, False, False)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="gptq_marlin_gemm",
|
||||
description="gptq_marlin_gemm_fp16",
|
||||
).blocked_autorange(min_run_time=min_run_time))
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt=
|
||||
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full, False, True)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="gptq_marlin_gemm_fp32",
|
||||
).blocked_autorange(min_run_time=min_run_time))
|
||||
|
||||
if (num_bits in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
|
||||
@ -183,12 +196,12 @@ def main(args):
|
||||
) > 0 and is_k_full not in args.limit_k_full:
|
||||
continue
|
||||
|
||||
for num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
|
||||
for num_bits in MARLIN_SUPPORTED_NUM_BITS:
|
||||
if len(args.limit_num_bits
|
||||
) > 0 and num_bits not in args.limit_num_bits:
|
||||
continue
|
||||
|
||||
for group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES:
|
||||
for group_size in MARLIN_SUPPORTED_GROUP_SIZES:
|
||||
if len(
|
||||
args.limit_group_size
|
||||
) > 0 and group_size not in args.limit_group_size:
|
||||
|
||||
@ -93,7 +93,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& g_idx, torch::Tensor& perm,
|
||||
torch::Tensor& workspace, int64_t num_bits,
|
||||
int64_t size_m, int64_t size_n, int64_t size_k,
|
||||
bool is_k_full, bool has_zp);
|
||||
bool is_k_full, bool has_zp,
|
||||
bool use_fp32_reduce);
|
||||
|
||||
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||
int64_t size_k, int64_t size_n,
|
||||
|
||||
@ -59,14 +59,16 @@ __global__ void Marlin(
|
||||
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
||||
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
||||
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
||||
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
|
||||
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
||||
// (k/groupsize)xn
|
||||
const int* __restrict__ g_idx, // int32 group indices of shape k
|
||||
int num_groups, // number of scale groups per output channel
|
||||
int prob_m, // batch dimension m
|
||||
int prob_n, // output dimension n
|
||||
int prob_k, // reduction dimension k
|
||||
int* locks // extra global storage for barrier synchronization
|
||||
int num_groups, // number of scale groups per output channel
|
||||
int prob_m, // batch dimension m
|
||||
int prob_n, // output dimension n
|
||||
int prob_k, // reduction dimension k
|
||||
int* locks, // extra global storage for barrier synchronization
|
||||
bool use_fp32_reduce // whether to use fp32 global reduce
|
||||
) {}
|
||||
|
||||
} // namespace gptq_marlin
|
||||
@ -532,16 +534,18 @@ __global__ void Marlin(
|
||||
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
||||
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
||||
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
||||
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
|
||||
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
||||
// (k/groupsize)xn
|
||||
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
|
||||
// (k/groupsize)x(n/pack_factor)
|
||||
const int* __restrict__ g_idx, // int32 group indices of shape k
|
||||
int num_groups, // number of scale groups per output channel
|
||||
int prob_m, // batch dimension m
|
||||
int prob_n, // output dimension n
|
||||
int prob_k, // reduction dimension k
|
||||
int* locks // extra global storage for barrier synchronization
|
||||
int num_groups, // number of scale groups per output channel
|
||||
int prob_m, // batch dimension m
|
||||
int prob_n, // output dimension n
|
||||
int prob_k, // reduction dimension k
|
||||
int* locks, // extra global storage for barrier synchronization
|
||||
bool use_fp32_reduce // whether to use fp32 global reduce
|
||||
) {
|
||||
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
|
||||
// same size, which might involve multiple column "slices" (of width 16 *
|
||||
@ -595,6 +599,8 @@ __global__ void Marlin(
|
||||
int slice_idx; // index of threadblock in current slice; numbered bottom to
|
||||
// top
|
||||
|
||||
int par_id = 0;
|
||||
|
||||
// We can easily implement parallel problem execution by just remapping
|
||||
// indices and advancing global pointers
|
||||
if (slice_col_par >= n_tiles) {
|
||||
@ -602,6 +608,7 @@ __global__ void Marlin(
|
||||
C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
|
||||
locks += (slice_col_par / n_tiles) * n_tiles;
|
||||
slice_col = slice_col_par % n_tiles;
|
||||
par_id = slice_col_par / n_tiles;
|
||||
}
|
||||
|
||||
// Compute all information about the current slice which is required for
|
||||
@ -632,6 +639,7 @@ __global__ void Marlin(
|
||||
C += 16 * thread_m_blocks * prob_n / 8;
|
||||
locks += n_tiles;
|
||||
slice_col = 0;
|
||||
par_id++;
|
||||
}
|
||||
};
|
||||
init_slice();
|
||||
@ -1321,7 +1329,7 @@ __global__ void Marlin(
|
||||
// finally have to globally reduce over the results. As the striped
|
||||
// partitioning minimizes the number of such reductions and our outputs are
|
||||
// usually rather small, we perform this reduction serially in L2 cache.
|
||||
auto global_reduce = [&](bool first = false, bool last = false) {
|
||||
auto global_reduce_fp16 = [&](bool first = false, bool last = false) {
|
||||
// We are very careful here to reduce directly in the output buffer to
|
||||
// maximize L2 cache utilization in this step. To do this, we write out
|
||||
// results in FP16 (but still reduce with FP32 compute).
|
||||
@ -1382,6 +1390,53 @@ __global__ void Marlin(
|
||||
}
|
||||
};
|
||||
|
||||
// Globally reduce over threadblocks that compute the same column block.
|
||||
// We use a tmp C buffer to reduce in full fp32 precision.
|
||||
auto global_reduce_fp32 = [&](bool first = false, bool last = false) {
|
||||
constexpr int tb_m = thread_m_blocks * 16;
|
||||
constexpr int tb_n = thread_n_blocks * 16;
|
||||
|
||||
constexpr int c_size = tb_m * tb_n * sizeof(float) / 16;
|
||||
|
||||
constexpr int active_threads = 32 * thread_n_blocks / 4;
|
||||
bool is_th_active = threadIdx.x < active_threads;
|
||||
|
||||
int par_offset = c_size * n_tiles * par_id;
|
||||
int slice_offset = c_size * slice_col;
|
||||
|
||||
constexpr int num_floats = thread_m_blocks * 4 * 2 * 4;
|
||||
constexpr int th_size = num_floats * sizeof(float) / 16;
|
||||
|
||||
int c_cur_offset = par_offset + slice_offset;
|
||||
|
||||
if (!is_th_active) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!first) {
|
||||
float* frag_c_ptr = reinterpret_cast<float*>(&frag_c);
|
||||
#pragma unroll
|
||||
for (int k = 0; k < th_size; k++) {
|
||||
sh[threadIdx.x] =
|
||||
C_tmp[c_cur_offset + active_threads * k + threadIdx.x];
|
||||
|
||||
float* sh_c_ptr = reinterpret_cast<float*>(&sh[threadIdx.x]);
|
||||
#pragma unroll
|
||||
for (int f = 0; f < 4; f++) {
|
||||
frag_c_ptr[k * 4 + f] += sh_c_ptr[f];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!last) {
|
||||
int4* frag_c_ptr = reinterpret_cast<int4*>(&frag_c);
|
||||
#pragma unroll
|
||||
for (int k = 0; k < th_size; k++) {
|
||||
C_tmp[c_cur_offset + active_threads * k + threadIdx.x] = frag_c_ptr[k];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Write out the reduce final result in the correct layout. We only actually
|
||||
// reshuffle matrix fragments in this step, the reduction above is performed
|
||||
// in fragment layout.
|
||||
@ -1606,7 +1661,11 @@ __global__ void Marlin(
|
||||
if (slice_count > 1) { // only globally reduce if there is more than one
|
||||
// block in a slice
|
||||
barrier_acquire(&locks[slice_col], slice_idx);
|
||||
global_reduce(slice_idx == 0, last);
|
||||
if (use_fp32_reduce) {
|
||||
global_reduce_fp32(slice_idx == 0, last);
|
||||
} else {
|
||||
global_reduce_fp16(slice_idx == 0, last);
|
||||
}
|
||||
barrier_release(&locks[slice_col], last);
|
||||
}
|
||||
if (last) // only the last block in a slice actually writes the result
|
||||
@ -1661,8 +1720,8 @@ __global__ void Marlin(
|
||||
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
|
||||
HAS_ZP, GROUP_BLOCKS> \
|
||||
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
|
||||
A_ptr, B_ptr, C_ptr, s_ptr, zp_ptr, g_idx_ptr, num_groups, \
|
||||
prob_m, prob_n, prob_k, locks); \
|
||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \
|
||||
num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
@ -1801,6 +1860,27 @@ bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
|
||||
return true;
|
||||
}
|
||||
|
||||
int determine_reduce_max_m(int prob_m, int max_par) {
|
||||
constexpr int tile_m_size = 16;
|
||||
|
||||
if (prob_m <= tile_m_size) {
|
||||
return tile_m_size;
|
||||
|
||||
} else if (prob_m <= tile_m_size * 2) {
|
||||
return tile_m_size * 2;
|
||||
|
||||
} else if (prob_m <= tile_m_size * 3) {
|
||||
return tile_m_size * 3;
|
||||
|
||||
} else if (prob_m <= tile_m_size * 4) {
|
||||
return tile_m_size * 4;
|
||||
|
||||
} else {
|
||||
int cur_par = min(div_ceil(prob_m, tile_m_size * 4), max_par);
|
||||
return tile_m_size * 4 * cur_par;
|
||||
}
|
||||
}
|
||||
|
||||
exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
|
||||
int num_bits, int group_size,
|
||||
bool has_act_order, bool is_k_full,
|
||||
@ -1880,13 +1960,13 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
|
||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
|
||||
|
||||
template <typename scalar_t>
|
||||
void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, void* zp,
|
||||
void* g_idx, void* perm, void* a_tmp, int prob_m,
|
||||
int prob_n, int prob_k, void* workspace, int num_bits,
|
||||
bool has_act_order, bool is_k_full, bool has_zp,
|
||||
int num_groups, int group_size, int dev,
|
||||
void marlin_mm_f16i4(const void* A, const void* B, void* C, void* C_tmp,
|
||||
void* s, void* zp, void* g_idx, void* perm, void* a_tmp,
|
||||
int prob_m, int prob_n, int prob_k, void* workspace,
|
||||
int num_bits, bool has_act_order, bool is_k_full,
|
||||
bool has_zp, int num_groups, int group_size, int dev,
|
||||
cudaStream_t stream, int thread_k, int thread_n, int sms,
|
||||
int max_par) {
|
||||
int max_par, bool use_fp32_reduce) {
|
||||
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
||||
"num_bits must be 4 or 8. Got = ", num_bits);
|
||||
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
||||
@ -1970,6 +2050,7 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, void* zp,
|
||||
const int4* A_ptr = (const int4*)A;
|
||||
const int4* B_ptr = (const int4*)B;
|
||||
int4* C_ptr = (int4*)C;
|
||||
int4* C_tmp_ptr = (int4*)C_tmp;
|
||||
const int4* s_ptr = (const int4*)s;
|
||||
const int4* zp_ptr = (const int4*)zp;
|
||||
const int* g_idx_ptr = (const int*)g_idx;
|
||||
@ -2049,7 +2130,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& g_idx, torch::Tensor& perm,
|
||||
torch::Tensor& workspace, int64_t num_bits,
|
||||
int64_t size_m, int64_t size_n, int64_t size_k,
|
||||
bool is_k_full, bool has_zp) {
|
||||
bool is_k_full, bool has_zp,
|
||||
bool use_fp32_reduce) {
|
||||
// Verify num_bits
|
||||
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
||||
"num_bits must be 4 or 8. Got = ", num_bits);
|
||||
@ -2099,6 +2181,17 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor c = torch::empty({size_m, size_n}, options);
|
||||
torch::Tensor a_tmp = torch::empty({size_m, size_k}, options);
|
||||
|
||||
// Alloc C tmp buffer that is going to be used for the global reduce
|
||||
int reduce_max_m = marlin::determine_reduce_max_m(size_m, marlin::max_par);
|
||||
int reduce_n = size_n;
|
||||
auto options_fp32 =
|
||||
torch::TensorOptions().dtype(at::kFloat).device(a.device());
|
||||
if (!use_fp32_reduce) {
|
||||
reduce_max_m = 0;
|
||||
reduce_n = 0;
|
||||
}
|
||||
torch::Tensor c_tmp = torch::empty({reduce_max_m, reduce_n}, options_fp32);
|
||||
|
||||
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
|
||||
// auto -1)
|
||||
int thread_k = -1;
|
||||
@ -2171,20 +2264,21 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
if (a.scalar_type() == at::ScalarType::Half) {
|
||||
marlin::marlin_mm_f16i4<half>(
|
||||
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
|
||||
b_scales.data_ptr<at::Half>(), b_zeros.data_ptr(), g_idx.data_ptr(),
|
||||
perm.data_ptr(), a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
|
||||
c_tmp.data_ptr<float>(), b_scales.data_ptr<at::Half>(),
|
||||
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
|
||||
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
|
||||
workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp,
|
||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_n, sms, marlin::max_par);
|
||||
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
|
||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||
marlin::marlin_mm_f16i4<nv_bfloat16>(
|
||||
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
||||
c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(),
|
||||
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
|
||||
a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
|
||||
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
|
||||
b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(),
|
||||
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
|
||||
workspace.data_ptr(), num_bits, has_act_order, is_k_full, has_zp,
|
||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_n, sms, marlin::max_par);
|
||||
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
|
||||
} else {
|
||||
TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
|
||||
}
|
||||
|
||||
@ -27,6 +27,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
|
||||
ACT_ORDER_OPTS = [False, True]
|
||||
K_FULL_OPTS = [False, True]
|
||||
USE_FP32_REDUCE_OPTS = [False, True]
|
||||
|
||||
MARLIN_K_CHUNKS = [128]
|
||||
MARLIN_N_CHUNKS = [64, 128, 256]
|
||||
@ -175,6 +176,7 @@ def test_awq_marlin_repack(k_chunk, n_chunk, num_bits, group_size,
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
|
||||
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
|
||||
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
|
||||
def test_gptq_marlin_gemm(
|
||||
k_chunk,
|
||||
n_chunk,
|
||||
@ -183,6 +185,7 @@ def test_gptq_marlin_gemm(
|
||||
mnk_factors,
|
||||
act_order,
|
||||
is_k_full,
|
||||
use_fp32_reduce,
|
||||
):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
@ -222,8 +225,9 @@ def test_gptq_marlin_gemm(
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
is_k_full,
|
||||
is_k_full=is_k_full,
|
||||
has_zp=False,
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
)
|
||||
output_ref = torch.matmul(a_input, w_ref)
|
||||
|
||||
@ -365,12 +369,14 @@ def test_fp8_marlin_gemm(
|
||||
@pytest.mark.parametrize("num_bits", MARLIN_SUPPORTED_NUM_BITS)
|
||||
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
|
||||
def test_awq_marlin_gemm(
|
||||
k_chunk,
|
||||
n_chunk,
|
||||
num_bits,
|
||||
group_size,
|
||||
mnk_factors,
|
||||
use_fp32_reduce,
|
||||
):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
@ -407,8 +413,9 @@ def test_awq_marlin_gemm(
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
is_k_full,
|
||||
has_zp,
|
||||
is_k_full=is_k_full,
|
||||
has_zp=has_zp,
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
)
|
||||
output_ref = torch.matmul(a_input, w_ref)
|
||||
|
||||
|
||||
@ -286,12 +286,12 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||
b_scales: torch.Tensor, b_zeros: torch.Tensor,
|
||||
g_idx: torch.Tensor, perm: torch.Tensor,
|
||||
workspace: torch.Tensor, num_bits: int, size_m: int,
|
||||
size_n: int, size_k: int, is_k_full: bool,
|
||||
has_zp: bool) -> torch.Tensor:
|
||||
size_n: int, size_k: int, is_k_full: bool, has_zp: bool,
|
||||
use_fp32_reduce: bool) -> torch.Tensor:
|
||||
return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
|
||||
g_idx, perm, workspace, num_bits,
|
||||
size_m, size_n, size_k, is_k_full,
|
||||
has_zp)
|
||||
has_zp, use_fp32_reduce)
|
||||
|
||||
|
||||
# fp8 marlin
|
||||
|
||||
@ -16,6 +16,11 @@ GPTQ_MARLIN_MAX_PARALLEL = 16
|
||||
MARLIN_SUPPORTED_NUM_BITS = [4, 8]
|
||||
MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
|
||||
# In case there is a performance issue with Marlin, the variable below can be
|
||||
# changed to False, which allows Marlin to perform global reductions in fp16
|
||||
# precision (instead of fp32), and therefore, save on some memory movements.
|
||||
USE_FP32_REDUCE_DEFAULT = True
|
||||
|
||||
|
||||
def _check_marlin_supported(num_bits: int, group_size: int, is_sym: bool,
|
||||
min_capability: Optional[int],
|
||||
@ -244,7 +249,8 @@ def apply_gptq_marlin_linear(
|
||||
output_size_per_partition: int,
|
||||
input_size_per_partition: int,
|
||||
is_k_full: bool,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
|
||||
reshaped_x = input.reshape(-1, input.shape[-1])
|
||||
out_shape = input.shape[:-1] + (output_size_per_partition, )
|
||||
|
||||
@ -260,7 +266,8 @@ def apply_gptq_marlin_linear(
|
||||
size_n=output_size_per_partition,
|
||||
size_k=input_size_per_partition,
|
||||
is_k_full=is_k_full,
|
||||
has_zp=False)
|
||||
has_zp=False,
|
||||
use_fp32_reduce=use_fp32_reduce)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
@ -279,7 +286,8 @@ def apply_awq_marlin_linear(
|
||||
num_bits: int,
|
||||
output_size_per_partition: int,
|
||||
input_size_per_partition: int,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor:
|
||||
reshaped_x = input.reshape(-1, input.shape[-1])
|
||||
out_shape = input.shape[:-1] + (output_size_per_partition, )
|
||||
|
||||
@ -295,7 +303,8 @@ def apply_awq_marlin_linear(
|
||||
size_n=output_size_per_partition,
|
||||
size_k=input_size_per_partition,
|
||||
is_k_full=True,
|
||||
has_zp=True)
|
||||
has_zp=True,
|
||||
use_fp32_reduce=use_fp32_reduce)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user