diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c93c2a097d4bc..6c52daaf05c84 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1638,7 +1638,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): @torch.inference_mode() def _ubatch_thread(results, ubatch_ctx, input_ids, positions, inputs_embeds, - intermediate_tensors, save_results): + intermediate_tensors): # print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True) model_output = _run(context=ubatch_ctx, input_ids=input_ids, @@ -1646,10 +1646,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): inputs_embeds=inputs_embeds, intermediate_tensors=intermediate_tensors) - # TODO (Sage) I think we can just delete this check now that we - # enforce that all microbatches are valid - if save_results: - results.append((ubatch_ctx.id, model_output)) + results.append((ubatch_ctx.id, model_output)) # print(f"Finishing Request on ubatch: {ubatch_ctx.id}", flush=True) def _run_ubatches(ubatch_slices, ubatch_ctxs, @@ -1663,19 +1660,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): with override_forward_context(None): ubatch_threads = [] for i, (_, tokens_slice) in enumerate(ubatch_slices): - # TODO (Sage) Consolidate all of this is_dummy_run - # is_dummy_ubatch, is attn_metadata==None, num_tokens==0 - # nonsense into some unified structure. It's way to hard - # to keep track of and keep consistent right now. - is_dummy_ubatch = tokens_slice.stop <= tokens_slice.start - assert not is_dummy_ubatch or i == len( - ubatch_slices) - 1 or is_dummy_run + assert tokens_slice.stop > tokens_slice.start - use_dummy_input = is_dummy_run or is_dummy_ubatch + use_dummy_input = is_dummy_run # The only time we don't save results is when one of our ubatches # is a dummy batch - save_results = not is_dummy_ubatch or is_dummy_run input_ids, positions, inputs_embeds, intermediate_tensors = \ model_inputs(tokens_slice, use_dummy_input) thread = threading.Thread(target=_ubatch_thread, @@ -1686,7 +1676,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): positions, inputs_embeds, intermediate_tensors, - save_results )) ubatch_threads.append(thread) thread.start()