diff --git a/benchmarks/kernels/benchmark_bitblas.py b/benchmarks/kernels/benchmark_bitblas.py new file mode 100644 index 000000000000..b23b4f3ea685 --- /dev/null +++ b/benchmarks/kernels/benchmark_bitblas.py @@ -0,0 +1,236 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( + MINIMUM_BITBLAS_VERSION) + +try: + import bitblas + if bitblas.__version__ < MINIMUM_BITBLAS_VERSION: + raise ImportError("bitblas version is wrong. Please " + f"install bitblas>={MINIMUM_BITBLAS_VERSION}") +except ImportError as e: + bitblas_import_exception = e + raise ValueError("Trying to use the bitblas backend, but could not import" + f"with the following error: {bitblas_import_exception}. " + "Please install bitblas through the following command: " + f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`" + ) from bitblas_import_exception + +from bitblas import Matmul, MatmulConfig, auto_detect_nvidia_target + +from vllm.utils import FlexibleArgumentParser + +parser = FlexibleArgumentParser( + description="Benchmark BitBLAS int4 on a specific target.") + +# Add arguments to the parser +parser.add_argument( + "--target", + type=str, + default=auto_detect_nvidia_target(), + help="Specify the target device for benchmarking.", +) +parser.add_argument("--group_size", + type=int, + default=None, + help="Group size for grouped quantization.") +parser.add_argument( + "--A_dtype", + type=str, + default="float16", + choices=["float16", "float32", "float64", "int32", "int8"], + help="Data type of activation A.", +) +parser.add_argument( + "--W_dtype", + type=str, + default="int4", + choices=[ + "float16", + "float32", + "float64", + "int32", + "int8", + "int4", + "int2", + "int1", + "nf4", + "fp4_e2m1", + ], + help="Data type of weight W.", +) +parser.add_argument( + "--accum_dtype", + type=str, + default="float16", + choices=["float16", "int32"], + help="Data type for accumulation.", +) +parser.add_argument( + "--out_dtype", + type=str, + default="float16", + choices=["float16", "float32", "int32", "int8"], + help="Data type for output.", +) +parser.add_argument( + "--layout", + type=str, + default="nt", + choices=["nt", "nn"], + help="Matrix layout, 'nt' for non-transpose A and transpose W.", +) +parser.add_argument("--with_bias", + action="store_true", + help="Include bias in the benchmark.") +parser.add_argument( + "--with_scaling", + action="store_true", + help="Include scaling factor in the quantization.", +) +parser.add_argument("--with_zeros", + action="store_true", + help="Include zeros in the quantization.") +parser.add_argument( + "--zeros_mode", + type=str, + default=None, + choices=["original", "rescale", "quantized"], + help="Specify the mode for calculating zeros.", +) + +# Parse the arguments +args = parser.parse_args() + +# Assign arguments to variables +target = args.target +A_dtype = args.A_dtype +W_dtype = args.W_dtype +accum_dtype = args.accum_dtype +out_dtype = args.out_dtype +layout = args.layout +with_bias = args.with_bias +group_size = args.group_size +with_scaling = args.with_scaling +with_zeros = args.with_zeros +zeros_mode = args.zeros_mode + +# Define a list of shared arguments that repeat in every config +shared_args = [ + A_dtype, + W_dtype, + out_dtype, + accum_dtype, + layout, + with_bias, + group_size, + with_scaling, + with_zeros, + zeros_mode, +] + +# Define just the (M, K, N) shapes in a more compact list +shapes = [ + # square test + (1, 16384, 16384), + # BLOOM-176B + (1, 43008, 14336), + (1, 14336, 14336), + (1, 57344, 14336), + (1, 14336, 57344), + # OPT-65B + (1, 9216, 9216), + (1, 36864, 9216), + (1, 9216, 36864), + (1, 22016, 8192), + # LLAMA-70B/65B + (1, 8192, 22016), + (1, 8192, 8192), + (1, 28672, 8192), + (1, 8192, 28672), + # square test + (16384, 16384, 16384), + # BLOOM-176B + (8192, 43008, 14336), + (8192, 14336, 14336), + (8192, 57344, 14336), + (8192, 14336, 57344), + # OPT-65B + (8192, 9216, 9216), + (8192, 36864, 9216), + (8192, 9216, 36864), + (8192, 22016, 8192), + # LLAMA-70B/65B + (8192, 8192, 22016), + (8192, 8192, 8192), + (8192, 28672, 8192), + (8192, 8192, 28672), +] + +# Build test shapes with all the shared arguments +test_shapes = [(MatmulConfig, Matmul, (*shape, *shared_args)) + for shape in shapes] + +benchmark_sets = [] +benchmark_sets.extend(test_shapes) + +benchmark_results = {} +for config_class, operator, input_args in benchmark_sets: + config = config_class(*input_args) + matmul = operator(config, target=target, enable_tuning=True) + kernel_latency = matmul.profile_latency() + + print("Time cost is: {:.3f} ms".format(kernel_latency)) + + profile_config = { + f"{operator.__name__}-{'-'.join([str(i) for i in input_args])}": { + "BitBLAS_top20_latency": kernel_latency, + } + } + + benchmark_results.update(profile_config) + +# Define headers for the table +headers = [ + "PrimFunc", + "Input Arguments", + "BitBLAS Top20 Latency", +] + +# Calculate column widths for pretty printing +col_widths = [0, 0, 0] +for config_key, values in benchmark_results.items(): + args_split = config_key.split("-") + func_name = args_split[0] + input_args_str = "-".join(args_split[1:]) + col_widths[0] = max(col_widths[0], len(func_name) + 2, len(headers[0]) + 2) + col_widths[1] = max(col_widths[1], + len(input_args_str) + 2, + len(headers[1]) + 2) + col_widths[2] = max(col_widths[2], + len(f"{values['BitBLAS_top20_latency']:.3f} ms") + 2, + len(headers[2]) + 2) + # break only if you want to measure widths from a single example; + # otherwise, let it loop over all items. + +# Print header +for i, header in enumerate(headers): + headers[i] = header.ljust(col_widths[i]) +print("".join(headers)) +print("-" * sum(col_widths)) + +# Print rows +for config_key, values in benchmark_results.items(): + args_split = config_key.split("-") + func_name = args_split[0] + input_args_str = "-".join(args_split[1:]) + row = [ + func_name, + input_args_str, + f"{values['BitBLAS_top20_latency']:.3f} ms", + ] + row_str = "".join( + [str(cell).ljust(col_widths[idx]) for idx, cell in enumerate(row)]) + print(row_str) diff --git a/docs/source/features/quantization/bitblas.md b/docs/source/features/quantization/bitblas.md new file mode 100644 index 000000000000..aff917f90ec2 --- /dev/null +++ b/docs/source/features/quantization/bitblas.md @@ -0,0 +1,40 @@ +# BitBLAS + +vLLM now supports [BitBLAS](https://github.com/microsoft/BitBLAS) for more efficient and flexible model inference. Compared to other quantization frameworks, BitBLAS provides more precision combinations. + +Below are the steps to utilize BitBLAS with vLLM. + +```console +pip install bitblas>=0.1.0 +``` + +vLLM reads the model's config file and supports pre-quantized checkpoints. + +You can find pre-quantized models on: + +- [Hugging Face (BitBLAS)](https://huggingface.co/models?other=bitblas) +- [Hugging Face (GPTQ)](https://huggingface.co/models?other=gptq) + +Usually, these repositories have a `quantize_config.json` file that includes a `quantization_config` section. + +## Read bitblas format checkpoint + +```python +from vllm import LLM +import torch + +# "hxbgsyxh/llama-13b-4bit-g-1-bitblas" is a pre-quantized checkpoint. +model_id = "hxbgsyxh/llama-13b-4bit-g-1-bitblas" +llm = LLM(model=model_id, dtype=torch.bfloat16, trust_remote_code=True, quantization="bitblas") +``` + +## Read gptq format checkpoint + +```python +from vllm import LLM +import torch + +# "hxbgsyxh/llama-13b-4bit-g-1" is a pre-quantized checkpoint. +model_id = "hxbgsyxh/llama-13b-4bit-g-1" +llm = LLM(model=model_id, dtype=torch.float16, trust_remote_code=True, quantization="bitblas", max_model_len=1024) +``` diff --git a/docs/source/features/quantization/index.md b/docs/source/features/quantization/index.md index 6f539f6e3f48..c7c8aeb662a5 100644 --- a/docs/source/features/quantization/index.md +++ b/docs/source/features/quantization/index.md @@ -11,6 +11,7 @@ Quantization trades off model precision for smaller memory footprint, allowing l supported_hardware auto_awq bnb +bitblas gguf gptqmodel int4 diff --git a/docs/source/features/quantization/supported_hardware.md b/docs/source/features/quantization/supported_hardware.md index 2cbe8779dd8a..984e6626e241 100644 --- a/docs/source/features/quantization/supported_hardware.md +++ b/docs/source/features/quantization/supported_hardware.md @@ -74,6 +74,17 @@ The table below shows the compatibility of various quantization implementations * ❌ * ❌ * ❌ +- * BitBLAS (GPTQ) + * ✅︎ + * ✅︎ + * ✅︎ + * ✅︎ + * ✅︎ + * ✅︎ + * ❌ + * ❌ + * ❌ + * ❌ - * AQLM * ✅︎ * ✅︎ diff --git a/tests/models/test_bitblas.py b/tests/models/test_bitblas.py new file mode 100644 index 000000000000..ae4a52214ad0 --- /dev/null +++ b/tests/models/test_bitblas.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Compare the outputs of a GPTQ model to a bitblas model. + +Note: GPTQ and bitblas do not have bitwise correctness. +As a result, in this test, we just confirm that the top selected tokens of the +bitblas/GPTQ models are in the top 3 selections of each other. + +Note: bitblas internally uses locks to synchronize the threads. This can +result in very slight nondeterminism for bitblas. As a result, we re-run the +test up to 3 times to see if we pass. + +Run `pytest tests/models/test_bitblas.py`. +""" +from dataclasses import dataclass + +import pytest + +from .utils import check_logprobs_close + + +@dataclass +class ModelPair: + model_bitblas: str + model_gptq: str + + +model_pairs = [ + ModelPair(model_bitblas="hxbgsyxh/opt-125m-4bit-128g-bitblas", + model_gptq="hxbgsyxh/opt-125m-4bit-128g"), +] + + +@pytest.mark.flaky(reruns=2) +@pytest.mark.skipif(True, reason="BitBLAS takes too much time for tuning.") +@pytest.mark.parametrize("model_pair", model_pairs) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models( + vllm_runner, + example_prompts, + model_pair: ModelPair, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: + with vllm_runner(model_pair.model_bitblas, + dtype=dtype, + quantization="bitblas") as bitblas_model: + bitblas_outputs = bitblas_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + with vllm_runner(model_pair.model_gptq, dtype=dtype, + quantization="gptq") as gptq_model: + gptq_outputs = gptq_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + check_logprobs_close( + outputs_0_lst=gptq_outputs, + outputs_1_lst=bitblas_outputs, + name_0="gptq", + name_1="bitblas", + ) diff --git a/tests/models/test_gptq_bitblas.py b/tests/models/test_gptq_bitblas.py new file mode 100644 index 000000000000..d28442120ea6 --- /dev/null +++ b/tests/models/test_gptq_bitblas.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Compare the outputs of a GPTQ model to a bitblas model. + +Note: GPTQ and bitblas do not have bitwise correctness. +As a result, in this test, we just confirm that the top selected tokens of the +bitblas/GPTQ models are in the top 3 selections of each other. + +Note: bitblas internally uses locks to synchronize the threads. This can +result in very slight nondeterminism for bitblas. As a result, we re-run the +test up to 3 times to see if we pass. + +Run `pytest tests/models/test_bitblas.py`. +""" +from dataclasses import dataclass + +import pytest + +from .utils import check_logprobs_close + + +@dataclass +class ModelPair: + model_gptq: str + + +model_pairs = [ + ModelPair(model_gptq="hxbgsyxh/opt-125m-4bit-128g"), +] + + +@pytest.mark.flaky(reruns=2) +@pytest.mark.skipif(True, reason="BitBLAS takes too much time for tuning.") +@pytest.mark.parametrize("model_pair", model_pairs) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models( + vllm_runner, + example_prompts, + model_pair: ModelPair, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: + with vllm_runner(model_pair.model_gptq, + dtype=dtype, + quantization="bitblas") as bitblas_model: + bitblas_outputs = bitblas_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + with vllm_runner(model_pair.model_gptq, dtype=dtype, + quantization="gptq") as gptq_model: + gptq_outputs = gptq_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + check_logprobs_close( + outputs_0_lst=gptq_outputs, + outputs_1_lst=bitblas_outputs, + name_0="gptq", + name_1="gptq_bitblas", + ) diff --git a/vllm/config.py b/vllm/config.py index 20ca20ad2b6d..f9e0ed937604 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -750,7 +750,8 @@ class ModelConfig: optimized_quantization_methods = [ "fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin", "awq_marlin", "fbgemm_fp8", "compressed_tensors", - "compressed-tensors", "experts_int8", "quark", "nvfp4" + "compressed-tensors", "experts_int8", "quark", "nvfp4", "bitblas", + "gptq_bitblas" ] if self.quantization is not None: self.quantization = self.quantization.lower() diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index c5536438f519..16500ab23e0f 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -31,6 +31,8 @@ logger = init_logger(__name__) WEIGHT_LOADER_V2_SUPPORTED = [ "CompressedTensorsLinearMethod", + "BitBLASLinearMethod", + "GPTQBitBLASLinearMethod", "AWQMarlinLinearMethod", "AWQLinearMethod", "GPTQMarlinLinearMethod", @@ -50,6 +52,15 @@ WEIGHT_LOADER_V2_SUPPORTED = [ ] +def adjust_bitblas_shard(param, shard_size, shard_offset): + bitblas_tile_size = getattr(param, "bitblas_tile_size", None) + if bitblas_tile_size is not None: + return (shard_size // bitblas_tile_size, + shard_offset // bitblas_tile_size) + + return shard_size, shard_offset + + def adjust_marlin_shard(param, shard_size, shard_offset): marlin_tile_size = getattr(param, "marlin_tile_size", None) if marlin_tile_size is None: @@ -615,6 +626,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear): shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) + shard_size, shard_offset = adjust_bitblas_shard( + param, shard_size, shard_offset) + if use_bitsandbytes_4bit: index = list(itertools.accumulate([0] + self.output_sizes)) orig_offsets = { @@ -646,6 +660,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) + shard_size, shard_offset = adjust_bitblas_shard( + param, shard_size, shard_offset) use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 89533955fd76..9e1bf05dab9e 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -18,9 +18,11 @@ QUANTIZATION_METHODS: List[str] = [ # The order of gptq methods is important for config.py iteration over # override_quantization_method(..) "marlin", + "bitblas", "gguf", "gptq_marlin_24", "gptq_marlin", + "gptq_bitblas", "awq_marlin", "gptq", "compressed-tensors", @@ -85,6 +87,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: from .aqlm import AQLMConfig from .awq import AWQConfig from .awq_marlin import AWQMarlinConfig + from .bitblas import BitBLASConfig from .bitsandbytes import BitsAndBytesConfig from .compressed_tensors.compressed_tensors import ( # noqa: E501 CompressedTensorsConfig) @@ -94,6 +97,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: from .fp8 import Fp8Config from .gguf import GGUFConfig from .gptq import GPTQConfig + from .gptq_bitblas import GPTQBitBLASConfig from .gptq_marlin import GPTQMarlinConfig from .gptq_marlin_24 import GPTQMarlin24Config from .hqq_marlin import HQQMarlinConfig @@ -119,9 +123,11 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: # The order of gptq methods is important for config.py iteration over # override_quantization_method(..) "marlin": MarlinConfig, + "bitblas": BitBLASConfig, "gguf": GGUFConfig, "gptq_marlin_24": GPTQMarlin24Config, "gptq_marlin": GPTQMarlinConfig, + "gptq_bitblas": GPTQBitBLASConfig, "awq_marlin": AWQMarlinConfig, "gptq": GPTQConfig, "compressed-tensors": CompressedTensorsConfig, @@ -146,4 +152,4 @@ __all__ = [ "QuantizationConfig", "get_quantization_config", "QUANTIZATION_METHODS", -] +] \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/bitblas.py b/vllm/model_executor/layers/quantization/bitblas.py new file mode 100644 index 000000000000..3eaaa6c252ce --- /dev/null +++ b/vllm/model_executor/layers/quantization/bitblas.py @@ -0,0 +1,459 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Any, Dict, List, Optional + +import torch + +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( + BITBLAS_OPTIMIZE_FEATURES, BITBLAS_SUPPORTED_NUM_BITS, + BITBLAS_SUPPORTED_SYM, MINIMUM_BITBLAS_VERSION) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedvLLMParameter) +from vllm.model_executor.utils import set_weight_attrs + +logger = init_logger(__name__) + + +class BitBLASConfig(QuantizationConfig): + """Config class for BitBLAS. + + Reference: https://github.com/Microsoft/BitBLAS + """ + TORCH_DTYPE = torch.float16 + STORAGE_DTYPE = "int8" # assume int8 storage + TORCH_STORAGE_DTYPE = getattr(torch, STORAGE_DTYPE) + # "original" or "rescale" or "quantized", + # gptq_with_bitblas prefer "quantized implementation" + ZEROS_MODE = "quantized" + + def __init__( + self, + weight_bits: int, + group_size: Optional[int], + desc_act: Optional[bool], + is_sym: Optional[bool], + quant_method: Optional[str], + lm_head_quantized: bool, + ) -> None: + try: + import bitblas + if bitblas.__version__ < MINIMUM_BITBLAS_VERSION: + raise ImportError( + "bitblas version is wrong. Please " + f"install bitblas>={MINIMUM_BITBLAS_VERSION}") + except ImportError as e: + bitblas_import_exception = e + raise ValueError( + "Trying to use the bitblas backend, but could not import" + f"with the following error: {bitblas_import_exception}. " + "Please install bitblas through the following command: " + f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`" + ) from bitblas_import_exception + + if desc_act and group_size == -1: + # In this case, act_order == True is the same as act_order == False + # (since we have only one group per output channel) + desc_act = False + + self.weight_bits = weight_bits + self.group_size = group_size + self.desc_act = desc_act + self.is_sym = is_sym + self.quant_method = quant_method + self.lm_head_quantized = lm_head_quantized + + # Verify + if self.weight_bits not in BITBLAS_SUPPORTED_NUM_BITS: + raise ValueError( + f"BitBLAS does not support weight_bits = {self.weight_bits}. " + f"Only weight_bits = {BITBLAS_SUPPORTED_NUM_BITS} " + "are supported.") + + if self.is_sym not in BITBLAS_SUPPORTED_SYM: + raise ValueError( + f"BitBLAS does not support is_sym = {self.is_sym}. " + f"Only sym = {BITBLAS_SUPPORTED_SYM} are supported.") + + storage_dtype = self.STORAGE_DTYPE + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + + self.storage_dtype = storage_dtype + self.storage_torch_dtype = self.TORCH_STORAGE_DTYPE + # 4 Bits packed into 32 bit datatype. + self.pack_factor = storage_nbit // weight_bits + self.nbits = weight_bits + + # Zeros type for the quantized weights. + self.zeros_mode = self.ZEROS_MODE + + def __repr__(self) -> str: + return (f"BitBLASConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"desc_act={self.desc_act}, " + f"is_sym={self.is_sym}, " + f"quant_method={self.quant_method})") + + @classmethod + def get_name(cls) -> str: + return "bitblas" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + # Need to figure it out + def get_min_capability(cls) -> int: + return 70 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quantize_config.json"] + + @staticmethod + def get_from_keys(config: Dict[str, Any], + keys: List[str], + default: Any = None) -> Any: + """Get a value from the model's quantization config.""" + for key in keys: + if key in config: + return config[key] + return default + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "BitBLASConfig": + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"], -1) + desc_act = cls.get_from_keys(config, ["desc_act"], False) + is_sym = cls.get_from_keys(config, ["sym"], False) + quant_method = cls.get_from_keys(config, ["quant_method"]) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], + default=False) + return cls(weight_bits, group_size, desc_act, is_sym, quant_method, + lm_head_quantized) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + # compat: autogptq >=0.8.0 use checkpoint_format: str + # compat: autogptq <=0.7.1 is_bitblas_format: bool + is_bitblas_format = (hf_quant_cfg.get("checkpoint_format") == "bitblas" + or hf_quant_cfg.get("is_bitblas_format", False)) + + is_valid_user_quant = (user_quant is None or user_quant == "gptq" + or user_quant == "bitblas") + + if is_bitblas_format and is_valid_user_quant: + msg = ("The model is serialized in {} format. Using {} kernel.". + format(cls.get_name(), cls.get_name())) + logger.info(msg) + return cls.get_name() + + return None + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["BitBLASLinearMethod"]: + if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead) + and self.lm_head_quantized): + return BitBLASLinearMethod(self) + return None + + +class BitBLASLinearMethod(LinearMethodBase): + """Linear method for BitBLAS. + + Args: + quant_config: The BitBLAS quantization config. + """ + # USE BITBLAS_OPTIMIZE_FEATURES_CONTIGUOUS + # Instead of BITBLAS_OPTIMIZE_FEATURES + # If you want to high contiguous batching + # performance + OPT_FEATURES = BITBLAS_OPTIMIZE_FEATURES + ENABLE_TUNING = True + BITBLAS_DTYPES = { + torch.float32: "float32", + torch.float16: "float16", + torch.bfloat16: "bfloat16", + torch.half: "float16", + torch.int8: "int8", + } + + def __init__(self, quant_config: BitBLASConfig): + self.quant_config = quant_config + + def create_weights_gptq( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + """Creates quantized weights for use in linear operations. + + The function initializes and returns a dictionary containing quantized + weights, scales, and zeros + for performing quantized matrix multiplication operations. + + Args: + input_size_per_partition: The size of the input partition. + output_size_per_partition: The size of the output partition. + input_size: The total size of the input (unused). + output_size: The total size of the output (unused). + params_dtype: + The data type of the parameters (expected to be torch.float16). + + Returns: + A dictionary containing the quantized weights ('qweight'), + scales ('scales'), and zeros ('zeros'). + + Raises: + ValueError: If `params_dtype` is not `torch.float16` or if the + input size per partition is not divisible by the group size in + `quant_config`. + """ + del input_size, output_size # Unused arguments. + weight_loader = extra_weight_attrs["weight_loader"] + + if params_dtype not in self.quant_config.get_supported_act_dtypes(): + raise ValueError("Parameter data type must be torch.float16, " + f"but got {params_dtype}") + group_size = self.quant_config.group_size + if group_size is None: + group_size = -1 + # Validate output_size_per_partition + output_size_per_partition = sum(output_partition_sizes) + if (group_size != -1 and input_size_per_partition % group_size != 0): + raise ValueError( + f"Input size per partition ({input_size_per_partition}) must " + f"be divisible by group size ({group_size}).") + + # Initialize or retrieve the BitBLAS matrix multiplication operator. + self._configure_bitblas_matmul( + input_size_per_partition, + output_size_per_partition, + params_dtype=params_dtype, + enable_tuning=self.ENABLE_TUNING, + bias=False, + layout="nt", + bits=self.quant_config.weight_bits, + ) + + # Initialize quantized weights with dimensions + # Quantized 4Bit weights packed. + qweight = PackedvLLMParameter( + data=torch.empty( + self.bitblas_matmul.retrieve_weight_shape(), + device="cuda", + dtype=self.quant_config.storage_torch_dtype, + requires_grad=False, + ), + input_dim=1, + output_dim=0, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + bitblas_tile_size=(self.bitblas_matmul.retrieve_weight_shape()[-2] + if self.bitblas_matmul.propagate_b else None), + weight_loader=weight_loader, + ) + + # Compute the number of input groups for channel-wise quantization. + input_groups = (1 if group_size == -1 else input_size_per_partition // + group_size) + + # Initialize scales and zeros for the quantized weights. + weight_scale_args = { + "data": + torch.empty( + output_size_per_partition, + input_groups, + device="cuda", + dtype=params_dtype, + ), + "weight_loader": + weight_loader + } + if input_groups == 1: + scales = ChannelQuantScaleParameter(output_dim=0, + **weight_scale_args) + else: + scales = GroupQuantScaleParameter(output_dim=0, + input_dim=1, + **weight_scale_args) + + if self.quant_config.zeros_mode == "quantized": + zeros = PackedvLLMParameter( + data=torch.empty( + input_groups, + output_size_per_partition // self.quant_config.pack_factor, + device="cuda", + dtype=self.quant_config.storage_torch_dtype, + requires_grad=False, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader, + ) + + else: + zeros = BasevLLMParameter( + torch.empty(output_size_per_partition, + input_groups, + device="cuda", + dtype=params_dtype), + weight_loader=weight_loader, + ) + # Set attributes to indicate how scales and zeros are applied. + set_weight_attrs(zeros, { + "input_dim": None if input_groups == 1 else 1, + "output_dim": 0, + }) + + layer.register_parameter("qweight", qweight) + layer.register_parameter("scales", scales) + layer.register_parameter("zeros", zeros) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + if self.quant_config.quant_method == "gptq": + return self.create_weights_gptq(layer, input_size_per_partition, + output_partition_sizes, input_size, + output_size, params_dtype, + **extra_weight_attrs) + else: + raise ValueError( + f"Unsupported quant_method {self.quant_config.quant_method}") + + def _configure_bitblas_matmul( + self, + infeatures, + outfeatures, + params_dtype, + enable_tuning, + bias, + layout, + bits, + out_dtype="float16", + ): + from bitblas import MatmulConfig + bitblas_dtype = self.BITBLAS_DTYPES[params_dtype] + + with_scaling = False + with_zeros = False + group_size = self.quant_config.group_size + zeros_mode = self.quant_config.zeros_mode + if self.quant_config.quant_method == "gptq": + with_scaling = True + with_zeros = True + W_dtype = f"uint{bits}" + if self.quant_config.is_sym: + with_zeros = False + W_dtype = f"int{bits}" + else: + raise ValueError( + f"Unsupported quant_method {self.quant_config.quant_method}") + + matmul_config = MatmulConfig( + N=outfeatures, + K=infeatures, + A_dtype=bitblas_dtype, + W_dtype=W_dtype, + out_dtype=out_dtype, + accum_dtype="int32" if bitblas_dtype == "int8" else bitblas_dtype, + storage_dtype=self.quant_config.STORAGE_DTYPE, + with_scaling=with_scaling, + with_zeros=with_zeros, + group_size=group_size, + with_bias=bias, + layout=layout, + zeros_mode=zeros_mode, + ) + self.bitblas_matmul = self._get_or_create_bitblas_operator( + matmul_config, enable_tuning) + + def _get_or_create_bitblas_operator(self, config, enable_tuning): + from bitblas import Matmul, auto_detect_nvidia_target + from bitblas.cache import get_database_path, global_operator_cache + BITBLAS_DATABASE_PATH = get_database_path() + BITBLAS_TARGET = auto_detect_nvidia_target() + if global_operator_cache.size() == 0: + global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH, + BITBLAS_TARGET) + + bitblas_matmul = global_operator_cache.get(config) + if bitblas_matmul is None: + bitblas_matmul = Matmul(config, + target=BITBLAS_TARGET, + enable_tuning=False) + if enable_tuning: + TUNING_MESSAGE = (f"BitBLAS Operator {config} is tuning ...") + logger.info(TUNING_MESSAGE) + bitblas_matmul.hardware_aware_finetune(topk=20) + global_operator_cache.add(config, bitblas_matmul) + global_operator_cache.save_into_database( + BITBLAS_DATABASE_PATH, BITBLAS_TARGET) + TUNED_MESSAGE = ( + f"BitBLAS Operator {config} tuned and saved to database.") + logger.info(TUNED_MESSAGE) + else: + _message = f"BitBLAS Operator {config} created." + logger.info(_message) + else: + _message = ( + f"BitBLAS Operator {config} found in global_operator_cache.") + logger.info(_message) + return bitblas_matmul + + def apply_gptq( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qweight = layer.qweight + scales = layer.scales + qzeros = layer.zeros + + x_2d = x.view(-1, x.shape[-1]) + + if self.quant_config.is_sym: + output_2d = self.bitblas_matmul(x_2d, qweight, scales) + else: + output_2d = self.bitblas_matmul(x_2d, qweight, scales, qzeros) + + output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) + + if bias is not None: + output.add_(bias) # In-place add + + return output + + def apply( + self, + *args: Any, + **kwargs: Any, + ) -> torch.Tensor: + if self.quant_config.quant_method == "gptq": + return self.apply_gptq(*args, **kwargs) + else: + raise ValueError( + f"Unsupported quant_method {self.quant_config.quant_method}") diff --git a/vllm/model_executor/layers/quantization/gptq_bitblas.py b/vllm/model_executor/layers/quantization/gptq_bitblas.py new file mode 100644 index 000000000000..88cada4c61b8 --- /dev/null +++ b/vllm/model_executor/layers/quantization/gptq_bitblas.py @@ -0,0 +1,438 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Any, Dict, List, Optional, Set + +import torch +from torch.nn.parameter import Parameter + +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + set_weight_attrs) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( + BitBLASLinearKernel, MPLinearLayerConfig) +from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( + BITBLAS_SUPPORTED_NUM_BITS as GPTQ_BITBLAS_SUPPORTED_NUM_BITS) +from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( + BITBLAS_SUPPORTED_SYM as GPTQ_BITBLAS_SUPPORTED_SYM) +from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( + MINIMUM_BITBLAS_VERSION, bitblas_repeat_scales_on_all_ranks, + check_bitblas_supported, verify_bitblas_supported) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.parameter import (ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedColumnParameter, + PackedvLLMParameter, + RowvLLMParameter) +from vllm.scalar_type import scalar_types + +logger = init_logger(__name__) + + +class GPTQBitBLASConfig(QuantizationConfig): + """Config class for GPTQ BitBLAS""" + + # (num_bits, is_sym) -> quant_type + TYPE_MAP = { + (4, True): scalar_types.uint4b8, + (8, True): scalar_types.uint8b128, + } + + TORCH_DTYPE = torch.float16 + GPTQ_CKPT_STORAGE_DTYPE = ( + "int32" # GPTQ Default Checkpoints use int32 as storage dtype + ) + GPTQ_BITBLAS_STORAGE_DTYPE = "int8" # BitBLAS uses int8 as storage dtype + TORCH_BITBLAS_STORAGE_DTYPE = getattr(torch, GPTQ_BITBLAS_STORAGE_DTYPE) + # "original" or "rescale" or "quantized", + # the gptq_bitblas prefer "quantized" + ZEROS_MODE = "quantized" + + def __init__( + self, + weight_bits: int, + group_size: int, + desc_act: bool, + is_sym: bool, + quant_method: Optional[str], + lm_head_quantized: bool, + ) -> None: + + try: + import bitblas + if bitblas.__version__ < MINIMUM_BITBLAS_VERSION: + raise ImportError( + "bitblas version is wrong. Please " + f"install bitblas>={MINIMUM_BITBLAS_VERSION}") + except ImportError as e: + bitblas_import_exception = e + raise ValueError( + "Trying to use the bitblas backend, but could not import" + f"with the following error: {bitblas_import_exception}. " + "Please install bitblas through the following command: " + f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`" + ) from bitblas_import_exception + + if desc_act and group_size == -1: + # In this case, act_order == True is the same as act_order == False + # (since we have only one group per output channel) + desc_act = False + + self.weight_bits = weight_bits + self.group_size = group_size + self.desc_act = desc_act + self.is_sym = is_sym + self.quant_method = quant_method + self.lm_head_quantized = lm_head_quantized + + # Verify + if self.weight_bits not in GPTQ_BITBLAS_SUPPORTED_NUM_BITS: + raise ValueError( + f"BitBLAS does not support weight_bits = {self.weight_bits}. " + f"Only weight_bits = {GPTQ_BITBLAS_SUPPORTED_NUM_BITS} " + "are supported.") + + if self.is_sym not in GPTQ_BITBLAS_SUPPORTED_SYM: + raise ValueError( + f"BitBLAS does not support is_sym = {self.is_sym}. " + f"Only sym = {GPTQ_BITBLAS_SUPPORTED_SYM} are supported.") + + self.storage_dtype = self.GPTQ_BITBLAS_STORAGE_DTYPE + + storage_nbit = int("".join(c for c in self.GPTQ_CKPT_STORAGE_DTYPE + if c.isdigit())) + + # 4 Bits packed into 32 bit datatype. + self.pack_factor = storage_nbit // weight_bits + self.nbits = weight_bits + + # Zeros type for the quantized weights. + self.zeros_mode = self.ZEROS_MODE + + if (weight_bits, is_sym) not in self.TYPE_MAP: + raise ValueError("Unsupported quantization config: " + f"bits={weight_bits}, sym={is_sym}") + + self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)] + + def __repr__(self) -> str: + return (f"GPTQBitBLASConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"desc_act={self.desc_act})" + f"is_sym={self.is_sym}, " + f"quant_method={self.quant_method})") + + @classmethod + def get_name(cls) -> str: + return "gptq_bitblas" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 70 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "GPTQBitBLASConfig": + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + desc_act = cls.get_from_keys(config, ["desc_act"]) + is_sym = cls.get_from_keys(config, ["sym"]) + quant_method = cls.get_from_keys(config, ["quant_method"]) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], + default=False) + return cls(weight_bits, group_size, desc_act, is_sym, quant_method, + lm_head_quantized) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + can_convert = cls.is_gptq_bitblas_compatible(hf_quant_cfg) + + is_valid_user_quant = (user_quant is None or user_quant == "bitblas" + or user_quant == "gptq_bitblas") + + if can_convert and is_valid_user_quant: + msg = ("The model is convertible to {} during runtime." + " Using {} kernel.".format(cls.get_name(), cls.get_name())) + logger.info(msg) + return cls.get_name() + + if can_convert and user_quant == "gptq": + logger.info("Detected that the model can run with gptq_bitblas" + ", however you specified quantization=gptq explicitly," + " so forcing gptq. Use quantization=gptq_bitblas for" + " faster inference") + return None + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["GPTQBitBLASLinearMethod"]: + if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead) + and self.lm_head_quantized): + return GPTQBitBLASLinearMethod(self) + return None + + @property + def torch_storage_dtype(self) -> torch.dtype: + return self.TORCH_BITBLAS_STORAGE_DTYPE + + @classmethod + def is_gptq_bitblas_compatible(cls, quant_config: Dict[str, Any]): + # Extract data from quant config. + num_bits = quant_config.get("bits") + group_size = quant_config.get("group_size") + sym = quant_config.get("sym") + desc_act = quant_config.get("desc_act") + + # If we cannot find the info needed in the config, cannot convert. + if (num_bits is None or group_size is None or sym is None + or desc_act is None): + return False + + if (num_bits, sym) not in cls.TYPE_MAP: + return False + + # If the capability of the device is too low, cannot convert. + major, minor = torch.cuda.get_device_capability() + device_capability = major * 10 + minor + if device_capability < cls.get_min_capability(): + return False + + # Otherwise, can convert if model satisfies bitblas constraints. + return check_bitblas_supported(quant_type=cls.TYPE_MAP[(num_bits, + sym)], + group_size=group_size) + + +class GPTQBitBLASLinearMethod(LinearMethodBase): + """Linear method for GPTQ BitBLAS. + + Args: + quant_config: The GPTQ BitBLAS quantization config. + """ + + kernel_type = BitBLASLinearKernel + _kernel_backends_being_used: Set[str] = set() + + def __init__(self, quant_config: GPTQBitBLASConfig) -> None: + self.quant_config = quant_config + # Verify supported on platform. + verify_bitblas_supported(quant_type=self.quant_config.quant_type, + group_size=self.quant_config.group_size) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + """Creates quantized weights for use in linear operations. + + The function initializes and returns a dictionary containing + quantized weights, scales, and zeros + for performing quantized matrix multiplication operations. + + Args: + input_size_per_partition: The size of the input partition. + output_partition_sizes: The size of the output partition. + input_size: The total size of the input (unused). + output_size: The total size of the output (unused). + params_dtype: + The data type of the parameters (expected to be torch.float16). + + Returns: + A dictionary containing the quantized weights ('qweight'), + scales ('scales'), and zeros ('zeros'). + + Raises: + ValueError: If `params_dtype` is not `torch.float16` or + if the input size per partition is not divisible by the + group size in `quant_config`. + """ + if params_dtype != torch.float16: + raise ValueError("Parameter data type must be torch.float16, " + f"but got {params_dtype}") + + # Normalize group_size + if self.quant_config.group_size != -1: + group_size = self.quant_config.group_size + else: + group_size = input_size + + if input_size_per_partition % group_size != 0: + raise ValueError( + f"Input size per partition ({input_size_per_partition}) must " + f"be divisible by group size ({self.quant_config.group_size})." + ) + + kernel_type = self.kernel_type + # Validate output_size_per_partition + output_size_per_partition = sum(output_partition_sizes) + + is_row_parallel = input_size != input_size_per_partition + weight_loader = extra_weight_attrs.get("weight_loader") + + mp_linear_kernel_config = MPLinearLayerConfig( + full_weight_shape=(input_size, output_size), + partition_weight_shape=\ + (input_size_per_partition, output_size_per_partition), + weight_type=self.quant_config.quant_type, + act_type=params_dtype, + group_size=self.quant_config.group_size, + zero_points=False, + has_g_idx=self.quant_config.desc_act + ) + + if kernel_type.__name__ not in self._kernel_backends_being_used: + logger.info("Using %s for GPTQBitBLASLinearMethod", + kernel_type.__name__) + self._kernel_backends_being_used.add(kernel_type.__name__) + + # Normalize group_size + if self.quant_config.group_size != -1: + group_size = self.quant_config.group_size + else: + group_size = input_size + + # Determine sharding + if bitblas_repeat_scales_on_all_ranks(self.quant_config.desc_act, + self.quant_config.group_size, + is_row_parallel): + # By setting scale_dim == None, weight_loader will + # repeat the scales on each GPU in TP>1 case. + scales_and_zp_input_dim = None + scales_and_zp_size = input_size // group_size + else: + # By setting scale_dim == 0, weight_loader will + # shard the scales in TP>1 case. + scales_and_zp_input_dim = 0 + scales_and_zp_size = input_size_per_partition // group_size + + # Init buffers + # Quantized weights + qweight = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition // self.quant_config.pack_factor, + output_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=0, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader) + + # Activation order + # Ignore warning from fused linear layers such as QKVParallelLinear. + g_idx = RowvLLMParameter(data=torch.empty( + input_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + weight_loader=weight_loader) + + # Scales + scales = Parameter( + torch.empty( + scales_and_zp_size, + output_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs( + scales, + { + **extra_weight_attrs, + "input_dim": scales_and_zp_input_dim, + "output_dim": 1, + }, + ) + + # Quantized zero-points + qzeros_args = { + "data": + torch.empty( + scales_and_zp_size, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + "weight_loader": + weight_loader + } + weight_scale_args = { + "data": + torch.empty( + scales_and_zp_size, + output_size_per_partition, + dtype=params_dtype, + ), + "weight_loader": + weight_loader + } + + if scales_and_zp_input_dim is None: + scales = ChannelQuantScaleParameter(output_dim=1, + **weight_scale_args) + qzeros = PackedColumnParameter( + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + **qzeros_args) + + else: + scales = GroupQuantScaleParameter(output_dim=1, + input_dim=0, + **weight_scale_args) + qzeros = PackedvLLMParameter( + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + **qzeros_args) + + layer.register_parameter("qweight", qweight) + layer.register_parameter("g_idx", g_idx) + layer.register_parameter("scales", scales) + layer.register_parameter("qzeros", qzeros) + + self.kernel = kernel_type( + mp_linear_kernel_config, + w_q_param_name="qweight", + w_s_param_name="scales", + w_zp_param_name="qzeros", + w_gidx_param_name="g_idx", + bitblas_quant_config=self.quant_config, + ) + + # Initialize or retrieve the BitBLAS matrix multiplication operator. + self.kernel.configure_bitblas_matmul( + input_size_per_partition, + output_size_per_partition, + params_dtype=params_dtype, + bias=False, + ) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.kernel.process_weights_after_loading(layer) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + out = self.kernel.apply_gptq_bitblas_linear(layer, x) + if bias is not None: + out.add_(bias) + return out diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py index 520e1bc96721..d144bb436104 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py @@ -5,6 +5,8 @@ from typing import List, Optional, Type import vllm.envs as envs from vllm.model_executor.layers.quantization.kernels.mixed_precision.allspark import ( # noqa: E501 AllSparkLinearKernel) +from vllm.model_executor.layers.quantization.kernels.mixed_precision.bitblas import ( # noqa: E501 + BitBLASLinearKernel) from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501 ExllamaLinearKernel) from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501 @@ -20,6 +22,7 @@ _POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [ MacheteLinearKernel, AllSparkLinearKernel, MarlinLinearKernel, + BitBLASLinearKernel, ExllamaLinearKernel, ] @@ -76,4 +79,4 @@ def choose_mp_linear_kernel( raise ValueError( "Failed to find a kernel that can implement the "\ "WNA16 linear layer. Reasons: \n" - + '\n'.join(failure_reasons)) + + '\n'.join(failure_reasons)) \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py new file mode 100644 index 000000000000..21452d08b8a1 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py @@ -0,0 +1,299 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Dict, List, Optional, Tuple + +import torch + +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( + BITBLAS_OPTIMIZE_FEATURES, BITBLAS_SUPPORTED_GROUP_SIZES, + MINIMUM_BITBLAS_VERSION, bitblas_make_empty_g_idx, bitblas_sort_g_idx, + check_bitblas_supports_shape, query_bitblas_supported_quant_types, + unpack_gptq_qweight, unpack_gptq_qzeros) + +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + +logger = init_logger(__name__) + + +class BitBLASLinearKernel(MPLinearKernel): + + OPT_FEATURES: List[int] = BITBLAS_OPTIMIZE_FEATURES + ENABLE_TUNING: bool = True + MATMUL_LAYOUT: str = "nt" + BITBLAS_DTYPES: Dict[torch.dtype, str] = { + torch.float32: "float32", + torch.float16: "float16", + torch.bfloat16: "bfloat16", + torch.half: "float16", + torch.int8: "int8", + } + bitblas_matmul: object = None + + def __init__( + self, + c: MPLinearLayerConfig, + w_q_param_name: str, + w_s_param_name: str, + w_zp_param_name: Optional[str] = None, + w_gidx_param_name: Optional[str] = None, + bitblas_quant_config: Optional[QuantizationConfig] = None, + ): + self.quant_config = bitblas_quant_config + super().__init__(c, w_q_param_name, w_s_param_name, w_zp_param_name, + w_gidx_param_name) + + def repack_bitblas_from_gptq( + self, + b_q_weight: torch.Tensor, + scales: torch.Tensor, + qzeros: Optional[torch.Tensor] = None, + ): + from bitblas.quantization.utils import general_compress + assert self.bitblas_matmul is not None, "bitblas_matmul is None" + + quant_config = self.quant_config + # qweight in gptq old quant linear stored with + # (outfeatures, infeatures), should be transposed. + qweight = b_q_weight.T.contiguous().view( + quant_config.torch_storage_dtype) # type: ignore[union-attr] + intweight = unpack_gptq_qweight( + qweight, + quant_config.weight_bits).contiguous() # type: ignore[union-attr] + if self.bitblas_matmul.weight_transform is not None: # type: ignore[attr-defined] + qweight = self.bitblas_matmul.weight_transform( # type: ignore[attr-defined] + intweight.cpu()).cuda() + # scales in gptq old quant linear stored with + # (infeatures // group_size, outfeatures), should be transposed. + scales = scales.T.contiguous() + + if qzeros is None: + return qweight, scales, None + + # qzeros should be de-quantized to int zeros. + weight_bits = quant_config.weight_bits # type: ignore[union-attr] + intzeros = unpack_gptq_qzeros(qzeros, weight_bits).T.contiguous() + zeros: Optional[torch.Tensor] = None + zeros_mode = self.bitblas_matmul.config.zeros_mode # type: ignore[attr-defined] + if zeros_mode == "original": + zeros = intzeros.to(torch.float16).contiguous() + elif zeros_mode == "rescale": + assert zeros is not None, "zeros should not be None" + zeros[:, :] = intzeros.to(torch.float16)[:, :] * scales[:, :] + elif zeros_mode == "quantized": + zeros = ( + torch.Tensor( + general_compress( + intzeros.T.contiguous().cpu().numpy(), + weight_bits, + )).to(qweight.device). + to(quant_config.torch_storage_dtype # type: ignore[union-attr] + ).contiguous()) + else: + raise ValueError("Unsupported zeros type: {}".format(zeros_mode)) + + return qweight, scales, zeros + + @classmethod + def get_min_capability(cls) -> int: + return 70 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + + is_bitblas_installed = True + + try: + import bitblas + if bitblas.__version__ < MINIMUM_BITBLAS_VERSION: + raise ImportError( + "bitblas version is wrong. Please " + f"install bitblas>={MINIMUM_BITBLAS_VERSION}") + except ImportError: + is_bitblas_installed = False + + if not is_bitblas_installed: + return False, "bitblas is not installed. Please install bitblas "\ + "by running `pip install bitblas>="\ + f"{MINIMUM_BITBLAS_VERSION}`" + + quant_types = query_bitblas_supported_quant_types(c.zero_points) + if c.weight_type not in quant_types: + return False, (f"Quant type ({c.weight_type}) not supported by" + f" BitBLAS, supported types are: {quant_types}") + + if c.group_size not in BITBLAS_SUPPORTED_GROUP_SIZES: + return False, (f"Group size ({c.group_size}) not supported by " + "BitBLAS, supported group sizes are: " + f"{BITBLAS_SUPPORTED_GROUP_SIZES}") + + return check_bitblas_supports_shape( + c.partition_weight_shape[1], # out_features + c.partition_weight_shape[0], # in_features + c.full_weight_shape[0], # in_features + c.group_size) + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + device = getattr(layer, self.w_q_name).device + c = self.config + quant_config = self.quant_config + + # Default names since bitblas requires empty parameters for these, + # TODO: remove this requirement from bitblas (allow optional tensors) + if self.w_gidx_name is None: + self.w_gidx_name = "g_idx" + if self.w_zp_name is None: + self.w_zp_name = "qzeros" + + if c.has_g_idx: + g_idx, g_idx_sort_indices = bitblas_sort_g_idx( + getattr(layer, self.w_gidx_name)) + self._transform_param(layer, self.w_gidx_name, lambda _: g_idx) + layer.g_idx_sort_indices = g_idx_sort_indices + else: + setattr(layer, self.w_gidx_name, bitblas_make_empty_g_idx(device)) + layer.g_idx_sort_indices = bitblas_make_empty_g_idx(device) + + if c.zero_points: + raise NotImplementedError("Zero points not supported by BitBLAS") + else: + setattr(layer, self.w_zp_name, bitblas_make_empty_g_idx(device)) + + # Repack weights + bitblas_qweight, bitblas_scales, bitblas_qzeros = ( + self.repack_bitblas_from_gptq( + layer.qweight, + layer.scales, + None if quant_config.is_sym else # type: ignore[union-attr] + layer.qzeros, # type: ignore[union-attr] + )) + replace_parameter(layer, self.w_q_name, bitblas_qweight) + replace_parameter(layer, self.w_s_name, bitblas_scales) + if bitblas_qzeros is not None: + replace_parameter(layer, self.w_zp_name, bitblas_qzeros) + + def configure_bitblas_matmul( + self, + infeatures: int, + outfeatures: int, + params_dtype: torch.dtype, + bias: bool, + ) -> None: + enable_tuning = self.ENABLE_TUNING + layout = self.MATMUL_LAYOUT + bits = self.quant_config.weight_bits # type: ignore[union-attr] + self._configure_bitblas_matmul( + infeatures, + outfeatures, + params_dtype, + enable_tuning, + bias, + layout, + bits, + ) + + def _configure_bitblas_matmul( + self, + infeatures, + outfeatures, + params_dtype, + enable_tuning, + bias, + layout, + bits, + ): + from bitblas import MatmulConfig + bitblas_dtype = self.BITBLAS_DTYPES[params_dtype] + quant_config = self.quant_config + with_scaling = False + with_zeros = False + group_size = quant_config.group_size # type: ignore[union-attr] + zeros_mode = quant_config.zeros_mode # type: ignore[union-attr] + if quant_config.quant_method == "gptq": # type: ignore[union-attr] + with_scaling = True + with_zeros = True + W_dtype = f"uint{bits}" + if quant_config.is_sym: # type: ignore[union-attr] + with_zeros = False + W_dtype = f"int{bits}" + else: + raise ValueError( + f"Unsupported quant_method {quant_config.quant_method}" # type: ignore[union-attr] + ) # type: ignore[union-attr] + + matmul_config = MatmulConfig( + M=self.OPT_FEATURES, + N=outfeatures, + K=infeatures, + A_dtype=bitblas_dtype, + W_dtype=W_dtype, + out_dtype=bitblas_dtype, + accum_dtype="int32" if bitblas_dtype == "int8" else bitblas_dtype, + storage_dtype=quant_config. # type: ignore[union-attr] + storage_dtype, # type: ignore[union-attr] + with_scaling=with_scaling, + with_zeros=with_zeros, + group_size=group_size, + with_bias=bias, + layout=layout, + zeros_mode=zeros_mode, + ) + self.bitblas_matmul = self._get_or_create_bitblas_operator( + matmul_config, enable_tuning) + + def _get_or_create_bitblas_operator(self, config, enable_tuning): + from bitblas import Matmul, auto_detect_nvidia_target + from bitblas.cache import get_database_path, global_operator_cache + BITBLAS_DATABASE_PATH = get_database_path() + BITBLAS_TARGET = auto_detect_nvidia_target() + + if global_operator_cache.size() == 0: + global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH, + BITBLAS_TARGET) + + bitblas_matmul = global_operator_cache.get(config) + if bitblas_matmul is None: + bitblas_matmul = Matmul(config, + target=BITBLAS_TARGET, + enable_tuning=False) + if enable_tuning: + bitblas_matmul.hardware_aware_finetune(topk=20) + global_operator_cache.add(config, bitblas_matmul) + global_operator_cache.save_into_database( + BITBLAS_DATABASE_PATH, BITBLAS_TARGET) + TUNING_MESSAGE = ( + f"BitBLAS Operator {config} tuned and saved to database.") + logger.info(TUNING_MESSAGE) + else: + _message = f"BitBLAS Operator {config} created without tuning. " + logger.info(_message) + else: + _message = f"BitBLAS Operator {config} retrieved from cache." + logger.info(_message) + return bitblas_matmul + + def apply_gptq_bitblas_linear( + self, + layer: torch.nn.Module, + x: torch.Tensor, + ) -> torch.Tensor: + output_size_per_partition = self.config.partition_weight_shape[1] + out_shape = x.shape[:-1] + (output_size_per_partition, ) + args = [x, layer.qweight, layer.scales] + if self.bitblas_matmul.config.with_zeros: # type: ignore[attr-defined] + args.append(layer.qzeros) + output = self.bitblas_matmul(*args) # type: ignore[operator] + return output.view(out_shape) + + def apply_weights(self, layer, x, bias=None): + NOT_IMPLEMENT_MESSAGE = ( + f"{self.__class__.__name__}.apply_weights is not implemented. " + "Please use BitBLASLinearKernel.apply_gptq_bitblas_linear instead") + raise NotImplementedError(NOT_IMPLEMENT_MESSAGE) diff --git a/vllm/model_executor/layers/quantization/utils/bitblas_utils.py b/vllm/model_executor/layers/quantization/utils/bitblas_utils.py new file mode 100644 index 000000000000..5d28d327e8a2 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/bitblas_utils.py @@ -0,0 +1,198 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Tuple + +import torch + +from vllm.platforms import current_platform +from vllm.scalar_type import ScalarType, scalar_types + +MINIMUM_BITBLAS_VERSION = "0.1.0" + +BITBLAS_MIN_WEIGHT_SIZE_N = 16 +BITBLAS_MIN_WEIGHT_SIZE_K = 16 +GPTQ_BITBLAS_MAX_PARALLEL = 16 + +BITBLAS_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + +# For dynamic shape code generation +BITBLAS_OPTIMIZE_FEATURES = [1, 16, 32, 64, 128, 256, 512, 1024] +# If want to enable high performance for contiguous batching +# Please use the following values +BITBLAS_OPTIMIZE_FEATURES_CONTIGUOUS = [16, 32, 64, 128, 256, 512, 1024] + +BITBLAS_SUPPORTED_NUM_BITS = [1, 2, 4, 8] +BITBLAS_SUPPORTED_SYM = [False, True] + + +# Determines the supported quantization types for BitBLAS based on the +# device's capability and whether zero-point (zp) is used. +def query_bitblas_supported_quant_types(has_zp: bool, + device_capability: Optional[int] = None + ): + if device_capability is None: + capability_tuple = current_platform.get_device_capability() + device_capability = (-1 if capability_tuple is None else + capability_tuple.to_int()) + + if device_capability < 70: + return [] + + if has_zp: + # AWQ style, unsigned + runtime zero-point + return [scalar_types.uint4, scalar_types.uint8] + else: + # GPTQ style, unsigned + symmetric bias + # TODO: once fp8_bitblas is merged into "gptq_bitblas" we should be able + # to add `scalar_types.float8_e4m3fn` here + return [scalar_types.uint4b8, scalar_types.uint8b128] + + +def _check_bitblas_supported( + quant_type: ScalarType, + group_size: Optional[int], + has_zp: bool, + device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]: + + if device_capability is None: + capability_tuple = current_platform.get_device_capability() + device_capability = (-1 if capability_tuple is None else + capability_tuple.to_int()) + + supported_types = query_bitblas_supported_quant_types( + has_zp, device_capability) + + if quant_type not in supported_types: + return (False, f"BitBLAS does not support weight_bits = {quant_type}. " + f"Only types = {supported_types} " + f"are supported (for group_size = {group_size}, " + f"device_capability = {device_capability}, zp = {has_zp}).") + if (group_size is None or group_size not in BITBLAS_SUPPORTED_GROUP_SIZES): + return (False, f"BitBLAS does not support group_size = {group_size}. " + f"Only group_sizes = {BITBLAS_SUPPORTED_GROUP_SIZES} " + "are supported.") + + return True, None + + +def check_bitblas_supported(quant_type: ScalarType, + group_size: int, + has_zp: bool = False, + device_capability: Optional[int] = None) -> bool: + cond, _ = _check_bitblas_supported(quant_type, group_size, has_zp, + device_capability) + return cond + + +def verify_bitblas_supported(quant_type: ScalarType, + group_size: int, + has_zp: bool = False) -> None: + cond, err_msg = _check_bitblas_supported(quant_type, group_size, has_zp) + if not cond: + assert err_msg is not None + raise ValueError(err_msg) + + +def verify_bitblas_supports_shape(output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, group_size: int) -> None: + + # Validate output_size_per_partition + if output_size_per_partition % BITBLAS_MIN_WEIGHT_SIZE_N != 0: + raise ValueError(f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f" min_thread_n = {BITBLAS_MIN_WEIGHT_SIZE_N}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") + + # Validate input_size_per_partition + if input_size_per_partition % BITBLAS_MIN_WEIGHT_SIZE_K != 0: + raise ValueError(f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible " + f"by min_thread_k = {BITBLAS_MIN_WEIGHT_SIZE_K}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") + + if (group_size < input_size + and input_size_per_partition % group_size != 0): + raise ValueError( + f"Weight input_size_per_partition = {input_size_per_partition}" + f" is not divisible by group_size = {group_size}." + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") + + +def check_bitblas_supports_shape(output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, group_size: int) \ + -> Tuple[bool, Optional[str]]: + try: + verify_bitblas_supports_shape(output_size_per_partition, + input_size_per_partition, input_size, + group_size) + except ValueError as e: + return False, e.__str__() + return True, None + + +def bitblas_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: + return (not act_order) or (act_order and not is_row_parallel) + + +def bitblas_repeat_scales_on_all_ranks(act_order: bool, group_size: int, + is_row_parallel: bool) -> bool: + # Need to repeat scales on every rank if act_ordering or + # channelwise and RowParallelLinear + is_channelwise = group_size == -1 + return act_order or (is_channelwise and is_row_parallel) + + +def bitblas_make_empty_g_idx(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) + + +def bitblas_make_empty_zp(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) + + +def bitblas_sort_g_idx( + g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) + return g_idx[g_idx_sort_indices], g_idx_sort_indices + + +def unpack_gptq_qzeros(qzeros, bits, is_gptq_v2=False) -> torch.Tensor: + qzeros = qzeros.view(torch.int32) + elems_per_int32 = 32 // bits + unpacked_zeros = torch.zeros( + (qzeros.shape[0], qzeros.shape[1] * elems_per_int32), + dtype=torch.int8, + device=qzeros.device, + requires_grad=False, + ) + + for col in range(unpacked_zeros.shape[1]): + i = col % elems_per_int32 + unpacked_zeros[:, col] = (qzeros[:, col // elems_per_int32] >> + (bits * i)) & 0xF + if not is_gptq_v2: + return unpacked_zeros + 1 + return unpacked_zeros + + +def unpack_gptq_qweight(qweight, bits): + qweight = qweight.view(torch.int8) + elems_per_int8 = 8 // bits + unpacked_weight = torch.zeros( + (qweight.shape[0], qweight.shape[1] * elems_per_int8), + dtype=torch.int8, + device=qweight.device, + requires_grad=False, + ) + for col in range(unpacked_weight.shape[1]): + i = col % elems_per_int8 + unpacked_weight[:, col] = (qweight[:, col // elems_per_int8] >> + (bits * i)) + + return torch.bitwise_and(unpacked_weight, 2**bits - 1) diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index 2b1294bf7baa..34a0b527b585 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -282,10 +282,12 @@ class PackedColumnParameter(_ColumnvLLMParameter): packed_factor: Union[int, Fraction], packed_dim: int, marlin_tile_size: Optional[int] = None, + bitblas_tile_size: Optional[int] = None, **kwargs): self._packed_factor = packed_factor self._packed_dim = packed_dim self._marlin_tile_size = marlin_tile_size + self._bitblas_tile_size = bitblas_tile_size super().__init__(**kwargs) @property @@ -300,12 +302,17 @@ class PackedColumnParameter(_ColumnvLLMParameter): def marlin_tile_size(self): return self._marlin_tile_size + @property + def bitblas_tile_size(self): + return self._bitblas_tile_size + def adjust_shard_indexes_for_packing(self, shard_size, shard_offset): return _adjust_shard_indexes_for_packing( shard_size=shard_size, shard_offset=shard_offset, packed_factor=self.packed_factor, - marlin_tile_size=self.marlin_tile_size) + marlin_tile_size=self.marlin_tile_size, + bitblas_tile_size=self.bitblas_tile_size) class PackedvLLMParameter(ModelWeightParameter): @@ -323,10 +330,12 @@ class PackedvLLMParameter(ModelWeightParameter): packed_factor: Union[int, Fraction], packed_dim: int, marlin_tile_size: Optional[int] = None, + bitblas_tile_size: Optional[int] = None, **kwargs): self._packed_factor = packed_factor self._packed_dim = packed_dim self._marlin_tile_size = marlin_tile_size + self._bitblas_tile_size = bitblas_tile_size super().__init__(**kwargs) @property @@ -341,12 +350,17 @@ class PackedvLLMParameter(ModelWeightParameter): def marlin_tile_size(self): return self._marlin_tile_size + @property + def bitblas_tile_size(self): + return self._bitblas_tile_size + def adjust_shard_indexes_for_packing(self, shard_size, shard_offset): return _adjust_shard_indexes_for_packing( shard_size=shard_size, shard_offset=shard_offset, packed_factor=self.packed_factor, - marlin_tile_size=self.marlin_tile_size) + marlin_tile_size=self.marlin_tile_size, + bitblas_tile_size=self.bitblas_tile_size) class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter): @@ -421,8 +435,13 @@ def _adjust_shard_indexes_for_marlin(shard_size, shard_offset, return shard_size * marlin_tile_size, shard_offset * marlin_tile_size +def _adjust_shard_indexes_for_bitblas(shard_size, shard_offset, + bitblas_tile_size): + return shard_size // bitblas_tile_size, shard_offset // bitblas_tile_size + + def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor, - marlin_tile_size): + marlin_tile_size, bitblas_tile_size): shard_size = shard_size // packed_factor shard_offset = shard_offset // packed_factor if marlin_tile_size is not None: @@ -430,4 +449,10 @@ def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor, shard_size=shard_size, shard_offset=shard_offset, marlin_tile_size=marlin_tile_size) - return shard_size, shard_offset + elif bitblas_tile_size is not None: + return _adjust_shard_indexes_for_bitblas( + shard_size=shard_size, + shard_offset=shard_offset, + bitblas_tile_size=bitblas_tile_size) + + return shard_size, shard_offset \ No newline at end of file