mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 17:45:58 +08:00
[Misc]Further reduce BNB static variable (#10597)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
e85250b1d1
commit
15cc2a9f1a
@ -28,7 +28,8 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
|
|||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
from vllm.envs import VLLM_USE_MODELSCOPE
|
from vllm.envs import VLLM_USE_MODELSCOPE
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (LinearBase,
|
||||||
|
MergedColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
ReplicatedLinear,
|
ReplicatedLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
@ -78,12 +79,14 @@ def device_loading_context(module: torch.nn.Module,
|
|||||||
original_device: torch.device = original_device_states[name]
|
original_device: torch.device = original_device_states[name]
|
||||||
if original_device.type == "cpu":
|
if original_device.type == "cpu":
|
||||||
# `torch.empty_like` does not support `pin_memory` argument
|
# `torch.empty_like` does not support `pin_memory` argument
|
||||||
cpu_data = torch.empty_strided(size=p.data.size(),
|
cpu_data = torch.empty_strided(
|
||||||
|
size=p.data.size(),
|
||||||
stride=p.data.stride(),
|
stride=p.data.stride(),
|
||||||
dtype=p.data.dtype,
|
dtype=p.data.dtype,
|
||||||
layout=p.data.layout,
|
layout=p.data.layout,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
pin_memory=pin_memory)
|
pin_memory=pin_memory,
|
||||||
|
)
|
||||||
cpu_data.copy_(p.data)
|
cpu_data.copy_(p.data)
|
||||||
p.data = cpu_data
|
p.data = cpu_data
|
||||||
else:
|
else:
|
||||||
@ -112,7 +115,8 @@ def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module:
|
|||||||
logger.warning(msg)
|
logger.warning(msg)
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Trying to guess the arguments for old-style model class %s",
|
"Trying to guess the arguments for old-style model class %s",
|
||||||
model_class)
|
model_class,
|
||||||
|
)
|
||||||
# try to be compatible with old-style model class
|
# try to be compatible with old-style model class
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if "prefix" in all_params:
|
if "prefix" in all_params:
|
||||||
@ -198,14 +202,17 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
return model_path
|
return model_path
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _prepare_weights(self, model_name_or_path: str,
|
def _prepare_weights(
|
||||||
|
self,
|
||||||
|
model_name_or_path: str,
|
||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
fall_back_to_pt: bool) -> Tuple[str, List[str], bool]:
|
fall_back_to_pt: bool,
|
||||||
|
) -> Tuple[str, List[str], bool]:
|
||||||
"""Prepare weights for the model.
|
"""Prepare weights for the model.
|
||||||
|
|
||||||
If the model is not local, it will be downloaded."""
|
If the model is not local, it will be downloaded."""
|
||||||
model_name_or_path = self._maybe_download_from_modelscope(
|
model_name_or_path = (self._maybe_download_from_modelscope(
|
||||||
model_name_or_path, revision) or model_name_or_path
|
model_name_or_path, revision) or model_name_or_path)
|
||||||
|
|
||||||
is_local = os.path.isdir(model_name_or_path)
|
is_local = os.path.isdir(model_name_or_path)
|
||||||
load_format = self.load_config.load_format
|
load_format = self.load_config.load_format
|
||||||
@ -258,8 +265,11 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
# any files not found in the index.
|
# any files not found in the index.
|
||||||
if not is_local:
|
if not is_local:
|
||||||
download_safetensors_index_file_from_hf(
|
download_safetensors_index_file_from_hf(
|
||||||
model_name_or_path, index_file,
|
model_name_or_path,
|
||||||
self.load_config.download_dir, revision)
|
index_file,
|
||||||
|
self.load_config.download_dir,
|
||||||
|
revision,
|
||||||
|
)
|
||||||
hf_weights_files = filter_duplicate_safetensors_files(
|
hf_weights_files = filter_duplicate_safetensors_files(
|
||||||
hf_weights_files, hf_folder, index_file)
|
hf_weights_files, hf_folder, index_file)
|
||||||
else:
|
else:
|
||||||
@ -282,8 +292,11 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
# Currently np_cache only support *.bin checkpoints
|
# Currently np_cache only support *.bin checkpoints
|
||||||
assert use_safetensors is False
|
assert use_safetensors is False
|
||||||
weights_iterator = np_cache_weights_iterator(
|
weights_iterator = np_cache_weights_iterator(
|
||||||
source.model_or_path, self.load_config.download_dir, hf_folder,
|
source.model_or_path,
|
||||||
hf_weights_files)
|
self.load_config.download_dir,
|
||||||
|
hf_folder,
|
||||||
|
hf_weights_files,
|
||||||
|
)
|
||||||
elif use_safetensors:
|
elif use_safetensors:
|
||||||
weights_iterator = safetensors_weights_iterator(hf_weights_files)
|
weights_iterator = safetensors_weights_iterator(hf_weights_files)
|
||||||
else:
|
else:
|
||||||
@ -310,17 +323,19 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||||
|
|
||||||
primary_weights = DefaultModelLoader.Source(
|
primary_weights = DefaultModelLoader.Source(
|
||||||
model_config.model,
|
model_config.model,
|
||||||
model_config.revision,
|
model_config.revision,
|
||||||
prefix="",
|
prefix="",
|
||||||
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load",
|
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load",
|
||||||
True))
|
True),
|
||||||
|
)
|
||||||
yield from self._get_weights_iterator(primary_weights)
|
yield from self._get_weights_iterator(primary_weights)
|
||||||
|
|
||||||
secondary_weights = cast(Iterable[DefaultModelLoader.Source],
|
secondary_weights = cast(
|
||||||
getattr(model, "secondary_weights", ()))
|
Iterable[DefaultModelLoader.Source],
|
||||||
|
getattr(model, "secondary_weights", ()),
|
||||||
|
)
|
||||||
for source in secondary_weights:
|
for source in secondary_weights:
|
||||||
yield from self._get_weights_iterator(source)
|
yield from self._get_weights_iterator(source)
|
||||||
|
|
||||||
@ -416,7 +431,7 @@ class TensorizerLoader(BaseModelLoader):
|
|||||||
self.tensorizer_config.verify_with_parallel_config(parallel_config)
|
self.tensorizer_config.verify_with_parallel_config(parallel_config)
|
||||||
|
|
||||||
def _get_weights_iterator(
|
def _get_weights_iterator(
|
||||||
self) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
self, ) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||||
tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
|
tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
|
||||||
return tensorizer_weights_iterator(tensorizer_args)
|
return tensorizer_weights_iterator(tensorizer_args)
|
||||||
|
|
||||||
@ -479,9 +494,10 @@ class TensorizerLoader(BaseModelLoader):
|
|||||||
|
|
||||||
if parallel_config.tensor_parallel_size > 1:
|
if parallel_config.tensor_parallel_size > 1:
|
||||||
from vllm.distributed import get_tensor_model_parallel_rank
|
from vllm.distributed import get_tensor_model_parallel_rank
|
||||||
self.tensorizer_config.tensorizer_uri = \
|
|
||||||
self.tensorizer_config.tensorizer_uri \
|
self.tensorizer_config.tensorizer_uri = (
|
||||||
% get_tensor_model_parallel_rank()
|
self.tensorizer_config.tensorizer_uri %
|
||||||
|
get_tensor_model_parallel_rank())
|
||||||
|
|
||||||
if is_vllm_tensorized(self.tensorizer_config):
|
if is_vllm_tensorized(self.tensorizer_config):
|
||||||
return self._load_model_serialized(vllm_config=vllm_config)
|
return self._load_model_serialized(vllm_config=vllm_config)
|
||||||
@ -520,13 +536,13 @@ class ShardedStateLoader(BaseModelLoader):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _filter_subtensors(
|
def _filter_subtensors(
|
||||||
tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
tensors: Dict[str, torch.Tensor], ) -> Dict[str, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Filter out all tensors that share the same memory or a subset of the
|
Filter out all tensors that share the same memory or a subset of the
|
||||||
memory of another tensor.
|
memory of another tensor.
|
||||||
"""
|
"""
|
||||||
same_storage_groups: Dict[Any, List[Tuple[
|
same_storage_groups: Dict[Any, List[Tuple[str, torch.Tensor]]] = (
|
||||||
str, torch.Tensor]]] = collections.defaultdict(list)
|
collections.defaultdict(list))
|
||||||
for key, tensor in tensors.items():
|
for key, tensor in tensors.items():
|
||||||
if tensor.numel():
|
if tensor.numel():
|
||||||
ptr = tensor.untyped_storage().data_ptr()
|
ptr = tensor.untyped_storage().data_ptr()
|
||||||
@ -615,8 +631,11 @@ class ShardedStateLoader(BaseModelLoader):
|
|||||||
if tensor.shape != param_shape:
|
if tensor.shape != param_shape:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"loading tensor of shape %s into "
|
"loading tensor of shape %s into "
|
||||||
"parameter '%s' of shape %s", tensor.shape,
|
"parameter '%s' of shape %s",
|
||||||
key, param_shape)
|
tensor.shape,
|
||||||
|
key,
|
||||||
|
param_shape,
|
||||||
|
)
|
||||||
param_data.copy_(tensor)
|
param_data.copy_(tensor)
|
||||||
state_dict.pop(key)
|
state_dict.pop(key)
|
||||||
if state_dict:
|
if state_dict:
|
||||||
@ -634,6 +653,7 @@ class ShardedStateLoader(BaseModelLoader):
|
|||||||
from safetensors.torch import save_file
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
from vllm.distributed import get_tensor_model_parallel_rank
|
from vllm.distributed import get_tensor_model_parallel_rank
|
||||||
|
|
||||||
if pattern is None:
|
if pattern is None:
|
||||||
pattern = ShardedStateLoader.DEFAULT_PATTERN
|
pattern = ShardedStateLoader.DEFAULT_PATTERN
|
||||||
rank = get_tensor_model_parallel_rank()
|
rank = get_tensor_model_parallel_rank()
|
||||||
@ -667,24 +687,6 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
|
|
||||||
possible_config_file_names = ["adapter_config.json"]
|
possible_config_file_names = ["adapter_config.json"]
|
||||||
|
|
||||||
default_target_modules = [
|
|
||||||
".gate_proj.",
|
|
||||||
".down_proj.",
|
|
||||||
".up_proj.",
|
|
||||||
".q_proj.",
|
|
||||||
".k_proj.",
|
|
||||||
".v_proj.",
|
|
||||||
".o_proj.",
|
|
||||||
'.fc1.',
|
|
||||||
'.fc2.',
|
|
||||||
'.dense.',
|
|
||||||
'.query_key_value.',
|
|
||||||
'.qkv_proj.',
|
|
||||||
'.dense_h_to_4h.',
|
|
||||||
'.dense_4h_to_h.',
|
|
||||||
'.out_proj.',
|
|
||||||
]
|
|
||||||
|
|
||||||
def __init__(self, load_config: LoadConfig):
|
def __init__(self, load_config: LoadConfig):
|
||||||
super().__init__(load_config)
|
super().__init__(load_config)
|
||||||
|
|
||||||
@ -709,6 +711,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
with open(config_file_path) as f:
|
with open(config_file_path) as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
self.target_modules = config["target_modules"]
|
self.target_modules = config["target_modules"]
|
||||||
|
# TODO: target_modules could be either a list or a regex string.
|
||||||
|
# We need to handle both cases.
|
||||||
|
assert isinstance(self.target_modules,
|
||||||
|
list), "Unsupported target_modules: "
|
||||||
|
f"{self.target_modules}"
|
||||||
|
|
||||||
def _get_config_file(self, qlora_adapter: str) -> str:
|
def _get_config_file(self, qlora_adapter: str) -> str:
|
||||||
is_local = os.path.isdir(qlora_adapter)
|
is_local = os.path.isdir(qlora_adapter)
|
||||||
@ -737,7 +744,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
self,
|
self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
allowed_patterns: List[str],
|
allowed_patterns: List[str],
|
||||||
revision: Optional[str] = None) -> Tuple[List[str], str]:
|
revision: Optional[str] = None,
|
||||||
|
) -> Tuple[List[str], str]:
|
||||||
"""Retrieve weight files. Download the files if necessary.
|
"""Retrieve weight files. Download the files if necessary.
|
||||||
|
|
||||||
Return the weight files and the file pattern."""
|
Return the weight files and the file pattern."""
|
||||||
@ -806,6 +814,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
# only load the bitsandbytes module when needed
|
# only load the bitsandbytes module when needed
|
||||||
try:
|
try:
|
||||||
import bitsandbytes
|
import bitsandbytes
|
||||||
|
|
||||||
if bitsandbytes.__version__ < "0.44.0":
|
if bitsandbytes.__version__ < "0.44.0":
|
||||||
raise ImportError("bitsandbytes version is wrong. Please "
|
raise ImportError("bitsandbytes version is wrong. Please "
|
||||||
"install bitsandbytes>=0.44.0.")
|
"install bitsandbytes>=0.44.0.")
|
||||||
@ -839,8 +848,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
|
|
||||||
def _is_4bit_weight_name(self, weight_name: str):
|
def _is_4bit_weight_name(self, weight_name: str):
|
||||||
quantized_suffix = {
|
quantized_suffix = {
|
||||||
"absmax", "quant_map", "nested_absmax", "nested_quant_map",
|
"absmax",
|
||||||
"bitsandbytes"
|
"quant_map",
|
||||||
|
"nested_absmax",
|
||||||
|
"nested_quant_map",
|
||||||
|
"bitsandbytes",
|
||||||
}
|
}
|
||||||
suffix = weight_name.split(".")[-1]
|
suffix = weight_name.split(".")[-1]
|
||||||
return any(q_suffix in suffix for q_suffix in quantized_suffix)
|
return any(q_suffix in suffix for q_suffix in quantized_suffix)
|
||||||
@ -857,7 +869,6 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
|
|
||||||
for weight_name, weight_tensor in self._hf_weight_iter(
|
for weight_name, weight_tensor in self._hf_weight_iter(
|
||||||
hf_weights_files, use_safetensors):
|
hf_weights_files, use_safetensors):
|
||||||
|
|
||||||
if self._is_8bit_weight_name(weight_name):
|
if self._is_8bit_weight_name(weight_name):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -899,13 +910,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
# pre quantized weights would have a quant_state
|
# pre quantized weights would have a quant_state
|
||||||
for weight_name, weight_tensor in self._hf_weight_iter(
|
for weight_name, weight_tensor in self._hf_weight_iter(
|
||||||
hf_weights_files, use_safetensors):
|
hf_weights_files, use_safetensors):
|
||||||
|
|
||||||
if self._is_4bit_weight_name(weight_name):
|
if self._is_4bit_weight_name(weight_name):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if (f"{weight_name}.quant_state.bitsandbytes__nf4" \
|
if (f"{weight_name}.quant_state.bitsandbytes__nf4"
|
||||||
in temp_state_dict) or \
|
in temp_state_dict) or (
|
||||||
(f"{weight_name}.quant_state.bitsandbytes__fp4" \
|
f"{weight_name}.quant_state.bitsandbytes__fp4"
|
||||||
in temp_state_dict):
|
in temp_state_dict):
|
||||||
quant_state = _parse_quant_state(weight_name, temp_state_dict)
|
quant_state = _parse_quant_state(weight_name, temp_state_dict)
|
||||||
quant_state_dict[weight_name] = quant_state
|
quant_state_dict[weight_name] = quant_state
|
||||||
@ -916,12 +926,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
def _unquantized_generator(self, hf_weights_files, use_safetensors,
|
def _unquantized_generator(self, hf_weights_files, use_safetensors,
|
||||||
quant_state_dict) -> Generator:
|
quant_state_dict) -> Generator:
|
||||||
from bitsandbytes.functional import quantize_4bit
|
from bitsandbytes.functional import quantize_4bit
|
||||||
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
|
||||||
for weight_name, weight_tensor in self._hf_weight_iter(
|
for weight_name, weight_tensor in self._hf_weight_iter(
|
||||||
hf_weights_files, use_safetensors):
|
hf_weights_files, use_safetensors):
|
||||||
|
|
||||||
if any(target_module in weight_name for target_module in
|
if any(target_module in weight_name for target_module in
|
||||||
self.target_modules) and weight_name.endswith(".weight"):
|
self.target_modules) and weight_name.endswith(".weight"):
|
||||||
# Without sharding
|
# Without sharding
|
||||||
@ -954,12 +964,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
# get the start/end index of each shard weight tensor
|
# get the start/end index of each shard weight tensor
|
||||||
total_start_index = list(
|
total_start_index = list(
|
||||||
itertools.accumulate([0] + total_shard_sizes))[:-1]
|
itertools.accumulate([0] + total_shard_sizes))[:-1]
|
||||||
shard_weights_index = [
|
shard_weights_index = [(
|
||||||
(idx + size // tp_size * tp_rank,
|
idx + size // tp_size * tp_rank,
|
||||||
idx + size // tp_size * (tp_rank + 1))
|
idx + size // tp_size * (tp_rank + 1),
|
||||||
for idx, size in zip(total_start_index,
|
) for idx, size in zip(total_start_index,
|
||||||
total_shard_sizes)
|
total_shard_sizes)]
|
||||||
]
|
|
||||||
# slice and reorder the weight tensor
|
# slice and reorder the weight tensor
|
||||||
weight_tensor = [
|
weight_tensor = [
|
||||||
weight_tensor[start_index:end_index, ...]
|
weight_tensor[start_index:end_index, ...]
|
||||||
@ -989,7 +998,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
processed_weight, quant_state = quantize_4bit(
|
processed_weight, quant_state = quantize_4bit(
|
||||||
loaded_weight,
|
loaded_weight,
|
||||||
compress_statistics=True,
|
compress_statistics=True,
|
||||||
quant_type="nf4")
|
quant_type="nf4",
|
||||||
|
)
|
||||||
|
|
||||||
quant_state_dict[weight_name] = quant_state
|
quant_state_dict[weight_name] = quant_state
|
||||||
else:
|
else:
|
||||||
@ -997,28 +1007,58 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
|
|
||||||
yield weight_name, processed_weight
|
yield weight_name, processed_weight
|
||||||
|
|
||||||
|
def _get_bnb_target_modules(self, model: nn.Module) -> None:
|
||||||
|
|
||||||
|
# TODO: Maybe we can replace bitsandbytes_stacked_params_mapping with
|
||||||
|
# packed_modules_mapping.
|
||||||
|
inverse_stacked_mapping: Dict[str, List[str]] = {}
|
||||||
|
for orig, (
|
||||||
|
packed,
|
||||||
|
idx,
|
||||||
|
) in model.bitsandbytes_stacked_params_mapping.items():
|
||||||
|
if packed not in inverse_stacked_mapping:
|
||||||
|
inverse_stacked_mapping[packed] = []
|
||||||
|
inverse_stacked_mapping[packed].insert(idx, orig)
|
||||||
|
|
||||||
|
linear_module_lst = []
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if isinstance(module, (LinearBase, )):
|
||||||
|
last_name = name.split(".")[-1]
|
||||||
|
if sub_modules := inverse_stacked_mapping.get(last_name, []):
|
||||||
|
# Map vllm's names to transformers' names.
|
||||||
|
for sub_name in sub_modules:
|
||||||
|
linear_module_lst.append(
|
||||||
|
name.replace(last_name, sub_name))
|
||||||
|
else:
|
||||||
|
linear_module_lst.append(name)
|
||||||
|
if self.target_modules:
|
||||||
|
# Update self.target_modules
|
||||||
|
self.target_modules = [
|
||||||
|
qual_name for qual_name in linear_module_lst
|
||||||
|
if any(t in qual_name for t in self.target_modules)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
self.target_modules = linear_module_lst
|
||||||
|
assert (self.target_modules
|
||||||
|
), "vllm currently does not support BNB quantization for"
|
||||||
|
f" {type(model).__name__}"
|
||||||
|
|
||||||
def _load_weights(self, model_config: ModelConfig,
|
def _load_weights(self, model_config: ModelConfig,
|
||||||
model: nn.Module) -> None:
|
model: nn.Module) -> None:
|
||||||
if not hasattr(model, 'load_weights'):
|
if not hasattr(model, "load_weights"):
|
||||||
raise AttributeError(
|
raise AttributeError(
|
||||||
"The required method 'load_weights' is not defined in class"
|
"The required method 'load_weights' is not defined in class"
|
||||||
f" {type(model).__name__}.")
|
f" {type(model).__name__}.")
|
||||||
|
|
||||||
if not hasattr(model, 'bitsandbytes_stacked_params_mapping'):
|
if not hasattr(model, "bitsandbytes_stacked_params_mapping"):
|
||||||
raise AttributeError(
|
raise AttributeError(
|
||||||
f"Model {type(model).__name__} does not support BitsAndBytes "
|
f"Model {type(model).__name__} does not support BitsAndBytes "
|
||||||
"quantization yet.")
|
"quantization yet.")
|
||||||
|
|
||||||
if len(self.target_modules) == 0:
|
|
||||||
if hasattr(model, 'default_bitsandbytes_target_modules'):
|
|
||||||
self.target_modules = model.default_bitsandbytes_target_modules
|
|
||||||
else:
|
|
||||||
self.target_modules = self.default_target_modules
|
|
||||||
|
|
||||||
# Modules whose weights might have fused on disk
|
# Modules whose weights might have fused on disk
|
||||||
# we need their output_sizes to make shard in flight correctly with TP
|
# we need their output_sizes to make shard in flight correctly with TP
|
||||||
self.maybe_fused_weights_modules: Dict[str, List[int]] = {}
|
self.maybe_fused_weights_modules: Dict[str, List[int]] = {}
|
||||||
|
self._get_bnb_target_modules(model)
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
# Some modules like `ReplicatedLinear` should not have their weights
|
# Some modules like `ReplicatedLinear` should not have their weights
|
||||||
# sharded. The reason for implementing it this way is to avoid new
|
# sharded. The reason for implementing it this way is to avoid new
|
||||||
@ -1046,7 +1086,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
|
|
||||||
pre_quant = False
|
pre_quant = False
|
||||||
if quant_config is not None:
|
if quant_config is not None:
|
||||||
quant_method = quant_config.get('quant_method')
|
quant_method = quant_config.get("quant_method")
|
||||||
if quant_method == "bitsandbytes":
|
if quant_method == "bitsandbytes":
|
||||||
pre_quant = True
|
pre_quant = True
|
||||||
else:
|
else:
|
||||||
@ -1063,11 +1103,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
|
|
||||||
load_8bit = False
|
load_8bit = False
|
||||||
if pre_quant:
|
if pre_quant:
|
||||||
load_8bit = quant_config.get('load_in_8bit', False)
|
load_8bit = quant_config.get("load_in_8bit", False)
|
||||||
|
|
||||||
qweight_iterator, quant_state_dict = \
|
qweight_iterator, quant_state_dict = (
|
||||||
self._get_quantized_weights_iterator(
|
self._get_quantized_weights_iterator(model_config.model,
|
||||||
model_config.model, model_config.revision, pre_quant, load_8bit)
|
model_config.revision,
|
||||||
|
pre_quant, load_8bit))
|
||||||
|
|
||||||
model.load_weights(qweight_iterator)
|
model.load_weights(qweight_iterator)
|
||||||
|
|
||||||
@ -1078,6 +1119,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
# TODO: Change this lazy import to normal import
|
# TODO: Change this lazy import to normal import
|
||||||
# after the checks are updated to run on a new version
|
# after the checks are updated to run on a new version
|
||||||
from vllm.model_executor.models.utils import is_pp_missing_parameter
|
from vllm.model_executor.models.utils import is_pp_missing_parameter
|
||||||
|
|
||||||
for quant_param_name in quant_state_dict:
|
for quant_param_name in quant_state_dict:
|
||||||
if is_pp_missing_parameter(quant_param_name, model):
|
if is_pp_missing_parameter(quant_param_name, model):
|
||||||
continue
|
continue
|
||||||
@ -1086,9 +1128,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
|
|
||||||
shard_index = 0
|
shard_index = 0
|
||||||
for shard_name, (
|
for shard_name, (
|
||||||
weight_name, index
|
weight_name,
|
||||||
|
index,
|
||||||
) in model.bitsandbytes_stacked_params_mapping.items():
|
) in model.bitsandbytes_stacked_params_mapping.items():
|
||||||
|
|
||||||
shard_pos = quant_param_name.find(shard_name)
|
shard_pos = quant_param_name.find(shard_name)
|
||||||
# Some models, such as MiniCPM V2.5/2.6, contain both
|
# Some models, such as MiniCPM V2.5/2.6, contain both
|
||||||
# module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
|
# module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
|
||||||
@ -1123,8 +1165,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
|
|
||||||
num_elements = [0] * len(quant_states)
|
num_elements = [0] * len(quant_states)
|
||||||
for seq, quant_state in quant_states.items():
|
for seq, quant_state in quant_states.items():
|
||||||
num_elements[seq] = math.prod(
|
num_elements[seq] = (math.prod(quant_state.shape) //
|
||||||
quant_state.shape) // pack_ratio
|
pack_ratio)
|
||||||
|
|
||||||
offsets = np.concatenate(([0], np.cumsum(num_elements)))
|
offsets = np.concatenate(([0], np.cumsum(num_elements)))
|
||||||
set_weight_attrs(param, {"bnb_shard_offsets": offsets})
|
set_weight_attrs(param, {"bnb_shard_offsets": offsets})
|
||||||
|
|||||||
@ -351,14 +351,6 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
embedding_padding_modules = []
|
embedding_padding_modules = []
|
||||||
|
|
||||||
# BitandBytes specific attributes
|
# BitandBytes specific attributes
|
||||||
default_bitsandbytes_target_modules = [
|
|
||||||
".W_pack.",
|
|
||||||
".o_proj.",
|
|
||||||
".down_proj.",
|
|
||||||
".up_proj.",
|
|
||||||
".gate_proj.",
|
|
||||||
".up_proj.",
|
|
||||||
]
|
|
||||||
bitsandbytes_stacked_params_mapping = {
|
bitsandbytes_stacked_params_mapping = {
|
||||||
# shard_name, weight_name, index
|
# shard_name, weight_name, index
|
||||||
"gate_proj": ("gate_up_proj", 0),
|
"gate_proj": ("gate_up_proj", 0),
|
||||||
|
|||||||
@ -412,12 +412,6 @@ class FalconForCausalLM(nn.Module, SupportsPP):
|
|||||||
|
|
||||||
# BitandBytes specific attributes
|
# BitandBytes specific attributes
|
||||||
bitsandbytes_stacked_params_mapping = {}
|
bitsandbytes_stacked_params_mapping = {}
|
||||||
default_bitsandbytes_target_modules = [
|
|
||||||
".query_key_value.",
|
|
||||||
".dense.",
|
|
||||||
".dense_h_to_4h.",
|
|
||||||
".dense_4h_to_h.",
|
|
||||||
]
|
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@ -350,15 +350,6 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
"down_proj",
|
"down_proj",
|
||||||
]
|
]
|
||||||
# BitandBytes specific attributes
|
# BitandBytes specific attributes
|
||||||
default_bitsandbytes_target_modules = [
|
|
||||||
".gate_proj.",
|
|
||||||
".down_proj.",
|
|
||||||
".up_proj.",
|
|
||||||
".q_proj.",
|
|
||||||
".k_proj.",
|
|
||||||
".v_proj.",
|
|
||||||
".o_proj.",
|
|
||||||
]
|
|
||||||
bitsandbytes_stacked_params_mapping = {
|
bitsandbytes_stacked_params_mapping = {
|
||||||
# shard_name, weight_name, index
|
# shard_name, weight_name, index
|
||||||
"q_proj": ("qkv_proj", 0),
|
"q_proj": ("qkv_proj", 0),
|
||||||
|
|||||||
@ -386,15 +386,6 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
embedding_padding_modules = []
|
embedding_padding_modules = []
|
||||||
|
|
||||||
# BitandBytes specific attributes
|
# BitandBytes specific attributes
|
||||||
default_bitsandbytes_target_modules = [
|
|
||||||
".gate_proj.",
|
|
||||||
".down_proj.",
|
|
||||||
".up_proj.",
|
|
||||||
".q_proj.",
|
|
||||||
".k_proj.",
|
|
||||||
".v_proj.",
|
|
||||||
".o_proj.",
|
|
||||||
]
|
|
||||||
bitsandbytes_stacked_params_mapping = {
|
bitsandbytes_stacked_params_mapping = {
|
||||||
# shard_name, weight_name, index
|
# shard_name, weight_name, index
|
||||||
"q_proj": ("qkv_proj", 0),
|
"q_proj": ("qkv_proj", 0),
|
||||||
|
|||||||
@ -656,21 +656,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
]
|
]
|
||||||
|
|
||||||
# BitandBytes specific attributes
|
# BitandBytes specific attributes
|
||||||
default_bitsandbytes_target_modules = [
|
|
||||||
".gate_proj.",
|
|
||||||
".down_proj.",
|
|
||||||
".up_proj.",
|
|
||||||
".q_proj.",
|
|
||||||
".k_proj.",
|
|
||||||
".v_proj.",
|
|
||||||
".o_proj.",
|
|
||||||
# vision_model
|
|
||||||
".fc1.",
|
|
||||||
".fc2.",
|
|
||||||
".out_proj.",
|
|
||||||
# connector
|
|
||||||
".proj.",
|
|
||||||
]
|
|
||||||
bitsandbytes_stacked_params_mapping = {
|
bitsandbytes_stacked_params_mapping = {
|
||||||
# shard_name, weight_name, index
|
# shard_name, weight_name, index
|
||||||
"q_proj": ("qkv_proj", 0),
|
"q_proj": ("qkv_proj", 0),
|
||||||
|
|||||||
@ -463,15 +463,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
embedding_padding_modules = ["lm_head"]
|
embedding_padding_modules = ["lm_head"]
|
||||||
|
|
||||||
# BitandBytes specific attributes
|
# BitandBytes specific attributes
|
||||||
default_bitsandbytes_target_modules = [
|
|
||||||
".gate_proj.",
|
|
||||||
".down_proj.",
|
|
||||||
".up_proj.",
|
|
||||||
".q_proj.",
|
|
||||||
".k_proj.",
|
|
||||||
".v_proj.",
|
|
||||||
".o_proj.",
|
|
||||||
]
|
|
||||||
bitsandbytes_stacked_params_mapping = {
|
bitsandbytes_stacked_params_mapping = {
|
||||||
# shard_name, weight_name, index
|
# shard_name, weight_name, index
|
||||||
"q_proj": ("qkv_proj", 0),
|
"q_proj": ("qkv_proj", 0),
|
||||||
|
|||||||
@ -822,25 +822,6 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
|
|||||||
]
|
]
|
||||||
|
|
||||||
# BitandBytes specific attributes
|
# BitandBytes specific attributes
|
||||||
default_bitsandbytes_target_modules = [
|
|
||||||
".gate_proj.",
|
|
||||||
".down_proj.",
|
|
||||||
".up_proj.",
|
|
||||||
".q_proj.",
|
|
||||||
".k_proj.",
|
|
||||||
".v_proj.",
|
|
||||||
".o_proj.",
|
|
||||||
# vision encoder
|
|
||||||
".fc1.",
|
|
||||||
".fc2.",
|
|
||||||
# Currently, vllm does not support BNB quantization for the `out_proj`
|
|
||||||
# of the resampler, so it's necessary to distinguish between the
|
|
||||||
# vision encoder and the resampler's out_proj. The same applies to
|
|
||||||
# MiniCPMV2_6.
|
|
||||||
".self_attn.out_proj.", # vision encoder out_proj
|
|
||||||
# resampler
|
|
||||||
".kv_proj.",
|
|
||||||
]
|
|
||||||
bitsandbytes_stacked_params_mapping = {
|
bitsandbytes_stacked_params_mapping = {
|
||||||
# shard_name, weight_name, index
|
# shard_name, weight_name, index
|
||||||
"q_proj": ("qkv_proj", 0),
|
"q_proj": ("qkv_proj", 0),
|
||||||
@ -964,21 +945,6 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
|
|||||||
]
|
]
|
||||||
|
|
||||||
# BitandBytes specific attributes
|
# BitandBytes specific attributes
|
||||||
default_bitsandbytes_target_modules = [
|
|
||||||
".gate_proj.",
|
|
||||||
".down_proj.",
|
|
||||||
".up_proj.",
|
|
||||||
".q_proj.",
|
|
||||||
".k_proj.",
|
|
||||||
".v_proj.",
|
|
||||||
".o_proj.",
|
|
||||||
# vision encoder
|
|
||||||
".fc1.",
|
|
||||||
".fc2.",
|
|
||||||
".self_attn.out_proj.",
|
|
||||||
# resampler
|
|
||||||
".kv_proj.",
|
|
||||||
]
|
|
||||||
bitsandbytes_stacked_params_mapping = {
|
bitsandbytes_stacked_params_mapping = {
|
||||||
# shard_name, weight_name, index
|
# shard_name, weight_name, index
|
||||||
"q_proj": ("qkv_proj", 0),
|
"q_proj": ("qkv_proj", 0),
|
||||||
|
|||||||
@ -1104,20 +1104,6 @@ class MllamaForCausalLM(nn.Module):
|
|||||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_mllama)
|
@INPUT_REGISTRY.register_input_processor(input_processor_for_mllama)
|
||||||
class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||||
# BitandBytes specific attributes
|
# BitandBytes specific attributes
|
||||||
default_bitsandbytes_target_modules = [
|
|
||||||
".gate_proj.",
|
|
||||||
".down_proj.",
|
|
||||||
".up_proj.",
|
|
||||||
".q_proj.",
|
|
||||||
".k_proj.",
|
|
||||||
".v_proj.",
|
|
||||||
".o_proj.",
|
|
||||||
".fc1.",
|
|
||||||
".fc2.",
|
|
||||||
# The `multi_modal_projector` is at the top level of the model,
|
|
||||||
# so we can't add a dot in front of it.
|
|
||||||
"multi_modal_projector."
|
|
||||||
]
|
|
||||||
bitsandbytes_stacked_params_mapping = {
|
bitsandbytes_stacked_params_mapping = {
|
||||||
# shard_name, weight_name, index
|
# shard_name, weight_name, index
|
||||||
"q_proj": ("qkv_proj", 0),
|
"q_proj": ("qkv_proj", 0),
|
||||||
|
|||||||
@ -337,9 +337,6 @@ class OPTForCausalLM(nn.Module, SupportsPP):
|
|||||||
"k_proj": ("qkv_proj", 1),
|
"k_proj": ("qkv_proj", 1),
|
||||||
"v_proj": ("qkv_proj", 2),
|
"v_proj": ("qkv_proj", 2),
|
||||||
}
|
}
|
||||||
default_bitsandbytes_target_modules = [
|
|
||||||
".q_proj.", ".k_proj.", ".v_proj.", ".out_proj.", ".fc1.", ".fc2."
|
|
||||||
]
|
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@ -286,9 +286,6 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
"k_proj": ("qkv_proj", 1),
|
"k_proj": ("qkv_proj", 1),
|
||||||
"v_proj": ("qkv_proj", 2),
|
"v_proj": ("qkv_proj", 2),
|
||||||
}
|
}
|
||||||
default_bitsandbytes_target_modules = [
|
|
||||||
".q_proj.", ".k_proj.", ".v_proj.", ".fc1.", ".fc2.", ".dense."
|
|
||||||
]
|
|
||||||
|
|
||||||
embedding_modules = {}
|
embedding_modules = {}
|
||||||
embedding_padding_modules = []
|
embedding_padding_modules = []
|
||||||
|
|||||||
@ -16,11 +16,5 @@ class Phi3ForCausalLM(LlamaForCausalLM):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# BitandBytes specific attributes
|
# BitandBytes specific attributes
|
||||||
default_bitsandbytes_target_modules = [
|
|
||||||
".gate_up_proj.",
|
|
||||||
".down_proj.",
|
|
||||||
".qkv_proj.",
|
|
||||||
".o_proj.",
|
|
||||||
]
|
|
||||||
# Initialize an empty dict when there is no stacked parameter mapping.
|
# Initialize an empty dict when there is no stacked parameter mapping.
|
||||||
bitsandbytes_stacked_params_mapping = {}
|
bitsandbytes_stacked_params_mapping = {}
|
||||||
|
|||||||
@ -1028,12 +1028,7 @@ class QWenLLM(QWenBaseModel):
|
|||||||
embedding_modules = {}
|
embedding_modules = {}
|
||||||
embedding_padding_modules = []
|
embedding_padding_modules = []
|
||||||
|
|
||||||
default_bitsandbytes_target_modules = [
|
# BitandBytes specific attributes
|
||||||
".c_attn.",
|
|
||||||
".c_proj.",
|
|
||||||
".w1.",
|
|
||||||
".w2.",
|
|
||||||
]
|
|
||||||
bitsandbytes_stacked_params_mapping = {
|
bitsandbytes_stacked_params_mapping = {
|
||||||
# shard_name, weight_name, index
|
# shard_name, weight_name, index
|
||||||
"w2": ("gate_up_proj", 0),
|
"w2": ("gate_up_proj", 0),
|
||||||
|
|||||||
@ -419,15 +419,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
embedding_padding_modules = []
|
embedding_padding_modules = []
|
||||||
|
|
||||||
# BitandBytes specific attributes
|
# BitandBytes specific attributes
|
||||||
default_bitsandbytes_target_modules = [
|
|
||||||
".gate_proj.",
|
|
||||||
".down_proj.",
|
|
||||||
".up_proj.",
|
|
||||||
".q_proj.",
|
|
||||||
".k_proj.",
|
|
||||||
".v_proj.",
|
|
||||||
".o_proj.",
|
|
||||||
]
|
|
||||||
bitsandbytes_stacked_params_mapping = {
|
bitsandbytes_stacked_params_mapping = {
|
||||||
# shard_name, weight_name, index
|
# shard_name, weight_name, index
|
||||||
"q_proj": ("qkv_proj", 0),
|
"q_proj": ("qkv_proj", 0),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user