mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 17:16:24 +08:00
Add marlin unit tests and marlin benchmark script (#4815)
This commit is contained in:
parent
973617ae02
commit
5c342570d7
183
benchmarks/kernels/benchmark_marlin.py
Normal file
183
benchmarks/kernels/benchmark_marlin.py
Normal file
@ -0,0 +1,183 @@
|
|||||||
|
import argparse
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.benchmark as benchmark
|
||||||
|
from benchmark_shapes import WEIGHT_SHAPES
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||||
|
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
|
MarlinWorkspace, marlin_quantize)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
gptq_pack, quantize_weights, sort_weights)
|
||||||
|
|
||||||
|
DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
|
||||||
|
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
|
||||||
|
|
||||||
|
ACT_ORDER_OPTS = [False, True]
|
||||||
|
K_FULL_OPTS = [False, True]
|
||||||
|
|
||||||
|
|
||||||
|
def bench_run(results, model, act_order, is_k_full, num_bits, group_size,
|
||||||
|
size_m, size_k, size_n):
|
||||||
|
label = "Quant Matmul"
|
||||||
|
|
||||||
|
sub_label = ("{}, act={} k_full={}, b={}, g={}, "
|
||||||
|
"MKN=({}x{}x{})".format(model, act_order, is_k_full, num_bits,
|
||||||
|
group_size, size_m, size_k, size_n))
|
||||||
|
|
||||||
|
print(f"Testing: {sub_label}")
|
||||||
|
|
||||||
|
a = torch.randn(size_m, size_k).to(torch.half).cuda()
|
||||||
|
b = torch.rand(size_k, size_n).to(torch.half).cuda()
|
||||||
|
|
||||||
|
a_tmp = (torch.zeros(size_m, size_k).to(torch.half).cuda())
|
||||||
|
|
||||||
|
# Marlin quant
|
||||||
|
(
|
||||||
|
marlin_w_ref,
|
||||||
|
marlin_q_w,
|
||||||
|
marlin_s,
|
||||||
|
marlin_g_idx,
|
||||||
|
marlin_sort_indices,
|
||||||
|
marlin_rand_perm,
|
||||||
|
) = marlin_quantize(b, num_bits, group_size, act_order)
|
||||||
|
|
||||||
|
# GPTQ quant
|
||||||
|
(w_ref, q_w, s, g_idx,
|
||||||
|
rand_perm) = quantize_weights(b, num_bits, group_size, act_order)
|
||||||
|
q_w_gptq = gptq_pack(q_w, num_bits, size_k, size_n)
|
||||||
|
|
||||||
|
# For act_order, sort the "weights" and "g_idx"
|
||||||
|
# so that group ids are increasing
|
||||||
|
repack_sort_indices = torch.empty(0, dtype=torch.int, device=b.device)
|
||||||
|
if act_order:
|
||||||
|
(q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx)
|
||||||
|
|
||||||
|
# Prepare
|
||||||
|
marlin_workspace = MarlinWorkspace(size_n)
|
||||||
|
|
||||||
|
globals = {
|
||||||
|
"marlin_w_ref": marlin_w_ref,
|
||||||
|
"marlin_q_w": marlin_q_w,
|
||||||
|
"marlin_s": marlin_s,
|
||||||
|
"marlin_g_idx": marlin_g_idx,
|
||||||
|
"marlin_sort_indices": marlin_sort_indices,
|
||||||
|
"marlin_rand_perm": marlin_rand_perm,
|
||||||
|
"q_w_gptq": q_w_gptq,
|
||||||
|
"repack_sort_indices": repack_sort_indices,
|
||||||
|
"num_bits": num_bits,
|
||||||
|
"group_size": group_size,
|
||||||
|
"size_m": size_m,
|
||||||
|
"size_n": size_n,
|
||||||
|
"size_k": size_k,
|
||||||
|
"is_k_full": is_k_full,
|
||||||
|
"a": a,
|
||||||
|
"a_tmp": a_tmp,
|
||||||
|
"gptq_marlin_gemm": ops.gptq_marlin_gemm,
|
||||||
|
"gptq_marlin_repack": ops.gptq_marlin_repack,
|
||||||
|
"marlin_workspace": marlin_workspace,
|
||||||
|
}
|
||||||
|
|
||||||
|
min_run_time = 1
|
||||||
|
|
||||||
|
# Warmup pytorch
|
||||||
|
for i in range(5):
|
||||||
|
torch.matmul(a, marlin_w_ref)
|
||||||
|
|
||||||
|
results.append(
|
||||||
|
benchmark.Timer(
|
||||||
|
stmt="torch.matmul(a, marlin_w_ref)",
|
||||||
|
globals=globals,
|
||||||
|
label=label,
|
||||||
|
sub_label=sub_label,
|
||||||
|
description="pytorch_gemm",
|
||||||
|
).blocked_autorange(min_run_time=min_run_time))
|
||||||
|
|
||||||
|
results.append(
|
||||||
|
benchmark.Timer(
|
||||||
|
stmt=
|
||||||
|
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full)", # noqa: E501
|
||||||
|
globals=globals,
|
||||||
|
label=label,
|
||||||
|
sub_label=sub_label,
|
||||||
|
description="gptq_marlin_gemm",
|
||||||
|
).blocked_autorange(min_run_time=min_run_time))
|
||||||
|
|
||||||
|
results.append(
|
||||||
|
benchmark.Timer(
|
||||||
|
stmt=
|
||||||
|
"q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, num_bits)", # noqa: E501
|
||||||
|
globals=globals,
|
||||||
|
label=label,
|
||||||
|
sub_label=sub_label,
|
||||||
|
description="gptq_marlin_repack",
|
||||||
|
).blocked_autorange(min_run_time=min_run_time))
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
print("Benchmarking models:")
|
||||||
|
for i, model in enumerate(args.models):
|
||||||
|
print(f"[{i}] {model}")
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for model in args.models:
|
||||||
|
for layer in WEIGHT_SHAPES[model]:
|
||||||
|
size_k = layer[0]
|
||||||
|
size_n = layer[1]
|
||||||
|
|
||||||
|
if len(args.limit_k) > 0 and size_k not in args.limit_k:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if len(args.limit_n) > 0 and size_n not in args.limit_n:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for act_order in ACT_ORDER_OPTS:
|
||||||
|
for is_k_full in K_FULL_OPTS:
|
||||||
|
for num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
|
||||||
|
for group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES:
|
||||||
|
if len(
|
||||||
|
args.limit_group_size
|
||||||
|
) > 0 and group_size not in args.limit_group_size:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# For act_order, the group_size must be less than
|
||||||
|
# size_k
|
||||||
|
if act_order and (group_size == size_k
|
||||||
|
or group_size == -1):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for size_m in args.batch_sizes:
|
||||||
|
bench_run(results, model, act_order, is_k_full,
|
||||||
|
num_bits, group_size, size_m, size_k,
|
||||||
|
size_n)
|
||||||
|
|
||||||
|
compare = benchmark.Compare(results)
|
||||||
|
compare.print()
|
||||||
|
|
||||||
|
|
||||||
|
# For quick benchmarking use:
|
||||||
|
# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 # noqa E501
|
||||||
|
#
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Benchmark Marlin across specified models/shapes/batches")
|
||||||
|
parser.add_argument(
|
||||||
|
"--models",
|
||||||
|
nargs="+",
|
||||||
|
type=str,
|
||||||
|
default=DEFAULT_MODELS,
|
||||||
|
choices=WEIGHT_SHAPES.keys(),
|
||||||
|
)
|
||||||
|
parser.add_argument("--batch-sizes",
|
||||||
|
nargs="+",
|
||||||
|
type=int,
|
||||||
|
default=DEFAULT_BATCH_SIZES)
|
||||||
|
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
|
||||||
|
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
|
||||||
|
parser.add_argument("--limit-group-size", nargs="+", type=int, default=[])
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
||||||
75
benchmarks/kernels/benchmark_shapes.py
Normal file
75
benchmarks/kernels/benchmark_shapes.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
WEIGHT_SHAPES = {
|
||||||
|
"ideal": [[4 * 256 * 32, 256 * 32]],
|
||||||
|
"mistralai/Mistral-7B-v0.1/TP1": [
|
||||||
|
[4096, 6144],
|
||||||
|
[4096, 4096],
|
||||||
|
[4096, 28672],
|
||||||
|
[14336, 4096],
|
||||||
|
],
|
||||||
|
"mistralai/Mistral-7B-v0.1/TP2": [
|
||||||
|
[4096, 3072],
|
||||||
|
[2048, 4096],
|
||||||
|
[4096, 14336],
|
||||||
|
[7168, 4096],
|
||||||
|
],
|
||||||
|
"mistralai/Mistral-7B-v0.1/TP4": [
|
||||||
|
[4096, 1536],
|
||||||
|
[1024, 4096],
|
||||||
|
[4096, 7168],
|
||||||
|
[3584, 4096],
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-2-7b-hf/TP1": [
|
||||||
|
[4096, 12288],
|
||||||
|
[4096, 4096],
|
||||||
|
[4096, 22016],
|
||||||
|
[11008, 4096],
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-2-7b-hf/TP2": [
|
||||||
|
[4096, 6144],
|
||||||
|
[2048, 4096],
|
||||||
|
[4096, 11008],
|
||||||
|
[5504, 4096],
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-2-7b-hf/TP4": [
|
||||||
|
[4096, 3072],
|
||||||
|
[1024, 4096],
|
||||||
|
[4096, 5504],
|
||||||
|
[2752, 4096],
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-2-13b-hf/TP1": [
|
||||||
|
[5120, 15360],
|
||||||
|
[5120, 5120],
|
||||||
|
[5120, 27648],
|
||||||
|
[13824, 5120],
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-2-13b-hf/TP2": [
|
||||||
|
[5120, 7680],
|
||||||
|
[2560, 5120],
|
||||||
|
[5120, 13824],
|
||||||
|
[6912, 5120],
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-2-13b-hf/TP4": [
|
||||||
|
[5120, 3840],
|
||||||
|
[1280, 5120],
|
||||||
|
[5120, 6912],
|
||||||
|
[3456, 5120],
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-2-70b-hf/TP1": [
|
||||||
|
[8192, 10240],
|
||||||
|
[8192, 8192],
|
||||||
|
[8192, 57344],
|
||||||
|
[28672, 8192],
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-2-70b-hf/TP2": [
|
||||||
|
[8192, 5120],
|
||||||
|
[4096, 8192],
|
||||||
|
[8192, 28672],
|
||||||
|
[14336, 8192],
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-2-70b-hf/TP4": [
|
||||||
|
[8192, 2560],
|
||||||
|
[2048, 8192],
|
||||||
|
[8192, 14336],
|
||||||
|
[7168, 8192],
|
||||||
|
],
|
||||||
|
}
|
||||||
158
tests/kernels/test_marlin_gemm.py
Normal file
158
tests/kernels/test_marlin_gemm.py
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
"""Tests for the marlin kernel.
|
||||||
|
|
||||||
|
Run `pytest tests/kernels/marlin/test_marlin_gemm.py`.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||||
|
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
|
MarlinWorkspace, is_marlin_supported, marlin_quantize, marlin_weights)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
gptq_pack, quantize_weights, sort_weights)
|
||||||
|
|
||||||
|
ACT_ORDER_OPTS = [False, True]
|
||||||
|
K_FULL_OPTS = [False, True]
|
||||||
|
|
||||||
|
K_CHUNKS = [128, 256]
|
||||||
|
N_CHUNKS = [64, 128, 256]
|
||||||
|
|
||||||
|
MNK_FACTORS = [
|
||||||
|
(1, 1, 1),
|
||||||
|
(1, 4, 8),
|
||||||
|
(1, 7, 5),
|
||||||
|
(1, 7 * 4, 5 * 1),
|
||||||
|
(13, 17, 67),
|
||||||
|
(26, 37, 13),
|
||||||
|
(67, 13, 11),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def rand_data(shape):
|
||||||
|
data = torch.rand(shape).to(torch.half).cuda()
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not is_marlin_supported(),
|
||||||
|
reason="Marlin is not supported on this GPU type.")
|
||||||
|
@pytest.mark.parametrize("k_chunk", K_CHUNKS)
|
||||||
|
@pytest.mark.parametrize("n_chunk", N_CHUNKS)
|
||||||
|
@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_SUPPORTED_NUM_BITS)
|
||||||
|
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_SUPPORTED_GROUP_SIZES)
|
||||||
|
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
|
||||||
|
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||||
|
def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
|
||||||
|
mnk_factors):
|
||||||
|
m_factor, n_factor, k_factor = mnk_factors
|
||||||
|
|
||||||
|
size_m = m_factor
|
||||||
|
size_k = k_chunk * k_factor
|
||||||
|
size_n = n_chunk * n_factor
|
||||||
|
|
||||||
|
print(f"MNK = {size_m} {size_n} {size_k}")
|
||||||
|
|
||||||
|
# Filter act_order
|
||||||
|
if act_order:
|
||||||
|
if group_size == -1:
|
||||||
|
return
|
||||||
|
if group_size == size_k:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Normalize group_size
|
||||||
|
if group_size == -1:
|
||||||
|
group_size = size_k
|
||||||
|
assert group_size <= size_k
|
||||||
|
|
||||||
|
# Create input
|
||||||
|
b_weight = rand_data((size_k, size_n))
|
||||||
|
|
||||||
|
# Quantize (and apply act_order if provided)
|
||||||
|
w_ref, q_w, s, g_idx, rand_perm = quantize_weights(b_weight, num_bits,
|
||||||
|
group_size, act_order)
|
||||||
|
|
||||||
|
# Pack to GPTQ format
|
||||||
|
q_w_gptq = gptq_pack(q_w, num_bits, size_k, size_n)
|
||||||
|
|
||||||
|
# For act_order, sort the "weights" and "g_idx" so that group ids are
|
||||||
|
# increasing
|
||||||
|
sort_indices = torch.empty(0, dtype=torch.int, device=b_weight.device)
|
||||||
|
if act_order:
|
||||||
|
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
|
||||||
|
|
||||||
|
# Pack to Marlin format
|
||||||
|
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, num_bits)
|
||||||
|
|
||||||
|
# Run Marlin repack GPU kernel
|
||||||
|
marlin_q_w_2 = ops.gptq_marlin_repack(
|
||||||
|
q_w_gptq,
|
||||||
|
sort_indices,
|
||||||
|
size_k,
|
||||||
|
size_n,
|
||||||
|
num_bits,
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
assert torch.allclose(marlin_q_w_1, marlin_q_w_2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not is_marlin_supported(),
|
||||||
|
reason="Marlin is not supported on this GPU type.")
|
||||||
|
@pytest.mark.parametrize("k_chunk", K_CHUNKS)
|
||||||
|
@pytest.mark.parametrize("n_chunk", N_CHUNKS)
|
||||||
|
@pytest.mark.parametrize("num_bits", GPTQ_MARLIN_SUPPORTED_NUM_BITS)
|
||||||
|
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_SUPPORTED_GROUP_SIZES)
|
||||||
|
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||||
|
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
|
||||||
|
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
|
||||||
|
def test_marlin_gemm(
|
||||||
|
k_chunk,
|
||||||
|
n_chunk,
|
||||||
|
num_bits,
|
||||||
|
group_size,
|
||||||
|
mnk_factors,
|
||||||
|
act_order,
|
||||||
|
is_k_full,
|
||||||
|
):
|
||||||
|
m_factor, n_factor, k_factor = mnk_factors
|
||||||
|
|
||||||
|
size_m = m_factor
|
||||||
|
size_k = k_chunk * k_factor
|
||||||
|
size_n = n_chunk * n_factor
|
||||||
|
|
||||||
|
print(f"MNK = {size_m} {size_n} {size_k}")
|
||||||
|
print(f"groupsize = {group_size}")
|
||||||
|
|
||||||
|
if act_order:
|
||||||
|
if group_size == -1:
|
||||||
|
return
|
||||||
|
if group_size == size_k:
|
||||||
|
return
|
||||||
|
|
||||||
|
a_input = rand_data((size_m, size_k))
|
||||||
|
b_weight = rand_data((size_k, size_n))
|
||||||
|
|
||||||
|
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
||||||
|
b_weight, num_bits, group_size, act_order)
|
||||||
|
|
||||||
|
workspace = MarlinWorkspace(size_n)
|
||||||
|
|
||||||
|
output = ops.gptq_marlin_gemm(
|
||||||
|
a_input,
|
||||||
|
marlin_q_w,
|
||||||
|
marlin_s,
|
||||||
|
g_idx,
|
||||||
|
sort_indices,
|
||||||
|
workspace.scratch,
|
||||||
|
num_bits,
|
||||||
|
a_input.shape[0],
|
||||||
|
b_weight.shape[1],
|
||||||
|
a_input.shape[1],
|
||||||
|
is_k_full,
|
||||||
|
)
|
||||||
|
output_ref = torch.matmul(a_input, w_ref)
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
assert torch.allclose(output, output_ref, rtol=1e-2)
|
||||||
174
vllm/model_executor/layers/quantization/utils/marlin_utils.py
Normal file
174
vllm/model_executor/layers/quantization/utils/marlin_utils.py
Normal file
@ -0,0 +1,174 @@
|
|||||||
|
"""This file is used for /tests and /benchmarks"""
|
||||||
|
import numpy
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||||
|
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_TILE)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
get_pack_factor, quantize_weights, sort_weights)
|
||||||
|
|
||||||
|
__cuda_arch = torch.cuda.get_device_capability()
|
||||||
|
|
||||||
|
|
||||||
|
def is_marlin_supported():
|
||||||
|
return __cuda_arch[0] >= 8
|
||||||
|
|
||||||
|
|
||||||
|
# Precompute permutations for Marlin weight and scale shuffling # noqa: E501
|
||||||
|
#
|
||||||
|
# Marlin works on [16,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501
|
||||||
|
# with the tensor-core format that is described here:
|
||||||
|
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
|
||||||
|
#
|
||||||
|
# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501
|
||||||
|
# (without the need to use ldmatrix instructions) # noqa: E501
|
||||||
|
def _get_perms(num_bits):
|
||||||
|
perm_list = []
|
||||||
|
for i in range(32):
|
||||||
|
perm1 = []
|
||||||
|
col = i // 4
|
||||||
|
for block in [0, 1]:
|
||||||
|
for row in [
|
||||||
|
2 * (i % 4),
|
||||||
|
2 * (i % 4) + 1,
|
||||||
|
2 * (i % 4 + 4),
|
||||||
|
2 * (i % 4 + 4) + 1,
|
||||||
|
]:
|
||||||
|
perm1.append(16 * row + col + 8 * block)
|
||||||
|
for j in range(4):
|
||||||
|
perm_list.extend([p + 256 * j for p in perm1])
|
||||||
|
|
||||||
|
perm = numpy.array(perm_list)
|
||||||
|
|
||||||
|
if num_bits == 4:
|
||||||
|
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
||||||
|
elif num_bits == 8:
|
||||||
|
interleave = numpy.array([0, 2, 1, 3])
|
||||||
|
else:
|
||||||
|
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
||||||
|
|
||||||
|
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
|
||||||
|
perm = torch.from_numpy(perm)
|
||||||
|
scale_perm = []
|
||||||
|
for i in range(8):
|
||||||
|
scale_perm.extend([i + 8 * j for j in range(8)])
|
||||||
|
scale_perm_single = []
|
||||||
|
for i in range(4):
|
||||||
|
scale_perm_single.extend(
|
||||||
|
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
|
||||||
|
return perm, scale_perm, scale_perm_single
|
||||||
|
|
||||||
|
|
||||||
|
_perm = {}
|
||||||
|
_scale_perm = {}
|
||||||
|
_scale_perm_single = {}
|
||||||
|
for num_bits in [4, 8]:
|
||||||
|
perm, scale_perm, scale_perm_single = _get_perms(num_bits)
|
||||||
|
_perm[num_bits] = perm
|
||||||
|
_scale_perm[num_bits] = scale_perm
|
||||||
|
_scale_perm_single[num_bits] = scale_perm_single
|
||||||
|
|
||||||
|
|
||||||
|
def marlin_permute_weights(q_w,
|
||||||
|
size_k,
|
||||||
|
size_n,
|
||||||
|
num_bits,
|
||||||
|
tile=GPTQ_MARLIN_TILE):
|
||||||
|
assert q_w.shape == (size_k, size_n)
|
||||||
|
assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
|
||||||
|
assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
|
||||||
|
|
||||||
|
# Permute weights to 16x64 marlin tiles
|
||||||
|
q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
|
||||||
|
q_w = q_w.permute((0, 2, 1, 3))
|
||||||
|
q_w = q_w.reshape((size_k // tile, size_n * tile))
|
||||||
|
|
||||||
|
q_w = q_w.reshape(
|
||||||
|
(-1, _perm[num_bits].numel()))[:, _perm[num_bits]].reshape(q_w.shape)
|
||||||
|
|
||||||
|
return q_w
|
||||||
|
|
||||||
|
|
||||||
|
def marlin_weights(q_w, size_k, size_n, num_bits):
|
||||||
|
# Permute
|
||||||
|
q_w = marlin_permute_weights(q_w, size_k, size_n, num_bits)
|
||||||
|
|
||||||
|
# Pack
|
||||||
|
pack_factor = get_pack_factor(num_bits)
|
||||||
|
orig_device = q_w.device
|
||||||
|
|
||||||
|
q_w = q_w.cpu().numpy().astype(numpy.uint32)
|
||||||
|
|
||||||
|
q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor),
|
||||||
|
dtype=numpy.uint32)
|
||||||
|
|
||||||
|
for i in range(pack_factor):
|
||||||
|
q_packed |= q_w[:, i::pack_factor] << num_bits * i
|
||||||
|
|
||||||
|
q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device)
|
||||||
|
|
||||||
|
return q_packed
|
||||||
|
|
||||||
|
|
||||||
|
def marlin_permute_scales(s, size_k, size_n, group_size, num_bits):
|
||||||
|
if group_size < size_k and group_size != -1:
|
||||||
|
s = s.reshape((-1, len(_scale_perm[num_bits])))[:,
|
||||||
|
_scale_perm[num_bits]]
|
||||||
|
else:
|
||||||
|
s = s.reshape(
|
||||||
|
(-1,
|
||||||
|
len(_scale_perm_single[num_bits])))[:,
|
||||||
|
_scale_perm_single[num_bits]]
|
||||||
|
s = s.reshape((-1, size_n)).contiguous()
|
||||||
|
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
def marlin_quantize(
|
||||||
|
w: torch.Tensor,
|
||||||
|
num_bits: int,
|
||||||
|
group_size: int,
|
||||||
|
act_order: bool,
|
||||||
|
):
|
||||||
|
size_k, size_n = w.shape
|
||||||
|
|
||||||
|
# Normalize group_size
|
||||||
|
if group_size == -1:
|
||||||
|
group_size = size_k
|
||||||
|
assert group_size <= size_k
|
||||||
|
|
||||||
|
# Quantize (and apply act_order if provided)
|
||||||
|
w_ref, q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size,
|
||||||
|
act_order)
|
||||||
|
|
||||||
|
# For act_order, sort the "weights" and "g_idx" so that group ids are
|
||||||
|
# increasing
|
||||||
|
sort_indices = torch.empty(0, dtype=torch.int, device=w.device)
|
||||||
|
if act_order:
|
||||||
|
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
|
||||||
|
|
||||||
|
# Reformat to marlin
|
||||||
|
marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits)
|
||||||
|
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size, num_bits)
|
||||||
|
|
||||||
|
# Create result
|
||||||
|
res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]
|
||||||
|
for i in range(len(res_list)):
|
||||||
|
res_list[i] = res_list[i].to(w.device)
|
||||||
|
|
||||||
|
return res_list
|
||||||
|
|
||||||
|
|
||||||
|
class MarlinWorkspace:
|
||||||
|
|
||||||
|
def __init__(self, out_features):
|
||||||
|
assert (out_features % GPTQ_MARLIN_MIN_THREAD_N == 0), (
|
||||||
|
"out_features = {} is undivisible by GPTQ_MARLIN_MIN_THREAD_N = {}"
|
||||||
|
.format(out_features, GPTQ_MARLIN_MIN_THREAD_N))
|
||||||
|
|
||||||
|
max_workspace_size = ((out_features // GPTQ_MARLIN_MIN_THREAD_N) *
|
||||||
|
GPTQ_MARLIN_MAX_PARALLEL)
|
||||||
|
|
||||||
|
self.scratch = torch.zeros(max_workspace_size,
|
||||||
|
dtype=torch.int,
|
||||||
|
device="cuda")
|
||||||
146
vllm/model_executor/layers/quantization/utils/quant_utils.py
Normal file
146
vllm/model_executor/layers/quantization/utils/quant_utils.py
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
"""This file is used for /tests and /benchmarks"""
|
||||||
|
import numpy
|
||||||
|
import torch
|
||||||
|
|
||||||
|
SUPPORTED_NUM_BITS = [4, 8]
|
||||||
|
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||||
|
|
||||||
|
|
||||||
|
def get_pack_factor(num_bits):
|
||||||
|
assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
|
||||||
|
return 32 // num_bits
|
||||||
|
|
||||||
|
|
||||||
|
def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
|
||||||
|
assert q_w.shape == w_ref.shape
|
||||||
|
|
||||||
|
orig_device = q_w.device
|
||||||
|
k_size, _ = q_w.shape
|
||||||
|
|
||||||
|
g_idx = torch.zeros((k_size, ), dtype=torch.int32)
|
||||||
|
for i in range(k_size):
|
||||||
|
g_idx[i] = i // group_size
|
||||||
|
|
||||||
|
# Simulate act_order by doing a random permutation on K
|
||||||
|
rand_perm = torch.randperm(k_size)
|
||||||
|
|
||||||
|
g_idx = g_idx[rand_perm].contiguous()
|
||||||
|
q_w = q_w[rand_perm, :].contiguous()
|
||||||
|
w_ref = w_ref[rand_perm, :].contiguous()
|
||||||
|
|
||||||
|
return (
|
||||||
|
w_ref.to(device=orig_device),
|
||||||
|
q_w.to(device=orig_device),
|
||||||
|
g_idx.to(device=orig_device),
|
||||||
|
rand_perm.to(device=orig_device),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
|
||||||
|
act_order: bool):
|
||||||
|
orig_device = w.device
|
||||||
|
size_k, size_n = w.shape
|
||||||
|
|
||||||
|
assert w.is_floating_point(), "w must be float"
|
||||||
|
assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
|
||||||
|
assert group_size in SUPPORTED_GROUP_SIZES + [
|
||||||
|
size_k
|
||||||
|
], f"Unsupported groupsize = {group_size}"
|
||||||
|
|
||||||
|
if group_size == -1:
|
||||||
|
group_size = size_k
|
||||||
|
assert group_size <= size_k
|
||||||
|
|
||||||
|
max_q_val = 2**num_bits - 1
|
||||||
|
half_q_val = (max_q_val + 1) // 2
|
||||||
|
|
||||||
|
# Reshape to [groupsize, -1]
|
||||||
|
if group_size < size_k:
|
||||||
|
w = w.reshape((-1, group_size, size_n))
|
||||||
|
w = w.permute(1, 0, 2)
|
||||||
|
w = w.reshape((group_size, -1))
|
||||||
|
|
||||||
|
# Compute scale for each group
|
||||||
|
s = torch.max(torch.abs(w), 0, keepdim=True)[0]
|
||||||
|
s *= 2 / max_q_val # 2 => symmetric
|
||||||
|
|
||||||
|
# Quantize
|
||||||
|
q_w = torch.round(w / s).int()
|
||||||
|
q_w += half_q_val
|
||||||
|
q_w = torch.clamp(q_w, 0, max_q_val)
|
||||||
|
|
||||||
|
# Compute ref (dequantized)
|
||||||
|
w_ref = (q_w - half_q_val).half() * s
|
||||||
|
|
||||||
|
# Restore original shapes
|
||||||
|
if group_size < size_k:
|
||||||
|
|
||||||
|
def reshape_w(w):
|
||||||
|
w = w.reshape((group_size, -1, size_n))
|
||||||
|
w = w.permute(1, 0, 2)
|
||||||
|
w = w.reshape((size_k, size_n)).contiguous()
|
||||||
|
return w
|
||||||
|
|
||||||
|
q_w = reshape_w(q_w)
|
||||||
|
w_ref = reshape_w(w_ref)
|
||||||
|
|
||||||
|
s = s.reshape((-1, size_n)).contiguous()
|
||||||
|
|
||||||
|
# Apply act_order
|
||||||
|
g_idx = torch.empty(0, dtype=torch.int, device=w.device)
|
||||||
|
rand_perm = torch.empty(0, dtype=torch.int, device=w.device)
|
||||||
|
if act_order:
|
||||||
|
assert (
|
||||||
|
group_size < size_k
|
||||||
|
), "For act_order, groupsize = {} must be less than size_k = {}".format(
|
||||||
|
group_size, size_k)
|
||||||
|
|
||||||
|
w_ref, q_w, g_idx, rand_perm = permute_rows(q_w, w_ref, group_size)
|
||||||
|
|
||||||
|
return (
|
||||||
|
w_ref.to(device=orig_device),
|
||||||
|
q_w.to(device=orig_device),
|
||||||
|
s.to(device=orig_device),
|
||||||
|
g_idx.to(device=orig_device),
|
||||||
|
rand_perm.to(device=orig_device),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):
|
||||||
|
orig_device = q_w.device
|
||||||
|
|
||||||
|
sort_indices = torch.argsort(g_idx).to(
|
||||||
|
dtype=torch.int32) # Sort based on g_idx
|
||||||
|
|
||||||
|
g_idx = g_idx[sort_indices].contiguous()
|
||||||
|
q_w = q_w[sort_indices, :].contiguous()
|
||||||
|
|
||||||
|
return (
|
||||||
|
q_w.to(device=orig_device),
|
||||||
|
g_idx.to(device=orig_device),
|
||||||
|
sort_indices.to(device=orig_device),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def gptq_pack(
|
||||||
|
q_w: torch.Tensor,
|
||||||
|
num_bits: int,
|
||||||
|
size_k: int,
|
||||||
|
size_n: int,
|
||||||
|
):
|
||||||
|
assert q_w.shape == (size_k, size_n)
|
||||||
|
|
||||||
|
pack_factor = get_pack_factor(num_bits)
|
||||||
|
assert size_k % pack_factor == 0
|
||||||
|
|
||||||
|
orig_device = q_w.device
|
||||||
|
|
||||||
|
q_w = q_w.cpu().numpy().astype(numpy.uint32)
|
||||||
|
|
||||||
|
q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32)
|
||||||
|
|
||||||
|
for i in range(pack_factor):
|
||||||
|
q_res |= q_w[i::pack_factor, :] << num_bits * i
|
||||||
|
|
||||||
|
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
||||||
|
return q_res
|
||||||
Loading…
x
Reference in New Issue
Block a user