From 51ba839555a5d122eadd91e9c16463ac288f5fa1 Mon Sep 17 00:00:00 2001 From: Calvin Chen Date: Sun, 20 Jul 2025 16:15:50 +0800 Subject: [PATCH] [Model] use AutoWeightsLoader for bart (#18299) Signed-off-by: calvin chen <120380290@qq.com> --- vllm/model_executor/models/bart.py | 166 ++++++++++++----------------- 1 file changed, 68 insertions(+), 98 deletions(-) diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index a0ec12674f19b..3d328c88ff6e0 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -46,7 +46,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsQuant, SupportsV0Only -from .utils import maybe_prefix +from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix logger = logging.get_logger(__name__) @@ -700,7 +700,8 @@ class BartDecoder(nn.Module): class BartModel(nn.Module, SupportsQuant): _tied_weights_keys = [ - "encoder.embed_tokens.weight", "decoder.embed_tokens.weight" + "encoder.embed_tokens.weight", + "decoder.embed_tokens.weight", ] def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -763,10 +764,54 @@ class BartModel(nn.Module, SupportsQuant): return decoder_outputs + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + other_weights = [] + loaded_stacked_params = [] + model_params_dict = dict(self.named_parameters()) + + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if name not in model_params_dict: + continue + param = model_params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + loaded_stacked_params.append(name) + break + else: + if name in model_params_dict: + other_weights.append((name, loaded_weight)) + + loader = AutoWeightsLoader(self) + loaded_params = loader.load_weights(other_weights) + loaded_params.update(loaded_stacked_params) + return loaded_params + class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant): - packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} - base_model_prefix = "model" + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "decoder.": "model.decoder.", + "encoder.": "model.encoder.", + "shared.": "model.shared." + }, + orig_to_new_substr={ + "beta": "bias", + "gamma": "weight", + "LayerNorm": "layernorm", + }, + ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -789,7 +834,6 @@ class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant): self.lm_head = BartParallelLMHead(config.vocab_size, config.d_model, embed_scale=embed_scale) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) @@ -828,61 +872,12 @@ class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant): sampling_metadata) return logits - stacked_params_mapping = { - "q_proj": { - "param_name": "qkv_proj", - "shard_id": "q", - }, - "k_proj": { - "param_name": "qkv_proj", - "shard_id": "k", - }, - "v_proj": { - "param_name": "qkv_proj", - "shard_id": "v", - }, - } - - params_mapping = { - "beta": "bias", - "gamma": "weight", - "LayerNorm": "layernorm", - } - - def _rename_key(self, key: str): - prefix = f"{self.base_model_prefix}." - key = key[len(prefix):] if key.startswith(prefix) else key - - for src, dst in self.params_mapping.items(): - key = key.replace(src, dst) - - return key - - def _rename_stacked_param( - self, - name: str, - ) -> tuple[str, Optional[str]]: - for key, mapping in self.stacked_params_mapping.items(): - if key in name: - name = name.replace(key, mapping["param_name"]) - return name, mapping["shard_id"] - return name, None - - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - - model_params_dict = dict(self.model.named_parameters()) - top_params_dict = dict(self.named_parameters()) - + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: weights_tuple_list = list(weights) shared_embedding_weight = None - shared_embedding_shard_id = None - for name, loaded_weight in weights_tuple_list: - - name = self._rename_key(name) - name, shard_id = self._rename_stacked_param(name) - if ('shared.weight' in name or 'encoder.embed_tokens.weight' in name or 'decoder.embed_tokens.weight' in name @@ -890,49 +885,24 @@ class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant): assert shared_embedding_weight is None, ( "Conflicting embedding weights.") shared_embedding_weight = loaded_weight - shared_embedding_shard_id = shard_id - else: - # Skip the specific downstream task weight. - if name.startswith('cls.'): - continue - # use Pooler instead. - if name.startswith('pooler.'): - continue - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in model_params_dict: - continue - param = model_params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - if shard_id: - weight_loader(param, loaded_weight, shard_id) - else: - weight_loader(param, loaded_weight) + loader = AutoWeightsLoader( + self, + skip_prefixes=(["cls.", "pooler."]), + ) + loaded_params = loader.load_weights(weights_tuple_list, + mapper=self.hf_to_vllm_mapper) - # Assign shared weight values - encoder_in_param = model_params_dict['encoder.embed_tokens.weight'] - encoder_in_weight_loader = getattr(encoder_in_param, "weight_loader", - default_weight_loader) + if shared_embedding_weight is not None: + weight_loader = getattr(self.lm_head.weight, "weight_loader", + default_weight_loader) + weight_loader(self.lm_head.weight, shared_embedding_weight) - decoder_in_param = model_params_dict['decoder.embed_tokens.weight'] - decoder_in_weight_loader = getattr(decoder_in_param, "weight_loader", - default_weight_loader) + self.model.encoder.embed_tokens.weight = self.lm_head.weight + self.model.decoder.embed_tokens.weight = self.lm_head.weight + loaded_params.update({ + 'model.encoder.embed_tokens.weight', 'lm_head.weight', + 'model.decoder.embed_tokens.weight' + }) - lm_head_in_param = top_params_dict['lm_head.weight'] - lm_head_in_weight_loader = getattr(lm_head_in_param, "weight_loader", - default_weight_loader) - - assert shared_embedding_weight is not None - - if shared_embedding_shard_id: - encoder_in_weight_loader(encoder_in_param, shared_embedding_weight, - shared_embedding_shard_id) - decoder_in_weight_loader(decoder_in_param, shared_embedding_weight, - shared_embedding_shard_id) - lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight, - shared_embedding_shard_id) - else: - encoder_in_weight_loader(encoder_in_param, shared_embedding_weight) - decoder_in_weight_loader(decoder_in_param, shared_embedding_weight) - lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight) + return loaded_params