[Quantization] Modify the logic of BNB double quantization (#19742)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-06-19 11:52:09 +08:00 committed by GitHub
parent 8d1e89d946
commit 4959915089
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -492,8 +492,6 @@ class BitsAndBytesModelLoader(BaseModelLoader):
raise ValueError("Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}")
torch.cuda.empty_cache()
param_dict = dict(model.named_parameters())
stacked_quant_state_dict: dict[str, dict[int, Any]] = {}
# TODO: Change this lazy import to normal import
@ -545,6 +543,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
for param_name, param in param_dict.items():
if param_name in stacked_quant_state_dict:
quant_states = stacked_quant_state_dict[param_name]
# Dequantize double quantized values during weight loading.
dequantize_dq(quant_states)
set_weight_attrs(param, {"bnb_quant_state": quant_states})
pack_ratio = getattr(param, "pack_factor", -1)
@ -565,6 +565,28 @@ class BitsAndBytesModelLoader(BaseModelLoader):
if load_8bit:
set_weight_attrs(
param, {"matmul_state": [None] * len(quant_states)})
torch.cuda.empty_cache()
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision)
def dequantize_dq(quant_states: dict) -> None:
"""
When BNB employs Double Quantization, we perform the dequantization of
these constants during weight loading rather than at inference time,
thereby avoiding this computational overhead during inference. This comes
at the cost of increased memory usage.
"""
from bitsandbytes.functional import dequantize_blockwise
for _, quant_state in quant_states.items():
# Copied from: https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.45.3/bitsandbytes/functional.py#L1352-#L1356
if quant_state.nested:
absmax = dequantize_blockwise(quant_state.absmax,
quant_state.state2)
absmax += quant_state.offset
if absmax.dtype != torch.float32:
absmax = absmax.float()
quant_state.absmax = absmax
quant_state.nested = False
quant_state.offset = None
quant_state.state2 = None