Don'e assume position_embedding_type will be present for BERT and RoBERTa models (#30770)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-12-16 13:40:26 +00:00 committed by GitHub
parent 676db55eec
commit 6f15ac5de7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 12 deletions

View File

@ -55,7 +55,9 @@ class BertEmbedding(nn.Module):
"position_ids",
torch.arange(config.max_position_embeddings).unsqueeze(0),
)
self.position_embedding_type = config.position_embedding_type
self.position_embedding_type = getattr(
config, "position_embedding_type", "absolute"
)
if self.position_embedding_type != "absolute":
raise ValueError(
"Only 'absolute' position_embedding_type" + " is supported"

View File

@ -57,12 +57,6 @@ class RobertaEmbedding(nn.Module):
torch.arange(config.max_position_embeddings).unsqueeze(0),
)
self.position_embedding_type = config.position_embedding_type
if self.position_embedding_type != "absolute":
raise ValueError(
"Only 'absolute' position_embedding_type" + " is supported"
)
def forward(
self,
input_ids: torch.Tensor,
@ -135,12 +129,12 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
def _build_model(
self, vllm_config: VllmConfig, prefix: str = ""
) -> BertModel | BertWithRope:
if vllm_config.model_config.hf_config.position_embedding_type == "rotary":
return JinaRobertaModel(vllm_config=vllm_config, prefix=prefix)
hf_config = vllm_config.model_config.hf_config
kwargs = dict(vllm_config=vllm_config, prefix=prefix)
if getattr(hf_config, "position_embedding_type", "absolute") == "absolute":
return BertModel(**kwargs, embedding_class=RobertaEmbedding)
else:
return BertModel(
vllm_config=vllm_config, prefix=prefix, embedding_class=RobertaEmbedding
)
return JinaRobertaModel(**kwargs)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
weights_list = list(weights)