Improve weight loading for encoder models in Transformers backend (#25289)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-09-20 04:11:03 +01:00 committed by GitHub
parent 535d80056b
commit c308501cb6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -702,21 +702,45 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
class TransformersModel(TransformersBase):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# Handle BERT-like models
"bert": "model",
# Add `model.` prefix for base model checkpoints
"": "model.",
# Remove `model.` from places it should not be
# Remove `model.` prefix if it was already there
"model.model.": "model.",
# Pooling adapters will be adjacent to `model`
"model.pooler": "pooler",
"model.score": "score",
# Classifier adapter's classifier layer is renamed to score
"model.classifier": "score",
},
orig_to_new_suffix={
# Replace legacy suffixes used for norms
".gamma": ".weight",
".beta": ".bias",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
# Some encoder models have the position_ids buffer in the checkpoint
# After creating a pooling model, `pooler` will be duplicated.
# The one inside `model` comes from the Transformers modelling code.
# The one after `model` is an adapter from vLLM.
# We want to use the adapter so we nullify the original pooler.
if getattr(self.model, "pooler", None) is not None:
self.skip_prefixes.append("pooler.")
self.model.pooler = torch.nn.Identity()
# Some encoder models have the position_ids buffer in the checkpoint.
# vLLM will always pass position_ids as an argument, so we skip loading
# the buffer if it exists
self.skip_substrs.append("position_ids")
# Some encoder models have the bias of the final classifier layer
# in the checkpoint. vLLM does not use this bias, so we skip loading
# it if it exists
self.skip_substrs.append("score.bias")
def create_attention_instances(
self, attn_type: AttentionType = AttentionType.DECODER):
# TODO(hmellor): Better way to detect encoder models