mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-08 12:27:05 +08:00
dummy run
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
52ca2f517a
commit
af65838d1f
@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
@ -69,6 +70,43 @@ class InputBatch:
|
||||
# [num_reqs]
|
||||
logits_indices: torch.Tensor
|
||||
|
||||
@classmethod
|
||||
def make_dummy(
|
||||
cls,
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
device: torch.device,
|
||||
) -> "InputBatch":
|
||||
assert 0 < num_reqs <= num_tokens
|
||||
req_ids = [f"req_{i}" for i in range(num_reqs)]
|
||||
idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
|
||||
idx_mapping = torch.tensor(idx_mapping_np, device=device)
|
||||
num_scheduled_tokens = np.full(num_reqs,
|
||||
num_tokens // num_reqs,
|
||||
dtype=np.int32)
|
||||
num_scheduled_tokens[-1] += num_tokens % num_reqs
|
||||
is_chunked_prefilling = np.zeros(num_reqs, dtype=np.bool_)
|
||||
input_ids = torch.zeros(num_tokens, dtype=torch.int32, device=device)
|
||||
positions = torch.zeros(num_tokens, dtype=torch.int64, device=device)
|
||||
attn_metadata = defaultdict(lambda: None)
|
||||
logits_indices = torch.arange(num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
return cls(
|
||||
req_ids=req_ids,
|
||||
num_reqs=num_reqs,
|
||||
idx_mapping=idx_mapping,
|
||||
idx_mapping_np=idx_mapping_np,
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
num_tokens=num_tokens,
|
||||
num_tokens_after_padding=num_tokens,
|
||||
is_chunked_prefilling=is_chunked_prefilling,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
attn_metadata=attn_metadata,
|
||||
logits_indices=logits_indices,
|
||||
)
|
||||
|
||||
|
||||
# NOTE: With the type annotations, this function is pre-compiled
|
||||
# before the first call.
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import gc
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -24,7 +25,7 @@ from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.input_batch import (InputBatch, InputBuffers,
|
||||
prepare_inputs)
|
||||
from vllm.v1.worker.gpu.sampler import Sampler
|
||||
from vllm.v1.worker.gpu.states import RequestState
|
||||
from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -129,15 +130,58 @@ class GPUModelRunner:
|
||||
self.device,
|
||||
)
|
||||
|
||||
def _dummy_run(self, num_tokens: int, *args, **kwargs) -> None:
|
||||
return None, None
|
||||
def _dummy_run(
|
||||
self,
|
||||
num_tokens: int,
|
||||
*args,
|
||||
input_batch: Optional[InputBatch] = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if input_batch is None:
|
||||
input_batch = InputBatch.make_dummy(
|
||||
num_reqs=min(num_tokens, self.max_num_reqs),
|
||||
num_tokens=num_tokens,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def _dummy_sampler_run(self, hidden_states: torch.Tensor, *args,
|
||||
**kwargs) -> None:
|
||||
return None
|
||||
with set_forward_context(
|
||||
input_batch.attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
):
|
||||
hidden_states = self.model(
|
||||
input_ids=input_batch.input_ids[:num_tokens],
|
||||
positions=input_batch.positions[:num_tokens],
|
||||
)
|
||||
sample_hidden_states = hidden_states[input_batch.logits_indices]
|
||||
return hidden_states, sample_hidden_states
|
||||
|
||||
def profile_run(self):
|
||||
pass
|
||||
def _dummy_sampler_run(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> None:
|
||||
num_reqs = hidden_states.shape[0]
|
||||
sampling_metadata = SamplingMetadata.make_dummy(
|
||||
num_reqs=num_reqs,
|
||||
device=self.device,
|
||||
)
|
||||
logits = self.model.compute_logits(hidden_states, None)
|
||||
self.sampler(logits, sampling_metadata)
|
||||
|
||||
def profile_run(self) -> None:
|
||||
input_batch = InputBatch.make_dummy(
|
||||
num_reqs=self.max_num_reqs,
|
||||
num_tokens=self.max_num_tokens,
|
||||
device=self.device,
|
||||
)
|
||||
hidden_states, sample_hidden_states = self._dummy_run(
|
||||
self.max_num_tokens,
|
||||
input_batch=input_batch,
|
||||
)
|
||||
self._dummy_sampler_run(sample_hidden_states)
|
||||
torch.cuda.synchronize()
|
||||
del hidden_states, sample_hidden_states
|
||||
gc.collect()
|
||||
|
||||
def update_states(self, scheduler_output: SchedulerOutput) -> None:
|
||||
# for req_id in scheduler_output.preempted_req_ids:
|
||||
|
||||
@ -29,6 +29,30 @@ class SamplingMetadata:
|
||||
# None means no logprobs, 0 means sampled token logprobs only
|
||||
max_num_logprobs: Optional[int]
|
||||
|
||||
@classmethod
|
||||
def make_dummy(
|
||||
cls,
|
||||
num_reqs: int,
|
||||
device: torch.device,
|
||||
) -> "SamplingMetadata":
|
||||
assert num_reqs > 0
|
||||
temperature = torch.zeros(num_reqs, dtype=torch.float32, device=device)
|
||||
temperature[0] = 0.5
|
||||
top_p = torch.ones(num_reqs, dtype=torch.float32, device=device)
|
||||
top_p[0] = 0.99
|
||||
top_k = torch.ones(num_reqs, dtype=torch.int32, device=device)
|
||||
seeds = torch.zeros(num_reqs, dtype=torch.int64, device=device)
|
||||
pos = torch.zeros(num_reqs, dtype=torch.int64, device=device)
|
||||
max_num_logprobs = 20
|
||||
return cls(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
seeds=seeds,
|
||||
pos=pos,
|
||||
max_num_logprobs=max_num_logprobs,
|
||||
)
|
||||
|
||||
|
||||
class RequestState:
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user