mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:35:01 +08:00
[Kernel] Support Microsoft Runtime Kernel Lib for our Low Precision Computation - BitBLAS (#6036)
Signed-off-by: xinyuxiao <xinyuxiao2024@gmail.com> Co-authored-by: xinyuxiao <xinyuxiao2024@gmail.com>
This commit is contained in:
parent
c4ab9f3e71
commit
8d32dc603d
236
benchmarks/kernels/benchmark_bitblas.py
Normal file
236
benchmarks/kernels/benchmark_bitblas.py
Normal file
@ -0,0 +1,236 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
|
||||
MINIMUM_BITBLAS_VERSION)
|
||||
|
||||
try:
|
||||
import bitblas
|
||||
if bitblas.__version__ < MINIMUM_BITBLAS_VERSION:
|
||||
raise ImportError("bitblas version is wrong. Please "
|
||||
f"install bitblas>={MINIMUM_BITBLAS_VERSION}")
|
||||
except ImportError as e:
|
||||
bitblas_import_exception = e
|
||||
raise ValueError("Trying to use the bitblas backend, but could not import"
|
||||
f"with the following error: {bitblas_import_exception}. "
|
||||
"Please install bitblas through the following command: "
|
||||
f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`"
|
||||
) from bitblas_import_exception
|
||||
|
||||
from bitblas import Matmul, MatmulConfig, auto_detect_nvidia_target
|
||||
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark BitBLAS int4 on a specific target.")
|
||||
|
||||
# Add arguments to the parser
|
||||
parser.add_argument(
|
||||
"--target",
|
||||
type=str,
|
||||
default=auto_detect_nvidia_target(),
|
||||
help="Specify the target device for benchmarking.",
|
||||
)
|
||||
parser.add_argument("--group_size",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Group size for grouped quantization.")
|
||||
parser.add_argument(
|
||||
"--A_dtype",
|
||||
type=str,
|
||||
default="float16",
|
||||
choices=["float16", "float32", "float64", "int32", "int8"],
|
||||
help="Data type of activation A.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--W_dtype",
|
||||
type=str,
|
||||
default="int4",
|
||||
choices=[
|
||||
"float16",
|
||||
"float32",
|
||||
"float64",
|
||||
"int32",
|
||||
"int8",
|
||||
"int4",
|
||||
"int2",
|
||||
"int1",
|
||||
"nf4",
|
||||
"fp4_e2m1",
|
||||
],
|
||||
help="Data type of weight W.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--accum_dtype",
|
||||
type=str,
|
||||
default="float16",
|
||||
choices=["float16", "int32"],
|
||||
help="Data type for accumulation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--out_dtype",
|
||||
type=str,
|
||||
default="float16",
|
||||
choices=["float16", "float32", "int32", "int8"],
|
||||
help="Data type for output.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--layout",
|
||||
type=str,
|
||||
default="nt",
|
||||
choices=["nt", "nn"],
|
||||
help="Matrix layout, 'nt' for non-transpose A and transpose W.",
|
||||
)
|
||||
parser.add_argument("--with_bias",
|
||||
action="store_true",
|
||||
help="Include bias in the benchmark.")
|
||||
parser.add_argument(
|
||||
"--with_scaling",
|
||||
action="store_true",
|
||||
help="Include scaling factor in the quantization.",
|
||||
)
|
||||
parser.add_argument("--with_zeros",
|
||||
action="store_true",
|
||||
help="Include zeros in the quantization.")
|
||||
parser.add_argument(
|
||||
"--zeros_mode",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["original", "rescale", "quantized"],
|
||||
help="Specify the mode for calculating zeros.",
|
||||
)
|
||||
|
||||
# Parse the arguments
|
||||
args = parser.parse_args()
|
||||
|
||||
# Assign arguments to variables
|
||||
target = args.target
|
||||
A_dtype = args.A_dtype
|
||||
W_dtype = args.W_dtype
|
||||
accum_dtype = args.accum_dtype
|
||||
out_dtype = args.out_dtype
|
||||
layout = args.layout
|
||||
with_bias = args.with_bias
|
||||
group_size = args.group_size
|
||||
with_scaling = args.with_scaling
|
||||
with_zeros = args.with_zeros
|
||||
zeros_mode = args.zeros_mode
|
||||
|
||||
# Define a list of shared arguments that repeat in every config
|
||||
shared_args = [
|
||||
A_dtype,
|
||||
W_dtype,
|
||||
out_dtype,
|
||||
accum_dtype,
|
||||
layout,
|
||||
with_bias,
|
||||
group_size,
|
||||
with_scaling,
|
||||
with_zeros,
|
||||
zeros_mode,
|
||||
]
|
||||
|
||||
# Define just the (M, K, N) shapes in a more compact list
|
||||
shapes = [
|
||||
# square test
|
||||
(1, 16384, 16384),
|
||||
# BLOOM-176B
|
||||
(1, 43008, 14336),
|
||||
(1, 14336, 14336),
|
||||
(1, 57344, 14336),
|
||||
(1, 14336, 57344),
|
||||
# OPT-65B
|
||||
(1, 9216, 9216),
|
||||
(1, 36864, 9216),
|
||||
(1, 9216, 36864),
|
||||
(1, 22016, 8192),
|
||||
# LLAMA-70B/65B
|
||||
(1, 8192, 22016),
|
||||
(1, 8192, 8192),
|
||||
(1, 28672, 8192),
|
||||
(1, 8192, 28672),
|
||||
# square test
|
||||
(16384, 16384, 16384),
|
||||
# BLOOM-176B
|
||||
(8192, 43008, 14336),
|
||||
(8192, 14336, 14336),
|
||||
(8192, 57344, 14336),
|
||||
(8192, 14336, 57344),
|
||||
# OPT-65B
|
||||
(8192, 9216, 9216),
|
||||
(8192, 36864, 9216),
|
||||
(8192, 9216, 36864),
|
||||
(8192, 22016, 8192),
|
||||
# LLAMA-70B/65B
|
||||
(8192, 8192, 22016),
|
||||
(8192, 8192, 8192),
|
||||
(8192, 28672, 8192),
|
||||
(8192, 8192, 28672),
|
||||
]
|
||||
|
||||
# Build test shapes with all the shared arguments
|
||||
test_shapes = [(MatmulConfig, Matmul, (*shape, *shared_args))
|
||||
for shape in shapes]
|
||||
|
||||
benchmark_sets = []
|
||||
benchmark_sets.extend(test_shapes)
|
||||
|
||||
benchmark_results = {}
|
||||
for config_class, operator, input_args in benchmark_sets:
|
||||
config = config_class(*input_args)
|
||||
matmul = operator(config, target=target, enable_tuning=True)
|
||||
kernel_latency = matmul.profile_latency()
|
||||
|
||||
print("Time cost is: {:.3f} ms".format(kernel_latency))
|
||||
|
||||
profile_config = {
|
||||
f"{operator.__name__}-{'-'.join([str(i) for i in input_args])}": {
|
||||
"BitBLAS_top20_latency": kernel_latency,
|
||||
}
|
||||
}
|
||||
|
||||
benchmark_results.update(profile_config)
|
||||
|
||||
# Define headers for the table
|
||||
headers = [
|
||||
"PrimFunc",
|
||||
"Input Arguments",
|
||||
"BitBLAS Top20 Latency",
|
||||
]
|
||||
|
||||
# Calculate column widths for pretty printing
|
||||
col_widths = [0, 0, 0]
|
||||
for config_key, values in benchmark_results.items():
|
||||
args_split = config_key.split("-")
|
||||
func_name = args_split[0]
|
||||
input_args_str = "-".join(args_split[1:])
|
||||
col_widths[0] = max(col_widths[0], len(func_name) + 2, len(headers[0]) + 2)
|
||||
col_widths[1] = max(col_widths[1],
|
||||
len(input_args_str) + 2,
|
||||
len(headers[1]) + 2)
|
||||
col_widths[2] = max(col_widths[2],
|
||||
len(f"{values['BitBLAS_top20_latency']:.3f} ms") + 2,
|
||||
len(headers[2]) + 2)
|
||||
# break only if you want to measure widths from a single example;
|
||||
# otherwise, let it loop over all items.
|
||||
|
||||
# Print header
|
||||
for i, header in enumerate(headers):
|
||||
headers[i] = header.ljust(col_widths[i])
|
||||
print("".join(headers))
|
||||
print("-" * sum(col_widths))
|
||||
|
||||
# Print rows
|
||||
for config_key, values in benchmark_results.items():
|
||||
args_split = config_key.split("-")
|
||||
func_name = args_split[0]
|
||||
input_args_str = "-".join(args_split[1:])
|
||||
row = [
|
||||
func_name,
|
||||
input_args_str,
|
||||
f"{values['BitBLAS_top20_latency']:.3f} ms",
|
||||
]
|
||||
row_str = "".join(
|
||||
[str(cell).ljust(col_widths[idx]) for idx, cell in enumerate(row)])
|
||||
print(row_str)
|
||||
40
docs/source/features/quantization/bitblas.md
Normal file
40
docs/source/features/quantization/bitblas.md
Normal file
@ -0,0 +1,40 @@
|
||||
# BitBLAS
|
||||
|
||||
vLLM now supports [BitBLAS](https://github.com/microsoft/BitBLAS) for more efficient and flexible model inference. Compared to other quantization frameworks, BitBLAS provides more precision combinations.
|
||||
|
||||
Below are the steps to utilize BitBLAS with vLLM.
|
||||
|
||||
```console
|
||||
pip install bitblas>=0.1.0
|
||||
```
|
||||
|
||||
vLLM reads the model's config file and supports pre-quantized checkpoints.
|
||||
|
||||
You can find pre-quantized models on:
|
||||
|
||||
- [Hugging Face (BitBLAS)](https://huggingface.co/models?other=bitblas)
|
||||
- [Hugging Face (GPTQ)](https://huggingface.co/models?other=gptq)
|
||||
|
||||
Usually, these repositories have a `quantize_config.json` file that includes a `quantization_config` section.
|
||||
|
||||
## Read bitblas format checkpoint
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
import torch
|
||||
|
||||
# "hxbgsyxh/llama-13b-4bit-g-1-bitblas" is a pre-quantized checkpoint.
|
||||
model_id = "hxbgsyxh/llama-13b-4bit-g-1-bitblas"
|
||||
llm = LLM(model=model_id, dtype=torch.bfloat16, trust_remote_code=True, quantization="bitblas")
|
||||
```
|
||||
|
||||
## Read gptq format checkpoint
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
import torch
|
||||
|
||||
# "hxbgsyxh/llama-13b-4bit-g-1" is a pre-quantized checkpoint.
|
||||
model_id = "hxbgsyxh/llama-13b-4bit-g-1"
|
||||
llm = LLM(model=model_id, dtype=torch.float16, trust_remote_code=True, quantization="bitblas", max_model_len=1024)
|
||||
```
|
||||
@ -11,6 +11,7 @@ Quantization trades off model precision for smaller memory footprint, allowing l
|
||||
supported_hardware
|
||||
auto_awq
|
||||
bnb
|
||||
bitblas
|
||||
gguf
|
||||
gptqmodel
|
||||
int4
|
||||
|
||||
@ -74,6 +74,17 @@ The table below shows the compatibility of various quantization implementations
|
||||
* ❌
|
||||
* ❌
|
||||
* ❌
|
||||
- * BitBLAS (GPTQ)
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
* ❌
|
||||
* ❌
|
||||
* ❌
|
||||
* ❌
|
||||
- * AQLM
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
|
||||
63
tests/models/test_bitblas.py
Normal file
63
tests/models/test_bitblas.py
Normal file
@ -0,0 +1,63 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Compare the outputs of a GPTQ model to a bitblas model.
|
||||
|
||||
Note: GPTQ and bitblas do not have bitwise correctness.
|
||||
As a result, in this test, we just confirm that the top selected tokens of the
|
||||
bitblas/GPTQ models are in the top 3 selections of each other.
|
||||
|
||||
Note: bitblas internally uses locks to synchronize the threads. This can
|
||||
result in very slight nondeterminism for bitblas. As a result, we re-run the
|
||||
test up to 3 times to see if we pass.
|
||||
|
||||
Run `pytest tests/models/test_bitblas.py`.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from .utils import check_logprobs_close
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelPair:
|
||||
model_bitblas: str
|
||||
model_gptq: str
|
||||
|
||||
|
||||
model_pairs = [
|
||||
ModelPair(model_bitblas="hxbgsyxh/opt-125m-4bit-128g-bitblas",
|
||||
model_gptq="hxbgsyxh/opt-125m-4bit-128g"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.flaky(reruns=2)
|
||||
@pytest.mark.skipif(True, reason="BitBLAS takes too much time for tuning.")
|
||||
@pytest.mark.parametrize("model_pair", model_pairs)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_models(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model_pair: ModelPair,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
with vllm_runner(model_pair.model_bitblas,
|
||||
dtype=dtype,
|
||||
quantization="bitblas") as bitblas_model:
|
||||
bitblas_outputs = bitblas_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
with vllm_runner(model_pair.model_gptq, dtype=dtype,
|
||||
quantization="gptq") as gptq_model:
|
||||
gptq_outputs = gptq_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=gptq_outputs,
|
||||
outputs_1_lst=bitblas_outputs,
|
||||
name_0="gptq",
|
||||
name_1="bitblas",
|
||||
)
|
||||
61
tests/models/test_gptq_bitblas.py
Normal file
61
tests/models/test_gptq_bitblas.py
Normal file
@ -0,0 +1,61 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Compare the outputs of a GPTQ model to a bitblas model.
|
||||
|
||||
Note: GPTQ and bitblas do not have bitwise correctness.
|
||||
As a result, in this test, we just confirm that the top selected tokens of the
|
||||
bitblas/GPTQ models are in the top 3 selections of each other.
|
||||
|
||||
Note: bitblas internally uses locks to synchronize the threads. This can
|
||||
result in very slight nondeterminism for bitblas. As a result, we re-run the
|
||||
test up to 3 times to see if we pass.
|
||||
|
||||
Run `pytest tests/models/test_bitblas.py`.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from .utils import check_logprobs_close
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelPair:
|
||||
model_gptq: str
|
||||
|
||||
|
||||
model_pairs = [
|
||||
ModelPair(model_gptq="hxbgsyxh/opt-125m-4bit-128g"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.flaky(reruns=2)
|
||||
@pytest.mark.skipif(True, reason="BitBLAS takes too much time for tuning.")
|
||||
@pytest.mark.parametrize("model_pair", model_pairs)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_models(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model_pair: ModelPair,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
with vllm_runner(model_pair.model_gptq,
|
||||
dtype=dtype,
|
||||
quantization="bitblas") as bitblas_model:
|
||||
bitblas_outputs = bitblas_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
with vllm_runner(model_pair.model_gptq, dtype=dtype,
|
||||
quantization="gptq") as gptq_model:
|
||||
gptq_outputs = gptq_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=gptq_outputs,
|
||||
outputs_1_lst=bitblas_outputs,
|
||||
name_0="gptq",
|
||||
name_1="gptq_bitblas",
|
||||
)
|
||||
@ -750,7 +750,8 @@ class ModelConfig:
|
||||
optimized_quantization_methods = [
|
||||
"fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
|
||||
"awq_marlin", "fbgemm_fp8", "compressed_tensors",
|
||||
"compressed-tensors", "experts_int8", "quark", "nvfp4"
|
||||
"compressed-tensors", "experts_int8", "quark", "nvfp4", "bitblas",
|
||||
"gptq_bitblas"
|
||||
]
|
||||
if self.quantization is not None:
|
||||
self.quantization = self.quantization.lower()
|
||||
|
||||
@ -31,6 +31,8 @@ logger = init_logger(__name__)
|
||||
|
||||
WEIGHT_LOADER_V2_SUPPORTED = [
|
||||
"CompressedTensorsLinearMethod",
|
||||
"BitBLASLinearMethod",
|
||||
"GPTQBitBLASLinearMethod",
|
||||
"AWQMarlinLinearMethod",
|
||||
"AWQLinearMethod",
|
||||
"GPTQMarlinLinearMethod",
|
||||
@ -50,6 +52,15 @@ WEIGHT_LOADER_V2_SUPPORTED = [
|
||||
]
|
||||
|
||||
|
||||
def adjust_bitblas_shard(param, shard_size, shard_offset):
|
||||
bitblas_tile_size = getattr(param, "bitblas_tile_size", None)
|
||||
if bitblas_tile_size is not None:
|
||||
return (shard_size // bitblas_tile_size,
|
||||
shard_offset // bitblas_tile_size)
|
||||
|
||||
return shard_size, shard_offset
|
||||
|
||||
|
||||
def adjust_marlin_shard(param, shard_size, shard_offset):
|
||||
marlin_tile_size = getattr(param, "marlin_tile_size", None)
|
||||
if marlin_tile_size is None:
|
||||
@ -615,6 +626,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
shard_size, shard_offset = adjust_marlin_shard(
|
||||
param, shard_size, shard_offset)
|
||||
|
||||
shard_size, shard_offset = adjust_bitblas_shard(
|
||||
param, shard_size, shard_offset)
|
||||
|
||||
if use_bitsandbytes_4bit:
|
||||
index = list(itertools.accumulate([0] + self.output_sizes))
|
||||
orig_offsets = {
|
||||
@ -646,6 +660,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
# Special case for Marlin.
|
||||
shard_size, shard_offset = adjust_marlin_shard(
|
||||
param, shard_size, shard_offset)
|
||||
shard_size, shard_offset = adjust_bitblas_shard(
|
||||
param, shard_size, shard_offset)
|
||||
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
|
||||
False)
|
||||
|
||||
@ -18,9 +18,11 @@ QUANTIZATION_METHODS: List[str] = [
|
||||
# The order of gptq methods is important for config.py iteration over
|
||||
# override_quantization_method(..)
|
||||
"marlin",
|
||||
"bitblas",
|
||||
"gguf",
|
||||
"gptq_marlin_24",
|
||||
"gptq_marlin",
|
||||
"gptq_bitblas",
|
||||
"awq_marlin",
|
||||
"gptq",
|
||||
"compressed-tensors",
|
||||
@ -85,6 +87,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
||||
from .aqlm import AQLMConfig
|
||||
from .awq import AWQConfig
|
||||
from .awq_marlin import AWQMarlinConfig
|
||||
from .bitblas import BitBLASConfig
|
||||
from .bitsandbytes import BitsAndBytesConfig
|
||||
from .compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||
CompressedTensorsConfig)
|
||||
@ -94,6 +97,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
||||
from .fp8 import Fp8Config
|
||||
from .gguf import GGUFConfig
|
||||
from .gptq import GPTQConfig
|
||||
from .gptq_bitblas import GPTQBitBLASConfig
|
||||
from .gptq_marlin import GPTQMarlinConfig
|
||||
from .gptq_marlin_24 import GPTQMarlin24Config
|
||||
from .hqq_marlin import HQQMarlinConfig
|
||||
@ -119,9 +123,11 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
||||
# The order of gptq methods is important for config.py iteration over
|
||||
# override_quantization_method(..)
|
||||
"marlin": MarlinConfig,
|
||||
"bitblas": BitBLASConfig,
|
||||
"gguf": GGUFConfig,
|
||||
"gptq_marlin_24": GPTQMarlin24Config,
|
||||
"gptq_marlin": GPTQMarlinConfig,
|
||||
"gptq_bitblas": GPTQBitBLASConfig,
|
||||
"awq_marlin": AWQMarlinConfig,
|
||||
"gptq": GPTQConfig,
|
||||
"compressed-tensors": CompressedTensorsConfig,
|
||||
@ -146,4 +152,4 @@ __all__ = [
|
||||
"QuantizationConfig",
|
||||
"get_quantization_config",
|
||||
"QUANTIZATION_METHODS",
|
||||
]
|
||||
]
|
||||
459
vllm/model_executor/layers/quantization/bitblas.py
Normal file
459
vllm/model_executor/layers/quantization/bitblas.py
Normal file
@ -0,0 +1,459 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
|
||||
BITBLAS_OPTIMIZE_FEATURES, BITBLAS_SUPPORTED_NUM_BITS,
|
||||
BITBLAS_SUPPORTED_SYM, MINIMUM_BITBLAS_VERSION)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedvLLMParameter)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BitBLASConfig(QuantizationConfig):
|
||||
"""Config class for BitBLAS.
|
||||
|
||||
Reference: https://github.com/Microsoft/BitBLAS
|
||||
"""
|
||||
TORCH_DTYPE = torch.float16
|
||||
STORAGE_DTYPE = "int8" # assume int8 storage
|
||||
TORCH_STORAGE_DTYPE = getattr(torch, STORAGE_DTYPE)
|
||||
# "original" or "rescale" or "quantized",
|
||||
# gptq_with_bitblas prefer "quantized implementation"
|
||||
ZEROS_MODE = "quantized"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: Optional[int],
|
||||
desc_act: Optional[bool],
|
||||
is_sym: Optional[bool],
|
||||
quant_method: Optional[str],
|
||||
lm_head_quantized: bool,
|
||||
) -> None:
|
||||
try:
|
||||
import bitblas
|
||||
if bitblas.__version__ < MINIMUM_BITBLAS_VERSION:
|
||||
raise ImportError(
|
||||
"bitblas version is wrong. Please "
|
||||
f"install bitblas>={MINIMUM_BITBLAS_VERSION}")
|
||||
except ImportError as e:
|
||||
bitblas_import_exception = e
|
||||
raise ValueError(
|
||||
"Trying to use the bitblas backend, but could not import"
|
||||
f"with the following error: {bitblas_import_exception}. "
|
||||
"Please install bitblas through the following command: "
|
||||
f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`"
|
||||
) from bitblas_import_exception
|
||||
|
||||
if desc_act and group_size == -1:
|
||||
# In this case, act_order == True is the same as act_order == False
|
||||
# (since we have only one group per output channel)
|
||||
desc_act = False
|
||||
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.desc_act = desc_act
|
||||
self.is_sym = is_sym
|
||||
self.quant_method = quant_method
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
|
||||
# Verify
|
||||
if self.weight_bits not in BITBLAS_SUPPORTED_NUM_BITS:
|
||||
raise ValueError(
|
||||
f"BitBLAS does not support weight_bits = {self.weight_bits}. "
|
||||
f"Only weight_bits = {BITBLAS_SUPPORTED_NUM_BITS} "
|
||||
"are supported.")
|
||||
|
||||
if self.is_sym not in BITBLAS_SUPPORTED_SYM:
|
||||
raise ValueError(
|
||||
f"BitBLAS does not support is_sym = {self.is_sym}. "
|
||||
f"Only sym = {BITBLAS_SUPPORTED_SYM} are supported.")
|
||||
|
||||
storage_dtype = self.STORAGE_DTYPE
|
||||
storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))
|
||||
|
||||
self.storage_dtype = storage_dtype
|
||||
self.storage_torch_dtype = self.TORCH_STORAGE_DTYPE
|
||||
# 4 Bits packed into 32 bit datatype.
|
||||
self.pack_factor = storage_nbit // weight_bits
|
||||
self.nbits = weight_bits
|
||||
|
||||
# Zeros type for the quantized weights.
|
||||
self.zeros_mode = self.ZEROS_MODE
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"BitBLASConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"desc_act={self.desc_act}, "
|
||||
f"is_sym={self.is_sym}, "
|
||||
f"quant_method={self.quant_method})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "bitblas"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
# Need to figure it out
|
||||
def get_min_capability(cls) -> int:
|
||||
return 70
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@staticmethod
|
||||
def get_from_keys(config: Dict[str, Any],
|
||||
keys: List[str],
|
||||
default: Any = None) -> Any:
|
||||
"""Get a value from the model's quantization config."""
|
||||
for key in keys:
|
||||
if key in config:
|
||||
return config[key]
|
||||
return default
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "BitBLASConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"], -1)
|
||||
desc_act = cls.get_from_keys(config, ["desc_act"], False)
|
||||
is_sym = cls.get_from_keys(config, ["sym"], False)
|
||||
quant_method = cls.get_from_keys(config, ["quant_method"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
return cls(weight_bits, group_size, desc_act, is_sym, quant_method,
|
||||
lm_head_quantized)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
user_quant) -> Optional[str]:
|
||||
# compat: autogptq >=0.8.0 use checkpoint_format: str
|
||||
# compat: autogptq <=0.7.1 is_bitblas_format: bool
|
||||
is_bitblas_format = (hf_quant_cfg.get("checkpoint_format") == "bitblas"
|
||||
or hf_quant_cfg.get("is_bitblas_format", False))
|
||||
|
||||
is_valid_user_quant = (user_quant is None or user_quant == "gptq"
|
||||
or user_quant == "bitblas")
|
||||
|
||||
if is_bitblas_format and is_valid_user_quant:
|
||||
msg = ("The model is serialized in {} format. Using {} kernel.".
|
||||
format(cls.get_name(), cls.get_name()))
|
||||
logger.info(msg)
|
||||
return cls.get_name()
|
||||
|
||||
return None
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["BitBLASLinearMethod"]:
|
||||
if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead)
|
||||
and self.lm_head_quantized):
|
||||
return BitBLASLinearMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
class BitBLASLinearMethod(LinearMethodBase):
|
||||
"""Linear method for BitBLAS.
|
||||
|
||||
Args:
|
||||
quant_config: The BitBLAS quantization config.
|
||||
"""
|
||||
# USE BITBLAS_OPTIMIZE_FEATURES_CONTIGUOUS
|
||||
# Instead of BITBLAS_OPTIMIZE_FEATURES
|
||||
# If you want to high contiguous batching
|
||||
# performance
|
||||
OPT_FEATURES = BITBLAS_OPTIMIZE_FEATURES
|
||||
ENABLE_TUNING = True
|
||||
BITBLAS_DTYPES = {
|
||||
torch.float32: "float32",
|
||||
torch.float16: "float16",
|
||||
torch.bfloat16: "bfloat16",
|
||||
torch.half: "float16",
|
||||
torch.int8: "int8",
|
||||
}
|
||||
|
||||
def __init__(self, quant_config: BitBLASConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights_gptq(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
"""Creates quantized weights for use in linear operations.
|
||||
|
||||
The function initializes and returns a dictionary containing quantized
|
||||
weights, scales, and zeros
|
||||
for performing quantized matrix multiplication operations.
|
||||
|
||||
Args:
|
||||
input_size_per_partition: The size of the input partition.
|
||||
output_size_per_partition: The size of the output partition.
|
||||
input_size: The total size of the input (unused).
|
||||
output_size: The total size of the output (unused).
|
||||
params_dtype:
|
||||
The data type of the parameters (expected to be torch.float16).
|
||||
|
||||
Returns:
|
||||
A dictionary containing the quantized weights ('qweight'),
|
||||
scales ('scales'), and zeros ('zeros').
|
||||
|
||||
Raises:
|
||||
ValueError: If `params_dtype` is not `torch.float16` or if the
|
||||
input size per partition is not divisible by the group size in
|
||||
`quant_config`.
|
||||
"""
|
||||
del input_size, output_size # Unused arguments.
|
||||
weight_loader = extra_weight_attrs["weight_loader"]
|
||||
|
||||
if params_dtype not in self.quant_config.get_supported_act_dtypes():
|
||||
raise ValueError("Parameter data type must be torch.float16, "
|
||||
f"but got {params_dtype}")
|
||||
group_size = self.quant_config.group_size
|
||||
if group_size is None:
|
||||
group_size = -1
|
||||
# Validate output_size_per_partition
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
if (group_size != -1 and input_size_per_partition % group_size != 0):
|
||||
raise ValueError(
|
||||
f"Input size per partition ({input_size_per_partition}) must "
|
||||
f"be divisible by group size ({group_size}).")
|
||||
|
||||
# Initialize or retrieve the BitBLAS matrix multiplication operator.
|
||||
self._configure_bitblas_matmul(
|
||||
input_size_per_partition,
|
||||
output_size_per_partition,
|
||||
params_dtype=params_dtype,
|
||||
enable_tuning=self.ENABLE_TUNING,
|
||||
bias=False,
|
||||
layout="nt",
|
||||
bits=self.quant_config.weight_bits,
|
||||
)
|
||||
|
||||
# Initialize quantized weights with dimensions
|
||||
# Quantized 4Bit weights packed.
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
self.bitblas_matmul.retrieve_weight_shape(),
|
||||
device="cuda",
|
||||
dtype=self.quant_config.storage_torch_dtype,
|
||||
requires_grad=False,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
bitblas_tile_size=(self.bitblas_matmul.retrieve_weight_shape()[-2]
|
||||
if self.bitblas_matmul.propagate_b else None),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
# Compute the number of input groups for channel-wise quantization.
|
||||
input_groups = (1 if group_size == -1 else input_size_per_partition //
|
||||
group_size)
|
||||
|
||||
# Initialize scales and zeros for the quantized weights.
|
||||
weight_scale_args = {
|
||||
"data":
|
||||
torch.empty(
|
||||
output_size_per_partition,
|
||||
input_groups,
|
||||
device="cuda",
|
||||
dtype=params_dtype,
|
||||
),
|
||||
"weight_loader":
|
||||
weight_loader
|
||||
}
|
||||
if input_groups == 1:
|
||||
scales = ChannelQuantScaleParameter(output_dim=0,
|
||||
**weight_scale_args)
|
||||
else:
|
||||
scales = GroupQuantScaleParameter(output_dim=0,
|
||||
input_dim=1,
|
||||
**weight_scale_args)
|
||||
|
||||
if self.quant_config.zeros_mode == "quantized":
|
||||
zeros = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_groups,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
device="cuda",
|
||||
dtype=self.quant_config.storage_torch_dtype,
|
||||
requires_grad=False,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
else:
|
||||
zeros = BasevLLMParameter(
|
||||
torch.empty(output_size_per_partition,
|
||||
input_groups,
|
||||
device="cuda",
|
||||
dtype=params_dtype),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
# Set attributes to indicate how scales and zeros are applied.
|
||||
set_weight_attrs(zeros, {
|
||||
"input_dim": None if input_groups == 1 else 1,
|
||||
"output_dim": 0,
|
||||
})
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("scales", scales)
|
||||
layer.register_parameter("zeros", zeros)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
if self.quant_config.quant_method == "gptq":
|
||||
return self.create_weights_gptq(layer, input_size_per_partition,
|
||||
output_partition_sizes, input_size,
|
||||
output_size, params_dtype,
|
||||
**extra_weight_attrs)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported quant_method {self.quant_config.quant_method}")
|
||||
|
||||
def _configure_bitblas_matmul(
|
||||
self,
|
||||
infeatures,
|
||||
outfeatures,
|
||||
params_dtype,
|
||||
enable_tuning,
|
||||
bias,
|
||||
layout,
|
||||
bits,
|
||||
out_dtype="float16",
|
||||
):
|
||||
from bitblas import MatmulConfig
|
||||
bitblas_dtype = self.BITBLAS_DTYPES[params_dtype]
|
||||
|
||||
with_scaling = False
|
||||
with_zeros = False
|
||||
group_size = self.quant_config.group_size
|
||||
zeros_mode = self.quant_config.zeros_mode
|
||||
if self.quant_config.quant_method == "gptq":
|
||||
with_scaling = True
|
||||
with_zeros = True
|
||||
W_dtype = f"uint{bits}"
|
||||
if self.quant_config.is_sym:
|
||||
with_zeros = False
|
||||
W_dtype = f"int{bits}"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported quant_method {self.quant_config.quant_method}")
|
||||
|
||||
matmul_config = MatmulConfig(
|
||||
N=outfeatures,
|
||||
K=infeatures,
|
||||
A_dtype=bitblas_dtype,
|
||||
W_dtype=W_dtype,
|
||||
out_dtype=out_dtype,
|
||||
accum_dtype="int32" if bitblas_dtype == "int8" else bitblas_dtype,
|
||||
storage_dtype=self.quant_config.STORAGE_DTYPE,
|
||||
with_scaling=with_scaling,
|
||||
with_zeros=with_zeros,
|
||||
group_size=group_size,
|
||||
with_bias=bias,
|
||||
layout=layout,
|
||||
zeros_mode=zeros_mode,
|
||||
)
|
||||
self.bitblas_matmul = self._get_or_create_bitblas_operator(
|
||||
matmul_config, enable_tuning)
|
||||
|
||||
def _get_or_create_bitblas_operator(self, config, enable_tuning):
|
||||
from bitblas import Matmul, auto_detect_nvidia_target
|
||||
from bitblas.cache import get_database_path, global_operator_cache
|
||||
BITBLAS_DATABASE_PATH = get_database_path()
|
||||
BITBLAS_TARGET = auto_detect_nvidia_target()
|
||||
if global_operator_cache.size() == 0:
|
||||
global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH,
|
||||
BITBLAS_TARGET)
|
||||
|
||||
bitblas_matmul = global_operator_cache.get(config)
|
||||
if bitblas_matmul is None:
|
||||
bitblas_matmul = Matmul(config,
|
||||
target=BITBLAS_TARGET,
|
||||
enable_tuning=False)
|
||||
if enable_tuning:
|
||||
TUNING_MESSAGE = (f"BitBLAS Operator {config} is tuning ...")
|
||||
logger.info(TUNING_MESSAGE)
|
||||
bitblas_matmul.hardware_aware_finetune(topk=20)
|
||||
global_operator_cache.add(config, bitblas_matmul)
|
||||
global_operator_cache.save_into_database(
|
||||
BITBLAS_DATABASE_PATH, BITBLAS_TARGET)
|
||||
TUNED_MESSAGE = (
|
||||
f"BitBLAS Operator {config} tuned and saved to database.")
|
||||
logger.info(TUNED_MESSAGE)
|
||||
else:
|
||||
_message = f"BitBLAS Operator {config} created."
|
||||
logger.info(_message)
|
||||
else:
|
||||
_message = (
|
||||
f"BitBLAS Operator {config} found in global_operator_cache.")
|
||||
logger.info(_message)
|
||||
return bitblas_matmul
|
||||
|
||||
def apply_gptq(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
qweight = layer.qweight
|
||||
scales = layer.scales
|
||||
qzeros = layer.zeros
|
||||
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
|
||||
if self.quant_config.is_sym:
|
||||
output_2d = self.bitblas_matmul(x_2d, qweight, scales)
|
||||
else:
|
||||
output_2d = self.bitblas_matmul(x_2d, qweight, scales, qzeros)
|
||||
|
||||
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output
|
||||
|
||||
def apply(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> torch.Tensor:
|
||||
if self.quant_config.quant_method == "gptq":
|
||||
return self.apply_gptq(*args, **kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported quant_method {self.quant_config.quant_method}")
|
||||
438
vllm/model_executor/layers/quantization/gptq_bitblas.py
Normal file
438
vllm/model_executor/layers/quantization/gptq_bitblas.py
Normal file
@ -0,0 +1,438 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
|
||||
BitBLASLinearKernel, MPLinearLayerConfig)
|
||||
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
|
||||
BITBLAS_SUPPORTED_NUM_BITS as GPTQ_BITBLAS_SUPPORTED_NUM_BITS)
|
||||
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
|
||||
BITBLAS_SUPPORTED_SYM as GPTQ_BITBLAS_SUPPORTED_SYM)
|
||||
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
|
||||
MINIMUM_BITBLAS_VERSION, bitblas_repeat_scales_on_all_ranks,
|
||||
check_bitblas_supported, verify_bitblas_supported)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
PackedvLLMParameter,
|
||||
RowvLLMParameter)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class GPTQBitBLASConfig(QuantizationConfig):
|
||||
"""Config class for GPTQ BitBLAS"""
|
||||
|
||||
# (num_bits, is_sym) -> quant_type
|
||||
TYPE_MAP = {
|
||||
(4, True): scalar_types.uint4b8,
|
||||
(8, True): scalar_types.uint8b128,
|
||||
}
|
||||
|
||||
TORCH_DTYPE = torch.float16
|
||||
GPTQ_CKPT_STORAGE_DTYPE = (
|
||||
"int32" # GPTQ Default Checkpoints use int32 as storage dtype
|
||||
)
|
||||
GPTQ_BITBLAS_STORAGE_DTYPE = "int8" # BitBLAS uses int8 as storage dtype
|
||||
TORCH_BITBLAS_STORAGE_DTYPE = getattr(torch, GPTQ_BITBLAS_STORAGE_DTYPE)
|
||||
# "original" or "rescale" or "quantized",
|
||||
# the gptq_bitblas prefer "quantized"
|
||||
ZEROS_MODE = "quantized"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
desc_act: bool,
|
||||
is_sym: bool,
|
||||
quant_method: Optional[str],
|
||||
lm_head_quantized: bool,
|
||||
) -> None:
|
||||
|
||||
try:
|
||||
import bitblas
|
||||
if bitblas.__version__ < MINIMUM_BITBLAS_VERSION:
|
||||
raise ImportError(
|
||||
"bitblas version is wrong. Please "
|
||||
f"install bitblas>={MINIMUM_BITBLAS_VERSION}")
|
||||
except ImportError as e:
|
||||
bitblas_import_exception = e
|
||||
raise ValueError(
|
||||
"Trying to use the bitblas backend, but could not import"
|
||||
f"with the following error: {bitblas_import_exception}. "
|
||||
"Please install bitblas through the following command: "
|
||||
f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`"
|
||||
) from bitblas_import_exception
|
||||
|
||||
if desc_act and group_size == -1:
|
||||
# In this case, act_order == True is the same as act_order == False
|
||||
# (since we have only one group per output channel)
|
||||
desc_act = False
|
||||
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.desc_act = desc_act
|
||||
self.is_sym = is_sym
|
||||
self.quant_method = quant_method
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
|
||||
# Verify
|
||||
if self.weight_bits not in GPTQ_BITBLAS_SUPPORTED_NUM_BITS:
|
||||
raise ValueError(
|
||||
f"BitBLAS does not support weight_bits = {self.weight_bits}. "
|
||||
f"Only weight_bits = {GPTQ_BITBLAS_SUPPORTED_NUM_BITS} "
|
||||
"are supported.")
|
||||
|
||||
if self.is_sym not in GPTQ_BITBLAS_SUPPORTED_SYM:
|
||||
raise ValueError(
|
||||
f"BitBLAS does not support is_sym = {self.is_sym}. "
|
||||
f"Only sym = {GPTQ_BITBLAS_SUPPORTED_SYM} are supported.")
|
||||
|
||||
self.storage_dtype = self.GPTQ_BITBLAS_STORAGE_DTYPE
|
||||
|
||||
storage_nbit = int("".join(c for c in self.GPTQ_CKPT_STORAGE_DTYPE
|
||||
if c.isdigit()))
|
||||
|
||||
# 4 Bits packed into 32 bit datatype.
|
||||
self.pack_factor = storage_nbit // weight_bits
|
||||
self.nbits = weight_bits
|
||||
|
||||
# Zeros type for the quantized weights.
|
||||
self.zeros_mode = self.ZEROS_MODE
|
||||
|
||||
if (weight_bits, is_sym) not in self.TYPE_MAP:
|
||||
raise ValueError("Unsupported quantization config: "
|
||||
f"bits={weight_bits}, sym={is_sym}")
|
||||
|
||||
self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"GPTQBitBLASConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"desc_act={self.desc_act})"
|
||||
f"is_sym={self.is_sym}, "
|
||||
f"quant_method={self.quant_method})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "gptq_bitblas"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 70
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "GPTQBitBLASConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
desc_act = cls.get_from_keys(config, ["desc_act"])
|
||||
is_sym = cls.get_from_keys(config, ["sym"])
|
||||
quant_method = cls.get_from_keys(config, ["quant_method"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
return cls(weight_bits, group_size, desc_act, is_sym, quant_method,
|
||||
lm_head_quantized)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
user_quant) -> Optional[str]:
|
||||
can_convert = cls.is_gptq_bitblas_compatible(hf_quant_cfg)
|
||||
|
||||
is_valid_user_quant = (user_quant is None or user_quant == "bitblas"
|
||||
or user_quant == "gptq_bitblas")
|
||||
|
||||
if can_convert and is_valid_user_quant:
|
||||
msg = ("The model is convertible to {} during runtime."
|
||||
" Using {} kernel.".format(cls.get_name(), cls.get_name()))
|
||||
logger.info(msg)
|
||||
return cls.get_name()
|
||||
|
||||
if can_convert and user_quant == "gptq":
|
||||
logger.info("Detected that the model can run with gptq_bitblas"
|
||||
", however you specified quantization=gptq explicitly,"
|
||||
" so forcing gptq. Use quantization=gptq_bitblas for"
|
||||
" faster inference")
|
||||
return None
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["GPTQBitBLASLinearMethod"]:
|
||||
if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead)
|
||||
and self.lm_head_quantized):
|
||||
return GPTQBitBLASLinearMethod(self)
|
||||
return None
|
||||
|
||||
@property
|
||||
def torch_storage_dtype(self) -> torch.dtype:
|
||||
return self.TORCH_BITBLAS_STORAGE_DTYPE
|
||||
|
||||
@classmethod
|
||||
def is_gptq_bitblas_compatible(cls, quant_config: Dict[str, Any]):
|
||||
# Extract data from quant config.
|
||||
num_bits = quant_config.get("bits")
|
||||
group_size = quant_config.get("group_size")
|
||||
sym = quant_config.get("sym")
|
||||
desc_act = quant_config.get("desc_act")
|
||||
|
||||
# If we cannot find the info needed in the config, cannot convert.
|
||||
if (num_bits is None or group_size is None or sym is None
|
||||
or desc_act is None):
|
||||
return False
|
||||
|
||||
if (num_bits, sym) not in cls.TYPE_MAP:
|
||||
return False
|
||||
|
||||
# If the capability of the device is too low, cannot convert.
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
device_capability = major * 10 + minor
|
||||
if device_capability < cls.get_min_capability():
|
||||
return False
|
||||
|
||||
# Otherwise, can convert if model satisfies bitblas constraints.
|
||||
return check_bitblas_supported(quant_type=cls.TYPE_MAP[(num_bits,
|
||||
sym)],
|
||||
group_size=group_size)
|
||||
|
||||
|
||||
class GPTQBitBLASLinearMethod(LinearMethodBase):
|
||||
"""Linear method for GPTQ BitBLAS.
|
||||
|
||||
Args:
|
||||
quant_config: The GPTQ BitBLAS quantization config.
|
||||
"""
|
||||
|
||||
kernel_type = BitBLASLinearKernel
|
||||
_kernel_backends_being_used: Set[str] = set()
|
||||
|
||||
def __init__(self, quant_config: GPTQBitBLASConfig) -> None:
|
||||
self.quant_config = quant_config
|
||||
# Verify supported on platform.
|
||||
verify_bitblas_supported(quant_type=self.quant_config.quant_type,
|
||||
group_size=self.quant_config.group_size)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
) -> None:
|
||||
"""Creates quantized weights for use in linear operations.
|
||||
|
||||
The function initializes and returns a dictionary containing
|
||||
quantized weights, scales, and zeros
|
||||
for performing quantized matrix multiplication operations.
|
||||
|
||||
Args:
|
||||
input_size_per_partition: The size of the input partition.
|
||||
output_partition_sizes: The size of the output partition.
|
||||
input_size: The total size of the input (unused).
|
||||
output_size: The total size of the output (unused).
|
||||
params_dtype:
|
||||
The data type of the parameters (expected to be torch.float16).
|
||||
|
||||
Returns:
|
||||
A dictionary containing the quantized weights ('qweight'),
|
||||
scales ('scales'), and zeros ('zeros').
|
||||
|
||||
Raises:
|
||||
ValueError: If `params_dtype` is not `torch.float16` or
|
||||
if the input size per partition is not divisible by the
|
||||
group size in `quant_config`.
|
||||
"""
|
||||
if params_dtype != torch.float16:
|
||||
raise ValueError("Parameter data type must be torch.float16, "
|
||||
f"but got {params_dtype}")
|
||||
|
||||
# Normalize group_size
|
||||
if self.quant_config.group_size != -1:
|
||||
group_size = self.quant_config.group_size
|
||||
else:
|
||||
group_size = input_size
|
||||
|
||||
if input_size_per_partition % group_size != 0:
|
||||
raise ValueError(
|
||||
f"Input size per partition ({input_size_per_partition}) must "
|
||||
f"be divisible by group size ({self.quant_config.group_size})."
|
||||
)
|
||||
|
||||
kernel_type = self.kernel_type
|
||||
# Validate output_size_per_partition
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
|
||||
is_row_parallel = input_size != input_size_per_partition
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
|
||||
mp_linear_kernel_config = MPLinearLayerConfig(
|
||||
full_weight_shape=(input_size, output_size),
|
||||
partition_weight_shape=\
|
||||
(input_size_per_partition, output_size_per_partition),
|
||||
weight_type=self.quant_config.quant_type,
|
||||
act_type=params_dtype,
|
||||
group_size=self.quant_config.group_size,
|
||||
zero_points=False,
|
||||
has_g_idx=self.quant_config.desc_act
|
||||
)
|
||||
|
||||
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||
logger.info("Using %s for GPTQBitBLASLinearMethod",
|
||||
kernel_type.__name__)
|
||||
self._kernel_backends_being_used.add(kernel_type.__name__)
|
||||
|
||||
# Normalize group_size
|
||||
if self.quant_config.group_size != -1:
|
||||
group_size = self.quant_config.group_size
|
||||
else:
|
||||
group_size = input_size
|
||||
|
||||
# Determine sharding
|
||||
if bitblas_repeat_scales_on_all_ranks(self.quant_config.desc_act,
|
||||
self.quant_config.group_size,
|
||||
is_row_parallel):
|
||||
# By setting scale_dim == None, weight_loader will
|
||||
# repeat the scales on each GPU in TP>1 case.
|
||||
scales_and_zp_input_dim = None
|
||||
scales_and_zp_size = input_size // group_size
|
||||
else:
|
||||
# By setting scale_dim == 0, weight_loader will
|
||||
# shard the scales in TP>1 case.
|
||||
scales_and_zp_input_dim = 0
|
||||
scales_and_zp_size = input_size_per_partition // group_size
|
||||
|
||||
# Init buffers
|
||||
# Quantized weights
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // self.quant_config.pack_factor,
|
||||
output_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=0,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
# Activation order
|
||||
# Ignore warning from fused linear layers such as QKVParallelLinear.
|
||||
g_idx = RowvLLMParameter(data=torch.empty(
|
||||
input_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
# Scales
|
||||
scales = Parameter(
|
||||
torch.empty(
|
||||
scales_and_zp_size,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
scales,
|
||||
{
|
||||
**extra_weight_attrs,
|
||||
"input_dim": scales_and_zp_input_dim,
|
||||
"output_dim": 1,
|
||||
},
|
||||
)
|
||||
|
||||
# Quantized zero-points
|
||||
qzeros_args = {
|
||||
"data":
|
||||
torch.empty(
|
||||
scales_and_zp_size,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
"weight_loader":
|
||||
weight_loader
|
||||
}
|
||||
weight_scale_args = {
|
||||
"data":
|
||||
torch.empty(
|
||||
scales_and_zp_size,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
"weight_loader":
|
||||
weight_loader
|
||||
}
|
||||
|
||||
if scales_and_zp_input_dim is None:
|
||||
scales = ChannelQuantScaleParameter(output_dim=1,
|
||||
**weight_scale_args)
|
||||
qzeros = PackedColumnParameter(
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args)
|
||||
|
||||
else:
|
||||
scales = GroupQuantScaleParameter(output_dim=1,
|
||||
input_dim=0,
|
||||
**weight_scale_args)
|
||||
qzeros = PackedvLLMParameter(
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args)
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("g_idx", g_idx)
|
||||
layer.register_parameter("scales", scales)
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
|
||||
self.kernel = kernel_type(
|
||||
mp_linear_kernel_config,
|
||||
w_q_param_name="qweight",
|
||||
w_s_param_name="scales",
|
||||
w_zp_param_name="qzeros",
|
||||
w_gidx_param_name="g_idx",
|
||||
bitblas_quant_config=self.quant_config,
|
||||
)
|
||||
|
||||
# Initialize or retrieve the BitBLAS matrix multiplication operator.
|
||||
self.kernel.configure_bitblas_matmul(
|
||||
input_size_per_partition,
|
||||
output_size_per_partition,
|
||||
params_dtype=params_dtype,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
self.kernel.process_weights_after_loading(layer)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
out = self.kernel.apply_gptq_bitblas_linear(layer, x)
|
||||
if bias is not None:
|
||||
out.add_(bias)
|
||||
return out
|
||||
@ -5,6 +5,8 @@ from typing import List, Optional, Type
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark import ( # noqa: E501
|
||||
AllSparkLinearKernel)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.bitblas import ( # noqa: E501
|
||||
BitBLASLinearKernel)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501
|
||||
ExllamaLinearKernel)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501
|
||||
@ -20,6 +22,7 @@ _POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [
|
||||
MacheteLinearKernel,
|
||||
AllSparkLinearKernel,
|
||||
MarlinLinearKernel,
|
||||
BitBLASLinearKernel,
|
||||
ExllamaLinearKernel,
|
||||
]
|
||||
|
||||
@ -76,4 +79,4 @@ def choose_mp_linear_kernel(
|
||||
raise ValueError(
|
||||
"Failed to find a kernel that can implement the "\
|
||||
"WNA16 linear layer. Reasons: \n"
|
||||
+ '\n'.join(failure_reasons))
|
||||
+ '\n'.join(failure_reasons))
|
||||
@ -0,0 +1,299 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
|
||||
BITBLAS_OPTIMIZE_FEATURES, BITBLAS_SUPPORTED_GROUP_SIZES,
|
||||
MINIMUM_BITBLAS_VERSION, bitblas_make_empty_g_idx, bitblas_sort_g_idx,
|
||||
check_bitblas_supports_shape, query_bitblas_supported_quant_types,
|
||||
unpack_gptq_qweight, unpack_gptq_qzeros)
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BitBLASLinearKernel(MPLinearKernel):
|
||||
|
||||
OPT_FEATURES: List[int] = BITBLAS_OPTIMIZE_FEATURES
|
||||
ENABLE_TUNING: bool = True
|
||||
MATMUL_LAYOUT: str = "nt"
|
||||
BITBLAS_DTYPES: Dict[torch.dtype, str] = {
|
||||
torch.float32: "float32",
|
||||
torch.float16: "float16",
|
||||
torch.bfloat16: "bfloat16",
|
||||
torch.half: "float16",
|
||||
torch.int8: "int8",
|
||||
}
|
||||
bitblas_matmul: object = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
c: MPLinearLayerConfig,
|
||||
w_q_param_name: str,
|
||||
w_s_param_name: str,
|
||||
w_zp_param_name: Optional[str] = None,
|
||||
w_gidx_param_name: Optional[str] = None,
|
||||
bitblas_quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
self.quant_config = bitblas_quant_config
|
||||
super().__init__(c, w_q_param_name, w_s_param_name, w_zp_param_name,
|
||||
w_gidx_param_name)
|
||||
|
||||
def repack_bitblas_from_gptq(
|
||||
self,
|
||||
b_q_weight: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
qzeros: Optional[torch.Tensor] = None,
|
||||
):
|
||||
from bitblas.quantization.utils import general_compress
|
||||
assert self.bitblas_matmul is not None, "bitblas_matmul is None"
|
||||
|
||||
quant_config = self.quant_config
|
||||
# qweight in gptq old quant linear stored with
|
||||
# (outfeatures, infeatures), should be transposed.
|
||||
qweight = b_q_weight.T.contiguous().view(
|
||||
quant_config.torch_storage_dtype) # type: ignore[union-attr]
|
||||
intweight = unpack_gptq_qweight(
|
||||
qweight,
|
||||
quant_config.weight_bits).contiguous() # type: ignore[union-attr]
|
||||
if self.bitblas_matmul.weight_transform is not None: # type: ignore[attr-defined]
|
||||
qweight = self.bitblas_matmul.weight_transform( # type: ignore[attr-defined]
|
||||
intweight.cpu()).cuda()
|
||||
# scales in gptq old quant linear stored with
|
||||
# (infeatures // group_size, outfeatures), should be transposed.
|
||||
scales = scales.T.contiguous()
|
||||
|
||||
if qzeros is None:
|
||||
return qweight, scales, None
|
||||
|
||||
# qzeros should be de-quantized to int zeros.
|
||||
weight_bits = quant_config.weight_bits # type: ignore[union-attr]
|
||||
intzeros = unpack_gptq_qzeros(qzeros, weight_bits).T.contiguous()
|
||||
zeros: Optional[torch.Tensor] = None
|
||||
zeros_mode = self.bitblas_matmul.config.zeros_mode # type: ignore[attr-defined]
|
||||
if zeros_mode == "original":
|
||||
zeros = intzeros.to(torch.float16).contiguous()
|
||||
elif zeros_mode == "rescale":
|
||||
assert zeros is not None, "zeros should not be None"
|
||||
zeros[:, :] = intzeros.to(torch.float16)[:, :] * scales[:, :]
|
||||
elif zeros_mode == "quantized":
|
||||
zeros = (
|
||||
torch.Tensor(
|
||||
general_compress(
|
||||
intzeros.T.contiguous().cpu().numpy(),
|
||||
weight_bits,
|
||||
)).to(qweight.device).
|
||||
to(quant_config.torch_storage_dtype # type: ignore[union-attr]
|
||||
).contiguous())
|
||||
else:
|
||||
raise ValueError("Unsupported zeros type: {}".format(zeros_mode))
|
||||
|
||||
return qweight, scales, zeros
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 70
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
||||
|
||||
is_bitblas_installed = True
|
||||
|
||||
try:
|
||||
import bitblas
|
||||
if bitblas.__version__ < MINIMUM_BITBLAS_VERSION:
|
||||
raise ImportError(
|
||||
"bitblas version is wrong. Please "
|
||||
f"install bitblas>={MINIMUM_BITBLAS_VERSION}")
|
||||
except ImportError:
|
||||
is_bitblas_installed = False
|
||||
|
||||
if not is_bitblas_installed:
|
||||
return False, "bitblas is not installed. Please install bitblas "\
|
||||
"by running `pip install bitblas>="\
|
||||
f"{MINIMUM_BITBLAS_VERSION}`"
|
||||
|
||||
quant_types = query_bitblas_supported_quant_types(c.zero_points)
|
||||
if c.weight_type not in quant_types:
|
||||
return False, (f"Quant type ({c.weight_type}) not supported by"
|
||||
f" BitBLAS, supported types are: {quant_types}")
|
||||
|
||||
if c.group_size not in BITBLAS_SUPPORTED_GROUP_SIZES:
|
||||
return False, (f"Group size ({c.group_size}) not supported by "
|
||||
"BitBLAS, supported group sizes are: "
|
||||
f"{BITBLAS_SUPPORTED_GROUP_SIZES}")
|
||||
|
||||
return check_bitblas_supports_shape(
|
||||
c.partition_weight_shape[1], # out_features
|
||||
c.partition_weight_shape[0], # in_features
|
||||
c.full_weight_shape[0], # in_features
|
||||
c.group_size)
|
||||
|
||||
# note assumes that
|
||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
device = getattr(layer, self.w_q_name).device
|
||||
c = self.config
|
||||
quant_config = self.quant_config
|
||||
|
||||
# Default names since bitblas requires empty parameters for these,
|
||||
# TODO: remove this requirement from bitblas (allow optional tensors)
|
||||
if self.w_gidx_name is None:
|
||||
self.w_gidx_name = "g_idx"
|
||||
if self.w_zp_name is None:
|
||||
self.w_zp_name = "qzeros"
|
||||
|
||||
if c.has_g_idx:
|
||||
g_idx, g_idx_sort_indices = bitblas_sort_g_idx(
|
||||
getattr(layer, self.w_gidx_name))
|
||||
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
|
||||
layer.g_idx_sort_indices = g_idx_sort_indices
|
||||
else:
|
||||
setattr(layer, self.w_gidx_name, bitblas_make_empty_g_idx(device))
|
||||
layer.g_idx_sort_indices = bitblas_make_empty_g_idx(device)
|
||||
|
||||
if c.zero_points:
|
||||
raise NotImplementedError("Zero points not supported by BitBLAS")
|
||||
else:
|
||||
setattr(layer, self.w_zp_name, bitblas_make_empty_g_idx(device))
|
||||
|
||||
# Repack weights
|
||||
bitblas_qweight, bitblas_scales, bitblas_qzeros = (
|
||||
self.repack_bitblas_from_gptq(
|
||||
layer.qweight,
|
||||
layer.scales,
|
||||
None if quant_config.is_sym else # type: ignore[union-attr]
|
||||
layer.qzeros, # type: ignore[union-attr]
|
||||
))
|
||||
replace_parameter(layer, self.w_q_name, bitblas_qweight)
|
||||
replace_parameter(layer, self.w_s_name, bitblas_scales)
|
||||
if bitblas_qzeros is not None:
|
||||
replace_parameter(layer, self.w_zp_name, bitblas_qzeros)
|
||||
|
||||
def configure_bitblas_matmul(
|
||||
self,
|
||||
infeatures: int,
|
||||
outfeatures: int,
|
||||
params_dtype: torch.dtype,
|
||||
bias: bool,
|
||||
) -> None:
|
||||
enable_tuning = self.ENABLE_TUNING
|
||||
layout = self.MATMUL_LAYOUT
|
||||
bits = self.quant_config.weight_bits # type: ignore[union-attr]
|
||||
self._configure_bitblas_matmul(
|
||||
infeatures,
|
||||
outfeatures,
|
||||
params_dtype,
|
||||
enable_tuning,
|
||||
bias,
|
||||
layout,
|
||||
bits,
|
||||
)
|
||||
|
||||
def _configure_bitblas_matmul(
|
||||
self,
|
||||
infeatures,
|
||||
outfeatures,
|
||||
params_dtype,
|
||||
enable_tuning,
|
||||
bias,
|
||||
layout,
|
||||
bits,
|
||||
):
|
||||
from bitblas import MatmulConfig
|
||||
bitblas_dtype = self.BITBLAS_DTYPES[params_dtype]
|
||||
quant_config = self.quant_config
|
||||
with_scaling = False
|
||||
with_zeros = False
|
||||
group_size = quant_config.group_size # type: ignore[union-attr]
|
||||
zeros_mode = quant_config.zeros_mode # type: ignore[union-attr]
|
||||
if quant_config.quant_method == "gptq": # type: ignore[union-attr]
|
||||
with_scaling = True
|
||||
with_zeros = True
|
||||
W_dtype = f"uint{bits}"
|
||||
if quant_config.is_sym: # type: ignore[union-attr]
|
||||
with_zeros = False
|
||||
W_dtype = f"int{bits}"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported quant_method {quant_config.quant_method}" # type: ignore[union-attr]
|
||||
) # type: ignore[union-attr]
|
||||
|
||||
matmul_config = MatmulConfig(
|
||||
M=self.OPT_FEATURES,
|
||||
N=outfeatures,
|
||||
K=infeatures,
|
||||
A_dtype=bitblas_dtype,
|
||||
W_dtype=W_dtype,
|
||||
out_dtype=bitblas_dtype,
|
||||
accum_dtype="int32" if bitblas_dtype == "int8" else bitblas_dtype,
|
||||
storage_dtype=quant_config. # type: ignore[union-attr]
|
||||
storage_dtype, # type: ignore[union-attr]
|
||||
with_scaling=with_scaling,
|
||||
with_zeros=with_zeros,
|
||||
group_size=group_size,
|
||||
with_bias=bias,
|
||||
layout=layout,
|
||||
zeros_mode=zeros_mode,
|
||||
)
|
||||
self.bitblas_matmul = self._get_or_create_bitblas_operator(
|
||||
matmul_config, enable_tuning)
|
||||
|
||||
def _get_or_create_bitblas_operator(self, config, enable_tuning):
|
||||
from bitblas import Matmul, auto_detect_nvidia_target
|
||||
from bitblas.cache import get_database_path, global_operator_cache
|
||||
BITBLAS_DATABASE_PATH = get_database_path()
|
||||
BITBLAS_TARGET = auto_detect_nvidia_target()
|
||||
|
||||
if global_operator_cache.size() == 0:
|
||||
global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH,
|
||||
BITBLAS_TARGET)
|
||||
|
||||
bitblas_matmul = global_operator_cache.get(config)
|
||||
if bitblas_matmul is None:
|
||||
bitblas_matmul = Matmul(config,
|
||||
target=BITBLAS_TARGET,
|
||||
enable_tuning=False)
|
||||
if enable_tuning:
|
||||
bitblas_matmul.hardware_aware_finetune(topk=20)
|
||||
global_operator_cache.add(config, bitblas_matmul)
|
||||
global_operator_cache.save_into_database(
|
||||
BITBLAS_DATABASE_PATH, BITBLAS_TARGET)
|
||||
TUNING_MESSAGE = (
|
||||
f"BitBLAS Operator {config} tuned and saved to database.")
|
||||
logger.info(TUNING_MESSAGE)
|
||||
else:
|
||||
_message = f"BitBLAS Operator {config} created without tuning. "
|
||||
logger.info(_message)
|
||||
else:
|
||||
_message = f"BitBLAS Operator {config} retrieved from cache."
|
||||
logger.info(_message)
|
||||
return bitblas_matmul
|
||||
|
||||
def apply_gptq_bitblas_linear(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
output_size_per_partition = self.config.partition_weight_shape[1]
|
||||
out_shape = x.shape[:-1] + (output_size_per_partition, )
|
||||
args = [x, layer.qweight, layer.scales]
|
||||
if self.bitblas_matmul.config.with_zeros: # type: ignore[attr-defined]
|
||||
args.append(layer.qzeros)
|
||||
output = self.bitblas_matmul(*args) # type: ignore[operator]
|
||||
return output.view(out_shape)
|
||||
|
||||
def apply_weights(self, layer, x, bias=None):
|
||||
NOT_IMPLEMENT_MESSAGE = (
|
||||
f"{self.__class__.__name__}.apply_weights is not implemented. "
|
||||
"Please use BitBLASLinearKernel.apply_gptq_bitblas_linear instead")
|
||||
raise NotImplementedError(NOT_IMPLEMENT_MESSAGE)
|
||||
198
vllm/model_executor/layers/quantization/utils/bitblas_utils.py
Normal file
198
vllm/model_executor/layers/quantization/utils/bitblas_utils.py
Normal file
@ -0,0 +1,198 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
MINIMUM_BITBLAS_VERSION = "0.1.0"
|
||||
|
||||
BITBLAS_MIN_WEIGHT_SIZE_N = 16
|
||||
BITBLAS_MIN_WEIGHT_SIZE_K = 16
|
||||
GPTQ_BITBLAS_MAX_PARALLEL = 16
|
||||
|
||||
BITBLAS_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
|
||||
# For dynamic shape code generation
|
||||
BITBLAS_OPTIMIZE_FEATURES = [1, 16, 32, 64, 128, 256, 512, 1024]
|
||||
# If want to enable high performance for contiguous batching
|
||||
# Please use the following values
|
||||
BITBLAS_OPTIMIZE_FEATURES_CONTIGUOUS = [16, 32, 64, 128, 256, 512, 1024]
|
||||
|
||||
BITBLAS_SUPPORTED_NUM_BITS = [1, 2, 4, 8]
|
||||
BITBLAS_SUPPORTED_SYM = [False, True]
|
||||
|
||||
|
||||
# Determines the supported quantization types for BitBLAS based on the
|
||||
# device's capability and whether zero-point (zp) is used.
|
||||
def query_bitblas_supported_quant_types(has_zp: bool,
|
||||
device_capability: Optional[int] = None
|
||||
):
|
||||
if device_capability is None:
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
device_capability = (-1 if capability_tuple is None else
|
||||
capability_tuple.to_int())
|
||||
|
||||
if device_capability < 70:
|
||||
return []
|
||||
|
||||
if has_zp:
|
||||
# AWQ style, unsigned + runtime zero-point
|
||||
return [scalar_types.uint4, scalar_types.uint8]
|
||||
else:
|
||||
# GPTQ style, unsigned + symmetric bias
|
||||
# TODO: once fp8_bitblas is merged into "gptq_bitblas" we should be able
|
||||
# to add `scalar_types.float8_e4m3fn` here
|
||||
return [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
|
||||
|
||||
def _check_bitblas_supported(
|
||||
quant_type: ScalarType,
|
||||
group_size: Optional[int],
|
||||
has_zp: bool,
|
||||
device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
|
||||
|
||||
if device_capability is None:
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
device_capability = (-1 if capability_tuple is None else
|
||||
capability_tuple.to_int())
|
||||
|
||||
supported_types = query_bitblas_supported_quant_types(
|
||||
has_zp, device_capability)
|
||||
|
||||
if quant_type not in supported_types:
|
||||
return (False, f"BitBLAS does not support weight_bits = {quant_type}. "
|
||||
f"Only types = {supported_types} "
|
||||
f"are supported (for group_size = {group_size}, "
|
||||
f"device_capability = {device_capability}, zp = {has_zp}).")
|
||||
if (group_size is None or group_size not in BITBLAS_SUPPORTED_GROUP_SIZES):
|
||||
return (False, f"BitBLAS does not support group_size = {group_size}. "
|
||||
f"Only group_sizes = {BITBLAS_SUPPORTED_GROUP_SIZES} "
|
||||
"are supported.")
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
def check_bitblas_supported(quant_type: ScalarType,
|
||||
group_size: int,
|
||||
has_zp: bool = False,
|
||||
device_capability: Optional[int] = None) -> bool:
|
||||
cond, _ = _check_bitblas_supported(quant_type, group_size, has_zp,
|
||||
device_capability)
|
||||
return cond
|
||||
|
||||
|
||||
def verify_bitblas_supported(quant_type: ScalarType,
|
||||
group_size: int,
|
||||
has_zp: bool = False) -> None:
|
||||
cond, err_msg = _check_bitblas_supported(quant_type, group_size, has_zp)
|
||||
if not cond:
|
||||
assert err_msg is not None
|
||||
raise ValueError(err_msg)
|
||||
|
||||
|
||||
def verify_bitblas_supports_shape(output_size_per_partition: int,
|
||||
input_size_per_partition: int,
|
||||
input_size: int, group_size: int) -> None:
|
||||
|
||||
# Validate output_size_per_partition
|
||||
if output_size_per_partition % BITBLAS_MIN_WEIGHT_SIZE_N != 0:
|
||||
raise ValueError(f"Weight output_size_per_partition = "
|
||||
f"{output_size_per_partition} is not divisible by "
|
||||
f" min_thread_n = {BITBLAS_MIN_WEIGHT_SIZE_N}. "
|
||||
"Consider reducing tensor_parallel_size or running "
|
||||
"with --quantization gptq.")
|
||||
|
||||
# Validate input_size_per_partition
|
||||
if input_size_per_partition % BITBLAS_MIN_WEIGHT_SIZE_K != 0:
|
||||
raise ValueError(f"Weight input_size_per_partition = "
|
||||
f"{input_size_per_partition} is not divisible "
|
||||
f"by min_thread_k = {BITBLAS_MIN_WEIGHT_SIZE_K}. "
|
||||
"Consider reducing tensor_parallel_size or running "
|
||||
"with --quantization gptq.")
|
||||
|
||||
if (group_size < input_size
|
||||
and input_size_per_partition % group_size != 0):
|
||||
raise ValueError(
|
||||
f"Weight input_size_per_partition = {input_size_per_partition}"
|
||||
f" is not divisible by group_size = {group_size}."
|
||||
"Consider reducing tensor_parallel_size or running "
|
||||
"with --quantization gptq.")
|
||||
|
||||
|
||||
def check_bitblas_supports_shape(output_size_per_partition: int,
|
||||
input_size_per_partition: int,
|
||||
input_size: int, group_size: int) \
|
||||
-> Tuple[bool, Optional[str]]:
|
||||
try:
|
||||
verify_bitblas_supports_shape(output_size_per_partition,
|
||||
input_size_per_partition, input_size,
|
||||
group_size)
|
||||
except ValueError as e:
|
||||
return False, e.__str__()
|
||||
return True, None
|
||||
|
||||
|
||||
def bitblas_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
|
||||
return (not act_order) or (act_order and not is_row_parallel)
|
||||
|
||||
|
||||
def bitblas_repeat_scales_on_all_ranks(act_order: bool, group_size: int,
|
||||
is_row_parallel: bool) -> bool:
|
||||
# Need to repeat scales on every rank if act_ordering or
|
||||
# channelwise and RowParallelLinear
|
||||
is_channelwise = group_size == -1
|
||||
return act_order or (is_channelwise and is_row_parallel)
|
||||
|
||||
|
||||
def bitblas_make_empty_g_idx(device: torch.device) -> torch.Tensor:
|
||||
return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
|
||||
requires_grad=False)
|
||||
|
||||
|
||||
def bitblas_make_empty_zp(device: torch.device) -> torch.Tensor:
|
||||
return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
|
||||
requires_grad=False)
|
||||
|
||||
|
||||
def bitblas_sort_g_idx(
|
||||
g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
|
||||
return g_idx[g_idx_sort_indices], g_idx_sort_indices
|
||||
|
||||
|
||||
def unpack_gptq_qzeros(qzeros, bits, is_gptq_v2=False) -> torch.Tensor:
|
||||
qzeros = qzeros.view(torch.int32)
|
||||
elems_per_int32 = 32 // bits
|
||||
unpacked_zeros = torch.zeros(
|
||||
(qzeros.shape[0], qzeros.shape[1] * elems_per_int32),
|
||||
dtype=torch.int8,
|
||||
device=qzeros.device,
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
for col in range(unpacked_zeros.shape[1]):
|
||||
i = col % elems_per_int32
|
||||
unpacked_zeros[:, col] = (qzeros[:, col // elems_per_int32] >>
|
||||
(bits * i)) & 0xF
|
||||
if not is_gptq_v2:
|
||||
return unpacked_zeros + 1
|
||||
return unpacked_zeros
|
||||
|
||||
|
||||
def unpack_gptq_qweight(qweight, bits):
|
||||
qweight = qweight.view(torch.int8)
|
||||
elems_per_int8 = 8 // bits
|
||||
unpacked_weight = torch.zeros(
|
||||
(qweight.shape[0], qweight.shape[1] * elems_per_int8),
|
||||
dtype=torch.int8,
|
||||
device=qweight.device,
|
||||
requires_grad=False,
|
||||
)
|
||||
for col in range(unpacked_weight.shape[1]):
|
||||
i = col % elems_per_int8
|
||||
unpacked_weight[:, col] = (qweight[:, col // elems_per_int8] >>
|
||||
(bits * i))
|
||||
|
||||
return torch.bitwise_and(unpacked_weight, 2**bits - 1)
|
||||
@ -282,10 +282,12 @@ class PackedColumnParameter(_ColumnvLLMParameter):
|
||||
packed_factor: Union[int, Fraction],
|
||||
packed_dim: int,
|
||||
marlin_tile_size: Optional[int] = None,
|
||||
bitblas_tile_size: Optional[int] = None,
|
||||
**kwargs):
|
||||
self._packed_factor = packed_factor
|
||||
self._packed_dim = packed_dim
|
||||
self._marlin_tile_size = marlin_tile_size
|
||||
self._bitblas_tile_size = bitblas_tile_size
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
@ -300,12 +302,17 @@ class PackedColumnParameter(_ColumnvLLMParameter):
|
||||
def marlin_tile_size(self):
|
||||
return self._marlin_tile_size
|
||||
|
||||
@property
|
||||
def bitblas_tile_size(self):
|
||||
return self._bitblas_tile_size
|
||||
|
||||
def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
|
||||
return _adjust_shard_indexes_for_packing(
|
||||
shard_size=shard_size,
|
||||
shard_offset=shard_offset,
|
||||
packed_factor=self.packed_factor,
|
||||
marlin_tile_size=self.marlin_tile_size)
|
||||
marlin_tile_size=self.marlin_tile_size,
|
||||
bitblas_tile_size=self.bitblas_tile_size)
|
||||
|
||||
|
||||
class PackedvLLMParameter(ModelWeightParameter):
|
||||
@ -323,10 +330,12 @@ class PackedvLLMParameter(ModelWeightParameter):
|
||||
packed_factor: Union[int, Fraction],
|
||||
packed_dim: int,
|
||||
marlin_tile_size: Optional[int] = None,
|
||||
bitblas_tile_size: Optional[int] = None,
|
||||
**kwargs):
|
||||
self._packed_factor = packed_factor
|
||||
self._packed_dim = packed_dim
|
||||
self._marlin_tile_size = marlin_tile_size
|
||||
self._bitblas_tile_size = bitblas_tile_size
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
@ -341,12 +350,17 @@ class PackedvLLMParameter(ModelWeightParameter):
|
||||
def marlin_tile_size(self):
|
||||
return self._marlin_tile_size
|
||||
|
||||
@property
|
||||
def bitblas_tile_size(self):
|
||||
return self._bitblas_tile_size
|
||||
|
||||
def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
|
||||
return _adjust_shard_indexes_for_packing(
|
||||
shard_size=shard_size,
|
||||
shard_offset=shard_offset,
|
||||
packed_factor=self.packed_factor,
|
||||
marlin_tile_size=self.marlin_tile_size)
|
||||
marlin_tile_size=self.marlin_tile_size,
|
||||
bitblas_tile_size=self.bitblas_tile_size)
|
||||
|
||||
|
||||
class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
|
||||
@ -421,8 +435,13 @@ def _adjust_shard_indexes_for_marlin(shard_size, shard_offset,
|
||||
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
|
||||
|
||||
|
||||
def _adjust_shard_indexes_for_bitblas(shard_size, shard_offset,
|
||||
bitblas_tile_size):
|
||||
return shard_size // bitblas_tile_size, shard_offset // bitblas_tile_size
|
||||
|
||||
|
||||
def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor,
|
||||
marlin_tile_size):
|
||||
marlin_tile_size, bitblas_tile_size):
|
||||
shard_size = shard_size // packed_factor
|
||||
shard_offset = shard_offset // packed_factor
|
||||
if marlin_tile_size is not None:
|
||||
@ -430,4 +449,10 @@ def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor,
|
||||
shard_size=shard_size,
|
||||
shard_offset=shard_offset,
|
||||
marlin_tile_size=marlin_tile_size)
|
||||
return shard_size, shard_offset
|
||||
elif bitblas_tile_size is not None:
|
||||
return _adjust_shard_indexes_for_bitblas(
|
||||
shard_size=shard_size,
|
||||
shard_offset=shard_offset,
|
||||
bitblas_tile_size=bitblas_tile_size)
|
||||
|
||||
return shard_size, shard_offset
|
||||
Loading…
x
Reference in New Issue
Block a user