refactoring pplx_prepare_finalize.py

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-07-03 13:16:34 +00:00
parent a9d47e8652
commit 631be12edb

View File

@ -58,8 +58,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
num_tokens = a1.size(0) # M
hidden_dim = a1.size(-1) # K
ubatch_ctx = get_current_ubatch_context()
ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1
a2a_idx = 0 if ubatch_id == -1 else ubatch_id
a2a_idx = ubatch_ctx.id if ubatch_ctx is not None else 0
assert rank_topk_ids.size(0) == num_tokens
# assert expert_map is None, "NYI"
@ -121,28 +120,17 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# There's not much point setting this unless it is != indices.size(0)
bound_m: Optional[torch.Tensor] = None
def dispatch(send: bool):
self.a2as[a2a_idx].dispatch(
out_expert_num_tokens=expert_num_tokens,
out_expert_x=expert_x,
out_expert_x_scale=expert_x_scale,
dp_x=a1q,
dp_x_scale=a1q_scale,
indices=rank_topk_ids,
bound_m=bound_m,
do_send=send,
do_recv=not send,
)
yield_and_switch_from_compute_to_comm_impl(schedule="default")
dispatch(True) # Send
# torch.cuda.synchronize()
# print(f"{ubatch_id} AFTER SEND SYNC", flush=True)
dispatch(False) # Recv
# torch.cuda.synchronize()
# print(f"{ubatch_id} AFTER RECV SYNC", flush=True)
self.a2as[a2a_idx].dispatch(
out_expert_num_tokens=expert_num_tokens,
out_expert_x=expert_x,
out_expert_x_scale=expert_x_scale,
dp_x=a1q,
dp_x_scale=a1q_scale,
indices=rank_topk_ids,
bound_m=bound_m,
)
yield_and_switch_from_comm_to_compute_impl(schedule="default")
# torch.cuda.synchronize()
if expert_x_scale is not None:
expert_x_scale = expert_x_scale[:, :, 0:1]
@ -174,22 +162,12 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
if apply_router_weight_on_input:
topk_weights = torch.ones_like(topk_weights)
def combine(send: bool):
self.a2as[a2a_idx].combine(
out_tokens=output,
indices=topk_ids,
weights=topk_weights,
expert_y=fused_expert_output,
bound_m=bound_m,
do_send=send,
do_recv=not send,
)
yield_and_switch_from_compute_to_comm_impl(schedule="default")
combine(True)
# torch.cuda.synchronize()
# print(f"{ubatch_id} AFTER COMBINE SEND SYNC", flush=True)
combine(False)
# print(f"{ubatch_id} AFTER COMBINE RECV SYNC", flush=True)
self.a2as[a2a_idx].combine(
out_tokens=output,
indices=topk_ids,
weights=topk_weights,
expert_y=fused_expert_output,
bound_m=bound_m,
)
yield_and_switch_from_comm_to_compute_impl(schedule="default")
# torch.cuda.synchronize()