mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 00:35:01 +08:00
[Misc] Add BNB support to GLM4-V model (#12184)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
936db119ed
commit
edaae198e7
@ -1105,15 +1105,22 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
weight_name,
|
weight_name,
|
||||||
index,
|
index,
|
||||||
) in self.modules_mapping.inverse_packed_mapping.items():
|
) in self.modules_mapping.inverse_packed_mapping.items():
|
||||||
shard_pos = quant_param_name.find(shard_name)
|
|
||||||
# Some models, such as MiniCPM V2.5/2.6, contain both
|
# Some models, such as MiniCPM V2.5/2.6, contain both
|
||||||
# module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
|
# module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
|
||||||
# from being incorrectly identified as being present in
|
# from being incorrectly identified as being present in
|
||||||
# 'vpm.encoder.layers.0.self_attn.qkv_proj.weight
|
# 'vpm.encoder.layers.0.self_attn.qkv_proj.weight
|
||||||
if shard_pos > 0 and quant_param_name[shard_pos - 1] == ".":
|
shard_pos = quant_param_name.find(shard_name)
|
||||||
shard_index = index
|
can_correct_rename = (shard_pos > 0) and (
|
||||||
quant_param_name = quant_param_name.replace(
|
quant_param_name[shard_pos - 1] == ".")
|
||||||
|
# If the quant_param_name is packed, it won't occur in the
|
||||||
|
# param_dict before renaming.
|
||||||
|
new_quant_param_name = quant_param_name.replace(
|
||||||
shard_name, weight_name)
|
shard_name, weight_name)
|
||||||
|
need_rename = (quant_param_name not in param_dict) \
|
||||||
|
and (new_quant_param_name in param_dict)
|
||||||
|
if can_correct_rename and need_rename:
|
||||||
|
shard_index = index
|
||||||
|
quant_param_name = new_quant_param_name
|
||||||
break
|
break
|
||||||
|
|
||||||
# Models like Clip/Siglip may skip some layers in initialization,
|
# Models like Clip/Siglip may skip some layers in initialization,
|
||||||
|
|||||||
@ -41,7 +41,7 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
|||||||
from vllm.transformers_utils.configs import ChatGLMConfig
|
from vllm.transformers_utils.configs import ChatGLMConfig
|
||||||
|
|
||||||
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
|
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
|
||||||
from .utils import (is_pp_missing_parameter,
|
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers,
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
maybe_prefix)
|
maybe_prefix)
|
||||||
|
|
||||||
@ -605,9 +605,50 @@ class ChatGLMModel(nn.Module):
|
|||||||
return IntermediateTensors({"hidden_states": hidden_states})
|
return IntermediateTensors({"hidden_states": hidden_states})
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
|
torch.Tensor]]) -> Set[str]:
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("linear_proj.merged_proj", "linear_proj.gate_proj", 0),
|
||||||
|
("linear_proj.merged_proj", "linear_proj.dense_h_to_4h", 1),
|
||||||
|
]
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
loaded_params: Set[str] = set()
|
||||||
|
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if "rotary_pos_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
class ChatGLMBaseModel(nn.Module, SupportsLoRA, SupportsPP):
|
class ChatGLMBaseModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||||
|
|
||||||
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
|
orig_to_new_substr={".word_embeddings": ""}, )
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
@ -660,52 +701,9 @@ class ChatGLMBaseModel(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
next_tokens = self.sampler(logits, sampling_metadata)
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str,
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
torch.Tensor]]) -> Set[str]:
|
loader = AutoWeightsLoader(self)
|
||||||
# Merge two ColumnParallelLinear into one MergedColumnParallelLinear
|
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||||
merged_weights_dict: Dict[str, Dict[str, Optional[torch.Tensor]]] = {
|
|
||||||
"transformer.vision.linear_proj.merged_proj.weight": {
|
|
||||||
"transformer.vision.linear_proj.gate_proj.weight": None,
|
|
||||||
"transformer.vision.linear_proj.dense_h_to_4h.weight": None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
|
||||||
loaded_params: Set[str] = set()
|
|
||||||
for name, loaded_weight in weights:
|
|
||||||
is_weight_to_be_merge = False
|
|
||||||
for _, merged_weight_dict in merged_weights_dict.items():
|
|
||||||
if name in merged_weight_dict:
|
|
||||||
assert merged_weight_dict[name] is None
|
|
||||||
merged_weight_dict[name] = loaded_weight
|
|
||||||
is_weight_to_be_merge = True
|
|
||||||
if is_weight_to_be_merge:
|
|
||||||
continue
|
|
||||||
if "rotary_pos_emb.inv_freq" in name:
|
|
||||||
continue
|
|
||||||
if "word_embeddings" in name:
|
|
||||||
name = name.replace(".word_embeddings", "")
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
if is_pp_missing_parameter(name, self):
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = getattr(param, "weight_loader",
|
|
||||||
default_weight_loader)
|
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
loaded_params.add(name)
|
|
||||||
|
|
||||||
for combined_name, merged_weight_dict in merged_weights_dict.items():
|
|
||||||
if combined_name in params_dict:
|
|
||||||
param = params_dict[combined_name]
|
|
||||||
combined_weight = torch.cat(list(merged_weight_dict.values()),
|
|
||||||
dim=0)
|
|
||||||
weight_loader = getattr(param, "weight_loader",
|
|
||||||
default_weight_loader)
|
|
||||||
weight_loader(param, combined_weight)
|
|
||||||
loaded_params.add(combined_name)
|
|
||||||
return loaded_params
|
|
||||||
|
|
||||||
|
|
||||||
class ChatGLM(ChatGLMBaseModel):
|
class ChatGLM(ChatGLMBaseModel):
|
||||||
@ -726,6 +724,7 @@ class ChatGLM(ChatGLMBaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class ChatGLMV(ChatGLMBaseModel, SupportsMultiModal):
|
class ChatGLMV(ChatGLMBaseModel, SupportsMultiModal):
|
||||||
|
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"query_key_value": ["query_key_value"],
|
"query_key_value": ["query_key_value"],
|
||||||
"dense_h_to_4h": ["dense_h_to_4h"],
|
"dense_h_to_4h": ["dense_h_to_4h"],
|
||||||
@ -777,7 +776,7 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
|
|||||||
) -> None:
|
) -> None:
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
# Initialize VL
|
# Initialize VL
|
||||||
if hasattr(config, "visual"):
|
if hasattr(config, "vision_config"):
|
||||||
return ChatGLMV(vllm_config=vllm_config, prefix=prefix)
|
return ChatGLMV(vllm_config=vllm_config, prefix=prefix)
|
||||||
# Initialize LLM
|
# Initialize LLM
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -42,7 +42,8 @@ class PatchEmbedding(nn.Module):
|
|||||||
torch.Tensor
|
torch.Tensor
|
||||||
Transformed tensor with shape (B, L, D)
|
Transformed tensor with shape (B, L, D)
|
||||||
"""
|
"""
|
||||||
images = images.to(self.proj.weight.device)
|
images = images.to(device=self.proj.weight.device,
|
||||||
|
dtype=self.proj.weight.dtype)
|
||||||
x = self.proj(images)
|
x = self.proj(images)
|
||||||
x = x.flatten(2).transpose(1, 2)
|
x = x.flatten(2).transpose(1, 2)
|
||||||
cls_token = self.cls_embedding.expand(x.shape[0], -1, -1)
|
cls_token = self.cls_embedding.expand(x.shape[0], -1, -1)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user