From 040242820069b46aac9c7287fac2497b2d3c1124 Mon Sep 17 00:00:00 2001 From: Lehua Ding Date: Sat, 25 Oct 2025 04:45:36 +0800 Subject: [PATCH] [Perf][Async Scheduling] Remove CPU->GPU sync in dummy_run (#27455) Signed-off-by: Lehua Ding --- vllm/v1/worker/gpu_model_runner.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a08d2262f0f30..efa88d5c68f6b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3492,7 +3492,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.eplb_step(is_dummy=True, is_profile=is_profile) logit_indices = np.cumsum(num_scheduled_tokens) - 1 - return hidden_states, hidden_states[logit_indices] + logit_indices_device = torch.from_numpy(logit_indices).to( + self.device, non_blocking=True + ) + return hidden_states, hidden_states[logit_indices_device] @torch.inference_mode() def _dummy_sampler_run(