[Misc]Further reduce BNB static variable (#10597)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2024-11-27 14:54:12 +08:00 committed by GitHub
parent e85250b1d1
commit 15cc2a9f1a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 131 additions and 219 deletions

View File

@ -28,7 +28,8 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from vllm.model_executor.layers.linear import (LinearBase,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
@ -78,12 +79,14 @@ def device_loading_context(module: torch.nn.Module,
original_device: torch.device = original_device_states[name]
if original_device.type == "cpu":
# `torch.empty_like` does not support `pin_memory` argument
cpu_data = torch.empty_strided(size=p.data.size(),
stride=p.data.stride(),
dtype=p.data.dtype,
layout=p.data.layout,
device="cpu",
pin_memory=pin_memory)
cpu_data = torch.empty_strided(
size=p.data.size(),
stride=p.data.stride(),
dtype=p.data.dtype,
layout=p.data.layout,
device="cpu",
pin_memory=pin_memory,
)
cpu_data.copy_(p.data)
p.data = cpu_data
else:
@ -112,7 +115,8 @@ def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module:
logger.warning(msg)
logger.warning(
"Trying to guess the arguments for old-style model class %s",
model_class)
model_class,
)
# try to be compatible with old-style model class
kwargs = {}
if "prefix" in all_params:
@ -198,14 +202,17 @@ class DefaultModelLoader(BaseModelLoader):
return model_path
return None
def _prepare_weights(self, model_name_or_path: str,
revision: Optional[str],
fall_back_to_pt: bool) -> Tuple[str, List[str], bool]:
def _prepare_weights(
self,
model_name_or_path: str,
revision: Optional[str],
fall_back_to_pt: bool,
) -> Tuple[str, List[str], bool]:
"""Prepare weights for the model.
If the model is not local, it will be downloaded."""
model_name_or_path = self._maybe_download_from_modelscope(
model_name_or_path, revision) or model_name_or_path
model_name_or_path = (self._maybe_download_from_modelscope(
model_name_or_path, revision) or model_name_or_path)
is_local = os.path.isdir(model_name_or_path)
load_format = self.load_config.load_format
@ -258,8 +265,11 @@ class DefaultModelLoader(BaseModelLoader):
# any files not found in the index.
if not is_local:
download_safetensors_index_file_from_hf(
model_name_or_path, index_file,
self.load_config.download_dir, revision)
model_name_or_path,
index_file,
self.load_config.download_dir,
revision,
)
hf_weights_files = filter_duplicate_safetensors_files(
hf_weights_files, hf_folder, index_file)
else:
@ -282,8 +292,11 @@ class DefaultModelLoader(BaseModelLoader):
# Currently np_cache only support *.bin checkpoints
assert use_safetensors is False
weights_iterator = np_cache_weights_iterator(
source.model_or_path, self.load_config.download_dir, hf_folder,
hf_weights_files)
source.model_or_path,
self.load_config.download_dir,
hf_folder,
hf_weights_files,
)
elif use_safetensors:
weights_iterator = safetensors_weights_iterator(hf_weights_files)
else:
@ -310,17 +323,19 @@ class DefaultModelLoader(BaseModelLoader):
model_config: ModelConfig,
model: nn.Module,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
primary_weights = DefaultModelLoader.Source(
model_config.model,
model_config.revision,
prefix="",
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load",
True))
True),
)
yield from self._get_weights_iterator(primary_weights)
secondary_weights = cast(Iterable[DefaultModelLoader.Source],
getattr(model, "secondary_weights", ()))
secondary_weights = cast(
Iterable[DefaultModelLoader.Source],
getattr(model, "secondary_weights", ()),
)
for source in secondary_weights:
yield from self._get_weights_iterator(source)
@ -416,7 +431,7 @@ class TensorizerLoader(BaseModelLoader):
self.tensorizer_config.verify_with_parallel_config(parallel_config)
def _get_weights_iterator(
self) -> Generator[Tuple[str, torch.Tensor], None, None]:
self, ) -> Generator[Tuple[str, torch.Tensor], None, None]:
tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
return tensorizer_weights_iterator(tensorizer_args)
@ -479,9 +494,10 @@ class TensorizerLoader(BaseModelLoader):
if parallel_config.tensor_parallel_size > 1:
from vllm.distributed import get_tensor_model_parallel_rank
self.tensorizer_config.tensorizer_uri = \
self.tensorizer_config.tensorizer_uri \
% get_tensor_model_parallel_rank()
self.tensorizer_config.tensorizer_uri = (
self.tensorizer_config.tensorizer_uri %
get_tensor_model_parallel_rank())
if is_vllm_tensorized(self.tensorizer_config):
return self._load_model_serialized(vllm_config=vllm_config)
@ -520,13 +536,13 @@ class ShardedStateLoader(BaseModelLoader):
@staticmethod
def _filter_subtensors(
tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
tensors: Dict[str, torch.Tensor], ) -> Dict[str, torch.Tensor]:
"""
Filter out all tensors that share the same memory or a subset of the
memory of another tensor.
"""
same_storage_groups: Dict[Any, List[Tuple[
str, torch.Tensor]]] = collections.defaultdict(list)
same_storage_groups: Dict[Any, List[Tuple[str, torch.Tensor]]] = (
collections.defaultdict(list))
for key, tensor in tensors.items():
if tensor.numel():
ptr = tensor.untyped_storage().data_ptr()
@ -615,8 +631,11 @@ class ShardedStateLoader(BaseModelLoader):
if tensor.shape != param_shape:
logger.warning(
"loading tensor of shape %s into "
"parameter '%s' of shape %s", tensor.shape,
key, param_shape)
"parameter '%s' of shape %s",
tensor.shape,
key,
param_shape,
)
param_data.copy_(tensor)
state_dict.pop(key)
if state_dict:
@ -634,6 +653,7 @@ class ShardedStateLoader(BaseModelLoader):
from safetensors.torch import save_file
from vllm.distributed import get_tensor_model_parallel_rank
if pattern is None:
pattern = ShardedStateLoader.DEFAULT_PATTERN
rank = get_tensor_model_parallel_rank()
@ -667,24 +687,6 @@ class BitsAndBytesModelLoader(BaseModelLoader):
possible_config_file_names = ["adapter_config.json"]
default_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
'.fc1.',
'.fc2.',
'.dense.',
'.query_key_value.',
'.qkv_proj.',
'.dense_h_to_4h.',
'.dense_4h_to_h.',
'.out_proj.',
]
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
@ -709,6 +711,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
with open(config_file_path) as f:
config = json.load(f)
self.target_modules = config["target_modules"]
# TODO: target_modules could be either a list or a regex string.
# We need to handle both cases.
assert isinstance(self.target_modules,
list), "Unsupported target_modules: "
f"{self.target_modules}"
def _get_config_file(self, qlora_adapter: str) -> str:
is_local = os.path.isdir(qlora_adapter)
@ -734,12 +741,13 @@ class BitsAndBytesModelLoader(BaseModelLoader):
return config_file_path
def _get_weight_files(
self,
model_name_or_path: str,
allowed_patterns: List[str],
revision: Optional[str] = None) -> Tuple[List[str], str]:
"""Retrieve weight files. Download the files if necessary.
self,
model_name_or_path: str,
allowed_patterns: List[str],
revision: Optional[str] = None,
) -> Tuple[List[str], str]:
"""Retrieve weight files. Download the files if necessary.
Return the weight files and the file pattern."""
is_local = os.path.isdir(model_name_or_path)
@ -806,6 +814,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# only load the bitsandbytes module when needed
try:
import bitsandbytes
if bitsandbytes.__version__ < "0.44.0":
raise ImportError("bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.44.0.")
@ -839,8 +848,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
def _is_4bit_weight_name(self, weight_name: str):
quantized_suffix = {
"absmax", "quant_map", "nested_absmax", "nested_quant_map",
"bitsandbytes"
"absmax",
"quant_map",
"nested_absmax",
"nested_quant_map",
"bitsandbytes",
}
suffix = weight_name.split(".")[-1]
return any(q_suffix in suffix for q_suffix in quantized_suffix)
@ -857,7 +869,6 @@ class BitsAndBytesModelLoader(BaseModelLoader):
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
if self._is_8bit_weight_name(weight_name):
continue
@ -899,14 +910,13 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# pre quantized weights would have a quant_state
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
if self._is_4bit_weight_name(weight_name):
continue
if (f"{weight_name}.quant_state.bitsandbytes__nf4" \
in temp_state_dict) or \
(f"{weight_name}.quant_state.bitsandbytes__fp4" \
in temp_state_dict):
if (f"{weight_name}.quant_state.bitsandbytes__nf4"
in temp_state_dict) or (
f"{weight_name}.quant_state.bitsandbytes__fp4"
in temp_state_dict):
quant_state = _parse_quant_state(weight_name, temp_state_dict)
quant_state_dict[weight_name] = quant_state
yield weight_name, weight_tensor
@ -916,12 +926,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
def _unquantized_generator(self, hf_weights_files, use_safetensors,
quant_state_dict) -> Generator:
from bitsandbytes.functional import quantize_4bit
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
if any(target_module in weight_name for target_module in
self.target_modules) and weight_name.endswith(".weight"):
# Without sharding
@ -954,12 +964,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# get the start/end index of each shard weight tensor
total_start_index = list(
itertools.accumulate([0] + total_shard_sizes))[:-1]
shard_weights_index = [
(idx + size // tp_size * tp_rank,
idx + size // tp_size * (tp_rank + 1))
for idx, size in zip(total_start_index,
total_shard_sizes)
]
shard_weights_index = [(
idx + size // tp_size * tp_rank,
idx + size // tp_size * (tp_rank + 1),
) for idx, size in zip(total_start_index,
total_shard_sizes)]
# slice and reorder the weight tensor
weight_tensor = [
weight_tensor[start_index:end_index, ...]
@ -989,7 +998,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
processed_weight, quant_state = quantize_4bit(
loaded_weight,
compress_statistics=True,
quant_type="nf4")
quant_type="nf4",
)
quant_state_dict[weight_name] = quant_state
else:
@ -997,28 +1007,58 @@ class BitsAndBytesModelLoader(BaseModelLoader):
yield weight_name, processed_weight
def _get_bnb_target_modules(self, model: nn.Module) -> None:
# TODO: Maybe we can replace bitsandbytes_stacked_params_mapping with
# packed_modules_mapping.
inverse_stacked_mapping: Dict[str, List[str]] = {}
for orig, (
packed,
idx,
) in model.bitsandbytes_stacked_params_mapping.items():
if packed not in inverse_stacked_mapping:
inverse_stacked_mapping[packed] = []
inverse_stacked_mapping[packed].insert(idx, orig)
linear_module_lst = []
for name, module in model.named_modules():
if isinstance(module, (LinearBase, )):
last_name = name.split(".")[-1]
if sub_modules := inverse_stacked_mapping.get(last_name, []):
# Map vllm's names to transformers' names.
for sub_name in sub_modules:
linear_module_lst.append(
name.replace(last_name, sub_name))
else:
linear_module_lst.append(name)
if self.target_modules:
# Update self.target_modules
self.target_modules = [
qual_name for qual_name in linear_module_lst
if any(t in qual_name for t in self.target_modules)
]
else:
self.target_modules = linear_module_lst
assert (self.target_modules
), "vllm currently does not support BNB quantization for"
f" {type(model).__name__}"
def _load_weights(self, model_config: ModelConfig,
model: nn.Module) -> None:
if not hasattr(model, 'load_weights'):
if not hasattr(model, "load_weights"):
raise AttributeError(
"The required method 'load_weights' is not defined in class"
f" {type(model).__name__}.")
if not hasattr(model, 'bitsandbytes_stacked_params_mapping'):
if not hasattr(model, "bitsandbytes_stacked_params_mapping"):
raise AttributeError(
f"Model {type(model).__name__} does not support BitsAndBytes "
"quantization yet.")
if len(self.target_modules) == 0:
if hasattr(model, 'default_bitsandbytes_target_modules'):
self.target_modules = model.default_bitsandbytes_target_modules
else:
self.target_modules = self.default_target_modules
# Modules whose weights might have fused on disk
# we need their output_sizes to make shard in flight correctly with TP
self.maybe_fused_weights_modules: Dict[str, List[int]] = {}
self._get_bnb_target_modules(model)
for name, module in model.named_modules():
# Some modules like `ReplicatedLinear` should not have their weights
# sharded. The reason for implementing it this way is to avoid new
@ -1046,7 +1086,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
pre_quant = False
if quant_config is not None:
quant_method = quant_config.get('quant_method')
quant_method = quant_config.get("quant_method")
if quant_method == "bitsandbytes":
pre_quant = True
else:
@ -1063,11 +1103,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
load_8bit = False
if pre_quant:
load_8bit = quant_config.get('load_in_8bit', False)
load_8bit = quant_config.get("load_in_8bit", False)
qweight_iterator, quant_state_dict = \
self._get_quantized_weights_iterator(
model_config.model, model_config.revision, pre_quant, load_8bit)
qweight_iterator, quant_state_dict = (
self._get_quantized_weights_iterator(model_config.model,
model_config.revision,
pre_quant, load_8bit))
model.load_weights(qweight_iterator)
@ -1078,6 +1119,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# TODO: Change this lazy import to normal import
# after the checks are updated to run on a new version
from vllm.model_executor.models.utils import is_pp_missing_parameter
for quant_param_name in quant_state_dict:
if is_pp_missing_parameter(quant_param_name, model):
continue
@ -1086,9 +1128,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
shard_index = 0
for shard_name, (
weight_name, index
weight_name,
index,
) in model.bitsandbytes_stacked_params_mapping.items():
shard_pos = quant_param_name.find(shard_name)
# Some models, such as MiniCPM V2.5/2.6, contain both
# module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
@ -1123,8 +1165,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
num_elements = [0] * len(quant_states)
for seq, quant_state in quant_states.items():
num_elements[seq] = math.prod(
quant_state.shape) // pack_ratio
num_elements[seq] = (math.prod(quant_state.shape) //
pack_ratio)
offsets = np.concatenate(([0], np.cumsum(num_elements)))
set_weight_attrs(param, {"bnb_shard_offsets": offsets})

View File

@ -351,14 +351,6 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
embedding_padding_modules = []
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".W_pack.",
".o_proj.",
".down_proj.",
".up_proj.",
".gate_proj.",
".up_proj.",
]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"gate_proj": ("gate_up_proj", 0),

View File

@ -412,12 +412,6 @@ class FalconForCausalLM(nn.Module, SupportsPP):
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {}
default_bitsandbytes_target_modules = [
".query_key_value.",
".dense.",
".dense_h_to_4h.",
".dense_4h_to_h.",
]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()

View File

@ -350,15 +350,6 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"down_proj",
]
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),

View File

@ -386,15 +386,6 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
embedding_padding_modules = []
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),

View File

@ -656,21 +656,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
]
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
# vision_model
".fc1.",
".fc2.",
".out_proj.",
# connector
".proj.",
]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),

View File

@ -463,15 +463,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
embedding_padding_modules = ["lm_head"]
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),

View File

@ -822,25 +822,6 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
]
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
# vision encoder
".fc1.",
".fc2.",
# Currently, vllm does not support BNB quantization for the `out_proj`
# of the resampler, so it's necessary to distinguish between the
# vision encoder and the resampler's out_proj. The same applies to
# MiniCPMV2_6.
".self_attn.out_proj.", # vision encoder out_proj
# resampler
".kv_proj.",
]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
@ -964,21 +945,6 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
]
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
# vision encoder
".fc1.",
".fc2.",
".self_attn.out_proj.",
# resampler
".kv_proj.",
]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),

View File

@ -1104,20 +1104,6 @@ class MllamaForCausalLM(nn.Module):
@INPUT_REGISTRY.register_input_processor(input_processor_for_mllama)
class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
".fc1.",
".fc2.",
# The `multi_modal_projector` is at the top level of the model,
# so we can't add a dot in front of it.
"multi_modal_projector."
]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),

View File

@ -337,9 +337,6 @@ class OPTForCausalLM(nn.Module, SupportsPP):
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
}
default_bitsandbytes_target_modules = [
".q_proj.", ".k_proj.", ".v_proj.", ".out_proj.", ".fc1.", ".fc2."
]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()

View File

@ -286,9 +286,6 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
}
default_bitsandbytes_target_modules = [
".q_proj.", ".k_proj.", ".v_proj.", ".fc1.", ".fc2.", ".dense."
]
embedding_modules = {}
embedding_padding_modules = []

View File

@ -16,11 +16,5 @@ class Phi3ForCausalLM(LlamaForCausalLM):
}
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_up_proj.",
".down_proj.",
".qkv_proj.",
".o_proj.",
]
# Initialize an empty dict when there is no stacked parameter mapping.
bitsandbytes_stacked_params_mapping = {}

View File

@ -1028,12 +1028,7 @@ class QWenLLM(QWenBaseModel):
embedding_modules = {}
embedding_padding_modules = []
default_bitsandbytes_target_modules = [
".c_attn.",
".c_proj.",
".w1.",
".w2.",
]
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"w2": ("gate_up_proj", 0),

View File

@ -419,15 +419,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
embedding_padding_modules = []
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),