From af65838d1f9fd004d6585063a1b700d8e98f1888 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 18 Sep 2025 18:29:18 -0700 Subject: [PATCH] dummy run Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu/input_batch.py | 38 ++++++++++++++++++ vllm/v1/worker/gpu/model_runner.py | 62 +++++++++++++++++++++++++----- vllm/v1/worker/gpu/states.py | 24 ++++++++++++ 3 files changed, 115 insertions(+), 9 deletions(-) diff --git a/vllm/v1/worker/gpu/input_batch.py b/vllm/v1/worker/gpu/input_batch.py index 76b9ef37b6fd1..e1000e52c5f69 100644 --- a/vllm/v1/worker/gpu/input_batch.py +++ b/vllm/v1/worker/gpu/input_batch.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections import defaultdict from dataclasses import dataclass from typing import Any @@ -69,6 +70,43 @@ class InputBatch: # [num_reqs] logits_indices: torch.Tensor + @classmethod + def make_dummy( + cls, + num_reqs: int, + num_tokens: int, + device: torch.device, + ) -> "InputBatch": + assert 0 < num_reqs <= num_tokens + req_ids = [f"req_{i}" for i in range(num_reqs)] + idx_mapping_np = np.arange(num_reqs, dtype=np.int32) + idx_mapping = torch.tensor(idx_mapping_np, device=device) + num_scheduled_tokens = np.full(num_reqs, + num_tokens // num_reqs, + dtype=np.int32) + num_scheduled_tokens[-1] += num_tokens % num_reqs + is_chunked_prefilling = np.zeros(num_reqs, dtype=np.bool_) + input_ids = torch.zeros(num_tokens, dtype=torch.int32, device=device) + positions = torch.zeros(num_tokens, dtype=torch.int64, device=device) + attn_metadata = defaultdict(lambda: None) + logits_indices = torch.arange(num_reqs, + dtype=torch.int32, + device=device) + return cls( + req_ids=req_ids, + num_reqs=num_reqs, + idx_mapping=idx_mapping, + idx_mapping_np=idx_mapping_np, + num_scheduled_tokens=num_scheduled_tokens, + num_tokens=num_tokens, + num_tokens_after_padding=num_tokens, + is_chunked_prefilling=is_chunked_prefilling, + input_ids=input_ids, + positions=positions, + attn_metadata=attn_metadata, + logits_indices=logits_indices, + ) + # NOTE: With the type annotations, this function is pre-compiled # before the first call. diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index cf44f7d52b0b5..4600e6315a369 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -1,8 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import gc import time from copy import deepcopy -from typing import Any +from typing import Any, Optional import numpy as np import torch @@ -24,7 +25,7 @@ from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.input_batch import (InputBatch, InputBuffers, prepare_inputs) from vllm.v1.worker.gpu.sampler import Sampler -from vllm.v1.worker.gpu.states import RequestState +from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata logger = init_logger(__name__) @@ -129,15 +130,58 @@ class GPUModelRunner: self.device, ) - def _dummy_run(self, num_tokens: int, *args, **kwargs) -> None: - return None, None + def _dummy_run( + self, + num_tokens: int, + *args, + input_batch: Optional[InputBatch] = None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + if input_batch is None: + input_batch = InputBatch.make_dummy( + num_reqs=min(num_tokens, self.max_num_reqs), + num_tokens=num_tokens, + device=self.device, + ) - def _dummy_sampler_run(self, hidden_states: torch.Tensor, *args, - **kwargs) -> None: - return None + with set_forward_context( + input_batch.attn_metadata, + self.vllm_config, + num_tokens=num_tokens, + ): + hidden_states = self.model( + input_ids=input_batch.input_ids[:num_tokens], + positions=input_batch.positions[:num_tokens], + ) + sample_hidden_states = hidden_states[input_batch.logits_indices] + return hidden_states, sample_hidden_states - def profile_run(self): - pass + def _dummy_sampler_run( + self, + hidden_states: torch.Tensor, + ) -> None: + num_reqs = hidden_states.shape[0] + sampling_metadata = SamplingMetadata.make_dummy( + num_reqs=num_reqs, + device=self.device, + ) + logits = self.model.compute_logits(hidden_states, None) + self.sampler(logits, sampling_metadata) + + def profile_run(self) -> None: + input_batch = InputBatch.make_dummy( + num_reqs=self.max_num_reqs, + num_tokens=self.max_num_tokens, + device=self.device, + ) + hidden_states, sample_hidden_states = self._dummy_run( + self.max_num_tokens, + input_batch=input_batch, + ) + self._dummy_sampler_run(sample_hidden_states) + torch.cuda.synchronize() + del hidden_states, sample_hidden_states + gc.collect() def update_states(self, scheduler_output: SchedulerOutput) -> None: # for req_id in scheduler_output.preempted_req_ids: diff --git a/vllm/v1/worker/gpu/states.py b/vllm/v1/worker/gpu/states.py index 1d315c9fee205..4deabd2439097 100644 --- a/vllm/v1/worker/gpu/states.py +++ b/vllm/v1/worker/gpu/states.py @@ -29,6 +29,30 @@ class SamplingMetadata: # None means no logprobs, 0 means sampled token logprobs only max_num_logprobs: Optional[int] + @classmethod + def make_dummy( + cls, + num_reqs: int, + device: torch.device, + ) -> "SamplingMetadata": + assert num_reqs > 0 + temperature = torch.zeros(num_reqs, dtype=torch.float32, device=device) + temperature[0] = 0.5 + top_p = torch.ones(num_reqs, dtype=torch.float32, device=device) + top_p[0] = 0.99 + top_k = torch.ones(num_reqs, dtype=torch.int32, device=device) + seeds = torch.zeros(num_reqs, dtype=torch.int64, device=device) + pos = torch.zeros(num_reqs, dtype=torch.int64, device=device) + max_num_logprobs = 20 + return cls( + temperature=temperature, + top_p=top_p, + top_k=top_k, + seeds=seeds, + pos=pos, + max_num_logprobs=max_num_logprobs, + ) + class RequestState: