diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 34b5af846493..825272535a45 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -64,7 +64,7 @@ from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.sequence import IntermediateTensors -from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP +from .interfaces import MixtureOfExperts, SupportsEagle3, SupportsLoRA, SupportsPP from .utils import ( AutoWeightsLoader, PPMissingLayer, @@ -422,6 +422,8 @@ class Qwen3MoeModel(nn.Module): self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size ) + # Track layers for auxiliary hidden state outputs (EAGLE3) + self.aux_hidden_state_layers: tuple[int, ...] = () def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -432,7 +434,9 @@ class Qwen3MoeModel(nn.Module): positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + ) -> Union[ + torch.Tensor, IntermediateTensors, tuple[torch.Tensor, list[torch.Tensor]] + ]: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -443,13 +447,29 @@ class Qwen3MoeModel(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in islice(self.layers, self.start_layer, self.end_layer): + + aux_hidden_states = [] + for layer_idx, layer in enumerate( + islice(self.layers, self.start_layer, self.end_layer), + start=self.start_layer, + ): + # Collect auxiliary hidden states if specified + if layer_idx in self.aux_hidden_state_layers: + aux_hidden_state = ( + hidden_states + residual if residual is not None else hidden_states + ) + aux_hidden_states.append(aux_hidden_state) hidden_states, residual = layer(positions, hidden_states, residual) + if not get_pp_group().is_last_rank: return IntermediateTensors( {"hidden_states": hidden_states, "residual": residual} ) hidden_states, _ = self.norm(hidden_states, residual) + + # Return auxiliary hidden states if collected + if len(aux_hidden_states) > 0: + return hidden_states, aux_hidden_states return hidden_states def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: @@ -606,7 +626,9 @@ class Qwen3MoeModel(nn.Module): return loaded_params -class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, MixtureOfExperts): +class Qwen3MoeForCausalLM( + nn.Module, SupportsPP, SupportsLoRA, SupportsEagle3, MixtureOfExperts +): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -702,6 +724,13 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, MixtureOfExperts) moe.n_redundant_experts = self.num_redundant_experts moe.experts.update_expert_map() + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + self.model.aux_hidden_state_layers = layers + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + num_layers = len(self.model.layers) + return (2, num_layers // 2, num_layers - 3) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids)