mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-03 18:35:42 +08:00
[Quantization][V1] BitsAndBytes support V1 (#15611)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
bd45912b99
commit
726efc6a32
@ -425,7 +425,6 @@ def test_bnb_regression(
|
|||||||
max_model_len=4096,
|
max_model_len=4096,
|
||||||
max_num_seqs=2,
|
max_num_seqs=2,
|
||||||
quantization="bitsandbytes",
|
quantization="bitsandbytes",
|
||||||
load_format="bitsandbytes",
|
|
||||||
)
|
)
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
temperature=0,
|
temperature=0,
|
||||||
|
|||||||
@ -72,7 +72,6 @@ def test_distributed(
|
|||||||
"meta-llama/Llama-3.2-1B-Instruct",
|
"meta-llama/Llama-3.2-1B-Instruct",
|
||||||
{
|
{
|
||||||
"quantization": "bitsandbytes",
|
"quantization": "bitsandbytes",
|
||||||
"load_format": "bitsandbytes",
|
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
])
|
])
|
||||||
|
|||||||
@ -101,8 +101,6 @@ def test_load_pp_4bit_bnb_model(model_name, description) -> None:
|
|||||||
"--enable-prefix-caching",
|
"--enable-prefix-caching",
|
||||||
"--quantization",
|
"--quantization",
|
||||||
"bitsandbytes",
|
"bitsandbytes",
|
||||||
"--load-format",
|
|
||||||
"bitsandbytes",
|
|
||||||
"--gpu-memory-utilization",
|
"--gpu-memory-utilization",
|
||||||
"0.7",
|
"0.7",
|
||||||
]
|
]
|
||||||
@ -137,7 +135,6 @@ def validate_generated_texts(hf_runner,
|
|||||||
# when using distributed inference
|
# when using distributed inference
|
||||||
with vllm_runner(model_name,
|
with vllm_runner(model_name,
|
||||||
quantization='bitsandbytes',
|
quantization='bitsandbytes',
|
||||||
load_format='bitsandbytes',
|
|
||||||
tensor_parallel_size=vllm_tp_size,
|
tensor_parallel_size=vllm_tp_size,
|
||||||
enforce_eager=False) as llm:
|
enforce_eager=False) as llm:
|
||||||
vllm_outputs = llm.generate_greedy(prompts, 8)
|
vllm_outputs = llm.generate_greedy(prompts, 8)
|
||||||
|
|||||||
@ -682,8 +682,9 @@ class ModelConfig:
|
|||||||
|
|
||||||
def _verify_bnb_config(self) -> None:
|
def _verify_bnb_config(self) -> None:
|
||||||
"""
|
"""
|
||||||
The current version of bitsandbytes (0.44.0) with 8-bit models does not
|
The current version of bitsandbytes (0.45.3) with 8-bit models does not
|
||||||
yet support CUDA graph.
|
yet support CUDA graph.
|
||||||
|
# TODO Remove this when bitsandbytes supports.
|
||||||
"""
|
"""
|
||||||
is_bitsandbytes = self.quantization == "bitsandbytes"
|
is_bitsandbytes = self.quantization == "bitsandbytes"
|
||||||
has_quantization_config = (getattr(self.hf_config,
|
has_quantization_config = (getattr(self.hf_config,
|
||||||
@ -698,8 +699,9 @@ class ModelConfig:
|
|||||||
not self.enforce_eager,
|
not self.enforce_eager,
|
||||||
]):
|
]):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"CUDA graph is not supported on BitAndBytes 8bit yet, "
|
"CUDA graph is not supported on BitsAndBytes 8bit yet, "
|
||||||
"fallback to the eager mode.")
|
"fallback to the eager mode.")
|
||||||
|
|
||||||
self.enforce_eager = True
|
self.enforce_eager = True
|
||||||
|
|
||||||
def _verify_with_expert_parallelism(self) -> None:
|
def _verify_with_expert_parallelism(self) -> None:
|
||||||
|
|||||||
@ -1616,7 +1616,7 @@ class EngineArgs:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# Some quantization is not compatible with torch.compile.
|
# Some quantization is not compatible with torch.compile.
|
||||||
V1_UNSUPPORTED_QUANT = ["bitsandbytes", "gguf"]
|
V1_UNSUPPORTED_QUANT = ["gguf"]
|
||||||
if model_config.quantization in V1_UNSUPPORTED_QUANT:
|
if model_config.quantization in V1_UNSUPPORTED_QUANT:
|
||||||
_raise_or_fallback(
|
_raise_or_fallback(
|
||||||
feature_name=f"--quantization {model_config.quantization}",
|
feature_name=f"--quantization {model_config.quantization}",
|
||||||
|
|||||||
@ -9,6 +9,7 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
|||||||
set_weight_attrs)
|
set_weight_attrs)
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
|
|
||||||
class BitsAndBytesConfig(QuantizationConfig):
|
class BitsAndBytesConfig(QuantizationConfig):
|
||||||
@ -321,9 +322,6 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
|
||||||
# only load the bitsandbytes module when needed
|
|
||||||
from bitsandbytes import matmul_4bit
|
|
||||||
|
|
||||||
original_type = x.dtype
|
original_type = x.dtype
|
||||||
original_shape = x.shape
|
original_shape = x.shape
|
||||||
reshape_after_matmul = False
|
reshape_after_matmul = False
|
||||||
@ -343,19 +341,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
|
|||||||
out_dim_1,
|
out_dim_1,
|
||||||
dtype=torch.bfloat16,
|
dtype=torch.bfloat16,
|
||||||
device=x.device)
|
device=x.device)
|
||||||
|
apply_bnb_4bit(bf_x, qweight, offsets, out)
|
||||||
current_index = 0
|
|
||||||
for i in range(len(quant_states)):
|
|
||||||
output_size = quant_states[i].shape[0]
|
|
||||||
# It is more efficient to use out kwarg like
|
|
||||||
# matmul_4bit(..., out = ...). Infeasible now due to the bug
|
|
||||||
# https://github.com/TimDettmers/bitsandbytes/issues/1235.
|
|
||||||
# Need to change after the bug is fixed.
|
|
||||||
out[:, current_index:current_index + output_size] = matmul_4bit(
|
|
||||||
bf_x, qweight[offsets[i]:offsets[i + 1]].t(), quant_states[i])
|
|
||||||
|
|
||||||
current_index += output_size
|
|
||||||
|
|
||||||
out = out.to(original_type)
|
out = out.to(original_type)
|
||||||
|
|
||||||
if reshape_after_matmul:
|
if reshape_after_matmul:
|
||||||
@ -365,3 +351,46 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
|
|||||||
out += bias
|
out += bias
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_bnb_4bit(
|
||||||
|
x: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
offsets: torch.Tensor,
|
||||||
|
out: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
# only load the bitsandbytes module when needed
|
||||||
|
from bitsandbytes import matmul_4bit
|
||||||
|
quant_states = weight.bnb_quant_state
|
||||||
|
current_index = 0
|
||||||
|
for i in range(len(quant_states)):
|
||||||
|
output_size = quant_states[i].shape[0]
|
||||||
|
# It is more efficient to use out kwarg like
|
||||||
|
# matmul_4bit(..., out = ...). Infeasible now due to the bug
|
||||||
|
# https://github.com/TimDettmers/bitsandbytes/issues/1235.
|
||||||
|
# Need to change after the bug is fixed.
|
||||||
|
out[:, current_index:current_index + output_size] = matmul_4bit(
|
||||||
|
x, weight[offsets[i]:offsets[i + 1]].t(), quant_states[i])
|
||||||
|
current_index += output_size
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_bnb_4bit_fake(
|
||||||
|
x: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
offsets: torch.Tensor,
|
||||||
|
out: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="apply_bnb_4bit",
|
||||||
|
op_func=_apply_bnb_4bit,
|
||||||
|
mutates_args=["out"],
|
||||||
|
fake_impl=_apply_bnb_4bit_fake,
|
||||||
|
)
|
||||||
|
apply_bnb_4bit = torch.ops.vllm.apply_bnb_4bit
|
||||||
|
|
||||||
|
except AttributeError as error:
|
||||||
|
raise error
|
||||||
|
|||||||
@ -1259,6 +1259,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
pack_ratio)
|
pack_ratio)
|
||||||
|
|
||||||
offsets = np.concatenate(([0], np.cumsum(num_elements)))
|
offsets = np.concatenate(([0], np.cumsum(num_elements)))
|
||||||
|
# Make torch infer_schema happy
|
||||||
|
offsets = torch.tensor(offsets).cpu()
|
||||||
set_weight_attrs(param, {"bnb_shard_offsets": offsets})
|
set_weight_attrs(param, {"bnb_shard_offsets": offsets})
|
||||||
|
|
||||||
if load_8bit:
|
if load_8bit:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user