diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 6c52daaf05c84..59c02fd5daa18 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1570,10 +1570,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): @dataclasses.dataclass class UbatchMetadata: - ubatch_id: int context: UBatchContext - ubatch_slice: UbatchSlice - input_ids: torch.Tensor positions: torch.Tensor inputs_embeds: Optional[torch.Tensor] @@ -1618,6 +1615,31 @@ class GPUModelRunner(LoRAModelRunnerMixin): assert scheduler_output is not None return self._get_model_inputs(tokens_slice, scheduler_output) + def _make_ubatch_metadata(ubatch_slices, + attn_metadata, + is_dummy_run, + num_tokens_across_dp) -> list[UbatchMetadata]: + ubatch_ctxs = _make_ubatch_contexts( + ubatch_slices=ubatch_slices, + attn_metadata=attn_metadata, + is_dummy_run=is_dummy_run, + num_tokens_across_dp=num_tokens_across_dp + ) + # First get some inputs + ubatch_metadata: list[UbatchMetadata] = [] + for i, (_, tokens_slice) in enumerate(ubatch_slices): + input_ids, positions, inputs_embeds, intermediate_tensors = \ + model_inputs(tokens_slice, is_dummy_run) + ubatch_metadata.append(UbatchMetadata( + context=ubatch_ctxs[i], + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors + )) + + return ubatch_metadata + def _run(context, input_ids, positions, @@ -1637,49 +1659,34 @@ class GPUModelRunner(LoRAModelRunnerMixin): return model_output @torch.inference_mode() - def _ubatch_thread(results, ubatch_ctx, input_ids, positions, inputs_embeds, - intermediate_tensors): + def _ubatch_thread(results, ubatch_metadata): # print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True) - model_output = _run(context=ubatch_ctx, - input_ids=input_ids, - positions=positions, - inputs_embeds=inputs_embeds, - intermediate_tensors=intermediate_tensors) + model_output = _run(context=ubatch_metadata.context, + input_ids=ubatch_metadata.input_ids, + positions=ubatch_metadata.positions, + inputs_embeds=ubatch_metadata.inputs_embeds, + intermediate_tensors=ubatch_metadata.intermediate_tensors) - results.append((ubatch_ctx.id, model_output)) + results.append((ubatch_metadata.context.id, model_output)) # print(f"Finishing Request on ubatch: {ubatch_ctx.id}", flush=True) - def _run_ubatches(ubatch_slices, ubatch_ctxs, - is_dummy_run) -> torch.Tensor: + def _run_ubatches(ubatch_metadata) -> torch.Tensor: results: list[tuple[int, torch.Tensor]] = [] - assert len(ubatch_slices) == 2, "Only two ubatches has been tested" root_stream = current_stream() # Ubatches will manually manage the forward context, so we override # it to None here so we can have it restored correctly later with override_forward_context(None): ubatch_threads = [] - for i, (_, tokens_slice) in enumerate(ubatch_slices): - assert tokens_slice.stop > tokens_slice.start - - use_dummy_input = is_dummy_run - - # The only time we don't save results is when one of our ubatches - # is a dummy batch - input_ids, positions, inputs_embeds, intermediate_tensors = \ - model_inputs(tokens_slice, use_dummy_input) + for metadata in ubatch_metadata: thread = threading.Thread(target=_ubatch_thread, args=( results, - ubatch_ctxs[i], - input_ids, - positions, - inputs_embeds, - intermediate_tensors, + metadata, )) ubatch_threads.append(thread) thread.start() - ubatch_ctxs[0].cpu_wait_event.set() + ubatch_metadata[0].context.cpu_wait_event.set() for thread in ubatch_threads: thread.join() @@ -1769,18 +1776,17 @@ class GPUModelRunner(LoRAModelRunnerMixin): # ) # run micro-batched if ubatch_slices is not None: + assert len(ubatch_slices) == 2, "Only two ubatches has been tested" # num_tokens = ubatch_slices[1][1].stop # print(f"RUNNING UBATCH {num_tokens} is_dummy_run: {is_dummy_run} num_tokens_across_dp{num_tokens_across_dp}") # assert not is_dummy_run - ubatch_ctxs = _make_ubatch_contexts( - ubatch_slices=ubatch_slices, - attn_metadata=attn_metadata, - is_dummy_run=is_dummy_run, - num_tokens_across_dp=num_tokens_across_dp - ) - model_output = _run_ubatches(ubatch_slices, - ubatch_ctxs, - is_dummy_run) + ubatch_metadata = _make_ubatch_metadata( + ubatch_slices=ubatch_slices, + attn_metadata=attn_metadata, + is_dummy_run=is_dummy_run, + num_tokens_across_dp=num_tokens_across_dp + ) + model_output = _run_ubatches(ubatch_metadata) # run single batch else: # print("RUN NORMAL")