mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-08 22:08:03 +08:00
refactoring pplx_prepare_finalize.py
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
a9d47e8652
commit
631be12edb
@ -58,8 +58,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
num_tokens = a1.size(0) # M
|
num_tokens = a1.size(0) # M
|
||||||
hidden_dim = a1.size(-1) # K
|
hidden_dim = a1.size(-1) # K
|
||||||
ubatch_ctx = get_current_ubatch_context()
|
ubatch_ctx = get_current_ubatch_context()
|
||||||
ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1
|
a2a_idx = ubatch_ctx.id if ubatch_ctx is not None else 0
|
||||||
a2a_idx = 0 if ubatch_id == -1 else ubatch_id
|
|
||||||
|
|
||||||
assert rank_topk_ids.size(0) == num_tokens
|
assert rank_topk_ids.size(0) == num_tokens
|
||||||
# assert expert_map is None, "NYI"
|
# assert expert_map is None, "NYI"
|
||||||
@ -121,28 +120,17 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
# There's not much point setting this unless it is != indices.size(0)
|
# There's not much point setting this unless it is != indices.size(0)
|
||||||
bound_m: Optional[torch.Tensor] = None
|
bound_m: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
def dispatch(send: bool):
|
|
||||||
self.a2as[a2a_idx].dispatch(
|
|
||||||
out_expert_num_tokens=expert_num_tokens,
|
|
||||||
out_expert_x=expert_x,
|
|
||||||
out_expert_x_scale=expert_x_scale,
|
|
||||||
dp_x=a1q,
|
|
||||||
dp_x_scale=a1q_scale,
|
|
||||||
indices=rank_topk_ids,
|
|
||||||
bound_m=bound_m,
|
|
||||||
do_send=send,
|
|
||||||
do_recv=not send,
|
|
||||||
)
|
|
||||||
|
|
||||||
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
||||||
dispatch(True) # Send
|
self.a2as[a2a_idx].dispatch(
|
||||||
# torch.cuda.synchronize()
|
out_expert_num_tokens=expert_num_tokens,
|
||||||
# print(f"{ubatch_id} AFTER SEND SYNC", flush=True)
|
out_expert_x=expert_x,
|
||||||
dispatch(False) # Recv
|
out_expert_x_scale=expert_x_scale,
|
||||||
# torch.cuda.synchronize()
|
dp_x=a1q,
|
||||||
# print(f"{ubatch_id} AFTER RECV SYNC", flush=True)
|
dp_x_scale=a1q_scale,
|
||||||
|
indices=rank_topk_ids,
|
||||||
|
bound_m=bound_m,
|
||||||
|
)
|
||||||
yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
||||||
# torch.cuda.synchronize()
|
|
||||||
if expert_x_scale is not None:
|
if expert_x_scale is not None:
|
||||||
expert_x_scale = expert_x_scale[:, :, 0:1]
|
expert_x_scale = expert_x_scale[:, :, 0:1]
|
||||||
|
|
||||||
@ -174,22 +162,12 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
if apply_router_weight_on_input:
|
if apply_router_weight_on_input:
|
||||||
topk_weights = torch.ones_like(topk_weights)
|
topk_weights = torch.ones_like(topk_weights)
|
||||||
|
|
||||||
def combine(send: bool):
|
|
||||||
self.a2as[a2a_idx].combine(
|
|
||||||
out_tokens=output,
|
|
||||||
indices=topk_ids,
|
|
||||||
weights=topk_weights,
|
|
||||||
expert_y=fused_expert_output,
|
|
||||||
bound_m=bound_m,
|
|
||||||
do_send=send,
|
|
||||||
do_recv=not send,
|
|
||||||
)
|
|
||||||
|
|
||||||
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
||||||
combine(True)
|
self.a2as[a2a_idx].combine(
|
||||||
# torch.cuda.synchronize()
|
out_tokens=output,
|
||||||
# print(f"{ubatch_id} AFTER COMBINE SEND SYNC", flush=True)
|
indices=topk_ids,
|
||||||
combine(False)
|
weights=topk_weights,
|
||||||
# print(f"{ubatch_id} AFTER COMBINE RECV SYNC", flush=True)
|
expert_y=fused_expert_output,
|
||||||
|
bound_m=bound_m,
|
||||||
|
)
|
||||||
yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
||||||
# torch.cuda.synchronize()
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user