[Spec Decode] Add support for EAGLE3 heads that do not use_aux_hidden_states (#27688)

Signed-off-by: hjjq <hanjieq@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Co-authored-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
Hanjie Qiu 2025-11-24 20:13:12 -05:00 committed by GitHub
parent 699bca76c0
commit 5f9679a43b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 46 additions and 15 deletions

View File

@ -142,6 +142,12 @@ class LlamaModel(nn.Module):
# Get drafter's quantization config
self.quant_config = get_draft_quant_config(vllm_config)
eagle_config = getattr(self.config, "eagle_config", None)
if eagle_config is not None and "use_aux_hidden_state" in eagle_config:
self.use_aux_hidden_state = eagle_config["use_aux_hidden_state"]
else:
self.use_aux_hidden_state = True
current_vllm_config = get_current_vllm_config()
self.embed_tokens = VocabParallelEmbedding(
@ -161,20 +167,20 @@ class LlamaModel(nn.Module):
for layer_idx in range(self.config.num_hidden_layers)
]
)
if hasattr(self.config, "target_hidden_size"):
fc_input_size = self.config.target_hidden_size * 3
else:
fc_input_size = self.config.hidden_size * 3
self.fc = ReplicatedLinear(
input_size=fc_input_size,
output_size=self.config.hidden_size,
bias=False,
params_dtype=vllm_config.model_config.dtype,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "fc"),
return_bias=False,
)
if self.use_aux_hidden_state:
if hasattr(self.config, "target_hidden_size"):
fc_input_size = self.config.target_hidden_size * 3
else:
fc_input_size = self.config.hidden_size * 3
self.fc = ReplicatedLinear(
input_size=fc_input_size,
output_size=self.config.hidden_size,
bias=False,
params_dtype=vllm_config.model_config.dtype,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "fc"),
return_bias=False,
)
self.norm = RMSNorm(
self.config.hidden_size,
eps=self.config.rms_norm_eps,
@ -332,6 +338,8 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
if not self.model.use_aux_hidden_state:
return hidden_states
# combine multiple auxiliary hidden states returned by eagle3
return self.model.fc(hidden_states)
@ -357,6 +365,8 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
skip_substrs.append("draft_id_to_target_id")
if not includes_embed_tokens:
skip_substrs.append("embed_tokens")
if not self.model.use_aux_hidden_state:
skip_substrs.append("fc.")
loader = AutoWeightsLoader(
self,
skip_prefixes=None,

View File

@ -83,6 +83,9 @@ class EagleProposer:
self.draft_indexer_metadata_builder: AttentionMetadataBuilder | None = None
self.attn_layer_names: list[str] = []
self.indexer_layer_names: list[str] = []
self.eagle3_use_aux_hidden_state: bool = (
self._get_eagle3_use_aux_hidden_state_from_config()
)
self.use_cuda_graph = False
@ -1169,6 +1172,22 @@ class EagleProposer:
)
return builder
def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool:
"""
Some eagle3 heads (e.g., nvidia/gpt-oss-120b-Eagle3-v2) do not use auxiliary
hidden states and directly uses the last layer output just like eagle1.
They might indicate this by setting "use_aux_hidden_state" to False
inside the "eagle_config" dict of their hf_config.
"""
if self.method != "eagle3":
return False
# Assume that eagle3 heads use aux hidden states by default
use_aux_hidden_state = True
eagle_config = getattr(self.draft_model_config.hf_config, "eagle_config", None)
if eagle_config is not None:
use_aux_hidden_state = eagle_config.get("use_aux_hidden_state", True)
return use_aux_hidden_state
def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None:
"""
Validate that all eagle layers belong to the same KVCacheGroup.

View File

@ -375,7 +375,9 @@ class GPUModelRunner(
elif self.speculative_config.use_eagle():
self.drafter = EagleProposer(self.vllm_config, self.device, self)
if self.speculative_config.method == "eagle3":
self.use_aux_hidden_state_outputs = True
self.use_aux_hidden_state_outputs = (
self.drafter.eagle3_use_aux_hidden_state
)
elif self.speculative_config.method == "medusa":
self.drafter = MedusaProposer(
vllm_config=self.vllm_config, device=self.device