From 8df98c2161e28387391b667201f8458c2bdf29f4 Mon Sep 17 00:00:00 2001 From: Jiangyun Zhu Date: Wed, 29 Oct 2025 16:12:54 +0800 Subject: [PATCH] [perf] Enable concurrent execution of "shared_experts" and "selected_experts" in qwen3-next (#27578) Signed-off-by: zjy0516 --- vllm/model_executor/models/qwen3_next.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index e81ad5f68d8f3..f452ba871582d 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -159,6 +159,7 @@ class Qwen3NextSparseMoeBlock(nn.Module): self.experts = SharedFusedMoE( shared_experts=self.shared_expert, + gate=self.gate, num_experts=self.n_routed_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, @@ -181,11 +182,17 @@ class Qwen3NextSparseMoeBlock(nn.Module): if self.is_sequence_parallel: hidden_states = sequence_parallel_chunk(hidden_states) - # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts( - hidden_states=hidden_states, router_logits=router_logits - ) + if self.experts.is_internal_router: + # In this case, the gate/router runs inside the FusedMoE class + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=hidden_states + ) + else: + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) if self.shared_expert is not None: final_hidden_states = final_hidden_states[0] + final_hidden_states[1]