diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index b60d810ccdad3..ba82131bb9d2e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1459,12 +1459,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): def _get_model_inputs(self, tokens_slice: slice, scheduler_output: "SchedulerOutput"): - num_tokens = tokens_slice.stop - tokens_slice.start - if num_tokens == 0: - # Dummy batch. (hopefully we are the last one so we can just - # update this to a one token batch and return) - tokens_slice = slice(tokens_slice.start, tokens_slice.start + 1) - num_tokens = 1 + assert tokens_slice.stop - tokens_slice.start > 0 # _prepare_inputs may reorder the batch, so we must gather multi # modal outputs after that to ensure the correct order @@ -1604,8 +1599,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): def _run_ubatches(ubatch_metadata, model) -> torch.Tensor: results: list[tuple[int, torch.Tensor]] = [] - # Ubatches will manually manage the forward context, so we override - # it to None here so we can have it restored correctly later + # Ubatch threads will manually manage the forward context, so we + # override it to None here so we can have it restored correctly + # after both threads have finished with override_forward_context(None): ubatch_threads = [] for metadata in ubatch_metadata: @@ -1618,7 +1614,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ubatch_threads.append(thread) thread.start() - # logger.info("FINISHED WAKEUP LOOP") ubatch_metadata[0].context.cpu_wait_event.set() for thread in ubatch_threads: thread.join()