pplx format

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-06-02 19:17:15 +00:00
parent 243eac58a4
commit d46397661f
2 changed files with 7 additions and 9 deletions

View File

@ -661,8 +661,6 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
f"Hidden size mismatch {hidden_states.size(-1)} "
f"!= {w1.size(2)}")
# print("in batched triton experts", hidden_states.shape, expert_num_tokens)
assert hidden_states.is_contiguous(
), "Hidden_states must be contiguous"
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"

View File

@ -8,9 +8,8 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
from vllm.v1.worker.ubatching import (
get_current_ubatch_context, yield_and_switch_from_compute_to_comm_impl,
yield_and_switch_from_comm_to_compute_impl
)
get_current_ubatch_context, yield_and_switch_from_comm_to_compute_impl,
yield_and_switch_from_compute_to_comm_impl)
# Note use: layer.get_all_to_all() to get an AllToAll instance
@ -124,14 +123,14 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
do_send=send,
do_recv=not send,
)
ubatch_ctx = get_current_ubatch_context()
ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1
yield_and_switch_from_compute_to_comm_impl(schedule="default")
dispatch(True) # Send
dispatch(True) # Send
# torch.cuda.synchronize()
# print(f"{ubatch_id} AFTER SEND SYNC", flush=True)
dispatch(False) # Recv
dispatch(False) # Recv
# torch.cuda.synchronize()
# print(f"{ubatch_id} AFTER RECV SYNC", flush=True)
yield_and_switch_from_comm_to_compute_impl(schedule="default")
@ -174,6 +173,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
do_send=send,
do_recv=not send,
)
yield_and_switch_from_compute_to_comm_impl(schedule="default")
combine(True)
# torch.cuda.synchronize()
@ -181,4 +181,4 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
combine(False)
# print(f"{ubatch_id} AFTER COMBINE RECV SYNC", flush=True)
yield_and_switch_from_comm_to_compute_impl(schedule="default")
# torch.cuda.synchronize()
# torch.cuda.synchronize()