diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 006c2b504541d..c2db793659312 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -661,8 +661,6 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): f"Hidden size mismatch {hidden_states.size(-1)} " f"!= {w1.size(2)}") - # print("in batched triton experts", hidden_states.shape, expert_num_tokens) - assert hidden_states.is_contiguous( ), "Hidden_states must be contiguous" assert w1.stride(-1) == 1, "Stride of last dimension must be 1" diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index c3bcbae46a9ae..bad45325117a7 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -8,9 +8,8 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.utils import ( moe_kernel_quantize_input) from vllm.v1.worker.ubatching import ( - get_current_ubatch_context, yield_and_switch_from_compute_to_comm_impl, - yield_and_switch_from_comm_to_compute_impl -) + get_current_ubatch_context, yield_and_switch_from_comm_to_compute_impl, + yield_and_switch_from_compute_to_comm_impl) # Note use: layer.get_all_to_all() to get an AllToAll instance @@ -124,14 +123,14 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): do_send=send, do_recv=not send, ) - + ubatch_ctx = get_current_ubatch_context() ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1 yield_and_switch_from_compute_to_comm_impl(schedule="default") - dispatch(True) # Send + dispatch(True) # Send # torch.cuda.synchronize() # print(f"{ubatch_id} AFTER SEND SYNC", flush=True) - dispatch(False) # Recv + dispatch(False) # Recv # torch.cuda.synchronize() # print(f"{ubatch_id} AFTER RECV SYNC", flush=True) yield_and_switch_from_comm_to_compute_impl(schedule="default") @@ -174,6 +173,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): do_send=send, do_recv=not send, ) + yield_and_switch_from_compute_to_comm_impl(schedule="default") combine(True) # torch.cuda.synchronize() @@ -181,4 +181,4 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): combine(False) # print(f"{ubatch_id} AFTER COMBINE RECV SYNC", flush=True) yield_and_switch_from_comm_to_compute_impl(schedule="default") - # torch.cuda.synchronize() \ No newline at end of file + # torch.cuda.synchronize()