From 583507d13075783a12ccbd774575974d10ca4959 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Thu, 22 May 2025 23:17:39 -0400 Subject: [PATCH] [Spec Decode] Make EAGLE3 draft token ID mapping optional (#18488) Signed-off-by: Benjamin Chislett Co-authored-by: Woosuk Kwon --- vllm/model_executor/models/llama_eagle3.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 96e666a3543dc..f211bfe54a7d7 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -214,6 +214,9 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) + if self.draft_id_to_target_id is None: + return logits + base = torch.arange(self.config.draft_vocab_size, device=logits.device) targets = base + self.draft_id_to_target_id logits_new = logits.new_full(( @@ -246,4 +249,9 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): name = "model." + name model_weights[name] = loaded_weight - return loader.load_weights(model_weights.items()) + loaded_weights = loader.load_weights(model_weights.items()) + + if 'd2t' not in loaded_weights: + self.draft_id_to_target_id = None + + return loaded_weights