Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-07-03 17:23:30 +00:00
parent bb0645c644
commit 3a41a3dcff

View File

@ -1459,12 +1459,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def _get_model_inputs(self, tokens_slice: slice,
scheduler_output: "SchedulerOutput"):
num_tokens = tokens_slice.stop - tokens_slice.start
if num_tokens == 0:
# Dummy batch. (hopefully we are the last one so we can just
# update this to a one token batch and return)
tokens_slice = slice(tokens_slice.start, tokens_slice.start + 1)
num_tokens = 1
assert tokens_slice.stop - tokens_slice.start > 0
# _prepare_inputs may reorder the batch, so we must gather multi
# modal outputs after that to ensure the correct order
@ -1604,8 +1599,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def _run_ubatches(ubatch_metadata, model) -> torch.Tensor:
results: list[tuple[int, torch.Tensor]] = []
# Ubatches will manually manage the forward context, so we override
# it to None here so we can have it restored correctly later
# Ubatch threads will manually manage the forward context, so we
# override it to None here so we can have it restored correctly
# after both threads have finished
with override_forward_context(None):
ubatch_threads = []
for metadata in ubatch_metadata:
@ -1618,7 +1614,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
ubatch_threads.append(thread)
thread.start()
# logger.info("FINISHED WAKEUP LOOP")
ubatch_metadata[0].context.cpu_wait_event.set()
for thread in ubatch_threads:
thread.join()