seperate gpu wait

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson 2025-05-22 21:52:27 +00:00
parent a8439e2fd4
commit 00f526f55b

View File

@ -37,7 +37,8 @@ class UBatchContext:
def __enter__(self):
global _CURRENT_CONTEXT
_CURRENT_CONTEXT[threading.get_ident()] = self
self._wait()
self._cpu_wait()
self.gpu_stream_wait()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
@ -53,9 +54,29 @@ class UBatchContext:
torch.cuda.set_stream(self.stream)
forward_context._forward_context = self.forward_context
def yield_(self):
# Seperate GPU wait so we can do
# ubatch0
# 1) work
# 2) dispatch
# 3) yield
# ubatch1
# 1) work
# 2) gpu wait
# 3) dispatch
# 4) yield
#
# This way we can have the CPU schedule ubatch1-dispatch while ubatch0
# before yielding back to ubatch1 but ensure we wont start the dispatch
# until ubatch0-dispatch is done avoiding overlapping dispatches that
# might share underlying buffers
def gpu_stream_wait(self):
self.stream.wait_event(self.gpu_wait_event)
def yield_(self, gpu_wait: bool = True):
self._signal()
self._wait()
self._cpu_wait()
if gpu_wait:
self.gpu_stream_wait()
def _signal(self):
# Wait for the next batch to signal back
@ -63,10 +84,9 @@ class UBatchContext:
# Signal that this batch reached the barrier
self.cpu_signal_event.set()
def _wait(self):
def _cpu_wait(self):
self.cpu_wait_event.wait()
self.cpu_wait_event.clear()
self.stream.wait_event(self.gpu_wait_event)
self._restore_context()
_CURRENT_CONTEXT: dict = {}
@ -121,7 +141,7 @@ def make_ubatch_context_chain(
ctxs.append(ctx)
def start_hook(from_stream: torch.cuda.Stream):
ctxs[0].cpu_wait_event.set()
ctxs[0].gpu_wait_event.record(from_stream)
ctxs[0].gpu_wait_event.record(from_stream)
ctxs[0].cpu_wait_event.set()
return ctxs, start_hook