Harry Mellor 13698db634
Improve configs - ModelConfig (#17130)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
2025-04-30 10:38:22 +08:00

461 lines
17 KiB
Python

# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
BITBLAS_OPTIMIZE_FEATURES, BITBLAS_SUPPORTED_NUM_BITS,
BITBLAS_SUPPORTED_SYM, MINIMUM_BITBLAS_VERSION)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedvLLMParameter)
from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__)
class BitBLASConfig(QuantizationConfig):
"""Config class for BitBLAS.
Reference: https://github.com/Microsoft/BitBLAS
"""
TORCH_DTYPE = torch.float16
STORAGE_DTYPE = "int8" # assume int8 storage
TORCH_STORAGE_DTYPE = getattr(torch, STORAGE_DTYPE)
# "original" or "rescale" or "quantized",
# gptq_with_bitblas prefer "quantized implementation"
ZEROS_MODE = "quantized"
def __init__(
self,
weight_bits: int,
group_size: Optional[int],
desc_act: Optional[bool],
is_sym: Optional[bool],
quant_method: Optional[str],
lm_head_quantized: bool,
) -> None:
try:
import bitblas
if bitblas.__version__ < MINIMUM_BITBLAS_VERSION:
raise ImportError(
"bitblas version is wrong. Please "
f"install bitblas>={MINIMUM_BITBLAS_VERSION}")
except ImportError as e:
bitblas_import_exception = e
raise ValueError(
"Trying to use the bitblas backend, but could not import"
f"with the following error: {bitblas_import_exception}. "
"Please install bitblas through the following command: "
f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`"
) from bitblas_import_exception
if desc_act and group_size == -1:
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
desc_act = False
self.weight_bits = weight_bits
self.group_size = group_size
self.desc_act = desc_act
self.is_sym = is_sym
self.quant_method = quant_method
self.lm_head_quantized = lm_head_quantized
# Verify
if self.weight_bits not in BITBLAS_SUPPORTED_NUM_BITS:
raise ValueError(
f"BitBLAS does not support weight_bits = {self.weight_bits}. "
f"Only weight_bits = {BITBLAS_SUPPORTED_NUM_BITS} "
"are supported.")
if self.is_sym not in BITBLAS_SUPPORTED_SYM:
raise ValueError(
f"BitBLAS does not support is_sym = {self.is_sym}. "
f"Only sym = {BITBLAS_SUPPORTED_SYM} are supported.")
storage_dtype = self.STORAGE_DTYPE
storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))
self.storage_dtype = storage_dtype
self.storage_torch_dtype = self.TORCH_STORAGE_DTYPE
# 4 Bits packed into 32 bit datatype.
self.pack_factor = storage_nbit // weight_bits
self.nbits = weight_bits
# Zeros type for the quantized weights.
self.zeros_mode = self.ZEROS_MODE
def __repr__(self) -> str:
return (f"BitBLASConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act}, "
f"is_sym={self.is_sym}, "
f"quant_method={self.quant_method})")
@classmethod
def get_name(cls) -> QuantizationMethods:
return "bitblas"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half, torch.bfloat16]
@classmethod
# Need to figure it out
def get_min_capability(cls) -> int:
return 70
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["quantize_config.json"]
@staticmethod
def get_from_keys(config: Dict[str, Any],
keys: List[str],
default: Any = None) -> Any:
"""Get a value from the model's quantization config."""
for key in keys:
if key in config:
return config[key]
return default
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "BitBLASConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"], -1)
desc_act = cls.get_from_keys(config, ["desc_act"], False)
is_sym = cls.get_from_keys(config, ["sym"], False)
quant_method = cls.get_from_keys(config, ["quant_method"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
return cls(weight_bits, group_size, desc_act, is_sym, quant_method,
lm_head_quantized)
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
# compat: autogptq >=0.8.0 use checkpoint_format: str
# compat: autogptq <=0.7.1 is_bitblas_format: bool
is_bitblas_format = (hf_quant_cfg.get("checkpoint_format") == "bitblas"
or hf_quant_cfg.get("is_bitblas_format", False))
is_valid_user_quant = (user_quant is None or user_quant == "gptq"
or user_quant == "bitblas")
if is_bitblas_format and is_valid_user_quant:
msg = ("The model is serialized in {} format. Using {} kernel.".
format(cls.get_name(), cls.get_name()))
logger.info(msg)
return cls.get_name()
return None
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["BitBLASLinearMethod"]:
if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead)
and self.lm_head_quantized):
return BitBLASLinearMethod(self)
return None
class BitBLASLinearMethod(LinearMethodBase):
"""Linear method for BitBLAS.
Args:
quant_config: The BitBLAS quantization config.
"""
# USE BITBLAS_OPTIMIZE_FEATURES_CONTIGUOUS
# Instead of BITBLAS_OPTIMIZE_FEATURES
# If you want to high contiguous batching
# performance
OPT_FEATURES = BITBLAS_OPTIMIZE_FEATURES
ENABLE_TUNING = True
BITBLAS_DTYPES = {
torch.float32: "float32",
torch.float16: "float16",
torch.bfloat16: "bfloat16",
torch.half: "float16",
torch.int8: "int8",
}
def __init__(self, quant_config: BitBLASConfig):
self.quant_config = quant_config
def create_weights_gptq(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
"""Creates quantized weights for use in linear operations.
The function initializes and returns a dictionary containing quantized
weights, scales, and zeros
for performing quantized matrix multiplication operations.
Args:
input_size_per_partition: The size of the input partition.
output_size_per_partition: The size of the output partition.
input_size: The total size of the input (unused).
output_size: The total size of the output (unused).
params_dtype:
The data type of the parameters (expected to be torch.float16).
Returns:
A dictionary containing the quantized weights ('qweight'),
scales ('scales'), and zeros ('zeros').
Raises:
ValueError: If `params_dtype` is not `torch.float16` or if the
input size per partition is not divisible by the group size in
`quant_config`.
"""
del input_size, output_size # Unused arguments.
weight_loader = extra_weight_attrs["weight_loader"]
if params_dtype not in self.quant_config.get_supported_act_dtypes():
raise ValueError("Parameter data type must be torch.float16, "
f"but got {params_dtype}")
group_size = self.quant_config.group_size
if group_size is None:
group_size = -1
# Validate output_size_per_partition
output_size_per_partition = sum(output_partition_sizes)
if (group_size != -1 and input_size_per_partition % group_size != 0):
raise ValueError(
f"Input size per partition ({input_size_per_partition}) must "
f"be divisible by group size ({group_size}).")
# Initialize or retrieve the BitBLAS matrix multiplication operator.
self._configure_bitblas_matmul(
input_size_per_partition,
output_size_per_partition,
params_dtype=params_dtype,
enable_tuning=self.ENABLE_TUNING,
bias=False,
layout="nt",
bits=self.quant_config.weight_bits,
)
# Initialize quantized weights with dimensions
# Quantized 4Bit weights packed.
qweight = PackedvLLMParameter(
data=torch.empty(
self.bitblas_matmul.retrieve_weight_shape(),
device="cuda",
dtype=self.quant_config.storage_torch_dtype,
requires_grad=False,
),
input_dim=1,
output_dim=0,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
bitblas_tile_size=(self.bitblas_matmul.retrieve_weight_shape()[-2]
if self.bitblas_matmul.propagate_b else None),
weight_loader=weight_loader,
)
# Compute the number of input groups for channel-wise quantization.
input_groups = (1 if group_size == -1 else input_size_per_partition //
group_size)
# Initialize scales and zeros for the quantized weights.
weight_scale_args = {
"data":
torch.empty(
output_size_per_partition,
input_groups,
device="cuda",
dtype=params_dtype,
),
"weight_loader":
weight_loader
}
if input_groups == 1:
scales = ChannelQuantScaleParameter(output_dim=0,
**weight_scale_args)
else:
scales = GroupQuantScaleParameter(output_dim=0,
input_dim=1,
**weight_scale_args)
if self.quant_config.zeros_mode == "quantized":
zeros = PackedvLLMParameter(
data=torch.empty(
input_groups,
output_size_per_partition // self.quant_config.pack_factor,
device="cuda",
dtype=self.quant_config.storage_torch_dtype,
requires_grad=False,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader,
)
else:
zeros = BasevLLMParameter(
torch.empty(output_size_per_partition,
input_groups,
device="cuda",
dtype=params_dtype),
weight_loader=weight_loader,
)
# Set attributes to indicate how scales and zeros are applied.
set_weight_attrs(zeros, {
"input_dim": None if input_groups == 1 else 1,
"output_dim": 0,
})
layer.register_parameter("qweight", qweight)
layer.register_parameter("scales", scales)
layer.register_parameter("zeros", zeros)
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
if self.quant_config.quant_method == "gptq":
return self.create_weights_gptq(layer, input_size_per_partition,
output_partition_sizes, input_size,
output_size, params_dtype,
**extra_weight_attrs)
else:
raise ValueError(
f"Unsupported quant_method {self.quant_config.quant_method}")
def _configure_bitblas_matmul(
self,
infeatures,
outfeatures,
params_dtype,
enable_tuning,
bias,
layout,
bits,
out_dtype="float16",
):
from bitblas import MatmulConfig
bitblas_dtype = self.BITBLAS_DTYPES[params_dtype]
with_scaling = False
with_zeros = False
group_size = self.quant_config.group_size
zeros_mode = self.quant_config.zeros_mode
if self.quant_config.quant_method == "gptq":
with_scaling = True
with_zeros = True
W_dtype = f"uint{bits}"
if self.quant_config.is_sym:
with_zeros = False
W_dtype = f"int{bits}"
else:
raise ValueError(
f"Unsupported quant_method {self.quant_config.quant_method}")
matmul_config = MatmulConfig(
N=outfeatures,
K=infeatures,
A_dtype=bitblas_dtype,
W_dtype=W_dtype,
out_dtype=out_dtype,
accum_dtype="int32" if bitblas_dtype == "int8" else bitblas_dtype,
storage_dtype=self.quant_config.STORAGE_DTYPE,
with_scaling=with_scaling,
with_zeros=with_zeros,
group_size=group_size,
with_bias=bias,
layout=layout,
zeros_mode=zeros_mode,
)
self.bitblas_matmul = self._get_or_create_bitblas_operator(
matmul_config, enable_tuning)
def _get_or_create_bitblas_operator(self, config, enable_tuning):
from bitblas import Matmul, auto_detect_nvidia_target
from bitblas.cache import get_database_path, global_operator_cache
BITBLAS_DATABASE_PATH = get_database_path()
BITBLAS_TARGET = auto_detect_nvidia_target()
if global_operator_cache.size() == 0:
global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH,
BITBLAS_TARGET)
bitblas_matmul = global_operator_cache.get(config)
if bitblas_matmul is None:
bitblas_matmul = Matmul(config,
target=BITBLAS_TARGET,
enable_tuning=False)
if enable_tuning:
TUNING_MESSAGE = (f"BitBLAS Operator {config} is tuning ...")
logger.info(TUNING_MESSAGE)
bitblas_matmul.hardware_aware_finetune(topk=20)
global_operator_cache.add(config, bitblas_matmul)
global_operator_cache.save_into_database(
BITBLAS_DATABASE_PATH, BITBLAS_TARGET)
TUNED_MESSAGE = (
f"BitBLAS Operator {config} tuned and saved to database.")
logger.info(TUNED_MESSAGE)
else:
_message = f"BitBLAS Operator {config} created."
logger.info(_message)
else:
_message = (
f"BitBLAS Operator {config} found in global_operator_cache.")
logger.info(_message)
return bitblas_matmul
def apply_gptq(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qweight = layer.qweight
scales = layer.scales
qzeros = layer.zeros
x_2d = x.view(-1, x.shape[-1])
if self.quant_config.is_sym:
output_2d = self.bitblas_matmul(x_2d, qweight, scales)
else:
output_2d = self.bitblas_matmul(x_2d, qweight, scales, qzeros)
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
if bias is not None:
output.add_(bias) # In-place add
return output
def apply(
self,
*args: Any,
**kwargs: Any,
) -> torch.Tensor:
if self.quant_config.quant_method == "gptq":
return self.apply_gptq(*args, **kwargs)
else:
raise ValueError(
f"Unsupported quant_method {self.quant_config.quant_method}")