[Misc]Further reduce BNB static variable (#10597)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2024-11-27 14:54:12 +08:00 committed by GitHub
parent e85250b1d1
commit 15cc2a9f1a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 131 additions and 219 deletions

View File

@ -28,7 +28,8 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.envs import VLLM_USE_MODELSCOPE from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (LinearBase,
MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
@ -78,12 +79,14 @@ def device_loading_context(module: torch.nn.Module,
original_device: torch.device = original_device_states[name] original_device: torch.device = original_device_states[name]
if original_device.type == "cpu": if original_device.type == "cpu":
# `torch.empty_like` does not support `pin_memory` argument # `torch.empty_like` does not support `pin_memory` argument
cpu_data = torch.empty_strided(size=p.data.size(), cpu_data = torch.empty_strided(
size=p.data.size(),
stride=p.data.stride(), stride=p.data.stride(),
dtype=p.data.dtype, dtype=p.data.dtype,
layout=p.data.layout, layout=p.data.layout,
device="cpu", device="cpu",
pin_memory=pin_memory) pin_memory=pin_memory,
)
cpu_data.copy_(p.data) cpu_data.copy_(p.data)
p.data = cpu_data p.data = cpu_data
else: else:
@ -112,7 +115,8 @@ def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module:
logger.warning(msg) logger.warning(msg)
logger.warning( logger.warning(
"Trying to guess the arguments for old-style model class %s", "Trying to guess the arguments for old-style model class %s",
model_class) model_class,
)
# try to be compatible with old-style model class # try to be compatible with old-style model class
kwargs = {} kwargs = {}
if "prefix" in all_params: if "prefix" in all_params:
@ -198,14 +202,17 @@ class DefaultModelLoader(BaseModelLoader):
return model_path return model_path
return None return None
def _prepare_weights(self, model_name_or_path: str, def _prepare_weights(
self,
model_name_or_path: str,
revision: Optional[str], revision: Optional[str],
fall_back_to_pt: bool) -> Tuple[str, List[str], bool]: fall_back_to_pt: bool,
) -> Tuple[str, List[str], bool]:
"""Prepare weights for the model. """Prepare weights for the model.
If the model is not local, it will be downloaded.""" If the model is not local, it will be downloaded."""
model_name_or_path = self._maybe_download_from_modelscope( model_name_or_path = (self._maybe_download_from_modelscope(
model_name_or_path, revision) or model_name_or_path model_name_or_path, revision) or model_name_or_path)
is_local = os.path.isdir(model_name_or_path) is_local = os.path.isdir(model_name_or_path)
load_format = self.load_config.load_format load_format = self.load_config.load_format
@ -258,8 +265,11 @@ class DefaultModelLoader(BaseModelLoader):
# any files not found in the index. # any files not found in the index.
if not is_local: if not is_local:
download_safetensors_index_file_from_hf( download_safetensors_index_file_from_hf(
model_name_or_path, index_file, model_name_or_path,
self.load_config.download_dir, revision) index_file,
self.load_config.download_dir,
revision,
)
hf_weights_files = filter_duplicate_safetensors_files( hf_weights_files = filter_duplicate_safetensors_files(
hf_weights_files, hf_folder, index_file) hf_weights_files, hf_folder, index_file)
else: else:
@ -282,8 +292,11 @@ class DefaultModelLoader(BaseModelLoader):
# Currently np_cache only support *.bin checkpoints # Currently np_cache only support *.bin checkpoints
assert use_safetensors is False assert use_safetensors is False
weights_iterator = np_cache_weights_iterator( weights_iterator = np_cache_weights_iterator(
source.model_or_path, self.load_config.download_dir, hf_folder, source.model_or_path,
hf_weights_files) self.load_config.download_dir,
hf_folder,
hf_weights_files,
)
elif use_safetensors: elif use_safetensors:
weights_iterator = safetensors_weights_iterator(hf_weights_files) weights_iterator = safetensors_weights_iterator(hf_weights_files)
else: else:
@ -310,17 +323,19 @@ class DefaultModelLoader(BaseModelLoader):
model_config: ModelConfig, model_config: ModelConfig,
model: nn.Module, model: nn.Module,
) -> Generator[Tuple[str, torch.Tensor], None, None]: ) -> Generator[Tuple[str, torch.Tensor], None, None]:
primary_weights = DefaultModelLoader.Source( primary_weights = DefaultModelLoader.Source(
model_config.model, model_config.model,
model_config.revision, model_config.revision,
prefix="", prefix="",
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load",
True)) True),
)
yield from self._get_weights_iterator(primary_weights) yield from self._get_weights_iterator(primary_weights)
secondary_weights = cast(Iterable[DefaultModelLoader.Source], secondary_weights = cast(
getattr(model, "secondary_weights", ())) Iterable[DefaultModelLoader.Source],
getattr(model, "secondary_weights", ()),
)
for source in secondary_weights: for source in secondary_weights:
yield from self._get_weights_iterator(source) yield from self._get_weights_iterator(source)
@ -416,7 +431,7 @@ class TensorizerLoader(BaseModelLoader):
self.tensorizer_config.verify_with_parallel_config(parallel_config) self.tensorizer_config.verify_with_parallel_config(parallel_config)
def _get_weights_iterator( def _get_weights_iterator(
self) -> Generator[Tuple[str, torch.Tensor], None, None]: self, ) -> Generator[Tuple[str, torch.Tensor], None, None]:
tensorizer_args = self.tensorizer_config._construct_tensorizer_args() tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
return tensorizer_weights_iterator(tensorizer_args) return tensorizer_weights_iterator(tensorizer_args)
@ -479,9 +494,10 @@ class TensorizerLoader(BaseModelLoader):
if parallel_config.tensor_parallel_size > 1: if parallel_config.tensor_parallel_size > 1:
from vllm.distributed import get_tensor_model_parallel_rank from vllm.distributed import get_tensor_model_parallel_rank
self.tensorizer_config.tensorizer_uri = \
self.tensorizer_config.tensorizer_uri \ self.tensorizer_config.tensorizer_uri = (
% get_tensor_model_parallel_rank() self.tensorizer_config.tensorizer_uri %
get_tensor_model_parallel_rank())
if is_vllm_tensorized(self.tensorizer_config): if is_vllm_tensorized(self.tensorizer_config):
return self._load_model_serialized(vllm_config=vllm_config) return self._load_model_serialized(vllm_config=vllm_config)
@ -520,13 +536,13 @@ class ShardedStateLoader(BaseModelLoader):
@staticmethod @staticmethod
def _filter_subtensors( def _filter_subtensors(
tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: tensors: Dict[str, torch.Tensor], ) -> Dict[str, torch.Tensor]:
""" """
Filter out all tensors that share the same memory or a subset of the Filter out all tensors that share the same memory or a subset of the
memory of another tensor. memory of another tensor.
""" """
same_storage_groups: Dict[Any, List[Tuple[ same_storage_groups: Dict[Any, List[Tuple[str, torch.Tensor]]] = (
str, torch.Tensor]]] = collections.defaultdict(list) collections.defaultdict(list))
for key, tensor in tensors.items(): for key, tensor in tensors.items():
if tensor.numel(): if tensor.numel():
ptr = tensor.untyped_storage().data_ptr() ptr = tensor.untyped_storage().data_ptr()
@ -615,8 +631,11 @@ class ShardedStateLoader(BaseModelLoader):
if tensor.shape != param_shape: if tensor.shape != param_shape:
logger.warning( logger.warning(
"loading tensor of shape %s into " "loading tensor of shape %s into "
"parameter '%s' of shape %s", tensor.shape, "parameter '%s' of shape %s",
key, param_shape) tensor.shape,
key,
param_shape,
)
param_data.copy_(tensor) param_data.copy_(tensor)
state_dict.pop(key) state_dict.pop(key)
if state_dict: if state_dict:
@ -634,6 +653,7 @@ class ShardedStateLoader(BaseModelLoader):
from safetensors.torch import save_file from safetensors.torch import save_file
from vllm.distributed import get_tensor_model_parallel_rank from vllm.distributed import get_tensor_model_parallel_rank
if pattern is None: if pattern is None:
pattern = ShardedStateLoader.DEFAULT_PATTERN pattern = ShardedStateLoader.DEFAULT_PATTERN
rank = get_tensor_model_parallel_rank() rank = get_tensor_model_parallel_rank()
@ -667,24 +687,6 @@ class BitsAndBytesModelLoader(BaseModelLoader):
possible_config_file_names = ["adapter_config.json"] possible_config_file_names = ["adapter_config.json"]
default_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
'.fc1.',
'.fc2.',
'.dense.',
'.query_key_value.',
'.qkv_proj.',
'.dense_h_to_4h.',
'.dense_4h_to_h.',
'.out_proj.',
]
def __init__(self, load_config: LoadConfig): def __init__(self, load_config: LoadConfig):
super().__init__(load_config) super().__init__(load_config)
@ -709,6 +711,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
with open(config_file_path) as f: with open(config_file_path) as f:
config = json.load(f) config = json.load(f)
self.target_modules = config["target_modules"] self.target_modules = config["target_modules"]
# TODO: target_modules could be either a list or a regex string.
# We need to handle both cases.
assert isinstance(self.target_modules,
list), "Unsupported target_modules: "
f"{self.target_modules}"
def _get_config_file(self, qlora_adapter: str) -> str: def _get_config_file(self, qlora_adapter: str) -> str:
is_local = os.path.isdir(qlora_adapter) is_local = os.path.isdir(qlora_adapter)
@ -737,7 +744,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
self, self,
model_name_or_path: str, model_name_or_path: str,
allowed_patterns: List[str], allowed_patterns: List[str],
revision: Optional[str] = None) -> Tuple[List[str], str]: revision: Optional[str] = None,
) -> Tuple[List[str], str]:
"""Retrieve weight files. Download the files if necessary. """Retrieve weight files. Download the files if necessary.
Return the weight files and the file pattern.""" Return the weight files and the file pattern."""
@ -806,6 +814,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# only load the bitsandbytes module when needed # only load the bitsandbytes module when needed
try: try:
import bitsandbytes import bitsandbytes
if bitsandbytes.__version__ < "0.44.0": if bitsandbytes.__version__ < "0.44.0":
raise ImportError("bitsandbytes version is wrong. Please " raise ImportError("bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.44.0.") "install bitsandbytes>=0.44.0.")
@ -839,8 +848,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
def _is_4bit_weight_name(self, weight_name: str): def _is_4bit_weight_name(self, weight_name: str):
quantized_suffix = { quantized_suffix = {
"absmax", "quant_map", "nested_absmax", "nested_quant_map", "absmax",
"bitsandbytes" "quant_map",
"nested_absmax",
"nested_quant_map",
"bitsandbytes",
} }
suffix = weight_name.split(".")[-1] suffix = weight_name.split(".")[-1]
return any(q_suffix in suffix for q_suffix in quantized_suffix) return any(q_suffix in suffix for q_suffix in quantized_suffix)
@ -857,7 +869,6 @@ class BitsAndBytesModelLoader(BaseModelLoader):
for weight_name, weight_tensor in self._hf_weight_iter( for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors): hf_weights_files, use_safetensors):
if self._is_8bit_weight_name(weight_name): if self._is_8bit_weight_name(weight_name):
continue continue
@ -899,13 +910,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# pre quantized weights would have a quant_state # pre quantized weights would have a quant_state
for weight_name, weight_tensor in self._hf_weight_iter( for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors): hf_weights_files, use_safetensors):
if self._is_4bit_weight_name(weight_name): if self._is_4bit_weight_name(weight_name):
continue continue
if (f"{weight_name}.quant_state.bitsandbytes__nf4" \ if (f"{weight_name}.quant_state.bitsandbytes__nf4"
in temp_state_dict) or \ in temp_state_dict) or (
(f"{weight_name}.quant_state.bitsandbytes__fp4" \ f"{weight_name}.quant_state.bitsandbytes__fp4"
in temp_state_dict): in temp_state_dict):
quant_state = _parse_quant_state(weight_name, temp_state_dict) quant_state = _parse_quant_state(weight_name, temp_state_dict)
quant_state_dict[weight_name] = quant_state quant_state_dict[weight_name] = quant_state
@ -916,12 +926,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
def _unquantized_generator(self, hf_weights_files, use_safetensors, def _unquantized_generator(self, hf_weights_files, use_safetensors,
quant_state_dict) -> Generator: quant_state_dict) -> Generator:
from bitsandbytes.functional import quantize_4bit from bitsandbytes.functional import quantize_4bit
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
for weight_name, weight_tensor in self._hf_weight_iter( for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors): hf_weights_files, use_safetensors):
if any(target_module in weight_name for target_module in if any(target_module in weight_name for target_module in
self.target_modules) and weight_name.endswith(".weight"): self.target_modules) and weight_name.endswith(".weight"):
# Without sharding # Without sharding
@ -954,12 +964,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# get the start/end index of each shard weight tensor # get the start/end index of each shard weight tensor
total_start_index = list( total_start_index = list(
itertools.accumulate([0] + total_shard_sizes))[:-1] itertools.accumulate([0] + total_shard_sizes))[:-1]
shard_weights_index = [ shard_weights_index = [(
(idx + size // tp_size * tp_rank, idx + size // tp_size * tp_rank,
idx + size // tp_size * (tp_rank + 1)) idx + size // tp_size * (tp_rank + 1),
for idx, size in zip(total_start_index, ) for idx, size in zip(total_start_index,
total_shard_sizes) total_shard_sizes)]
]
# slice and reorder the weight tensor # slice and reorder the weight tensor
weight_tensor = [ weight_tensor = [
weight_tensor[start_index:end_index, ...] weight_tensor[start_index:end_index, ...]
@ -989,7 +998,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
processed_weight, quant_state = quantize_4bit( processed_weight, quant_state = quantize_4bit(
loaded_weight, loaded_weight,
compress_statistics=True, compress_statistics=True,
quant_type="nf4") quant_type="nf4",
)
quant_state_dict[weight_name] = quant_state quant_state_dict[weight_name] = quant_state
else: else:
@ -997,28 +1007,58 @@ class BitsAndBytesModelLoader(BaseModelLoader):
yield weight_name, processed_weight yield weight_name, processed_weight
def _get_bnb_target_modules(self, model: nn.Module) -> None:
# TODO: Maybe we can replace bitsandbytes_stacked_params_mapping with
# packed_modules_mapping.
inverse_stacked_mapping: Dict[str, List[str]] = {}
for orig, (
packed,
idx,
) in model.bitsandbytes_stacked_params_mapping.items():
if packed not in inverse_stacked_mapping:
inverse_stacked_mapping[packed] = []
inverse_stacked_mapping[packed].insert(idx, orig)
linear_module_lst = []
for name, module in model.named_modules():
if isinstance(module, (LinearBase, )):
last_name = name.split(".")[-1]
if sub_modules := inverse_stacked_mapping.get(last_name, []):
# Map vllm's names to transformers' names.
for sub_name in sub_modules:
linear_module_lst.append(
name.replace(last_name, sub_name))
else:
linear_module_lst.append(name)
if self.target_modules:
# Update self.target_modules
self.target_modules = [
qual_name for qual_name in linear_module_lst
if any(t in qual_name for t in self.target_modules)
]
else:
self.target_modules = linear_module_lst
assert (self.target_modules
), "vllm currently does not support BNB quantization for"
f" {type(model).__name__}"
def _load_weights(self, model_config: ModelConfig, def _load_weights(self, model_config: ModelConfig,
model: nn.Module) -> None: model: nn.Module) -> None:
if not hasattr(model, 'load_weights'): if not hasattr(model, "load_weights"):
raise AttributeError( raise AttributeError(
"The required method 'load_weights' is not defined in class" "The required method 'load_weights' is not defined in class"
f" {type(model).__name__}.") f" {type(model).__name__}.")
if not hasattr(model, 'bitsandbytes_stacked_params_mapping'): if not hasattr(model, "bitsandbytes_stacked_params_mapping"):
raise AttributeError( raise AttributeError(
f"Model {type(model).__name__} does not support BitsAndBytes " f"Model {type(model).__name__} does not support BitsAndBytes "
"quantization yet.") "quantization yet.")
if len(self.target_modules) == 0:
if hasattr(model, 'default_bitsandbytes_target_modules'):
self.target_modules = model.default_bitsandbytes_target_modules
else:
self.target_modules = self.default_target_modules
# Modules whose weights might have fused on disk # Modules whose weights might have fused on disk
# we need their output_sizes to make shard in flight correctly with TP # we need their output_sizes to make shard in flight correctly with TP
self.maybe_fused_weights_modules: Dict[str, List[int]] = {} self.maybe_fused_weights_modules: Dict[str, List[int]] = {}
self._get_bnb_target_modules(model)
for name, module in model.named_modules(): for name, module in model.named_modules():
# Some modules like `ReplicatedLinear` should not have their weights # Some modules like `ReplicatedLinear` should not have their weights
# sharded. The reason for implementing it this way is to avoid new # sharded. The reason for implementing it this way is to avoid new
@ -1046,7 +1086,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
pre_quant = False pre_quant = False
if quant_config is not None: if quant_config is not None:
quant_method = quant_config.get('quant_method') quant_method = quant_config.get("quant_method")
if quant_method == "bitsandbytes": if quant_method == "bitsandbytes":
pre_quant = True pre_quant = True
else: else:
@ -1063,11 +1103,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
load_8bit = False load_8bit = False
if pre_quant: if pre_quant:
load_8bit = quant_config.get('load_in_8bit', False) load_8bit = quant_config.get("load_in_8bit", False)
qweight_iterator, quant_state_dict = \ qweight_iterator, quant_state_dict = (
self._get_quantized_weights_iterator( self._get_quantized_weights_iterator(model_config.model,
model_config.model, model_config.revision, pre_quant, load_8bit) model_config.revision,
pre_quant, load_8bit))
model.load_weights(qweight_iterator) model.load_weights(qweight_iterator)
@ -1078,6 +1119,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# TODO: Change this lazy import to normal import # TODO: Change this lazy import to normal import
# after the checks are updated to run on a new version # after the checks are updated to run on a new version
from vllm.model_executor.models.utils import is_pp_missing_parameter from vllm.model_executor.models.utils import is_pp_missing_parameter
for quant_param_name in quant_state_dict: for quant_param_name in quant_state_dict:
if is_pp_missing_parameter(quant_param_name, model): if is_pp_missing_parameter(quant_param_name, model):
continue continue
@ -1086,9 +1128,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
shard_index = 0 shard_index = 0
for shard_name, ( for shard_name, (
weight_name, index weight_name,
index,
) in model.bitsandbytes_stacked_params_mapping.items(): ) in model.bitsandbytes_stacked_params_mapping.items():
shard_pos = quant_param_name.find(shard_name) shard_pos = quant_param_name.find(shard_name)
# Some models, such as MiniCPM V2.5/2.6, contain both # Some models, such as MiniCPM V2.5/2.6, contain both
# module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj' # module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
@ -1123,8 +1165,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
num_elements = [0] * len(quant_states) num_elements = [0] * len(quant_states)
for seq, quant_state in quant_states.items(): for seq, quant_state in quant_states.items():
num_elements[seq] = math.prod( num_elements[seq] = (math.prod(quant_state.shape) //
quant_state.shape) // pack_ratio pack_ratio)
offsets = np.concatenate(([0], np.cumsum(num_elements))) offsets = np.concatenate(([0], np.cumsum(num_elements)))
set_weight_attrs(param, {"bnb_shard_offsets": offsets}) set_weight_attrs(param, {"bnb_shard_offsets": offsets})

View File

@ -351,14 +351,6 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
embedding_padding_modules = [] embedding_padding_modules = []
# BitandBytes specific attributes # BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".W_pack.",
".o_proj.",
".down_proj.",
".up_proj.",
".gate_proj.",
".up_proj.",
]
bitsandbytes_stacked_params_mapping = { bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index # shard_name, weight_name, index
"gate_proj": ("gate_up_proj", 0), "gate_proj": ("gate_up_proj", 0),

View File

@ -412,12 +412,6 @@ class FalconForCausalLM(nn.Module, SupportsPP):
# BitandBytes specific attributes # BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {} bitsandbytes_stacked_params_mapping = {}
default_bitsandbytes_target_modules = [
".query_key_value.",
".dense.",
".dense_h_to_4h.",
".dense_4h_to_h.",
]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()

View File

@ -350,15 +350,6 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"down_proj", "down_proj",
] ]
# BitandBytes specific attributes # BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
bitsandbytes_stacked_params_mapping = { bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index # shard_name, weight_name, index
"q_proj": ("qkv_proj", 0), "q_proj": ("qkv_proj", 0),

View File

@ -386,15 +386,6 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
embedding_padding_modules = [] embedding_padding_modules = []
# BitandBytes specific attributes # BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
bitsandbytes_stacked_params_mapping = { bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index # shard_name, weight_name, index
"q_proj": ("qkv_proj", 0), "q_proj": ("qkv_proj", 0),

View File

@ -656,21 +656,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
] ]
# BitandBytes specific attributes # BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
# vision_model
".fc1.",
".fc2.",
".out_proj.",
# connector
".proj.",
]
bitsandbytes_stacked_params_mapping = { bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index # shard_name, weight_name, index
"q_proj": ("qkv_proj", 0), "q_proj": ("qkv_proj", 0),

View File

@ -463,15 +463,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
embedding_padding_modules = ["lm_head"] embedding_padding_modules = ["lm_head"]
# BitandBytes specific attributes # BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
bitsandbytes_stacked_params_mapping = { bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index # shard_name, weight_name, index
"q_proj": ("qkv_proj", 0), "q_proj": ("qkv_proj", 0),

View File

@ -822,25 +822,6 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
] ]
# BitandBytes specific attributes # BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
# vision encoder
".fc1.",
".fc2.",
# Currently, vllm does not support BNB quantization for the `out_proj`
# of the resampler, so it's necessary to distinguish between the
# vision encoder and the resampler's out_proj. The same applies to
# MiniCPMV2_6.
".self_attn.out_proj.", # vision encoder out_proj
# resampler
".kv_proj.",
]
bitsandbytes_stacked_params_mapping = { bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index # shard_name, weight_name, index
"q_proj": ("qkv_proj", 0), "q_proj": ("qkv_proj", 0),
@ -964,21 +945,6 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
] ]
# BitandBytes specific attributes # BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
# vision encoder
".fc1.",
".fc2.",
".self_attn.out_proj.",
# resampler
".kv_proj.",
]
bitsandbytes_stacked_params_mapping = { bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index # shard_name, weight_name, index
"q_proj": ("qkv_proj", 0), "q_proj": ("qkv_proj", 0),

View File

@ -1104,20 +1104,6 @@ class MllamaForCausalLM(nn.Module):
@INPUT_REGISTRY.register_input_processor(input_processor_for_mllama) @INPUT_REGISTRY.register_input_processor(input_processor_for_mllama)
class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
# BitandBytes specific attributes # BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
".fc1.",
".fc2.",
# The `multi_modal_projector` is at the top level of the model,
# so we can't add a dot in front of it.
"multi_modal_projector."
]
bitsandbytes_stacked_params_mapping = { bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index # shard_name, weight_name, index
"q_proj": ("qkv_proj", 0), "q_proj": ("qkv_proj", 0),

View File

@ -337,9 +337,6 @@ class OPTForCausalLM(nn.Module, SupportsPP):
"k_proj": ("qkv_proj", 1), "k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2), "v_proj": ("qkv_proj", 2),
} }
default_bitsandbytes_target_modules = [
".q_proj.", ".k_proj.", ".v_proj.", ".out_proj.", ".fc1.", ".fc2."
]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()

View File

@ -286,9 +286,6 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"k_proj": ("qkv_proj", 1), "k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2), "v_proj": ("qkv_proj", 2),
} }
default_bitsandbytes_target_modules = [
".q_proj.", ".k_proj.", ".v_proj.", ".fc1.", ".fc2.", ".dense."
]
embedding_modules = {} embedding_modules = {}
embedding_padding_modules = [] embedding_padding_modules = []

View File

@ -16,11 +16,5 @@ class Phi3ForCausalLM(LlamaForCausalLM):
} }
# BitandBytes specific attributes # BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_up_proj.",
".down_proj.",
".qkv_proj.",
".o_proj.",
]
# Initialize an empty dict when there is no stacked parameter mapping. # Initialize an empty dict when there is no stacked parameter mapping.
bitsandbytes_stacked_params_mapping = {} bitsandbytes_stacked_params_mapping = {}

View File

@ -1028,12 +1028,7 @@ class QWenLLM(QWenBaseModel):
embedding_modules = {} embedding_modules = {}
embedding_padding_modules = [] embedding_padding_modules = []
default_bitsandbytes_target_modules = [ # BitandBytes specific attributes
".c_attn.",
".c_proj.",
".w1.",
".w2.",
]
bitsandbytes_stacked_params_mapping = { bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index # shard_name, weight_name, index
"w2": ("gate_up_proj", 0), "w2": ("gate_up_proj", 0),

View File

@ -419,15 +419,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
embedding_padding_modules = [] embedding_padding_modules = []
# BitandBytes specific attributes # BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
bitsandbytes_stacked_params_mapping = { bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index # shard_name, weight_name, index
"q_proj": ("qkv_proj", 0), "q_proj": ("qkv_proj", 0),