[Bugfix] Token type and position embeddings fail to be applied to inputs_embeds (#25922)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-10-01 00:23:12 +08:00 committed by GitHub
parent ef283548f7
commit 9f1c4ecaf2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 14 additions and 9 deletions

View File

@ -61,11 +61,13 @@ class BertEmbedding(nn.Module):
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
token_type_ids = _decode_token_type_ids(input_ids)
inputs_embeds = self.word_embeddings(input_ids)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
@ -358,11 +360,12 @@ class BertModel(nn.Module, SupportsQuant):
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embeddings(input_ids=input_ids,
position_ids=positions)
hidden_states = self.embeddings(
input_ids=input_ids,
position_ids=positions,
inputs_embeds=inputs_embeds,
)
return self.encoder(hidden_states)
def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):

View File

@ -56,11 +56,13 @@ class RobertaEmbedding(nn.Module):
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
token_type_ids = _decode_token_type_ids(input_ids)
inputs_embeds = self.word_embeddings(input_ids)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)