[Model] use AutoWeightsLoader for bart (#18299)

Signed-off-by: calvin chen <120380290@qq.com>
This commit is contained in:
Calvin Chen 2025-07-20 16:15:50 +08:00 committed by GitHub
parent d1fb65bde3
commit 51ba839555
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -46,7 +46,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsQuant, SupportsV0Only
from .utils import maybe_prefix
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
logger = logging.get_logger(__name__)
@ -700,7 +700,8 @@ class BartDecoder(nn.Module):
class BartModel(nn.Module, SupportsQuant):
_tied_weights_keys = [
"encoder.embed_tokens.weight", "decoder.embed_tokens.weight"
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
@ -763,10 +764,54 @@ class BartModel(nn.Module, SupportsQuant):
return decoder_outputs
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
other_weights = []
loaded_stacked_params = []
model_params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if name not in model_params_dict:
continue
param = model_params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
loaded_stacked_params.append(name)
break
else:
if name in model_params_dict:
other_weights.append((name, loaded_weight))
loader = AutoWeightsLoader(self)
loaded_params = loader.load_weights(other_weights)
loaded_params.update(loaded_stacked_params)
return loaded_params
class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant):
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
base_model_prefix = "model"
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"decoder.": "model.decoder.",
"encoder.": "model.encoder.",
"shared.": "model.shared."
},
orig_to_new_substr={
"beta": "bias",
"gamma": "weight",
"LayerNorm": "layernorm",
},
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
@ -789,7 +834,6 @@ class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant):
self.lm_head = BartParallelLMHead(config.vocab_size,
config.d_model,
embed_scale=embed_scale)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
@ -828,61 +872,12 @@ class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant):
sampling_metadata)
return logits
stacked_params_mapping = {
"q_proj": {
"param_name": "qkv_proj",
"shard_id": "q",
},
"k_proj": {
"param_name": "qkv_proj",
"shard_id": "k",
},
"v_proj": {
"param_name": "qkv_proj",
"shard_id": "v",
},
}
params_mapping = {
"beta": "bias",
"gamma": "weight",
"LayerNorm": "layernorm",
}
def _rename_key(self, key: str):
prefix = f"{self.base_model_prefix}."
key = key[len(prefix):] if key.startswith(prefix) else key
for src, dst in self.params_mapping.items():
key = key.replace(src, dst)
return key
def _rename_stacked_param(
self,
name: str,
) -> tuple[str, Optional[str]]:
for key, mapping in self.stacked_params_mapping.items():
if key in name:
name = name.replace(key, mapping["param_name"])
return name, mapping["shard_id"]
return name, None
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
model_params_dict = dict(self.model.named_parameters())
top_params_dict = dict(self.named_parameters())
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
weights_tuple_list = list(weights)
shared_embedding_weight = None
shared_embedding_shard_id = None
for name, loaded_weight in weights_tuple_list:
name = self._rename_key(name)
name, shard_id = self._rename_stacked_param(name)
if ('shared.weight' in name
or 'encoder.embed_tokens.weight' in name
or 'decoder.embed_tokens.weight' in name
@ -890,49 +885,24 @@ class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant):
assert shared_embedding_weight is None, (
"Conflicting embedding weights.")
shared_embedding_weight = loaded_weight
shared_embedding_shard_id = shard_id
else:
# Skip the specific downstream task weight.
if name.startswith('cls.'):
continue
# use Pooler instead.
if name.startswith('pooler.'):
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in model_params_dict:
continue
param = model_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
if shard_id:
weight_loader(param, loaded_weight, shard_id)
else:
weight_loader(param, loaded_weight)
loader = AutoWeightsLoader(
self,
skip_prefixes=(["cls.", "pooler."]),
)
loaded_params = loader.load_weights(weights_tuple_list,
mapper=self.hf_to_vllm_mapper)
# Assign shared weight values
encoder_in_param = model_params_dict['encoder.embed_tokens.weight']
encoder_in_weight_loader = getattr(encoder_in_param, "weight_loader",
default_weight_loader)
if shared_embedding_weight is not None:
weight_loader = getattr(self.lm_head.weight, "weight_loader",
default_weight_loader)
weight_loader(self.lm_head.weight, shared_embedding_weight)
decoder_in_param = model_params_dict['decoder.embed_tokens.weight']
decoder_in_weight_loader = getattr(decoder_in_param, "weight_loader",
default_weight_loader)
self.model.encoder.embed_tokens.weight = self.lm_head.weight
self.model.decoder.embed_tokens.weight = self.lm_head.weight
loaded_params.update({
'model.encoder.embed_tokens.weight', 'lm_head.weight',
'model.decoder.embed_tokens.weight'
})
lm_head_in_param = top_params_dict['lm_head.weight']
lm_head_in_weight_loader = getattr(lm_head_in_param, "weight_loader",
default_weight_loader)
assert shared_embedding_weight is not None
if shared_embedding_shard_id:
encoder_in_weight_loader(encoder_in_param, shared_embedding_weight,
shared_embedding_shard_id)
decoder_in_weight_loader(decoder_in_param, shared_embedding_weight,
shared_embedding_shard_id)
lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight,
shared_embedding_shard_id)
else:
encoder_in_weight_loader(encoder_in_param, shared_embedding_weight)
decoder_in_weight_loader(decoder_in_param, shared_embedding_weight)
lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight)
return loaded_params