mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:35:17 +08:00
[bitsandbytes]: support read bnb pre-quantized model (#5753)
Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
parent
2f808e69ab
commit
87525fab92
@ -105,6 +105,7 @@ Documentation
|
||||
|
||||
quantization/supported_hardware
|
||||
quantization/auto_awq
|
||||
quantization/bnb
|
||||
quantization/fp8
|
||||
quantization/fp8_e5m2_kvcache
|
||||
quantization/fp8_e4m3_kvcache
|
||||
|
||||
43
docs/source/quantization/bnb.rst
Normal file
43
docs/source/quantization/bnb.rst
Normal file
@ -0,0 +1,43 @@
|
||||
.. _bits_and_bytes:
|
||||
|
||||
BitsAndBytes
|
||||
==================
|
||||
|
||||
vLLM now supports `BitsAndBytes <https://github.com/TimDettmers/bitsandbytes>`_ for more efficient model inference.
|
||||
BitsAndBytes quantizes models to reduce memory usage and enhance performance without significantly sacrificing accuracy.
|
||||
Compared to other quantization methods, BitsAndBytes eliminates the need for calibrating the quantized model with input data.
|
||||
|
||||
Below are the steps to utilize BitsAndBytes with vLLM.
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ pip install bitsandbytes>=0.42.0
|
||||
|
||||
vLLM reads the model's config file and supports both in-flight quantization and pre-quantized checkpoint.
|
||||
|
||||
You can find bitsandbytes quantized models on https://huggingface.co/models?other=bitsandbytes.
|
||||
And usually, these repositories have a config.json file that includes a quantization_config section.
|
||||
|
||||
Read quantized checkpoint.
|
||||
--------------------------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from vllm import LLM
|
||||
import torch
|
||||
# unsloth/tinyllama-bnb-4bit is a pre-quantized checkpoint.
|
||||
model_id = "unsloth/tinyllama-bnb-4bit"
|
||||
llm = LLM(model=model_id, dtype=torch.bfloat16, trust_remote_code=True, \
|
||||
quantization="bitsandbytes", load_format="bitsandbytes")
|
||||
|
||||
Inflight quantization: load as 4bit quantization
|
||||
------------------------------------------------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from vllm import LLM
|
||||
import torch
|
||||
model_id = "huggyllama/llama-7b"
|
||||
llm = LLM(model=model_id, dtype=torch.bfloat16, trust_remote_code=True, \
|
||||
quantization="bitsandbytes", load_format="bitsandbytes")
|
||||
|
||||
@ -8,15 +8,20 @@ import torch
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm import SamplingParams
|
||||
|
||||
models_to_test = [
|
||||
('huggyllama/llama-7b', 'quantize model inflight'),
|
||||
('lllyasviel/omost-llama-3-8b-4bits', 'read pre-quantized model'),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
|
||||
reason='bitsandbytes is not supported on this GPU type.')
|
||||
def test_load_bnb_model(vllm_runner) -> None:
|
||||
with vllm_runner('huggyllama/llama-7b',
|
||||
@pytest.mark.parametrize("model_name, description", models_to_test)
|
||||
def test_load_bnb_model(vllm_runner, model_name, description) -> None:
|
||||
with vllm_runner(model_name,
|
||||
quantization='bitsandbytes',
|
||||
load_format='bitsandbytes',
|
||||
enforce_eager=True) as llm:
|
||||
|
||||
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
||||
|
||||
# check the weights in MLP & SelfAttention are quantized to torch.uint8
|
||||
@ -65,12 +70,17 @@ def test_load_bnb_model(vllm_runner) -> None:
|
||||
'To be or not to be, that is the question.'
|
||||
]
|
||||
outputs = llm.generate(prompts, sampling_params=sampling_params)
|
||||
|
||||
assert len(outputs) == len(prompts)
|
||||
|
||||
for index in range(len(outputs)):
|
||||
# compare the first line of the output
|
||||
actual_output = outputs[index][1][0].split('\n', 1)[0]
|
||||
expected_output = expected_outputs[index].split('\n', 1)[0]
|
||||
|
||||
assert len(actual_output) >= len(expected_output), (
|
||||
f'Actual {actual_output} should be larger than or equal to '
|
||||
f'expected {expected_output}')
|
||||
actual_output = actual_output[:len(expected_output)]
|
||||
|
||||
assert actual_output == expected_output, (
|
||||
f'Expected: {expected_output}, but got: {actual_output}')
|
||||
|
||||
@ -591,9 +591,11 @@ class LoadConfig:
|
||||
mainly for profiling.
|
||||
"tensorizer" will use CoreWeave's tensorizer library for
|
||||
fast weight loading.
|
||||
"bitsandbytes" will load nf4 type weights.
|
||||
ignore_patterns: The list of patterns to ignore when loading the model.
|
||||
Default to "original/**/*" to avoid repeated loading of llama's
|
||||
checkpoints.
|
||||
|
||||
"""
|
||||
|
||||
load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
|
||||
|
||||
@ -676,8 +676,8 @@ class EngineArgs:
|
||||
# bitsandbytes quantization needs a specific model loader
|
||||
# so we make sure the quant method and the load format are consistent
|
||||
if (self.quantization == "bitsandbytes" or
|
||||
self.qlora_adapter_name_or_path is not None) and \
|
||||
self.load_format != "bitsandbytes":
|
||||
self.qlora_adapter_name_or_path is not None) and \
|
||||
self.load_format != "bitsandbytes":
|
||||
raise ValueError(
|
||||
"BitsAndBytes quantization and QLoRA adapter only support "
|
||||
f"'bitsandbytes' load format, but got {self.load_format}")
|
||||
|
||||
@ -15,19 +15,11 @@ class BitsAndBytesConfig(QuantizationConfig):
|
||||
Reference: https://arxiv.org/abs/2305.14314
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
adapter_name_or_path: str,
|
||||
target_modules: List[str],
|
||||
) -> None:
|
||||
|
||||
self.adapter_name_or_path = adapter_name_or_path
|
||||
self.target_modules = target_modules
|
||||
def __init__(self, ) -> None:
|
||||
pass
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"BitsAndBytesConfig(adapter_name_or_path={self.adapter_name_or_path}"
|
||||
)
|
||||
return "BitsAndBytesConfig"
|
||||
|
||||
@classmethod
|
||||
def get_name(self) -> str:
|
||||
@ -49,16 +41,7 @@ class BitsAndBytesConfig(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "BitsAndBytesConfig":
|
||||
adapter_name = cls.get_from_keys(config, ["adapter_name_or_path"])
|
||||
default_target_modules = [
|
||||
"gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj",
|
||||
"o_proj"
|
||||
]
|
||||
if adapter_name == "":
|
||||
target_modules = default_target_modules
|
||||
else:
|
||||
target_modules = cls.get_from_keys(config, ["target_modules"])
|
||||
return cls(adapter_name, target_modules)
|
||||
return cls()
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["BitsAndBytesLinearMethod"]:
|
||||
|
||||
@ -702,8 +702,14 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
|
||||
return hf_weights_files, matched_pattern == "*.safetensors"
|
||||
|
||||
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
|
||||
if use_safetensors:
|
||||
return safetensors_weights_iterator(hf_weights_files)
|
||||
else:
|
||||
return pt_weights_iterator(hf_weights_files)
|
||||
|
||||
def _get_quantized_weights_iterator(
|
||||
self, model_name_or_path: str, revision: Optional[str]
|
||||
self, model_name_or_path: str, revision: Optional[str], pre_quant: bool
|
||||
) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str,
|
||||
Any]]:
|
||||
"""Get an iterator to the model weights with bitsandbytes quantization,
|
||||
@ -712,6 +718,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
# only load the bitsandbytes module when needed
|
||||
try:
|
||||
import bitsandbytes
|
||||
from bitsandbytes.functional import QuantState
|
||||
if bitsandbytes.__version__ < "0.42.0":
|
||||
raise ImportError("bitsandbytes version is wrong. Please "
|
||||
"install bitsandbytes>=0.42.0.")
|
||||
@ -725,17 +732,63 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
model_name_or_path, revision)
|
||||
|
||||
quant_state_dict = {}
|
||||
if use_safetensors:
|
||||
weight_iterator = safetensors_weights_iterator(hf_weights_files)
|
||||
else:
|
||||
weight_iterator = pt_weights_iterator(hf_weights_files)
|
||||
|
||||
def generator():
|
||||
def quantized_checkpoint() -> Generator:
|
||||
# First iterate over all quant state weights
|
||||
weight_iterator = self._hf_weight_iter(hf_weights_files,
|
||||
use_safetensors)
|
||||
temp_state_dict = {}
|
||||
for weight_name, weight_tensor in weight_iterator:
|
||||
if weight_name.endswith(".weight"):
|
||||
continue
|
||||
# TODO: only nf4 quantization is supported for now
|
||||
if weight_name.endswith(".quant_state.bitsandbytes__fp4"):
|
||||
raise NotImplementedError(
|
||||
"Only bitsandbytes_nf4 quantization"
|
||||
f"is supported for now. {weight_name} is fp4 quantized"
|
||||
)
|
||||
temp_state_dict[weight_name] = weight_tensor
|
||||
|
||||
# Closure to parse quant_state for each prequant weight
|
||||
def _parse_quant_state(param_name: str,
|
||||
temp_state_dict: Dict) -> QuantState:
|
||||
quant_state = {}
|
||||
for k in temp_state_dict:
|
||||
if param_name + "." in k:
|
||||
quant_state[k] = temp_state_dict[k]
|
||||
# bitsandbytes library requires
|
||||
# weight.quant_state.bitsandbytes__nf4 in CPU
|
||||
quant_state[param_name +
|
||||
".quant_state.bitsandbytes__nf4"] = quant_state[
|
||||
param_name +
|
||||
".quant_state.bitsandbytes__nf4"].cpu().data
|
||||
return QuantState.from_dict(quant_state, device="cuda")
|
||||
|
||||
# Second iterate over all prequant and normal weights
|
||||
# pre quantized weights would have a quant_state
|
||||
for weight_name, weight_tensor in self._hf_weight_iter(
|
||||
hf_weights_files, use_safetensors):
|
||||
# Filter out all weights whose suffix is not ".weight"
|
||||
if not weight_name.endswith(".weight"):
|
||||
continue
|
||||
if weight_name + ".quant_state.bitsandbytes__nf4" \
|
||||
in temp_state_dict:
|
||||
quant_state = _parse_quant_state(weight_name,
|
||||
temp_state_dict)
|
||||
weight_name = weight_name.replace(".weight", ".qweight")
|
||||
quant_state_dict[weight_name] = quant_state
|
||||
yield weight_name.replace(".weight",
|
||||
".qweight"), weight_tensor
|
||||
else:
|
||||
yield weight_name, weight_tensor
|
||||
|
||||
def generator() -> Generator:
|
||||
for weight_name, weight_tensor in self._hf_weight_iter(
|
||||
hf_weights_files, use_safetensors):
|
||||
if any(target_module in weight_name
|
||||
for target_module in self.target_modules):
|
||||
weight_name = weight_name.replace(".weight", ".qweight")
|
||||
# bitsandbytes requires data in GPU
|
||||
# bitsandbytes requires data in GPU
|
||||
loaded_weight = weight_tensor.cuda().data
|
||||
with set_default_torch_dtype(torch.float32):
|
||||
processed_weight, quant_state = quantize_4bit(
|
||||
@ -749,6 +802,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
|
||||
yield weight_name, processed_weight
|
||||
|
||||
if pre_quant:
|
||||
return quantized_checkpoint(), quant_state_dict
|
||||
return generator(), quant_state_dict
|
||||
|
||||
def _load_weights(self, model_config: ModelConfig,
|
||||
@ -766,12 +821,21 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
logger.info("Loading weights with BitsAndBytes quantization. "
|
||||
" May take a while ...")
|
||||
|
||||
qweight_iterator, quant_state_dict = (
|
||||
self._get_quantized_weights_iterator(model_config.model,
|
||||
model_config.revision))
|
||||
is_quantized_checkpoint = False
|
||||
quant_config = getattr(model_config.hf_config, "quantization_config",
|
||||
None)
|
||||
if quant_config is not None and quant_config.get(
|
||||
'quant_method') == "bitsandbytes":
|
||||
is_quantized_checkpoint = True
|
||||
|
||||
qweight_iterator, quant_state_dict = \
|
||||
self._get_quantized_weights_iterator(
|
||||
model_config.model, model_config.revision, is_quantized_checkpoint)
|
||||
|
||||
model.load_weights(qweight_iterator)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
param_dict = dict(model.named_parameters())
|
||||
stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {}
|
||||
for quant_param_name in quant_state_dict:
|
||||
@ -809,9 +873,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
f"pack_factor not set for parameter {param_name}.")
|
||||
|
||||
num_elements = [0] * len(quant_states)
|
||||
for seq, quant_state in enumerate(quant_states.items()):
|
||||
for seq, quant_state in quant_states.items():
|
||||
num_elements[seq] = math.prod(
|
||||
quant_state[1].shape) // pack_ratio
|
||||
quant_state.shape) // pack_ratio
|
||||
|
||||
offsets = np.concatenate(([0], np.cumsum(num_elements)))
|
||||
set_weight_attrs(param, {"bnb_shard_offsets": offsets})
|
||||
|
||||
@ -118,6 +118,7 @@ def convert_bin_to_safetensor_file(
|
||||
# TODO(woosuk): Move this to other place.
|
||||
def get_quant_config(model_config: ModelConfig,
|
||||
load_config: LoadConfig) -> QuantizationConfig:
|
||||
|
||||
quant_cls = get_quantization_config(model_config.quantization)
|
||||
# Read the quantization config from the HF model config, if available.
|
||||
hf_quant_config = getattr(model_config.hf_config, "quantization_config",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user