From 462c6b0b504a6f751f4fda4527eee5150e10d4d8 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Tue, 8 Jul 2025 18:59:36 +0000 Subject: [PATCH] remove some dummy_run logic Signed-off-by: Sage Moore --- vllm/v1/worker/gpu_model_runner.py | 112 ++++++++++++----------------- 1 file changed, 47 insertions(+), 65 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 94a623dc878b8..f62cf76952a96 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1442,37 +1442,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): return DPMetadata.should_ubatch_across_dp(should_ubatch, dp_size, dp_rank) - def _get_dummy_model_inputs(self, num_tokens: int) -> tuple: - # Dummy batch. (hopefully we are the last one so we can just - # update this to a one token batch and return) - - if self.is_multimodal_model: - input_ids = None - inputs_embeds = self.inputs_embeds[:num_tokens] - else: - input_ids = self.input_ids[:num_tokens] - inputs_embeds = None - - if self.uses_mrope: - positions = self.mrope_positions[:, :num_tokens] - else: - positions = self.positions[:num_tokens] - - if get_pp_group().is_first_rank: - intermediate_tensors = None - else: - if self.intermediate_tensors is None: - self.intermediate_tensors = ( - self.model.make_empty_intermediate_tensors( - batch_size=self.max_num_tokens, - dtype=self.model_config.dtype, - device=self.device)) - - intermediate_tensors = self.sync_and_slice_intermediate_tensors( - slice(0, num_tokens), None, False) - - return input_ids, positions, inputs_embeds, intermediate_tensors - def _get_model_inputs(self, tokens_slice: slice, scheduler_output: "SchedulerOutput"): assert tokens_slice.stop - tokens_slice.start > 0 @@ -1519,18 +1488,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): tokens_slice, intermediate_tensors, True) return input_ids, positions, inputs_embeds, intermediate_tensors - def model_inputs(self, tokens_slice: slice, use_dummy_input: bool, - scheduler_output: Optional["SchedulerOutput"]) -> tuple: - if use_dummy_input: - return self._get_dummy_model_inputs(tokens_slice.stop - - tokens_slice.start) - else: - assert scheduler_output is not None - return self._get_model_inputs(tokens_slice, scheduler_output) - def _make_ubatch_metadata(self, ubatch_slices, attn_metadata, - compute_stream, is_dummy_run, - num_tokens_across_dp, skip_cuda_graphs, + compute_stream, num_tokens_across_dp, + skip_cuda_graphs, scheduler_output) -> list[UbatchMetadata]: # Create one forward context per ubatch @@ -1554,7 +1514,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ubatch_metadata: list[UbatchMetadata] = [] for i, (_, tokens_slice) in enumerate(ubatch_slices): input_ids, positions, inputs_embeds, intermediate_tensors = \ - self.model_inputs(tokens_slice, is_dummy_run, scheduler_output) + self._get_model_inputs(tokens_slice, scheduler_output) ubatch_metadata.append( UbatchMetadata(context=ubatch_ctxs[i], input_ids=input_ids, @@ -1602,11 +1562,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): return result def _run_model(self, - attn_metadata: Optional[PerLayerAttnMetadata], - num_scheduled_tokens: Optional[int], + attn_metadata: PerLayerAttnMetadata, + num_scheduled_tokens: int, + scheduler_output: "SchedulerOutput", ubatch_slices: Optional[UBatchSlices] = None, - scheduler_output: Optional["SchedulerOutput"] = None, - is_dummy_run: bool = False, num_tokens_across_dp: Optional[torch.Tensor] = None, skip_cuda_graphs: bool = False): @@ -1619,7 +1578,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ubatch_slices=ubatch_slices, attn_metadata=attn_metadata, compute_stream=compute_stream, - is_dummy_run=is_dummy_run, num_tokens_across_dp=num_tokens_across_dp, skip_cuda_graphs=skip_cuda_graphs, scheduler_output=scheduler_output) @@ -1627,9 +1585,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): # run normal batch else: input_ids, positions, inputs_embeds, intermediate_tensors = \ - self.model_inputs(slice(0, num_scheduled_tokens), - is_dummy_run, - scheduler_output) + self._get_model_inputs(slice(0, num_scheduled_tokens), + scheduler_output) with set_forward_context(attn_metadata, vllm_config=self.vllm_config, num_tokens=num_scheduled_tokens or 1, @@ -1723,14 +1680,14 @@ class GPUModelRunner(LoRAModelRunnerMixin): # compiled with full CUDA graphs, we have to skip them entirely. skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs - # Run the decoder. + # Run the model. # Use persistent buffers for CUDA graphs. self.maybe_setup_kv_connector(scheduler_output) model_output = self._run_model( attn_metadata=attn_metadata, num_scheduled_tokens=num_input_tokens, - ubatch_slices=ubatch_slices, scheduler_output=scheduler_output, + ubatch_slices=ubatch_slices, num_tokens_across_dp=num_tokens_after_padding, skip_cuda_graphs=skip_cuda_graphs, ) @@ -2022,7 +1979,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): def kv_connector_no_forward( self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput: - # KV send/recv even if no work to do. with set_forward_context(None, self.vllm_config): self.maybe_setup_kv_connector(scheduler_output) @@ -2308,7 +2264,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): def _dummy_run( self, num_tokens: int, - # Maybe return a cudagraph here capture_attn_cudagraph: bool = False, skip_eplb: bool = False, is_profile: bool = False, @@ -2336,8 +2291,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) - # We currently only microbatch if the number of tokens is - # over a certain threshold. attn_metadata: Optional[dict[str, Any]] = None if capture_attn_cudagraph: attn_metadata = {} @@ -2350,13 +2303,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): non_blocking=True) seq_lens = self.seq_lens[:num_reqs] - max_query_len = num_tokens common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, seq_lens=seq_lens, num_reqs=num_reqs, num_actual_tokens=num_tokens, - max_query_len=max_query_len, + max_query_len=num_tokens, ) for kv_cache_group_id, kv_cache_group_spec in enumerate( @@ -2369,12 +2321,42 @@ class GPUModelRunner(LoRAModelRunnerMixin): with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): - outputs = self._run_model( - attn_metadata, - num_tokens, - is_dummy_run=True, - num_tokens_across_dp=num_tokens_across_dp, - ) + model = self.model + if self.is_multimodal_model: + input_ids = None + inputs_embeds = self.inputs_embeds[:num_tokens] + else: + input_ids = self.input_ids[:num_tokens] + inputs_embeds = None + if self.uses_mrope: + positions = self.mrope_positions[:, :num_tokens] + else: + positions = self.positions[:num_tokens] + + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + if self.intermediate_tensors is None: + self.intermediate_tensors = ( + self.model.make_empty_intermediate_tensors( + batch_size=self.max_num_tokens, + dtype=self.model_config.dtype, + device=self.device)) + + intermediate_tensors = self.sync_and_slice_intermediate_tensors( + slice(0, num_tokens), None, False) + + with self.maybe_randomize_inputs(input_ids), set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp): + outputs = model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) if self.use_aux_hidden_state_outputs: hidden_states, _ = outputs else: