[bitsandbytes]: support read bnb pre-quantized model (#5753)

Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
dongmao zhang 2024-07-23 16:45:09 -07:00 committed by GitHub
parent 2f808e69ab
commit 87525fab92
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 143 additions and 39 deletions

View File

@ -105,6 +105,7 @@ Documentation
quantization/supported_hardware quantization/supported_hardware
quantization/auto_awq quantization/auto_awq
quantization/bnb
quantization/fp8 quantization/fp8
quantization/fp8_e5m2_kvcache quantization/fp8_e5m2_kvcache
quantization/fp8_e4m3_kvcache quantization/fp8_e4m3_kvcache

View File

@ -0,0 +1,43 @@
.. _bits_and_bytes:
BitsAndBytes
==================
vLLM now supports `BitsAndBytes <https://github.com/TimDettmers/bitsandbytes>`_ for more efficient model inference.
BitsAndBytes quantizes models to reduce memory usage and enhance performance without significantly sacrificing accuracy.
Compared to other quantization methods, BitsAndBytes eliminates the need for calibrating the quantized model with input data.
Below are the steps to utilize BitsAndBytes with vLLM.
.. code-block:: console
$ pip install bitsandbytes>=0.42.0
vLLM reads the model's config file and supports both in-flight quantization and pre-quantized checkpoint.
You can find bitsandbytes quantized models on https://huggingface.co/models?other=bitsandbytes.
And usually, these repositories have a config.json file that includes a quantization_config section.
Read quantized checkpoint.
--------------------------
.. code-block:: python
from vllm import LLM
import torch
# unsloth/tinyllama-bnb-4bit is a pre-quantized checkpoint.
model_id = "unsloth/tinyllama-bnb-4bit"
llm = LLM(model=model_id, dtype=torch.bfloat16, trust_remote_code=True, \
quantization="bitsandbytes", load_format="bitsandbytes")
Inflight quantization: load as 4bit quantization
------------------------------------------------
.. code-block:: python
from vllm import LLM
import torch
model_id = "huggyllama/llama-7b"
llm = LLM(model=model_id, dtype=torch.bfloat16, trust_remote_code=True, \
quantization="bitsandbytes", load_format="bitsandbytes")

View File

@ -8,15 +8,20 @@ import torch
from tests.quantization.utils import is_quant_method_supported from tests.quantization.utils import is_quant_method_supported
from vllm import SamplingParams from vllm import SamplingParams
models_to_test = [
('huggyllama/llama-7b', 'quantize model inflight'),
('lllyasviel/omost-llama-3-8b-4bits', 'read pre-quantized model'),
]
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), @pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
reason='bitsandbytes is not supported on this GPU type.') reason='bitsandbytes is not supported on this GPU type.')
def test_load_bnb_model(vllm_runner) -> None: @pytest.mark.parametrize("model_name, description", models_to_test)
with vllm_runner('huggyllama/llama-7b', def test_load_bnb_model(vllm_runner, model_name, description) -> None:
with vllm_runner(model_name,
quantization='bitsandbytes', quantization='bitsandbytes',
load_format='bitsandbytes', load_format='bitsandbytes',
enforce_eager=True) as llm: enforce_eager=True) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
# check the weights in MLP & SelfAttention are quantized to torch.uint8 # check the weights in MLP & SelfAttention are quantized to torch.uint8
@ -65,12 +70,17 @@ def test_load_bnb_model(vllm_runner) -> None:
'To be or not to be, that is the question.' 'To be or not to be, that is the question.'
] ]
outputs = llm.generate(prompts, sampling_params=sampling_params) outputs = llm.generate(prompts, sampling_params=sampling_params)
assert len(outputs) == len(prompts) assert len(outputs) == len(prompts)
for index in range(len(outputs)): for index in range(len(outputs)):
# compare the first line of the output # compare the first line of the output
actual_output = outputs[index][1][0].split('\n', 1)[0] actual_output = outputs[index][1][0].split('\n', 1)[0]
expected_output = expected_outputs[index].split('\n', 1)[0] expected_output = expected_outputs[index].split('\n', 1)[0]
assert len(actual_output) >= len(expected_output), (
f'Actual {actual_output} should be larger than or equal to '
f'expected {expected_output}')
actual_output = actual_output[:len(expected_output)]
assert actual_output == expected_output, ( assert actual_output == expected_output, (
f'Expected: {expected_output}, but got: {actual_output}') f'Expected: {expected_output}, but got: {actual_output}')

View File

@ -591,9 +591,11 @@ class LoadConfig:
mainly for profiling. mainly for profiling.
"tensorizer" will use CoreWeave's tensorizer library for "tensorizer" will use CoreWeave's tensorizer library for
fast weight loading. fast weight loading.
"bitsandbytes" will load nf4 type weights.
ignore_patterns: The list of patterns to ignore when loading the model. ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's Default to "original/**/*" to avoid repeated loading of llama's
checkpoints. checkpoints.
""" """
load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO

View File

@ -15,19 +15,11 @@ class BitsAndBytesConfig(QuantizationConfig):
Reference: https://arxiv.org/abs/2305.14314 Reference: https://arxiv.org/abs/2305.14314
""" """
def __init__( def __init__(self, ) -> None:
self, pass
adapter_name_or_path: str,
target_modules: List[str],
) -> None:
self.adapter_name_or_path = adapter_name_or_path
self.target_modules = target_modules
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return "BitsAndBytesConfig"
f"BitsAndBytesConfig(adapter_name_or_path={self.adapter_name_or_path}"
)
@classmethod @classmethod
def get_name(self) -> str: def get_name(self) -> str:
@ -49,16 +41,7 @@ class BitsAndBytesConfig(QuantizationConfig):
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "BitsAndBytesConfig": def from_config(cls, config: Dict[str, Any]) -> "BitsAndBytesConfig":
adapter_name = cls.get_from_keys(config, ["adapter_name_or_path"]) return cls()
default_target_modules = [
"gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj",
"o_proj"
]
if adapter_name == "":
target_modules = default_target_modules
else:
target_modules = cls.get_from_keys(config, ["target_modules"])
return cls(adapter_name, target_modules)
def get_quant_method(self, layer: torch.nn.Module, def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["BitsAndBytesLinearMethod"]: prefix: str) -> Optional["BitsAndBytesLinearMethod"]:

View File

@ -702,8 +702,14 @@ class BitsAndBytesModelLoader(BaseModelLoader):
return hf_weights_files, matched_pattern == "*.safetensors" return hf_weights_files, matched_pattern == "*.safetensors"
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
if use_safetensors:
return safetensors_weights_iterator(hf_weights_files)
else:
return pt_weights_iterator(hf_weights_files)
def _get_quantized_weights_iterator( def _get_quantized_weights_iterator(
self, model_name_or_path: str, revision: Optional[str] self, model_name_or_path: str, revision: Optional[str], pre_quant: bool
) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str, ) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str,
Any]]: Any]]:
"""Get an iterator to the model weights with bitsandbytes quantization, """Get an iterator to the model weights with bitsandbytes quantization,
@ -712,6 +718,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# only load the bitsandbytes module when needed # only load the bitsandbytes module when needed
try: try:
import bitsandbytes import bitsandbytes
from bitsandbytes.functional import QuantState
if bitsandbytes.__version__ < "0.42.0": if bitsandbytes.__version__ < "0.42.0":
raise ImportError("bitsandbytes version is wrong. Please " raise ImportError("bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.42.0.") "install bitsandbytes>=0.42.0.")
@ -725,13 +732,59 @@ class BitsAndBytesModelLoader(BaseModelLoader):
model_name_or_path, revision) model_name_or_path, revision)
quant_state_dict = {} quant_state_dict = {}
if use_safetensors:
weight_iterator = safetensors_weights_iterator(hf_weights_files)
else:
weight_iterator = pt_weights_iterator(hf_weights_files)
def generator(): def quantized_checkpoint() -> Generator:
# First iterate over all quant state weights
weight_iterator = self._hf_weight_iter(hf_weights_files,
use_safetensors)
temp_state_dict = {}
for weight_name, weight_tensor in weight_iterator: for weight_name, weight_tensor in weight_iterator:
if weight_name.endswith(".weight"):
continue
# TODO: only nf4 quantization is supported for now
if weight_name.endswith(".quant_state.bitsandbytes__fp4"):
raise NotImplementedError(
"Only bitsandbytes_nf4 quantization"
f"is supported for now. {weight_name} is fp4 quantized"
)
temp_state_dict[weight_name] = weight_tensor
# Closure to parse quant_state for each prequant weight
def _parse_quant_state(param_name: str,
temp_state_dict: Dict) -> QuantState:
quant_state = {}
for k in temp_state_dict:
if param_name + "." in k:
quant_state[k] = temp_state_dict[k]
# bitsandbytes library requires
# weight.quant_state.bitsandbytes__nf4 in CPU
quant_state[param_name +
".quant_state.bitsandbytes__nf4"] = quant_state[
param_name +
".quant_state.bitsandbytes__nf4"].cpu().data
return QuantState.from_dict(quant_state, device="cuda")
# Second iterate over all prequant and normal weights
# pre quantized weights would have a quant_state
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
# Filter out all weights whose suffix is not ".weight"
if not weight_name.endswith(".weight"):
continue
if weight_name + ".quant_state.bitsandbytes__nf4" \
in temp_state_dict:
quant_state = _parse_quant_state(weight_name,
temp_state_dict)
weight_name = weight_name.replace(".weight", ".qweight")
quant_state_dict[weight_name] = quant_state
yield weight_name.replace(".weight",
".qweight"), weight_tensor
else:
yield weight_name, weight_tensor
def generator() -> Generator:
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
if any(target_module in weight_name if any(target_module in weight_name
for target_module in self.target_modules): for target_module in self.target_modules):
weight_name = weight_name.replace(".weight", ".qweight") weight_name = weight_name.replace(".weight", ".qweight")
@ -749,6 +802,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
yield weight_name, processed_weight yield weight_name, processed_weight
if pre_quant:
return quantized_checkpoint(), quant_state_dict
return generator(), quant_state_dict return generator(), quant_state_dict
def _load_weights(self, model_config: ModelConfig, def _load_weights(self, model_config: ModelConfig,
@ -766,12 +821,21 @@ class BitsAndBytesModelLoader(BaseModelLoader):
logger.info("Loading weights with BitsAndBytes quantization. " logger.info("Loading weights with BitsAndBytes quantization. "
" May take a while ...") " May take a while ...")
qweight_iterator, quant_state_dict = ( is_quantized_checkpoint = False
self._get_quantized_weights_iterator(model_config.model, quant_config = getattr(model_config.hf_config, "quantization_config",
model_config.revision)) None)
if quant_config is not None and quant_config.get(
'quant_method') == "bitsandbytes":
is_quantized_checkpoint = True
qweight_iterator, quant_state_dict = \
self._get_quantized_weights_iterator(
model_config.model, model_config.revision, is_quantized_checkpoint)
model.load_weights(qweight_iterator) model.load_weights(qweight_iterator)
torch.cuda.empty_cache()
param_dict = dict(model.named_parameters()) param_dict = dict(model.named_parameters())
stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {} stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {}
for quant_param_name in quant_state_dict: for quant_param_name in quant_state_dict:
@ -809,9 +873,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
f"pack_factor not set for parameter {param_name}.") f"pack_factor not set for parameter {param_name}.")
num_elements = [0] * len(quant_states) num_elements = [0] * len(quant_states)
for seq, quant_state in enumerate(quant_states.items()): for seq, quant_state in quant_states.items():
num_elements[seq] = math.prod( num_elements[seq] = math.prod(
quant_state[1].shape) // pack_ratio quant_state.shape) // pack_ratio
offsets = np.concatenate(([0], np.cumsum(num_elements))) offsets = np.concatenate(([0], np.cumsum(num_elements)))
set_weight_attrs(param, {"bnb_shard_offsets": offsets}) set_weight_attrs(param, {"bnb_shard_offsets": offsets})

View File

@ -118,6 +118,7 @@ def convert_bin_to_safetensor_file(
# TODO(woosuk): Move this to other place. # TODO(woosuk): Move this to other place.
def get_quant_config(model_config: ModelConfig, def get_quant_config(model_config: ModelConfig,
load_config: LoadConfig) -> QuantizationConfig: load_config: LoadConfig) -> QuantizationConfig:
quant_cls = get_quantization_config(model_config.quantization) quant_cls = get_quantization_config(model_config.quantization)
# Read the quantization config from the HF model config, if available. # Read the quantization config from the HF model config, if available.
hf_quant_config = getattr(model_config.hf_config, "quantization_config", hf_quant_config = getattr(model_config.hf_config, "quantization_config",