From bb0645c6448dc135c67d5c0ba93f1f197e618614 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Thu, 3 Jul 2025 17:07:58 +0000 Subject: [PATCH] separate ubatch and normal runs Signed-off-by: Sage Moore --- vllm/v1/worker/gpu_model_runner.py | 59 +++++++++++------------------- 1 file changed, 21 insertions(+), 38 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 02b632c1377ac..b60d810ccdad3 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1588,38 +1588,20 @@ class GPUModelRunner(LoRAModelRunnerMixin): return ubatch_metadata - def _run(context, - input_ids, - positions, - inputs_embeds, - intermediate_tensors): - with context: - model_output = self.model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - ) - if isinstance(context, UBatchContext): - # Clone before we leave the ubatch context - model_output = model_output.clone() - - return model_output - @torch.inference_mode() - def _ubatch_thread(results, ubatch_metadata): + def _ubatch_thread(results, model, ubatch_metadata): # print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True) - context = ubatch_metadata.context - model_output = _run(context=context, - input_ids=ubatch_metadata.input_ids, - positions=ubatch_metadata.positions, - inputs_embeds=ubatch_metadata.inputs_embeds, - intermediate_tensors=ubatch_metadata.intermediate_tensors) - + with ubatch_metadata.context: + model_output = model( + input_ids=ubatch_metadata.input_ids, + positions=ubatch_metadata.positions, + intermediate_tensors=ubatch_metadata.intermediate_tensors, + inputs_embeds=ubatch_metadata.inputs_embeds, + ) results.append((ubatch_metadata.context.id, model_output)) # print(f"Finishing Request on ubatch: {ubatch_ctx.id}", flush=True) - def _run_ubatches(ubatch_metadata) -> torch.Tensor: + def _run_ubatches(ubatch_metadata, model) -> torch.Tensor: results: list[tuple[int, torch.Tensor]] = [] # Ubatches will manually manage the forward context, so we override @@ -1630,6 +1612,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): thread = threading.Thread(target=_ubatch_thread, args=( results, + model, metadata, )) ubatch_threads.append(thread) @@ -1657,22 +1640,22 @@ class GPUModelRunner(LoRAModelRunnerMixin): num_tokens_across_dp=num_tokens_across_dp, skip_cuda_graphs=skip_cuda_graphs ) - return _run_ubatches(ubatch_metadata) + return _run_ubatches(ubatch_metadata, self.model) # run normal batch else: input_ids, positions, inputs_embeds, intermediate_tensors = \ model_inputs(slice(0, num_scheduled_tokens), is_dummy_run) logger.info(f"NORMAL RUN {num_scheduled_tokens}") - return _run( - context = set_forward_context(attn_metadata, - vllm_config=self.vllm_config, - num_tokens=num_scheduled_tokens or 1, - num_tokens_across_dp=num_tokens_across_dp, - skip_cuda_graphs=skip_cuda_graphs), - input_ids=input_ids, - positions=positions, - inputs_embeds=inputs_embeds, - intermediate_tensors=intermediate_tensors + with set_forward_context(attn_metadata, + vllm_config=self.vllm_config, + num_tokens=num_scheduled_tokens or 1, + num_tokens_across_dp=num_tokens_across_dp, + skip_cuda_graphs=skip_cuda_graphs): + return self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, ) def _pool( self,