From 0aeb698b774e2d8593b14988e3af9ebbdd773730 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 26 Nov 2025 19:47:17 -0800 Subject: [PATCH] [Model Runner V2] Minor code cleanup (#29570) Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu/cudagraph_utils.py | 11 ++--------- vllm/v1/worker/gpu/dp_utils.py | 9 +++++++++ vllm/v1/worker/gpu/model_runner.py | 16 +++++++--------- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/vllm/v1/worker/gpu/cudagraph_utils.py b/vllm/v1/worker/gpu/cudagraph_utils.py index ba783e2d0c6f..6b056641c903 100644 --- a/vllm/v1/worker/gpu/cudagraph_utils.py +++ b/vllm/v1/worker/gpu/cudagraph_utils.py @@ -16,6 +16,7 @@ from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.worker.gpu.attn_utils import build_attn_metadata from vllm.v1.worker.gpu.block_table import BlockTables +from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp from vllm.v1.worker.gpu.input_batch import InputBuffers @@ -127,15 +128,7 @@ class CudaGraphManager: slot_mappings=slot_mappings, kv_cache_config=kv_cache_config, ) - if self.dp_size > 1: - num_tokens_across_dp = torch.full( - (self.dp_size,), - batch_size, - dtype=torch.int32, - device="cpu", - ) - else: - num_tokens_across_dp = None + num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, batch_size) # Warm up. with set_forward_context( diff --git a/vllm/v1/worker/gpu/dp_utils.py b/vllm/v1/worker/gpu/dp_utils.py index 9bfc7f25bef3..d71d91d1e5cb 100644 --- a/vllm/v1/worker/gpu/dp_utils.py +++ b/vllm/v1/worker/gpu/dp_utils.py @@ -20,3 +20,12 @@ def get_batch_metadata_across_dp( tensor[1][dp_rank] = cudagraph_size dist.all_reduce(tensor, group=group) return tensor[0], tensor[1] + + +def make_num_tokens_across_dp( + dp_size: int, + num_tokens: int, +) -> torch.Tensor | None: + if dp_size == 1: + return None + return torch.full((dp_size,), num_tokens, dtype=torch.int32, device="cpu") diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index e34a45f97980..6a78776b0a8a 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -35,7 +35,10 @@ from vllm.v1.worker.gpu.attn_utils import ( ) from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager -from vllm.v1.worker.gpu.dp_utils import get_batch_metadata_across_dp +from vllm.v1.worker.gpu.dp_utils import ( + get_batch_metadata_across_dp, + make_num_tokens_across_dp, +) from vllm.v1.worker.gpu.input_batch import ( InputBatch, InputBuffers, @@ -255,12 +258,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if not skip_attn: self.prepare_dummy_attn_metadata(input_batch) - if self.dp_size == 1: - num_tokens_across_dp: torch.Tensor | None = None - else: - num_tokens_across_dp = torch.full( - (self.dp_size,), num_tokens, dtype=torch.int32, device="cpu" - ) + num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens) num_sampled_tokens = np.ones(input_batch.num_reqs, dtype=np.int32) with ( self.maybe_dummy_run_with_lora( @@ -816,7 +814,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.req_states.last_sampled_tokens, next_prefill_tokens, ) - self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens return draft_tokens def get_cudagraph_and_dp_padding( @@ -1006,7 +1003,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): input_batch, sampler_output.sampled_token_ids, num_sampled, num_rejected ) if self.do_spec_decode: - _ = self.propose_draft( + draft_tokens = self.propose_draft( input_batch, sampling_metadata, hidden_states, @@ -1014,6 +1011,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_sampled, num_rejected, ) + self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens if self.use_async_scheduling: return async_output