[Bugfix] Fix interface for Olmo2 on V1 (#14976)

Signed-off-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Roger Wang 2025-03-17 11:26:38 -07:00 committed by GitHub
parent 37e3806132
commit 5340b0e221
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -42,7 +42,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@ -283,17 +283,19 @@ class Olmo2Model(nn.Module):
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
"""
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
"""
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
# Get embeddings of input.
# shape: (batch_size, seq_len, d_model)
inputs_embeds = self.embed_tokens(input_ids)
else:
hidden_states = self.embed_tokens(input_ids)
# embed positions
hidden_states = inputs_embeds
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
@ -337,7 +339,7 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP):
prefix=maybe_prefix(prefix, "lm_head"),
)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
@ -346,11 +348,13 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP):
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
return hidden_states