From 17a7ceef273b55f7dc5badfac58be5088e1e752d Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Thu, 3 Jul 2025 13:35:21 +0000 Subject: [PATCH] cleanup deepep ll Signed-off-by: Sage Moore --- .../fused_moe/deepep_ll_prepare_finalize.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index d98e607795792..244a63ddd82b0 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -131,8 +131,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): hidden_size = a1.size(1) ubatch_ctx = get_current_ubatch_context() - ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1 - a2a_idx = 0 if ubatch_id == -1 else ubatch_id + a2a_idx = ubatch_ctx.id if ubatch_ctx is not None else 0 # assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \ # (f"Hidden Size {hidden_size} not in supported list of hidden sizes" # f"{self.SUPPORTED_HIDDEN_SIZES}") @@ -155,8 +154,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): a1 = a1 * rank_topk_weights.to(a1.dtype) # Dispatch - # yield_and_switch_from_compute_to_comm_impl(schedule="default") - expert_x, expert_num_tokens, handle, event, hook = \ + yield_and_switch_from_compute_to_comm_impl(schedule="default") + expert_x, expert_num_tokens, handle, _, _= \ self.buffers[a2a_idx].low_latency_dispatch(a1, rank_topk_ids, self.max_tokens_per_rank, @@ -165,7 +164,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): async_finish=False, return_recv_hook=False) self.handles[a2a_idx] = handle - # yield_and_switch_from_comm_to_compute_impl(schedule="default") + yield_and_switch_from_comm_to_compute_impl(schedule="default") expert_x, expert_x_scale = self._do_quant(expert_x, a1_scale, a2_scale, a1.dtype) @@ -177,8 +176,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): apply_router_weight_on_input: bool) -> None: ubatch_ctx = get_current_ubatch_context() - ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1 - a2a_idx = 0 if ubatch_id == -1 else ubatch_id + a2a_idx = ubatch_ctx.id if ubatch_ctx is not None else 0 handle = self.handles[a2a_idx] assert handle is not None @@ -188,8 +186,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): combine_topk_weights = torch.ones_like(topk_weights) # TODO (varun) : Enable zero copy mode - # yield_and_switch_from_compute_to_comm_impl(schedule="default") - _, event, hook = self.buffers[a2a_idx].low_latency_combine( + yield_and_switch_from_compute_to_comm_impl(schedule="default") + _ = self.buffers[a2a_idx].low_latency_combine( fused_expert_output, topk_ids, combine_topk_weights, @@ -198,6 +196,5 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): zero_copy=False, return_recv_hook=False, out=output) - # event.current_stream_wait() - # yield_and_switch_from_comm_to_compute_impl(schedule="default") + yield_and_switch_from_comm_to_compute_impl(schedule="default")