From 6f15ac5de7303ba0e7ea161452f8cfd9a1445cee Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Tue, 16 Dec 2025 13:40:26 +0000 Subject: [PATCH] Don'e assume `position_embedding_type` will be present for BERT and RoBERTa models (#30770) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/bert.py | 4 +++- vllm/model_executor/models/roberta.py | 16 +++++----------- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index e774cd647ea8c..ee429bf458843 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -55,7 +55,9 @@ class BertEmbedding(nn.Module): "position_ids", torch.arange(config.max_position_embeddings).unsqueeze(0), ) - self.position_embedding_type = config.position_embedding_type + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) if self.position_embedding_type != "absolute": raise ValueError( "Only 'absolute' position_embedding_type" + " is supported" diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 31cc645099141..45b6e93307ac3 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -57,12 +57,6 @@ class RobertaEmbedding(nn.Module): torch.arange(config.max_position_embeddings).unsqueeze(0), ) - self.position_embedding_type = config.position_embedding_type - if self.position_embedding_type != "absolute": - raise ValueError( - "Only 'absolute' position_embedding_type" + " is supported" - ) - def forward( self, input_ids: torch.Tensor, @@ -135,12 +129,12 @@ class RobertaEmbeddingModel(BertEmbeddingModel): def _build_model( self, vllm_config: VllmConfig, prefix: str = "" ) -> BertModel | BertWithRope: - if vllm_config.model_config.hf_config.position_embedding_type == "rotary": - return JinaRobertaModel(vllm_config=vllm_config, prefix=prefix) + hf_config = vllm_config.model_config.hf_config + kwargs = dict(vllm_config=vllm_config, prefix=prefix) + if getattr(hf_config, "position_embedding_type", "absolute") == "absolute": + return BertModel(**kwargs, embedding_class=RobertaEmbedding) else: - return BertModel( - vllm_config=vllm_config, prefix=prefix, embedding_class=RobertaEmbedding - ) + return JinaRobertaModel(**kwargs) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): weights_list = list(weights)