dummy run

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-09-18 18:29:18 -07:00
parent 52ca2f517a
commit af65838d1f
3 changed files with 115 additions and 9 deletions

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from dataclasses import dataclass
from typing import Any
@ -69,6 +70,43 @@ class InputBatch:
# [num_reqs]
logits_indices: torch.Tensor
@classmethod
def make_dummy(
cls,
num_reqs: int,
num_tokens: int,
device: torch.device,
) -> "InputBatch":
assert 0 < num_reqs <= num_tokens
req_ids = [f"req_{i}" for i in range(num_reqs)]
idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
idx_mapping = torch.tensor(idx_mapping_np, device=device)
num_scheduled_tokens = np.full(num_reqs,
num_tokens // num_reqs,
dtype=np.int32)
num_scheduled_tokens[-1] += num_tokens % num_reqs
is_chunked_prefilling = np.zeros(num_reqs, dtype=np.bool_)
input_ids = torch.zeros(num_tokens, dtype=torch.int32, device=device)
positions = torch.zeros(num_tokens, dtype=torch.int64, device=device)
attn_metadata = defaultdict(lambda: None)
logits_indices = torch.arange(num_reqs,
dtype=torch.int32,
device=device)
return cls(
req_ids=req_ids,
num_reqs=num_reqs,
idx_mapping=idx_mapping,
idx_mapping_np=idx_mapping_np,
num_scheduled_tokens=num_scheduled_tokens,
num_tokens=num_tokens,
num_tokens_after_padding=num_tokens,
is_chunked_prefilling=is_chunked_prefilling,
input_ids=input_ids,
positions=positions,
attn_metadata=attn_metadata,
logits_indices=logits_indices,
)
# NOTE: With the type annotations, this function is pre-compiled
# before the first call.

View File

@ -1,8 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
import time
from copy import deepcopy
from typing import Any
from typing import Any, Optional
import numpy as np
import torch
@ -24,7 +25,7 @@ from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.input_batch import (InputBatch, InputBuffers,
prepare_inputs)
from vllm.v1.worker.gpu.sampler import Sampler
from vllm.v1.worker.gpu.states import RequestState
from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata
logger = init_logger(__name__)
@ -129,15 +130,58 @@ class GPUModelRunner:
self.device,
)
def _dummy_run(self, num_tokens: int, *args, **kwargs) -> None:
return None, None
def _dummy_run(
self,
num_tokens: int,
*args,
input_batch: Optional[InputBatch] = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
if input_batch is None:
input_batch = InputBatch.make_dummy(
num_reqs=min(num_tokens, self.max_num_reqs),
num_tokens=num_tokens,
device=self.device,
)
def _dummy_sampler_run(self, hidden_states: torch.Tensor, *args,
**kwargs) -> None:
return None
with set_forward_context(
input_batch.attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
):
hidden_states = self.model(
input_ids=input_batch.input_ids[:num_tokens],
positions=input_batch.positions[:num_tokens],
)
sample_hidden_states = hidden_states[input_batch.logits_indices]
return hidden_states, sample_hidden_states
def profile_run(self):
pass
def _dummy_sampler_run(
self,
hidden_states: torch.Tensor,
) -> None:
num_reqs = hidden_states.shape[0]
sampling_metadata = SamplingMetadata.make_dummy(
num_reqs=num_reqs,
device=self.device,
)
logits = self.model.compute_logits(hidden_states, None)
self.sampler(logits, sampling_metadata)
def profile_run(self) -> None:
input_batch = InputBatch.make_dummy(
num_reqs=self.max_num_reqs,
num_tokens=self.max_num_tokens,
device=self.device,
)
hidden_states, sample_hidden_states = self._dummy_run(
self.max_num_tokens,
input_batch=input_batch,
)
self._dummy_sampler_run(sample_hidden_states)
torch.cuda.synchronize()
del hidden_states, sample_hidden_states
gc.collect()
def update_states(self, scheduler_output: SchedulerOutput) -> None:
# for req_id in scheduler_output.preempted_req_ids:

View File

@ -29,6 +29,30 @@ class SamplingMetadata:
# None means no logprobs, 0 means sampled token logprobs only
max_num_logprobs: Optional[int]
@classmethod
def make_dummy(
cls,
num_reqs: int,
device: torch.device,
) -> "SamplingMetadata":
assert num_reqs > 0
temperature = torch.zeros(num_reqs, dtype=torch.float32, device=device)
temperature[0] = 0.5
top_p = torch.ones(num_reqs, dtype=torch.float32, device=device)
top_p[0] = 0.99
top_k = torch.ones(num_reqs, dtype=torch.int32, device=device)
seeds = torch.zeros(num_reqs, dtype=torch.int64, device=device)
pos = torch.zeros(num_reqs, dtype=torch.int64, device=device)
max_num_logprobs = 20
return cls(
temperature=temperature,
top_p=top_p,
top_k=top_k,
seeds=seeds,
pos=pos,
max_num_logprobs=max_num_logprobs,
)
class RequestState: