mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-12 18:09:09 +08:00
[Quantization] Improve BitsAndBytesModelLoader (#20242)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
e936e401de
commit
8fe7fc8634
@ -20,8 +20,6 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
|
|||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
# yapf conflicts with isort for this block
|
|
||||||
# yapf: disable
|
|
||||||
from vllm.model_executor.layers.linear import (LinearBase,
|
from vllm.model_executor.layers.linear import (LinearBase,
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
@ -39,6 +37,8 @@ from vllm.model_executor.utils import (get_packed_modules_mapping,
|
|||||||
set_weight_attrs)
|
set_weight_attrs)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
# yapf conflicts with isort for this block
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -54,11 +54,17 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
self.unsharded_weights_modules: list[str] = []
|
self.unsharded_weights_modules: list[str] = []
|
||||||
# Save the module names that are sharded by column.
|
# Save the module names that are sharded by column.
|
||||||
self.column_sharded_weights_modules: list[str] = []
|
self.column_sharded_weights_modules: list[str] = []
|
||||||
|
# Modules whose weights might have fused on disk
|
||||||
|
# we need their output_sizes to make shard in flight correctly with TP
|
||||||
|
self.maybe_fused_weights_modules: dict[str, list[int]] = {}
|
||||||
# Store all module names (from transformers) that support
|
# Store all module names (from transformers) that support
|
||||||
# BNB quantization.
|
# BNB quantization.
|
||||||
self.target_modules: list[str] = []
|
self.target_modules: list[str] = []
|
||||||
# mapping weight names from transformers to vllm.
|
# mapping weight names from transformers to vllm.
|
||||||
self.weight_mapper: Callable = lambda name: name
|
self.weight_mapper: Callable = lambda name: name
|
||||||
|
self.pre_quant: bool = False
|
||||||
|
self.load_8bit: bool = False
|
||||||
|
self.is_pool_model: bool = False
|
||||||
|
|
||||||
def _get_weight_files(
|
def _get_weight_files(
|
||||||
self,
|
self,
|
||||||
@ -134,13 +140,14 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
return hf_weights_files, use_safetensors
|
return hf_weights_files, use_safetensors
|
||||||
|
|
||||||
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
|
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
|
||||||
def _maybe_pool_model(module_name:str):
|
|
||||||
|
def _maybe_pool_model(module_name: str):
|
||||||
# For pool model, we need to add the prefix `model.`
|
# For pool model, we need to add the prefix `model.`
|
||||||
# for the weight name if possible.
|
# for the weight name if possible.
|
||||||
if self.is_pool_model and self.target_modules[0]. \
|
if self.is_pool_model and self.target_modules[0]. \
|
||||||
startswith("model.") and not module_name.startswith(
|
startswith("model.") and not module_name.startswith(
|
||||||
"model."):
|
"model."):
|
||||||
return "model."+module_name
|
return "model." + module_name
|
||||||
|
|
||||||
return module_name
|
return module_name
|
||||||
|
|
||||||
@ -159,8 +166,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
# mapping weight names from transformers to vllm while preserving
|
# mapping weight names from transformers to vllm while preserving
|
||||||
# original names.
|
# original names.
|
||||||
mapped_name = self.weight_mapper(org_name)
|
mapped_name = self.weight_mapper(org_name)
|
||||||
mapped_name=_maybe_pool_model(mapped_name)
|
mapped_name = _maybe_pool_model(mapped_name)
|
||||||
|
|
||||||
|
|
||||||
yield org_name, mapped_name, param
|
yield org_name, mapped_name, param
|
||||||
|
|
||||||
@ -168,8 +174,6 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
self,
|
self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
pre_quant: bool,
|
|
||||||
load_8bit: bool,
|
|
||||||
) -> tuple[Generator[tuple[str, torch.Tensor], None, None], dict[str,
|
) -> tuple[Generator[tuple[str, torch.Tensor], None, None], dict[str,
|
||||||
Any]]:
|
Any]]:
|
||||||
"""Get an iterator to the model weights with bitsandbytes quantization,
|
"""Get an iterator to the model weights with bitsandbytes quantization,
|
||||||
@ -192,8 +196,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
|
|
||||||
quant_state_dict: dict[str, Any] = {}
|
quant_state_dict: dict[str, Any] = {}
|
||||||
|
|
||||||
if pre_quant:
|
if self.pre_quant:
|
||||||
if load_8bit:
|
if self.load_8bit:
|
||||||
return self._quantized_8bit_generator(
|
return self._quantized_8bit_generator(
|
||||||
hf_weights_files, use_safetensors,
|
hf_weights_files, use_safetensors,
|
||||||
quant_state_dict), quant_state_dict
|
quant_state_dict), quant_state_dict
|
||||||
@ -390,10 +394,13 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
yield org_weight_name, processed_weight
|
yield org_weight_name, processed_weight
|
||||||
|
|
||||||
def _get_bnb_target_modules(self, model: nn.Module) -> None:
|
def _get_bnb_target_modules(self, model: nn.Module) -> None:
|
||||||
|
"""
|
||||||
|
Identify and collect all modules that support BitsAndBytes
|
||||||
|
quantization.
|
||||||
|
"""
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if (isinstance(module, LinearBase) and
|
if (isinstance(module, LinearBase)
|
||||||
hasattr(module.quant_method, "quant_config")):
|
and hasattr(module.quant_method, "quant_config")):
|
||||||
if modules_info := self.modules_mapping.get_sub_modules(name):
|
if modules_info := self.modules_mapping.get_sub_modules(name):
|
||||||
# Map vllm's names to transformers's names.
|
# Map vllm's names to transformers's names.
|
||||||
rep_name, sub_modules = modules_info
|
rep_name, sub_modules = modules_info
|
||||||
@ -409,29 +416,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
), "vllm currently does not support BNB quantization for"
|
), "vllm currently does not support BNB quantization for"
|
||||||
f" {type(model).__name__}"
|
f" {type(model).__name__}"
|
||||||
|
|
||||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
def _classify_module_sharding(self, model: nn.Module):
|
||||||
if not hasattr(model, "load_weights"):
|
"""
|
||||||
raise AttributeError(
|
Categorize modules based on their weight sharding requirements
|
||||||
"The required method 'load_weights' is not defined in class"
|
for tensor parallelism.
|
||||||
f" {type(model).__name__}.")
|
"""
|
||||||
|
|
||||||
if not hasattr(model, "packed_modules_mapping"):
|
|
||||||
raise AttributeError(
|
|
||||||
f"Model {type(model).__name__} does not support BitsAndBytes "
|
|
||||||
"quantization yet. No 'packed_modules_mapping' found.")
|
|
||||||
self.is_pool_model=is_pooling_model(model)
|
|
||||||
|
|
||||||
self.modules_mapping = ParamMapping(get_packed_modules_mapping(model))
|
|
||||||
|
|
||||||
# For some models like Molmo, we need to use hf_to_vllm_mapper
|
|
||||||
# to ensure correct loading of weights.
|
|
||||||
if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
|
|
||||||
self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name)
|
|
||||||
|
|
||||||
# Modules whose weights might have fused on disk
|
|
||||||
# we need their output_sizes to make shard in flight correctly with TP
|
|
||||||
self.maybe_fused_weights_modules: dict[str, list[int]] = {}
|
|
||||||
self._get_bnb_target_modules(model)
|
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
# Some modules like `ReplicatedLinear` should not have their weights
|
# Some modules like `ReplicatedLinear` should not have their weights
|
||||||
# sharded. The reason for implementing it this way is to avoid new
|
# sharded. The reason for implementing it this way is to avoid new
|
||||||
@ -449,19 +438,27 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
elif isinstance(module, (RowParallelLinear, )):
|
elif isinstance(module, (RowParallelLinear, )):
|
||||||
self.column_sharded_weights_modules.append(name)
|
self.column_sharded_weights_modules.append(name)
|
||||||
|
|
||||||
self.model_type = type(model).__name__
|
def _verify_model_compatibility(self, model: nn.Module,
|
||||||
|
model_config: ModelConfig) -> None:
|
||||||
|
"""
|
||||||
|
Verify that the model is compatible with BitsAndBytes quantization.
|
||||||
|
"""
|
||||||
|
if not hasattr(model, "load_weights"):
|
||||||
|
raise AttributeError(
|
||||||
|
"The required method 'load_weights' is not defined in class"
|
||||||
|
f" {type(model).__name__}.")
|
||||||
|
|
||||||
logger.info("Loading weights with BitsAndBytes quantization. "
|
if not hasattr(model, "packed_modules_mapping"):
|
||||||
"May take a while ...")
|
raise AttributeError(
|
||||||
|
f"Model {type(model).__name__} does not support BitsAndBytes "
|
||||||
|
"quantization yet. No 'packed_modules_mapping' found.")
|
||||||
|
|
||||||
quant_config = getattr(model_config.hf_config, "quantization_config",
|
quant_config = getattr(model_config.hf_config, "quantization_config",
|
||||||
None)
|
None)
|
||||||
|
|
||||||
pre_quant = False
|
|
||||||
if quant_config is not None:
|
if quant_config is not None:
|
||||||
quant_method = quant_config.get("quant_method")
|
quant_method = quant_config.get("quant_method")
|
||||||
if quant_method == "bitsandbytes":
|
if quant_method == "bitsandbytes":
|
||||||
pre_quant = True
|
self.pre_quant = True
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"BitsAndBytes loader does not support {quant_method} "
|
f"BitsAndBytes loader does not support {quant_method} "
|
||||||
@ -469,20 +466,43 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
|
|
||||||
# The quant_states in pre_quantized models cannot work with a split
|
# The quant_states in pre_quantized models cannot work with a split
|
||||||
# weight tensor. So TP does not work with pre_quantized bnb models.
|
# weight tensor. So TP does not work with pre_quantized bnb models.
|
||||||
if pre_quant and get_tensor_model_parallel_world_size() > 1:
|
if self.pre_quant and get_tensor_model_parallel_world_size() > 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Prequant BitsAndBytes models with tensor parallelism is not "
|
"Prequant BitsAndBytes models with tensor parallelism is not "
|
||||||
"supported. Please try with pipeline parallelism.")
|
"supported. Please try with pipeline parallelism.")
|
||||||
|
if self.pre_quant:
|
||||||
|
self.load_8bit = quant_config.get("load_in_8bit", False)
|
||||||
|
|
||||||
load_8bit = False
|
def _initialize_loader_state(self, model: nn.Module,
|
||||||
if pre_quant:
|
model_config: ModelConfig) -> None:
|
||||||
load_8bit = quant_config.get("load_in_8bit", False)
|
"""
|
||||||
|
Initialize the loader's internal state based on the model and
|
||||||
|
configuration.
|
||||||
|
"""
|
||||||
|
self.is_pool_model = is_pooling_model(model)
|
||||||
|
self.modules_mapping = ParamMapping(get_packed_modules_mapping(model))
|
||||||
|
|
||||||
|
# For some models like Molmo, we need to use hf_to_vllm_mapper
|
||||||
|
# to ensure correct loading of weights.
|
||||||
|
if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
|
||||||
|
self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name)
|
||||||
|
|
||||||
|
self._get_bnb_target_modules(model)
|
||||||
|
self._classify_module_sharding(model)
|
||||||
|
|
||||||
|
def load_weights(self, model: nn.Module,
|
||||||
|
model_config: ModelConfig) -> None:
|
||||||
|
|
||||||
|
self._verify_model_compatibility(model, model_config)
|
||||||
|
self._initialize_loader_state(model, model_config)
|
||||||
|
|
||||||
|
logger.info("Loading weights with BitsAndBytes quantization. "
|
||||||
|
"May take a while ...")
|
||||||
qweight_iterator, quant_state_dict = (
|
qweight_iterator, quant_state_dict = (
|
||||||
self._get_quantized_weights_iterator(model_config.model,
|
self._get_quantized_weights_iterator(
|
||||||
model_config.revision,
|
model_config.model,
|
||||||
pre_quant, load_8bit))
|
model_config.revision,
|
||||||
|
))
|
||||||
weights_to_load = {name for name, _ in model.named_parameters()}
|
weights_to_load = {name for name, _ in model.named_parameters()}
|
||||||
loaded_weights = model.load_weights(qweight_iterator)
|
loaded_weights = model.load_weights(qweight_iterator)
|
||||||
# Some models may have weights loading tracker unimplemented.
|
# Some models may have weights loading tracker unimplemented.
|
||||||
@ -562,10 +582,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
offsets = torch.tensor(offsets).cpu()
|
offsets = torch.tensor(offsets).cpu()
|
||||||
set_weight_attrs(param, {"bnb_shard_offsets": offsets})
|
set_weight_attrs(param, {"bnb_shard_offsets": offsets})
|
||||||
|
|
||||||
if load_8bit:
|
if self.load_8bit:
|
||||||
set_weight_attrs(
|
set_weight_attrs(
|
||||||
param, {"matmul_state": [None] * len(quant_states)})
|
param, {"matmul_state": [None] * len(quant_states)})
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def download_model(self, model_config: ModelConfig) -> None:
|
def download_model(self, model_config: ModelConfig) -> None:
|
||||||
self._prepare_weights(model_config.model, model_config.revision)
|
self._prepare_weights(model_config.model, model_config.revision)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user