[Quantization][V1] BitsAndBytes support V1 (#15611)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-03-28 10:12:47 +08:00 committed by GitHub
parent bd45912b99
commit 726efc6a32
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 52 additions and 24 deletions

View File

@ -425,7 +425,6 @@ def test_bnb_regression(
max_model_len=4096, max_model_len=4096,
max_num_seqs=2, max_num_seqs=2,
quantization="bitsandbytes", quantization="bitsandbytes",
load_format="bitsandbytes",
) )
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=0, temperature=0,

View File

@ -72,7 +72,6 @@ def test_distributed(
"meta-llama/Llama-3.2-1B-Instruct", "meta-llama/Llama-3.2-1B-Instruct",
{ {
"quantization": "bitsandbytes", "quantization": "bitsandbytes",
"load_format": "bitsandbytes",
}, },
), ),
]) ])

View File

@ -101,8 +101,6 @@ def test_load_pp_4bit_bnb_model(model_name, description) -> None:
"--enable-prefix-caching", "--enable-prefix-caching",
"--quantization", "--quantization",
"bitsandbytes", "bitsandbytes",
"--load-format",
"bitsandbytes",
"--gpu-memory-utilization", "--gpu-memory-utilization",
"0.7", "0.7",
] ]
@ -137,7 +135,6 @@ def validate_generated_texts(hf_runner,
# when using distributed inference # when using distributed inference
with vllm_runner(model_name, with vllm_runner(model_name,
quantization='bitsandbytes', quantization='bitsandbytes',
load_format='bitsandbytes',
tensor_parallel_size=vllm_tp_size, tensor_parallel_size=vllm_tp_size,
enforce_eager=False) as llm: enforce_eager=False) as llm:
vllm_outputs = llm.generate_greedy(prompts, 8) vllm_outputs = llm.generate_greedy(prompts, 8)

View File

@ -682,8 +682,9 @@ class ModelConfig:
def _verify_bnb_config(self) -> None: def _verify_bnb_config(self) -> None:
""" """
The current version of bitsandbytes (0.44.0) with 8-bit models does not The current version of bitsandbytes (0.45.3) with 8-bit models does not
yet support CUDA graph. yet support CUDA graph.
# TODO Remove this when bitsandbytes supports.
""" """
is_bitsandbytes = self.quantization == "bitsandbytes" is_bitsandbytes = self.quantization == "bitsandbytes"
has_quantization_config = (getattr(self.hf_config, has_quantization_config = (getattr(self.hf_config,
@ -698,8 +699,9 @@ class ModelConfig:
not self.enforce_eager, not self.enforce_eager,
]): ]):
logger.warning( logger.warning(
"CUDA graph is not supported on BitAndBytes 8bit yet, " "CUDA graph is not supported on BitsAndBytes 8bit yet, "
"fallback to the eager mode.") "fallback to the eager mode.")
self.enforce_eager = True self.enforce_eager = True
def _verify_with_expert_parallelism(self) -> None: def _verify_with_expert_parallelism(self) -> None:

View File

@ -1616,7 +1616,7 @@ class EngineArgs:
return False return False
# Some quantization is not compatible with torch.compile. # Some quantization is not compatible with torch.compile.
V1_UNSUPPORTED_QUANT = ["bitsandbytes", "gguf"] V1_UNSUPPORTED_QUANT = ["gguf"]
if model_config.quantization in V1_UNSUPPORTED_QUANT: if model_config.quantization in V1_UNSUPPORTED_QUANT:
_raise_or_fallback( _raise_or_fallback(
feature_name=f"--quantization {model_config.quantization}", feature_name=f"--quantization {model_config.quantization}",

View File

@ -9,6 +9,7 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs) set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.utils import direct_register_custom_op
class BitsAndBytesConfig(QuantizationConfig): class BitsAndBytesConfig(QuantizationConfig):
@ -321,9 +322,6 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# only load the bitsandbytes module when needed
from bitsandbytes import matmul_4bit
original_type = x.dtype original_type = x.dtype
original_shape = x.shape original_shape = x.shape
reshape_after_matmul = False reshape_after_matmul = False
@ -343,19 +341,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
out_dim_1, out_dim_1,
dtype=torch.bfloat16, dtype=torch.bfloat16,
device=x.device) device=x.device)
apply_bnb_4bit(bf_x, qweight, offsets, out)
current_index = 0
for i in range(len(quant_states)):
output_size = quant_states[i].shape[0]
# It is more efficient to use out kwarg like
# matmul_4bit(..., out = ...). Infeasible now due to the bug
# https://github.com/TimDettmers/bitsandbytes/issues/1235.
# Need to change after the bug is fixed.
out[:, current_index:current_index + output_size] = matmul_4bit(
bf_x, qweight[offsets[i]:offsets[i + 1]].t(), quant_states[i])
current_index += output_size
out = out.to(original_type) out = out.to(original_type)
if reshape_after_matmul: if reshape_after_matmul:
@ -365,3 +351,46 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
out += bias out += bias
return out return out
def _apply_bnb_4bit(
x: torch.Tensor,
weight: torch.Tensor,
offsets: torch.Tensor,
out: torch.Tensor,
) -> None:
# only load the bitsandbytes module when needed
from bitsandbytes import matmul_4bit
quant_states = weight.bnb_quant_state
current_index = 0
for i in range(len(quant_states)):
output_size = quant_states[i].shape[0]
# It is more efficient to use out kwarg like
# matmul_4bit(..., out = ...). Infeasible now due to the bug
# https://github.com/TimDettmers/bitsandbytes/issues/1235.
# Need to change after the bug is fixed.
out[:, current_index:current_index + output_size] = matmul_4bit(
x, weight[offsets[i]:offsets[i + 1]].t(), quant_states[i])
current_index += output_size
def _apply_bnb_4bit_fake(
x: torch.Tensor,
weight: torch.Tensor,
offsets: torch.Tensor,
out: torch.Tensor,
) -> None:
return
try:
direct_register_custom_op(
op_name="apply_bnb_4bit",
op_func=_apply_bnb_4bit,
mutates_args=["out"],
fake_impl=_apply_bnb_4bit_fake,
)
apply_bnb_4bit = torch.ops.vllm.apply_bnb_4bit
except AttributeError as error:
raise error

View File

@ -1259,6 +1259,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
pack_ratio) pack_ratio)
offsets = np.concatenate(([0], np.cumsum(num_elements))) offsets = np.concatenate(([0], np.cumsum(num_elements)))
# Make torch infer_schema happy
offsets = torch.tensor(offsets).cpu()
set_weight_attrs(param, {"bnb_shard_offsets": offsets}) set_weight_attrs(param, {"bnb_shard_offsets": offsets})
if load_8bit: if load_8bit: