From 83caef8bacf44d7202c5d5b61d23ce510c40c70b Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Thu, 3 Jul 2025 13:50:19 +0000 Subject: [PATCH] cleanups for ubatching.py Signed-off-by: Sage Moore --- .../fused_moe/deepep_ll_prepare_finalize.py | 12 ++--- .../layers/fused_moe/pplx_prepare_finalize.py | 12 ++--- vllm/v1/worker/ubatching.py | 52 +------------------ 3 files changed, 14 insertions(+), 62 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index 244a63ddd82b0..360371bec5f17 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -8,8 +8,8 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.utils import ( moe_kernel_quantize_input) from vllm.v1.worker.ubatching import ( - get_current_ubatch_context, yield_and_switch_from_comm_to_compute_impl, - yield_and_switch_from_compute_to_comm_impl) + get_current_ubatch_context, yield_and_switch_from_comm_to_compute, + yield_and_switch_from_compute_to_comm) # DeepEP kernels quantize dispatch inputs in 128 element chunks. DEEPEP_QUANT_BLOCK_SIZE = 128 @@ -154,7 +154,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): a1 = a1 * rank_topk_weights.to(a1.dtype) # Dispatch - yield_and_switch_from_compute_to_comm_impl(schedule="default") + yield_and_switch_from_compute_to_comm(schedule="default") expert_x, expert_num_tokens, handle, _, _= \ self.buffers[a2a_idx].low_latency_dispatch(a1, rank_topk_ids, @@ -164,7 +164,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): async_finish=False, return_recv_hook=False) self.handles[a2a_idx] = handle - yield_and_switch_from_comm_to_compute_impl(schedule="default") + yield_and_switch_from_comm_to_compute(schedule="default") expert_x, expert_x_scale = self._do_quant(expert_x, a1_scale, a2_scale, a1.dtype) @@ -186,7 +186,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): combine_topk_weights = torch.ones_like(topk_weights) # TODO (varun) : Enable zero copy mode - yield_and_switch_from_compute_to_comm_impl(schedule="default") + yield_and_switch_from_compute_to_comm(schedule="default") _ = self.buffers[a2a_idx].low_latency_combine( fused_expert_output, topk_ids, @@ -196,5 +196,5 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): zero_copy=False, return_recv_hook=False, out=output) - yield_and_switch_from_comm_to_compute_impl(schedule="default") + yield_and_switch_from_comm_to_compute(schedule="default") diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 415b021c9d751..e35ae3a4fc737 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -9,8 +9,8 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.utils import ( moe_kernel_quantize_input) from vllm.v1.worker.ubatching import ( - get_current_ubatch_context, yield_and_switch_from_comm_to_compute_impl, - yield_and_switch_from_compute_to_comm_impl) + get_current_ubatch_context, yield_and_switch_from_comm_to_compute, + yield_and_switch_from_compute_to_comm) # The max_num_tokens, world_size and dp_size must be the same @@ -120,7 +120,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): # There's not much point setting this unless it is != indices.size(0) bound_m: Optional[torch.Tensor] = None - yield_and_switch_from_compute_to_comm_impl(schedule="default") + yield_and_switch_from_compute_to_comm(schedule="default") self.a2as[a2a_idx].dispatch( out_expert_num_tokens=expert_num_tokens, out_expert_x=expert_x, @@ -130,7 +130,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): indices=rank_topk_ids, bound_m=bound_m, ) - yield_and_switch_from_comm_to_compute_impl(schedule="default") + yield_and_switch_from_comm_to_compute(schedule="default") if expert_x_scale is not None: expert_x_scale = expert_x_scale[:, :, 0:1] @@ -162,7 +162,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): if apply_router_weight_on_input: topk_weights = torch.ones_like(topk_weights) - yield_and_switch_from_compute_to_comm_impl(schedule="default") + yield_and_switch_from_compute_to_comm(schedule="default") self.a2as[a2a_idx].combine( out_tokens=output, indices=topk_ids, @@ -170,4 +170,4 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): expert_y=fused_expert_output, bound_m=bound_m, ) - yield_and_switch_from_comm_to_compute_impl(schedule="default") + yield_and_switch_from_comm_to_compute(schedule="default") diff --git a/vllm/v1/worker/ubatching.py b/vllm/v1/worker/ubatching.py index 3defe34e06bf5..f563f76be5d3c 100644 --- a/vllm/v1/worker/ubatching.py +++ b/vllm/v1/worker/ubatching.py @@ -3,12 +3,9 @@ import threading from typing import Optional import torch -import torch._dynamo -from torch.library import custom_op from vllm import forward_context from vllm.utils import current_stream -from vllm.distributed.parallel_state import get_dp_group class UBatchContext: @@ -75,31 +72,26 @@ class UBatchContext: pass def _signal_comm_done(self): - # assert False self.ctx_valid_state() self.gpu_comm_done_event.record(self.comm_stream) def _signal_compute_done(self): - # assert False self.ctx_valid_state() self.gpu_compute_done_event.record(self.compute_stream) def _wait_compute_done(self): - # assert False # print(f"{self.id} Waiting on COMPUTE stream", flush=True) self.ctx_valid_state() self.comm_stream.wait_event(self.gpu_compute_done_event) # print("Compute stream done", flush=True) def _wait_comm_done(self): - # assert False # print(f"{self.id} Waiting on COMM stream", flush=True) self.ctx_valid_state() self.compute_stream.wait_event(self.gpu_comm_done_event) # print("Comm stream done", flush=True) def stream_string(self): - # assert False if current_stream() == self.compute_stream: assert self.current_stream == self.compute_stream return "COMPUTE" @@ -118,7 +110,6 @@ class UBatchContext: # print(f"UBatchContext: {self.id} resuming CPU", flush=True) def yield_and_switch_from_compute_to_comm(self): - # assert False assert current_stream() == self.compute_stream # dp_rank = get_dp_group().rank_in_group # print(f"DP: {dp_rank} UB: {self.id} " @@ -134,7 +125,6 @@ class UBatchContext: self._wait_compute_done() def yield_and_switch_from_comm_to_compute(self): - # assert False assert current_stream() == self.comm_stream # dp_rank = get_dp_group().rank_in_group # print(f"DP: {dp_rank} UB: {self.id} " @@ -161,7 +151,7 @@ def get_current_ubatch_context() -> Optional[UBatchContext]: return _CURRENT_CONTEXT.get(threading.get_ident(), None) -def yield_and_switch_from_compute_to_comm_impl(schedule="default"): +def yield_and_switch_from_compute_to_comm(schedule="default"): # Perform the barrier if a context exists for this thread ctx = get_current_ubatch_context() #print("you are in yield_impl", ctx) @@ -169,50 +159,12 @@ def yield_and_switch_from_compute_to_comm_impl(schedule="default"): ctx.yield_and_switch_from_compute_to_comm() -def yield_and_switch_from_comm_to_compute_impl(schedule="default"): +def yield_and_switch_from_comm_to_compute(schedule="default"): # Perform the barrier if a context exists for this thread ctx = get_current_ubatch_context() if ctx is not None and ctx.schedule == schedule: ctx.yield_and_switch_from_comm_to_compute() - -# 2) Register kernel for CUDA, mark as mutating to prevent the compiler from -# optimizing it away (TODO: see if this is actually needed) -@custom_op("vllm::yield_and_switch_from_compute_to_comm", mutates_args=("x", )) -def yield_and_switch_from_compute_to_comm(x: torch.Tensor, - schedule: str = "default") -> None: - yield_and_switch_from_compute_to_comm_impl(schedule) - - -# 3) Fake implementation for shape prop and FX tracing -@yield_and_switch_from_compute_to_comm.register_fake -def yield_and_switch_from_compute_to_comm_fake(x: torch.Tensor, - schedule: str = "default" - ) -> None: - pass - - -@custom_op("vllm::yield_and_switch_from_comm_to_compute", mutates_args=("x", )) -def yield_and_switch_from_comm_to_compute(x: torch.Tensor, - schedule: str = "default") -> None: - yield_and_switch_from_comm_to_compute_impl(schedule) - - -@yield_and_switch_from_comm_to_compute.register_fake -def yield_and_switch_from_comm_to_compute_fake(x: torch.Tensor, - schedule: str = "default" - ) -> None: - pass - - -def dump_ubatching_state(): - pass - - -""" -""" - - def make_ubatch_contexts( num_micro_batches: int, compute_stream: torch.cuda.Stream,