mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-07 06:35:42 +08:00
[Model] Refactor and decouple weight loading logic for InternVL2 model (#7067)
This commit is contained in:
parent
a0d164567c
commit
0c25435daa
@ -4,7 +4,7 @@
|
|||||||
# Copyright (c) 2023 OpenGVLab
|
# Copyright (c) 2023 OpenGVLab
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
# --------------------------------------------------------
|
# --------------------------------------------------------
|
||||||
from typing import Optional
|
from typing import Iterable, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -16,6 +16,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
|||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
|
||||||
NORM2FN = {
|
NORM2FN = {
|
||||||
'rms_norm': RMSNorm,
|
'rms_norm': RMSNorm,
|
||||||
@ -268,3 +269,11 @@ class InternVisionModel(nn.Module):
|
|||||||
encoder_outputs = self.encoder(inputs_embeds=hidden_states)
|
encoder_outputs = self.encoder(inputs_embeds=hidden_states)
|
||||||
|
|
||||||
return encoder_outputs
|
return encoder_outputs
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|||||||
@ -4,6 +4,7 @@
|
|||||||
# Copyright (c) 2023 OpenGVLab
|
# Copyright (c) 2023 OpenGVLab
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
# --------------------------------------------------------
|
# --------------------------------------------------------
|
||||||
|
import itertools
|
||||||
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
|
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -414,58 +415,31 @@ class InternVLChatModel(nn.Module, SupportsVision):
|
|||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
return self.language_model.sample(logits, sampling_metadata)
|
return self.language_model.sample(logits, sampling_metadata)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def _filter_weights(self, weights: Iterable[Tuple[str, torch.Tensor]],
|
||||||
stacked_params_mapping = [
|
prefix: str):
|
||||||
# (param_name, shard_name, shard_id)
|
|
||||||
(".qkv_proj", ".q_proj", "q"),
|
|
||||||
(".qkv_proj", ".k_proj", "k"),
|
|
||||||
(".qkv_proj", ".v_proj", "v"),
|
|
||||||
(".gate_up_proj", ".gate_proj", 0),
|
|
||||||
(".gate_up_proj", ".up_proj", 1),
|
|
||||||
(".gate_up_proj", ".w1", 0),
|
|
||||||
(".gate_up_proj", ".w3", 1),
|
|
||||||
]
|
|
||||||
params_dict = dict(self.named_parameters())
|
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
if "rotary_emb.inv_freq" in name:
|
name = name.split(".")
|
||||||
continue
|
if prefix == name.pop(0):
|
||||||
if self.config.text_config.tie_word_embeddings \
|
name = ".".join(name)
|
||||||
and "lm_head.weight" in name:
|
yield name, loaded_weight
|
||||||
continue
|
|
||||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
# We only do sharding for language model
|
# prepare weight iterators for components
|
||||||
# and not vision model for now.
|
vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
|
||||||
if "vision_embed_tokens" in name and self.vision_embed_tokens:
|
|
||||||
continue
|
# load vision encoder
|
||||||
if weight_name not in name:
|
vit_weights = self._filter_weights(vit_weights, "vision_model")
|
||||||
continue
|
self.vision_model.load_weights(vit_weights)
|
||||||
param = params_dict[name.replace(weight_name, param_name)]
|
|
||||||
weight_loader = param.weight_loader
|
# load mlp projector
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
mlp_weights = self._filter_weights(mlp_weights, "mlp1")
|
||||||
break
|
mlp_params_dict = dict(self.mlp1.named_parameters())
|
||||||
else:
|
for name, loaded_weight in mlp_weights:
|
||||||
# Skip loading extra bias for GPTQ models.
|
param = mlp_params_dict[name]
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
weight_loader = getattr(param, "weight_loader",
|
||||||
continue
|
default_weight_loader)
|
||||||
param = params_dict[name]
|
weight_loader(param, loaded_weight)
|
||||||
if "wqkv" in name:
|
|
||||||
config = self.config.text_config
|
# load llm backbone
|
||||||
kv_groups = (config.num_attention_heads //
|
llm_weights = self._filter_weights(llm_weights, "language_model")
|
||||||
config.num_key_value_heads)
|
self.language_model.load_weights(llm_weights)
|
||||||
head_dim = config.hidden_size // config.num_attention_heads
|
|
||||||
loaded_weight = loaded_weight.view(-1, 2 + kv_groups,
|
|
||||||
head_dim,
|
|
||||||
loaded_weight.shape[-1])
|
|
||||||
wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1],
|
|
||||||
dim=1)
|
|
||||||
wq = wq.reshape(-1, wq.shape[-1])
|
|
||||||
wk = wk.reshape(-1, wk.shape[-1])
|
|
||||||
wv = wv.reshape(-1, wv.shape[-1])
|
|
||||||
weight_loader = param.weight_loader
|
|
||||||
weight_loader(param, wq, 'q')
|
|
||||||
weight_loader(param, wk, 'k')
|
|
||||||
weight_loader(param, wv, 'v')
|
|
||||||
continue
|
|
||||||
weight_loader = getattr(param, "weight_loader",
|
|
||||||
default_weight_loader)
|
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user