From 726efc6a320ad9a4ef0b0378b40abbd0561ea394 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Fri, 28 Mar 2025 10:12:47 +0800 Subject: [PATCH] [Quantization][V1] BitsAndBytes support V1 (#15611) Signed-off-by: Jee Jee Li --- .../vision_language/test_mllama.py | 1 - tests/models/test_transformers.py | 1 - tests/quantization/test_bitsandbytes.py | 3 - vllm/config.py | 6 +- vllm/engine/arg_utils.py | 2 +- .../layers/quantization/bitsandbytes.py | 61 ++++++++++++++----- vllm/model_executor/model_loader/loader.py | 2 + 7 files changed, 52 insertions(+), 24 deletions(-) diff --git a/tests/models/encoder_decoder/vision_language/test_mllama.py b/tests/models/encoder_decoder/vision_language/test_mllama.py index ae7a7b028b152..260d2c1093879 100644 --- a/tests/models/encoder_decoder/vision_language/test_mllama.py +++ b/tests/models/encoder_decoder/vision_language/test_mllama.py @@ -425,7 +425,6 @@ def test_bnb_regression( max_model_len=4096, max_num_seqs=2, quantization="bitsandbytes", - load_format="bitsandbytes", ) sampling_params = SamplingParams( temperature=0, diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py index c45fc7e649ec8..65bb11d6b5e4e 100644 --- a/tests/models/test_transformers.py +++ b/tests/models/test_transformers.py @@ -72,7 +72,6 @@ def test_distributed( "meta-llama/Llama-3.2-1B-Instruct", { "quantization": "bitsandbytes", - "load_format": "bitsandbytes", }, ), ]) diff --git a/tests/quantization/test_bitsandbytes.py b/tests/quantization/test_bitsandbytes.py index 1b6a918401487..533b055ee6d53 100644 --- a/tests/quantization/test_bitsandbytes.py +++ b/tests/quantization/test_bitsandbytes.py @@ -101,8 +101,6 @@ def test_load_pp_4bit_bnb_model(model_name, description) -> None: "--enable-prefix-caching", "--quantization", "bitsandbytes", - "--load-format", - "bitsandbytes", "--gpu-memory-utilization", "0.7", ] @@ -137,7 +135,6 @@ def validate_generated_texts(hf_runner, # when using distributed inference with vllm_runner(model_name, quantization='bitsandbytes', - load_format='bitsandbytes', tensor_parallel_size=vllm_tp_size, enforce_eager=False) as llm: vllm_outputs = llm.generate_greedy(prompts, 8) diff --git a/vllm/config.py b/vllm/config.py index 831fa2e4b06eb..5c73ff56ebbcf 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -682,8 +682,9 @@ class ModelConfig: def _verify_bnb_config(self) -> None: """ - The current version of bitsandbytes (0.44.0) with 8-bit models does not + The current version of bitsandbytes (0.45.3) with 8-bit models does not yet support CUDA graph. + # TODO Remove this when bitsandbytes supports. """ is_bitsandbytes = self.quantization == "bitsandbytes" has_quantization_config = (getattr(self.hf_config, @@ -698,8 +699,9 @@ class ModelConfig: not self.enforce_eager, ]): logger.warning( - "CUDA graph is not supported on BitAndBytes 8bit yet, " + "CUDA graph is not supported on BitsAndBytes 8bit yet, " "fallback to the eager mode.") + self.enforce_eager = True def _verify_with_expert_parallelism(self) -> None: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index a3b83c65a604a..d049f773caccd 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1616,7 +1616,7 @@ class EngineArgs: return False # Some quantization is not compatible with torch.compile. - V1_UNSUPPORTED_QUANT = ["bitsandbytes", "gguf"] + V1_UNSUPPORTED_QUANT = ["gguf"] if model_config.quantization in V1_UNSUPPORTED_QUANT: _raise_or_fallback( feature_name=f"--quantization {model_config.quantization}", diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index 1e8e7aa1b8c12..f5d32efe83688 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -9,6 +9,7 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.utils import direct_register_custom_op class BitsAndBytesConfig(QuantizationConfig): @@ -321,9 +322,6 @@ class BitsAndBytesLinearMethod(LinearMethodBase): x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - # only load the bitsandbytes module when needed - from bitsandbytes import matmul_4bit - original_type = x.dtype original_shape = x.shape reshape_after_matmul = False @@ -343,19 +341,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase): out_dim_1, dtype=torch.bfloat16, device=x.device) - - current_index = 0 - for i in range(len(quant_states)): - output_size = quant_states[i].shape[0] - # It is more efficient to use out kwarg like - # matmul_4bit(..., out = ...). Infeasible now due to the bug - # https://github.com/TimDettmers/bitsandbytes/issues/1235. - # Need to change after the bug is fixed. - out[:, current_index:current_index + output_size] = matmul_4bit( - bf_x, qweight[offsets[i]:offsets[i + 1]].t(), quant_states[i]) - - current_index += output_size - + apply_bnb_4bit(bf_x, qweight, offsets, out) out = out.to(original_type) if reshape_after_matmul: @@ -365,3 +351,46 @@ class BitsAndBytesLinearMethod(LinearMethodBase): out += bias return out + + +def _apply_bnb_4bit( + x: torch.Tensor, + weight: torch.Tensor, + offsets: torch.Tensor, + out: torch.Tensor, +) -> None: + # only load the bitsandbytes module when needed + from bitsandbytes import matmul_4bit + quant_states = weight.bnb_quant_state + current_index = 0 + for i in range(len(quant_states)): + output_size = quant_states[i].shape[0] + # It is more efficient to use out kwarg like + # matmul_4bit(..., out = ...). Infeasible now due to the bug + # https://github.com/TimDettmers/bitsandbytes/issues/1235. + # Need to change after the bug is fixed. + out[:, current_index:current_index + output_size] = matmul_4bit( + x, weight[offsets[i]:offsets[i + 1]].t(), quant_states[i]) + current_index += output_size + + +def _apply_bnb_4bit_fake( + x: torch.Tensor, + weight: torch.Tensor, + offsets: torch.Tensor, + out: torch.Tensor, +) -> None: + return + + +try: + direct_register_custom_op( + op_name="apply_bnb_4bit", + op_func=_apply_bnb_4bit, + mutates_args=["out"], + fake_impl=_apply_bnb_4bit_fake, + ) + apply_bnb_4bit = torch.ops.vllm.apply_bnb_4bit + +except AttributeError as error: + raise error diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index c969f18b822c4..5649cf2dd2cf1 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -1259,6 +1259,8 @@ class BitsAndBytesModelLoader(BaseModelLoader): pack_ratio) offsets = np.concatenate(([0], np.cumsum(num_elements))) + # Make torch infer_schema happy + offsets = torch.tensor(offsets).cpu() set_weight_attrs(param, {"bnb_shard_offsets": offsets}) if load_8bit: