From e5c78956c0c576d8f7230c29550ff09ffff0c064 Mon Sep 17 00:00:00 2001 From: Alexander Matveev <59768536+alexm-redhat@users.noreply.github.com> Date: Fri, 14 Nov 2025 17:13:46 -0500 Subject: [PATCH] [Bugfix] Fix incorrect use of hidden_states for shared_experts due to do_naive_dispatch_combine (#28740) Signed-off-by: Alexander Matveev --- vllm/model_executor/layers/fused_moe/layer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index aed8245cbd830..023132acfed3f 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1749,14 +1749,16 @@ class FusedMoE(CustomOp): with sp_ctx: if do_naive_dispatch_combine: - hidden_states, router_logits = get_ep_group().dispatch( + hidden_states_combined, router_logits = get_ep_group().dispatch( hidden_states, router_logits, self.is_sequence_parallel ) # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, - x=hidden_states, + x=hidden_states_combined + if do_naive_dispatch_combine + else hidden_states, router_logits=router_logits, top_k=self.top_k, renormalize=self.renormalize,