From 930efd02abeacd1d415ac838389106e80d7efc97 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Tue, 24 Jun 2025 21:53:54 +0000 Subject: [PATCH] yields now work with deepep_ll Signed-off-by: Sage Moore --- .../fused_moe/deepep_ll_prepare_finalize.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index c2aa5a831b0f5..f9c50ed40de69 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -60,7 +60,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): # The dispatch function returns a handle that the combine function # requires. We store the handle here so it is available to the # combine function. - self.handle = None + self.handles: list[Optional[tuple]] = [None, None] def max_num_tokens_per_rank(self) -> Optional[int]: return self.max_tokens_per_rank @@ -155,7 +155,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): a1 = a1 * rank_topk_weights.to(a1.dtype) # Dispatch - expert_x, expert_num_tokens, self.handle, event, hook = \ + yield_and_switch_from_compute_to_comm_impl(schedule="default") + expert_x, expert_num_tokens, handle, event, hook = \ self.buffers[a2a_idx].low_latency_dispatch(a1, rank_topk_ids, self.max_tokens_per_rank, @@ -163,6 +164,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): use_fp8=self.use_fp8_dispatch, async_finish=False, return_recv_hook=False) + self.handles[a2a_idx] = handle + yield_and_switch_from_comm_to_compute_impl(schedule="default") expert_x, expert_x_scale = self._do_quant(expert_x, a1_scale, a2_scale, a1.dtype) @@ -173,10 +176,11 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): topk_weights: torch.Tensor, topk_ids: torch.Tensor, apply_router_weight_on_input: bool) -> None: - assert self.handle is not None ubatch_ctx = get_current_ubatch_context() ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1 a2a_idx = 0 if ubatch_id == -1 else ubatch_id + handle = self.handles[a2a_idx] + assert handle is not None combine_topk_weights = topk_weights if apply_router_weight_on_input: @@ -184,12 +188,16 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): combine_topk_weights = torch.ones_like(topk_weights) # TODO (varun) : Enable zero copy mode + yield_and_switch_from_compute_to_comm_impl(schedule="default") _, event, hook = self.buffers[a2a_idx].low_latency_combine( fused_expert_output, topk_ids, combine_topk_weights, - self.handle, + handle, async_finish=False, zero_copy=False, return_recv_hook=False, out=output) + # event.current_stream_wait() + yield_and_switch_from_comm_to_compute_impl(schedule="default") +