mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 08:04:58 +08:00
Add marlin unit tests and marlin benchmark script (#4815)
This commit is contained in:
parent
973617ae02
commit
5c342570d7
183
benchmarks/kernels/benchmark_marlin.py
Normal file
183
benchmarks/kernels/benchmark_marlin.py
Normal file
@ -0,0 +1,183 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import torch.utils.benchmark as benchmark
|
||||
from benchmark_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
MarlinWorkspace, marlin_quantize)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
gptq_pack, quantize_weights, sort_weights)
|
||||
|
||||
DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
|
||||
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
|
||||
|
||||
ACT_ORDER_OPTS = [False, True]
|
||||
K_FULL_OPTS = [False, True]
|
||||
|
||||
|
||||
def bench_run(results, model, act_order, is_k_full, num_bits, group_size,
|
||||
size_m, size_k, size_n):
|
||||
label = "Quant Matmul"
|
||||
|
||||
sub_label = ("{}, act={} k_full={}, b={}, g={}, "
|
||||
"MKN=({}x{}x{})".format(model, act_order, is_k_full, num_bits,
|
||||
group_size, size_m, size_k, size_n))
|
||||
|
||||
print(f"Testing: {sub_label}")
|
||||
|
||||
a = torch.randn(size_m, size_k).to(torch.half).cuda()
|
||||
b = torch.rand(size_k, size_n).to(torch.half).cuda()
|
||||
|
||||
a_tmp = (torch.zeros(size_m, size_k).to(torch.half).cuda())
|
||||
|
||||
# Marlin quant
|
||||
(
|
||||
marlin_w_ref,
|
||||
marlin_q_w,
|
||||
marlin_s,
|
||||
marlin_g_idx,
|
||||
marlin_sort_indices,
|
||||
marlin_rand_perm,
|
||||
) = marlin_quantize(b, num_bits, group_size, act_order)
|
||||
|
||||
# GPTQ quant
|
||||
(w_ref, q_w, s, g_idx,
|
||||
rand_perm) = quantize_weights(b, num_bits, group_size, act_order)
|
||||
q_w_gptq = gptq_pack(q_w, num_bits, size_k, size_n)
|
||||
|
||||
# For act_order, sort the "weights" and "g_idx"
|
||||
# so that group ids are increasing
|
||||
repack_sort_indices = torch.empty(0, dtype=torch.int, device=b.device)
|
||||
if act_order:
|
||||
(q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx)
|
||||
|
||||
# Prepare
|
||||
marlin_workspace = MarlinWorkspace(size_n)
|
||||
|
||||
globals = {
|
||||
"marlin_w_ref": marlin_w_ref,
|
||||
"marlin_q_w": marlin_q_w,
|
||||
"marlin_s": marlin_s,
|
||||
"marlin_g_idx": marlin_g_idx,
|
||||
"marlin_sort_indices": marlin_sort_indices,
|
||||
"marlin_rand_perm": marlin_rand_perm,
|
||||
"q_w_gptq": q_w_gptq,
|
||||
"repack_sort_indices": repack_sort_indices,
|
||||
"num_bits": num_bits,
|
||||
"group_size": group_size,
|
||||
"size_m": size_m,
|
||||
"size_n": size_n,
|
||||
"size_k": size_k,
|
||||
"is_k_full": is_k_full,
|
||||
"a": a,
|
||||
"a_tmp": a_tmp,
|
||||
"gptq_marlin_gemm": ops.gptq_marlin_gemm,
|
||||
"gptq_marlin_repack": ops.gptq_marlin_repack,
|
||||
"marlin_workspace": marlin_workspace,
|
||||
}
|
||||
|
||||
min_run_time = 1
|
||||
|
||||
# Warmup pytorch
|
||||
for i in range(5):
|
||||
torch.matmul(a, marlin_w_ref)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="torch.matmul(a, marlin_w_ref)",
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="pytorch_gemm",
|
||||
).blocked_autorange(min_run_time=min_run_time))
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt=
|
||||
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="gptq_marlin_gemm",
|
||||
).blocked_autorange(min_run_time=min_run_time))
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt=
|
||||
"q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, num_bits)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="gptq_marlin_repack",
|
||||
).blocked_autorange(min_run_time=min_run_time))
|
||||
|
||||
|
||||
def main(args):
|
||||
print("Benchmarking models:")
|
||||
for i, model in enumerate(args.models):
|
||||
print(f"[{i}] {model}")
|
||||
|
||||
results = []
|
||||
|
||||
for model in args.models:
|
||||
for layer in WEIGHT_SHAPES[model]:
|
||||
size_k = layer[0]
|
||||
size_n = layer[1]
|
||||
|
||||
if len(args.limit_k) > 0 and size_k not in args.limit_k:
|
||||
continue
|
||||
|
||||
if len(args.limit_n) > 0 and size_n not in args.limit_n:
|
||||
continue
|
||||
|
||||
for act_order in ACT_ORDER_OPTS:
|
||||
for is_k_full in K_FULL_OPTS:
|
||||
for num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
|
||||
for group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES:
|
||||
if len(
|
||||
args.limit_group_size
|
||||
) > 0 and group_size not in args.limit_group_size:
|
||||
continue
|
||||
|
||||
# For act_order, the group_size must be less than
|
||||
# size_k
|
||||
if act_order and (group_size == size_k
|
||||
or group_size == -1):
|
||||
continue
|
||||
|
||||
for size_m in args.batch_sizes:
|
||||
bench_run(results, model, act_order, is_k_full,
|
||||
num_bits, group_size, size_m, size_k,
|
||||
size_n)
|
||||
|
||||
compare = benchmark.Compare(results)
|
||||
compare.print()
|
||||
|
||||
|
||||
# For quick benchmarking use:
|
||||
# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 # noqa E501
|
||||
#
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Benchmark Marlin across specified models/shapes/batches")
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=DEFAULT_MODELS,
|
||||
choices=WEIGHT_SHAPES.keys(),
|
||||
)
|
||||
parser.add_argument("--batch-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=DEFAULT_BATCH_SIZES)
|
||||
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--limit-group-size", nargs="+", type=int, default=[])
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
75
benchmarks/kernels/benchmark_shapes.py
Normal file
75
benchmarks/kernels/benchmark_shapes.py
Normal file
@ -0,0 +1,75 @@
|
||||
WEIGHT_SHAPES = {
|
||||
"ideal": [[4 * 256 * 32, 256 * 32]],
|
||||
"mistralai/Mistral-7B-v0.1/TP1": [
|
||||
[4096, 6144],
|
||||
[4096, 4096],
|
||||
[4096, 28672],
|
||||
[14336, 4096],
|
||||
],
|
||||
"mistralai/Mistral-7B-v0.1/TP2": [
|
||||
[4096, 3072],
|
||||
[2048, 4096],
|
||||
[4096, 14336],
|
||||
[7168, 4096],
|
||||
],
|
||||
"mistralai/Mistral-7B-v0.1/TP4": [
|
||||
[4096, 1536],
|
||||
[1024, 4096],
|
||||
[4096, 7168],
|
||||
[3584, 4096],
|
||||
],
|
||||
"meta-llama/Llama-2-7b-hf/TP1": [
|
||||
[4096, 12288],
|
||||
[4096, 4096],
|
||||
[4096, 22016],
|
||||
[11008, 4096],
|
||||
],
|
||||
"meta-llama/Llama-2-7b-hf/TP2": [
|
||||
[4096, 6144],
|
||||
[2048, 4096],
|
||||
[4096, 11008],
|
||||
[5504, 4096],
|
||||
],
|
||||
"meta-llama/Llama-2-7b-hf/TP4": [
|
||||
[4096, 3072],
|
||||
[1024, 4096],
|
||||
[4096, 5504],
|
||||
[2752, 4096],
|
||||
],
|
||||
"meta-llama/Llama-2-13b-hf/TP1": [
|
||||
[5120, 15360],
|
||||
[5120, 5120],
|
||||
[5120, 27648],
|
||||
[13824, 5120],
|
||||
],
|
||||
"meta-llama/Llama-2-13b-hf/TP2": [
|
||||
[5120, 7680],
|
||||
[2560, 5120],
|
||||
[5120, 13824],
|
||||
[6912, 5120],
|
||||
],
|
||||
"meta-llama/Llama-2-13b-hf/TP4": [
|
||||
[5120, 3840],
|
||||
[1280, 5120],
|
||||
[5120, 6912],
|
||||
[3456, 5120],
|
||||
],
|
||||
"meta-llama/Llama-2-70b-hf/TP1": [
|
||||
[8192, 10240],
|
||||
[8192, 8192],
|
||||
[8192, 57344],
|
||||
[28672, 8192],
|
||||
],
|
||||
"meta-llama/Llama-2-70b-hf/TP2": [
|
||||
[8192, 5120],
|
||||
[4096, 8192],
|
||||
[8192, 28672],
|
||||
[14336, 8192],
|
||||
],
|
||||
"meta-llama/Llama-2-70b-hf/TP4": [
|
||||
[8192, 2560],
|
||||
[2048, 8192],
|
||||
[8192, 14336],
|
||||
[7168, 8192],
|
||||
],
|
||||
}
|
||||
158
tests/kernels/test_marlin_gemm.py
Normal file
158
tests/kernels/test_marlin_gemm.py
Normal file
@ -0,0 +1,158 @@
|
||||
"""Tests for the marlin kernel.
|
||||
|
||||
Run `pytest tests/kernels/marlin/test_marlin_gemm.py`.
|
||||
"""
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
MarlinWorkspace, is_marlin_supported, marlin_quantize, marlin_weights)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
gptq_pack, quantize_weights, sort_weights)
|
||||
|
||||
ACT_ORDER_OPTS = [False, True]
|
||||
K_FULL_OPTS = [False, True]
|
||||
|
||||
K_CHUNKS = [128, 256]
|
||||
N_CHUNKS = [64, 128, 256]
|
||||
|
||||
MNK_FACTORS = [
|
||||
(1, 1, 1),
|
||||
(1, 4, 8),
|
||||
(1, 7, 5),
|
||||
(1, 7 * 4, 5 * 1),
|
||||
(13, 17, 67),
|
||||
(26, 37, 13),
|
||||
(67, 13, 11),
|
||||
]
|
||||
|
||||
|
||||
def rand_data(shape):
|
||||
data = torch.rand(shape).to(torch.half).cuda()
|
||||
return data
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_marlin_supported(),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("k_chunk", K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", N_CHUNKS)
|
||||
@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_SUPPORTED_NUM_BITS)
|
||||
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
|
||||
mnk_factors):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
size_m = m_factor
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
|
||||
print(f"MNK = {size_m} {size_n} {size_k}")
|
||||
|
||||
# Filter act_order
|
||||
if act_order:
|
||||
if group_size == -1:
|
||||
return
|
||||
if group_size == size_k:
|
||||
return
|
||||
|
||||
# Normalize group_size
|
||||
if group_size == -1:
|
||||
group_size = size_k
|
||||
assert group_size <= size_k
|
||||
|
||||
# Create input
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
|
||||
# Quantize (and apply act_order if provided)
|
||||
w_ref, q_w, s, g_idx, rand_perm = quantize_weights(b_weight, num_bits,
|
||||
group_size, act_order)
|
||||
|
||||
# Pack to GPTQ format
|
||||
q_w_gptq = gptq_pack(q_w, num_bits, size_k, size_n)
|
||||
|
||||
# For act_order, sort the "weights" and "g_idx" so that group ids are
|
||||
# increasing
|
||||
sort_indices = torch.empty(0, dtype=torch.int, device=b_weight.device)
|
||||
if act_order:
|
||||
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
|
||||
|
||||
# Pack to Marlin format
|
||||
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, num_bits)
|
||||
|
||||
# Run Marlin repack GPU kernel
|
||||
marlin_q_w_2 = ops.gptq_marlin_repack(
|
||||
q_w_gptq,
|
||||
sort_indices,
|
||||
size_k,
|
||||
size_n,
|
||||
num_bits,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
assert torch.allclose(marlin_q_w_1, marlin_q_w_2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_marlin_supported(),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("k_chunk", K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", N_CHUNKS)
|
||||
@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_SUPPORTED_NUM_BITS)
|
||||
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
|
||||
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
|
||||
def test_marlin_gemm(
|
||||
k_chunk,
|
||||
n_chunk,
|
||||
num_bits,
|
||||
group_size,
|
||||
mnk_factors,
|
||||
act_order,
|
||||
is_k_full,
|
||||
):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
size_m = m_factor
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
|
||||
print(f"MNK = {size_m} {size_n} {size_k}")
|
||||
print(f"groupsize = {group_size}")
|
||||
|
||||
if act_order:
|
||||
if group_size == -1:
|
||||
return
|
||||
if group_size == size_k:
|
||||
return
|
||||
|
||||
a_input = rand_data((size_m, size_k))
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
|
||||
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
||||
b_weight, num_bits, group_size, act_order)
|
||||
|
||||
workspace = MarlinWorkspace(size_n)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
a_input,
|
||||
marlin_q_w,
|
||||
marlin_s,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
workspace.scratch,
|
||||
num_bits,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
is_k_full,
|
||||
)
|
||||
output_ref = torch.matmul(a_input, w_ref)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
assert torch.allclose(output, output_ref, rtol=1e-2)
|
||||
174
vllm/model_executor/layers/quantization/utils/marlin_utils.py
Normal file
174
vllm/model_executor/layers/quantization/utils/marlin_utils.py
Normal file
@ -0,0 +1,174 @@
|
||||
"""This file is used for /tests and /benchmarks"""
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_TILE)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
get_pack_factor, quantize_weights, sort_weights)
|
||||
|
||||
__cuda_arch = torch.cuda.get_device_capability()
|
||||
|
||||
|
||||
def is_marlin_supported():
|
||||
return __cuda_arch[0] >= 8
|
||||
|
||||
|
||||
# Precompute permutations for Marlin weight and scale shuffling # noqa: E501
|
||||
#
|
||||
# Marlin works on [16,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501
|
||||
# with the tensor-core format that is described here:
|
||||
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
|
||||
#
|
||||
# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501
|
||||
# (without the need to use ldmatrix instructions) # noqa: E501
|
||||
def _get_perms(num_bits):
|
||||
perm_list = []
|
||||
for i in range(32):
|
||||
perm1 = []
|
||||
col = i // 4
|
||||
for block in [0, 1]:
|
||||
for row in [
|
||||
2 * (i % 4),
|
||||
2 * (i % 4) + 1,
|
||||
2 * (i % 4 + 4),
|
||||
2 * (i % 4 + 4) + 1,
|
||||
]:
|
||||
perm1.append(16 * row + col + 8 * block)
|
||||
for j in range(4):
|
||||
perm_list.extend([p + 256 * j for p in perm1])
|
||||
|
||||
perm = numpy.array(perm_list)
|
||||
|
||||
if num_bits == 4:
|
||||
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
||||
elif num_bits == 8:
|
||||
interleave = numpy.array([0, 2, 1, 3])
|
||||
else:
|
||||
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
||||
|
||||
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
|
||||
perm = torch.from_numpy(perm)
|
||||
scale_perm = []
|
||||
for i in range(8):
|
||||
scale_perm.extend([i + 8 * j for j in range(8)])
|
||||
scale_perm_single = []
|
||||
for i in range(4):
|
||||
scale_perm_single.extend(
|
||||
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
|
||||
return perm, scale_perm, scale_perm_single
|
||||
|
||||
|
||||
_perm = {}
|
||||
_scale_perm = {}
|
||||
_scale_perm_single = {}
|
||||
for num_bits in [4, 8]:
|
||||
perm, scale_perm, scale_perm_single = _get_perms(num_bits)
|
||||
_perm[num_bits] = perm
|
||||
_scale_perm[num_bits] = scale_perm
|
||||
_scale_perm_single[num_bits] = scale_perm_single
|
||||
|
||||
|
||||
def marlin_permute_weights(q_w,
|
||||
size_k,
|
||||
size_n,
|
||||
num_bits,
|
||||
tile=GPTQ_MARLIN_TILE):
|
||||
assert q_w.shape == (size_k, size_n)
|
||||
assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
|
||||
assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
|
||||
|
||||
# Permute weights to 16x64 marlin tiles
|
||||
q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
|
||||
q_w = q_w.permute((0, 2, 1, 3))
|
||||
q_w = q_w.reshape((size_k // tile, size_n * tile))
|
||||
|
||||
q_w = q_w.reshape(
|
||||
(-1, _perm[num_bits].numel()))[:, _perm[num_bits]].reshape(q_w.shape)
|
||||
|
||||
return q_w
|
||||
|
||||
|
||||
def marlin_weights(q_w, size_k, size_n, num_bits):
|
||||
# Permute
|
||||
q_w = marlin_permute_weights(q_w, size_k, size_n, num_bits)
|
||||
|
||||
# Pack
|
||||
pack_factor = get_pack_factor(num_bits)
|
||||
orig_device = q_w.device
|
||||
|
||||
q_w = q_w.cpu().numpy().astype(numpy.uint32)
|
||||
|
||||
q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor),
|
||||
dtype=numpy.uint32)
|
||||
|
||||
for i in range(pack_factor):
|
||||
q_packed |= q_w[:, i::pack_factor] << num_bits * i
|
||||
|
||||
q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device)
|
||||
|
||||
return q_packed
|
||||
|
||||
|
||||
def marlin_permute_scales(s, size_k, size_n, group_size, num_bits):
|
||||
if group_size < size_k and group_size != -1:
|
||||
s = s.reshape((-1, len(_scale_perm[num_bits])))[:,
|
||||
_scale_perm[num_bits]]
|
||||
else:
|
||||
s = s.reshape(
|
||||
(-1,
|
||||
len(_scale_perm_single[num_bits])))[:,
|
||||
_scale_perm_single[num_bits]]
|
||||
s = s.reshape((-1, size_n)).contiguous()
|
||||
|
||||
return s
|
||||
|
||||
|
||||
def marlin_quantize(
|
||||
w: torch.Tensor,
|
||||
num_bits: int,
|
||||
group_size: int,
|
||||
act_order: bool,
|
||||
):
|
||||
size_k, size_n = w.shape
|
||||
|
||||
# Normalize group_size
|
||||
if group_size == -1:
|
||||
group_size = size_k
|
||||
assert group_size <= size_k
|
||||
|
||||
# Quantize (and apply act_order if provided)
|
||||
w_ref, q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size,
|
||||
act_order)
|
||||
|
||||
# For act_order, sort the "weights" and "g_idx" so that group ids are
|
||||
# increasing
|
||||
sort_indices = torch.empty(0, dtype=torch.int, device=w.device)
|
||||
if act_order:
|
||||
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
|
||||
|
||||
# Reformat to marlin
|
||||
marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits)
|
||||
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size, num_bits)
|
||||
|
||||
# Create result
|
||||
res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]
|
||||
for i in range(len(res_list)):
|
||||
res_list[i] = res_list[i].to(w.device)
|
||||
|
||||
return res_list
|
||||
|
||||
|
||||
class MarlinWorkspace:
|
||||
|
||||
def __init__(self, out_features):
|
||||
assert (out_features % GPTQ_MARLIN_MIN_THREAD_N == 0), (
|
||||
"out_features = {} is undivisible by GPTQ_MARLIN_MIN_THREAD_N = {}"
|
||||
.format(out_features, GPTQ_MARLIN_MIN_THREAD_N))
|
||||
|
||||
max_workspace_size = ((out_features // GPTQ_MARLIN_MIN_THREAD_N) *
|
||||
GPTQ_MARLIN_MAX_PARALLEL)
|
||||
|
||||
self.scratch = torch.zeros(max_workspace_size,
|
||||
dtype=torch.int,
|
||||
device="cuda")
|
||||
146
vllm/model_executor/layers/quantization/utils/quant_utils.py
Normal file
146
vllm/model_executor/layers/quantization/utils/quant_utils.py
Normal file
@ -0,0 +1,146 @@
|
||||
"""This file is used for /tests and /benchmarks"""
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
SUPPORTED_NUM_BITS = [4, 8]
|
||||
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
|
||||
|
||||
def get_pack_factor(num_bits):
|
||||
assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
|
||||
return 32 // num_bits
|
||||
|
||||
|
||||
def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
|
||||
assert q_w.shape == w_ref.shape
|
||||
|
||||
orig_device = q_w.device
|
||||
k_size, _ = q_w.shape
|
||||
|
||||
g_idx = torch.zeros((k_size, ), dtype=torch.int32)
|
||||
for i in range(k_size):
|
||||
g_idx[i] = i // group_size
|
||||
|
||||
# Simulate act_order by doing a random permutation on K
|
||||
rand_perm = torch.randperm(k_size)
|
||||
|
||||
g_idx = g_idx[rand_perm].contiguous()
|
||||
q_w = q_w[rand_perm, :].contiguous()
|
||||
w_ref = w_ref[rand_perm, :].contiguous()
|
||||
|
||||
return (
|
||||
w_ref.to(device=orig_device),
|
||||
q_w.to(device=orig_device),
|
||||
g_idx.to(device=orig_device),
|
||||
rand_perm.to(device=orig_device),
|
||||
)
|
||||
|
||||
|
||||
def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
|
||||
act_order: bool):
|
||||
orig_device = w.device
|
||||
size_k, size_n = w.shape
|
||||
|
||||
assert w.is_floating_point(), "w must be float"
|
||||
assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
|
||||
assert group_size in SUPPORTED_GROUP_SIZES + [
|
||||
size_k
|
||||
], f"Unsupported groupsize = {group_size}"
|
||||
|
||||
if group_size == -1:
|
||||
group_size = size_k
|
||||
assert group_size <= size_k
|
||||
|
||||
max_q_val = 2**num_bits - 1
|
||||
half_q_val = (max_q_val + 1) // 2
|
||||
|
||||
# Reshape to [groupsize, -1]
|
||||
if group_size < size_k:
|
||||
w = w.reshape((-1, group_size, size_n))
|
||||
w = w.permute(1, 0, 2)
|
||||
w = w.reshape((group_size, -1))
|
||||
|
||||
# Compute scale for each group
|
||||
s = torch.max(torch.abs(w), 0, keepdim=True)[0]
|
||||
s *= 2 / max_q_val # 2 => symmetric
|
||||
|
||||
# Quantize
|
||||
q_w = torch.round(w / s).int()
|
||||
q_w += half_q_val
|
||||
q_w = torch.clamp(q_w, 0, max_q_val)
|
||||
|
||||
# Compute ref (dequantized)
|
||||
w_ref = (q_w - half_q_val).half() * s
|
||||
|
||||
# Restore original shapes
|
||||
if group_size < size_k:
|
||||
|
||||
def reshape_w(w):
|
||||
w = w.reshape((group_size, -1, size_n))
|
||||
w = w.permute(1, 0, 2)
|
||||
w = w.reshape((size_k, size_n)).contiguous()
|
||||
return w
|
||||
|
||||
q_w = reshape_w(q_w)
|
||||
w_ref = reshape_w(w_ref)
|
||||
|
||||
s = s.reshape((-1, size_n)).contiguous()
|
||||
|
||||
# Apply act_order
|
||||
g_idx = torch.empty(0, dtype=torch.int, device=w.device)
|
||||
rand_perm = torch.empty(0, dtype=torch.int, device=w.device)
|
||||
if act_order:
|
||||
assert (
|
||||
group_size < size_k
|
||||
), "For act_order, groupsize = {} must be less than size_k = {}".format(
|
||||
group_size, size_k)
|
||||
|
||||
w_ref, q_w, g_idx, rand_perm = permute_rows(q_w, w_ref, group_size)
|
||||
|
||||
return (
|
||||
w_ref.to(device=orig_device),
|
||||
q_w.to(device=orig_device),
|
||||
s.to(device=orig_device),
|
||||
g_idx.to(device=orig_device),
|
||||
rand_perm.to(device=orig_device),
|
||||
)
|
||||
|
||||
|
||||
def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):
|
||||
orig_device = q_w.device
|
||||
|
||||
sort_indices = torch.argsort(g_idx).to(
|
||||
dtype=torch.int32) # Sort based on g_idx
|
||||
|
||||
g_idx = g_idx[sort_indices].contiguous()
|
||||
q_w = q_w[sort_indices, :].contiguous()
|
||||
|
||||
return (
|
||||
q_w.to(device=orig_device),
|
||||
g_idx.to(device=orig_device),
|
||||
sort_indices.to(device=orig_device),
|
||||
)
|
||||
|
||||
|
||||
def gptq_pack(
|
||||
q_w: torch.Tensor,
|
||||
num_bits: int,
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
):
|
||||
assert q_w.shape == (size_k, size_n)
|
||||
|
||||
pack_factor = get_pack_factor(num_bits)
|
||||
assert size_k % pack_factor == 0
|
||||
|
||||
orig_device = q_w.device
|
||||
|
||||
q_w = q_w.cpu().numpy().astype(numpy.uint32)
|
||||
|
||||
q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32)
|
||||
|
||||
for i in range(pack_factor):
|
||||
q_res |= q_w[i::pack_factor, :] << num_bits * i
|
||||
|
||||
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
||||
return q_res
|
||||
Loading…
x
Reference in New Issue
Block a user