Add EAGLE-3 Speculative Decoding Support for Qwen3 MoE (#26485)

Signed-off-by: Rahul Tuli <rtuli@redhat.com>
This commit is contained in:
Rahul Tuli 2025-10-11 15:44:41 +05:30 committed by GitHub
parent 086609de64
commit d2a71530c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -64,7 +64,7 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.sequence import IntermediateTensors
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
from .interfaces import MixtureOfExperts, SupportsEagle3, SupportsLoRA, SupportsPP
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
@ -422,6 +422,8 @@ class Qwen3MoeModel(nn.Module):
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size
)
# Track layers for auxiliary hidden state outputs (EAGLE3)
self.aux_hidden_state_layers: tuple[int, ...] = ()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
@ -432,7 +434,9 @@ class Qwen3MoeModel(nn.Module):
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
) -> Union[
torch.Tensor, IntermediateTensors, tuple[torch.Tensor, list[torch.Tensor]]
]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
@ -443,13 +447,29 @@ class Qwen3MoeModel(nn.Module):
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in islice(self.layers, self.start_layer, self.end_layer):
aux_hidden_states = []
for layer_idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer),
start=self.start_layer,
):
# Collect auxiliary hidden states if specified
if layer_idx in self.aux_hidden_state_layers:
aux_hidden_state = (
hidden_states + residual if residual is not None else hidden_states
)
aux_hidden_states.append(aux_hidden_state)
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual}
)
hidden_states, _ = self.norm(hidden_states, residual)
# Return auxiliary hidden states if collected
if len(aux_hidden_states) > 0:
return hidden_states, aux_hidden_states
return hidden_states
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
@ -606,7 +626,9 @@ class Qwen3MoeModel(nn.Module):
return loaded_params
class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, MixtureOfExperts):
class Qwen3MoeForCausalLM(
nn.Module, SupportsPP, SupportsLoRA, SupportsEagle3, MixtureOfExperts
):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@ -702,6 +724,13 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, MixtureOfExperts)
moe.n_redundant_experts = self.num_redundant_experts
moe.experts.update_expert_map()
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)