Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson 2025-05-23 03:32:25 +00:00
parent 18bf91e6a8
commit 2dc3b8b0a2
2 changed files with 31 additions and 6 deletions

View File

@ -1300,12 +1300,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
input_ids, positions, inputs_embeds, intermediate_tensors = \
model_inputs(token_slice, use_dummy_input)
with context:
if isinstance(context, UBatchContext):
print(f"Running ubatch {context.id} with input_ids {input_ids.shape} and positions {positions.shape}")
model_output = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
if isinstance(context, UBatchContext):
print(f"Ran ubatch {context.id}putput {model_output.shape}")
if isinstance(context, UBatchContext):
# Clone before we leave the ubatch context
model_output = model_output.clone()

View File

@ -22,6 +22,7 @@ class UBatchContext:
cpu_signal_event: threading.Event,
gpu_wait_event: torch.cuda.Event,
gpu_signal_event: torch.cuda.Event,
gpu_wait_on_launch: bool = False,
schedule="default"):
self.id = id
self.stream = stream
@ -33,18 +34,27 @@ class UBatchContext:
self.gpu_signal_event = gpu_signal_event
self.schedule = schedule
self.done_event = torch.cuda.Event()
self.gpu_wait_on_launch = gpu_wait_on_launch
def __enter__(self):
global _CURRENT_CONTEXT
_CURRENT_CONTEXT[threading.get_ident()] = self
self._cpu_wait()
self.gpu_stream_wait()
start_event = torch.cuda.Event()
self.original_stream.record_event(start_event)
self.stream.wait_event(start_event)
print("Starting ubatch %d" % self.id)
if self.gpu_wait_on_launch:
self.gpu_stream_wait()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
global _CURRENT_CONTEXT
_CURRENT_CONTEXT[threading.get_ident()] = None
torch.cuda.set_stream(self.original_stream)
print("Finishing ubatch %d" % self.id)
self._signal()
self._signal()
self._signal()
return False
@ -72,9 +82,11 @@ class UBatchContext:
def gpu_stream_wait(self):
self.stream.wait_event(self.gpu_wait_event)
def yield_(self, gpu_wait: bool = True):
def _yield(self, gpu_wait: bool = True):
print("Yielding ubatch %d" % self.id)
self._signal()
self._cpu_wait()
print("Resuming ubatch %d" % self.id)
if gpu_wait:
self.gpu_stream_wait()
@ -91,11 +103,19 @@ class UBatchContext:
_CURRENT_CONTEXT: dict = {}
def yield_impl(schedule="default"):
def get_current_ubatch_context() -> Optional[UBatchContext]:
global _CURRENT_CONTEXT
"""
Get the current UBatchContext for the current thread.
"""
return _CURRENT_CONTEXT.get(threading.get_ident(), None)
def yield_impl(schedule="default", gpu_wait: bool = True):
# Perform the barrier if a context exists for this thread
ctx = _CURRENT_CONTEXT.get(threading.get_ident(), None)
if ctx is not None and ctx.schedule == schedule:
ctx.yield_()
ctx = get_current_ubatch_context()
print("you are in yield_impl", ctx)
if ctx is not None:
ctx._yield(gpu_wait=gpu_wait)
# 2) Register kernel for CUDA, mark as mutating to prevent the compiler from
@ -137,6 +157,7 @@ def make_ubatch_context_chain(
cpu_signal_event=cpu_events[(i + 1) % num_micro_batches],
gpu_wait_event=gpu_events[i],
gpu_signal_event=gpu_events[(i + 1) % num_micro_batches],
gpu_wait_on_launch=(i > 0),
)
ctxs.append(ctx)