mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-08 14:07:04 +08:00
164 lines
5.3 KiB
Python
164 lines
5.3 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
import numba
|
|
import numpy as np
|
|
import torch
|
|
from numba import types
|
|
|
|
from vllm.v1.utils import CpuGpuBuffer
|
|
|
|
|
|
class InputBuffers:
|
|
|
|
def __init__(
|
|
self,
|
|
max_num_reqs: int,
|
|
max_num_tokens: int,
|
|
device: torch.device,
|
|
pin_memory: bool,
|
|
):
|
|
self.max_num_reqs = max_num_reqs
|
|
self.max_num_tokens = max_num_tokens
|
|
self.device = device
|
|
self.pin_memory = pin_memory
|
|
|
|
self.idx_mapping = self._make_buffer(max_num_reqs, dtype=torch.int32)
|
|
self.input_ids = self._make_buffer(max_num_tokens, dtype=torch.int32)
|
|
self.positions = self._make_buffer(max_num_tokens, dtype=torch.int64)
|
|
self.query_start_loc = self._make_buffer(max_num_reqs + 1,
|
|
dtype=torch.int32)
|
|
self.seq_lens = self._make_buffer(max_num_reqs, dtype=torch.int32)
|
|
|
|
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
|
|
return CpuGpuBuffer(*args,
|
|
dtype=dtype,
|
|
pin_memory=self.pin_memory,
|
|
device=self.device)
|
|
|
|
|
|
@dataclass
|
|
class InputBatch:
|
|
|
|
# batch_idx -> req_id
|
|
req_ids: list[str]
|
|
num_reqs: int
|
|
|
|
# batch_idx -> req_state_idx
|
|
idx_mapping: torch.Tensor
|
|
idx_mapping_np: np.ndarray
|
|
|
|
# batch_idx -> num_scheduled_tokens
|
|
num_scheduled_tokens: np.ndarray
|
|
# sum(num_scheduled_tokens)
|
|
num_tokens: int
|
|
num_tokens_after_padding: int
|
|
# [num_reqs]
|
|
is_chunked_prefilling: np.ndarray
|
|
|
|
# [max_num_batched_tokens]
|
|
input_ids: torch.Tensor
|
|
# [max_num_batched_tokens]
|
|
positions: torch.Tensor
|
|
|
|
# layer_name -> Metadata
|
|
attn_metadata: dict[str, Any]
|
|
|
|
# [num_reqs]
|
|
logits_indices: torch.Tensor
|
|
|
|
@classmethod
|
|
def make_dummy(
|
|
cls,
|
|
num_reqs: int,
|
|
num_tokens: int,
|
|
device: torch.device,
|
|
) -> "InputBatch":
|
|
assert 0 < num_reqs <= num_tokens
|
|
req_ids = [f"req_{i}" for i in range(num_reqs)]
|
|
idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
|
|
idx_mapping = torch.tensor(idx_mapping_np, device=device)
|
|
num_scheduled_tokens = np.full(num_reqs,
|
|
num_tokens // num_reqs,
|
|
dtype=np.int32)
|
|
num_scheduled_tokens[-1] += num_tokens % num_reqs
|
|
is_chunked_prefilling = np.zeros(num_reqs, dtype=np.bool_)
|
|
input_ids = torch.zeros(num_tokens, dtype=torch.int32, device=device)
|
|
positions = torch.zeros(num_tokens, dtype=torch.int64, device=device)
|
|
attn_metadata = defaultdict(lambda: None)
|
|
logits_indices = torch.arange(num_reqs,
|
|
dtype=torch.int32,
|
|
device=device)
|
|
return cls(
|
|
req_ids=req_ids,
|
|
num_reqs=num_reqs,
|
|
idx_mapping=idx_mapping,
|
|
idx_mapping_np=idx_mapping_np,
|
|
num_scheduled_tokens=num_scheduled_tokens,
|
|
num_tokens=num_tokens,
|
|
num_tokens_after_padding=num_tokens,
|
|
is_chunked_prefilling=is_chunked_prefilling,
|
|
input_ids=input_ids,
|
|
positions=positions,
|
|
attn_metadata=attn_metadata,
|
|
logits_indices=logits_indices,
|
|
)
|
|
|
|
|
|
# NOTE: With the type annotations, this function is pre-compiled
|
|
# before the first call.
|
|
@numba.jit(
|
|
[
|
|
types.none(
|
|
types.int32[:], # idx_mapping
|
|
types.int32[:, :], # token_ids
|
|
types.int32[:], # num_computed_tokens
|
|
types.int32[:], # num_scheduled_tokens
|
|
types.int32[:], # input_ids
|
|
types.int64[:], # positions
|
|
types.int32[:], # query_start_loc
|
|
types.int32[:], # seq_lens
|
|
)
|
|
],
|
|
nopython=True,
|
|
cache=True,
|
|
)
|
|
def prepare_inputs(
|
|
idx_mapping: np.ndarray, # batch_idx -> req_idx
|
|
token_ids: np.ndarray, # [N, max_model_len]
|
|
num_computed_tokens: np.ndarray, # [N]
|
|
num_scheduled_tokens: np.ndarray, # [B]
|
|
input_ids: np.ndarray, # [num_input_tokens]
|
|
positions: np.ndarray, # [num_input_tokens]
|
|
query_start_loc: np.ndarray, # [B + 1]
|
|
seq_lens: np.ndarray, # [B]
|
|
) -> None:
|
|
num_reqs = num_scheduled_tokens.shape[0]
|
|
query_start_loc[0] = 0
|
|
|
|
cu_num_tokens = 0
|
|
for i in range(num_reqs):
|
|
req_idx = idx_mapping[i]
|
|
query_len = num_scheduled_tokens[i]
|
|
start = num_computed_tokens[req_idx]
|
|
end = start + query_len
|
|
seq_lens[i] = end
|
|
|
|
start_idx = cu_num_tokens
|
|
end_idx = start_idx + query_len
|
|
input_ids[start_idx:end_idx] = token_ids[req_idx, start:end]
|
|
positions[start_idx:end_idx] = np.arange(start, end, dtype=np.int64)
|
|
|
|
cu_num_tokens = end_idx
|
|
query_start_loc[i + 1] = cu_num_tokens
|
|
|
|
# Pad the inputs for CUDA graphs.
|
|
# Note: pad query_start_loc to be non-decreasing, as kernels
|
|
# like FlashAttention requires that
|
|
query_start_loc[num_reqs + 1:].fill(cu_num_tokens)
|
|
# Fill unused with 0 for full cuda graph mode.
|
|
seq_lens[num_reqs:].fill(0)
|