mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-01 08:48:42 +08:00
[Model] Refactor Molmo weights loading to use AutoWeightsLoader (#10771)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
40bc242579
commit
16ee07f22a
@ -3,7 +3,7 @@ import re
|
||||
from array import array
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache, partial
|
||||
from typing import Iterable, List, Mapping, Optional, Tuple, TypedDict
|
||||
from typing import Iterable, List, Mapping, Optional, Set, Tuple, TypedDict
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
@ -44,7 +44,8 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
||||
from vllm.transformers_utils.processor import get_processor
|
||||
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import (get_vit_attn_backend,
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, get_vit_attn_backend,
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
@ -720,6 +721,42 @@ class MolmoVisionBackbone(nn.Module):
|
||||
# image_features: (batch_size, num_image, num_patch, d_model)
|
||||
return image_features
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class MolmoModel(nn.Module):
|
||||
@ -804,6 +841,28 @@ class MolmoModel(nn.Module):
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if "gate_up_proj" in name:
|
||||
up_proj, gate_proj = loaded_weight.chunk(2, dim=0)
|
||||
loaded_weight = torch.cat([gate_proj, up_proj], dim=0)
|
||||
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
cached_get_processor = lru_cache(get_processor)
|
||||
|
||||
@ -1200,103 +1259,53 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_substr={
|
||||
# vision backbone mapping
|
||||
"image_projector.w1.": "image_projector.gate_proj.",
|
||||
"image_projector.w3.": "image_projector.up_proj.",
|
||||
"image_projector.w2.": "image_projector.down_proj.",
|
||||
# language backbone mapping
|
||||
"att_proj": "self_attn.qkv_proj",
|
||||
"attn_out": "self_attn.o_proj",
|
||||
"q_norm": "self_attn.q_norm",
|
||||
"k_norm": "self_attn.k_norm",
|
||||
"ff_proj": "mlp.gate_up_proj",
|
||||
"ff_out": "mlp.down_proj",
|
||||
"attn_norm": "input_layernorm",
|
||||
"ff_norm": "post_attention_layernorm",
|
||||
},
|
||||
orig_to_new_prefix={
|
||||
# vision backbone mapping
|
||||
"model.vision_backbone.": "vision_backbone.",
|
||||
# language backbone mapping
|
||||
"model.transformer.blocks.": "model.layers.",
|
||||
"model.transformer.ln_f.": "model.norm.",
|
||||
# lm_head is renamed to model.transformer.mlp.down_proj firstly,
|
||||
# we need to run a second renaming for it
|
||||
"model.transformer.mlp.down_proj.": "lm_head.",
|
||||
},
|
||||
)
|
||||
loader = AutoWeightsLoader(self)
|
||||
weights = _get_weights_with_merged_embedding(weights)
|
||||
return loader.load_weights(weights, mapper=hf_to_vllm_mapper)
|
||||
|
||||
params_mapping = [
|
||||
("model.transformer.ln_f.weight", "model.norm.weight"),
|
||||
("attn_out", "self_attn.o_proj"),
|
||||
("att_proj", "self_attn.qkv_proj"),
|
||||
("q_norm", "self_attn.q_norm"),
|
||||
("k_norm", "self_attn.k_norm"),
|
||||
("attn_norm", "input_layernorm"),
|
||||
("ff_norm", "post_attention_layernorm"),
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
|
||||
embedding_weight = dict()
|
||||
projector_weight = dict()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
||||
continue
|
||||
|
||||
if "wte.embedding" in name:
|
||||
embedding_weight["embedding"] = loaded_weight
|
||||
continue
|
||||
|
||||
if "wte.new_embedding" in name:
|
||||
embedding_weight["new_embedding"] = loaded_weight
|
||||
continue
|
||||
|
||||
if "vision_backbone" in name:
|
||||
if name.startswith("model"):
|
||||
name = name[len("model."):]
|
||||
if 'image_projector' in name:
|
||||
if 'w1' in name:
|
||||
projector_weight['gate_proj'] = loaded_weight
|
||||
elif 'w3' in name:
|
||||
projector_weight['up_proj'] = loaded_weight
|
||||
elif 'w2' in name:
|
||||
projector_weight['down_proj'] = loaded_weight
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected projector weight: {name}")
|
||||
continue
|
||||
else:
|
||||
if "transformer.blocks" in name:
|
||||
name = name.replace("transformer.blocks", "layers")
|
||||
|
||||
if "ff_proj" in name:
|
||||
name = name.replace("ff_proj", "mlp.gate_up_proj")
|
||||
assert 'weight' in name
|
||||
up_weight, gate_weight = loaded_weight.chunk(2, dim=0)
|
||||
loaded_weight = torch.cat([gate_weight, up_weight], dim=0)
|
||||
|
||||
elif "ff_out" in name:
|
||||
if "layers" in name:
|
||||
name = name.replace("ff_out", "mlp.down_proj")
|
||||
else:
|
||||
# lm head
|
||||
name = name.replace("model.transformer.ff_out",
|
||||
"lm_head")
|
||||
|
||||
else:
|
||||
for (param_name, weight_name) in params_mapping:
|
||||
if param_name in name:
|
||||
name = name.replace(param_name, weight_name)
|
||||
break
|
||||
|
||||
try:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
except KeyError:
|
||||
raise ValueError(f"Unexpected weight: {name}") from None
|
||||
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
gate_up_proj_weight = torch.cat(
|
||||
[projector_weight["gate_proj"], projector_weight["up_proj"]],
|
||||
dim=0)
|
||||
name = "vision_backbone.image_projector.gate_up_proj.weight"
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, gate_up_proj_weight)
|
||||
|
||||
down_proj_weight = projector_weight["down_proj"]
|
||||
name = "vision_backbone.image_projector.down_proj.weight"
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, down_proj_weight)
|
||||
|
||||
embedding_weight = torch.cat(
|
||||
[embedding_weight["embedding"], embedding_weight["new_embedding"]],
|
||||
dim=0)
|
||||
name = "model.embed_tokens.weight"
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, embedding_weight)
|
||||
def _get_weights_with_merged_embedding(
|
||||
weights: Iterable[Tuple[str, torch.Tensor]]
|
||||
) -> Iterable[Tuple[str, torch.Tensor]]:
|
||||
embedding_weights = {}
|
||||
for name, weight in weights:
|
||||
if "wte.embedding" in name:
|
||||
embedding_weights["embedding"] = weight
|
||||
elif "wte.new_embedding" in name:
|
||||
embedding_weights["new_embedding"] = weight
|
||||
else:
|
||||
yield (name, weight)
|
||||
# this is compatible with most of quantization,
|
||||
# because they won't quantize embed_tokens
|
||||
embedding_weights = torch.cat(
|
||||
[embedding_weights["embedding"], embedding_weights["new_embedding"]],
|
||||
dim=0,
|
||||
)
|
||||
yield ("model.embed_tokens.weight", embedding_weights)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user