[Model Runner V2] Minor code cleanup (#29570)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-11-26 19:47:17 -08:00 committed by GitHub
parent 9bb33c8919
commit 0aeb698b77
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 18 additions and 18 deletions

View File

@ -16,6 +16,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
from vllm.v1.worker.gpu.input_batch import InputBuffers from vllm.v1.worker.gpu.input_batch import InputBuffers
@ -127,15 +128,7 @@ class CudaGraphManager:
slot_mappings=slot_mappings, slot_mappings=slot_mappings,
kv_cache_config=kv_cache_config, kv_cache_config=kv_cache_config,
) )
if self.dp_size > 1: num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, batch_size)
num_tokens_across_dp = torch.full(
(self.dp_size,),
batch_size,
dtype=torch.int32,
device="cpu",
)
else:
num_tokens_across_dp = None
# Warm up. # Warm up.
with set_forward_context( with set_forward_context(

View File

@ -20,3 +20,12 @@ def get_batch_metadata_across_dp(
tensor[1][dp_rank] = cudagraph_size tensor[1][dp_rank] = cudagraph_size
dist.all_reduce(tensor, group=group) dist.all_reduce(tensor, group=group)
return tensor[0], tensor[1] return tensor[0], tensor[1]
def make_num_tokens_across_dp(
dp_size: int,
num_tokens: int,
) -> torch.Tensor | None:
if dp_size == 1:
return None
return torch.full((dp_size,), num_tokens, dtype=torch.int32, device="cpu")

View File

@ -35,7 +35,10 @@ from vllm.v1.worker.gpu.attn_utils import (
) )
from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
from vllm.v1.worker.gpu.dp_utils import get_batch_metadata_across_dp from vllm.v1.worker.gpu.dp_utils import (
get_batch_metadata_across_dp,
make_num_tokens_across_dp,
)
from vllm.v1.worker.gpu.input_batch import ( from vllm.v1.worker.gpu.input_batch import (
InputBatch, InputBatch,
InputBuffers, InputBuffers,
@ -255,12 +258,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if not skip_attn: if not skip_attn:
self.prepare_dummy_attn_metadata(input_batch) self.prepare_dummy_attn_metadata(input_batch)
if self.dp_size == 1: num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
num_tokens_across_dp: torch.Tensor | None = None
else:
num_tokens_across_dp = torch.full(
(self.dp_size,), num_tokens, dtype=torch.int32, device="cpu"
)
num_sampled_tokens = np.ones(input_batch.num_reqs, dtype=np.int32) num_sampled_tokens = np.ones(input_batch.num_reqs, dtype=np.int32)
with ( with (
self.maybe_dummy_run_with_lora( self.maybe_dummy_run_with_lora(
@ -816,7 +814,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.req_states.last_sampled_tokens, self.req_states.last_sampled_tokens,
next_prefill_tokens, next_prefill_tokens,
) )
self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens
return draft_tokens return draft_tokens
def get_cudagraph_and_dp_padding( def get_cudagraph_and_dp_padding(
@ -1006,7 +1003,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_batch, sampler_output.sampled_token_ids, num_sampled, num_rejected input_batch, sampler_output.sampled_token_ids, num_sampled, num_rejected
) )
if self.do_spec_decode: if self.do_spec_decode:
_ = self.propose_draft( draft_tokens = self.propose_draft(
input_batch, input_batch,
sampling_metadata, sampling_metadata,
hidden_states, hidden_states,
@ -1014,6 +1011,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_sampled, num_sampled,
num_rejected, num_rejected,
) )
self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens
if self.use_async_scheduling: if self.use_async_scheduling:
return async_output return async_output