mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:34:57 +08:00
[Misc]Further reduce BNB static variable (#10597)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
e85250b1d1
commit
15cc2a9f1a
@ -28,7 +28,8 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.envs import VLLM_USE_MODELSCOPE
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
@ -78,12 +79,14 @@ def device_loading_context(module: torch.nn.Module,
|
||||
original_device: torch.device = original_device_states[name]
|
||||
if original_device.type == "cpu":
|
||||
# `torch.empty_like` does not support `pin_memory` argument
|
||||
cpu_data = torch.empty_strided(size=p.data.size(),
|
||||
stride=p.data.stride(),
|
||||
dtype=p.data.dtype,
|
||||
layout=p.data.layout,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
cpu_data = torch.empty_strided(
|
||||
size=p.data.size(),
|
||||
stride=p.data.stride(),
|
||||
dtype=p.data.dtype,
|
||||
layout=p.data.layout,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
cpu_data.copy_(p.data)
|
||||
p.data = cpu_data
|
||||
else:
|
||||
@ -112,7 +115,8 @@ def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module:
|
||||
logger.warning(msg)
|
||||
logger.warning(
|
||||
"Trying to guess the arguments for old-style model class %s",
|
||||
model_class)
|
||||
model_class,
|
||||
)
|
||||
# try to be compatible with old-style model class
|
||||
kwargs = {}
|
||||
if "prefix" in all_params:
|
||||
@ -198,14 +202,17 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
return model_path
|
||||
return None
|
||||
|
||||
def _prepare_weights(self, model_name_or_path: str,
|
||||
revision: Optional[str],
|
||||
fall_back_to_pt: bool) -> Tuple[str, List[str], bool]:
|
||||
def _prepare_weights(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
revision: Optional[str],
|
||||
fall_back_to_pt: bool,
|
||||
) -> Tuple[str, List[str], bool]:
|
||||
"""Prepare weights for the model.
|
||||
|
||||
If the model is not local, it will be downloaded."""
|
||||
model_name_or_path = self._maybe_download_from_modelscope(
|
||||
model_name_or_path, revision) or model_name_or_path
|
||||
model_name_or_path = (self._maybe_download_from_modelscope(
|
||||
model_name_or_path, revision) or model_name_or_path)
|
||||
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
load_format = self.load_config.load_format
|
||||
@ -258,8 +265,11 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
# any files not found in the index.
|
||||
if not is_local:
|
||||
download_safetensors_index_file_from_hf(
|
||||
model_name_or_path, index_file,
|
||||
self.load_config.download_dir, revision)
|
||||
model_name_or_path,
|
||||
index_file,
|
||||
self.load_config.download_dir,
|
||||
revision,
|
||||
)
|
||||
hf_weights_files = filter_duplicate_safetensors_files(
|
||||
hf_weights_files, hf_folder, index_file)
|
||||
else:
|
||||
@ -282,8 +292,11 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
# Currently np_cache only support *.bin checkpoints
|
||||
assert use_safetensors is False
|
||||
weights_iterator = np_cache_weights_iterator(
|
||||
source.model_or_path, self.load_config.download_dir, hf_folder,
|
||||
hf_weights_files)
|
||||
source.model_or_path,
|
||||
self.load_config.download_dir,
|
||||
hf_folder,
|
||||
hf_weights_files,
|
||||
)
|
||||
elif use_safetensors:
|
||||
weights_iterator = safetensors_weights_iterator(hf_weights_files)
|
||||
else:
|
||||
@ -310,17 +323,19 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
model_config: ModelConfig,
|
||||
model: nn.Module,
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
|
||||
primary_weights = DefaultModelLoader.Source(
|
||||
model_config.model,
|
||||
model_config.revision,
|
||||
prefix="",
|
||||
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load",
|
||||
True))
|
||||
True),
|
||||
)
|
||||
yield from self._get_weights_iterator(primary_weights)
|
||||
|
||||
secondary_weights = cast(Iterable[DefaultModelLoader.Source],
|
||||
getattr(model, "secondary_weights", ()))
|
||||
secondary_weights = cast(
|
||||
Iterable[DefaultModelLoader.Source],
|
||||
getattr(model, "secondary_weights", ()),
|
||||
)
|
||||
for source in secondary_weights:
|
||||
yield from self._get_weights_iterator(source)
|
||||
|
||||
@ -416,7 +431,7 @@ class TensorizerLoader(BaseModelLoader):
|
||||
self.tensorizer_config.verify_with_parallel_config(parallel_config)
|
||||
|
||||
def _get_weights_iterator(
|
||||
self) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
self, ) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
|
||||
return tensorizer_weights_iterator(tensorizer_args)
|
||||
|
||||
@ -479,9 +494,10 @@ class TensorizerLoader(BaseModelLoader):
|
||||
|
||||
if parallel_config.tensor_parallel_size > 1:
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
self.tensorizer_config.tensorizer_uri = \
|
||||
self.tensorizer_config.tensorizer_uri \
|
||||
% get_tensor_model_parallel_rank()
|
||||
|
||||
self.tensorizer_config.tensorizer_uri = (
|
||||
self.tensorizer_config.tensorizer_uri %
|
||||
get_tensor_model_parallel_rank())
|
||||
|
||||
if is_vllm_tensorized(self.tensorizer_config):
|
||||
return self._load_model_serialized(vllm_config=vllm_config)
|
||||
@ -520,13 +536,13 @@ class ShardedStateLoader(BaseModelLoader):
|
||||
|
||||
@staticmethod
|
||||
def _filter_subtensors(
|
||||
tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
tensors: Dict[str, torch.Tensor], ) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Filter out all tensors that share the same memory or a subset of the
|
||||
memory of another tensor.
|
||||
"""
|
||||
same_storage_groups: Dict[Any, List[Tuple[
|
||||
str, torch.Tensor]]] = collections.defaultdict(list)
|
||||
same_storage_groups: Dict[Any, List[Tuple[str, torch.Tensor]]] = (
|
||||
collections.defaultdict(list))
|
||||
for key, tensor in tensors.items():
|
||||
if tensor.numel():
|
||||
ptr = tensor.untyped_storage().data_ptr()
|
||||
@ -615,8 +631,11 @@ class ShardedStateLoader(BaseModelLoader):
|
||||
if tensor.shape != param_shape:
|
||||
logger.warning(
|
||||
"loading tensor of shape %s into "
|
||||
"parameter '%s' of shape %s", tensor.shape,
|
||||
key, param_shape)
|
||||
"parameter '%s' of shape %s",
|
||||
tensor.shape,
|
||||
key,
|
||||
param_shape,
|
||||
)
|
||||
param_data.copy_(tensor)
|
||||
state_dict.pop(key)
|
||||
if state_dict:
|
||||
@ -634,6 +653,7 @@ class ShardedStateLoader(BaseModelLoader):
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
|
||||
if pattern is None:
|
||||
pattern = ShardedStateLoader.DEFAULT_PATTERN
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
@ -667,24 +687,6 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
|
||||
possible_config_file_names = ["adapter_config.json"]
|
||||
|
||||
default_target_modules = [
|
||||
".gate_proj.",
|
||||
".down_proj.",
|
||||
".up_proj.",
|
||||
".q_proj.",
|
||||
".k_proj.",
|
||||
".v_proj.",
|
||||
".o_proj.",
|
||||
'.fc1.',
|
||||
'.fc2.',
|
||||
'.dense.',
|
||||
'.query_key_value.',
|
||||
'.qkv_proj.',
|
||||
'.dense_h_to_4h.',
|
||||
'.dense_4h_to_h.',
|
||||
'.out_proj.',
|
||||
]
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
|
||||
@ -709,6 +711,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
with open(config_file_path) as f:
|
||||
config = json.load(f)
|
||||
self.target_modules = config["target_modules"]
|
||||
# TODO: target_modules could be either a list or a regex string.
|
||||
# We need to handle both cases.
|
||||
assert isinstance(self.target_modules,
|
||||
list), "Unsupported target_modules: "
|
||||
f"{self.target_modules}"
|
||||
|
||||
def _get_config_file(self, qlora_adapter: str) -> str:
|
||||
is_local = os.path.isdir(qlora_adapter)
|
||||
@ -734,12 +741,13 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
return config_file_path
|
||||
|
||||
def _get_weight_files(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
allowed_patterns: List[str],
|
||||
revision: Optional[str] = None) -> Tuple[List[str], str]:
|
||||
"""Retrieve weight files. Download the files if necessary.
|
||||
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
allowed_patterns: List[str],
|
||||
revision: Optional[str] = None,
|
||||
) -> Tuple[List[str], str]:
|
||||
"""Retrieve weight files. Download the files if necessary.
|
||||
|
||||
Return the weight files and the file pattern."""
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
|
||||
@ -806,6 +814,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
# only load the bitsandbytes module when needed
|
||||
try:
|
||||
import bitsandbytes
|
||||
|
||||
if bitsandbytes.__version__ < "0.44.0":
|
||||
raise ImportError("bitsandbytes version is wrong. Please "
|
||||
"install bitsandbytes>=0.44.0.")
|
||||
@ -839,8 +848,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
|
||||
def _is_4bit_weight_name(self, weight_name: str):
|
||||
quantized_suffix = {
|
||||
"absmax", "quant_map", "nested_absmax", "nested_quant_map",
|
||||
"bitsandbytes"
|
||||
"absmax",
|
||||
"quant_map",
|
||||
"nested_absmax",
|
||||
"nested_quant_map",
|
||||
"bitsandbytes",
|
||||
}
|
||||
suffix = weight_name.split(".")[-1]
|
||||
return any(q_suffix in suffix for q_suffix in quantized_suffix)
|
||||
@ -857,7 +869,6 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
|
||||
for weight_name, weight_tensor in self._hf_weight_iter(
|
||||
hf_weights_files, use_safetensors):
|
||||
|
||||
if self._is_8bit_weight_name(weight_name):
|
||||
continue
|
||||
|
||||
@ -899,14 +910,13 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
# pre quantized weights would have a quant_state
|
||||
for weight_name, weight_tensor in self._hf_weight_iter(
|
||||
hf_weights_files, use_safetensors):
|
||||
|
||||
if self._is_4bit_weight_name(weight_name):
|
||||
continue
|
||||
|
||||
if (f"{weight_name}.quant_state.bitsandbytes__nf4" \
|
||||
in temp_state_dict) or \
|
||||
(f"{weight_name}.quant_state.bitsandbytes__fp4" \
|
||||
in temp_state_dict):
|
||||
if (f"{weight_name}.quant_state.bitsandbytes__nf4"
|
||||
in temp_state_dict) or (
|
||||
f"{weight_name}.quant_state.bitsandbytes__fp4"
|
||||
in temp_state_dict):
|
||||
quant_state = _parse_quant_state(weight_name, temp_state_dict)
|
||||
quant_state_dict[weight_name] = quant_state
|
||||
yield weight_name, weight_tensor
|
||||
@ -916,12 +926,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
def _unquantized_generator(self, hf_weights_files, use_safetensors,
|
||||
quant_state_dict) -> Generator:
|
||||
from bitsandbytes.functional import quantize_4bit
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
for weight_name, weight_tensor in self._hf_weight_iter(
|
||||
hf_weights_files, use_safetensors):
|
||||
|
||||
if any(target_module in weight_name for target_module in
|
||||
self.target_modules) and weight_name.endswith(".weight"):
|
||||
# Without sharding
|
||||
@ -954,12 +964,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
# get the start/end index of each shard weight tensor
|
||||
total_start_index = list(
|
||||
itertools.accumulate([0] + total_shard_sizes))[:-1]
|
||||
shard_weights_index = [
|
||||
(idx + size // tp_size * tp_rank,
|
||||
idx + size // tp_size * (tp_rank + 1))
|
||||
for idx, size in zip(total_start_index,
|
||||
total_shard_sizes)
|
||||
]
|
||||
shard_weights_index = [(
|
||||
idx + size // tp_size * tp_rank,
|
||||
idx + size // tp_size * (tp_rank + 1),
|
||||
) for idx, size in zip(total_start_index,
|
||||
total_shard_sizes)]
|
||||
# slice and reorder the weight tensor
|
||||
weight_tensor = [
|
||||
weight_tensor[start_index:end_index, ...]
|
||||
@ -989,7 +998,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
processed_weight, quant_state = quantize_4bit(
|
||||
loaded_weight,
|
||||
compress_statistics=True,
|
||||
quant_type="nf4")
|
||||
quant_type="nf4",
|
||||
)
|
||||
|
||||
quant_state_dict[weight_name] = quant_state
|
||||
else:
|
||||
@ -997,28 +1007,58 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
|
||||
yield weight_name, processed_weight
|
||||
|
||||
def _get_bnb_target_modules(self, model: nn.Module) -> None:
|
||||
|
||||
# TODO: Maybe we can replace bitsandbytes_stacked_params_mapping with
|
||||
# packed_modules_mapping.
|
||||
inverse_stacked_mapping: Dict[str, List[str]] = {}
|
||||
for orig, (
|
||||
packed,
|
||||
idx,
|
||||
) in model.bitsandbytes_stacked_params_mapping.items():
|
||||
if packed not in inverse_stacked_mapping:
|
||||
inverse_stacked_mapping[packed] = []
|
||||
inverse_stacked_mapping[packed].insert(idx, orig)
|
||||
|
||||
linear_module_lst = []
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, (LinearBase, )):
|
||||
last_name = name.split(".")[-1]
|
||||
if sub_modules := inverse_stacked_mapping.get(last_name, []):
|
||||
# Map vllm's names to transformers' names.
|
||||
for sub_name in sub_modules:
|
||||
linear_module_lst.append(
|
||||
name.replace(last_name, sub_name))
|
||||
else:
|
||||
linear_module_lst.append(name)
|
||||
if self.target_modules:
|
||||
# Update self.target_modules
|
||||
self.target_modules = [
|
||||
qual_name for qual_name in linear_module_lst
|
||||
if any(t in qual_name for t in self.target_modules)
|
||||
]
|
||||
else:
|
||||
self.target_modules = linear_module_lst
|
||||
assert (self.target_modules
|
||||
), "vllm currently does not support BNB quantization for"
|
||||
f" {type(model).__name__}"
|
||||
|
||||
def _load_weights(self, model_config: ModelConfig,
|
||||
model: nn.Module) -> None:
|
||||
if not hasattr(model, 'load_weights'):
|
||||
if not hasattr(model, "load_weights"):
|
||||
raise AttributeError(
|
||||
"The required method 'load_weights' is not defined in class"
|
||||
f" {type(model).__name__}.")
|
||||
|
||||
if not hasattr(model, 'bitsandbytes_stacked_params_mapping'):
|
||||
if not hasattr(model, "bitsandbytes_stacked_params_mapping"):
|
||||
raise AttributeError(
|
||||
f"Model {type(model).__name__} does not support BitsAndBytes "
|
||||
"quantization yet.")
|
||||
|
||||
if len(self.target_modules) == 0:
|
||||
if hasattr(model, 'default_bitsandbytes_target_modules'):
|
||||
self.target_modules = model.default_bitsandbytes_target_modules
|
||||
else:
|
||||
self.target_modules = self.default_target_modules
|
||||
|
||||
# Modules whose weights might have fused on disk
|
||||
# we need their output_sizes to make shard in flight correctly with TP
|
||||
self.maybe_fused_weights_modules: Dict[str, List[int]] = {}
|
||||
|
||||
self._get_bnb_target_modules(model)
|
||||
for name, module in model.named_modules():
|
||||
# Some modules like `ReplicatedLinear` should not have their weights
|
||||
# sharded. The reason for implementing it this way is to avoid new
|
||||
@ -1046,7 +1086,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
|
||||
pre_quant = False
|
||||
if quant_config is not None:
|
||||
quant_method = quant_config.get('quant_method')
|
||||
quant_method = quant_config.get("quant_method")
|
||||
if quant_method == "bitsandbytes":
|
||||
pre_quant = True
|
||||
else:
|
||||
@ -1063,11 +1103,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
|
||||
load_8bit = False
|
||||
if pre_quant:
|
||||
load_8bit = quant_config.get('load_in_8bit', False)
|
||||
load_8bit = quant_config.get("load_in_8bit", False)
|
||||
|
||||
qweight_iterator, quant_state_dict = \
|
||||
self._get_quantized_weights_iterator(
|
||||
model_config.model, model_config.revision, pre_quant, load_8bit)
|
||||
qweight_iterator, quant_state_dict = (
|
||||
self._get_quantized_weights_iterator(model_config.model,
|
||||
model_config.revision,
|
||||
pre_quant, load_8bit))
|
||||
|
||||
model.load_weights(qweight_iterator)
|
||||
|
||||
@ -1078,6 +1119,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
# TODO: Change this lazy import to normal import
|
||||
# after the checks are updated to run on a new version
|
||||
from vllm.model_executor.models.utils import is_pp_missing_parameter
|
||||
|
||||
for quant_param_name in quant_state_dict:
|
||||
if is_pp_missing_parameter(quant_param_name, model):
|
||||
continue
|
||||
@ -1086,9 +1128,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
|
||||
shard_index = 0
|
||||
for shard_name, (
|
||||
weight_name, index
|
||||
weight_name,
|
||||
index,
|
||||
) in model.bitsandbytes_stacked_params_mapping.items():
|
||||
|
||||
shard_pos = quant_param_name.find(shard_name)
|
||||
# Some models, such as MiniCPM V2.5/2.6, contain both
|
||||
# module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
|
||||
@ -1123,8 +1165,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
|
||||
num_elements = [0] * len(quant_states)
|
||||
for seq, quant_state in quant_states.items():
|
||||
num_elements[seq] = math.prod(
|
||||
quant_state.shape) // pack_ratio
|
||||
num_elements[seq] = (math.prod(quant_state.shape) //
|
||||
pack_ratio)
|
||||
|
||||
offsets = np.concatenate(([0], np.cumsum(num_elements)))
|
||||
set_weight_attrs(param, {"bnb_shard_offsets": offsets})
|
||||
|
||||
@ -351,14 +351,6 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
embedding_padding_modules = []
|
||||
|
||||
# BitandBytes specific attributes
|
||||
default_bitsandbytes_target_modules = [
|
||||
".W_pack.",
|
||||
".o_proj.",
|
||||
".down_proj.",
|
||||
".up_proj.",
|
||||
".gate_proj.",
|
||||
".up_proj.",
|
||||
]
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"gate_proj": ("gate_up_proj", 0),
|
||||
|
||||
@ -412,12 +412,6 @@ class FalconForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
# BitandBytes specific attributes
|
||||
bitsandbytes_stacked_params_mapping = {}
|
||||
default_bitsandbytes_target_modules = [
|
||||
".query_key_value.",
|
||||
".dense.",
|
||||
".dense_h_to_4h.",
|
||||
".dense_4h_to_h.",
|
||||
]
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
@ -350,15 +350,6 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
"down_proj",
|
||||
]
|
||||
# BitandBytes specific attributes
|
||||
default_bitsandbytes_target_modules = [
|
||||
".gate_proj.",
|
||||
".down_proj.",
|
||||
".up_proj.",
|
||||
".q_proj.",
|
||||
".k_proj.",
|
||||
".v_proj.",
|
||||
".o_proj.",
|
||||
]
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
|
||||
@ -386,15 +386,6 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
embedding_padding_modules = []
|
||||
|
||||
# BitandBytes specific attributes
|
||||
default_bitsandbytes_target_modules = [
|
||||
".gate_proj.",
|
||||
".down_proj.",
|
||||
".up_proj.",
|
||||
".q_proj.",
|
||||
".k_proj.",
|
||||
".v_proj.",
|
||||
".o_proj.",
|
||||
]
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
|
||||
@ -656,21 +656,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
]
|
||||
|
||||
# BitandBytes specific attributes
|
||||
default_bitsandbytes_target_modules = [
|
||||
".gate_proj.",
|
||||
".down_proj.",
|
||||
".up_proj.",
|
||||
".q_proj.",
|
||||
".k_proj.",
|
||||
".v_proj.",
|
||||
".o_proj.",
|
||||
# vision_model
|
||||
".fc1.",
|
||||
".fc2.",
|
||||
".out_proj.",
|
||||
# connector
|
||||
".proj.",
|
||||
]
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
|
||||
@ -463,15 +463,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
embedding_padding_modules = ["lm_head"]
|
||||
|
||||
# BitandBytes specific attributes
|
||||
default_bitsandbytes_target_modules = [
|
||||
".gate_proj.",
|
||||
".down_proj.",
|
||||
".up_proj.",
|
||||
".q_proj.",
|
||||
".k_proj.",
|
||||
".v_proj.",
|
||||
".o_proj.",
|
||||
]
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
|
||||
@ -822,25 +822,6 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
|
||||
]
|
||||
|
||||
# BitandBytes specific attributes
|
||||
default_bitsandbytes_target_modules = [
|
||||
".gate_proj.",
|
||||
".down_proj.",
|
||||
".up_proj.",
|
||||
".q_proj.",
|
||||
".k_proj.",
|
||||
".v_proj.",
|
||||
".o_proj.",
|
||||
# vision encoder
|
||||
".fc1.",
|
||||
".fc2.",
|
||||
# Currently, vllm does not support BNB quantization for the `out_proj`
|
||||
# of the resampler, so it's necessary to distinguish between the
|
||||
# vision encoder and the resampler's out_proj. The same applies to
|
||||
# MiniCPMV2_6.
|
||||
".self_attn.out_proj.", # vision encoder out_proj
|
||||
# resampler
|
||||
".kv_proj.",
|
||||
]
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
@ -964,21 +945,6 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
|
||||
]
|
||||
|
||||
# BitandBytes specific attributes
|
||||
default_bitsandbytes_target_modules = [
|
||||
".gate_proj.",
|
||||
".down_proj.",
|
||||
".up_proj.",
|
||||
".q_proj.",
|
||||
".k_proj.",
|
||||
".v_proj.",
|
||||
".o_proj.",
|
||||
# vision encoder
|
||||
".fc1.",
|
||||
".fc2.",
|
||||
".self_attn.out_proj.",
|
||||
# resampler
|
||||
".kv_proj.",
|
||||
]
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
|
||||
@ -1104,20 +1104,6 @@ class MllamaForCausalLM(nn.Module):
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_mllama)
|
||||
class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
# BitandBytes specific attributes
|
||||
default_bitsandbytes_target_modules = [
|
||||
".gate_proj.",
|
||||
".down_proj.",
|
||||
".up_proj.",
|
||||
".q_proj.",
|
||||
".k_proj.",
|
||||
".v_proj.",
|
||||
".o_proj.",
|
||||
".fc1.",
|
||||
".fc2.",
|
||||
# The `multi_modal_projector` is at the top level of the model,
|
||||
# so we can't add a dot in front of it.
|
||||
"multi_modal_projector."
|
||||
]
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
|
||||
@ -337,9 +337,6 @@ class OPTForCausalLM(nn.Module, SupportsPP):
|
||||
"k_proj": ("qkv_proj", 1),
|
||||
"v_proj": ("qkv_proj", 2),
|
||||
}
|
||||
default_bitsandbytes_target_modules = [
|
||||
".q_proj.", ".k_proj.", ".v_proj.", ".out_proj.", ".fc1.", ".fc2."
|
||||
]
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
@ -286,9 +286,6 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
"k_proj": ("qkv_proj", 1),
|
||||
"v_proj": ("qkv_proj", 2),
|
||||
}
|
||||
default_bitsandbytes_target_modules = [
|
||||
".q_proj.", ".k_proj.", ".v_proj.", ".fc1.", ".fc2.", ".dense."
|
||||
]
|
||||
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
@ -16,11 +16,5 @@ class Phi3ForCausalLM(LlamaForCausalLM):
|
||||
}
|
||||
|
||||
# BitandBytes specific attributes
|
||||
default_bitsandbytes_target_modules = [
|
||||
".gate_up_proj.",
|
||||
".down_proj.",
|
||||
".qkv_proj.",
|
||||
".o_proj.",
|
||||
]
|
||||
# Initialize an empty dict when there is no stacked parameter mapping.
|
||||
bitsandbytes_stacked_params_mapping = {}
|
||||
|
||||
@ -1028,12 +1028,7 @@ class QWenLLM(QWenBaseModel):
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
default_bitsandbytes_target_modules = [
|
||||
".c_attn.",
|
||||
".c_proj.",
|
||||
".w1.",
|
||||
".w2.",
|
||||
]
|
||||
# BitandBytes specific attributes
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"w2": ("gate_up_proj", 0),
|
||||
|
||||
@ -419,15 +419,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
embedding_padding_modules = []
|
||||
|
||||
# BitandBytes specific attributes
|
||||
default_bitsandbytes_target_modules = [
|
||||
".gate_proj.",
|
||||
".down_proj.",
|
||||
".up_proj.",
|
||||
".q_proj.",
|
||||
".k_proj.",
|
||||
".v_proj.",
|
||||
".o_proj.",
|
||||
]
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user