[Quantization] Improve BitsAndBytesModelLoader (#20242)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-06-30 18:22:09 +08:00 committed by GitHub
parent e936e401de
commit 8fe7fc8634
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -20,8 +20,6 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
# yapf: enable
from vllm.logger import init_logger
# yapf conflicts with isort for this block
# yapf: disable
from vllm.model_executor.layers.linear import (LinearBase,
MergedColumnParallelLinear,
QKVParallelLinear,
@ -39,6 +37,8 @@ from vllm.model_executor.utils import (get_packed_modules_mapping,
set_weight_attrs)
from vllm.platforms import current_platform
# yapf conflicts with isort for this block
logger = init_logger(__name__)
@ -54,11 +54,17 @@ class BitsAndBytesModelLoader(BaseModelLoader):
self.unsharded_weights_modules: list[str] = []
# Save the module names that are sharded by column.
self.column_sharded_weights_modules: list[str] = []
# Modules whose weights might have fused on disk
# we need their output_sizes to make shard in flight correctly with TP
self.maybe_fused_weights_modules: dict[str, list[int]] = {}
# Store all module names (from transformers) that support
# BNB quantization.
self.target_modules: list[str] = []
# mapping weight names from transformers to vllm.
self.weight_mapper: Callable = lambda name: name
self.pre_quant: bool = False
self.load_8bit: bool = False
self.is_pool_model: bool = False
def _get_weight_files(
self,
@ -134,13 +140,14 @@ class BitsAndBytesModelLoader(BaseModelLoader):
return hf_weights_files, use_safetensors
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
def _maybe_pool_model(module_name:str):
def _maybe_pool_model(module_name: str):
# For pool model, we need to add the prefix `model.`
# for the weight name if possible.
if self.is_pool_model and self.target_modules[0]. \
startswith("model.") and not module_name.startswith(
"model."):
return "model."+module_name
return "model." + module_name
return module_name
@ -159,8 +166,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# mapping weight names from transformers to vllm while preserving
# original names.
mapped_name = self.weight_mapper(org_name)
mapped_name=_maybe_pool_model(mapped_name)
mapped_name = _maybe_pool_model(mapped_name)
yield org_name, mapped_name, param
@ -168,8 +174,6 @@ class BitsAndBytesModelLoader(BaseModelLoader):
self,
model_name_or_path: str,
revision: Optional[str],
pre_quant: bool,
load_8bit: bool,
) -> tuple[Generator[tuple[str, torch.Tensor], None, None], dict[str,
Any]]:
"""Get an iterator to the model weights with bitsandbytes quantization,
@ -192,8 +196,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
quant_state_dict: dict[str, Any] = {}
if pre_quant:
if load_8bit:
if self.pre_quant:
if self.load_8bit:
return self._quantized_8bit_generator(
hf_weights_files, use_safetensors,
quant_state_dict), quant_state_dict
@ -390,10 +394,13 @@ class BitsAndBytesModelLoader(BaseModelLoader):
yield org_weight_name, processed_weight
def _get_bnb_target_modules(self, model: nn.Module) -> None:
"""
Identify and collect all modules that support BitsAndBytes
quantization.
"""
for name, module in model.named_modules():
if (isinstance(module, LinearBase) and
hasattr(module.quant_method, "quant_config")):
if (isinstance(module, LinearBase)
and hasattr(module.quant_method, "quant_config")):
if modules_info := self.modules_mapping.get_sub_modules(name):
# Map vllm's names to transformers's names.
rep_name, sub_modules = modules_info
@ -409,29 +416,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
), "vllm currently does not support BNB quantization for"
f" {type(model).__name__}"
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
if not hasattr(model, "load_weights"):
raise AttributeError(
"The required method 'load_weights' is not defined in class"
f" {type(model).__name__}.")
if not hasattr(model, "packed_modules_mapping"):
raise AttributeError(
f"Model {type(model).__name__} does not support BitsAndBytes "
"quantization yet. No 'packed_modules_mapping' found.")
self.is_pool_model=is_pooling_model(model)
self.modules_mapping = ParamMapping(get_packed_modules_mapping(model))
# For some models like Molmo, we need to use hf_to_vllm_mapper
# to ensure correct loading of weights.
if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name)
# Modules whose weights might have fused on disk
# we need their output_sizes to make shard in flight correctly with TP
self.maybe_fused_weights_modules: dict[str, list[int]] = {}
self._get_bnb_target_modules(model)
def _classify_module_sharding(self, model: nn.Module):
"""
Categorize modules based on their weight sharding requirements
for tensor parallelism.
"""
for name, module in model.named_modules():
# Some modules like `ReplicatedLinear` should not have their weights
# sharded. The reason for implementing it this way is to avoid new
@ -449,19 +438,27 @@ class BitsAndBytesModelLoader(BaseModelLoader):
elif isinstance(module, (RowParallelLinear, )):
self.column_sharded_weights_modules.append(name)
self.model_type = type(model).__name__
def _verify_model_compatibility(self, model: nn.Module,
model_config: ModelConfig) -> None:
"""
Verify that the model is compatible with BitsAndBytes quantization.
"""
if not hasattr(model, "load_weights"):
raise AttributeError(
"The required method 'load_weights' is not defined in class"
f" {type(model).__name__}.")
logger.info("Loading weights with BitsAndBytes quantization. "
"May take a while ...")
if not hasattr(model, "packed_modules_mapping"):
raise AttributeError(
f"Model {type(model).__name__} does not support BitsAndBytes "
"quantization yet. No 'packed_modules_mapping' found.")
quant_config = getattr(model_config.hf_config, "quantization_config",
None)
pre_quant = False
if quant_config is not None:
quant_method = quant_config.get("quant_method")
if quant_method == "bitsandbytes":
pre_quant = True
self.pre_quant = True
else:
raise ValueError(
f"BitsAndBytes loader does not support {quant_method} "
@ -469,20 +466,43 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# The quant_states in pre_quantized models cannot work with a split
# weight tensor. So TP does not work with pre_quantized bnb models.
if pre_quant and get_tensor_model_parallel_world_size() > 1:
if self.pre_quant and get_tensor_model_parallel_world_size() > 1:
raise ValueError(
"Prequant BitsAndBytes models with tensor parallelism is not "
"supported. Please try with pipeline parallelism.")
if self.pre_quant:
self.load_8bit = quant_config.get("load_in_8bit", False)
load_8bit = False
if pre_quant:
load_8bit = quant_config.get("load_in_8bit", False)
def _initialize_loader_state(self, model: nn.Module,
model_config: ModelConfig) -> None:
"""
Initialize the loader's internal state based on the model and
configuration.
"""
self.is_pool_model = is_pooling_model(model)
self.modules_mapping = ParamMapping(get_packed_modules_mapping(model))
# For some models like Molmo, we need to use hf_to_vllm_mapper
# to ensure correct loading of weights.
if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name)
self._get_bnb_target_modules(model)
self._classify_module_sharding(model)
def load_weights(self, model: nn.Module,
model_config: ModelConfig) -> None:
self._verify_model_compatibility(model, model_config)
self._initialize_loader_state(model, model_config)
logger.info("Loading weights with BitsAndBytes quantization. "
"May take a while ...")
qweight_iterator, quant_state_dict = (
self._get_quantized_weights_iterator(model_config.model,
model_config.revision,
pre_quant, load_8bit))
self._get_quantized_weights_iterator(
model_config.model,
model_config.revision,
))
weights_to_load = {name for name, _ in model.named_parameters()}
loaded_weights = model.load_weights(qweight_iterator)
# Some models may have weights loading tracker unimplemented.
@ -562,10 +582,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
offsets = torch.tensor(offsets).cpu()
set_weight_attrs(param, {"bnb_shard_offsets": offsets})
if load_8bit:
if self.load_8bit:
set_weight_attrs(
param, {"matmul_state": [None] * len(quant_states)})
torch.cuda.empty_cache()
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision)