mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-07 11:01:19 +08:00
[Model] Refactor Molmo weights loading to use AutoWeightsLoader (#10771)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
40bc242579
commit
16ee07f22a
@ -3,7 +3,7 @@ import re
|
|||||||
from array import array
|
from array import array
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import lru_cache, partial
|
from functools import lru_cache, partial
|
||||||
from typing import Iterable, List, Mapping, Optional, Tuple, TypedDict
|
from typing import Iterable, List, Mapping, Optional, Set, Tuple, TypedDict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
@ -44,7 +44,8 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
|||||||
from vllm.transformers_utils.processor import get_processor
|
from vllm.transformers_utils.processor import get_processor
|
||||||
|
|
||||||
from .interfaces import SupportsMultiModal, SupportsPP
|
from .interfaces import SupportsMultiModal, SupportsPP
|
||||||
from .utils import (get_vit_attn_backend,
|
from .utils import (AutoWeightsLoader, WeightsMapper, get_vit_attn_backend,
|
||||||
|
is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers,
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
maybe_prefix)
|
maybe_prefix)
|
||||||
|
|
||||||
@ -720,6 +721,42 @@ class MolmoVisionBackbone(nn.Module):
|
|||||||
# image_features: (batch_size, num_image, num_patch, d_model)
|
# image_features: (batch_size, num_image, num_patch, d_model)
|
||||||
return image_features
|
return image_features
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
|
torch.Tensor]]) -> Set[str]:
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("gate_up_proj", "gate_proj", 0),
|
||||||
|
("gate_up_proj", "up_proj", 1),
|
||||||
|
]
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
loaded_params: Set[str] = set()
|
||||||
|
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
@support_torch_compile
|
@support_torch_compile
|
||||||
class MolmoModel(nn.Module):
|
class MolmoModel(nn.Module):
|
||||||
@ -804,6 +841,28 @@ class MolmoModel(nn.Module):
|
|||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
|
torch.Tensor]]) -> Set[str]:
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
loaded_params: Set[str] = set()
|
||||||
|
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if "gate_up_proj" in name:
|
||||||
|
up_proj, gate_proj = loaded_weight.chunk(2, dim=0)
|
||||||
|
loaded_weight = torch.cat([gate_proj, up_proj], dim=0)
|
||||||
|
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
cached_get_processor = lru_cache(get_processor)
|
cached_get_processor = lru_cache(get_processor)
|
||||||
|
|
||||||
@ -1200,103 +1259,53 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
|
orig_to_new_substr={
|
||||||
|
# vision backbone mapping
|
||||||
|
"image_projector.w1.": "image_projector.gate_proj.",
|
||||||
|
"image_projector.w3.": "image_projector.up_proj.",
|
||||||
|
"image_projector.w2.": "image_projector.down_proj.",
|
||||||
|
# language backbone mapping
|
||||||
|
"att_proj": "self_attn.qkv_proj",
|
||||||
|
"attn_out": "self_attn.o_proj",
|
||||||
|
"q_norm": "self_attn.q_norm",
|
||||||
|
"k_norm": "self_attn.k_norm",
|
||||||
|
"ff_proj": "mlp.gate_up_proj",
|
||||||
|
"ff_out": "mlp.down_proj",
|
||||||
|
"attn_norm": "input_layernorm",
|
||||||
|
"ff_norm": "post_attention_layernorm",
|
||||||
|
},
|
||||||
|
orig_to_new_prefix={
|
||||||
|
# vision backbone mapping
|
||||||
|
"model.vision_backbone.": "vision_backbone.",
|
||||||
|
# language backbone mapping
|
||||||
|
"model.transformer.blocks.": "model.layers.",
|
||||||
|
"model.transformer.ln_f.": "model.norm.",
|
||||||
|
# lm_head is renamed to model.transformer.mlp.down_proj firstly,
|
||||||
|
# we need to run a second renaming for it
|
||||||
|
"model.transformer.mlp.down_proj.": "lm_head.",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
loader = AutoWeightsLoader(self)
|
||||||
|
weights = _get_weights_with_merged_embedding(weights)
|
||||||
|
return loader.load_weights(weights, mapper=hf_to_vllm_mapper)
|
||||||
|
|
||||||
params_mapping = [
|
|
||||||
("model.transformer.ln_f.weight", "model.norm.weight"),
|
|
||||||
("attn_out", "self_attn.o_proj"),
|
|
||||||
("att_proj", "self_attn.qkv_proj"),
|
|
||||||
("q_norm", "self_attn.q_norm"),
|
|
||||||
("k_norm", "self_attn.k_norm"),
|
|
||||||
("attn_norm", "input_layernorm"),
|
|
||||||
("ff_norm", "post_attention_layernorm"),
|
|
||||||
]
|
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
def _get_weights_with_merged_embedding(
|
||||||
|
weights: Iterable[Tuple[str, torch.Tensor]]
|
||||||
embedding_weight = dict()
|
) -> Iterable[Tuple[str, torch.Tensor]]:
|
||||||
projector_weight = dict()
|
embedding_weights = {}
|
||||||
for name, loaded_weight in weights:
|
for name, weight in weights:
|
||||||
if "rotary_emb.inv_freq" in name:
|
if "wte.embedding" in name:
|
||||||
continue
|
embedding_weights["embedding"] = weight
|
||||||
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
elif "wte.new_embedding" in name:
|
||||||
continue
|
embedding_weights["new_embedding"] = weight
|
||||||
|
else:
|
||||||
if "wte.embedding" in name:
|
yield (name, weight)
|
||||||
embedding_weight["embedding"] = loaded_weight
|
# this is compatible with most of quantization,
|
||||||
continue
|
# because they won't quantize embed_tokens
|
||||||
|
embedding_weights = torch.cat(
|
||||||
if "wte.new_embedding" in name:
|
[embedding_weights["embedding"], embedding_weights["new_embedding"]],
|
||||||
embedding_weight["new_embedding"] = loaded_weight
|
dim=0,
|
||||||
continue
|
)
|
||||||
|
yield ("model.embed_tokens.weight", embedding_weights)
|
||||||
if "vision_backbone" in name:
|
|
||||||
if name.startswith("model"):
|
|
||||||
name = name[len("model."):]
|
|
||||||
if 'image_projector' in name:
|
|
||||||
if 'w1' in name:
|
|
||||||
projector_weight['gate_proj'] = loaded_weight
|
|
||||||
elif 'w3' in name:
|
|
||||||
projector_weight['up_proj'] = loaded_weight
|
|
||||||
elif 'w2' in name:
|
|
||||||
projector_weight['down_proj'] = loaded_weight
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unexpected projector weight: {name}")
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
if "transformer.blocks" in name:
|
|
||||||
name = name.replace("transformer.blocks", "layers")
|
|
||||||
|
|
||||||
if "ff_proj" in name:
|
|
||||||
name = name.replace("ff_proj", "mlp.gate_up_proj")
|
|
||||||
assert 'weight' in name
|
|
||||||
up_weight, gate_weight = loaded_weight.chunk(2, dim=0)
|
|
||||||
loaded_weight = torch.cat([gate_weight, up_weight], dim=0)
|
|
||||||
|
|
||||||
elif "ff_out" in name:
|
|
||||||
if "layers" in name:
|
|
||||||
name = name.replace("ff_out", "mlp.down_proj")
|
|
||||||
else:
|
|
||||||
# lm head
|
|
||||||
name = name.replace("model.transformer.ff_out",
|
|
||||||
"lm_head")
|
|
||||||
|
|
||||||
else:
|
|
||||||
for (param_name, weight_name) in params_mapping:
|
|
||||||
if param_name in name:
|
|
||||||
name = name.replace(param_name, weight_name)
|
|
||||||
break
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
|
||||||
except KeyError:
|
|
||||||
raise ValueError(f"Unexpected weight: {name}") from None
|
|
||||||
|
|
||||||
weight_loader = getattr(param, "weight_loader",
|
|
||||||
default_weight_loader)
|
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
|
|
||||||
gate_up_proj_weight = torch.cat(
|
|
||||||
[projector_weight["gate_proj"], projector_weight["up_proj"]],
|
|
||||||
dim=0)
|
|
||||||
name = "vision_backbone.image_projector.gate_up_proj.weight"
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
||||||
weight_loader(param, gate_up_proj_weight)
|
|
||||||
|
|
||||||
down_proj_weight = projector_weight["down_proj"]
|
|
||||||
name = "vision_backbone.image_projector.down_proj.weight"
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
||||||
weight_loader(param, down_proj_weight)
|
|
||||||
|
|
||||||
embedding_weight = torch.cat(
|
|
||||||
[embedding_weight["embedding"], embedding_weight["new_embedding"]],
|
|
||||||
dim=0)
|
|
||||||
name = "model.embed_tokens.weight"
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
||||||
weight_loader(param, embedding_weight)
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user