[Model] Add AutoWeightsLoader support for BERT, RoBERTa (#20534)

Signed-off-by: Jennifer He <islandhe@gmail.com>
Signed-off-by: <islandhe@gmail.com>
Signed-off-by: Jen H <islandhe@gmail.com>
This commit is contained in:
Jennifer He 2025-07-15 01:34:24 -04:00 committed by GitHub
parent 91b3d190ae
commit 85bd6599e4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 59 additions and 100 deletions

View File

@ -22,12 +22,11 @@ from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler,
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only
from .utils import WeightsMapper, maybe_prefix
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
class BertEmbedding(nn.Module):
@ -44,9 +43,11 @@ class BertEmbedding(nn.Module):
config.type_vocab_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.position_ids = nn.Parameter(
torch.empty((1, config.max_position_embeddings)), )
self.register_buffer(
"position_ids",
torch.arange(config.max_position_embeddings).unsqueeze(0),
)
self.position_embedding_type = config.position_embedding_type
if self.position_embedding_type != "absolute":
raise ValueError("Only 'absolute' position_embedding_type" +
@ -358,45 +359,45 @@ class BertModel(nn.Module, SupportsQuant):
("qkv_proj", "value", "v"),
]
loaded_stacked_params = []
other_weights = []
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if self.pooler is None and "pooler" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
loaded_stacked_params.append(name)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
if name in params_dict:
other_weights.append((name, loaded_weight))
loader = AutoWeightsLoader(
self,
skip_prefixes=(["pooler."] if self.pooler is None else []),
)
loaded_params = loader.load_weights(other_weights)
loaded_params.update(loaded_stacked_params)
return loaded_params
class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
"""A model that uses Bert to provide embedding functionalities.
This class encapsulates the BertModel and provides an interface for
embedding operations and customized pooling functions.
This class encapsulates the BertModel and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of BertModel used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
Attributes:
model: An instance of BertModel used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
@ -425,10 +426,15 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
weights = self.hf_to_vllm_mapper.apply(weights)
weights = ((name, data) for name, data in weights
if not name.startswith("lm_head."))
self.model.load_weights(weights)
weights_list = list(weights)
has_model_prefix = any(
name.startswith("model.") for name, _ in weights_list)
if not has_model_prefix:
mapper = WeightsMapper(orig_to_new_prefix={"": "model."})
loader = AutoWeightsLoader(self, skip_prefixes=["lm_head."])
return loader.load_weights(weights_list, mapper=mapper)
def _build_model(self,
vllm_config: VllmConfig,
@ -470,26 +476,9 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only,
self.classifier, self.bert.pooler)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
self_weights = []
def weight_filter():
for name, weight in weights:
if name.startswith("bert."):
yield (name[len("bert."):], weight)
else:
self_weights.append((name, weight))
self.bert.load_weights(weight_filter())
params_dict = dict(self.named_parameters())
for name, loaded_weight in self_weights:
if name.startswith("classifier"):
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loader = AutoWeightsLoader(self)
loaded_params = loader.load_weights(weights)
return loaded_params
def pooler(
self,

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from collections.abc import Iterable
from typing import Optional, Union
@ -13,9 +12,9 @@ from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import ClassifierPooler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix
from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
maybe_prefix)
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
@ -39,8 +38,10 @@ class RobertaEmbedding(nn.Module):
config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.position_ids = nn.Parameter(
torch.empty((1, config.max_position_embeddings)), )
self.register_buffer(
"position_ids",
torch.arange(config.max_position_embeddings).unsqueeze(0),
)
self.position_embedding_type = config.position_embedding_type
if self.position_embedding_type != "absolute":
@ -136,16 +137,20 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
embedding_class=RobertaEmbedding)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
weights = self.hf_to_vllm_mapper.apply(weights)
# Separate weights in "roberta"-prefixed and all else (not in memory).
# For use with models like FacebookAI/roberta-base.
bert_weights, task_weights = roberta_task_weights_filter(weights)
loaded = self.model.load_weights(bert_weights)
if not len(loaded):
# Fix for models like `sentence-transformers/stsb-roberta-base-v2`
# which use the same architecture, but have no "roberta" prefix.
loaded = self.model.load_weights(task_weights)
assert len(loaded), "Unable to load RobertaEmbeddingModel"
weights_list = list(weights)
has_roberta_prefix = any(
name.startswith("roberta.") for name, _ in weights_list)
if has_roberta_prefix:
# For models with the `roberta.` prefix e.g.
# `FacebookAI/roberta-base`
mapper = WeightsMapper(orig_to_new_prefix={"roberta.": "model."})
else:
# For models without the `roberta.` prefix e.g.
# `sentence-transformers/stsb-roberta-base-v2`
mapper = WeightsMapper(orig_to_new_prefix={"": "model."})
loader = AutoWeightsLoader(self, skip_prefixes=["lm_head."])
return loader.load_weights(weights_list, mapper=mapper)
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
@ -187,19 +192,8 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding,
self.classifier)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
bert_weights, task_weights = roberta_task_weights_filter(weights)
bert_weights = self.jina_to_vllm_mapper.apply(bert_weights)
self.roberta.load_weights(bert_weights)
params_dict = dict(self.named_parameters())
for name, loaded_weight in task_weights:
if name.startswith("classifier"):
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.jina_to_vllm_mapper)
def pooler(
self,
@ -245,27 +239,3 @@ def create_position_ids_from_input_ids(input_ids,
past_key_values_length) * mask
return incremental_indices.long() + padding_idx
def roberta_task_weights_filter(
all_weights: Iterable[tuple[str, torch.Tensor]]
) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[str,
torch.Tensor]]]:
"""
Separate task-specific weights that are applied on top
of the encoder-decoder bert base.
To do so, return two generators over the original iterator.
Also, remove the "roberta." prefix to make it loadable
from vanilla BertModel.
"""
# Copy of a lazy iterator without in-memory overhead so both
# iterators can be iterated upon independently.
all_weights1, all_weights2 = itertools.tee(all_weights)
def encoder_decoder_weights():
for name, weight in all_weights1:
if name.startswith("roberta."):
yield (name[len("roberta."):], weight)
return encoder_decoder_weights(), ((n, w) for n, w in all_weights2
if not n.startswith("roberta."))