diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py index 43974ba917e42..c6dc95acdb636 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py @@ -50,6 +50,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): prepare_finalize, old_quant_method.select_gemm_impl(prepare_finalize, moe_layer), shared_experts, + getattr(moe_layer, "shared_experts_stream", None), ), ) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index d9525a7439c3e..b2f554efd8a6f 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -850,6 +850,45 @@ class FusedMoE(CustomOp): dp_size=get_dp_group().world_size, ) + def _maybe_setup_shared_experts_stream( + self, + hidden_states: torch.Tensor, + has_separate_shared_experts: bool, + use_chunked_impl: bool, + ) -> tuple[bool, torch.Tensor | None]: + use_shared_experts_stream = ( + has_separate_shared_experts + and not use_chunked_impl + and self.shared_experts_stream is not None + and ( + hidden_states.shape[0] + <= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD + ) + ) + + hidden_states_clone: torch.Tensor | None = None + if use_shared_experts_stream: + assert self.shared_experts_stream is not None + + # Clone BEFORE switching streams to avoid race condition + # where routed_expert kernel may mutate hidden_states. + hidden_states_clone = hidden_states.clone() + + # Record that the clone will be used by shared_experts_stream + # to avoid gc issue from deallocation of hidden_states_clone + # For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501 + # NOTE: We dont need shared_output.record_stream(current_stream()) + # because we synch the streams before using shared_output. + hidden_states_clone.record_stream(self.shared_experts_stream) + + # Mark sync start point for the separate shared experts + # stream here since we want to run in parallel with the + # router/gate (next op below) + assert self.shared_experts_stream is not None + self.shared_experts_stream.wait_stream(current_stream()) + + return use_shared_experts_stream, hidden_states_clone + def _load_per_tensor_weight_scale( self, shard_id: str, @@ -1819,36 +1858,12 @@ class FusedMoE(CustomOp): use_chunked_impl = self.use_dp_chunking - use_shared_experts_stream = ( - has_separate_shared_experts - and not use_chunked_impl - and self.shared_experts_stream is not None - and ( - hidden_states.shape[0] - <= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD + use_shared_experts_stream, hidden_states_clone = ( + self._maybe_setup_shared_experts_stream( + hidden_states, has_separate_shared_experts, use_chunked_impl ) ) - if use_shared_experts_stream: - assert self.shared_experts_stream is not None - - # Clone BEFORE switching streams to avoid race condition - # where routed_expert kernel may mutate hidden_states. - hidden_states_clone = hidden_states.clone() - - # Record that the clone will be used by shared_experts_stream - # to avoid gc issue from deallocation of hidden_states_clone - # For more details: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html # noqa: E501 - # NOTE: We dont need shared_output.record_stream(current_stream()) - # because we synch the streams before using shared_output. - hidden_states_clone.record_stream(self.shared_experts_stream) - - # Mark sync start point for the separate shared experts - # stream here since we want to run in parallel with the - # router/gate (next op below) - assert self.shared_experts_stream is not None - self.shared_experts_stream.wait_stream(current_stream()) - # If router/gate provided, then apply it here. # (Note: This code runs only when "overlapped mode" is on to allow # parallel execution of shared experts with the FusedMoE via diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 093affe51f503..4af7af9257dfa 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -16,6 +16,7 @@ from vllm.model_executor.layers.fused_moe.utils import ( count_expert_num_tokens, disable_inplace, ) +from vllm.platforms import current_platform from vllm.utils.math_utils import cdiv from vllm.v1.worker.ubatching import ( dbo_current_ubatch_id, @@ -709,11 +710,13 @@ class FusedMoEModularKernel(torch.nn.Module): prepare_finalize: FusedMoEPrepareAndFinalize, fused_experts: FusedMoEPermuteExpertsUnpermute, shared_experts: torch.nn.Module | None = None, + shared_experts_stream: torch.cuda.Stream | None = None, ): super().__init__() self.prepare_finalize = prepare_finalize self.fused_experts = fused_experts self.shared_experts = shared_experts + self.shared_experts_stream = shared_experts_stream self._post_init_setup() assert ( @@ -890,6 +893,34 @@ class FusedMoEModularKernel(torch.nn.Module): expert_num_tokens_cpu=c_expert_num_tokens_cpu, ) + def _maybe_setup_shared_experts_stream( + self, hidden_states: torch.Tensor + ) -> tuple[bool, torch.Tensor | None]: + # decide whether to run shared experts on a separate CUDA stream to + # overlap with the main fused MoE kernel. + use_shared_experts_stream = ( + self.shared_experts is not None + and self.shared_experts_stream is not None + and hidden_states.is_cuda + and ( + hidden_states.shape[0] + <= envs.VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD + ) + ) + + hidden_states_clone: torch.Tensor | None = None + if use_shared_experts_stream and self.shared_experts_stream is not None: + # TODO: Optimize this (complicated) + # Note: this clone adds overhead but is required + # for correctness with multiple CUDA streams and CUDA graph capture. + hidden_states_clone = hidden_states.clone() + # record that the clone will be used by the separate stream so its + # lifetime is correctly tracked. + hidden_states_clone.record_stream(self.shared_experts_stream) + self.shared_experts_stream.wait_stream(torch.cuda.current_stream()) + + return use_shared_experts_stream, hidden_states_clone + def _prepare( self, hidden_states: torch.Tensor, @@ -1077,12 +1108,30 @@ class FusedMoEModularKernel(torch.nn.Module): topk_weights: torch.Tensor, topk_ids: torch.Tensor, apply_router_weight_on_input: bool, + hidden_states_clone: torch.Tensor | None = None, + use_shared_experts_stream: bool = False, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """ The _finalize method is a wrapper around self.prepare_finalize.finalize that handles DBO, async and shared expert overlap. """ - shared_output: torch.Tensor | None = None + + def maybe_run_shared_experts() -> torch.Tensor | None: + if self.shared_experts is None: + return None + + if ( + not use_shared_experts_stream + or self.shared_experts_stream is not None + and (not hidden_states.is_cuda or not torch.cuda.is_available()) + ): + # fall back to running on the current stream + return self.shared_experts(hidden_states) + + assert hidden_states_clone is not None + # launch shared experts on the dedicated stream. + with torch.cuda.stream(self.shared_experts_stream): + return self.shared_experts(hidden_states_clone) if not self.prepare_finalize.supports_async(): assert not dbo_enabled() @@ -1095,8 +1144,7 @@ class FusedMoEModularKernel(torch.nn.Module): apply_router_weight_on_input, self.fused_experts.finalize_weight_and_reduce_impl(), ) - if self.shared_experts is not None: - shared_output = self.shared_experts(hidden_states) + shared_output = maybe_run_shared_experts() else: finalize_ret = self.prepare_finalize.finalize_async( output, @@ -1107,8 +1155,7 @@ class FusedMoEModularKernel(torch.nn.Module): self.fused_experts.finalize_weight_and_reduce_impl(), ) - if self.shared_experts is not None: - shared_output = self.shared_experts(hidden_states) + shared_output = maybe_run_shared_experts() # TODO(lucas): refactor this in the alternative schedules followup # currently unpack if we have hook + receiver pair or just @@ -1131,12 +1178,28 @@ class FusedMoEModularKernel(torch.nn.Module): receiver() + self._wait_for_shared_experts_stream(hidden_states, use_shared_experts_stream) + if self.shared_experts is None: return output else: assert shared_output is not None return shared_output, output + def _wait_for_shared_experts_stream( + self, hidden_states: torch.Tensor, use_shared_experts_stream: bool + ) -> None: + # ensure that any work enqueued on the shared_experts_stream is + # completed before the shared_output tensor is consumed + if ( + self.shared_experts is not None + and use_shared_experts_stream + and self.shared_experts_stream is not None + and hidden_states.is_cuda + and current_platform.is_cuda() + ): + torch.cuda.current_stream().wait_stream(self.shared_experts_stream) + def forward( self, hidden_states: torch.Tensor, @@ -1183,6 +1246,10 @@ class FusedMoEModularKernel(torch.nn.Module): else: output = torch.zeros_like(hidden_states) + use_shared_experts_stream, hidden_states_clone = ( + self._maybe_setup_shared_experts_stream(hidden_states) + ) + local_num_experts = w1.size(0) if global_num_experts == -1: global_num_experts = local_num_experts @@ -1219,4 +1286,6 @@ class FusedMoEModularKernel(torch.nn.Module): topk_weights, topk_ids, apply_router_weight_on_input, + hidden_states_clone=hidden_states_clone, + use_shared_experts_stream=use_shared_experts_stream, ) diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index 9bb976fb9ec93..e27e2eb32da0f 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -45,7 +45,8 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): assert topk == 1, ( "apply_router_weight_on_input is only implemented for topk=1" ) - a1.mul_(topk_weights.to(a1.dtype)) + # Note: do not use inplace for shared experts overlap + a1 = a1 * topk_weights.to(a1.dtype) a1q, a1q_scale = moe_kernel_quantize_input( a1,