mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 16:35:43 +08:00
[MoE][Kernel][Perf] Improve Shared Expert Stream Overlap (#28406)
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
This commit is contained in:
parent
4ca5cd5740
commit
69d0e90313
@ -222,6 +222,7 @@ if TYPE_CHECKING:
|
||||
VLLM_USE_FBGEMM: bool = False
|
||||
VLLM_GC_DEBUG: str = ""
|
||||
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
|
||||
VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256
|
||||
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
|
||||
VLLM_FLAT_LOGPROBS: bool = False
|
||||
|
||||
@ -1476,6 +1477,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: bool(
|
||||
int(os.getenv("VLLM_DISABLE_SHARED_EXPERTS_STREAM", "0"))
|
||||
),
|
||||
# Limits when we run shared_experts in a separate stream.
|
||||
# We found out that for large batch sizes, the separate stream
|
||||
# execution is not beneficial (most likely because of the input clone)
|
||||
# TODO(alexm-redhat): Tune to be more dynamic based on GPU type
|
||||
"VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD": lambda: int(
|
||||
int(os.getenv("VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD", 256))
|
||||
),
|
||||
# Format for saving torch.compile cache artifacts
|
||||
# - "binary": saves as binary file
|
||||
# Safe for multiple vllm serve processes accessing the same torch compile cache.
|
||||
|
||||
@ -48,7 +48,11 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv, round_up
|
||||
from vllm.utils.torch_utils import current_stream, direct_register_custom_op
|
||||
from vllm.utils.torch_utils import (
|
||||
aux_stream,
|
||||
current_stream,
|
||||
direct_register_custom_op,
|
||||
)
|
||||
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
@ -331,7 +335,11 @@ class FusedMoE(CustomOp):
|
||||
logger.info_once("Disabling MoE shared_experts cuda stream")
|
||||
self.shared_experts_stream = None
|
||||
else:
|
||||
self.shared_experts_stream = torch.cuda.Stream()
|
||||
# TODO(rob): enable shared expert overlap with non-cuda.
|
||||
# aux_stream() returns None on non-cuda platforms.
|
||||
self.shared_experts_stream = aux_stream()
|
||||
if self.shared_experts_stream is not None:
|
||||
logger.info_once("Enabled separate cuda stream for MoE shared_experts")
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
@ -1606,7 +1614,9 @@ class FusedMoE(CustomOp):
|
||||
if has_separate_shared_experts:
|
||||
assert not isinstance(final_hidden_states, tuple)
|
||||
assert self.shared_experts is not None
|
||||
|
||||
shared_output = self.shared_experts(staged_hidden_states)
|
||||
|
||||
final_hidden_states = (
|
||||
shared_output,
|
||||
final_hidden_states,
|
||||
@ -1684,13 +1694,34 @@ class FusedMoE(CustomOp):
|
||||
|
||||
use_chunked_impl = self.use_dp_chunking
|
||||
|
||||
if (
|
||||
use_shared_experts_stream = (
|
||||
has_separate_shared_experts
|
||||
and not use_chunked_impl
|
||||
and self.shared_experts_stream is not None
|
||||
):
|
||||
# Start the separate shared experts stream here since we want
|
||||
# to run in parallel with the router/gate (next op below)
|
||||
and (
|
||||
hidden_states.shape[0]
|
||||
<= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
|
||||
)
|
||||
)
|
||||
|
||||
if use_shared_experts_stream:
|
||||
assert self.shared_experts_stream is not None
|
||||
|
||||
# Clone BEFORE switching streams to avoid race condition
|
||||
# where routed_expert kernel may mutate hidden_states.
|
||||
hidden_states_clone = hidden_states.clone()
|
||||
|
||||
# Record that the clone will be used by shared_experts_stream
|
||||
# to avoid gc issue from deallocation of hidden_states_clone
|
||||
# For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
|
||||
# NOTE: We dont need shared_output.record_stream(current_stream())
|
||||
# because we synch the streams before using shared_output.
|
||||
hidden_states_clone.record_stream(self.shared_experts_stream)
|
||||
|
||||
# Mark sync start point for the separate shared experts
|
||||
# stream here since we want to run in parallel with the
|
||||
# router/gate (next op below)
|
||||
assert self.shared_experts_stream is not None
|
||||
self.shared_experts_stream.wait_stream(current_stream())
|
||||
|
||||
# If router/gate provided, then apply it here.
|
||||
@ -1709,33 +1740,6 @@ class FusedMoE(CustomOp):
|
||||
self.quant_method, FusedMoEModularMethod
|
||||
)
|
||||
|
||||
# If there are shared experts but we are not using a modular kernel, the
|
||||
# shared experts must be called here
|
||||
if has_separate_shared_experts:
|
||||
assert self.shared_experts is not None
|
||||
|
||||
if self.shared_experts_stream is not None:
|
||||
# Clone BEFORE switching streams to avoid race condition
|
||||
# where routed_expert kernel may mutate hidden_states.
|
||||
hidden_states_clone = hidden_states.clone()
|
||||
self.shared_experts_stream.wait_stream(current_stream())
|
||||
|
||||
# Run shared experts in parallel on a separate stream
|
||||
with torch.cuda.stream(self.shared_experts_stream):
|
||||
shared_output = self.shared_experts(hidden_states_clone)
|
||||
|
||||
# Record that the clone will be used by shared_experts_stream
|
||||
# to avoid gc issue from deallocation of hidden_states_clone
|
||||
# For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
|
||||
# NOTE: we dont need shared_output.record_stream(current_stream())
|
||||
# because we synch the streams before using shared_output.
|
||||
hidden_states_clone.record_stream(self.shared_experts_stream)
|
||||
|
||||
else:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
else:
|
||||
shared_output = None
|
||||
|
||||
ctx = get_forward_context()
|
||||
sp_ctx = (
|
||||
ctx.dp_metadata.sp_local_sizes(self.sp_size)
|
||||
@ -1776,12 +1780,21 @@ class FusedMoE(CustomOp):
|
||||
)
|
||||
|
||||
if has_separate_shared_experts:
|
||||
assert not isinstance(final_hidden_states, tuple)
|
||||
assert self.shared_experts is not None
|
||||
|
||||
# Wait for the parallel shared experts stream to finish here
|
||||
if self.shared_experts_stream is not None:
|
||||
if use_shared_experts_stream:
|
||||
# Run shared experts in parallel on a separate stream
|
||||
# NOTE: We start the separate stream here and mark the
|
||||
# sync end point immediately after it is done. This is
|
||||
# important to avoid excessive stream allocations by the cuda
|
||||
# graph replay later.
|
||||
with torch.cuda.stream(self.shared_experts_stream):
|
||||
# Note that hidden_states clone() is necessary here to avoid
|
||||
# conflict with the main stream
|
||||
shared_output = self.shared_experts(hidden_states_clone)
|
||||
current_stream().wait_stream(self.shared_experts_stream)
|
||||
else:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
|
||||
final_hidden_states = (
|
||||
shared_output,
|
||||
|
||||
@ -409,6 +409,30 @@ def current_stream() -> torch.cuda.Stream:
|
||||
return _current_stream_tls.value
|
||||
|
||||
|
||||
# Global auxilary stream for running operations in background streams.
|
||||
# We have single global auxilary stream to avoid an explosion of streams
|
||||
# for every layer (and make profiling look sane).
|
||||
#
|
||||
# aux_stream() is currently used for:
|
||||
# - MoE shared_expert overlap with router
|
||||
_aux_stream: torch.cuda.Stream | None = None
|
||||
|
||||
|
||||
def aux_stream() -> torch.cuda.Stream | None:
|
||||
"""
|
||||
Ensures aux_stream is initialized only once
|
||||
"""
|
||||
global _aux_stream
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# TODO: validate this works properly on ROCm platform.
|
||||
if _aux_stream is None and current_platform.is_cuda():
|
||||
_aux_stream = torch.cuda.Stream()
|
||||
|
||||
return _aux_stream
|
||||
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
def _cuda_device_count_stateless(cuda_visible_devices: str | None = None) -> int:
|
||||
# Note: cuda_visible_devices is not used, but we keep it as an argument for
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user