mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 20:07:28 +08:00
402 lines
12 KiB
Python
402 lines
12 KiB
Python
"""This file is used for /tests and /benchmarks"""
|
|
from typing import List
|
|
|
|
import numpy
|
|
import torch
|
|
|
|
from vllm.model_executor.layers.quantization.qqq import (
|
|
MARLIN_QQQ_SUPPORTED_NUM_BITS)
|
|
from vllm.scalar_type import ScalarType, scalar_types
|
|
|
|
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
|
|
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
|
|
|
# Note: this is a hack. We should update each model to register the
|
|
# stacked params and get it from there instead in a future PR.
|
|
# fused_name: List[shard_name]
|
|
FUSED_LAYER_NAME_MAPPING = {
|
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
|
"gate_up_proj": ["gate_proj", "up_proj"]
|
|
}
|
|
|
|
|
|
def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
|
|
# prefix: model.layers.0.self_attn.q_proj
|
|
# proj_name: q_proj
|
|
proj_name = prefix.split(".")[-1]
|
|
if proj_name in FUSED_LAYER_NAME_MAPPING:
|
|
shard_prefixes = [
|
|
prefix.replace(proj_name, shard_proj_name)
|
|
for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name]
|
|
]
|
|
|
|
is_skipped = None
|
|
for shard_prefix in shard_prefixes:
|
|
is_shard_skipped = shard_prefix in ignored_layers
|
|
|
|
if is_skipped is None:
|
|
is_skipped = is_shard_skipped
|
|
elif is_shard_skipped != is_skipped:
|
|
raise ValueError(
|
|
f"Detected some but not all shards of {prefix} "
|
|
"are quantized. All shards of fused layers "
|
|
"to have the same precision.")
|
|
else:
|
|
is_skipped = prefix in ignored_layers
|
|
|
|
assert is_skipped is not None
|
|
return is_skipped
|
|
|
|
|
|
def get_pack_factor(num_bits):
|
|
assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
|
|
return 32 // num_bits
|
|
|
|
|
|
def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
|
|
assert q_w.shape == w_ref.shape
|
|
|
|
orig_device = q_w.device
|
|
k_size, _ = q_w.shape
|
|
|
|
g_idx = torch.zeros((k_size, ), dtype=torch.int32)
|
|
for i in range(k_size):
|
|
g_idx[i] = i // group_size
|
|
|
|
# Simulate act_order by doing a random permutation on K
|
|
rand_perm = torch.randperm(k_size)
|
|
|
|
g_idx = g_idx[rand_perm].contiguous()
|
|
q_w = q_w[rand_perm, :].contiguous()
|
|
w_ref = w_ref[rand_perm, :].contiguous()
|
|
|
|
return (
|
|
w_ref.to(device=orig_device),
|
|
q_w.to(device=orig_device),
|
|
g_idx.to(device=orig_device),
|
|
rand_perm.to(device=orig_device),
|
|
)
|
|
|
|
|
|
def quantize_weights(w: torch.Tensor,
|
|
quant_type: ScalarType,
|
|
group_size: int,
|
|
zero_points: bool = False,
|
|
ref_zero_points_after_scales: bool = False):
|
|
assert quant_type.is_integer(), \
|
|
"Floating point quantization may work but has not been tested"
|
|
|
|
orig_device = w.device
|
|
orig_type = w.dtype
|
|
size_k, size_n = w.shape
|
|
|
|
assert w.is_floating_point(), "w must be float"
|
|
|
|
if group_size == -1:
|
|
group_size = size_k
|
|
assert group_size <= size_k
|
|
|
|
# Reshape to [groupsize, -1]
|
|
if group_size < size_k:
|
|
w = w.reshape((-1, group_size, size_n))
|
|
w = w.permute(1, 0, 2)
|
|
w = w.reshape((group_size, -1))
|
|
|
|
# Compute scale for each group
|
|
max_val = torch.max(w, 0, keepdim=True).values
|
|
min_val = torch.min(w, 0, keepdim=True).values
|
|
|
|
max_q_val = quant_type.max()
|
|
min_q_val = quant_type.min()
|
|
|
|
if zero_points:
|
|
assert not quant_type.is_signed() and quant_type.max() > 0
|
|
w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
|
|
maybe_w_zp = torch.round(torch.abs(min_val / w_s)) \
|
|
.clamp(min_q_val, max_q_val).int()
|
|
else:
|
|
# If the bias is such that there are no possible negative/positive
|
|
# values, set the max value to inf to avoid divide by 0
|
|
w_s = torch.max(
|
|
abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
|
|
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)))
|
|
maybe_w_zp = None
|
|
|
|
# Quantize
|
|
w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
|
|
w_q = torch.clamp(w_q, min_q_val, max_q_val)
|
|
|
|
# Compute ref (dequantized)
|
|
# For some kernels (namely Machete) the zero-points are applied after the
|
|
# scales are applied, for this case computing the reference in similar way
|
|
# allows us to use tighter error tolerances in our unit tests.
|
|
if ref_zero_points_after_scales and zero_points:
|
|
w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
|
|
else:
|
|
w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
|
|
|
|
if quant_type.has_bias():
|
|
w_q += quant_type.bias
|
|
|
|
# Restore original shapes
|
|
if group_size < size_k:
|
|
|
|
def reshape_w(w):
|
|
w = w.reshape((group_size, -1, size_n))
|
|
w = w.permute(1, 0, 2)
|
|
w = w.reshape((size_k, size_n)).contiguous()
|
|
return w
|
|
|
|
w_q = reshape_w(w_q)
|
|
w_ref = reshape_w(w_ref)
|
|
|
|
w_s = w_s.reshape((-1, size_n)).contiguous()
|
|
|
|
if zero_points:
|
|
maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
|
|
maybe_w_zp = maybe_w_zp.to(device=orig_device)
|
|
|
|
return (
|
|
w_ref.to(device=orig_device),
|
|
w_q.to(device=orig_device),
|
|
w_s.to(device=orig_device),
|
|
maybe_w_zp,
|
|
)
|
|
|
|
|
|
def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType,
|
|
group_size: int, act_order: bool):
|
|
size_k, _ = w.shape
|
|
|
|
assert w.is_floating_point(), "w must be float"
|
|
assert quant_type in SUPPORTED_GPTQ_QUANT_TYPES, \
|
|
f"Unsupported gptq type = {quant_type}"
|
|
assert group_size in SUPPORTED_GROUP_SIZES + [
|
|
size_k
|
|
], f"Unsupported groupsize = {group_size}"
|
|
|
|
w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size)
|
|
|
|
# Apply act_order
|
|
g_idx = torch.empty(0, dtype=torch.int, device=w.device)
|
|
rand_perm = torch.empty(0, dtype=torch.int, device=w.device)
|
|
if act_order:
|
|
assert (
|
|
group_size < size_k
|
|
), "For act_order, groupsize = {} must be less than size_k = {}".format(
|
|
group_size, size_k)
|
|
|
|
w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size)
|
|
|
|
return w_ref, w_q, w_s, g_idx, rand_perm
|
|
|
|
|
|
# QQQ employs different quant schemes for per-group and
|
|
# per-channel quantization.
|
|
def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int):
|
|
orig_device = w.device
|
|
size_k, size_n = w.shape
|
|
|
|
assert w.is_floating_point(), "w must be float"
|
|
assert num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS, \
|
|
f"Unsupported num_bits = {num_bits}"
|
|
assert group_size in SUPPORTED_GROUP_SIZES + [
|
|
size_k
|
|
], f"Unsupported groupsize = {group_size}"
|
|
|
|
if group_size == -1:
|
|
group_size = size_k
|
|
assert group_size <= size_k
|
|
|
|
if group_size < size_k:
|
|
# Reshape to [groupsize, -1]
|
|
w = w.reshape((-1, group_size, size_n))
|
|
w = w.permute(1, 0, 2)
|
|
w = w.reshape((group_size, -1))
|
|
|
|
max_q_val = 2**num_bits - 1
|
|
half_q_val = (max_q_val + 1) // 2
|
|
|
|
# Compute scale for each group
|
|
s_group = torch.max(torch.abs(w), 0, keepdim=True)[0]
|
|
s_group *= 2 / max_q_val # 2 => symmetric
|
|
|
|
# Quantize
|
|
q_w = torch.round(w / s_group).int()
|
|
q_w += half_q_val
|
|
q_w = torch.clamp(q_w, 0, max_q_val)
|
|
# Compute ref (dequantized)
|
|
w_ref = (q_w - half_q_val).half() * s_group
|
|
|
|
# Restore original shapes
|
|
def reshape_w(w):
|
|
w = w.reshape((group_size, -1, size_n))
|
|
w = w.permute(1, 0, 2)
|
|
w = w.reshape((size_k, size_n)).contiguous()
|
|
return w
|
|
|
|
q_w = reshape_w(q_w)
|
|
w_ref = reshape_w(w_ref)
|
|
|
|
# Compute int8 quantization scale for each channel
|
|
s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0]
|
|
s_channel /= 127.0
|
|
t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8)
|
|
w_ref = t_int8.half() * s_channel
|
|
s_channel = s_channel.reshape(1, -1).to(dtype=torch.float)
|
|
|
|
# Fuse scales
|
|
s_group = (s_group.reshape(-1, size_n).contiguous() /
|
|
s_channel).to(dtype=torch.half)
|
|
else:
|
|
max_q_val = 2**(num_bits - 1) - 1
|
|
|
|
# Compute scale for each channel
|
|
s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0]
|
|
s_channel /= max_q_val
|
|
|
|
# Quantize
|
|
q_w = torch.round(w / s_channel).int()
|
|
q_w = torch.clamp(q_w, -max_q_val, max_q_val)
|
|
# Compute ref (dequantized)
|
|
w_ref = q_w.half() * s_channel
|
|
|
|
s_group = torch.tensor([], dtype=torch.half)
|
|
# div 2 ** (8 - self.bits)) to offset right shift in unpacking
|
|
s_channel /= (2**(8 - num_bits))
|
|
s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float)
|
|
|
|
return (
|
|
w_ref.to(device=orig_device),
|
|
q_w.to(device=orig_device),
|
|
s_group.to(device=orig_device),
|
|
s_channel.to(device=orig_device),
|
|
)
|
|
|
|
|
|
def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor):
|
|
orig_device = q_w.device
|
|
|
|
sort_indices = torch.argsort(g_idx).to(
|
|
dtype=torch.int32) # Sort based on g_idx
|
|
|
|
g_idx = g_idx[sort_indices].contiguous()
|
|
q_w = q_w[sort_indices, :].contiguous()
|
|
|
|
return (
|
|
q_w.to(device=orig_device),
|
|
g_idx.to(device=orig_device),
|
|
sort_indices.to(device=orig_device),
|
|
)
|
|
|
|
|
|
def pack_rows(
|
|
q_w: torch.Tensor,
|
|
num_bits: int,
|
|
size_k: int,
|
|
size_n: int,
|
|
):
|
|
assert q_w.shape == (size_k, size_n)
|
|
|
|
pack_factor = get_pack_factor(num_bits)
|
|
assert size_k % pack_factor == 0
|
|
|
|
orig_device = q_w.device
|
|
|
|
q_w = q_w.cpu().numpy().astype(numpy.uint32)
|
|
|
|
q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32)
|
|
|
|
for i in range(pack_factor):
|
|
q_res |= q_w[i::pack_factor, :] << num_bits * i
|
|
|
|
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
|
return q_res
|
|
|
|
|
|
def pack_cols(
|
|
q_w: torch.Tensor,
|
|
num_bits: int,
|
|
size_k: int,
|
|
size_n: int,
|
|
):
|
|
assert q_w.shape == (size_k, size_n)
|
|
|
|
pack_factor = get_pack_factor(num_bits)
|
|
assert size_n % pack_factor == 0
|
|
|
|
orig_device = q_w.device
|
|
|
|
q_w = q_w.cpu().numpy().astype(numpy.uint32)
|
|
|
|
q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
|
|
|
|
for i in range(pack_factor):
|
|
q_res |= q_w[:, i::pack_factor] << num_bits * i
|
|
|
|
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
|
q_res = q_res.contiguous()
|
|
|
|
return q_res
|
|
|
|
|
|
def unpack_cols(
|
|
packed_q_w: torch.Tensor,
|
|
num_bits: int,
|
|
size_k: int,
|
|
size_n: int,
|
|
):
|
|
pack_factor = get_pack_factor(num_bits)
|
|
assert size_n % pack_factor == 0
|
|
assert packed_q_w.shape == (
|
|
size_k, size_n // pack_factor
|
|
), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
|
|
packed_q_w.shape, size_k, size_n, pack_factor)
|
|
|
|
orig_device = packed_q_w.device
|
|
|
|
packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
|
|
q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
|
|
|
|
mask = (1 << num_bits) - 1
|
|
for i in range(pack_factor):
|
|
vals = packed_q_w_cpu & mask
|
|
packed_q_w_cpu >>= num_bits
|
|
q_res[:, i::pack_factor] = vals
|
|
|
|
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
|
q_res = q_res.contiguous()
|
|
|
|
return q_res
|
|
|
|
|
|
def gptq_pack(
|
|
q_w: torch.Tensor,
|
|
num_bits: int,
|
|
size_k: int,
|
|
size_n: int,
|
|
):
|
|
return pack_rows(q_w, num_bits, size_k, size_n)
|
|
|
|
|
|
def awq_pack(
|
|
q_w: torch.Tensor,
|
|
num_bits: int,
|
|
size_k: int,
|
|
size_n: int,
|
|
):
|
|
assert q_w.shape == (size_k, size_n)
|
|
|
|
# Interleave column dim (for the dequantize code) and pack it to int32
|
|
if num_bits == 4:
|
|
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
|
elif num_bits == 8:
|
|
interleave = numpy.array([0, 2, 1, 3])
|
|
else:
|
|
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
|
|
|
q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel()
|
|
q_w = q_w.reshape((-1, size_n)).contiguous()
|
|
|
|
return pack_cols(q_w, num_bits, size_k, size_n)
|