From 0e2b4bd5464a321100d68354851db8e93ca46391 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Wed, 25 Jun 2025 23:43:49 +0000 Subject: [PATCH] more refactoring Signed-off-by: Sage Moore --- vllm/v1/worker/gpu_model_runner.py | 72 +++++++++++++++++++----------- 1 file changed, 45 insertions(+), 27 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 2957cfb1c85b0..c93c2a097d4bc 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1618,11 +1618,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): assert scheduler_output is not None return self._get_model_inputs(tokens_slice, scheduler_output) - def _run(token_slice: slice, - context, - use_dummy_input: bool = False): - input_ids, positions, inputs_embeds, intermediate_tensors = \ - model_inputs(token_slice, use_dummy_input) + def _run(context, + input_ids, + positions, + inputs_embeds, + intermediate_tensors): with context: model_output = self.model( input_ids=input_ids, @@ -1637,26 +1637,27 @@ class GPUModelRunner(LoRAModelRunnerMixin): return model_output @torch.inference_mode() - def _ubatch_thread(ubatch_ctx, token_slice, results, save_results, - use_dummy_input): + def _ubatch_thread(results, ubatch_ctx, input_ids, positions, inputs_embeds, + intermediate_tensors, save_results): # print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True) - model_output = _run(token_slice, ubatch_ctx, use_dummy_input) + model_output = _run(context=ubatch_ctx, + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors) + # TODO (Sage) I think we can just delete this check now that we + # enforce that all microbatches are valid if save_results: results.append((ubatch_ctx.id, model_output)) # print(f"Finishing Request on ubatch: {ubatch_ctx.id}", flush=True) - def _run_ubatches(ubatch_slices, attn_metadata, - is_dummy_run, num_tokens_across_dp) -> torch.Tensor: + def _run_ubatches(ubatch_slices, ubatch_ctxs, + is_dummy_run) -> torch.Tensor: results: list[tuple[int, torch.Tensor]] = [] assert len(ubatch_slices) == 2, "Only two ubatches has been tested" root_stream = current_stream() - ubatch_ctxs = _make_ubatch_contexts(ubatch_slices=ubatch_slices, - attn_metadata=attn_metadata, - is_dummy_run=is_dummy_run, - num_tokens_across_dp=num_tokens_across_dp) - # Ubatches will manually manage the forward context, so we override # it to None here so we can have it restored correctly later with override_forward_context(None): @@ -1670,15 +1671,22 @@ class GPUModelRunner(LoRAModelRunnerMixin): assert not is_dummy_ubatch or i == len( ubatch_slices) - 1 or is_dummy_run + use_dummy_input = is_dummy_run or is_dummy_ubatch + + # The only time we don't save results is when one of our ubatches + # is a dummy batch + save_results = not is_dummy_ubatch or is_dummy_run + input_ids, positions, inputs_embeds, intermediate_tensors = \ + model_inputs(tokens_slice, use_dummy_input) thread = threading.Thread(target=_ubatch_thread, args=( - ubatch_ctxs[i], - tokens_slice, results, - not is_dummy_ubatch - or is_dummy_run, - is_dummy_ubatch - or is_dummy_run + ubatch_ctxs[i], + input_ids, + positions, + inputs_embeds, + intermediate_tensors, + save_results )) ubatch_threads.append(thread) thread.start() @@ -1775,21 +1783,31 @@ class GPUModelRunner(LoRAModelRunnerMixin): # num_tokens = ubatch_slices[1][1].stop # print(f"RUNNING UBATCH {num_tokens} is_dummy_run: {is_dummy_run} num_tokens_across_dp{num_tokens_across_dp}") # assert not is_dummy_run + ubatch_ctxs = _make_ubatch_contexts( + ubatch_slices=ubatch_slices, + attn_metadata=attn_metadata, + is_dummy_run=is_dummy_run, + num_tokens_across_dp=num_tokens_across_dp + ) model_output = _run_ubatches(ubatch_slices, - attn_metadata, - is_dummy_run, - num_tokens_across_dp=num_tokens_across_dp) + ubatch_ctxs, + is_dummy_run) # run single batch else: # print("RUN NORMAL") + input_ids, positions, inputs_embeds, intermediate_tensors = \ + model_inputs(slice(0, num_scheduled_tokens), is_dummy_run) model_output = _run( - slice(0, num_scheduled_tokens), - set_forward_context(attn_metadata, + context = set_forward_context(attn_metadata, vllm_config=self.vllm_config, num_tokens=num_scheduled_tokens or 1, num_tokens_across_dp=num_tokens_across_dp, skip_cuda_graphs=skip_cuda_graphs), - is_dummy_run) + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors + ) return model_output