From 14e1e9b09acbb3192943fffeaac1b81cfb8a6b54 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Sat, 20 Sep 2025 04:11:03 +0100 Subject: [PATCH] Improve weight loading for encoder models in Transformers backend (#25289) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: yewentao256 --- vllm/model_executor/models/transformers.py | 28 ++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index f40a20dee63d7..3bd4d10316ec6 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -702,21 +702,45 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): class TransformersModel(TransformersBase): hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ + # Handle BERT-like models + "bert": "model", # Add `model.` prefix for base model checkpoints "": "model.", - # Remove `model.` from places it should not be + # Remove `model.` prefix if it was already there "model.model.": "model.", + # Pooling adapters will be adjacent to `model` + "model.pooler": "pooler", "model.score": "score", + # Classifier adapter's classifier layer is renamed to score + "model.classifier": "score", + }, + orig_to_new_suffix={ + # Replace legacy suffixes used for norms + ".gamma": ".weight", + ".beta": ".bias", }) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) - # Some encoder models have the position_ids buffer in the checkpoint + # After creating a pooling model, `pooler` will be duplicated. + # The one inside `model` comes from the Transformers modelling code. + # The one after `model` is an adapter from vLLM. + # We want to use the adapter so we nullify the original pooler. + if getattr(self.model, "pooler", None) is not None: + self.skip_prefixes.append("pooler.") + self.model.pooler = torch.nn.Identity() + + # Some encoder models have the position_ids buffer in the checkpoint. # vLLM will always pass position_ids as an argument, so we skip loading # the buffer if it exists self.skip_substrs.append("position_ids") + # Some encoder models have the bias of the final classifier layer + # in the checkpoint. vLLM does not use this bias, so we skip loading + # it if it exists + self.skip_substrs.append("score.bias") + def create_attention_instances( self, attn_type: AttentionType = AttentionType.DECODER): # TODO(hmellor): Better way to detect encoder models