mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:44:57 +08:00
Improve weight loading for encoder models in Transformers backend (#25289)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
535d80056b
commit
c308501cb6
@ -702,21 +702,45 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
class TransformersModel(TransformersBase):
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
# Handle BERT-like models
|
||||
"bert": "model",
|
||||
# Add `model.` prefix for base model checkpoints
|
||||
"": "model.",
|
||||
# Remove `model.` from places it should not be
|
||||
# Remove `model.` prefix if it was already there
|
||||
"model.model.": "model.",
|
||||
# Pooling adapters will be adjacent to `model`
|
||||
"model.pooler": "pooler",
|
||||
"model.score": "score",
|
||||
# Classifier adapter's classifier layer is renamed to score
|
||||
"model.classifier": "score",
|
||||
},
|
||||
orig_to_new_suffix={
|
||||
# Replace legacy suffixes used for norms
|
||||
".gamma": ".weight",
|
||||
".beta": ".bias",
|
||||
})
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
# Some encoder models have the position_ids buffer in the checkpoint
|
||||
# After creating a pooling model, `pooler` will be duplicated.
|
||||
# The one inside `model` comes from the Transformers modelling code.
|
||||
# The one after `model` is an adapter from vLLM.
|
||||
# We want to use the adapter so we nullify the original pooler.
|
||||
if getattr(self.model, "pooler", None) is not None:
|
||||
self.skip_prefixes.append("pooler.")
|
||||
self.model.pooler = torch.nn.Identity()
|
||||
|
||||
# Some encoder models have the position_ids buffer in the checkpoint.
|
||||
# vLLM will always pass position_ids as an argument, so we skip loading
|
||||
# the buffer if it exists
|
||||
self.skip_substrs.append("position_ids")
|
||||
|
||||
# Some encoder models have the bias of the final classifier layer
|
||||
# in the checkpoint. vLLM does not use this bias, so we skip loading
|
||||
# it if it exists
|
||||
self.skip_substrs.append("score.bias")
|
||||
|
||||
def create_attention_instances(
|
||||
self, attn_type: AttentionType = AttentionType.DECODER):
|
||||
# TODO(hmellor): Better way to detect encoder models
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user