From d71af5f5020e0fee5375b3cf7898852abbae22f2 Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Wed, 5 Nov 2025 20:21:08 -0500 Subject: [PATCH] [Feature] Enable TP + EP `shared_experts` overlap with router, 3.7% E2E performance improvement (#28164) Signed-off-by: yewentao256 --- vllm/model_executor/layers/fused_moe/layer.py | 2 +- .../layers/fused_moe/shared_fused_moe.py | 22 +++++++++++++------ 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 0a8c2f311f5c6..1236116386c97 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1178,7 +1178,7 @@ class FusedMoE(CustomOp): hidden_size: Input hidden state size of the transformer intermediate_size: Intermediate size of the experts params_dtype: Data type for the parameters. - reduce_results: Whether to all all_reduce on the output of the layer + reduce_results: Whether to all_reduce on the output of the layer renormalize: Whether to renormalize the logits in the fused_moe kernel quant_config: Quantization configure. enable_eplb: Whether to enable expert parallelism load balancer. diff --git a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py index 2db733b765cea..6b4a0b8cf0730 100644 --- a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py @@ -3,7 +3,10 @@ import torch -from vllm.distributed import tensor_model_parallel_all_reduce +from vllm.distributed import ( + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) from vllm.model_executor.layers.fused_moe.layer import FusedMoE @@ -25,16 +28,13 @@ class SharedFusedMoE(FusedMoE): super().__init__(**kwargs) self._shared_experts = shared_experts - # Disable shared expert overlap if EP is disabled or we are not using + # Disable shared expert overlap if we are not using # flashinfer + DP since there is nothing to be gained in this case. # Disabling the overlap optimization also prevents the shared experts # from being hidden from torch.compile. self.use_overlapped = ( use_overlapped - and not ( - self.use_ep - or (self.use_flashinfer_cutlass_kernels and self.dp_size > 1) - ) + and not (self.use_flashinfer_cutlass_kernels and self.dp_size > 1) and self._shared_experts is not None ) @@ -65,7 +65,7 @@ class SharedFusedMoE(FusedMoE): # should have been created with reduce_results=False. if ( self.reduce_results - and self.tp_size > 1 + and get_tensor_model_parallel_world_size() > 1 and self.must_reduce_shared_expert_outputs() ): shared_out = tensor_model_parallel_all_reduce(shared_out) @@ -81,4 +81,12 @@ class SharedFusedMoE(FusedMoE): hidden_states=hidden_states, router_logits=router_logits, ) + # ensure early TP reduction of shared expert outputs when required + if ( + shared_out is not None + and self.reduce_results + and get_tensor_model_parallel_world_size() > 1 + and self.must_reduce_shared_expert_outputs() + ): + shared_out = tensor_model_parallel_all_reduce(shared_out) return shared_out, fused_out