vllm/vllm/v1/worker/gpu/input_batch.py
Woosuk Kwon af65838d1f dummy run
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-09-18 18:29:18 -07:00

164 lines
5.3 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from dataclasses import dataclass
from typing import Any
import numba
import numpy as np
import torch
from numba import types
from vllm.v1.utils import CpuGpuBuffer
class InputBuffers:
def __init__(
self,
max_num_reqs: int,
max_num_tokens: int,
device: torch.device,
pin_memory: bool,
):
self.max_num_reqs = max_num_reqs
self.max_num_tokens = max_num_tokens
self.device = device
self.pin_memory = pin_memory
self.idx_mapping = self._make_buffer(max_num_reqs, dtype=torch.int32)
self.input_ids = self._make_buffer(max_num_tokens, dtype=torch.int32)
self.positions = self._make_buffer(max_num_tokens, dtype=torch.int64)
self.query_start_loc = self._make_buffer(max_num_reqs + 1,
dtype=torch.int32)
self.seq_lens = self._make_buffer(max_num_reqs, dtype=torch.int32)
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(*args,
dtype=dtype,
pin_memory=self.pin_memory,
device=self.device)
@dataclass
class InputBatch:
# batch_idx -> req_id
req_ids: list[str]
num_reqs: int
# batch_idx -> req_state_idx
idx_mapping: torch.Tensor
idx_mapping_np: np.ndarray
# batch_idx -> num_scheduled_tokens
num_scheduled_tokens: np.ndarray
# sum(num_scheduled_tokens)
num_tokens: int
num_tokens_after_padding: int
# [num_reqs]
is_chunked_prefilling: np.ndarray
# [max_num_batched_tokens]
input_ids: torch.Tensor
# [max_num_batched_tokens]
positions: torch.Tensor
# layer_name -> Metadata
attn_metadata: dict[str, Any]
# [num_reqs]
logits_indices: torch.Tensor
@classmethod
def make_dummy(
cls,
num_reqs: int,
num_tokens: int,
device: torch.device,
) -> "InputBatch":
assert 0 < num_reqs <= num_tokens
req_ids = [f"req_{i}" for i in range(num_reqs)]
idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
idx_mapping = torch.tensor(idx_mapping_np, device=device)
num_scheduled_tokens = np.full(num_reqs,
num_tokens // num_reqs,
dtype=np.int32)
num_scheduled_tokens[-1] += num_tokens % num_reqs
is_chunked_prefilling = np.zeros(num_reqs, dtype=np.bool_)
input_ids = torch.zeros(num_tokens, dtype=torch.int32, device=device)
positions = torch.zeros(num_tokens, dtype=torch.int64, device=device)
attn_metadata = defaultdict(lambda: None)
logits_indices = torch.arange(num_reqs,
dtype=torch.int32,
device=device)
return cls(
req_ids=req_ids,
num_reqs=num_reqs,
idx_mapping=idx_mapping,
idx_mapping_np=idx_mapping_np,
num_scheduled_tokens=num_scheduled_tokens,
num_tokens=num_tokens,
num_tokens_after_padding=num_tokens,
is_chunked_prefilling=is_chunked_prefilling,
input_ids=input_ids,
positions=positions,
attn_metadata=attn_metadata,
logits_indices=logits_indices,
)
# NOTE: With the type annotations, this function is pre-compiled
# before the first call.
@numba.jit(
[
types.none(
types.int32[:], # idx_mapping
types.int32[:, :], # token_ids
types.int32[:], # num_computed_tokens
types.int32[:], # num_scheduled_tokens
types.int32[:], # input_ids
types.int64[:], # positions
types.int32[:], # query_start_loc
types.int32[:], # seq_lens
)
],
nopython=True,
cache=True,
)
def prepare_inputs(
idx_mapping: np.ndarray, # batch_idx -> req_idx
token_ids: np.ndarray, # [N, max_model_len]
num_computed_tokens: np.ndarray, # [N]
num_scheduled_tokens: np.ndarray, # [B]
input_ids: np.ndarray, # [num_input_tokens]
positions: np.ndarray, # [num_input_tokens]
query_start_loc: np.ndarray, # [B + 1]
seq_lens: np.ndarray, # [B]
) -> None:
num_reqs = num_scheduled_tokens.shape[0]
query_start_loc[0] = 0
cu_num_tokens = 0
for i in range(num_reqs):
req_idx = idx_mapping[i]
query_len = num_scheduled_tokens[i]
start = num_computed_tokens[req_idx]
end = start + query_len
seq_lens[i] = end
start_idx = cu_num_tokens
end_idx = start_idx + query_len
input_ids[start_idx:end_idx] = token_ids[req_idx, start:end]
positions[start_idx:end_idx] = np.arange(start, end, dtype=np.int64)
cu_num_tokens = end_idx
query_start_loc[i + 1] = cu_num_tokens
# Pad the inputs for CUDA graphs.
# Note: pad query_start_loc to be non-decreasing, as kernels
# like FlashAttention requires that
query_start_loc[num_reqs + 1:].fill(cu_num_tokens)
# Fill unused with 0 for full cuda graph mode.
seq_lens[num_reqs:].fill(0)