mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 07:44:55 +08:00
Add EAGLE-3 Speculative Decoding Support for Qwen3 MoE (#26485)
Signed-off-by: Rahul Tuli <rtuli@redhat.com>
This commit is contained in:
parent
086609de64
commit
d2a71530c1
@ -64,7 +64,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
from vllm.model_executor.models.utils import sequence_parallel_chunk
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
|
||||
from .interfaces import MixtureOfExperts, SupportsEagle3, SupportsLoRA, SupportsPP
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
PPMissingLayer,
|
||||
@ -422,6 +422,8 @@ class Qwen3MoeModel(nn.Module):
|
||||
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
# Track layers for auxiliary hidden state outputs (EAGLE3)
|
||||
self.aux_hidden_state_layers: tuple[int, ...] = ()
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
@ -432,7 +434,9 @@ class Qwen3MoeModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
) -> Union[
|
||||
torch.Tensor, IntermediateTensors, tuple[torch.Tensor, list[torch.Tensor]]
|
||||
]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
@ -443,13 +447,29 @@ class Qwen3MoeModel(nn.Module):
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||
|
||||
aux_hidden_states = []
|
||||
for layer_idx, layer in enumerate(
|
||||
islice(self.layers, self.start_layer, self.end_layer),
|
||||
start=self.start_layer,
|
||||
):
|
||||
# Collect auxiliary hidden states if specified
|
||||
if layer_idx in self.aux_hidden_state_layers:
|
||||
aux_hidden_state = (
|
||||
hidden_states + residual if residual is not None else hidden_states
|
||||
)
|
||||
aux_hidden_states.append(aux_hidden_state)
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors(
|
||||
{"hidden_states": hidden_states, "residual": residual}
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
# Return auxiliary hidden states if collected
|
||||
if len(aux_hidden_states) > 0:
|
||||
return hidden_states, aux_hidden_states
|
||||
return hidden_states
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
@ -606,7 +626,9 @@ class Qwen3MoeModel(nn.Module):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, MixtureOfExperts):
|
||||
class Qwen3MoeForCausalLM(
|
||||
nn.Module, SupportsPP, SupportsLoRA, SupportsEagle3, MixtureOfExperts
|
||||
):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@ -702,6 +724,13 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, MixtureOfExperts)
|
||||
moe.n_redundant_experts = self.num_redundant_experts
|
||||
moe.experts.update_expert_map()
|
||||
|
||||
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
|
||||
self.model.aux_hidden_state_layers = layers
|
||||
|
||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
|
||||
num_layers = len(self.model.layers)
|
||||
return (2, num_layers // 2, num_layers - 3)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user