diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9d0f26266f0c5..3539f75612050 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1903,7 +1903,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): **self._init_model_kwargs(num_scheduled_tokens), **self._extract_mm_kwargs(scheduler_output), } - elif (self.enable_prompt_embeds and get_pp_group().is_first_rank): + elif self.enable_prompt_embeds and get_pp_group().is_first_rank: # Get the input embeddings for the tokens that are not input embeds, # then put them into the appropriate positions. # TODO(qthequartermasterman): Since even when prompt embeds are @@ -2125,6 +2125,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): invalid_req_indices, ) + @contextmanager + def synchronize_input_prep(self): + if self.prepare_inputs_event is None: + yield + return + + # Ensure prior step has finished with reused CPU tensors. + # This is required in the async scheduling case because + # the CPU->GPU transfer happens async. + self.prepare_inputs_event.synchronize() + try: + yield + finally: + self.prepare_inputs_event.record() + @torch.inference_mode() def execute_model( self, @@ -2132,33 +2147,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): intermediate_tensors: Optional[IntermediateTensors] = None, ) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]: with record_function_or_nullcontext("Preprocess"): - self._update_states(scheduler_output) - if not scheduler_output.total_num_scheduled_tokens: - if not has_kv_transfer_group(): - # Return empty ModelRunnerOutput if there's no work to do. - return EMPTY_MODEL_RUNNER_OUTPUT - return self.kv_connector_no_forward(scheduler_output, - self.vllm_config) - if self.cache_config.kv_sharing_fast_prefill: - assert not self.input_batch.num_prompt_logprobs, ( - "--kv-sharing-fast-prefill produces incorrect logprobs for " - "prompt tokens, tokens, please disable it when the requests" - " need prompt logprobs") + with self.synchronize_input_prep(): + # Update persistent batch states. + self._update_states(scheduler_output) + + if not scheduler_output.total_num_scheduled_tokens: + if not has_kv_transfer_group(): + # Return empty ModelRunnerOutput if no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT + return self.kv_connector_no_forward( + scheduler_output, self.vllm_config) + if self.cache_config.kv_sharing_fast_prefill: + assert not self.input_batch.num_prompt_logprobs, ( + "--kv-sharing-fast-prefill produces incorrect " + "logprobs for prompt tokens, tokens, please disable " + "it when the requests need prompt logprobs") - if self.prepare_inputs_event is not None: - # Ensure prior step has finished with reused CPU tensors. - self.prepare_inputs_event.synchronize() - try: # Prepare the decoder inputs. (attn_metadata, logits_indices, spec_decode_metadata, num_scheduled_tokens_np, spec_decode_common_attn_metadata, max_query_len, ubatch_slices, num_tokens_after_padding ) = self._prepare_inputs(scheduler_output) - finally: - if self.prepare_inputs_event is not None: - self.prepare_inputs_event.record() - ( num_scheduled_tokens, num_input_tokens,