diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index c984845204c4f..2ec3edc5a0a7a 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -61,11 +61,13 @@ class BertEmbedding(nn.Module): self, input_ids: torch.Tensor, position_ids: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - token_type_ids = _decode_token_type_ids(input_ids) - inputs_embeds = self.word_embeddings(input_ids) + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) @@ -358,11 +360,12 @@ class BertModel(nn.Module, SupportsQuant): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.embeddings(input_ids=input_ids, - position_ids=positions) + hidden_states = self.embeddings( + input_ids=input_ids, + position_ids=positions, + inputs_embeds=inputs_embeds, + ) + return self.encoder(hidden_states) def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 53e698c4fa806..a13042a6367cc 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -56,11 +56,13 @@ class RobertaEmbedding(nn.Module): self, input_ids: torch.Tensor, position_ids: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - token_type_ids = _decode_token_type_ids(input_ids) - inputs_embeds = self.word_embeddings(input_ids) + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids)