mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:54:56 +08:00
[MoE][Kernel][Perf] Improve Shared Expert Stream Overlap (#28406)
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
This commit is contained in:
parent
4ca5cd5740
commit
69d0e90313
@ -222,6 +222,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_USE_FBGEMM: bool = False
|
VLLM_USE_FBGEMM: bool = False
|
||||||
VLLM_GC_DEBUG: str = ""
|
VLLM_GC_DEBUG: str = ""
|
||||||
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
|
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
|
||||||
|
VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256
|
||||||
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
|
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
|
||||||
VLLM_FLAT_LOGPROBS: bool = False
|
VLLM_FLAT_LOGPROBS: bool = False
|
||||||
|
|
||||||
@ -1476,6 +1477,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: bool(
|
"VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: bool(
|
||||||
int(os.getenv("VLLM_DISABLE_SHARED_EXPERTS_STREAM", "0"))
|
int(os.getenv("VLLM_DISABLE_SHARED_EXPERTS_STREAM", "0"))
|
||||||
),
|
),
|
||||||
|
# Limits when we run shared_experts in a separate stream.
|
||||||
|
# We found out that for large batch sizes, the separate stream
|
||||||
|
# execution is not beneficial (most likely because of the input clone)
|
||||||
|
# TODO(alexm-redhat): Tune to be more dynamic based on GPU type
|
||||||
|
"VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD": lambda: int(
|
||||||
|
int(os.getenv("VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD", 256))
|
||||||
|
),
|
||||||
# Format for saving torch.compile cache artifacts
|
# Format for saving torch.compile cache artifacts
|
||||||
# - "binary": saves as binary file
|
# - "binary": saves as binary file
|
||||||
# Safe for multiple vllm serve processes accessing the same torch compile cache.
|
# Safe for multiple vllm serve processes accessing the same torch compile cache.
|
||||||
|
|||||||
@ -48,7 +48,11 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
|||||||
)
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.math_utils import cdiv, round_up
|
from vllm.utils.math_utils import cdiv, round_up
|
||||||
from vllm.utils.torch_utils import current_stream, direct_register_custom_op
|
from vllm.utils.torch_utils import (
|
||||||
|
aux_stream,
|
||||||
|
current_stream,
|
||||||
|
direct_register_custom_op,
|
||||||
|
)
|
||||||
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
|
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
|
||||||
|
|
||||||
if current_platform.is_cuda_alike():
|
if current_platform.is_cuda_alike():
|
||||||
@ -331,7 +335,11 @@ class FusedMoE(CustomOp):
|
|||||||
logger.info_once("Disabling MoE shared_experts cuda stream")
|
logger.info_once("Disabling MoE shared_experts cuda stream")
|
||||||
self.shared_experts_stream = None
|
self.shared_experts_stream = None
|
||||||
else:
|
else:
|
||||||
self.shared_experts_stream = torch.cuda.Stream()
|
# TODO(rob): enable shared expert overlap with non-cuda.
|
||||||
|
# aux_stream() returns None on non-cuda platforms.
|
||||||
|
self.shared_experts_stream = aux_stream()
|
||||||
|
if self.shared_experts_stream is not None:
|
||||||
|
logger.info_once("Enabled separate cuda stream for MoE shared_experts")
|
||||||
|
|
||||||
if params_dtype is None:
|
if params_dtype is None:
|
||||||
params_dtype = torch.get_default_dtype()
|
params_dtype = torch.get_default_dtype()
|
||||||
@ -1606,7 +1614,9 @@ class FusedMoE(CustomOp):
|
|||||||
if has_separate_shared_experts:
|
if has_separate_shared_experts:
|
||||||
assert not isinstance(final_hidden_states, tuple)
|
assert not isinstance(final_hidden_states, tuple)
|
||||||
assert self.shared_experts is not None
|
assert self.shared_experts is not None
|
||||||
|
|
||||||
shared_output = self.shared_experts(staged_hidden_states)
|
shared_output = self.shared_experts(staged_hidden_states)
|
||||||
|
|
||||||
final_hidden_states = (
|
final_hidden_states = (
|
||||||
shared_output,
|
shared_output,
|
||||||
final_hidden_states,
|
final_hidden_states,
|
||||||
@ -1684,13 +1694,34 @@ class FusedMoE(CustomOp):
|
|||||||
|
|
||||||
use_chunked_impl = self.use_dp_chunking
|
use_chunked_impl = self.use_dp_chunking
|
||||||
|
|
||||||
if (
|
use_shared_experts_stream = (
|
||||||
has_separate_shared_experts
|
has_separate_shared_experts
|
||||||
and not use_chunked_impl
|
and not use_chunked_impl
|
||||||
and self.shared_experts_stream is not None
|
and self.shared_experts_stream is not None
|
||||||
):
|
and (
|
||||||
# Start the separate shared experts stream here since we want
|
hidden_states.shape[0]
|
||||||
# to run in parallel with the router/gate (next op below)
|
<= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_shared_experts_stream:
|
||||||
|
assert self.shared_experts_stream is not None
|
||||||
|
|
||||||
|
# Clone BEFORE switching streams to avoid race condition
|
||||||
|
# where routed_expert kernel may mutate hidden_states.
|
||||||
|
hidden_states_clone = hidden_states.clone()
|
||||||
|
|
||||||
|
# Record that the clone will be used by shared_experts_stream
|
||||||
|
# to avoid gc issue from deallocation of hidden_states_clone
|
||||||
|
# For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
|
||||||
|
# NOTE: We dont need shared_output.record_stream(current_stream())
|
||||||
|
# because we synch the streams before using shared_output.
|
||||||
|
hidden_states_clone.record_stream(self.shared_experts_stream)
|
||||||
|
|
||||||
|
# Mark sync start point for the separate shared experts
|
||||||
|
# stream here since we want to run in parallel with the
|
||||||
|
# router/gate (next op below)
|
||||||
|
assert self.shared_experts_stream is not None
|
||||||
self.shared_experts_stream.wait_stream(current_stream())
|
self.shared_experts_stream.wait_stream(current_stream())
|
||||||
|
|
||||||
# If router/gate provided, then apply it here.
|
# If router/gate provided, then apply it here.
|
||||||
@ -1709,33 +1740,6 @@ class FusedMoE(CustomOp):
|
|||||||
self.quant_method, FusedMoEModularMethod
|
self.quant_method, FusedMoEModularMethod
|
||||||
)
|
)
|
||||||
|
|
||||||
# If there are shared experts but we are not using a modular kernel, the
|
|
||||||
# shared experts must be called here
|
|
||||||
if has_separate_shared_experts:
|
|
||||||
assert self.shared_experts is not None
|
|
||||||
|
|
||||||
if self.shared_experts_stream is not None:
|
|
||||||
# Clone BEFORE switching streams to avoid race condition
|
|
||||||
# where routed_expert kernel may mutate hidden_states.
|
|
||||||
hidden_states_clone = hidden_states.clone()
|
|
||||||
self.shared_experts_stream.wait_stream(current_stream())
|
|
||||||
|
|
||||||
# Run shared experts in parallel on a separate stream
|
|
||||||
with torch.cuda.stream(self.shared_experts_stream):
|
|
||||||
shared_output = self.shared_experts(hidden_states_clone)
|
|
||||||
|
|
||||||
# Record that the clone will be used by shared_experts_stream
|
|
||||||
# to avoid gc issue from deallocation of hidden_states_clone
|
|
||||||
# For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501
|
|
||||||
# NOTE: we dont need shared_output.record_stream(current_stream())
|
|
||||||
# because we synch the streams before using shared_output.
|
|
||||||
hidden_states_clone.record_stream(self.shared_experts_stream)
|
|
||||||
|
|
||||||
else:
|
|
||||||
shared_output = self.shared_experts(hidden_states)
|
|
||||||
else:
|
|
||||||
shared_output = None
|
|
||||||
|
|
||||||
ctx = get_forward_context()
|
ctx = get_forward_context()
|
||||||
sp_ctx = (
|
sp_ctx = (
|
||||||
ctx.dp_metadata.sp_local_sizes(self.sp_size)
|
ctx.dp_metadata.sp_local_sizes(self.sp_size)
|
||||||
@ -1776,12 +1780,21 @@ class FusedMoE(CustomOp):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if has_separate_shared_experts:
|
if has_separate_shared_experts:
|
||||||
assert not isinstance(final_hidden_states, tuple)
|
|
||||||
assert self.shared_experts is not None
|
assert self.shared_experts is not None
|
||||||
|
|
||||||
# Wait for the parallel shared experts stream to finish here
|
if use_shared_experts_stream:
|
||||||
if self.shared_experts_stream is not None:
|
# Run shared experts in parallel on a separate stream
|
||||||
|
# NOTE: We start the separate stream here and mark the
|
||||||
|
# sync end point immediately after it is done. This is
|
||||||
|
# important to avoid excessive stream allocations by the cuda
|
||||||
|
# graph replay later.
|
||||||
|
with torch.cuda.stream(self.shared_experts_stream):
|
||||||
|
# Note that hidden_states clone() is necessary here to avoid
|
||||||
|
# conflict with the main stream
|
||||||
|
shared_output = self.shared_experts(hidden_states_clone)
|
||||||
current_stream().wait_stream(self.shared_experts_stream)
|
current_stream().wait_stream(self.shared_experts_stream)
|
||||||
|
else:
|
||||||
|
shared_output = self.shared_experts(hidden_states)
|
||||||
|
|
||||||
final_hidden_states = (
|
final_hidden_states = (
|
||||||
shared_output,
|
shared_output,
|
||||||
|
|||||||
@ -409,6 +409,30 @@ def current_stream() -> torch.cuda.Stream:
|
|||||||
return _current_stream_tls.value
|
return _current_stream_tls.value
|
||||||
|
|
||||||
|
|
||||||
|
# Global auxilary stream for running operations in background streams.
|
||||||
|
# We have single global auxilary stream to avoid an explosion of streams
|
||||||
|
# for every layer (and make profiling look sane).
|
||||||
|
#
|
||||||
|
# aux_stream() is currently used for:
|
||||||
|
# - MoE shared_expert overlap with router
|
||||||
|
_aux_stream: torch.cuda.Stream | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def aux_stream() -> torch.cuda.Stream | None:
|
||||||
|
"""
|
||||||
|
Ensures aux_stream is initialized only once
|
||||||
|
"""
|
||||||
|
global _aux_stream
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
# TODO: validate this works properly on ROCm platform.
|
||||||
|
if _aux_stream is None and current_platform.is_cuda():
|
||||||
|
_aux_stream = torch.cuda.Stream()
|
||||||
|
|
||||||
|
return _aux_stream
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=8)
|
@lru_cache(maxsize=8)
|
||||||
def _cuda_device_count_stateless(cuda_visible_devices: str | None = None) -> int:
|
def _cuda_device_count_stateless(cuda_visible_devices: str | None = None) -> int:
|
||||||
# Note: cuda_visible_devices is not used, but we keep it as an argument for
|
# Note: cuda_visible_devices is not used, but we keep it as an argument for
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user