mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-07 20:07:07 +08:00
refactoring pplx_prepare_finalize.py
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
a9d47e8652
commit
631be12edb
@ -58,8 +58,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
num_tokens = a1.size(0) # M
|
||||
hidden_dim = a1.size(-1) # K
|
||||
ubatch_ctx = get_current_ubatch_context()
|
||||
ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1
|
||||
a2a_idx = 0 if ubatch_id == -1 else ubatch_id
|
||||
a2a_idx = ubatch_ctx.id if ubatch_ctx is not None else 0
|
||||
|
||||
assert rank_topk_ids.size(0) == num_tokens
|
||||
# assert expert_map is None, "NYI"
|
||||
@ -121,28 +120,17 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
# There's not much point setting this unless it is != indices.size(0)
|
||||
bound_m: Optional[torch.Tensor] = None
|
||||
|
||||
def dispatch(send: bool):
|
||||
self.a2as[a2a_idx].dispatch(
|
||||
out_expert_num_tokens=expert_num_tokens,
|
||||
out_expert_x=expert_x,
|
||||
out_expert_x_scale=expert_x_scale,
|
||||
dp_x=a1q,
|
||||
dp_x_scale=a1q_scale,
|
||||
indices=rank_topk_ids,
|
||||
bound_m=bound_m,
|
||||
do_send=send,
|
||||
do_recv=not send,
|
||||
)
|
||||
|
||||
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
||||
dispatch(True) # Send
|
||||
# torch.cuda.synchronize()
|
||||
# print(f"{ubatch_id} AFTER SEND SYNC", flush=True)
|
||||
dispatch(False) # Recv
|
||||
# torch.cuda.synchronize()
|
||||
# print(f"{ubatch_id} AFTER RECV SYNC", flush=True)
|
||||
self.a2as[a2a_idx].dispatch(
|
||||
out_expert_num_tokens=expert_num_tokens,
|
||||
out_expert_x=expert_x,
|
||||
out_expert_x_scale=expert_x_scale,
|
||||
dp_x=a1q,
|
||||
dp_x_scale=a1q_scale,
|
||||
indices=rank_topk_ids,
|
||||
bound_m=bound_m,
|
||||
)
|
||||
yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
||||
# torch.cuda.synchronize()
|
||||
if expert_x_scale is not None:
|
||||
expert_x_scale = expert_x_scale[:, :, 0:1]
|
||||
|
||||
@ -174,22 +162,12 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
if apply_router_weight_on_input:
|
||||
topk_weights = torch.ones_like(topk_weights)
|
||||
|
||||
def combine(send: bool):
|
||||
self.a2as[a2a_idx].combine(
|
||||
out_tokens=output,
|
||||
indices=topk_ids,
|
||||
weights=topk_weights,
|
||||
expert_y=fused_expert_output,
|
||||
bound_m=bound_m,
|
||||
do_send=send,
|
||||
do_recv=not send,
|
||||
)
|
||||
|
||||
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
||||
combine(True)
|
||||
# torch.cuda.synchronize()
|
||||
# print(f"{ubatch_id} AFTER COMBINE SEND SYNC", flush=True)
|
||||
combine(False)
|
||||
# print(f"{ubatch_id} AFTER COMBINE RECV SYNC", flush=True)
|
||||
self.a2as[a2a_idx].combine(
|
||||
out_tokens=output,
|
||||
indices=topk_ids,
|
||||
weights=topk_weights,
|
||||
expert_y=fused_expert_output,
|
||||
bound_m=bound_m,
|
||||
)
|
||||
yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
||||
# torch.cuda.synchronize()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user