[Performance] Dual stream execution of "shared_experts" and "selected_experts" inside FusedMoE (#26440)

Signed-off-by: Alexander Matveev <amatveev@redhat.com>
This commit is contained in:
Alexander Matveev 2025-10-21 17:38:29 -04:00 committed by GitHub
parent becb7de40b
commit 344a0017c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 122 additions and 22 deletions

View File

@ -213,6 +213,7 @@ if TYPE_CHECKING:
VLLM_NCCL_INCLUDE_PATH: str | None = None
VLLM_USE_FBGEMM: bool = False
VLLM_GC_DEBUG: str = ""
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
def get_default_cache_root():
@ -1379,6 +1380,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
# - VLLM_GC_DEBUG='{"top_objects":5}': enable GC debugger with
# top 5 collected objects
"VLLM_GC_DEBUG": lambda: os.getenv("VLLM_GC_DEBUG", ""),
# Disables parallel execution of shared_experts via separate cuda stream
"VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: os.getenv(
"VLLM_DISABLE_SHARED_EXPERTS_STREAM", False
),
}
# --8<-- [end:env-vars-definition]

View File

@ -57,7 +57,7 @@ from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
from vllm.utils import cdiv, has_deep_ep, has_pplx, round_up
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.utils.torch_utils import current_stream, direct_register_custom_op
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
if current_platform.is_cuda_alike():
@ -1082,6 +1082,17 @@ class FusedMoE(CustomOp):
n_shared_experts: int | None = None,
):
super().__init__()
# Allow disabling of the separate shared experts stream for
# debug purposes.
# TODO: Remove this after more extensive testings with TP/DP
# and other execution modes
if envs.VLLM_DISABLE_SHARED_EXPERTS_STREAM:
logger.info_once("Disabling MoE shared_experts cuda stream")
self.shared_experts_stream = None
else:
self.shared_experts_stream = torch.cuda.Stream()
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
@ -1332,6 +1343,10 @@ class FusedMoE(CustomOp):
def shared_experts(self) -> torch.nn.Module | None:
return None
@property
def gate(self) -> torch.nn.Module | None:
return None
@property
def tp_size(self):
return self.moe_parallel_config.tp_size
@ -1390,6 +1405,11 @@ class FusedMoE(CustomOp):
or (self.dp_size > 1 and self.use_flashinfer_cutlass_kernels)
)
@property
def is_internal_router(self) -> bool:
# By default, router/gate is called before FusedMoE forward pass
return False
def update_expert_map(self):
# ep_size and ep_rank should already be updated
assert self.expert_map is not None
@ -2168,6 +2188,7 @@ class FusedMoE(CustomOp):
self,
full_hidden_states: torch.Tensor,
full_router_logits: torch.Tensor,
has_separate_shared_experts: bool,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.batched_hidden_states is not None
assert self.batched_router_logits is not None
@ -2216,11 +2237,23 @@ class FusedMoE(CustomOp):
# If there are shared experts but we are not using a modular kernel,
# the shared experts must be called here
if (
not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel)
and self.shared_experts is not None
):
shared_output = self.shared_experts(staged_hidden_states)
if has_separate_shared_experts:
assert self.shared_experts is not None
if self.shared_experts_stream is not None:
# For chunked, we start the shared experts stream here
# (Note that no concurrency with the router/gate)
self.shared_experts_stream.wait_stream(current_stream())
with torch.cuda.stream(self.shared_experts_stream):
# Note that staged_hidden_states clone() is necessary
# here to avoid conflict with the main stream
shared_output = self.shared_experts(
staged_hidden_states.clone()
)
else:
shared_output = self.shared_experts(staged_hidden_states)
else:
shared_output = None
@ -2249,9 +2282,14 @@ class FusedMoE(CustomOp):
logical_replica_count=self.logical_replica_count,
)
if shared_output is not None:
if has_separate_shared_experts:
assert not isinstance(final_hidden_states, tuple)
assert self.shared_experts is not None
# Here we finish the shared experts stream
if self.shared_experts_stream is not None:
current_stream().wait_stream(self.shared_experts_stream)
final_hidden_states = (
shared_output,
final_hidden_states,
@ -2321,8 +2359,33 @@ class FusedMoE(CustomOp):
self.ensure_moe_quant_config()
if self.use_dp_chunking:
return self.forward_impl_chunked(hidden_states, router_logits)
has_separate_shared_experts = (
not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel)
and self.shared_experts is not None
)
use_chunked_impl = self.use_dp_chunking
if (
has_separate_shared_experts
and not use_chunked_impl
and self.shared_experts_stream is not None
):
# Start the separate shared experts stream here since we want
# to run in parallel with the router/gate (next op below)
self.shared_experts_stream.wait_stream(current_stream())
# If router/gate provided, then apply it here.
# (Note: This code runs only when "overlapped mode" is on to allow
# parallel execution of shared experts with the FusedMoE via
# separate cuda stream)
if self.gate is not None:
router_logits, _ = self.gate(hidden_states)
if use_chunked_impl:
return self.forward_impl_chunked(
hidden_states, router_logits, has_separate_shared_experts
)
do_naive_dispatch_combine: bool = (
self.dp_size > 1 and not self.quant_method.using_modular_kernel
@ -2330,11 +2393,17 @@ class FusedMoE(CustomOp):
# If there are shared experts but we are not using a modular kernel, the
# shared experts must be called here
if (
not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel)
and self.shared_experts is not None
):
shared_output = self.shared_experts(hidden_states)
if has_separate_shared_experts:
assert self.shared_experts is not None
if self.shared_experts_stream is not None:
# Run shared experts in parallel on a separate stream
with torch.cuda.stream(self.shared_experts_stream):
# Note that hidden_states clone() is necessary here to avoid
# conflict with the main stream
shared_output = self.shared_experts(hidden_states.clone())
else:
shared_output = self.shared_experts(hidden_states)
else:
shared_output = None
@ -2377,9 +2446,14 @@ class FusedMoE(CustomOp):
logical_replica_count=self.logical_replica_count,
)
if shared_output is not None:
if has_separate_shared_experts:
assert not isinstance(final_hidden_states, tuple)
assert self.shared_experts is not None
# Wait for the parallel shared experts stream to finish here
if self.shared_experts_stream is not None:
current_stream().wait_stream(self.shared_experts_stream)
final_hidden_states = (
shared_output,
final_hidden_states,

View File

@ -18,25 +18,40 @@ class SharedFusedMoE(FusedMoE):
def __init__(
self,
shared_experts: torch.nn.Module | None,
gate: torch.nn.Module | None = None,
use_overlapped: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self._shared_experts = shared_experts
# Disable shared expert overlap if EP is disabled or we are not using
# flashinfer + DP since there is nothing to be gained in this case.
# Disabling the overlap optimization also prevents the shared experts
# from being hidden from torch.compile.
self.use_overlapped = (
use_overlapped
and not (self.use_ep or self.use_flashinfer_cutlass_kernels)
and not (
self.use_ep
or (self.use_flashinfer_cutlass_kernels and self.dp_size > 1)
)
and self._shared_experts is not None
)
self._gate = gate
@property
def shared_experts(self) -> torch.nn.Module | None:
return self._shared_experts if self.use_overlapped else None
@property
def gate(self) -> torch.nn.Module | None:
return self._gate if self.use_overlapped else None
@property
def is_internal_router(self) -> bool:
return self.gate is not None
def forward(
self,
hidden_states: torch.Tensor,

View File

@ -227,6 +227,7 @@ class DeepseekV2MoE(nn.Module):
self.experts = SharedFusedMoE(
shared_experts=self.shared_experts,
gate=self.gate,
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
@ -264,12 +265,17 @@ class DeepseekV2MoE(nn.Module):
if self.is_sequence_parallel:
hidden_states = sequence_parallel_chunk(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
if self.experts.is_internal_router:
# In this case, the gate/router runs inside the FusedMoE class
fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=hidden_states
)
else:
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
shared_output, final_hidden_states = fused_moe_out
if self.shared_experts is None: