mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-19 09:46:11 +08:00
298 lines
11 KiB
Python
298 lines
11 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from typing import Any, Optional
|
|
|
|
import torch
|
|
from torch.nn.parameter import Parameter
|
|
|
|
from vllm import _custom_ops as ops
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
|
from vllm.model_executor.layers.quantization.base_config import (
|
|
QuantizationConfig)
|
|
from vllm.model_executor.parameter import (BasevLLMParameter,
|
|
ChannelQuantScaleParameter,
|
|
GroupQuantScaleParameter,
|
|
PackedvLLMParameter)
|
|
from vllm.scalar_type import scalar_types
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
GPTQ_MARLIN_24_TILE = 16
|
|
GPTQ_MARLIN_24_MIN_THREAD_N = 128
|
|
GPTQ_MARLIN_24_MIN_THREAD_K = 128
|
|
GPTQ_MARLIN_24_MAX_PARALLEL = 64
|
|
|
|
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [
|
|
scalar_types.uint4b8, scalar_types.uint8b128
|
|
]
|
|
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
|
|
|
|
|
|
class GPTQMarlin24Config(QuantizationConfig):
|
|
"""Config class for Marlin24.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
weight_bits: int,
|
|
group_size: int,
|
|
) -> None:
|
|
super().__init__()
|
|
quant_type = {
|
|
4: scalar_types.uint4b8,
|
|
8: scalar_types.uint8b128,
|
|
}.get(weight_bits)
|
|
|
|
self.group_size = group_size
|
|
|
|
# Verify
|
|
if quant_type is None or \
|
|
quant_type not in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES:
|
|
raise ValueError(
|
|
f"Marlin_24 does not support quant_type = {quant_type}. "
|
|
f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES} "
|
|
"are supported.")
|
|
if self.group_size not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES:
|
|
raise ValueError(
|
|
f"Marlin_24 does not support group_size = {self.group_size}. "
|
|
f"Only group_sizes = {GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES} "
|
|
"are supported.")
|
|
|
|
self.quant_type = quant_type
|
|
|
|
# 4 Bits packed into 32 bit datatype.
|
|
self.pack_factor = 32 // self.quant_type.size_bits
|
|
|
|
# Tile size used by marlin kernels.
|
|
self.tile_size = 16
|
|
|
|
# Min out_features dim
|
|
self.min_n_threads = GPTQ_MARLIN_24_MIN_THREAD_N
|
|
|
|
# Min in_features dim
|
|
self.min_k_threads = GPTQ_MARLIN_24_MIN_THREAD_K
|
|
|
|
# Max parallel problems to solve at once (improves large
|
|
# batch performance)
|
|
self.max_parallel = GPTQ_MARLIN_24_MAX_PARALLEL
|
|
|
|
# Permutation length used by the marlin kernels.
|
|
self.perm_len = 1024
|
|
|
|
def __repr__(self) -> str:
|
|
return "Marlin24Config(quant_type={}, group_size={})".format(
|
|
self.quant_type, self.group_size)
|
|
|
|
@classmethod
|
|
def get_name(cls) -> QuantizationMethods:
|
|
return "gptq_marlin_24"
|
|
|
|
@classmethod
|
|
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
|
return [torch.half]
|
|
|
|
@classmethod
|
|
# Need to figure it out
|
|
def get_min_capability(cls) -> int:
|
|
return 80
|
|
|
|
@classmethod
|
|
def get_config_filenames(cls) -> list[str]:
|
|
return ["quantize_config.json"]
|
|
|
|
@classmethod
|
|
def from_config(cls, config: dict[str, Any]) -> "GPTQMarlin24Config":
|
|
weight_bits = cls.get_from_keys(config, ["bits"])
|
|
group_size = cls.get_from_keys(config, ["group_size"])
|
|
return cls(weight_bits, group_size)
|
|
|
|
@classmethod
|
|
def override_quantization_method(
|
|
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
|
is_marlin_24_format = (
|
|
hf_quant_cfg.get("checkpoint_format") == "marlin_24")
|
|
|
|
is_valid_user_quant = (user_quant is None or user_quant == "gptq"
|
|
or user_quant == "gptq_marlin_24")
|
|
|
|
if is_marlin_24_format and is_valid_user_quant:
|
|
msg = ("The model is serialized in {} format. "
|
|
"Using {} kernel.".format(cls.get_name(), cls.get_name()))
|
|
logger.info(msg)
|
|
return cls.get_name()
|
|
|
|
return None
|
|
|
|
def get_quant_method(self, layer: torch.nn.Module,
|
|
prefix: str) -> Optional["GPTQMarlin24LinearMethod"]:
|
|
if isinstance(layer, LinearBase):
|
|
return GPTQMarlin24LinearMethod(self)
|
|
return None
|
|
|
|
|
|
class GPTQMarlin24LinearMethod(LinearMethodBase):
|
|
"""Linear method for Marlin24.
|
|
|
|
Args:
|
|
quant_config: The Marlin24 quantization config.
|
|
"""
|
|
|
|
def __init__(self, quant_config: GPTQMarlin24Config):
|
|
self.quant_config = quant_config
|
|
|
|
def create_weights(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
input_size_per_partition: int,
|
|
output_partition_sizes: list[int],
|
|
input_size: int,
|
|
output_size: int,
|
|
params_dtype: torch.dtype,
|
|
**extra_weight_attrs,
|
|
):
|
|
del output_size # Unused.
|
|
weight_loader = extra_weight_attrs["weight_loader"]
|
|
if params_dtype != torch.float16:
|
|
raise ValueError(
|
|
f"The params dtype must be float16, but got {params_dtype}")
|
|
|
|
# Validate output_size_per_partition
|
|
output_size_per_partition = sum(output_partition_sizes)
|
|
if output_size_per_partition % self.quant_config.min_n_threads != 0:
|
|
raise ValueError(
|
|
f"Weight output_size_per_partition = "
|
|
f"{output_size_per_partition} is not divisible by "
|
|
f"min_n_threads = {self.quant_config.min_n_threads}.")
|
|
if output_size_per_partition % self.quant_config.pack_factor != 0:
|
|
raise ValueError(
|
|
f"Weight output_size_per_partition = "
|
|
f"{output_size_per_partition} is not divisible by "
|
|
f"pack_factor = {self.quant_config.pack_factor}.")
|
|
|
|
# Validate input_size_per_partition
|
|
if input_size_per_partition % self.quant_config.min_k_threads != 0:
|
|
raise ValueError(
|
|
f"Weight input_size_per_partition = "
|
|
f"{input_size_per_partition} is not divisible by "
|
|
f"min_k_threads = {self.quant_config.min_k_threads}.")
|
|
if (self.quant_config.group_size != -1 and
|
|
input_size_per_partition % self.quant_config.group_size != 0):
|
|
raise ValueError(f"Weight input_size_per_partition = "
|
|
f"{input_size_per_partition} is not divisible by "
|
|
f"group_size = {self.quant_config.group_size}.")
|
|
|
|
# Check that we have at least 4 tiles horizontally in the shard
|
|
num_tiles_per_perm = self.quant_config.perm_len // (
|
|
self.quant_config.tile_size**2)
|
|
if output_size_per_partition % num_tiles_per_perm != 0:
|
|
raise ValueError(
|
|
"Each permutation group must reside on the same gpu")
|
|
|
|
# Quantized 4Bit weights packed into Int32.
|
|
qweight = PackedvLLMParameter(
|
|
data=torch.empty(
|
|
input_size_per_partition // self.quant_config.tile_size // 2,
|
|
output_size_per_partition * self.quant_config.tile_size //
|
|
self.quant_config.pack_factor,
|
|
device="cuda",
|
|
dtype=torch.int32,
|
|
),
|
|
input_dim=0,
|
|
output_dim=1,
|
|
packed_dim=1,
|
|
packed_factor=self.quant_config.pack_factor,
|
|
marlin_tile_size=self.quant_config.tile_size,
|
|
weight_loader=weight_loader)
|
|
|
|
# Meta
|
|
meta = PackedvLLMParameter(data=torch.empty(
|
|
input_size_per_partition // 8 // 2 // 2,
|
|
output_size_per_partition * 2,
|
|
device="cuda",
|
|
dtype=torch.int16,
|
|
),
|
|
input_dim=0,
|
|
output_dim=1,
|
|
packed_dim=1,
|
|
packed_factor=1,
|
|
marlin_tile_size=2,
|
|
weight_loader=weight_loader)
|
|
|
|
# Determine if channelwise or not
|
|
input_groups = (1 if self.quant_config.group_size == -1 else
|
|
input_size_per_partition //
|
|
self.quant_config.group_size)
|
|
|
|
weight_scale_args = {
|
|
"data":
|
|
torch.empty(
|
|
input_groups,
|
|
output_size_per_partition,
|
|
device="cuda",
|
|
dtype=params_dtype,
|
|
),
|
|
"weight_loader":
|
|
weight_loader
|
|
}
|
|
if input_groups == 1:
|
|
scales = ChannelQuantScaleParameter(output_dim=1,
|
|
**weight_scale_args)
|
|
else:
|
|
scales = GroupQuantScaleParameter(output_dim=1,
|
|
input_dim=0,
|
|
**weight_scale_args)
|
|
|
|
# Allocate workspace (Used for internal locking mechanism)
|
|
max_workspace_size = (
|
|
output_size_per_partition //
|
|
self.quant_config.min_n_threads) * self.quant_config.max_parallel
|
|
|
|
workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size,
|
|
device="cuda",
|
|
dtype=torch.int),
|
|
weight_loader=weight_loader)
|
|
|
|
layer.register_parameter("B_24", qweight)
|
|
layer.register_parameter("B_meta", meta)
|
|
layer.register_parameter("s", scales)
|
|
layer.register_parameter("workspace", workspace)
|
|
|
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
# required by torch.compile
|
|
layer.B_24 = Parameter(layer.B_24.data, requires_grad=False)
|
|
layer.s = Parameter(layer.s.data, requires_grad=False)
|
|
layer.B_meta = Parameter(layer.B_meta.data, requires_grad=False)
|
|
layer.workspace = Parameter(layer.workspace.data, requires_grad=False)
|
|
|
|
def apply(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
bias: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
qweight = layer.B_24
|
|
meta = layer.B_meta
|
|
scales = layer.s
|
|
workspace = layer.workspace
|
|
|
|
x_2d = x.view(-1, x.shape[-1])
|
|
|
|
size_m = x_2d.shape[0]
|
|
size_k = x_2d.shape[1]
|
|
size_n = scales.shape[1]
|
|
|
|
output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales,
|
|
workspace,
|
|
self.quant_config.quant_type,
|
|
size_m, size_n, size_k)
|
|
|
|
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
|
|
|
|
if bias is not None:
|
|
output.add_(bias) # In-place add
|
|
|
|
return output
|