[bitsandbytes]: support read bnb pre-quantized model (#5753)

Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
dongmao zhang 2024-07-23 16:45:09 -07:00 committed by GitHub
parent 2f808e69ab
commit 87525fab92
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 143 additions and 39 deletions

View File

@ -105,6 +105,7 @@ Documentation
quantization/supported_hardware
quantization/auto_awq
quantization/bnb
quantization/fp8
quantization/fp8_e5m2_kvcache
quantization/fp8_e4m3_kvcache

View File

@ -0,0 +1,43 @@
.. _bits_and_bytes:
BitsAndBytes
==================
vLLM now supports `BitsAndBytes <https://github.com/TimDettmers/bitsandbytes>`_ for more efficient model inference.
BitsAndBytes quantizes models to reduce memory usage and enhance performance without significantly sacrificing accuracy.
Compared to other quantization methods, BitsAndBytes eliminates the need for calibrating the quantized model with input data.
Below are the steps to utilize BitsAndBytes with vLLM.
.. code-block:: console
$ pip install bitsandbytes>=0.42.0
vLLM reads the model's config file and supports both in-flight quantization and pre-quantized checkpoint.
You can find bitsandbytes quantized models on https://huggingface.co/models?other=bitsandbytes.
And usually, these repositories have a config.json file that includes a quantization_config section.
Read quantized checkpoint.
--------------------------
.. code-block:: python
from vllm import LLM
import torch
# unsloth/tinyllama-bnb-4bit is a pre-quantized checkpoint.
model_id = "unsloth/tinyllama-bnb-4bit"
llm = LLM(model=model_id, dtype=torch.bfloat16, trust_remote_code=True, \
quantization="bitsandbytes", load_format="bitsandbytes")
Inflight quantization: load as 4bit quantization
------------------------------------------------
.. code-block:: python
from vllm import LLM
import torch
model_id = "huggyllama/llama-7b"
llm = LLM(model=model_id, dtype=torch.bfloat16, trust_remote_code=True, \
quantization="bitsandbytes", load_format="bitsandbytes")

View File

@ -8,15 +8,20 @@ import torch
from tests.quantization.utils import is_quant_method_supported
from vllm import SamplingParams
models_to_test = [
('huggyllama/llama-7b', 'quantize model inflight'),
('lllyasviel/omost-llama-3-8b-4bits', 'read pre-quantized model'),
]
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
reason='bitsandbytes is not supported on this GPU type.')
def test_load_bnb_model(vllm_runner) -> None:
with vllm_runner('huggyllama/llama-7b',
@pytest.mark.parametrize("model_name, description", models_to_test)
def test_load_bnb_model(vllm_runner, model_name, description) -> None:
with vllm_runner(model_name,
quantization='bitsandbytes',
load_format='bitsandbytes',
enforce_eager=True) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
# check the weights in MLP & SelfAttention are quantized to torch.uint8
@ -65,12 +70,17 @@ def test_load_bnb_model(vllm_runner) -> None:
'To be or not to be, that is the question.'
]
outputs = llm.generate(prompts, sampling_params=sampling_params)
assert len(outputs) == len(prompts)
for index in range(len(outputs)):
# compare the first line of the output
actual_output = outputs[index][1][0].split('\n', 1)[0]
expected_output = expected_outputs[index].split('\n', 1)[0]
assert len(actual_output) >= len(expected_output), (
f'Actual {actual_output} should be larger than or equal to '
f'expected {expected_output}')
actual_output = actual_output[:len(expected_output)]
assert actual_output == expected_output, (
f'Expected: {expected_output}, but got: {actual_output}')

View File

@ -591,9 +591,11 @@ class LoadConfig:
mainly for profiling.
"tensorizer" will use CoreWeave's tensorizer library for
fast weight loading.
"bitsandbytes" will load nf4 type weights.
ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's
checkpoints.
"""
load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO

View File

@ -15,19 +15,11 @@ class BitsAndBytesConfig(QuantizationConfig):
Reference: https://arxiv.org/abs/2305.14314
"""
def __init__(
self,
adapter_name_or_path: str,
target_modules: List[str],
) -> None:
self.adapter_name_or_path = adapter_name_or_path
self.target_modules = target_modules
def __init__(self, ) -> None:
pass
def __repr__(self) -> str:
return (
f"BitsAndBytesConfig(adapter_name_or_path={self.adapter_name_or_path}"
)
return "BitsAndBytesConfig"
@classmethod
def get_name(self) -> str:
@ -49,16 +41,7 @@ class BitsAndBytesConfig(QuantizationConfig):
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "BitsAndBytesConfig":
adapter_name = cls.get_from_keys(config, ["adapter_name_or_path"])
default_target_modules = [
"gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj",
"o_proj"
]
if adapter_name == "":
target_modules = default_target_modules
else:
target_modules = cls.get_from_keys(config, ["target_modules"])
return cls(adapter_name, target_modules)
return cls()
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["BitsAndBytesLinearMethod"]:

View File

@ -702,8 +702,14 @@ class BitsAndBytesModelLoader(BaseModelLoader):
return hf_weights_files, matched_pattern == "*.safetensors"
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
if use_safetensors:
return safetensors_weights_iterator(hf_weights_files)
else:
return pt_weights_iterator(hf_weights_files)
def _get_quantized_weights_iterator(
self, model_name_or_path: str, revision: Optional[str]
self, model_name_or_path: str, revision: Optional[str], pre_quant: bool
) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str,
Any]]:
"""Get an iterator to the model weights with bitsandbytes quantization,
@ -712,6 +718,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# only load the bitsandbytes module when needed
try:
import bitsandbytes
from bitsandbytes.functional import QuantState
if bitsandbytes.__version__ < "0.42.0":
raise ImportError("bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.42.0.")
@ -725,13 +732,59 @@ class BitsAndBytesModelLoader(BaseModelLoader):
model_name_or_path, revision)
quant_state_dict = {}
if use_safetensors:
weight_iterator = safetensors_weights_iterator(hf_weights_files)
else:
weight_iterator = pt_weights_iterator(hf_weights_files)
def generator():
def quantized_checkpoint() -> Generator:
# First iterate over all quant state weights
weight_iterator = self._hf_weight_iter(hf_weights_files,
use_safetensors)
temp_state_dict = {}
for weight_name, weight_tensor in weight_iterator:
if weight_name.endswith(".weight"):
continue
# TODO: only nf4 quantization is supported for now
if weight_name.endswith(".quant_state.bitsandbytes__fp4"):
raise NotImplementedError(
"Only bitsandbytes_nf4 quantization"
f"is supported for now. {weight_name} is fp4 quantized"
)
temp_state_dict[weight_name] = weight_tensor
# Closure to parse quant_state for each prequant weight
def _parse_quant_state(param_name: str,
temp_state_dict: Dict) -> QuantState:
quant_state = {}
for k in temp_state_dict:
if param_name + "." in k:
quant_state[k] = temp_state_dict[k]
# bitsandbytes library requires
# weight.quant_state.bitsandbytes__nf4 in CPU
quant_state[param_name +
".quant_state.bitsandbytes__nf4"] = quant_state[
param_name +
".quant_state.bitsandbytes__nf4"].cpu().data
return QuantState.from_dict(quant_state, device="cuda")
# Second iterate over all prequant and normal weights
# pre quantized weights would have a quant_state
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
# Filter out all weights whose suffix is not ".weight"
if not weight_name.endswith(".weight"):
continue
if weight_name + ".quant_state.bitsandbytes__nf4" \
in temp_state_dict:
quant_state = _parse_quant_state(weight_name,
temp_state_dict)
weight_name = weight_name.replace(".weight", ".qweight")
quant_state_dict[weight_name] = quant_state
yield weight_name.replace(".weight",
".qweight"), weight_tensor
else:
yield weight_name, weight_tensor
def generator() -> Generator:
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
if any(target_module in weight_name
for target_module in self.target_modules):
weight_name = weight_name.replace(".weight", ".qweight")
@ -749,6 +802,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
yield weight_name, processed_weight
if pre_quant:
return quantized_checkpoint(), quant_state_dict
return generator(), quant_state_dict
def _load_weights(self, model_config: ModelConfig,
@ -766,12 +821,21 @@ class BitsAndBytesModelLoader(BaseModelLoader):
logger.info("Loading weights with BitsAndBytes quantization. "
" May take a while ...")
qweight_iterator, quant_state_dict = (
self._get_quantized_weights_iterator(model_config.model,
model_config.revision))
is_quantized_checkpoint = False
quant_config = getattr(model_config.hf_config, "quantization_config",
None)
if quant_config is not None and quant_config.get(
'quant_method') == "bitsandbytes":
is_quantized_checkpoint = True
qweight_iterator, quant_state_dict = \
self._get_quantized_weights_iterator(
model_config.model, model_config.revision, is_quantized_checkpoint)
model.load_weights(qweight_iterator)
torch.cuda.empty_cache()
param_dict = dict(model.named_parameters())
stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {}
for quant_param_name in quant_state_dict:
@ -809,9 +873,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
f"pack_factor not set for parameter {param_name}.")
num_elements = [0] * len(quant_states)
for seq, quant_state in enumerate(quant_states.items()):
for seq, quant_state in quant_states.items():
num_elements[seq] = math.prod(
quant_state[1].shape) // pack_ratio
quant_state.shape) // pack_ratio
offsets = np.concatenate(([0], np.cumsum(num_elements)))
set_weight_attrs(param, {"bnb_shard_offsets": offsets})

View File

@ -118,6 +118,7 @@ def convert_bin_to_safetensor_file(
# TODO(woosuk): Move this to other place.
def get_quant_config(model_config: ModelConfig,
load_config: LoadConfig) -> QuantizationConfig:
quant_cls = get_quantization_config(model_config.quantization)
# Read the quantization config from the HF model config, if available.
hf_quant_config = getattr(model_config.hf_config, "quantization_config",