[Model] Refactor Molmo weights loading to use AutoWeightsLoader (#10771)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2024-11-30 12:19:14 +08:00 committed by GitHub
parent 40bc242579
commit 16ee07f22a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,7 +3,7 @@ import re
from array import array
from dataclasses import dataclass
from functools import lru_cache, partial
from typing import Iterable, List, Mapping, Optional, Tuple, TypedDict
from typing import Iterable, List, Mapping, Optional, Set, Tuple, TypedDict
import torch
from einops import rearrange
@ -44,7 +44,8 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
from vllm.transformers_utils.processor import get_processor
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (get_vit_attn_backend,
from .utils import (AutoWeightsLoader, WeightsMapper, get_vit_attn_backend,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
@ -720,6 +721,42 @@ class MolmoVisionBackbone(nn.Module):
# image_features: (batch_size, num_image, num_patch, d_model)
return image_features
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
@support_torch_compile
class MolmoModel(nn.Module):
@ -804,6 +841,28 @@ class MolmoModel(nn.Module):
hidden_states = self.norm(hidden_states)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "gate_up_proj" in name:
up_proj, gate_proj = loaded_weight.chunk(2, dim=0)
loaded_weight = torch.cat([gate_proj, up_proj], dim=0)
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
cached_get_processor = lru_cache(get_processor)
@ -1200,103 +1259,53 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
# vision backbone mapping
"image_projector.w1.": "image_projector.gate_proj.",
"image_projector.w3.": "image_projector.up_proj.",
"image_projector.w2.": "image_projector.down_proj.",
# language backbone mapping
"att_proj": "self_attn.qkv_proj",
"attn_out": "self_attn.o_proj",
"q_norm": "self_attn.q_norm",
"k_norm": "self_attn.k_norm",
"ff_proj": "mlp.gate_up_proj",
"ff_out": "mlp.down_proj",
"attn_norm": "input_layernorm",
"ff_norm": "post_attention_layernorm",
},
orig_to_new_prefix={
# vision backbone mapping
"model.vision_backbone.": "vision_backbone.",
# language backbone mapping
"model.transformer.blocks.": "model.layers.",
"model.transformer.ln_f.": "model.norm.",
# lm_head is renamed to model.transformer.mlp.down_proj firstly,
# we need to run a second renaming for it
"model.transformer.mlp.down_proj.": "lm_head.",
},
)
loader = AutoWeightsLoader(self)
weights = _get_weights_with_merged_embedding(weights)
return loader.load_weights(weights, mapper=hf_to_vllm_mapper)
params_mapping = [
("model.transformer.ln_f.weight", "model.norm.weight"),
("attn_out", "self_attn.o_proj"),
("att_proj", "self_attn.qkv_proj"),
("q_norm", "self_attn.q_norm"),
("k_norm", "self_attn.k_norm"),
("attn_norm", "input_layernorm"),
("ff_norm", "post_attention_layernorm"),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
embedding_weight = dict()
projector_weight = dict()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
if "wte.embedding" in name:
embedding_weight["embedding"] = loaded_weight
continue
if "wte.new_embedding" in name:
embedding_weight["new_embedding"] = loaded_weight
continue
if "vision_backbone" in name:
if name.startswith("model"):
name = name[len("model."):]
if 'image_projector' in name:
if 'w1' in name:
projector_weight['gate_proj'] = loaded_weight
elif 'w3' in name:
projector_weight['up_proj'] = loaded_weight
elif 'w2' in name:
projector_weight['down_proj'] = loaded_weight
else:
raise ValueError(
f"Unexpected projector weight: {name}")
continue
else:
if "transformer.blocks" in name:
name = name.replace("transformer.blocks", "layers")
if "ff_proj" in name:
name = name.replace("ff_proj", "mlp.gate_up_proj")
assert 'weight' in name
up_weight, gate_weight = loaded_weight.chunk(2, dim=0)
loaded_weight = torch.cat([gate_weight, up_weight], dim=0)
elif "ff_out" in name:
if "layers" in name:
name = name.replace("ff_out", "mlp.down_proj")
else:
# lm head
name = name.replace("model.transformer.ff_out",
"lm_head")
else:
for (param_name, weight_name) in params_mapping:
if param_name in name:
name = name.replace(param_name, weight_name)
break
try:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
except KeyError:
raise ValueError(f"Unexpected weight: {name}") from None
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
gate_up_proj_weight = torch.cat(
[projector_weight["gate_proj"], projector_weight["up_proj"]],
dim=0)
name = "vision_backbone.image_projector.gate_up_proj.weight"
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, gate_up_proj_weight)
down_proj_weight = projector_weight["down_proj"]
name = "vision_backbone.image_projector.down_proj.weight"
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, down_proj_weight)
embedding_weight = torch.cat(
[embedding_weight["embedding"], embedding_weight["new_embedding"]],
dim=0)
name = "model.embed_tokens.weight"
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, embedding_weight)
def _get_weights_with_merged_embedding(
weights: Iterable[Tuple[str, torch.Tensor]]
) -> Iterable[Tuple[str, torch.Tensor]]:
embedding_weights = {}
for name, weight in weights:
if "wte.embedding" in name:
embedding_weights["embedding"] = weight
elif "wte.new_embedding" in name:
embedding_weights["new_embedding"] = weight
else:
yield (name, weight)
# this is compatible with most of quantization,
# because they won't quantize embed_tokens
embedding_weights = torch.cat(
[embedding_weights["embedding"], embedding_weights["new_embedding"]],
dim=0,
)
yield ("model.embed_tokens.weight", embedding_weights)