diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 2a8ff746f1125..0102ca4739ad5 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2663,6 +2663,18 @@ class GPUModelRunner( return make_empty_encoder_model_runner_output(scheduler_output) if not num_scheduled_tokens: + if ( + self.parallel_config.distributed_executor_backend + == "external_launcher" + and self.parallel_config.data_parallel_size > 1 + ): + # this is a corner case when both external launcher + # and DP are enabled, num_scheduled_tokens could be + # 0, and has_unfinished_requests in the outer loop + # returns True. before returning early here we call + # dummy run to ensure coordinate_batch_across_dp + # is called into to avoid out of sync issues. + self._dummy_run(1) if not has_kv_transfer_group(): # Return empty ModelRunnerOutput if no work to do. return EMPTY_MODEL_RUNNER_OUTPUT