[Model] Refactor Molmo weights loading to use AutoWeightsLoader (#10771)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2024-11-30 12:19:14 +08:00 committed by GitHub
parent 40bc242579
commit 16ee07f22a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,7 +3,7 @@ import re
from array import array from array import array
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache, partial from functools import lru_cache, partial
from typing import Iterable, List, Mapping, Optional, Tuple, TypedDict from typing import Iterable, List, Mapping, Optional, Set, Tuple, TypedDict
import torch import torch
from einops import rearrange from einops import rearrange
@ -44,7 +44,8 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
from vllm.transformers_utils.processor import get_processor from vllm.transformers_utils.processor import get_processor
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (get_vit_attn_backend, from .utils import (AutoWeightsLoader, WeightsMapper, get_vit_attn_backend,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
@ -720,6 +721,42 @@ class MolmoVisionBackbone(nn.Module):
# image_features: (batch_size, num_image, num_patch, d_model) # image_features: (batch_size, num_image, num_patch, d_model)
return image_features return image_features
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
@support_torch_compile @support_torch_compile
class MolmoModel(nn.Module): class MolmoModel(nn.Module):
@ -804,6 +841,28 @@ class MolmoModel(nn.Module):
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
return hidden_states return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "gate_up_proj" in name:
up_proj, gate_proj = loaded_weight.chunk(2, dim=0)
loaded_weight = torch.cat([gate_proj, up_proj], dim=0)
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
cached_get_processor = lru_cache(get_processor) cached_get_processor = lru_cache(get_processor)
@ -1200,103 +1259,53 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
# vision backbone mapping
"image_projector.w1.": "image_projector.gate_proj.",
"image_projector.w3.": "image_projector.up_proj.",
"image_projector.w2.": "image_projector.down_proj.",
# language backbone mapping
"att_proj": "self_attn.qkv_proj",
"attn_out": "self_attn.o_proj",
"q_norm": "self_attn.q_norm",
"k_norm": "self_attn.k_norm",
"ff_proj": "mlp.gate_up_proj",
"ff_out": "mlp.down_proj",
"attn_norm": "input_layernorm",
"ff_norm": "post_attention_layernorm",
},
orig_to_new_prefix={
# vision backbone mapping
"model.vision_backbone.": "vision_backbone.",
# language backbone mapping
"model.transformer.blocks.": "model.layers.",
"model.transformer.ln_f.": "model.norm.",
# lm_head is renamed to model.transformer.mlp.down_proj firstly,
# we need to run a second renaming for it
"model.transformer.mlp.down_proj.": "lm_head.",
},
)
loader = AutoWeightsLoader(self)
weights = _get_weights_with_merged_embedding(weights)
return loader.load_weights(weights, mapper=hf_to_vllm_mapper)
params_mapping = [
("model.transformer.ln_f.weight", "model.norm.weight"),
("attn_out", "self_attn.o_proj"),
("att_proj", "self_attn.qkv_proj"),
("q_norm", "self_attn.q_norm"),
("k_norm", "self_attn.k_norm"),
("attn_norm", "input_layernorm"),
("ff_norm", "post_attention_layernorm"),
]
params_dict = dict(self.named_parameters(remove_duplicate=False)) def _get_weights_with_merged_embedding(
weights: Iterable[Tuple[str, torch.Tensor]]
embedding_weight = dict() ) -> Iterable[Tuple[str, torch.Tensor]]:
projector_weight = dict() embedding_weights = {}
for name, loaded_weight in weights: for name, weight in weights:
if "rotary_emb.inv_freq" in name: if "wte.embedding" in name:
continue embedding_weights["embedding"] = weight
if self.config.tie_word_embeddings and "lm_head.weight" in name: elif "wte.new_embedding" in name:
continue embedding_weights["new_embedding"] = weight
else:
if "wte.embedding" in name: yield (name, weight)
embedding_weight["embedding"] = loaded_weight # this is compatible with most of quantization,
continue # because they won't quantize embed_tokens
embedding_weights = torch.cat(
if "wte.new_embedding" in name: [embedding_weights["embedding"], embedding_weights["new_embedding"]],
embedding_weight["new_embedding"] = loaded_weight dim=0,
continue )
yield ("model.embed_tokens.weight", embedding_weights)
if "vision_backbone" in name:
if name.startswith("model"):
name = name[len("model."):]
if 'image_projector' in name:
if 'w1' in name:
projector_weight['gate_proj'] = loaded_weight
elif 'w3' in name:
projector_weight['up_proj'] = loaded_weight
elif 'w2' in name:
projector_weight['down_proj'] = loaded_weight
else:
raise ValueError(
f"Unexpected projector weight: {name}")
continue
else:
if "transformer.blocks" in name:
name = name.replace("transformer.blocks", "layers")
if "ff_proj" in name:
name = name.replace("ff_proj", "mlp.gate_up_proj")
assert 'weight' in name
up_weight, gate_weight = loaded_weight.chunk(2, dim=0)
loaded_weight = torch.cat([gate_weight, up_weight], dim=0)
elif "ff_out" in name:
if "layers" in name:
name = name.replace("ff_out", "mlp.down_proj")
else:
# lm head
name = name.replace("model.transformer.ff_out",
"lm_head")
else:
for (param_name, weight_name) in params_mapping:
if param_name in name:
name = name.replace(param_name, weight_name)
break
try:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
except KeyError:
raise ValueError(f"Unexpected weight: {name}") from None
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
gate_up_proj_weight = torch.cat(
[projector_weight["gate_proj"], projector_weight["up_proj"]],
dim=0)
name = "vision_backbone.image_projector.gate_up_proj.weight"
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, gate_up_proj_weight)
down_proj_weight = projector_weight["down_proj"]
name = "vision_backbone.image_projector.down_proj.weight"
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, down_proj_weight)
embedding_weight = torch.cat(
[embedding_weight["embedding"], embedding_weight["new_embedding"]],
dim=0)
name = "model.embed_tokens.weight"
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, embedding_weight)