[Kernel] Support Microsoft Runtime Kernel Lib for our Low Precision Computation - BitBLAS (#6036)

Signed-off-by: xinyuxiao <xinyuxiao2024@gmail.com>
Co-authored-by: xinyuxiao <xinyuxiao2024@gmail.com>
This commit is contained in:
Lei Wang 2025-04-22 16:01:36 +08:00 committed by GitHub
parent c4ab9f3e71
commit 8d32dc603d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 1864 additions and 7 deletions

View File

@ -0,0 +1,236 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
MINIMUM_BITBLAS_VERSION)
try:
import bitblas
if bitblas.__version__ < MINIMUM_BITBLAS_VERSION:
raise ImportError("bitblas version is wrong. Please "
f"install bitblas>={MINIMUM_BITBLAS_VERSION}")
except ImportError as e:
bitblas_import_exception = e
raise ValueError("Trying to use the bitblas backend, but could not import"
f"with the following error: {bitblas_import_exception}. "
"Please install bitblas through the following command: "
f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`"
) from bitblas_import_exception
from bitblas import Matmul, MatmulConfig, auto_detect_nvidia_target
from vllm.utils import FlexibleArgumentParser
parser = FlexibleArgumentParser(
description="Benchmark BitBLAS int4 on a specific target.")
# Add arguments to the parser
parser.add_argument(
"--target",
type=str,
default=auto_detect_nvidia_target(),
help="Specify the target device for benchmarking.",
)
parser.add_argument("--group_size",
type=int,
default=None,
help="Group size for grouped quantization.")
parser.add_argument(
"--A_dtype",
type=str,
default="float16",
choices=["float16", "float32", "float64", "int32", "int8"],
help="Data type of activation A.",
)
parser.add_argument(
"--W_dtype",
type=str,
default="int4",
choices=[
"float16",
"float32",
"float64",
"int32",
"int8",
"int4",
"int2",
"int1",
"nf4",
"fp4_e2m1",
],
help="Data type of weight W.",
)
parser.add_argument(
"--accum_dtype",
type=str,
default="float16",
choices=["float16", "int32"],
help="Data type for accumulation.",
)
parser.add_argument(
"--out_dtype",
type=str,
default="float16",
choices=["float16", "float32", "int32", "int8"],
help="Data type for output.",
)
parser.add_argument(
"--layout",
type=str,
default="nt",
choices=["nt", "nn"],
help="Matrix layout, 'nt' for non-transpose A and transpose W.",
)
parser.add_argument("--with_bias",
action="store_true",
help="Include bias in the benchmark.")
parser.add_argument(
"--with_scaling",
action="store_true",
help="Include scaling factor in the quantization.",
)
parser.add_argument("--with_zeros",
action="store_true",
help="Include zeros in the quantization.")
parser.add_argument(
"--zeros_mode",
type=str,
default=None,
choices=["original", "rescale", "quantized"],
help="Specify the mode for calculating zeros.",
)
# Parse the arguments
args = parser.parse_args()
# Assign arguments to variables
target = args.target
A_dtype = args.A_dtype
W_dtype = args.W_dtype
accum_dtype = args.accum_dtype
out_dtype = args.out_dtype
layout = args.layout
with_bias = args.with_bias
group_size = args.group_size
with_scaling = args.with_scaling
with_zeros = args.with_zeros
zeros_mode = args.zeros_mode
# Define a list of shared arguments that repeat in every config
shared_args = [
A_dtype,
W_dtype,
out_dtype,
accum_dtype,
layout,
with_bias,
group_size,
with_scaling,
with_zeros,
zeros_mode,
]
# Define just the (M, K, N) shapes in a more compact list
shapes = [
# square test
(1, 16384, 16384),
# BLOOM-176B
(1, 43008, 14336),
(1, 14336, 14336),
(1, 57344, 14336),
(1, 14336, 57344),
# OPT-65B
(1, 9216, 9216),
(1, 36864, 9216),
(1, 9216, 36864),
(1, 22016, 8192),
# LLAMA-70B/65B
(1, 8192, 22016),
(1, 8192, 8192),
(1, 28672, 8192),
(1, 8192, 28672),
# square test
(16384, 16384, 16384),
# BLOOM-176B
(8192, 43008, 14336),
(8192, 14336, 14336),
(8192, 57344, 14336),
(8192, 14336, 57344),
# OPT-65B
(8192, 9216, 9216),
(8192, 36864, 9216),
(8192, 9216, 36864),
(8192, 22016, 8192),
# LLAMA-70B/65B
(8192, 8192, 22016),
(8192, 8192, 8192),
(8192, 28672, 8192),
(8192, 8192, 28672),
]
# Build test shapes with all the shared arguments
test_shapes = [(MatmulConfig, Matmul, (*shape, *shared_args))
for shape in shapes]
benchmark_sets = []
benchmark_sets.extend(test_shapes)
benchmark_results = {}
for config_class, operator, input_args in benchmark_sets:
config = config_class(*input_args)
matmul = operator(config, target=target, enable_tuning=True)
kernel_latency = matmul.profile_latency()
print("Time cost is: {:.3f} ms".format(kernel_latency))
profile_config = {
f"{operator.__name__}-{'-'.join([str(i) for i in input_args])}": {
"BitBLAS_top20_latency": kernel_latency,
}
}
benchmark_results.update(profile_config)
# Define headers for the table
headers = [
"PrimFunc",
"Input Arguments",
"BitBLAS Top20 Latency",
]
# Calculate column widths for pretty printing
col_widths = [0, 0, 0]
for config_key, values in benchmark_results.items():
args_split = config_key.split("-")
func_name = args_split[0]
input_args_str = "-".join(args_split[1:])
col_widths[0] = max(col_widths[0], len(func_name) + 2, len(headers[0]) + 2)
col_widths[1] = max(col_widths[1],
len(input_args_str) + 2,
len(headers[1]) + 2)
col_widths[2] = max(col_widths[2],
len(f"{values['BitBLAS_top20_latency']:.3f} ms") + 2,
len(headers[2]) + 2)
# break only if you want to measure widths from a single example;
# otherwise, let it loop over all items.
# Print header
for i, header in enumerate(headers):
headers[i] = header.ljust(col_widths[i])
print("".join(headers))
print("-" * sum(col_widths))
# Print rows
for config_key, values in benchmark_results.items():
args_split = config_key.split("-")
func_name = args_split[0]
input_args_str = "-".join(args_split[1:])
row = [
func_name,
input_args_str,
f"{values['BitBLAS_top20_latency']:.3f} ms",
]
row_str = "".join(
[str(cell).ljust(col_widths[idx]) for idx, cell in enumerate(row)])
print(row_str)

View File

@ -0,0 +1,40 @@
# BitBLAS
vLLM now supports [BitBLAS](https://github.com/microsoft/BitBLAS) for more efficient and flexible model inference. Compared to other quantization frameworks, BitBLAS provides more precision combinations.
Below are the steps to utilize BitBLAS with vLLM.
```console
pip install bitblas>=0.1.0
```
vLLM reads the model's config file and supports pre-quantized checkpoints.
You can find pre-quantized models on:
- [Hugging Face (BitBLAS)](https://huggingface.co/models?other=bitblas)
- [Hugging Face (GPTQ)](https://huggingface.co/models?other=gptq)
Usually, these repositories have a `quantize_config.json` file that includes a `quantization_config` section.
## Read bitblas format checkpoint
```python
from vllm import LLM
import torch
# "hxbgsyxh/llama-13b-4bit-g-1-bitblas" is a pre-quantized checkpoint.
model_id = "hxbgsyxh/llama-13b-4bit-g-1-bitblas"
llm = LLM(model=model_id, dtype=torch.bfloat16, trust_remote_code=True, quantization="bitblas")
```
## Read gptq format checkpoint
```python
from vllm import LLM
import torch
# "hxbgsyxh/llama-13b-4bit-g-1" is a pre-quantized checkpoint.
model_id = "hxbgsyxh/llama-13b-4bit-g-1"
llm = LLM(model=model_id, dtype=torch.float16, trust_remote_code=True, quantization="bitblas", max_model_len=1024)
```

View File

@ -11,6 +11,7 @@ Quantization trades off model precision for smaller memory footprint, allowing l
supported_hardware
auto_awq
bnb
bitblas
gguf
gptqmodel
int4

View File

@ -74,6 +74,17 @@ The table below shows the compatibility of various quantization implementations
* ❌
* ❌
* ❌
- * BitBLAS (GPTQ)
* ✅︎
* ✅︎
* ✅︎
* ✅︎
* ✅︎
* ✅︎
* ❌
* ❌
* ❌
* ❌
- * AQLM
* ✅︎
* ✅︎

View File

@ -0,0 +1,63 @@
# SPDX-License-Identifier: Apache-2.0
"""Compare the outputs of a GPTQ model to a bitblas model.
Note: GPTQ and bitblas do not have bitwise correctness.
As a result, in this test, we just confirm that the top selected tokens of the
bitblas/GPTQ models are in the top 3 selections of each other.
Note: bitblas internally uses locks to synchronize the threads. This can
result in very slight nondeterminism for bitblas. As a result, we re-run the
test up to 3 times to see if we pass.
Run `pytest tests/models/test_bitblas.py`.
"""
from dataclasses import dataclass
import pytest
from .utils import check_logprobs_close
@dataclass
class ModelPair:
model_bitblas: str
model_gptq: str
model_pairs = [
ModelPair(model_bitblas="hxbgsyxh/opt-125m-4bit-128g-bitblas",
model_gptq="hxbgsyxh/opt-125m-4bit-128g"),
]
@pytest.mark.flaky(reruns=2)
@pytest.mark.skipif(True, reason="BitBLAS takes too much time for tuning.")
@pytest.mark.parametrize("model_pair", model_pairs)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(
vllm_runner,
example_prompts,
model_pair: ModelPair,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
with vllm_runner(model_pair.model_bitblas,
dtype=dtype,
quantization="bitblas") as bitblas_model:
bitblas_outputs = bitblas_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
with vllm_runner(model_pair.model_gptq, dtype=dtype,
quantization="gptq") as gptq_model:
gptq_outputs = gptq_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
check_logprobs_close(
outputs_0_lst=gptq_outputs,
outputs_1_lst=bitblas_outputs,
name_0="gptq",
name_1="bitblas",
)

View File

@ -0,0 +1,61 @@
# SPDX-License-Identifier: Apache-2.0
"""Compare the outputs of a GPTQ model to a bitblas model.
Note: GPTQ and bitblas do not have bitwise correctness.
As a result, in this test, we just confirm that the top selected tokens of the
bitblas/GPTQ models are in the top 3 selections of each other.
Note: bitblas internally uses locks to synchronize the threads. This can
result in very slight nondeterminism for bitblas. As a result, we re-run the
test up to 3 times to see if we pass.
Run `pytest tests/models/test_bitblas.py`.
"""
from dataclasses import dataclass
import pytest
from .utils import check_logprobs_close
@dataclass
class ModelPair:
model_gptq: str
model_pairs = [
ModelPair(model_gptq="hxbgsyxh/opt-125m-4bit-128g"),
]
@pytest.mark.flaky(reruns=2)
@pytest.mark.skipif(True, reason="BitBLAS takes too much time for tuning.")
@pytest.mark.parametrize("model_pair", model_pairs)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(
vllm_runner,
example_prompts,
model_pair: ModelPair,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
with vllm_runner(model_pair.model_gptq,
dtype=dtype,
quantization="bitblas") as bitblas_model:
bitblas_outputs = bitblas_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
with vllm_runner(model_pair.model_gptq, dtype=dtype,
quantization="gptq") as gptq_model:
gptq_outputs = gptq_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
check_logprobs_close(
outputs_0_lst=gptq_outputs,
outputs_1_lst=bitblas_outputs,
name_0="gptq",
name_1="gptq_bitblas",
)

View File

@ -750,7 +750,8 @@ class ModelConfig:
optimized_quantization_methods = [
"fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
"awq_marlin", "fbgemm_fp8", "compressed_tensors",
"compressed-tensors", "experts_int8", "quark", "nvfp4"
"compressed-tensors", "experts_int8", "quark", "nvfp4", "bitblas",
"gptq_bitblas"
]
if self.quantization is not None:
self.quantization = self.quantization.lower()

View File

@ -31,6 +31,8 @@ logger = init_logger(__name__)
WEIGHT_LOADER_V2_SUPPORTED = [
"CompressedTensorsLinearMethod",
"BitBLASLinearMethod",
"GPTQBitBLASLinearMethod",
"AWQMarlinLinearMethod",
"AWQLinearMethod",
"GPTQMarlinLinearMethod",
@ -50,6 +52,15 @@ WEIGHT_LOADER_V2_SUPPORTED = [
]
def adjust_bitblas_shard(param, shard_size, shard_offset):
bitblas_tile_size = getattr(param, "bitblas_tile_size", None)
if bitblas_tile_size is not None:
return (shard_size // bitblas_tile_size,
shard_offset // bitblas_tile_size)
return shard_size, shard_offset
def adjust_marlin_shard(param, shard_size, shard_offset):
marlin_tile_size = getattr(param, "marlin_tile_size", None)
if marlin_tile_size is None:
@ -615,6 +626,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
shard_size, shard_offset = adjust_bitblas_shard(
param, shard_size, shard_offset)
if use_bitsandbytes_4bit:
index = list(itertools.accumulate([0] + self.output_sizes))
orig_offsets = {
@ -646,6 +660,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
shard_size, shard_offset = adjust_bitblas_shard(
param, shard_size, shard_offset)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
False)

View File

@ -18,9 +18,11 @@ QUANTIZATION_METHODS: List[str] = [
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin",
"bitblas",
"gguf",
"gptq_marlin_24",
"gptq_marlin",
"gptq_bitblas",
"awq_marlin",
"gptq",
"compressed-tensors",
@ -85,6 +87,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
from .aqlm import AQLMConfig
from .awq import AWQConfig
from .awq_marlin import AWQMarlinConfig
from .bitblas import BitBLASConfig
from .bitsandbytes import BitsAndBytesConfig
from .compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsConfig)
@ -94,6 +97,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
from .fp8 import Fp8Config
from .gguf import GGUFConfig
from .gptq import GPTQConfig
from .gptq_bitblas import GPTQBitBLASConfig
from .gptq_marlin import GPTQMarlinConfig
from .gptq_marlin_24 import GPTQMarlin24Config
from .hqq_marlin import HQQMarlinConfig
@ -119,9 +123,11 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin": MarlinConfig,
"bitblas": BitBLASConfig,
"gguf": GGUFConfig,
"gptq_marlin_24": GPTQMarlin24Config,
"gptq_marlin": GPTQMarlinConfig,
"gptq_bitblas": GPTQBitBLASConfig,
"awq_marlin": AWQMarlinConfig,
"gptq": GPTQConfig,
"compressed-tensors": CompressedTensorsConfig,
@ -146,4 +152,4 @@ __all__ = [
"QuantizationConfig",
"get_quantization_config",
"QUANTIZATION_METHODS",
]
]

View File

@ -0,0 +1,459 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
BITBLAS_OPTIMIZE_FEATURES, BITBLAS_SUPPORTED_NUM_BITS,
BITBLAS_SUPPORTED_SYM, MINIMUM_BITBLAS_VERSION)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedvLLMParameter)
from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__)
class BitBLASConfig(QuantizationConfig):
"""Config class for BitBLAS.
Reference: https://github.com/Microsoft/BitBLAS
"""
TORCH_DTYPE = torch.float16
STORAGE_DTYPE = "int8" # assume int8 storage
TORCH_STORAGE_DTYPE = getattr(torch, STORAGE_DTYPE)
# "original" or "rescale" or "quantized",
# gptq_with_bitblas prefer "quantized implementation"
ZEROS_MODE = "quantized"
def __init__(
self,
weight_bits: int,
group_size: Optional[int],
desc_act: Optional[bool],
is_sym: Optional[bool],
quant_method: Optional[str],
lm_head_quantized: bool,
) -> None:
try:
import bitblas
if bitblas.__version__ < MINIMUM_BITBLAS_VERSION:
raise ImportError(
"bitblas version is wrong. Please "
f"install bitblas>={MINIMUM_BITBLAS_VERSION}")
except ImportError as e:
bitblas_import_exception = e
raise ValueError(
"Trying to use the bitblas backend, but could not import"
f"with the following error: {bitblas_import_exception}. "
"Please install bitblas through the following command: "
f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`"
) from bitblas_import_exception
if desc_act and group_size == -1:
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
desc_act = False
self.weight_bits = weight_bits
self.group_size = group_size
self.desc_act = desc_act
self.is_sym = is_sym
self.quant_method = quant_method
self.lm_head_quantized = lm_head_quantized
# Verify
if self.weight_bits not in BITBLAS_SUPPORTED_NUM_BITS:
raise ValueError(
f"BitBLAS does not support weight_bits = {self.weight_bits}. "
f"Only weight_bits = {BITBLAS_SUPPORTED_NUM_BITS} "
"are supported.")
if self.is_sym not in BITBLAS_SUPPORTED_SYM:
raise ValueError(
f"BitBLAS does not support is_sym = {self.is_sym}. "
f"Only sym = {BITBLAS_SUPPORTED_SYM} are supported.")
storage_dtype = self.STORAGE_DTYPE
storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))
self.storage_dtype = storage_dtype
self.storage_torch_dtype = self.TORCH_STORAGE_DTYPE
# 4 Bits packed into 32 bit datatype.
self.pack_factor = storage_nbit // weight_bits
self.nbits = weight_bits
# Zeros type for the quantized weights.
self.zeros_mode = self.ZEROS_MODE
def __repr__(self) -> str:
return (f"BitBLASConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act}, "
f"is_sym={self.is_sym}, "
f"quant_method={self.quant_method})")
@classmethod
def get_name(cls) -> str:
return "bitblas"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half, torch.bfloat16]
@classmethod
# Need to figure it out
def get_min_capability(cls) -> int:
return 70
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["quantize_config.json"]
@staticmethod
def get_from_keys(config: Dict[str, Any],
keys: List[str],
default: Any = None) -> Any:
"""Get a value from the model's quantization config."""
for key in keys:
if key in config:
return config[key]
return default
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "BitBLASConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"], -1)
desc_act = cls.get_from_keys(config, ["desc_act"], False)
is_sym = cls.get_from_keys(config, ["sym"], False)
quant_method = cls.get_from_keys(config, ["quant_method"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
return cls(weight_bits, group_size, desc_act, is_sym, quant_method,
lm_head_quantized)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
# compat: autogptq >=0.8.0 use checkpoint_format: str
# compat: autogptq <=0.7.1 is_bitblas_format: bool
is_bitblas_format = (hf_quant_cfg.get("checkpoint_format") == "bitblas"
or hf_quant_cfg.get("is_bitblas_format", False))
is_valid_user_quant = (user_quant is None or user_quant == "gptq"
or user_quant == "bitblas")
if is_bitblas_format and is_valid_user_quant:
msg = ("The model is serialized in {} format. Using {} kernel.".
format(cls.get_name(), cls.get_name()))
logger.info(msg)
return cls.get_name()
return None
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["BitBLASLinearMethod"]:
if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead)
and self.lm_head_quantized):
return BitBLASLinearMethod(self)
return None
class BitBLASLinearMethod(LinearMethodBase):
"""Linear method for BitBLAS.
Args:
quant_config: The BitBLAS quantization config.
"""
# USE BITBLAS_OPTIMIZE_FEATURES_CONTIGUOUS
# Instead of BITBLAS_OPTIMIZE_FEATURES
# If you want to high contiguous batching
# performance
OPT_FEATURES = BITBLAS_OPTIMIZE_FEATURES
ENABLE_TUNING = True
BITBLAS_DTYPES = {
torch.float32: "float32",
torch.float16: "float16",
torch.bfloat16: "bfloat16",
torch.half: "float16",
torch.int8: "int8",
}
def __init__(self, quant_config: BitBLASConfig):
self.quant_config = quant_config
def create_weights_gptq(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
"""Creates quantized weights for use in linear operations.
The function initializes and returns a dictionary containing quantized
weights, scales, and zeros
for performing quantized matrix multiplication operations.
Args:
input_size_per_partition: The size of the input partition.
output_size_per_partition: The size of the output partition.
input_size: The total size of the input (unused).
output_size: The total size of the output (unused).
params_dtype:
The data type of the parameters (expected to be torch.float16).
Returns:
A dictionary containing the quantized weights ('qweight'),
scales ('scales'), and zeros ('zeros').
Raises:
ValueError: If `params_dtype` is not `torch.float16` or if the
input size per partition is not divisible by the group size in
`quant_config`.
"""
del input_size, output_size # Unused arguments.
weight_loader = extra_weight_attrs["weight_loader"]
if params_dtype not in self.quant_config.get_supported_act_dtypes():
raise ValueError("Parameter data type must be torch.float16, "
f"but got {params_dtype}")
group_size = self.quant_config.group_size
if group_size is None:
group_size = -1
# Validate output_size_per_partition
output_size_per_partition = sum(output_partition_sizes)
if (group_size != -1 and input_size_per_partition % group_size != 0):
raise ValueError(
f"Input size per partition ({input_size_per_partition}) must "
f"be divisible by group size ({group_size}).")
# Initialize or retrieve the BitBLAS matrix multiplication operator.
self._configure_bitblas_matmul(
input_size_per_partition,
output_size_per_partition,
params_dtype=params_dtype,
enable_tuning=self.ENABLE_TUNING,
bias=False,
layout="nt",
bits=self.quant_config.weight_bits,
)
# Initialize quantized weights with dimensions
# Quantized 4Bit weights packed.
qweight = PackedvLLMParameter(
data=torch.empty(
self.bitblas_matmul.retrieve_weight_shape(),
device="cuda",
dtype=self.quant_config.storage_torch_dtype,
requires_grad=False,
),
input_dim=1,
output_dim=0,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
bitblas_tile_size=(self.bitblas_matmul.retrieve_weight_shape()[-2]
if self.bitblas_matmul.propagate_b else None),
weight_loader=weight_loader,
)
# Compute the number of input groups for channel-wise quantization.
input_groups = (1 if group_size == -1 else input_size_per_partition //
group_size)
# Initialize scales and zeros for the quantized weights.
weight_scale_args = {
"data":
torch.empty(
output_size_per_partition,
input_groups,
device="cuda",
dtype=params_dtype,
),
"weight_loader":
weight_loader
}
if input_groups == 1:
scales = ChannelQuantScaleParameter(output_dim=0,
**weight_scale_args)
else:
scales = GroupQuantScaleParameter(output_dim=0,
input_dim=1,
**weight_scale_args)
if self.quant_config.zeros_mode == "quantized":
zeros = PackedvLLMParameter(
data=torch.empty(
input_groups,
output_size_per_partition // self.quant_config.pack_factor,
device="cuda",
dtype=self.quant_config.storage_torch_dtype,
requires_grad=False,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader,
)
else:
zeros = BasevLLMParameter(
torch.empty(output_size_per_partition,
input_groups,
device="cuda",
dtype=params_dtype),
weight_loader=weight_loader,
)
# Set attributes to indicate how scales and zeros are applied.
set_weight_attrs(zeros, {
"input_dim": None if input_groups == 1 else 1,
"output_dim": 0,
})
layer.register_parameter("qweight", qweight)
layer.register_parameter("scales", scales)
layer.register_parameter("zeros", zeros)
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
if self.quant_config.quant_method == "gptq":
return self.create_weights_gptq(layer, input_size_per_partition,
output_partition_sizes, input_size,
output_size, params_dtype,
**extra_weight_attrs)
else:
raise ValueError(
f"Unsupported quant_method {self.quant_config.quant_method}")
def _configure_bitblas_matmul(
self,
infeatures,
outfeatures,
params_dtype,
enable_tuning,
bias,
layout,
bits,
out_dtype="float16",
):
from bitblas import MatmulConfig
bitblas_dtype = self.BITBLAS_DTYPES[params_dtype]
with_scaling = False
with_zeros = False
group_size = self.quant_config.group_size
zeros_mode = self.quant_config.zeros_mode
if self.quant_config.quant_method == "gptq":
with_scaling = True
with_zeros = True
W_dtype = f"uint{bits}"
if self.quant_config.is_sym:
with_zeros = False
W_dtype = f"int{bits}"
else:
raise ValueError(
f"Unsupported quant_method {self.quant_config.quant_method}")
matmul_config = MatmulConfig(
N=outfeatures,
K=infeatures,
A_dtype=bitblas_dtype,
W_dtype=W_dtype,
out_dtype=out_dtype,
accum_dtype="int32" if bitblas_dtype == "int8" else bitblas_dtype,
storage_dtype=self.quant_config.STORAGE_DTYPE,
with_scaling=with_scaling,
with_zeros=with_zeros,
group_size=group_size,
with_bias=bias,
layout=layout,
zeros_mode=zeros_mode,
)
self.bitblas_matmul = self._get_or_create_bitblas_operator(
matmul_config, enable_tuning)
def _get_or_create_bitblas_operator(self, config, enable_tuning):
from bitblas import Matmul, auto_detect_nvidia_target
from bitblas.cache import get_database_path, global_operator_cache
BITBLAS_DATABASE_PATH = get_database_path()
BITBLAS_TARGET = auto_detect_nvidia_target()
if global_operator_cache.size() == 0:
global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH,
BITBLAS_TARGET)
bitblas_matmul = global_operator_cache.get(config)
if bitblas_matmul is None:
bitblas_matmul = Matmul(config,
target=BITBLAS_TARGET,
enable_tuning=False)
if enable_tuning:
TUNING_MESSAGE = (f"BitBLAS Operator {config} is tuning ...")
logger.info(TUNING_MESSAGE)
bitblas_matmul.hardware_aware_finetune(topk=20)
global_operator_cache.add(config, bitblas_matmul)
global_operator_cache.save_into_database(
BITBLAS_DATABASE_PATH, BITBLAS_TARGET)
TUNED_MESSAGE = (
f"BitBLAS Operator {config} tuned and saved to database.")
logger.info(TUNED_MESSAGE)
else:
_message = f"BitBLAS Operator {config} created."
logger.info(_message)
else:
_message = (
f"BitBLAS Operator {config} found in global_operator_cache.")
logger.info(_message)
return bitblas_matmul
def apply_gptq(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qweight = layer.qweight
scales = layer.scales
qzeros = layer.zeros
x_2d = x.view(-1, x.shape[-1])
if self.quant_config.is_sym:
output_2d = self.bitblas_matmul(x_2d, qweight, scales)
else:
output_2d = self.bitblas_matmul(x_2d, qweight, scales, qzeros)
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
if bias is not None:
output.add_(bias) # In-place add
return output
def apply(
self,
*args: Any,
**kwargs: Any,
) -> torch.Tensor:
if self.quant_config.quant_method == "gptq":
return self.apply_gptq(*args, **kwargs)
else:
raise ValueError(
f"Unsupported quant_method {self.quant_config.quant_method}")

View File

@ -0,0 +1,438 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional, Set
import torch
from torch.nn.parameter import Parameter
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
BitBLASLinearKernel, MPLinearLayerConfig)
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
BITBLAS_SUPPORTED_NUM_BITS as GPTQ_BITBLAS_SUPPORTED_NUM_BITS)
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
BITBLAS_SUPPORTED_SYM as GPTQ_BITBLAS_SUPPORTED_SYM)
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
MINIMUM_BITBLAS_VERSION, bitblas_repeat_scales_on_all_ranks,
check_bitblas_supported, verify_bitblas_supported)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
RowvLLMParameter)
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
class GPTQBitBLASConfig(QuantizationConfig):
"""Config class for GPTQ BitBLAS"""
# (num_bits, is_sym) -> quant_type
TYPE_MAP = {
(4, True): scalar_types.uint4b8,
(8, True): scalar_types.uint8b128,
}
TORCH_DTYPE = torch.float16
GPTQ_CKPT_STORAGE_DTYPE = (
"int32" # GPTQ Default Checkpoints use int32 as storage dtype
)
GPTQ_BITBLAS_STORAGE_DTYPE = "int8" # BitBLAS uses int8 as storage dtype
TORCH_BITBLAS_STORAGE_DTYPE = getattr(torch, GPTQ_BITBLAS_STORAGE_DTYPE)
# "original" or "rescale" or "quantized",
# the gptq_bitblas prefer "quantized"
ZEROS_MODE = "quantized"
def __init__(
self,
weight_bits: int,
group_size: int,
desc_act: bool,
is_sym: bool,
quant_method: Optional[str],
lm_head_quantized: bool,
) -> None:
try:
import bitblas
if bitblas.__version__ < MINIMUM_BITBLAS_VERSION:
raise ImportError(
"bitblas version is wrong. Please "
f"install bitblas>={MINIMUM_BITBLAS_VERSION}")
except ImportError as e:
bitblas_import_exception = e
raise ValueError(
"Trying to use the bitblas backend, but could not import"
f"with the following error: {bitblas_import_exception}. "
"Please install bitblas through the following command: "
f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`"
) from bitblas_import_exception
if desc_act and group_size == -1:
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
desc_act = False
self.weight_bits = weight_bits
self.group_size = group_size
self.desc_act = desc_act
self.is_sym = is_sym
self.quant_method = quant_method
self.lm_head_quantized = lm_head_quantized
# Verify
if self.weight_bits not in GPTQ_BITBLAS_SUPPORTED_NUM_BITS:
raise ValueError(
f"BitBLAS does not support weight_bits = {self.weight_bits}. "
f"Only weight_bits = {GPTQ_BITBLAS_SUPPORTED_NUM_BITS} "
"are supported.")
if self.is_sym not in GPTQ_BITBLAS_SUPPORTED_SYM:
raise ValueError(
f"BitBLAS does not support is_sym = {self.is_sym}. "
f"Only sym = {GPTQ_BITBLAS_SUPPORTED_SYM} are supported.")
self.storage_dtype = self.GPTQ_BITBLAS_STORAGE_DTYPE
storage_nbit = int("".join(c for c in self.GPTQ_CKPT_STORAGE_DTYPE
if c.isdigit()))
# 4 Bits packed into 32 bit datatype.
self.pack_factor = storage_nbit // weight_bits
self.nbits = weight_bits
# Zeros type for the quantized weights.
self.zeros_mode = self.ZEROS_MODE
if (weight_bits, is_sym) not in self.TYPE_MAP:
raise ValueError("Unsupported quantization config: "
f"bits={weight_bits}, sym={is_sym}")
self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
def __repr__(self) -> str:
return (f"GPTQBitBLASConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act})"
f"is_sym={self.is_sym}, "
f"quant_method={self.quant_method})")
@classmethod
def get_name(cls) -> str:
return "gptq_bitblas"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
return 70
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "GPTQBitBLASConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
desc_act = cls.get_from_keys(config, ["desc_act"])
is_sym = cls.get_from_keys(config, ["sym"])
quant_method = cls.get_from_keys(config, ["quant_method"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
return cls(weight_bits, group_size, desc_act, is_sym, quant_method,
lm_head_quantized)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
can_convert = cls.is_gptq_bitblas_compatible(hf_quant_cfg)
is_valid_user_quant = (user_quant is None or user_quant == "bitblas"
or user_quant == "gptq_bitblas")
if can_convert and is_valid_user_quant:
msg = ("The model is convertible to {} during runtime."
" Using {} kernel.".format(cls.get_name(), cls.get_name()))
logger.info(msg)
return cls.get_name()
if can_convert and user_quant == "gptq":
logger.info("Detected that the model can run with gptq_bitblas"
", however you specified quantization=gptq explicitly,"
" so forcing gptq. Use quantization=gptq_bitblas for"
" faster inference")
return None
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["GPTQBitBLASLinearMethod"]:
if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead)
and self.lm_head_quantized):
return GPTQBitBLASLinearMethod(self)
return None
@property
def torch_storage_dtype(self) -> torch.dtype:
return self.TORCH_BITBLAS_STORAGE_DTYPE
@classmethod
def is_gptq_bitblas_compatible(cls, quant_config: Dict[str, Any]):
# Extract data from quant config.
num_bits = quant_config.get("bits")
group_size = quant_config.get("group_size")
sym = quant_config.get("sym")
desc_act = quant_config.get("desc_act")
# If we cannot find the info needed in the config, cannot convert.
if (num_bits is None or group_size is None or sym is None
or desc_act is None):
return False
if (num_bits, sym) not in cls.TYPE_MAP:
return False
# If the capability of the device is too low, cannot convert.
major, minor = torch.cuda.get_device_capability()
device_capability = major * 10 + minor
if device_capability < cls.get_min_capability():
return False
# Otherwise, can convert if model satisfies bitblas constraints.
return check_bitblas_supported(quant_type=cls.TYPE_MAP[(num_bits,
sym)],
group_size=group_size)
class GPTQBitBLASLinearMethod(LinearMethodBase):
"""Linear method for GPTQ BitBLAS.
Args:
quant_config: The GPTQ BitBLAS quantization config.
"""
kernel_type = BitBLASLinearKernel
_kernel_backends_being_used: Set[str] = set()
def __init__(self, quant_config: GPTQBitBLASConfig) -> None:
self.quant_config = quant_config
# Verify supported on platform.
verify_bitblas_supported(quant_type=self.quant_config.quant_type,
group_size=self.quant_config.group_size)
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
"""Creates quantized weights for use in linear operations.
The function initializes and returns a dictionary containing
quantized weights, scales, and zeros
for performing quantized matrix multiplication operations.
Args:
input_size_per_partition: The size of the input partition.
output_partition_sizes: The size of the output partition.
input_size: The total size of the input (unused).
output_size: The total size of the output (unused).
params_dtype:
The data type of the parameters (expected to be torch.float16).
Returns:
A dictionary containing the quantized weights ('qweight'),
scales ('scales'), and zeros ('zeros').
Raises:
ValueError: If `params_dtype` is not `torch.float16` or
if the input size per partition is not divisible by the
group size in `quant_config`.
"""
if params_dtype != torch.float16:
raise ValueError("Parameter data type must be torch.float16, "
f"but got {params_dtype}")
# Normalize group_size
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size
if input_size_per_partition % group_size != 0:
raise ValueError(
f"Input size per partition ({input_size_per_partition}) must "
f"be divisible by group size ({self.quant_config.group_size})."
)
kernel_type = self.kernel_type
# Validate output_size_per_partition
output_size_per_partition = sum(output_partition_sizes)
is_row_parallel = input_size != input_size_per_partition
weight_loader = extra_weight_attrs.get("weight_loader")
mp_linear_kernel_config = MPLinearLayerConfig(
full_weight_shape=(input_size, output_size),
partition_weight_shape=\
(input_size_per_partition, output_size_per_partition),
weight_type=self.quant_config.quant_type,
act_type=params_dtype,
group_size=self.quant_config.group_size,
zero_points=False,
has_g_idx=self.quant_config.desc_act
)
if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for GPTQBitBLASLinearMethod",
kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__)
# Normalize group_size
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size
# Determine sharding
if bitblas_repeat_scales_on_all_ranks(self.quant_config.desc_act,
self.quant_config.group_size,
is_row_parallel):
# By setting scale_dim == None, weight_loader will
# repeat the scales on each GPU in TP>1 case.
scales_and_zp_input_dim = None
scales_and_zp_size = input_size // group_size
else:
# By setting scale_dim == 0, weight_loader will
# shard the scales in TP>1 case.
scales_and_zp_input_dim = 0
scales_and_zp_size = input_size_per_partition // group_size
# Init buffers
# Quantized weights
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=0,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader)
# Activation order
# Ignore warning from fused linear layers such as QKVParallelLinear.
g_idx = RowvLLMParameter(data=torch.empty(
input_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
weight_loader=weight_loader)
# Scales
scales = Parameter(
torch.empty(
scales_and_zp_size,
output_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(
scales,
{
**extra_weight_attrs,
"input_dim": scales_and_zp_input_dim,
"output_dim": 1,
},
)
# Quantized zero-points
qzeros_args = {
"data":
torch.empty(
scales_and_zp_size,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
"weight_loader":
weight_loader
}
weight_scale_args = {
"data":
torch.empty(
scales_and_zp_size,
output_size_per_partition,
dtype=params_dtype,
),
"weight_loader":
weight_loader
}
if scales_and_zp_input_dim is None:
scales = ChannelQuantScaleParameter(output_dim=1,
**weight_scale_args)
qzeros = PackedColumnParameter(
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args)
else:
scales = GroupQuantScaleParameter(output_dim=1,
input_dim=0,
**weight_scale_args)
qzeros = PackedvLLMParameter(
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args)
layer.register_parameter("qweight", qweight)
layer.register_parameter("g_idx", g_idx)
layer.register_parameter("scales", scales)
layer.register_parameter("qzeros", qzeros)
self.kernel = kernel_type(
mp_linear_kernel_config,
w_q_param_name="qweight",
w_s_param_name="scales",
w_zp_param_name="qzeros",
w_gidx_param_name="g_idx",
bitblas_quant_config=self.quant_config,
)
# Initialize or retrieve the BitBLAS matrix multiplication operator.
self.kernel.configure_bitblas_matmul(
input_size_per_partition,
output_size_per_partition,
params_dtype=params_dtype,
bias=False,
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.kernel.process_weights_after_loading(layer)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
out = self.kernel.apply_gptq_bitblas_linear(layer, x)
if bias is not None:
out.add_(bias)
return out

View File

@ -5,6 +5,8 @@ from typing import List, Optional, Type
import vllm.envs as envs
from vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark import ( # noqa: E501
AllSparkLinearKernel)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.bitblas import ( # noqa: E501
BitBLASLinearKernel)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501
ExllamaLinearKernel)
from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501
@ -20,6 +22,7 @@ _POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [
MacheteLinearKernel,
AllSparkLinearKernel,
MarlinLinearKernel,
BitBLASLinearKernel,
ExllamaLinearKernel,
]
@ -76,4 +79,4 @@ def choose_mp_linear_kernel(
raise ValueError(
"Failed to find a kernel that can implement the "\
"WNA16 linear layer. Reasons: \n"
+ '\n'.join(failure_reasons))
+ '\n'.join(failure_reasons))

View File

@ -0,0 +1,299 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Dict, List, Optional, Tuple
import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
BITBLAS_OPTIMIZE_FEATURES, BITBLAS_SUPPORTED_GROUP_SIZES,
MINIMUM_BITBLAS_VERSION, bitblas_make_empty_g_idx, bitblas_sort_g_idx,
check_bitblas_supports_shape, query_bitblas_supported_quant_types,
unpack_gptq_qweight, unpack_gptq_qzeros)
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
logger = init_logger(__name__)
class BitBLASLinearKernel(MPLinearKernel):
OPT_FEATURES: List[int] = BITBLAS_OPTIMIZE_FEATURES
ENABLE_TUNING: bool = True
MATMUL_LAYOUT: str = "nt"
BITBLAS_DTYPES: Dict[torch.dtype, str] = {
torch.float32: "float32",
torch.float16: "float16",
torch.bfloat16: "bfloat16",
torch.half: "float16",
torch.int8: "int8",
}
bitblas_matmul: object = None
def __init__(
self,
c: MPLinearLayerConfig,
w_q_param_name: str,
w_s_param_name: str,
w_zp_param_name: Optional[str] = None,
w_gidx_param_name: Optional[str] = None,
bitblas_quant_config: Optional[QuantizationConfig] = None,
):
self.quant_config = bitblas_quant_config
super().__init__(c, w_q_param_name, w_s_param_name, w_zp_param_name,
w_gidx_param_name)
def repack_bitblas_from_gptq(
self,
b_q_weight: torch.Tensor,
scales: torch.Tensor,
qzeros: Optional[torch.Tensor] = None,
):
from bitblas.quantization.utils import general_compress
assert self.bitblas_matmul is not None, "bitblas_matmul is None"
quant_config = self.quant_config
# qweight in gptq old quant linear stored with
# (outfeatures, infeatures), should be transposed.
qweight = b_q_weight.T.contiguous().view(
quant_config.torch_storage_dtype) # type: ignore[union-attr]
intweight = unpack_gptq_qweight(
qweight,
quant_config.weight_bits).contiguous() # type: ignore[union-attr]
if self.bitblas_matmul.weight_transform is not None: # type: ignore[attr-defined]
qweight = self.bitblas_matmul.weight_transform( # type: ignore[attr-defined]
intweight.cpu()).cuda()
# scales in gptq old quant linear stored with
# (infeatures // group_size, outfeatures), should be transposed.
scales = scales.T.contiguous()
if qzeros is None:
return qweight, scales, None
# qzeros should be de-quantized to int zeros.
weight_bits = quant_config.weight_bits # type: ignore[union-attr]
intzeros = unpack_gptq_qzeros(qzeros, weight_bits).T.contiguous()
zeros: Optional[torch.Tensor] = None
zeros_mode = self.bitblas_matmul.config.zeros_mode # type: ignore[attr-defined]
if zeros_mode == "original":
zeros = intzeros.to(torch.float16).contiguous()
elif zeros_mode == "rescale":
assert zeros is not None, "zeros should not be None"
zeros[:, :] = intzeros.to(torch.float16)[:, :] * scales[:, :]
elif zeros_mode == "quantized":
zeros = (
torch.Tensor(
general_compress(
intzeros.T.contiguous().cpu().numpy(),
weight_bits,
)).to(qweight.device).
to(quant_config.torch_storage_dtype # type: ignore[union-attr]
).contiguous())
else:
raise ValueError("Unsupported zeros type: {}".format(zeros_mode))
return qweight, scales, zeros
@classmethod
def get_min_capability(cls) -> int:
return 70
@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
is_bitblas_installed = True
try:
import bitblas
if bitblas.__version__ < MINIMUM_BITBLAS_VERSION:
raise ImportError(
"bitblas version is wrong. Please "
f"install bitblas>={MINIMUM_BITBLAS_VERSION}")
except ImportError:
is_bitblas_installed = False
if not is_bitblas_installed:
return False, "bitblas is not installed. Please install bitblas "\
"by running `pip install bitblas>="\
f"{MINIMUM_BITBLAS_VERSION}`"
quant_types = query_bitblas_supported_quant_types(c.zero_points)
if c.weight_type not in quant_types:
return False, (f"Quant type ({c.weight_type}) not supported by"
f" BitBLAS, supported types are: {quant_types}")
if c.group_size not in BITBLAS_SUPPORTED_GROUP_SIZES:
return False, (f"Group size ({c.group_size}) not supported by "
"BitBLAS, supported group sizes are: "
f"{BITBLAS_SUPPORTED_GROUP_SIZES}")
return check_bitblas_supports_shape(
c.partition_weight_shape[1], # out_features
c.partition_weight_shape[0], # in_features
c.full_weight_shape[0], # in_features
c.group_size)
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = getattr(layer, self.w_q_name).device
c = self.config
quant_config = self.quant_config
# Default names since bitblas requires empty parameters for these,
# TODO: remove this requirement from bitblas (allow optional tensors)
if self.w_gidx_name is None:
self.w_gidx_name = "g_idx"
if self.w_zp_name is None:
self.w_zp_name = "qzeros"
if c.has_g_idx:
g_idx, g_idx_sort_indices = bitblas_sort_g_idx(
getattr(layer, self.w_gidx_name))
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
else:
setattr(layer, self.w_gidx_name, bitblas_make_empty_g_idx(device))
layer.g_idx_sort_indices = bitblas_make_empty_g_idx(device)
if c.zero_points:
raise NotImplementedError("Zero points not supported by BitBLAS")
else:
setattr(layer, self.w_zp_name, bitblas_make_empty_g_idx(device))
# Repack weights
bitblas_qweight, bitblas_scales, bitblas_qzeros = (
self.repack_bitblas_from_gptq(
layer.qweight,
layer.scales,
None if quant_config.is_sym else # type: ignore[union-attr]
layer.qzeros, # type: ignore[union-attr]
))
replace_parameter(layer, self.w_q_name, bitblas_qweight)
replace_parameter(layer, self.w_s_name, bitblas_scales)
if bitblas_qzeros is not None:
replace_parameter(layer, self.w_zp_name, bitblas_qzeros)
def configure_bitblas_matmul(
self,
infeatures: int,
outfeatures: int,
params_dtype: torch.dtype,
bias: bool,
) -> None:
enable_tuning = self.ENABLE_TUNING
layout = self.MATMUL_LAYOUT
bits = self.quant_config.weight_bits # type: ignore[union-attr]
self._configure_bitblas_matmul(
infeatures,
outfeatures,
params_dtype,
enable_tuning,
bias,
layout,
bits,
)
def _configure_bitblas_matmul(
self,
infeatures,
outfeatures,
params_dtype,
enable_tuning,
bias,
layout,
bits,
):
from bitblas import MatmulConfig
bitblas_dtype = self.BITBLAS_DTYPES[params_dtype]
quant_config = self.quant_config
with_scaling = False
with_zeros = False
group_size = quant_config.group_size # type: ignore[union-attr]
zeros_mode = quant_config.zeros_mode # type: ignore[union-attr]
if quant_config.quant_method == "gptq": # type: ignore[union-attr]
with_scaling = True
with_zeros = True
W_dtype = f"uint{bits}"
if quant_config.is_sym: # type: ignore[union-attr]
with_zeros = False
W_dtype = f"int{bits}"
else:
raise ValueError(
f"Unsupported quant_method {quant_config.quant_method}" # type: ignore[union-attr]
) # type: ignore[union-attr]
matmul_config = MatmulConfig(
M=self.OPT_FEATURES,
N=outfeatures,
K=infeatures,
A_dtype=bitblas_dtype,
W_dtype=W_dtype,
out_dtype=bitblas_dtype,
accum_dtype="int32" if bitblas_dtype == "int8" else bitblas_dtype,
storage_dtype=quant_config. # type: ignore[union-attr]
storage_dtype, # type: ignore[union-attr]
with_scaling=with_scaling,
with_zeros=with_zeros,
group_size=group_size,
with_bias=bias,
layout=layout,
zeros_mode=zeros_mode,
)
self.bitblas_matmul = self._get_or_create_bitblas_operator(
matmul_config, enable_tuning)
def _get_or_create_bitblas_operator(self, config, enable_tuning):
from bitblas import Matmul, auto_detect_nvidia_target
from bitblas.cache import get_database_path, global_operator_cache
BITBLAS_DATABASE_PATH = get_database_path()
BITBLAS_TARGET = auto_detect_nvidia_target()
if global_operator_cache.size() == 0:
global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH,
BITBLAS_TARGET)
bitblas_matmul = global_operator_cache.get(config)
if bitblas_matmul is None:
bitblas_matmul = Matmul(config,
target=BITBLAS_TARGET,
enable_tuning=False)
if enable_tuning:
bitblas_matmul.hardware_aware_finetune(topk=20)
global_operator_cache.add(config, bitblas_matmul)
global_operator_cache.save_into_database(
BITBLAS_DATABASE_PATH, BITBLAS_TARGET)
TUNING_MESSAGE = (
f"BitBLAS Operator {config} tuned and saved to database.")
logger.info(TUNING_MESSAGE)
else:
_message = f"BitBLAS Operator {config} created without tuning. "
logger.info(_message)
else:
_message = f"BitBLAS Operator {config} retrieved from cache."
logger.info(_message)
return bitblas_matmul
def apply_gptq_bitblas_linear(
self,
layer: torch.nn.Module,
x: torch.Tensor,
) -> torch.Tensor:
output_size_per_partition = self.config.partition_weight_shape[1]
out_shape = x.shape[:-1] + (output_size_per_partition, )
args = [x, layer.qweight, layer.scales]
if self.bitblas_matmul.config.with_zeros: # type: ignore[attr-defined]
args.append(layer.qzeros)
output = self.bitblas_matmul(*args) # type: ignore[operator]
return output.view(out_shape)
def apply_weights(self, layer, x, bias=None):
NOT_IMPLEMENT_MESSAGE = (
f"{self.__class__.__name__}.apply_weights is not implemented. "
"Please use BitBLASLinearKernel.apply_gptq_bitblas_linear instead")
raise NotImplementedError(NOT_IMPLEMENT_MESSAGE)

View File

@ -0,0 +1,198 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple
import torch
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
MINIMUM_BITBLAS_VERSION = "0.1.0"
BITBLAS_MIN_WEIGHT_SIZE_N = 16
BITBLAS_MIN_WEIGHT_SIZE_K = 16
GPTQ_BITBLAS_MAX_PARALLEL = 16
BITBLAS_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
# For dynamic shape code generation
BITBLAS_OPTIMIZE_FEATURES = [1, 16, 32, 64, 128, 256, 512, 1024]
# If want to enable high performance for contiguous batching
# Please use the following values
BITBLAS_OPTIMIZE_FEATURES_CONTIGUOUS = [16, 32, 64, 128, 256, 512, 1024]
BITBLAS_SUPPORTED_NUM_BITS = [1, 2, 4, 8]
BITBLAS_SUPPORTED_SYM = [False, True]
# Determines the supported quantization types for BitBLAS based on the
# device's capability and whether zero-point (zp) is used.
def query_bitblas_supported_quant_types(has_zp: bool,
device_capability: Optional[int] = None
):
if device_capability is None:
capability_tuple = current_platform.get_device_capability()
device_capability = (-1 if capability_tuple is None else
capability_tuple.to_int())
if device_capability < 70:
return []
if has_zp:
# AWQ style, unsigned + runtime zero-point
return [scalar_types.uint4, scalar_types.uint8]
else:
# GPTQ style, unsigned + symmetric bias
# TODO: once fp8_bitblas is merged into "gptq_bitblas" we should be able
# to add `scalar_types.float8_e4m3fn` here
return [scalar_types.uint4b8, scalar_types.uint8b128]
def _check_bitblas_supported(
quant_type: ScalarType,
group_size: Optional[int],
has_zp: bool,
device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
if device_capability is None:
capability_tuple = current_platform.get_device_capability()
device_capability = (-1 if capability_tuple is None else
capability_tuple.to_int())
supported_types = query_bitblas_supported_quant_types(
has_zp, device_capability)
if quant_type not in supported_types:
return (False, f"BitBLAS does not support weight_bits = {quant_type}. "
f"Only types = {supported_types} "
f"are supported (for group_size = {group_size}, "
f"device_capability = {device_capability}, zp = {has_zp}).")
if (group_size is None or group_size not in BITBLAS_SUPPORTED_GROUP_SIZES):
return (False, f"BitBLAS does not support group_size = {group_size}. "
f"Only group_sizes = {BITBLAS_SUPPORTED_GROUP_SIZES} "
"are supported.")
return True, None
def check_bitblas_supported(quant_type: ScalarType,
group_size: int,
has_zp: bool = False,
device_capability: Optional[int] = None) -> bool:
cond, _ = _check_bitblas_supported(quant_type, group_size, has_zp,
device_capability)
return cond
def verify_bitblas_supported(quant_type: ScalarType,
group_size: int,
has_zp: bool = False) -> None:
cond, err_msg = _check_bitblas_supported(quant_type, group_size, has_zp)
if not cond:
assert err_msg is not None
raise ValueError(err_msg)
def verify_bitblas_supports_shape(output_size_per_partition: int,
input_size_per_partition: int,
input_size: int, group_size: int) -> None:
# Validate output_size_per_partition
if output_size_per_partition % BITBLAS_MIN_WEIGHT_SIZE_N != 0:
raise ValueError(f"Weight output_size_per_partition = "
f"{output_size_per_partition} is not divisible by "
f" min_thread_n = {BITBLAS_MIN_WEIGHT_SIZE_N}. "
"Consider reducing tensor_parallel_size or running "
"with --quantization gptq.")
# Validate input_size_per_partition
if input_size_per_partition % BITBLAS_MIN_WEIGHT_SIZE_K != 0:
raise ValueError(f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible "
f"by min_thread_k = {BITBLAS_MIN_WEIGHT_SIZE_K}. "
"Consider reducing tensor_parallel_size or running "
"with --quantization gptq.")
if (group_size < input_size
and input_size_per_partition % group_size != 0):
raise ValueError(
f"Weight input_size_per_partition = {input_size_per_partition}"
f" is not divisible by group_size = {group_size}."
"Consider reducing tensor_parallel_size or running "
"with --quantization gptq.")
def check_bitblas_supports_shape(output_size_per_partition: int,
input_size_per_partition: int,
input_size: int, group_size: int) \
-> Tuple[bool, Optional[str]]:
try:
verify_bitblas_supports_shape(output_size_per_partition,
input_size_per_partition, input_size,
group_size)
except ValueError as e:
return False, e.__str__()
return True, None
def bitblas_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
return (not act_order) or (act_order and not is_row_parallel)
def bitblas_repeat_scales_on_all_ranks(act_order: bool, group_size: int,
is_row_parallel: bool) -> bool:
# Need to repeat scales on every rank if act_ordering or
# channelwise and RowParallelLinear
is_channelwise = group_size == -1
return act_order or (is_channelwise and is_row_parallel)
def bitblas_make_empty_g_idx(device: torch.device) -> torch.Tensor:
return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
requires_grad=False)
def bitblas_make_empty_zp(device: torch.device) -> torch.Tensor:
return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
requires_grad=False)
def bitblas_sort_g_idx(
g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
return g_idx[g_idx_sort_indices], g_idx_sort_indices
def unpack_gptq_qzeros(qzeros, bits, is_gptq_v2=False) -> torch.Tensor:
qzeros = qzeros.view(torch.int32)
elems_per_int32 = 32 // bits
unpacked_zeros = torch.zeros(
(qzeros.shape[0], qzeros.shape[1] * elems_per_int32),
dtype=torch.int8,
device=qzeros.device,
requires_grad=False,
)
for col in range(unpacked_zeros.shape[1]):
i = col % elems_per_int32
unpacked_zeros[:, col] = (qzeros[:, col // elems_per_int32] >>
(bits * i)) & 0xF
if not is_gptq_v2:
return unpacked_zeros + 1
return unpacked_zeros
def unpack_gptq_qweight(qweight, bits):
qweight = qweight.view(torch.int8)
elems_per_int8 = 8 // bits
unpacked_weight = torch.zeros(
(qweight.shape[0], qweight.shape[1] * elems_per_int8),
dtype=torch.int8,
device=qweight.device,
requires_grad=False,
)
for col in range(unpacked_weight.shape[1]):
i = col % elems_per_int8
unpacked_weight[:, col] = (qweight[:, col // elems_per_int8] >>
(bits * i))
return torch.bitwise_and(unpacked_weight, 2**bits - 1)

View File

@ -282,10 +282,12 @@ class PackedColumnParameter(_ColumnvLLMParameter):
packed_factor: Union[int, Fraction],
packed_dim: int,
marlin_tile_size: Optional[int] = None,
bitblas_tile_size: Optional[int] = None,
**kwargs):
self._packed_factor = packed_factor
self._packed_dim = packed_dim
self._marlin_tile_size = marlin_tile_size
self._bitblas_tile_size = bitblas_tile_size
super().__init__(**kwargs)
@property
@ -300,12 +302,17 @@ class PackedColumnParameter(_ColumnvLLMParameter):
def marlin_tile_size(self):
return self._marlin_tile_size
@property
def bitblas_tile_size(self):
return self._bitblas_tile_size
def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
return _adjust_shard_indexes_for_packing(
shard_size=shard_size,
shard_offset=shard_offset,
packed_factor=self.packed_factor,
marlin_tile_size=self.marlin_tile_size)
marlin_tile_size=self.marlin_tile_size,
bitblas_tile_size=self.bitblas_tile_size)
class PackedvLLMParameter(ModelWeightParameter):
@ -323,10 +330,12 @@ class PackedvLLMParameter(ModelWeightParameter):
packed_factor: Union[int, Fraction],
packed_dim: int,
marlin_tile_size: Optional[int] = None,
bitblas_tile_size: Optional[int] = None,
**kwargs):
self._packed_factor = packed_factor
self._packed_dim = packed_dim
self._marlin_tile_size = marlin_tile_size
self._bitblas_tile_size = bitblas_tile_size
super().__init__(**kwargs)
@property
@ -341,12 +350,17 @@ class PackedvLLMParameter(ModelWeightParameter):
def marlin_tile_size(self):
return self._marlin_tile_size
@property
def bitblas_tile_size(self):
return self._bitblas_tile_size
def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
return _adjust_shard_indexes_for_packing(
shard_size=shard_size,
shard_offset=shard_offset,
packed_factor=self.packed_factor,
marlin_tile_size=self.marlin_tile_size)
marlin_tile_size=self.marlin_tile_size,
bitblas_tile_size=self.bitblas_tile_size)
class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
@ -421,8 +435,13 @@ def _adjust_shard_indexes_for_marlin(shard_size, shard_offset,
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
def _adjust_shard_indexes_for_bitblas(shard_size, shard_offset,
bitblas_tile_size):
return shard_size // bitblas_tile_size, shard_offset // bitblas_tile_size
def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor,
marlin_tile_size):
marlin_tile_size, bitblas_tile_size):
shard_size = shard_size // packed_factor
shard_offset = shard_offset // packed_factor
if marlin_tile_size is not None:
@ -430,4 +449,10 @@ def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor,
shard_size=shard_size,
shard_offset=shard_offset,
marlin_tile_size=marlin_tile_size)
return shard_size, shard_offset
elif bitblas_tile_size is not None:
return _adjust_shard_indexes_for_bitblas(
shard_size=shard_size,
shard_offset=shard_offset,
bitblas_tile_size=bitblas_tile_size)
return shard_size, shard_offset