[Feature] Shared Experts Overlap with FI deepgemm swap kernel, 2.2% throughput improvement and 3.6% TTFT improvement (#28879)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye 2025-11-20 20:41:49 -05:00 committed by GitHub
parent 87cbbdff63
commit df44df0143
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 119 additions and 33 deletions

View File

@ -50,6 +50,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
prepare_finalize,
old_quant_method.select_gemm_impl(prepare_finalize, moe_layer),
shared_experts,
getattr(moe_layer, "shared_experts_stream", None),
),
)

View File

@ -850,6 +850,45 @@ class FusedMoE(CustomOp):
dp_size=get_dp_group().world_size,
)
def _maybe_setup_shared_experts_stream(
self,
hidden_states: torch.Tensor,
has_separate_shared_experts: bool,
use_chunked_impl: bool,
) -> tuple[bool, torch.Tensor | None]:
use_shared_experts_stream = (
has_separate_shared_experts
and not use_chunked_impl
and self.shared_experts_stream is not None
and (
hidden_states.shape[0]
<= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
)
)
hidden_states_clone: torch.Tensor | None = None
if use_shared_experts_stream:
assert self.shared_experts_stream is not None
# Clone BEFORE switching streams to avoid race condition
# where routed_expert kernel may mutate hidden_states.
hidden_states_clone = hidden_states.clone()
# Record that the clone will be used by shared_experts_stream
# to avoid gc issue from deallocation of hidden_states_clone
# For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
# NOTE: We dont need shared_output.record_stream(current_stream())
# because we synch the streams before using shared_output.
hidden_states_clone.record_stream(self.shared_experts_stream)
# Mark sync start point for the separate shared experts
# stream here since we want to run in parallel with the
# router/gate (next op below)
assert self.shared_experts_stream is not None
self.shared_experts_stream.wait_stream(current_stream())
return use_shared_experts_stream, hidden_states_clone
def _load_per_tensor_weight_scale(
self,
shard_id: str,
@ -1819,36 +1858,12 @@ class FusedMoE(CustomOp):
use_chunked_impl = self.use_dp_chunking
use_shared_experts_stream = (
has_separate_shared_experts
and not use_chunked_impl
and self.shared_experts_stream is not None
and (
hidden_states.shape[0]
<= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
use_shared_experts_stream, hidden_states_clone = (
self._maybe_setup_shared_experts_stream(
hidden_states, has_separate_shared_experts, use_chunked_impl
)
)
if use_shared_experts_stream:
assert self.shared_experts_stream is not None
# Clone BEFORE switching streams to avoid race condition
# where routed_expert kernel may mutate hidden_states.
hidden_states_clone = hidden_states.clone()
# Record that the clone will be used by shared_experts_stream
# to avoid gc issue from deallocation of hidden_states_clone
# For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
# NOTE: We dont need shared_output.record_stream(current_stream())
# because we synch the streams before using shared_output.
hidden_states_clone.record_stream(self.shared_experts_stream)
# Mark sync start point for the separate shared experts
# stream here since we want to run in parallel with the
# router/gate (next op below)
assert self.shared_experts_stream is not None
self.shared_experts_stream.wait_stream(current_stream())
# If router/gate provided, then apply it here.
# (Note: This code runs only when "overlapped mode" is on to allow
# parallel execution of shared experts with the FusedMoE via

View File

@ -16,6 +16,7 @@ from vllm.model_executor.layers.fused_moe.utils import (
count_expert_num_tokens,
disable_inplace,
)
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
from vllm.v1.worker.ubatching import (
dbo_current_ubatch_id,
@ -709,11 +710,13 @@ class FusedMoEModularKernel(torch.nn.Module):
prepare_finalize: FusedMoEPrepareAndFinalize,
fused_experts: FusedMoEPermuteExpertsUnpermute,
shared_experts: torch.nn.Module | None = None,
shared_experts_stream: torch.cuda.Stream | None = None,
):
super().__init__()
self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts
self.shared_experts = shared_experts
self.shared_experts_stream = shared_experts_stream
self._post_init_setup()
assert (
@ -890,6 +893,34 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_num_tokens_cpu=c_expert_num_tokens_cpu,
)
def _maybe_setup_shared_experts_stream(
self, hidden_states: torch.Tensor
) -> tuple[bool, torch.Tensor | None]:
# decide whether to run shared experts on a separate CUDA stream to
# overlap with the main fused MoE kernel.
use_shared_experts_stream = (
self.shared_experts is not None
and self.shared_experts_stream is not None
and hidden_states.is_cuda
and (
hidden_states.shape[0]
<= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
)
)
hidden_states_clone: torch.Tensor | None = None
if use_shared_experts_stream and self.shared_experts_stream is not None:
# TODO: Optimize this (complicated)
# Note: this clone adds overhead but is required
# for correctness with multiple CUDA streams and CUDA graph capture.
hidden_states_clone = hidden_states.clone()
# record that the clone will be used by the separate stream so its
# lifetime is correctly tracked.
hidden_states_clone.record_stream(self.shared_experts_stream)
self.shared_experts_stream.wait_stream(torch.cuda.current_stream())
return use_shared_experts_stream, hidden_states_clone
def _prepare(
self,
hidden_states: torch.Tensor,
@ -1077,12 +1108,30 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
hidden_states_clone: torch.Tensor | None = None,
use_shared_experts_stream: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
The _finalize method is a wrapper around self.prepare_finalize.finalize
that handles DBO, async and shared expert overlap.
"""
shared_output: torch.Tensor | None = None
def maybe_run_shared_experts() -> torch.Tensor | None:
if self.shared_experts is None:
return None
if (
not use_shared_experts_stream
or self.shared_experts_stream is not None
and (not hidden_states.is_cuda or not torch.cuda.is_available())
):
# fall back to running on the current stream
return self.shared_experts(hidden_states)
assert hidden_states_clone is not None
# launch shared experts on the dedicated stream.
with torch.cuda.stream(self.shared_experts_stream):
return self.shared_experts(hidden_states_clone)
if not self.prepare_finalize.supports_async():
assert not dbo_enabled()
@ -1095,8 +1144,7 @@ class FusedMoEModularKernel(torch.nn.Module):
apply_router_weight_on_input,
self.fused_experts.finalize_weight_and_reduce_impl(),
)
if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
shared_output = maybe_run_shared_experts()
else:
finalize_ret = self.prepare_finalize.finalize_async(
output,
@ -1107,8 +1155,7 @@ class FusedMoEModularKernel(torch.nn.Module):
self.fused_experts.finalize_weight_and_reduce_impl(),
)
if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
shared_output = maybe_run_shared_experts()
# TODO(lucas): refactor this in the alternative schedules followup
# currently unpack if we have hook + receiver pair or just
@ -1131,12 +1178,28 @@ class FusedMoEModularKernel(torch.nn.Module):
receiver()
self._wait_for_shared_experts_stream(hidden_states, use_shared_experts_stream)
if self.shared_experts is None:
return output
else:
assert shared_output is not None
return shared_output, output
def _wait_for_shared_experts_stream(
self, hidden_states: torch.Tensor, use_shared_experts_stream: bool
) -> None:
# ensure that any work enqueued on the shared_experts_stream is
# completed before the shared_output tensor is consumed
if (
self.shared_experts is not None
and use_shared_experts_stream
and self.shared_experts_stream is not None
and hidden_states.is_cuda
and current_platform.is_cuda()
):
torch.cuda.current_stream().wait_stream(self.shared_experts_stream)
def forward(
self,
hidden_states: torch.Tensor,
@ -1183,6 +1246,10 @@ class FusedMoEModularKernel(torch.nn.Module):
else:
output = torch.zeros_like(hidden_states)
use_shared_experts_stream, hidden_states_clone = (
self._maybe_setup_shared_experts_stream(hidden_states)
)
local_num_experts = w1.size(0)
if global_num_experts == -1:
global_num_experts = local_num_experts
@ -1219,4 +1286,6 @@ class FusedMoEModularKernel(torch.nn.Module):
topk_weights,
topk_ids,
apply_router_weight_on_input,
hidden_states_clone=hidden_states_clone,
use_shared_experts_stream=use_shared_experts_stream,
)

View File

@ -45,7 +45,8 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1"
)
a1.mul_(topk_weights.to(a1.dtype))
# Note: do not use inplace for shared experts overlap
a1 = a1 * topk_weights.to(a1.dtype)
a1q, a1q_scale = moe_kernel_quantize_input(
a1,