mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 20:24:34 +08:00
[Quantization] Improve BitsAndBytesModelLoader (#20242)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
e936e401de
commit
8fe7fc8634
@ -20,8 +20,6 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
# yapf: enable
|
||||
from vllm.logger import init_logger
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
@ -39,6 +37,8 @@ from vllm.model_executor.utils import (get_packed_modules_mapping,
|
||||
set_weight_attrs)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# yapf conflicts with isort for this block
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -54,11 +54,17 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
self.unsharded_weights_modules: list[str] = []
|
||||
# Save the module names that are sharded by column.
|
||||
self.column_sharded_weights_modules: list[str] = []
|
||||
# Modules whose weights might have fused on disk
|
||||
# we need their output_sizes to make shard in flight correctly with TP
|
||||
self.maybe_fused_weights_modules: dict[str, list[int]] = {}
|
||||
# Store all module names (from transformers) that support
|
||||
# BNB quantization.
|
||||
self.target_modules: list[str] = []
|
||||
# mapping weight names from transformers to vllm.
|
||||
self.weight_mapper: Callable = lambda name: name
|
||||
self.pre_quant: bool = False
|
||||
self.load_8bit: bool = False
|
||||
self.is_pool_model: bool = False
|
||||
|
||||
def _get_weight_files(
|
||||
self,
|
||||
@ -134,13 +140,14 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
return hf_weights_files, use_safetensors
|
||||
|
||||
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
|
||||
def _maybe_pool_model(module_name:str):
|
||||
|
||||
def _maybe_pool_model(module_name: str):
|
||||
# For pool model, we need to add the prefix `model.`
|
||||
# for the weight name if possible.
|
||||
if self.is_pool_model and self.target_modules[0]. \
|
||||
startswith("model.") and not module_name.startswith(
|
||||
"model."):
|
||||
return "model."+module_name
|
||||
return "model." + module_name
|
||||
|
||||
return module_name
|
||||
|
||||
@ -159,8 +166,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
# mapping weight names from transformers to vllm while preserving
|
||||
# original names.
|
||||
mapped_name = self.weight_mapper(org_name)
|
||||
mapped_name=_maybe_pool_model(mapped_name)
|
||||
|
||||
mapped_name = _maybe_pool_model(mapped_name)
|
||||
|
||||
yield org_name, mapped_name, param
|
||||
|
||||
@ -168,8 +174,6 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
revision: Optional[str],
|
||||
pre_quant: bool,
|
||||
load_8bit: bool,
|
||||
) -> tuple[Generator[tuple[str, torch.Tensor], None, None], dict[str,
|
||||
Any]]:
|
||||
"""Get an iterator to the model weights with bitsandbytes quantization,
|
||||
@ -192,8 +196,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
|
||||
quant_state_dict: dict[str, Any] = {}
|
||||
|
||||
if pre_quant:
|
||||
if load_8bit:
|
||||
if self.pre_quant:
|
||||
if self.load_8bit:
|
||||
return self._quantized_8bit_generator(
|
||||
hf_weights_files, use_safetensors,
|
||||
quant_state_dict), quant_state_dict
|
||||
@ -390,10 +394,13 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
yield org_weight_name, processed_weight
|
||||
|
||||
def _get_bnb_target_modules(self, model: nn.Module) -> None:
|
||||
|
||||
"""
|
||||
Identify and collect all modules that support BitsAndBytes
|
||||
quantization.
|
||||
"""
|
||||
for name, module in model.named_modules():
|
||||
if (isinstance(module, LinearBase) and
|
||||
hasattr(module.quant_method, "quant_config")):
|
||||
if (isinstance(module, LinearBase)
|
||||
and hasattr(module.quant_method, "quant_config")):
|
||||
if modules_info := self.modules_mapping.get_sub_modules(name):
|
||||
# Map vllm's names to transformers's names.
|
||||
rep_name, sub_modules = modules_info
|
||||
@ -409,29 +416,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
), "vllm currently does not support BNB quantization for"
|
||||
f" {type(model).__name__}"
|
||||
|
||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||
if not hasattr(model, "load_weights"):
|
||||
raise AttributeError(
|
||||
"The required method 'load_weights' is not defined in class"
|
||||
f" {type(model).__name__}.")
|
||||
|
||||
if not hasattr(model, "packed_modules_mapping"):
|
||||
raise AttributeError(
|
||||
f"Model {type(model).__name__} does not support BitsAndBytes "
|
||||
"quantization yet. No 'packed_modules_mapping' found.")
|
||||
self.is_pool_model=is_pooling_model(model)
|
||||
|
||||
self.modules_mapping = ParamMapping(get_packed_modules_mapping(model))
|
||||
|
||||
# For some models like Molmo, we need to use hf_to_vllm_mapper
|
||||
# to ensure correct loading of weights.
|
||||
if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
|
||||
self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name)
|
||||
|
||||
# Modules whose weights might have fused on disk
|
||||
# we need their output_sizes to make shard in flight correctly with TP
|
||||
self.maybe_fused_weights_modules: dict[str, list[int]] = {}
|
||||
self._get_bnb_target_modules(model)
|
||||
def _classify_module_sharding(self, model: nn.Module):
|
||||
"""
|
||||
Categorize modules based on their weight sharding requirements
|
||||
for tensor parallelism.
|
||||
"""
|
||||
for name, module in model.named_modules():
|
||||
# Some modules like `ReplicatedLinear` should not have their weights
|
||||
# sharded. The reason for implementing it this way is to avoid new
|
||||
@ -449,19 +438,27 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
elif isinstance(module, (RowParallelLinear, )):
|
||||
self.column_sharded_weights_modules.append(name)
|
||||
|
||||
self.model_type = type(model).__name__
|
||||
def _verify_model_compatibility(self, model: nn.Module,
|
||||
model_config: ModelConfig) -> None:
|
||||
"""
|
||||
Verify that the model is compatible with BitsAndBytes quantization.
|
||||
"""
|
||||
if not hasattr(model, "load_weights"):
|
||||
raise AttributeError(
|
||||
"The required method 'load_weights' is not defined in class"
|
||||
f" {type(model).__name__}.")
|
||||
|
||||
logger.info("Loading weights with BitsAndBytes quantization. "
|
||||
"May take a while ...")
|
||||
if not hasattr(model, "packed_modules_mapping"):
|
||||
raise AttributeError(
|
||||
f"Model {type(model).__name__} does not support BitsAndBytes "
|
||||
"quantization yet. No 'packed_modules_mapping' found.")
|
||||
|
||||
quant_config = getattr(model_config.hf_config, "quantization_config",
|
||||
None)
|
||||
|
||||
pre_quant = False
|
||||
if quant_config is not None:
|
||||
quant_method = quant_config.get("quant_method")
|
||||
if quant_method == "bitsandbytes":
|
||||
pre_quant = True
|
||||
self.pre_quant = True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"BitsAndBytes loader does not support {quant_method} "
|
||||
@ -469,20 +466,43 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
|
||||
# The quant_states in pre_quantized models cannot work with a split
|
||||
# weight tensor. So TP does not work with pre_quantized bnb models.
|
||||
if pre_quant and get_tensor_model_parallel_world_size() > 1:
|
||||
if self.pre_quant and get_tensor_model_parallel_world_size() > 1:
|
||||
raise ValueError(
|
||||
"Prequant BitsAndBytes models with tensor parallelism is not "
|
||||
"supported. Please try with pipeline parallelism.")
|
||||
if self.pre_quant:
|
||||
self.load_8bit = quant_config.get("load_in_8bit", False)
|
||||
|
||||
load_8bit = False
|
||||
if pre_quant:
|
||||
load_8bit = quant_config.get("load_in_8bit", False)
|
||||
def _initialize_loader_state(self, model: nn.Module,
|
||||
model_config: ModelConfig) -> None:
|
||||
"""
|
||||
Initialize the loader's internal state based on the model and
|
||||
configuration.
|
||||
"""
|
||||
self.is_pool_model = is_pooling_model(model)
|
||||
self.modules_mapping = ParamMapping(get_packed_modules_mapping(model))
|
||||
|
||||
# For some models like Molmo, we need to use hf_to_vllm_mapper
|
||||
# to ensure correct loading of weights.
|
||||
if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
|
||||
self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name)
|
||||
|
||||
self._get_bnb_target_modules(model)
|
||||
self._classify_module_sharding(model)
|
||||
|
||||
def load_weights(self, model: nn.Module,
|
||||
model_config: ModelConfig) -> None:
|
||||
|
||||
self._verify_model_compatibility(model, model_config)
|
||||
self._initialize_loader_state(model, model_config)
|
||||
|
||||
logger.info("Loading weights with BitsAndBytes quantization. "
|
||||
"May take a while ...")
|
||||
qweight_iterator, quant_state_dict = (
|
||||
self._get_quantized_weights_iterator(model_config.model,
|
||||
model_config.revision,
|
||||
pre_quant, load_8bit))
|
||||
|
||||
self._get_quantized_weights_iterator(
|
||||
model_config.model,
|
||||
model_config.revision,
|
||||
))
|
||||
weights_to_load = {name for name, _ in model.named_parameters()}
|
||||
loaded_weights = model.load_weights(qweight_iterator)
|
||||
# Some models may have weights loading tracker unimplemented.
|
||||
@ -562,10 +582,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
offsets = torch.tensor(offsets).cpu()
|
||||
set_weight_attrs(param, {"bnb_shard_offsets": offsets})
|
||||
|
||||
if load_8bit:
|
||||
if self.load_8bit:
|
||||
set_weight_attrs(
|
||||
param, {"matmul_state": [None] * len(quant_states)})
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
self._prepare_weights(model_config.model, model_config.revision)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user