diff --git a/vllm/envs.py b/vllm/envs.py index 1c3247a315c1..0530938c32f9 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -222,6 +222,7 @@ if TYPE_CHECKING: VLLM_USE_FBGEMM: bool = False VLLM_GC_DEBUG: str = "" VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False + VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256 VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary" VLLM_FLAT_LOGPROBS: bool = False @@ -1476,6 +1477,13 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: bool( int(os.getenv("VLLM_DISABLE_SHARED_EXPERTS_STREAM", "0")) ), + # Limits when we run shared_experts in a separate stream. + # We found out that for large batch sizes, the separate stream + # execution is not beneficial (most likely because of the input clone) + # TODO(alexm-redhat): Tune to be more dynamic based on GPU type + "VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD": lambda: int( + int(os.getenv("VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD", 256)) + ), # Format for saving torch.compile cache artifacts # - "binary": saves as binary file # Safe for multiple vllm serve processes accessing the same torch compile cache. diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 3bd7c54c520c..aed8245cbd83 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -48,7 +48,11 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( ) from vllm.platforms import current_platform from vllm.utils.math_utils import cdiv, round_up -from vllm.utils.torch_utils import current_stream, direct_register_custom_op +from vllm.utils.torch_utils import ( + aux_stream, + current_stream, + direct_register_custom_op, +) from vllm.v1.worker.ubatching import dbo_current_ubatch_id if current_platform.is_cuda_alike(): @@ -331,7 +335,11 @@ class FusedMoE(CustomOp): logger.info_once("Disabling MoE shared_experts cuda stream") self.shared_experts_stream = None else: - self.shared_experts_stream = torch.cuda.Stream() + # TODO(rob): enable shared expert overlap with non-cuda. + # aux_stream() returns None on non-cuda platforms. + self.shared_experts_stream = aux_stream() + if self.shared_experts_stream is not None: + logger.info_once("Enabled separate cuda stream for MoE shared_experts") if params_dtype is None: params_dtype = torch.get_default_dtype() @@ -1606,7 +1614,9 @@ class FusedMoE(CustomOp): if has_separate_shared_experts: assert not isinstance(final_hidden_states, tuple) assert self.shared_experts is not None + shared_output = self.shared_experts(staged_hidden_states) + final_hidden_states = ( shared_output, final_hidden_states, @@ -1684,13 +1694,34 @@ class FusedMoE(CustomOp): use_chunked_impl = self.use_dp_chunking - if ( + use_shared_experts_stream = ( has_separate_shared_experts and not use_chunked_impl and self.shared_experts_stream is not None - ): - # Start the separate shared experts stream here since we want - # to run in parallel with the router/gate (next op below) + and ( + hidden_states.shape[0] + <= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD + ) + ) + + if use_shared_experts_stream: + assert self.shared_experts_stream is not None + + # Clone BEFORE switching streams to avoid race condition + # where routed_expert kernel may mutate hidden_states. + hidden_states_clone = hidden_states.clone() + + # Record that the clone will be used by shared_experts_stream + # to avoid gc issue from deallocation of hidden_states_clone + # For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501 + # NOTE: We dont need shared_output.record_stream(current_stream()) + # because we synch the streams before using shared_output. + hidden_states_clone.record_stream(self.shared_experts_stream) + + # Mark sync start point for the separate shared experts + # stream here since we want to run in parallel with the + # router/gate (next op below) + assert self.shared_experts_stream is not None self.shared_experts_stream.wait_stream(current_stream()) # If router/gate provided, then apply it here. @@ -1709,33 +1740,6 @@ class FusedMoE(CustomOp): self.quant_method, FusedMoEModularMethod ) - # If there are shared experts but we are not using a modular kernel, the - # shared experts must be called here - if has_separate_shared_experts: - assert self.shared_experts is not None - - if self.shared_experts_stream is not None: - # Clone BEFORE switching streams to avoid race condition - # where routed_expert kernel may mutate hidden_states. - hidden_states_clone = hidden_states.clone() - self.shared_experts_stream.wait_stream(current_stream()) - - # Run shared experts in parallel on a separate stream - with torch.cuda.stream(self.shared_experts_stream): - shared_output = self.shared_experts(hidden_states_clone) - - # Record that the clone will be used by shared_experts_stream - # to avoid gc issue from deallocation of hidden_states_clone - # For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501 - # NOTE: we dont need shared_output.record_stream(current_stream()) - # because we synch the streams before using shared_output. - hidden_states_clone.record_stream(self.shared_experts_stream) - - else: - shared_output = self.shared_experts(hidden_states) - else: - shared_output = None - ctx = get_forward_context() sp_ctx = ( ctx.dp_metadata.sp_local_sizes(self.sp_size) @@ -1776,12 +1780,21 @@ class FusedMoE(CustomOp): ) if has_separate_shared_experts: - assert not isinstance(final_hidden_states, tuple) assert self.shared_experts is not None - # Wait for the parallel shared experts stream to finish here - if self.shared_experts_stream is not None: + if use_shared_experts_stream: + # Run shared experts in parallel on a separate stream + # NOTE: We start the separate stream here and mark the + # sync end point immediately after it is done. This is + # important to avoid excessive stream allocations by the cuda + # graph replay later. + with torch.cuda.stream(self.shared_experts_stream): + # Note that hidden_states clone() is necessary here to avoid + # conflict with the main stream + shared_output = self.shared_experts(hidden_states_clone) current_stream().wait_stream(self.shared_experts_stream) + else: + shared_output = self.shared_experts(hidden_states) final_hidden_states = ( shared_output, diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py index fd5c1b73f191..7c094e14cff7 100644 --- a/vllm/utils/torch_utils.py +++ b/vllm/utils/torch_utils.py @@ -409,6 +409,30 @@ def current_stream() -> torch.cuda.Stream: return _current_stream_tls.value +# Global auxilary stream for running operations in background streams. +# We have single global auxilary stream to avoid an explosion of streams +# for every layer (and make profiling look sane). +# +# aux_stream() is currently used for: +# - MoE shared_expert overlap with router +_aux_stream: torch.cuda.Stream | None = None + + +def aux_stream() -> torch.cuda.Stream | None: + """ + Ensures aux_stream is initialized only once + """ + global _aux_stream + + from vllm.platforms import current_platform + + # TODO: validate this works properly on ROCm platform. + if _aux_stream is None and current_platform.is_cuda(): + _aux_stream = torch.cuda.Stream() + + return _aux_stream + + @lru_cache(maxsize=8) def _cuda_device_count_stateless(cuda_visible_devices: str | None = None) -> int: # Note: cuda_visible_devices is not used, but we keep it as an argument for