From 5f9679a43bf92fc0fc8610f0ba5cc9c857148ccf Mon Sep 17 00:00:00 2001 From: Hanjie Qiu <50634613+hjjq@users.noreply.github.com> Date: Mon, 24 Nov 2025 20:13:12 -0500 Subject: [PATCH] [Spec Decode] Add support for EAGLE3 heads that do not use_aux_hidden_states (#27688) Signed-off-by: hjjq Signed-off-by: Benjamin Chislett Co-authored-by: Benjamin Chislett --- vllm/model_executor/models/llama_eagle3.py | 38 ++++++++++++++-------- vllm/v1/spec_decode/eagle.py | 19 +++++++++++ vllm/v1/worker/gpu_model_runner.py | 4 ++- 3 files changed, 46 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 3eaf2d80082f1..7a57644db1b13 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -142,6 +142,12 @@ class LlamaModel(nn.Module): # Get drafter's quantization config self.quant_config = get_draft_quant_config(vllm_config) + eagle_config = getattr(self.config, "eagle_config", None) + if eagle_config is not None and "use_aux_hidden_state" in eagle_config: + self.use_aux_hidden_state = eagle_config["use_aux_hidden_state"] + else: + self.use_aux_hidden_state = True + current_vllm_config = get_current_vllm_config() self.embed_tokens = VocabParallelEmbedding( @@ -161,20 +167,20 @@ class LlamaModel(nn.Module): for layer_idx in range(self.config.num_hidden_layers) ] ) - if hasattr(self.config, "target_hidden_size"): - fc_input_size = self.config.target_hidden_size * 3 - else: - fc_input_size = self.config.hidden_size * 3 - self.fc = ReplicatedLinear( - input_size=fc_input_size, - output_size=self.config.hidden_size, - bias=False, - params_dtype=vllm_config.model_config.dtype, - quant_config=self.quant_config, - prefix=maybe_prefix(prefix, "fc"), - return_bias=False, - ) - + if self.use_aux_hidden_state: + if hasattr(self.config, "target_hidden_size"): + fc_input_size = self.config.target_hidden_size * 3 + else: + fc_input_size = self.config.hidden_size * 3 + self.fc = ReplicatedLinear( + input_size=fc_input_size, + output_size=self.config.hidden_size, + bias=False, + params_dtype=vllm_config.model_config.dtype, + quant_config=self.quant_config, + prefix=maybe_prefix(prefix, "fc"), + return_bias=False, + ) self.norm = RMSNorm( self.config.hidden_size, eps=self.config.rms_norm_eps, @@ -332,6 +338,8 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): self, hidden_states: torch.Tensor, ) -> torch.Tensor: + if not self.model.use_aux_hidden_state: + return hidden_states # combine multiple auxiliary hidden states returned by eagle3 return self.model.fc(hidden_states) @@ -357,6 +365,8 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM): skip_substrs.append("draft_id_to_target_id") if not includes_embed_tokens: skip_substrs.append("embed_tokens") + if not self.model.use_aux_hidden_state: + skip_substrs.append("fc.") loader = AutoWeightsLoader( self, skip_prefixes=None, diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 3de418f1d13c8..afa16573eea10 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -83,6 +83,9 @@ class EagleProposer: self.draft_indexer_metadata_builder: AttentionMetadataBuilder | None = None self.attn_layer_names: list[str] = [] self.indexer_layer_names: list[str] = [] + self.eagle3_use_aux_hidden_state: bool = ( + self._get_eagle3_use_aux_hidden_state_from_config() + ) self.use_cuda_graph = False @@ -1169,6 +1172,22 @@ class EagleProposer: ) return builder + def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool: + """ + Some eagle3 heads (e.g., nvidia/gpt-oss-120b-Eagle3-v2) do not use auxiliary + hidden states and directly uses the last layer output just like eagle1. + They might indicate this by setting "use_aux_hidden_state" to False + inside the "eagle_config" dict of their hf_config. + """ + if self.method != "eagle3": + return False + # Assume that eagle3 heads use aux hidden states by default + use_aux_hidden_state = True + eagle_config = getattr(self.draft_model_config.hf_config, "eagle_config", None) + if eagle_config is not None: + use_aux_hidden_state = eagle_config.get("use_aux_hidden_state", True) + return use_aux_hidden_state + def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None: """ Validate that all eagle layers belong to the same KVCacheGroup. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index cbafc9c993cc2..6a83ac14e0b3f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -375,7 +375,9 @@ class GPUModelRunner( elif self.speculative_config.use_eagle(): self.drafter = EagleProposer(self.vllm_config, self.device, self) if self.speculative_config.method == "eagle3": - self.use_aux_hidden_state_outputs = True + self.use_aux_hidden_state_outputs = ( + self.drafter.eagle3_use_aux_hidden_state + ) elif self.speculative_config.method == "medusa": self.drafter = MedusaProposer( vllm_config=self.vllm_config, device=self.device