Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson 2025-05-27 18:37:43 +00:00
parent a743a35948
commit f0b66d6929

View File

@ -48,8 +48,9 @@ class UBatchContext:
def __exit__(self, exc_type, exc_val, exc_tb):
global _CURRENT_CONTEXT
_CURRENT_CONTEXT[threading.get_ident()] = None
torch.cuda.set_stream(self.original_stream)
print("Finishing ubatch %d" % self.id)
print("Finishing ubatch %d\n" % self.id)
self.cpu_signal_event.set()
torch.cuda.set_stream(self.compute_stream)
return False
def _restore_context(self):
@ -67,11 +68,13 @@ class UBatchContext:
def _wait_comm_done(self):
self.compute_stream.wait_event(self.gpu_comm_done_event)
def _cpu_yield(self, gpu_wait: bool = True):
def _cpu_yield(self):
print("UBatchContext: %d yielding CPU\n" % self.id)
self.cpu_signal_event.set()
self.cpu_wait_event.wait()
self.cpu_wait_event.clear()
self._restore_context()
print("UBatchContext: %d resuming CPU\n" % self.id)
def yield_and_switch_from_compute_to_comm(self):
self._signal_compute_done()
@ -99,13 +102,13 @@ def yield_and_switch_from_compute_to_comm_impl(schedule="default"):
# Perform the barrier if a context exists for this thread
ctx = get_current_ubatch_context()
#print("you are in yield_impl", ctx)
if ctx is not None:
if ctx is not None and ctx.schedule == schedule:
ctx.yield_and_switch_from_compute_to_comm()
def yield_and_switch_from_comm_to_compute_impl(schedule="default"):
# Perform the barrier if a context exists for this thread
ctx = get_current_ubatch_context()
if ctx is not None:
if ctx is not None and ctx.schedule == schedule:
ctx.yield_and_switch_from_comm_to_compute()
# 2) Register kernel for CUDA, mark as mutating to prevent the compiler from