From b2789112294cb1f54c55c2011d2a5737c2b22cae Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 25 Apr 2025 21:54:47 -0700 Subject: [PATCH] [Minor][Models] Fix Return Types of Llama & Eagle (#17220) Signed-off-by: Woosuk Kwon --- vllm/model_executor/models/llama.py | 3 ++- vllm/model_executor/models/llama_eagle.py | 4 ++-- vllm/model_executor/models/llama_eagle3.py | 4 ++-- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 17d080fa5a28..04415dace4d9 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -345,7 +345,8 @@ class LlamaModel(nn.Module): positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor, + list[torch.Tensor]]]: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index 06f7cb08a7c8..56e53ac2b815 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -70,7 +70,7 @@ class LlamaModel(nn.Module): input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor]: input_embeds = self.embed_tokens(input_ids) hidden_states = self.fc( torch.cat((input_embeds, hidden_states), dim=-1)) @@ -133,7 +133,7 @@ class EagleLlamaForCausalLM(LlamaForCausalLM): input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor]: return self.model(input_ids, positions, hidden_states) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index ffbb9d75a06b..0b18e4a8fe2f 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -117,7 +117,7 @@ class LlamaModel(nn.Module): input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor]: input_embeds = self.embed_tokens(input_ids) if (hidden_states.shape[-1] != input_embeds.shape[-1]): hidden_states = self.fc(hidden_states) @@ -194,7 +194,7 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor]: return self.model(input_ids, positions, hidden_states) def compute_logits(