From 5340b0e2214ec71117e7f0a953cf3033e3194d2a Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Mon, 17 Mar 2025 11:26:38 -0700 Subject: [PATCH] [Bugfix] Fix interface for Olmo2 on V1 (#14976) Signed-off-by: Roger Wang --- vllm/model_executor/models/olmo2.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py index 54cc851de9347..f9427cdadf7a2 100644 --- a/vllm/model_executor/models/olmo2.py +++ b/vllm/model_executor/models/olmo2.py @@ -42,7 +42,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -283,17 +283,19 @@ class Olmo2Model(nn.Module): input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: """ :param input_ids: A tensor of shape `(batch_size, seq_len)`. """ if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds # Get embeddings of input. # shape: (batch_size, seq_len, d_model) - inputs_embeds = self.embed_tokens(input_ids) + else: + hidden_states = self.embed_tokens(input_ids) - # embed positions - hidden_states = inputs_embeds else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] @@ -337,7 +339,7 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP): prefix=maybe_prefix(prefix, "lm_head"), ) self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -346,11 +348,13 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP): input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, ) return hidden_states