From 7b31e8a8fff8a502d9480e805153168480725d44 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 27 May 2025 16:51:27 +0000 Subject: [PATCH] wip seperate comm and compute threads Signed-off-by: Lucas Wilkinson --- .../layers/fused_moe/pplx_prepare_finalize.py | 21 +-- vllm/v1/worker/gpu_model_runner.py | 15 +- vllm/v1/worker/ubatching.py | 175 ++++++++---------- 3 files changed, 85 insertions(+), 126 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index f5276637326b0..07f5b2bfba998 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -7,7 +7,10 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.utils import ( moe_kernel_quantize_input) -from vllm.v1.worker.ubatching import get_current_ubatch_context, yield_impl +from vllm.v1.worker.ubatching import ( + get_current_ubatch_context, yield_and_switch_from_compute_to_comm_impl, + yield_and_switch_from_comm_to_compute_impl +) # Note use: layer.get_all_to_all() to get an AllToAll instance @@ -119,14 +122,10 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): do_recv=not send, ) - #print("Dispatch pre-wait") - if (ubatch_ctx := get_current_ubatch_context()) is not None: - ubatch_ctx.gpu_stream_wait() - #print("Dispatch launched") + yield_and_switch_from_compute_to_comm_impl(schedule="default") dispatch(True) # Send - yield_impl(gpu_wait=False) dispatch(False) # Recv - #print("Finished dispatch") + yield_and_switch_from_comm_to_compute_impl(schedule="default") return expert_x, expert_x_scale, expert_num_tokens @@ -164,11 +163,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): do_recv=not send, ) - #print("Combine pre-wait") - if (ubatch_ctx := get_current_ubatch_context()) is not None: - ubatch_ctx.gpu_stream_wait() + yield_and_switch_from_compute_to_comm_impl(schedule="default") combine(True) - #print("Combine launched") - yield_impl(gpu_wait=False) combine(False) - #print("Finished combine") + yield_and_switch_from_comm_to_compute_impl(schedule="default") \ No newline at end of file diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index be756929959ff..5c456eb409e6c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -59,7 +59,7 @@ from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin -from vllm.v1.worker.ubatching import make_ubatch_context_chain, UBatchContext +from vllm.v1.worker.ubatching import make_ubatch_contexts, UBatchContext from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs, scatter_mm_placeholders) @@ -1342,19 +1342,13 @@ class GPUModelRunner(LoRAModelRunnerMixin): # attn_metadata[i] if attn_metadata is not None else None, # self.vllm_config, num_tokens=(tokens_slice.stop - tokens_slice.start) # ) for i, (_, tokens_slice) in enumerate(ubatch_slices)] - ubatch_ctxs, start_hook = make_ubatch_context_chain( + ubatch_ctxs, start_hook = make_ubatch_contexts( len(ubatch_slices), - #fwd_ctxs=ubatch_fwd_ctxs, - streams=self.ubatch_streams, #stream=root_stream, # Only works currently if everything is run on the same stream + compute_stream=root_stream, device=self.device) setup_done = threading.Event() ubatch_threads = [] - # Initialize Events? not sure if this helps - for ubatch_ctx in ubatch_ctxs: - ubatch_ctx.gpu_wait_event.record(ubatch_ctx.stream) - ubatch_ctx.stream.wait_event(ubatch_ctx.gpu_wait_event) - # Ubatches will manually manage the forward context, so we override # it to None here so we can have it restored correctly later with override_forward_context(None): @@ -1388,9 +1382,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): for thread in ubatch_threads: thread.join() - - for ubatch_ctx in ubatch_ctxs: - root_stream.wait_stream(ubatch_ctx.stream) torch.cuda.set_stream(root_stream) return torch.cat(results, dim=0) diff --git a/vllm/v1/worker/ubatching.py b/vllm/v1/worker/ubatching.py index 182a20698fd1f..1907e0509a91f 100644 --- a/vllm/v1/worker/ubatching.py +++ b/vllm/v1/worker/ubatching.py @@ -17,36 +17,31 @@ class UBatchContext: """ def __init__(self, id: int, - stream: torch.cuda.Stream, + comm_stream: torch.cuda.Stream, + compute_stream: torch.cuda.Stream, #fwd_ctx: forward_context.ForwardContext, cpu_wait_event: threading.Event, cpu_signal_event: threading.Event, - gpu_wait_event: torch.cuda.Event, - gpu_signal_event: torch.cuda.Event, - gpu_wait_on_launch: bool = False, - schedule="default"): + gpu_comm_done_event: torch.cuda.Event, + gpu_compute_done_event: torch.cuda.Event, + schedule: str = "default"): self.id = id - self.stream = stream + self.comm_stream = comm_stream + self.compute_stream = compute_stream self.original_stream = current_stream() self.forward_context = None #fwd_ctx self.cpu_wait_event = cpu_wait_event self.cpu_signal_event = cpu_signal_event - self.gpu_wait_event = gpu_wait_event - self.gpu_signal_event = gpu_signal_event + self.gpu_comm_done_event = gpu_comm_done_event + self.gpu_compute_done_event = gpu_compute_done_event self.schedule = schedule - self.done_event = torch.cuda.Event() - self.gpu_wait_on_launch = gpu_wait_on_launch def __enter__(self): global _CURRENT_CONTEXT _CURRENT_CONTEXT[threading.get_ident()] = self - self._cpu_wait() - # start_event = torch.cuda.Event() - # self.original_stream.record_event(start_event) - # self.stream.wait_event(start_event) - print("Starting ubatch %d" % self.id) - # if self.gpu_wait_on_launch: - self.gpu_stream_wait() + # Assume we start on the compute stream + assert current_stream() == self.compute_stream, \ + "Expected to start on the compute stream, but found %s" % current_stream() return self def __exit__(self, exc_type, exc_val, exc_tb): @@ -54,9 +49,6 @@ class UBatchContext: _CURRENT_CONTEXT[threading.get_ident()] = None torch.cuda.set_stream(self.original_stream) print("Finishing ubatch %d" % self.id) - self._signal() - self._signal() - self._signal() return False def _restore_context(self): @@ -65,66 +57,37 @@ class UBatchContext: torch.cuda.set_stream(self.stream) forward_context._forward_context = self.forward_context - # Seperate GPU wait so we can do - # ubatch0 - # 1) work - # 2) dispatch - # 3) yield - # ubatch1 - # 1) work - # 2) gpu wait - # 3) dispatch - # 4) yield - # - # This way we can have the CPU schedule ubatch1-dispatch while ubatch0 - # before yielding back to ubatch1 but ensure we wont start the dispatch - # until ubatch0-dispatch is done avoiding overlapping dispatches that - # might share underlying buffers - # - # NOTE(lucas): I think we need to do: - # ubatch0 - # - work - # - dispatch send - # - yield - # ubatch1 - # - work - # - yield - # ubatch0 - # - dispatch recv - # - gpu record, event0 - # - yield - # ubatch1 - # - gpu wait, event0 - # - dispatch send - # - yield - # ubatch0 - # - work - # ..... - # To ensure we record the cuda event before waiting - def gpu_stream_wait(self): - print("Waiting ubatch %d on %s in stream %s" % (self.id, self.gpu_wait_event, self.stream)) - self.stream.wait_event(self.gpu_wait_event) + def _signal_comm_done(self): + self.gpu_comm_done_event.record(self.comm_stream) + + def _signal_compute_done(self): + self.gpu_compute_done_event.record(self.compute_stream) - def _yield(self, gpu_wait: bool = True): - #print("Yielding ubatch %d" % self.id) - self._signal() - self._cpu_wait() - #print("Resuming ubatch %d" % self.id) - if gpu_wait: - self.gpu_stream_wait() + def _wait_compute_done(self): + self.comm_stream.wait_event(self.gpu_compute_done_event) - def _signal(self): - # Wait for the next batch to signal back - print(f"signaling ubatch {self.id} to {self.gpu_signal_event} on {self.stream}") - self.gpu_signal_event.record(self.stream) - # Signal that this batch reached the barrier + def _wait_comm_done(self): + self.compute_stream.wait_event(self.gpu_comm_done_event) + + def _cpu_yield(self, gpu_wait: bool = True): self.cpu_signal_event.set() - - def _cpu_wait(self): self.cpu_wait_event.wait() self.cpu_wait_event.clear() self._restore_context() + def yield_and_switch_from_compute_to_comm(self): + self._signal_compute_done() + self._cpu_yield() + torch.cuda.set_stream(self.comm_stream) + self._wait_compute_done() + + def yield_and_switch_from_comm_to_compute(self): + self._signal_comm_done() + self._cpu_yield() + torch.cuda.set_stream(self.compute_stream) + self._wait_comm_done() + + _CURRENT_CONTEXT: dict = {} def get_current_ubatch_context() -> Optional[UBatchContext]: @@ -134,23 +97,36 @@ def get_current_ubatch_context() -> Optional[UBatchContext]: """ return _CURRENT_CONTEXT.get(threading.get_ident(), None) -def yield_impl(schedule="default", gpu_wait: bool = True): +def yield_and_switch_from_compute_to_comm_impl(schedule="default"): # Perform the barrier if a context exists for this thread ctx = get_current_ubatch_context() #print("you are in yield_impl", ctx) if ctx is not None: - ctx._yield(gpu_wait=gpu_wait) + ctx.yield_and_switch_from_compute_to_comm() +def yield_and_switch_from_comm_to_compute_impl(schedule="default"): + # Perform the barrier if a context exists for this thread + ctx = get_current_ubatch_context() + if ctx is not None: + ctx.yield_and_switch_from_comm_to_compute() # 2) Register kernel for CUDA, mark as mutating to prevent the compiler from # optimizing it away (TODO: see if this is actually needed) -@custom_op("vllm::yield_", mutates_args=("x",)) -def yield_(x: torch.Tensor, schedule: str="default") -> None: - yield_impl(schedule) +@custom_op("vllm::yield_and_switch_from_compute_to_comm", mutates_args=("x",)) +def yield_and_switch_from_compute_to_comm(x: torch.Tensor, schedule: str="default") -> None: + yield_and_switch_from_compute_to_comm_impl(schedule) # 3) Fake implementation for shape prop and FX tracing -@yield_.register_fake -def yield_(x: torch.Tensor, schedule: str="default") -> None: +@yield_and_switch_from_compute_to_comm.register_fake +def yield_and_switch_from_compute_to_comm(x: torch.Tensor, schedule: str="default") -> None: + pass + +@custom_op("vllm::yield_and_switch_from_comm_to_compute", mutates_args=("x",)) +def yield_and_switch_from_comm_to_compute(x: torch.Tensor, schedule: str="default") -> None: + yield_and_switch_from_comm_to_compute_impl(schedule) + +@yield_and_switch_from_comm_to_compute.register_fake +def yield_and_switch_from_comm_to_compute(x: torch.Tensor, schedule: str="default") -> None: pass def dump_ubatching_state(): @@ -169,16 +145,13 @@ def dump_ubatching_state(): f" CPU Signal Event: {ctx.cpu_signal_event}\n" f" GPU Signal Event: {ctx.gpu_signal_event} ({ctx.gpu_signal_event.query()})\n") - - """ - """ -def make_ubatch_context_chain( +def make_ubatch_contexts( num_micro_batches: int, - #fwd_ctxs: forward_context.ForwardContext, - streams: Optional[list[torch.Stream]] = None, - device: Optional[torch.device] = None + compute_stream: torch.cuda.Stream, + device: Optional[torch.device] = None, + schedule: str = "default", ) -> list[UBatchContext]: assert num_micro_batches == 2, "only been tested with 2 micro-batches" @@ -186,26 +159,26 @@ def make_ubatch_context_chain( Create a context manager for micro-batching synchronization. """ cpu_events = [threading.Event() for _ in range(num_micro_batches)] - gpu_events = [torch.cuda.Event(blocking=True) for _ in range(num_micro_batches)] + gpu_comm_done_events = [ + torch.cuda.Event() for _ in range(num_micro_batches) + ] + gpu_compute_done_events = [ + torch.cuda.Event() for _ in range(num_micro_batches) + ] device = device or torch.cuda.current_device() - + comm_stream = torch.cuda.Stream(device) + ctxs = [] for i in range(num_micro_batches): - stream = (streams[i] if streams else None) or torch.cuda.Stream(device) ctx = UBatchContext(id=i, - stream=stream, - #fwd_ctx=fwd_ctxs[i], + compute_stream=compute_stream, + comm_stream=comm_stream, cpu_wait_event=cpu_events[i], cpu_signal_event=cpu_events[(i + 1) % num_micro_batches], - gpu_wait_event=gpu_events[i], - gpu_signal_event=gpu_events[(i + 1) % num_micro_batches], - gpu_wait_on_launch=(i > 0), + gpu_comm_done_event=gpu_comm_done_events[i], + gpu_compute_done_event=gpu_compute_done_events[i], + schedule=schedule ) ctxs.append(ctx) - - def start_hook(from_stream: torch.cuda.Stream): - ctxs[0].gpu_wait_event.record(from_stream) - print('singal to ubatch %d event %s from stream %s' % (ctxs[0].id, ctxs[0].gpu_wait_event, from_stream)) - ctxs[0].cpu_wait_event.set() - return ctxs, start_hook \ No newline at end of file + return ctxs, \ No newline at end of file