From 1bc86a3da1bd45e7d43347d6532a515950a438f0 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Sat, 31 May 2025 22:58:07 -0400 Subject: [PATCH] [Bugfix] Fix EAGLE3 broken logits (#18909) Signed-off-by: Benjamin Chislett --- vllm/model_executor/models/llama_eagle3.py | 23 +++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index f211bfe54a7d..1e40017fc792 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -215,6 +215,9 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) if self.draft_id_to_target_id is None: + assert logits.shape[1] == self.config.vocab_size, \ + "Expected logits to have shape " \ + f"(*, {self.config.vocab_size}), but got {logits.shape}" return logits base = torch.arange(self.config.draft_vocab_size, device=logits.device) @@ -234,24 +237,22 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): return self.model.fc(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - loader = AutoWeightsLoader( - self, - skip_prefixes=None, - ) - model_weights = {} + includes_draft_id_mapping = False for name, loaded_weight in weights: if "t2d" in name: continue if "d2t" in name: name = name.replace("d2t", "draft_id_to_target_id") + includes_draft_id_mapping = True elif "lm_head" not in name: name = "model." + name model_weights[name] = loaded_weight - loaded_weights = loader.load_weights(model_weights.items()) - - if 'd2t' not in loaded_weights: - self.draft_id_to_target_id = None - - return loaded_weights + loader = AutoWeightsLoader( + self, + skip_prefixes=None, + skip_substrs=["draft_id_to_target_id"] \ + if not includes_draft_id_mapping else None, + ) + loader.load_weights(model_weights.items())