From 631be12edb2d4a61cc21ea99ecc0eadb3ff7d6e8 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Thu, 3 Jul 2025 13:16:34 +0000 Subject: [PATCH] refactoring pplx_prepare_finalize.py Signed-off-by: Sage Moore --- .../layers/fused_moe/pplx_prepare_finalize.py | 56 ++++++------------- 1 file changed, 17 insertions(+), 39 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 6e577cfd9e04f..415b021c9d751 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -58,8 +58,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): num_tokens = a1.size(0) # M hidden_dim = a1.size(-1) # K ubatch_ctx = get_current_ubatch_context() - ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1 - a2a_idx = 0 if ubatch_id == -1 else ubatch_id + a2a_idx = ubatch_ctx.id if ubatch_ctx is not None else 0 assert rank_topk_ids.size(0) == num_tokens # assert expert_map is None, "NYI" @@ -121,28 +120,17 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): # There's not much point setting this unless it is != indices.size(0) bound_m: Optional[torch.Tensor] = None - def dispatch(send: bool): - self.a2as[a2a_idx].dispatch( - out_expert_num_tokens=expert_num_tokens, - out_expert_x=expert_x, - out_expert_x_scale=expert_x_scale, - dp_x=a1q, - dp_x_scale=a1q_scale, - indices=rank_topk_ids, - bound_m=bound_m, - do_send=send, - do_recv=not send, - ) - yield_and_switch_from_compute_to_comm_impl(schedule="default") - dispatch(True) # Send - # torch.cuda.synchronize() - # print(f"{ubatch_id} AFTER SEND SYNC", flush=True) - dispatch(False) # Recv - # torch.cuda.synchronize() - # print(f"{ubatch_id} AFTER RECV SYNC", flush=True) + self.a2as[a2a_idx].dispatch( + out_expert_num_tokens=expert_num_tokens, + out_expert_x=expert_x, + out_expert_x_scale=expert_x_scale, + dp_x=a1q, + dp_x_scale=a1q_scale, + indices=rank_topk_ids, + bound_m=bound_m, + ) yield_and_switch_from_comm_to_compute_impl(schedule="default") - # torch.cuda.synchronize() if expert_x_scale is not None: expert_x_scale = expert_x_scale[:, :, 0:1] @@ -174,22 +162,12 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): if apply_router_weight_on_input: topk_weights = torch.ones_like(topk_weights) - def combine(send: bool): - self.a2as[a2a_idx].combine( - out_tokens=output, - indices=topk_ids, - weights=topk_weights, - expert_y=fused_expert_output, - bound_m=bound_m, - do_send=send, - do_recv=not send, - ) - yield_and_switch_from_compute_to_comm_impl(schedule="default") - combine(True) - # torch.cuda.synchronize() - # print(f"{ubatch_id} AFTER COMBINE SEND SYNC", flush=True) - combine(False) - # print(f"{ubatch_id} AFTER COMBINE RECV SYNC", flush=True) + self.a2as[a2a_idx].combine( + out_tokens=output, + indices=topk_ids, + weights=topk_weights, + expert_y=fused_expert_output, + bound_m=bound_m, + ) yield_and_switch_from_comm_to_compute_impl(schedule="default") - # torch.cuda.synchronize()