[MoE][Kernel][Perf] Improve Shared Expert Stream Overlap (#28406)

Signed-off-by: Alexander Matveev <amatveev@redhat.com>
This commit is contained in:
Alexander Matveev 2025-11-12 18:37:24 -05:00 committed by GitHub
parent 4ca5cd5740
commit 69d0e90313
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 81 additions and 36 deletions

View File

@ -222,6 +222,7 @@ if TYPE_CHECKING:
VLLM_USE_FBGEMM: bool = False
VLLM_GC_DEBUG: str = ""
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
VLLM_FLAT_LOGPROBS: bool = False
@ -1476,6 +1477,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: bool(
int(os.getenv("VLLM_DISABLE_SHARED_EXPERTS_STREAM", "0"))
),
# Limits when we run shared_experts in a separate stream.
# We found out that for large batch sizes, the separate stream
# execution is not beneficial (most likely because of the input clone)
# TODO(alexm-redhat): Tune to be more dynamic based on GPU type
"VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD": lambda: int(
int(os.getenv("VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD", 256))
),
# Format for saving torch.compile cache artifacts
# - "binary": saves as binary file
# Safe for multiple vllm serve processes accessing the same torch compile cache.

View File

@ -48,7 +48,11 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
)
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.torch_utils import current_stream, direct_register_custom_op
from vllm.utils.torch_utils import (
aux_stream,
current_stream,
direct_register_custom_op,
)
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
if current_platform.is_cuda_alike():
@ -331,7 +335,11 @@ class FusedMoE(CustomOp):
logger.info_once("Disabling MoE shared_experts cuda stream")
self.shared_experts_stream = None
else:
self.shared_experts_stream = torch.cuda.Stream()
# TODO(rob): enable shared expert overlap with non-cuda.
# aux_stream() returns None on non-cuda platforms.
self.shared_experts_stream = aux_stream()
if self.shared_experts_stream is not None:
logger.info_once("Enabled separate cuda stream for MoE shared_experts")
if params_dtype is None:
params_dtype = torch.get_default_dtype()
@ -1606,7 +1614,9 @@ class FusedMoE(CustomOp):
if has_separate_shared_experts:
assert not isinstance(final_hidden_states, tuple)
assert self.shared_experts is not None
shared_output = self.shared_experts(staged_hidden_states)
final_hidden_states = (
shared_output,
final_hidden_states,
@ -1684,13 +1694,34 @@ class FusedMoE(CustomOp):
use_chunked_impl = self.use_dp_chunking
if (
use_shared_experts_stream = (
has_separate_shared_experts
and not use_chunked_impl
and self.shared_experts_stream is not None
):
# Start the separate shared experts stream here since we want
# to run in parallel with the router/gate (next op below)
and (
hidden_states.shape[0]
<= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
)
)
if use_shared_experts_stream:
assert self.shared_experts_stream is not None
# Clone BEFORE switching streams to avoid race condition
# where routed_expert kernel may mutate hidden_states.
hidden_states_clone = hidden_states.clone()
# Record that the clone will be used by shared_experts_stream
# to avoid gc issue from deallocation of hidden_states_clone
# For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
# NOTE: We dont need shared_output.record_stream(current_stream())
# because we synch the streams before using shared_output.
hidden_states_clone.record_stream(self.shared_experts_stream)
# Mark sync start point for the separate shared experts
# stream here since we want to run in parallel with the
# router/gate (next op below)
assert self.shared_experts_stream is not None
self.shared_experts_stream.wait_stream(current_stream())
# If router/gate provided, then apply it here.
@ -1709,33 +1740,6 @@ class FusedMoE(CustomOp):
self.quant_method, FusedMoEModularMethod
)
# If there are shared experts but we are not using a modular kernel, the
# shared experts must be called here
if has_separate_shared_experts:
assert self.shared_experts is not None
if self.shared_experts_stream is not None:
# Clone BEFORE switching streams to avoid race condition
# where routed_expert kernel may mutate hidden_states.
hidden_states_clone = hidden_states.clone()
self.shared_experts_stream.wait_stream(current_stream())
# Run shared experts in parallel on a separate stream
with torch.cuda.stream(self.shared_experts_stream):
shared_output = self.shared_experts(hidden_states_clone)
# Record that the clone will be used by shared_experts_stream
# to avoid gc issue from deallocation of hidden_states_clone
# For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
# NOTE: we dont need shared_output.record_stream(current_stream())
# because we synch the streams before using shared_output.
hidden_states_clone.record_stream(self.shared_experts_stream)
else:
shared_output = self.shared_experts(hidden_states)
else:
shared_output = None
ctx = get_forward_context()
sp_ctx = (
ctx.dp_metadata.sp_local_sizes(self.sp_size)
@ -1776,12 +1780,21 @@ class FusedMoE(CustomOp):
)
if has_separate_shared_experts:
assert not isinstance(final_hidden_states, tuple)
assert self.shared_experts is not None
# Wait for the parallel shared experts stream to finish here
if self.shared_experts_stream is not None:
if use_shared_experts_stream:
# Run shared experts in parallel on a separate stream
# NOTE: We start the separate stream here and mark the
# sync end point immediately after it is done. This is
# important to avoid excessive stream allocations by the cuda
# graph replay later.
with torch.cuda.stream(self.shared_experts_stream):
# Note that hidden_states clone() is necessary here to avoid
# conflict with the main stream
shared_output = self.shared_experts(hidden_states_clone)
current_stream().wait_stream(self.shared_experts_stream)
else:
shared_output = self.shared_experts(hidden_states)
final_hidden_states = (
shared_output,

View File

@ -409,6 +409,30 @@ def current_stream() -> torch.cuda.Stream:
return _current_stream_tls.value
# Global auxilary stream for running operations in background streams.
# We have single global auxilary stream to avoid an explosion of streams
# for every layer (and make profiling look sane).
#
# aux_stream() is currently used for:
# - MoE shared_expert overlap with router
_aux_stream: torch.cuda.Stream | None = None
def aux_stream() -> torch.cuda.Stream | None:
"""
Ensures aux_stream is initialized only once
"""
global _aux_stream
from vllm.platforms import current_platform
# TODO: validate this works properly on ROCm platform.
if _aux_stream is None and current_platform.is_cuda():
_aux_stream = torch.cuda.Stream()
return _aux_stream
@lru_cache(maxsize=8)
def _cuda_device_count_stateless(cuda_visible_devices: str | None = None) -> int:
# Note: cuda_visible_devices is not used, but we keep it as an argument for