diff --git a/vllm/v1/worker/ubatching.py b/vllm/v1/worker/ubatching.py index 2dc84dc579092..3688ae8129db6 100644 --- a/vllm/v1/worker/ubatching.py +++ b/vllm/v1/worker/ubatching.py @@ -37,7 +37,8 @@ class UBatchContext: def __enter__(self): global _CURRENT_CONTEXT _CURRENT_CONTEXT[threading.get_ident()] = self - self._wait() + self._cpu_wait() + self.gpu_stream_wait() return self def __exit__(self, exc_type, exc_val, exc_tb): @@ -53,9 +54,29 @@ class UBatchContext: torch.cuda.set_stream(self.stream) forward_context._forward_context = self.forward_context - def yield_(self): + # Seperate GPU wait so we can do + # ubatch0 + # 1) work + # 2) dispatch + # 3) yield + # ubatch1 + # 1) work + # 2) gpu wait + # 3) dispatch + # 4) yield + # + # This way we can have the CPU schedule ubatch1-dispatch while ubatch0 + # before yielding back to ubatch1 but ensure we wont start the dispatch + # until ubatch0-dispatch is done avoiding overlapping dispatches that + # might share underlying buffers + def gpu_stream_wait(self): + self.stream.wait_event(self.gpu_wait_event) + + def yield_(self, gpu_wait: bool = True): self._signal() - self._wait() + self._cpu_wait() + if gpu_wait: + self.gpu_stream_wait() def _signal(self): # Wait for the next batch to signal back @@ -63,10 +84,9 @@ class UBatchContext: # Signal that this batch reached the barrier self.cpu_signal_event.set() - def _wait(self): + def _cpu_wait(self): self.cpu_wait_event.wait() self.cpu_wait_event.clear() - self.stream.wait_event(self.gpu_wait_event) self._restore_context() _CURRENT_CONTEXT: dict = {} @@ -121,7 +141,7 @@ def make_ubatch_context_chain( ctxs.append(ctx) def start_hook(from_stream: torch.cuda.Stream): - ctxs[0].cpu_wait_event.set() - ctxs[0].gpu_wait_event.record(from_stream) + ctxs[0].gpu_wait_event.record(from_stream) + ctxs[0].cpu_wait_event.set() return ctxs, start_hook \ No newline at end of file