mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 12:04:28 +08:00
[Spec Decode] Add support for EAGLE3 heads that do not use_aux_hidden_states (#27688)
Signed-off-by: hjjq <hanjieq@nvidia.com> Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Co-authored-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
parent
699bca76c0
commit
5f9679a43b
@ -142,6 +142,12 @@ class LlamaModel(nn.Module):
|
||||
# Get drafter's quantization config
|
||||
self.quant_config = get_draft_quant_config(vllm_config)
|
||||
|
||||
eagle_config = getattr(self.config, "eagle_config", None)
|
||||
if eagle_config is not None and "use_aux_hidden_state" in eagle_config:
|
||||
self.use_aux_hidden_state = eagle_config["use_aux_hidden_state"]
|
||||
else:
|
||||
self.use_aux_hidden_state = True
|
||||
|
||||
current_vllm_config = get_current_vllm_config()
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
@ -161,20 +167,20 @@ class LlamaModel(nn.Module):
|
||||
for layer_idx in range(self.config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
if hasattr(self.config, "target_hidden_size"):
|
||||
fc_input_size = self.config.target_hidden_size * 3
|
||||
else:
|
||||
fc_input_size = self.config.hidden_size * 3
|
||||
self.fc = ReplicatedLinear(
|
||||
input_size=fc_input_size,
|
||||
output_size=self.config.hidden_size,
|
||||
bias=False,
|
||||
params_dtype=vllm_config.model_config.dtype,
|
||||
quant_config=self.quant_config,
|
||||
prefix=maybe_prefix(prefix, "fc"),
|
||||
return_bias=False,
|
||||
)
|
||||
|
||||
if self.use_aux_hidden_state:
|
||||
if hasattr(self.config, "target_hidden_size"):
|
||||
fc_input_size = self.config.target_hidden_size * 3
|
||||
else:
|
||||
fc_input_size = self.config.hidden_size * 3
|
||||
self.fc = ReplicatedLinear(
|
||||
input_size=fc_input_size,
|
||||
output_size=self.config.hidden_size,
|
||||
bias=False,
|
||||
params_dtype=vllm_config.model_config.dtype,
|
||||
quant_config=self.quant_config,
|
||||
prefix=maybe_prefix(prefix, "fc"),
|
||||
return_bias=False,
|
||||
)
|
||||
self.norm = RMSNorm(
|
||||
self.config.hidden_size,
|
||||
eps=self.config.rms_norm_eps,
|
||||
@ -332,6 +338,8 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
if not self.model.use_aux_hidden_state:
|
||||
return hidden_states
|
||||
# combine multiple auxiliary hidden states returned by eagle3
|
||||
return self.model.fc(hidden_states)
|
||||
|
||||
@ -357,6 +365,8 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
||||
skip_substrs.append("draft_id_to_target_id")
|
||||
if not includes_embed_tokens:
|
||||
skip_substrs.append("embed_tokens")
|
||||
if not self.model.use_aux_hidden_state:
|
||||
skip_substrs.append("fc.")
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=None,
|
||||
|
||||
@ -83,6 +83,9 @@ class EagleProposer:
|
||||
self.draft_indexer_metadata_builder: AttentionMetadataBuilder | None = None
|
||||
self.attn_layer_names: list[str] = []
|
||||
self.indexer_layer_names: list[str] = []
|
||||
self.eagle3_use_aux_hidden_state: bool = (
|
||||
self._get_eagle3_use_aux_hidden_state_from_config()
|
||||
)
|
||||
|
||||
self.use_cuda_graph = False
|
||||
|
||||
@ -1169,6 +1172,22 @@ class EagleProposer:
|
||||
)
|
||||
return builder
|
||||
|
||||
def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool:
|
||||
"""
|
||||
Some eagle3 heads (e.g., nvidia/gpt-oss-120b-Eagle3-v2) do not use auxiliary
|
||||
hidden states and directly uses the last layer output just like eagle1.
|
||||
They might indicate this by setting "use_aux_hidden_state" to False
|
||||
inside the "eagle_config" dict of their hf_config.
|
||||
"""
|
||||
if self.method != "eagle3":
|
||||
return False
|
||||
# Assume that eagle3 heads use aux hidden states by default
|
||||
use_aux_hidden_state = True
|
||||
eagle_config = getattr(self.draft_model_config.hf_config, "eagle_config", None)
|
||||
if eagle_config is not None:
|
||||
use_aux_hidden_state = eagle_config.get("use_aux_hidden_state", True)
|
||||
return use_aux_hidden_state
|
||||
|
||||
def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Validate that all eagle layers belong to the same KVCacheGroup.
|
||||
|
||||
@ -375,7 +375,9 @@ class GPUModelRunner(
|
||||
elif self.speculative_config.use_eagle():
|
||||
self.drafter = EagleProposer(self.vllm_config, self.device, self)
|
||||
if self.speculative_config.method == "eagle3":
|
||||
self.use_aux_hidden_state_outputs = True
|
||||
self.use_aux_hidden_state_outputs = (
|
||||
self.drafter.eagle3_use_aux_hidden_state
|
||||
)
|
||||
elif self.speculative_config.method == "medusa":
|
||||
self.drafter = MedusaProposer(
|
||||
vllm_config=self.vllm_config, device=self.device
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user