mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-21 00:54:34 +08:00
[Quantization] Modify the logic of BNB double quantization (#19742)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
8d1e89d946
commit
4959915089
@ -492,8 +492,6 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
raise ValueError("Following weights were not initialized from "
|
||||
f"checkpoint: {weights_not_loaded}")
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
param_dict = dict(model.named_parameters())
|
||||
stacked_quant_state_dict: dict[str, dict[int, Any]] = {}
|
||||
# TODO: Change this lazy import to normal import
|
||||
@ -545,6 +543,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
for param_name, param in param_dict.items():
|
||||
if param_name in stacked_quant_state_dict:
|
||||
quant_states = stacked_quant_state_dict[param_name]
|
||||
# Dequantize double quantized values during weight loading.
|
||||
dequantize_dq(quant_states)
|
||||
set_weight_attrs(param, {"bnb_quant_state": quant_states})
|
||||
|
||||
pack_ratio = getattr(param, "pack_factor", -1)
|
||||
@ -565,6 +565,28 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
if load_8bit:
|
||||
set_weight_attrs(
|
||||
param, {"matmul_state": [None] * len(quant_states)})
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
self._prepare_weights(model_config.model, model_config.revision)
|
||||
|
||||
|
||||
def dequantize_dq(quant_states: dict) -> None:
|
||||
"""
|
||||
When BNB employs Double Quantization, we perform the dequantization of
|
||||
these constants during weight loading rather than at inference time,
|
||||
thereby avoiding this computational overhead during inference. This comes
|
||||
at the cost of increased memory usage.
|
||||
"""
|
||||
from bitsandbytes.functional import dequantize_blockwise
|
||||
for _, quant_state in quant_states.items():
|
||||
# Copied from: https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.45.3/bitsandbytes/functional.py#L1352-#L1356
|
||||
if quant_state.nested:
|
||||
absmax = dequantize_blockwise(quant_state.absmax,
|
||||
quant_state.state2)
|
||||
absmax += quant_state.offset
|
||||
if absmax.dtype != torch.float32:
|
||||
absmax = absmax.float()
|
||||
quant_state.absmax = absmax
|
||||
quant_state.nested = False
|
||||
quant_state.offset = None
|
||||
quant_state.state2 = None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user