mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-19 19:27:05 +08:00
wip
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
48bca9a109
commit
a1e3745150
@ -90,9 +90,9 @@ class Sampler(nn.Module):
|
||||
# Apply bad words exclusion.
|
||||
logits = self.apply_bad_words(logits, sampling_metadata)
|
||||
|
||||
# Apply logits processors which can impact greedy sampling
|
||||
for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
|
||||
logits = processor.apply(logits)
|
||||
# # Apply logits processors which can impact greedy sampling
|
||||
# for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
|
||||
# logits = processor.apply(logits)
|
||||
|
||||
# Apply penalties (e.g., min_tokens, freq_penalties).
|
||||
logits = self.apply_penalties(logits, sampling_metadata)
|
||||
@ -167,10 +167,10 @@ class Sampler(nn.Module):
|
||||
# Apply temperature.
|
||||
logits = self.apply_temperature(logits, sampling_metadata.temperature)
|
||||
|
||||
# Apply logits processors that only apply to random sampling
|
||||
# (argmax invariant)
|
||||
for processor in sampling_metadata.logitsprocs.argmax_invariant:
|
||||
logits = processor.apply(logits)
|
||||
# # Apply logits processors that only apply to random sampling
|
||||
# # (argmax invariant)
|
||||
# for processor in sampling_metadata.logitsprocs.argmax_invariant:
|
||||
# logits = processor.apply(logits)
|
||||
|
||||
# Apply top_k and/or top_p.
|
||||
random_sampled, processed_logprobs = self.topk_topp_sampler(
|
||||
|
||||
312
vllm/v1/worker/gpu_block_table.py
Normal file
312
vllm/v1/worker/gpu_block_table.py
Normal file
@ -0,0 +1,312 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.worker.utils import CpuGpuBuffer
|
||||
|
||||
PAD_SLOT_ID = -1
|
||||
|
||||
|
||||
class BlockTables:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_sizes: list[int],
|
||||
max_num_reqs: int,
|
||||
max_num_cached_reqs: int,
|
||||
max_num_batched_tokens: int,
|
||||
max_model_len: int,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
):
|
||||
self.block_sizes = block_sizes
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_num_cached_reqs = max_num_cached_reqs
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.max_model_len = max_model_len
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
|
||||
self.num_kv_cache_groups = len(self.block_sizes)
|
||||
# [num_kv_cache_groups, max_num_reqs, max_num_blocks]
|
||||
self.block_tables: list[torch.Tensor] = []
|
||||
# [num_kv_cache_groups, max_num_cached_reqs, max_num_blocks]
|
||||
self.block_table_buffers: list[torch.Tensor] = []
|
||||
# [num_kv_cache_groups, max_num_reqs]
|
||||
self.num_blocks: list[torch.Tensor] = []
|
||||
# [num_kv_cache_groups, max_num_tokens]
|
||||
self.slot_mappings: list[torch.Tensor] = []
|
||||
for i in range(self.num_kv_cache_groups):
|
||||
block_size = self.block_sizes[i]
|
||||
max_num_blocks = cdiv(self.max_model_len, block_size)
|
||||
|
||||
block_table = torch.zeros(
|
||||
self.max_num_reqs,
|
||||
max_num_blocks,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
self.block_tables.append(block_table)
|
||||
|
||||
block_table_buffer = torch.zeros(
|
||||
self.max_num_cached_reqs,
|
||||
max_num_blocks,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
self.block_table_buffers.append(block_table_buffer)
|
||||
|
||||
num_blocks = torch.zeros(self.max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
self.num_blocks.append(num_blocks)
|
||||
|
||||
slot_mapping = torch.zeros(self.max_num_batched_tokens,
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
self.slot_mappings.append(slot_mapping)
|
||||
|
||||
self.block_table_ptrs = self._make_ptr_tensor(self.block_tables)
|
||||
self.buffer_ptrs = self._make_ptr_tensor(self.block_table_buffers)
|
||||
self.block_table_strides = torch.tensor(
|
||||
[b.stride(0) for b in self.block_tables],
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
self.num_blocks_ptrs = self._make_ptr_tensor(self.num_blocks)
|
||||
self.block_sizes_tensor = torch.tensor(self.block_sizes,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
self.slot_mapping_ptrs = self._make_ptr_tensor(self.slot_mappings)
|
||||
|
||||
# Misc buffers.
|
||||
self.req_indices = self._make_buffer(self.max_num_reqs, torch.int32)
|
||||
self.overwrite = self._make_buffer(self.max_num_reqs, torch.bool)
|
||||
self.cu_num_new_blocks: list[CpuGpuBuffer] = []
|
||||
self.new_block_ids: list[CpuGpuBuffer] = []
|
||||
for i in range(self.num_kv_cache_groups):
|
||||
self.cu_num_new_blocks.append(
|
||||
self._make_buffer(self.max_num_reqs + 1, torch.int32))
|
||||
# NOTE(woosuk): Here, we assume that total number of new blocks
|
||||
# is ALWAYS less than max_num_batched_tokens.
|
||||
# TODO(woosuk): Rigorously verify that this assumption is correct.
|
||||
self.new_block_ids.append(
|
||||
self._make_buffer(self.max_num_batched_tokens, torch.int32))
|
||||
|
||||
def _make_buffer(self, n: int, dtype: torch.dtype) -> CpuGpuBuffer:
|
||||
return CpuGpuBuffer(n,
|
||||
dtype=dtype,
|
||||
pin_memory=self.pin_memory,
|
||||
device=self.device)
|
||||
|
||||
def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor:
|
||||
ptrs_tensor_cpu = torch.tensor([t.data_ptr() for t in x],
|
||||
dtype=torch.int64,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory)
|
||||
return ptrs_tensor_cpu.to(self.device, non_blocking=True)
|
||||
|
||||
def append_block_ids(
|
||||
self,
|
||||
# [num_reqs]
|
||||
req_indices: list[int],
|
||||
# [num_kv_cache_groups, num_reqs + 1]
|
||||
cu_num_new_blocks: list[list[int]],
|
||||
# [num_kv_cache_groups, num_new_blocks]
|
||||
new_block_ids: list[list[int]],
|
||||
# [num_reqs]
|
||||
overwrite: list[bool],
|
||||
) -> None:
|
||||
# TODO(woosuk): Optimize & simplify this.
|
||||
num_reqs = len(req_indices)
|
||||
self.req_indices.np[:num_reqs] = req_indices
|
||||
self.overwrite.np[:num_reqs] = overwrite
|
||||
for i in range(self.num_kv_cache_groups):
|
||||
self.cu_num_new_blocks[i].np[:num_reqs + 1] = cu_num_new_blocks[i]
|
||||
n = len(new_block_ids[i])
|
||||
self.new_block_ids[i].np[:n] = new_block_ids[i]
|
||||
|
||||
cu_num_new_blocks_ptrs = self._make_ptr_tensor(
|
||||
[x.copy_to_gpu(num_reqs + 1) for x in self.cu_num_new_blocks])
|
||||
new_block_ids_ptrs = self._make_ptr_tensor([
|
||||
x.copy_to_gpu(len(new_block_ids[i]))
|
||||
for i, x in enumerate(self.new_block_ids)
|
||||
])
|
||||
_append_block_ids_kernel[(num_reqs, self.num_kv_cache_groups)](
|
||||
self.req_indices.copy_to_gpu(num_reqs),
|
||||
cu_num_new_blocks_ptrs,
|
||||
new_block_ids_ptrs,
|
||||
self.overwrite.copy_to_gpu(num_reqs),
|
||||
self.block_table_strides,
|
||||
self.buffer_ptrs,
|
||||
self.num_blocks_ptrs,
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
|
||||
def compute_block_tables(
|
||||
self,
|
||||
idx_mapping: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
batch_size = idx_mapping.shape[0]
|
||||
_compute_block_tables_kernel[(batch_size, self.num_kv_cache_groups)](
|
||||
idx_mapping,
|
||||
self.buffer_ptrs,
|
||||
self.block_table_ptrs,
|
||||
self.block_table_strides,
|
||||
self.num_blocks_ptrs,
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
return tuple(b[:batch_size] for b in self.block_tables)
|
||||
|
||||
def compute_slot_mappings(
|
||||
self,
|
||||
cu_num_tokens: torch.Tensor,
|
||||
pos: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
num_tokens = pos.shape[0]
|
||||
num_reqs = cu_num_tokens.shape[0] - 1
|
||||
num_groups = self.num_kv_cache_groups
|
||||
_compute_slot_mappings_kernel[(num_reqs + 1, num_groups)](
|
||||
num_tokens,
|
||||
self.max_num_batched_tokens,
|
||||
cu_num_tokens,
|
||||
pos,
|
||||
self.block_table_ptrs,
|
||||
self.block_table_strides,
|
||||
self.block_sizes_tensor,
|
||||
self.slot_mapping_ptrs,
|
||||
PAD_ID=PAD_SLOT_ID,
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
return tuple(x[:num_tokens] for x in self.slot_mappings)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _append_block_ids_kernel(
|
||||
# Inputs
|
||||
req_indices, # [num_reqs]
|
||||
cu_num_new_block_ptrs, # [num_kv_cache_groups, num_reqs + 1]
|
||||
new_block_id_ptrs, # [num_kv_cache_groups, num_new_blocks]
|
||||
overwrite, # [num_reqs]
|
||||
block_table_strides, # [num_kv_cache_groups]
|
||||
# Outputs
|
||||
block_table_buffer_ptrs, # [num_kv_cache_groups]
|
||||
num_block_ptrs, # [num_kv_cache_groups]
|
||||
# Constants
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
group_id = tl.program_id(1)
|
||||
req_idx = tl.load(req_indices + batch_idx)
|
||||
do_overwrite = tl.load(overwrite + batch_idx)
|
||||
|
||||
cu_num_new_blocks_ptr = _load_ptr(cu_num_new_block_ptrs + group_id,
|
||||
tl.int32)
|
||||
start_idx = tl.load(cu_num_new_blocks_ptr + batch_idx)
|
||||
end_idx = tl.load(cu_num_new_blocks_ptr + batch_idx + 1)
|
||||
num_new_blocks = end_idx - start_idx
|
||||
|
||||
num_blocks_ptr = _load_ptr(num_block_ptrs + group_id, tl.int32)
|
||||
if do_overwrite:
|
||||
dst_start_idx = 0
|
||||
else:
|
||||
dst_start_idx = tl.load(num_blocks_ptr + req_idx)
|
||||
dst_end_idx = dst_start_idx + num_new_blocks
|
||||
tl.store(num_blocks_ptr + req_idx, dst_end_idx)
|
||||
|
||||
# Destination
|
||||
block_table_buffer_ptr = _load_ptr(block_table_buffer_ptrs + group_id,
|
||||
tl.int32)
|
||||
block_table_stride = tl.load(block_table_strides + group_id)
|
||||
buffer_row_ptr = block_table_buffer_ptr + req_idx * block_table_stride
|
||||
|
||||
new_block_ids_ptr = _load_ptr(new_block_id_ptrs + group_id, tl.int32)
|
||||
for i in tl.range(0, num_new_blocks, BLOCK_SIZE):
|
||||
offset = i + tl.arange(0, BLOCK_SIZE)
|
||||
block_ids = tl.load(new_block_ids_ptr + start_idx + offset,
|
||||
mask=offset < num_new_blocks)
|
||||
tl.store(buffer_row_ptr + dst_start_idx + offset,
|
||||
block_ids,
|
||||
mask=offset < num_new_blocks)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _compute_block_tables_kernel(
|
||||
batch_idx_to_req_idx, # [batch_size]
|
||||
src_block_table_ptrs, # [num_kv_cache_groups]
|
||||
dst_block_table_ptrs, # [num_kv_cache_groups]
|
||||
block_table_strides, # [num_kv_cache_groups]
|
||||
num_blocks_ptrs, # [num_kv_cache_groups]
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
# kv cache group id
|
||||
group_id = tl.program_id(1)
|
||||
req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
|
||||
|
||||
num_blocks_ptr = _load_ptr(num_blocks_ptrs + group_id, tl.int32)
|
||||
num_blocks = tl.load(num_blocks_ptr + req_idx)
|
||||
|
||||
stride = tl.load(block_table_strides + group_id)
|
||||
src_block_table_ptr = _load_ptr(src_block_table_ptrs + group_id, tl.int32)
|
||||
src_row_ptr = src_block_table_ptr + req_idx * stride
|
||||
dst_block_table_ptr = _load_ptr(dst_block_table_ptrs + group_id, tl.int32)
|
||||
dst_row_ptr = dst_block_table_ptr + batch_idx * stride
|
||||
|
||||
for i in tl.range(0, num_blocks, BLOCK_SIZE):
|
||||
offset = i + tl.arange(0, BLOCK_SIZE)
|
||||
block_ids = tl.load(src_row_ptr + offset, mask=offset < num_blocks)
|
||||
tl.store(dst_row_ptr + offset, block_ids, mask=offset < num_blocks)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _compute_slot_mappings_kernel(
|
||||
num_tokens,
|
||||
max_num_tokens,
|
||||
cu_num_tokens, # [num_reqs + 1]
|
||||
pos, # [num_tokens]
|
||||
block_table_ptrs, # [num_kv_cache_groups]
|
||||
block_table_strides, # [num_kv_cache_groups]
|
||||
page_sizes, # [num_kv_cache_groups]
|
||||
slot_mapping_ptrs, # [num_kv_cache_groups]
|
||||
PAD_ID: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
# kv cache group id
|
||||
group_id = tl.program_id(1)
|
||||
slot_mapping_ptr = _load_ptr(slot_mapping_ptrs + group_id, tl.int64)
|
||||
|
||||
if req_idx == tl.num_programs(0) - 1:
|
||||
# Pad remaining slots to -1. This is needed for CUDA graphs.
|
||||
for i in tl.range(num_tokens, max_num_tokens, BLOCK_SIZE):
|
||||
offset = i + tl.arange(0, BLOCK_SIZE)
|
||||
tl.store(slot_mapping_ptr + offset,
|
||||
PAD_ID,
|
||||
mask=offset < max_num_tokens)
|
||||
return
|
||||
|
||||
block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
|
||||
block_table_stride = tl.load(block_table_strides + group_id)
|
||||
page_size = tl.load(page_sizes + group_id)
|
||||
|
||||
start_idx = tl.load(cu_num_tokens + req_idx)
|
||||
end_idx = tl.load(cu_num_tokens + req_idx + 1)
|
||||
for i in tl.range(start_idx, end_idx, BLOCK_SIZE):
|
||||
offset = i + tl.arange(0, BLOCK_SIZE)
|
||||
positions = tl.load(pos + offset, mask=offset < end_idx, other=0)
|
||||
block_indices = positions // page_size
|
||||
block_numbers = tl.load(block_table_ptr +
|
||||
req_idx * block_table_stride + block_indices)
|
||||
slot_ids = block_numbers * page_size + positions % page_size
|
||||
tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _load_ptr(ptr_to_ptr, elem_dtype):
|
||||
ptr = tl.load(ptr_to_ptr)
|
||||
return tl.cast(ptr, tl.pointer_type(elem_dtype))
|
||||
@ -1,401 +1,14 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Datastructures defining a GPU input batch
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import numba
|
||||
import numpy as np
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from typing_extensions import deprecated
|
||||
from numba import types
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.inputs import (MultiModalKwargsItem,
|
||||
MultiModalKwargsItems, PlaceholderRange)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.utils import cdiv, get_cuda_view_from_cpu_tensor, is_uva_available
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessors
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
|
||||
PAD_SLOT_ID = -1
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestData:
|
||||
|
||||
mm_kwargs: list[MultiModalKwargsItem]
|
||||
mm_positions: list[PlaceholderRange]
|
||||
sampling_params: Optional[SamplingParams]
|
||||
pooling_params: Optional[PoolingParams]
|
||||
|
||||
# M-RoPE (only for Qwen2/2.5-VL)
|
||||
mrope_positions: Optional[torch.Tensor] = None
|
||||
mrope_position_delta: Optional[int] = None
|
||||
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
|
||||
# Temporary back-compatibility for plugins that define model runner
|
||||
@property
|
||||
@deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be "
|
||||
"removed in v0.13. Please use `mm_kwargs` instead.")
|
||||
def mm_inputs(self) -> list[MultiModalKwargsItems]:
|
||||
return [
|
||||
MultiModalKwargsItems.from_seq([item]) for item in self.mm_kwargs
|
||||
]
|
||||
|
||||
|
||||
class PerRequestAttribute:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_rows_cpu: int,
|
||||
num_cols: int,
|
||||
num_rows_gpu: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
is_scalar: bool = False,
|
||||
):
|
||||
assert is_uva_available(), "UVA is not available."
|
||||
self.cpu = torch.zeros(num_rows_cpu,
|
||||
num_cols,
|
||||
dtype=dtype,
|
||||
device="cpu",
|
||||
pin_memory=True)
|
||||
self.np = self.cpu.numpy()
|
||||
self.uva = get_cuda_view_from_cpu_tensor(self.cpu)
|
||||
self.gpu = torch.zeros(num_rows_gpu,
|
||||
num_cols,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
if is_scalar:
|
||||
assert num_cols == 1
|
||||
self.cpu.squeeze_(1)
|
||||
self.np = self.cpu.numpy()
|
||||
self.uva.squeeze_(1)
|
||||
self.gpu.squeeze_(1)
|
||||
|
||||
|
||||
class RequestState:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_model_len: int,
|
||||
max_num_batched_tokens: int,
|
||||
max_num_cached_reqs: int,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
vocab_size: int,
|
||||
block_sizes: list[int], # The block_size of each kv cache group
|
||||
logitsprocs: Optional[LogitsProcessors] = None,
|
||||
is_spec_decode: bool = False,
|
||||
is_pooling_model: bool = False,
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_model_len = max_model_len
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.max_num_cached_reqs = max_num_cached_reqs
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
self.vocab_size = vocab_size
|
||||
self.is_spec_decode = is_spec_decode
|
||||
self.pooling_params = None
|
||||
self.block_sizes = block_sizes
|
||||
self.num_prompt_logprobs = {}
|
||||
|
||||
self.req_id_to_index: dict[str, int] = {}
|
||||
self.index_to_req_id: dict[int, str] = {}
|
||||
self.free_indices = list(range(max_num_cached_reqs))
|
||||
# Used to construct the input batch.
|
||||
self._add_scalar_attr("idx_mapping", torch.int32)
|
||||
|
||||
# Request states.
|
||||
self.req_data: dict[int, RequestData] = {}
|
||||
# TODO(woosuk): Because the token_ids tensor can be very big, we only
|
||||
# initialize it on CPU memory.
|
||||
self._add_vector_attr_cpu("token_ids", self.max_model_len, torch.int32)
|
||||
self._add_scalar_attr("num_prompt_tokens", torch.int32)
|
||||
self._add_scalar_attr("num_tokens", torch.int32)
|
||||
self._add_scalar_attr("num_computed_tokens", torch.int32)
|
||||
|
||||
# Sampling-related.
|
||||
self._add_scalar_attr("temperature", torch.float32)
|
||||
self.greedy_reqs: set[str] = set()
|
||||
self.random_reqs: set[str] = set()
|
||||
self._add_scalar_attr("top_p", torch.float32)
|
||||
self.top_p_reqs: set[str] = set()
|
||||
self._add_scalar_attr("top_k", torch.int32)
|
||||
self.top_k_reqs: set[str] = set()
|
||||
self._add_scalar_attr("frequency_penalties", torch.float32)
|
||||
self.frequency_penalties_reqs: set[str] = set()
|
||||
self._add_scalar_attr("presence_penalties", torch.float32)
|
||||
self.presence_penalties_reqs: set[str] = set()
|
||||
self._add_scalar_attr("repetition_penalties", torch.float32)
|
||||
self.repetition_penalties_reqs: set[str] = set()
|
||||
|
||||
# req_idx -> generator
|
||||
self.generators: dict[int, torch.Generator] = {}
|
||||
|
||||
# Block table(s).
|
||||
self._init_block_tables()
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
req_id: str,
|
||||
prompt_token_ids: list[int],
|
||||
num_computed_tokens: int,
|
||||
block_ids: tuple[list[int], ...],
|
||||
sampling_params: SamplingParams,
|
||||
) -> None:
|
||||
req_idx = self.free_indices.pop()
|
||||
self.req_id_to_index[req_id] = req_idx
|
||||
self.index_to_req_id[req_idx] = req_id
|
||||
|
||||
self.num_prompt_tokens.np[req_idx] = len(prompt_token_ids)
|
||||
self.num_computed_tokens.np[req_idx] = num_computed_tokens
|
||||
self.append_token_ids(req_idx, prompt_token_ids)
|
||||
self.append_block_ids(req_idx, block_ids, overwrite=True)
|
||||
|
||||
self.temperature.np[req_idx] = sampling_params.temperature
|
||||
if sampling_params.sampling_type == SamplingType.GREEDY:
|
||||
# NOTE: Be careful about division by zero.
|
||||
self.greedy_reqs.add(req_id)
|
||||
elif sampling_params.sampling_type == SamplingType.RANDOM:
|
||||
self.random_reqs.add(req_id)
|
||||
|
||||
self.top_p.np[req_idx] = sampling_params.top_p
|
||||
if sampling_params.top_p < 1.0:
|
||||
self.top_p_reqs.add(req_id)
|
||||
|
||||
top_k = sampling_params.top_k
|
||||
if 0 < top_k < self.vocab_size:
|
||||
self.top_k_reqs.add(req_id)
|
||||
else:
|
||||
top_k = self.vocab_size
|
||||
self.top_k.np[req_idx] = top_k
|
||||
|
||||
self.frequency_penalties.np[
|
||||
req_idx] = sampling_params.frequency_penalty
|
||||
if sampling_params.frequency_penalty != 0.0:
|
||||
self.frequency_penalties_reqs.add(req_id)
|
||||
self.presence_penalties.np[req_idx] = sampling_params.presence_penalty
|
||||
if sampling_params.presence_penalty != 0.0:
|
||||
self.presence_penalties_reqs.add(req_id)
|
||||
self.repetition_penalties.np[
|
||||
req_idx] = sampling_params.repetition_penalty
|
||||
if sampling_params.repetition_penalty != 1.0:
|
||||
self.repetition_penalties_reqs.add(req_id)
|
||||
|
||||
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
|
||||
generator = torch.Generator(device=self.device)
|
||||
generator.manual_seed(sampling_params.seed)
|
||||
self.generators[req_idx] = generator
|
||||
|
||||
def append_token_ids(self, req_idx: int, token_ids: list[int]) -> None:
|
||||
start_idx = self.num_tokens.np[req_idx]
|
||||
end_idx = start_idx + len(token_ids)
|
||||
self.token_ids.np[req_idx, start_idx:end_idx] = token_ids
|
||||
self.num_tokens.np[req_idx] = end_idx
|
||||
|
||||
# TODO(woosuk): Further vectorize this to minimize overheads.
|
||||
def append_block_ids(
|
||||
self,
|
||||
req_idx: int,
|
||||
new_block_ids: tuple[list[int], ...],
|
||||
overwrite: bool,
|
||||
) -> None:
|
||||
for i in range(self.num_block_tables):
|
||||
block_table = self.block_tables[i]
|
||||
num_blocks = self.num_blocks[i]
|
||||
num_new_blocks = len(new_block_ids[i])
|
||||
if overwrite:
|
||||
# Replace the existing block IDs with the new ones.
|
||||
# This happens when the request is resumed from preemption.
|
||||
block_table.np[req_idx, :num_new_blocks] = new_block_ids[i]
|
||||
num_blocks.np[req_idx] = num_new_blocks
|
||||
else:
|
||||
# Append the new block IDs to the existing ones (common case).
|
||||
start_idx = num_blocks.np[req_idx]
|
||||
end_idx = start_idx + num_new_blocks
|
||||
block_table.np[req_idx, start_idx:end_idx] = new_block_ids[i]
|
||||
num_blocks.np[req_idx] = end_idx
|
||||
|
||||
def remove_request(self, req_id: str) -> None:
|
||||
req_idx = self.req_id_to_index.pop(req_id, None)
|
||||
if req_idx is None:
|
||||
# Request not found.
|
||||
return
|
||||
self.index_to_req_id.pop(req_idx, None)
|
||||
self.free_indices.append(req_idx)
|
||||
|
||||
self.greedy_reqs.discard(req_id)
|
||||
self.random_reqs.discard(req_id)
|
||||
self.top_p_reqs.discard(req_id)
|
||||
self.top_k_reqs.discard(req_id)
|
||||
self.frequency_penalties_reqs.discard(req_id)
|
||||
self.presence_penalties_reqs.discard(req_id)
|
||||
self.repetition_penalties_reqs.discard(req_id)
|
||||
self.generators.pop(req_idx, None)
|
||||
|
||||
def get_index_mapping(self, idx_mapping: list[int]) -> torch.Tensor:
|
||||
num_reqs = len(idx_mapping)
|
||||
self.idx_mapping.np[:num_reqs] = idx_mapping
|
||||
return self.idx_mapping.gpu[:num_reqs].copy_(
|
||||
self.idx_mapping.uva[:num_reqs], non_blocking=True)
|
||||
|
||||
def make_sampling_metadata(
|
||||
self,
|
||||
batch_idx_to_req_idx: torch.Tensor,
|
||||
) -> SamplingMetadata:
|
||||
batch_size = batch_idx_to_req_idx.shape[0]
|
||||
_make_sampling_metadata_kernel[(batch_size, )](
|
||||
batch_idx_to_req_idx,
|
||||
self.temperature.uva,
|
||||
self.temperature.gpu,
|
||||
self.top_p.uva,
|
||||
self.top_p.gpu,
|
||||
self.top_k.uva,
|
||||
self.top_k.gpu,
|
||||
self.frequency_penalties.uva,
|
||||
self.frequency_penalties.gpu,
|
||||
self.presence_penalties.uva,
|
||||
self.presence_penalties.gpu,
|
||||
self.repetition_penalties.uva,
|
||||
self.repetition_penalties.gpu,
|
||||
num_warps=1,
|
||||
num_stages=1,
|
||||
)
|
||||
no_penalties = not (self.frequency_penalties_reqs
|
||||
or self.presence_penalties_reqs
|
||||
or self.repetition_penalties_reqs)
|
||||
return SamplingMetadata(
|
||||
temperature=self.temperature.gpu[:batch_size],
|
||||
all_greedy=not self.random_reqs,
|
||||
all_random=not self.greedy_reqs,
|
||||
top_p=self.top_p.gpu[:batch_size],
|
||||
top_k=self.top_k.gpu[:batch_size],
|
||||
frequency_penalties=self.frequency_penalties.gpu[:batch_size],
|
||||
presence_penalties=self.presence_penalties.gpu[:batch_size],
|
||||
repetition_penalties=self.repetition_penalties.gpu[:batch_size],
|
||||
no_penalties=no_penalties,
|
||||
# TODO
|
||||
generators={},
|
||||
token_ids=self.token_ids.gpu[:batch_size],
|
||||
max_num_logprobs=None,
|
||||
allowed_token_ids_mask=None,
|
||||
bad_words_token_ids={},
|
||||
logitsprocs=None,
|
||||
)
|
||||
|
||||
@property
|
||||
def num_reqs(self) -> int:
|
||||
return len(self.req_id_to_index)
|
||||
|
||||
def _add_scalar_attr(self, name: str, dtype: torch.dtype):
|
||||
attr = PerRequestAttribute(self.max_num_cached_reqs,
|
||||
1,
|
||||
self.max_num_reqs,
|
||||
dtype,
|
||||
self.device,
|
||||
is_scalar=True)
|
||||
setattr(self, name, attr)
|
||||
|
||||
def _add_vector_attr(self, name: str, max_len: int, dtype: torch.dtype):
|
||||
attr = PerRequestAttribute(self.max_num_cached_reqs, max_len,
|
||||
self.max_num_reqs, dtype, self.device)
|
||||
setattr(self, name, attr)
|
||||
|
||||
def _add_vector_attr_cpu(self, name: str, max_len: int,
|
||||
dtype: torch.dtype):
|
||||
attr = PerRequestAttribute(self.max_num_cached_reqs, max_len, 0, dtype,
|
||||
self.device)
|
||||
setattr(self, name, attr)
|
||||
|
||||
def _init_block_tables(self):
|
||||
self.num_block_tables = len(self.block_sizes)
|
||||
self.block_tables = []
|
||||
self.num_blocks = []
|
||||
self.slot_mappings: list[torch.Tensor] = []
|
||||
for i in range(self.num_block_tables):
|
||||
max_num_blocks = cdiv(self.max_model_len, self.block_sizes[i])
|
||||
block_table = PerRequestAttribute(self.max_num_cached_reqs,
|
||||
max_num_blocks,
|
||||
self.max_num_reqs, torch.int32,
|
||||
self.device)
|
||||
self.block_tables.append(block_table)
|
||||
num_blocks = PerRequestAttribute(self.max_num_cached_reqs,
|
||||
1,
|
||||
self.max_num_reqs,
|
||||
torch.int32,
|
||||
self.device,
|
||||
is_scalar=True)
|
||||
self.num_blocks.append(num_blocks)
|
||||
slot_mapping = torch.zeros(self.max_num_batched_tokens,
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
self.slot_mappings.append(slot_mapping)
|
||||
|
||||
def make_ptr_tensor(x: list[torch.Tensor]) -> torch.Tensor:
|
||||
return torch.tensor([t.data_ptr() for t in x],
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
|
||||
self.uva_block_table_ptrs = make_ptr_tensor(
|
||||
[b.uva for b in self.block_tables])
|
||||
self.gpu_block_table_ptrs = make_ptr_tensor(
|
||||
[b.gpu for b in self.block_tables])
|
||||
self.uva_num_blocks_ptrs = make_ptr_tensor(
|
||||
[n.uva for n in self.num_blocks])
|
||||
self.gpu_num_blocks_ptrs = make_ptr_tensor(
|
||||
[n.gpu for n in self.num_blocks])
|
||||
self.block_table_strides = torch.tensor(
|
||||
[b.gpu.shape[1] for b in self.block_tables],
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
self.block_sizes_tensor = torch.tensor(self.block_sizes,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
self.slot_mapping_ptrs = make_ptr_tensor(self.slot_mappings)
|
||||
|
||||
def make_block_tables(
|
||||
self,
|
||||
idx_mapping: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
batch_size = idx_mapping.shape[0]
|
||||
_make_block_tables_kernel[(batch_size, self.num_block_tables)](
|
||||
idx_mapping,
|
||||
self.uva_block_table_ptrs,
|
||||
self.gpu_block_table_ptrs,
|
||||
self.block_table_strides,
|
||||
self.uva_num_blocks_ptrs,
|
||||
self.gpu_num_blocks_ptrs,
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
return tuple(b.gpu[:batch_size] for b in self.block_tables)
|
||||
|
||||
def make_slot_mappings(
|
||||
self,
|
||||
cu_num_tokens: torch.Tensor,
|
||||
pos: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
num_tokens = pos.shape[0]
|
||||
num_reqs = cu_num_tokens.shape[0] - 1
|
||||
_make_slot_mappings_kernel[(num_reqs + 1, self.num_block_tables)](
|
||||
num_tokens,
|
||||
self.max_num_batched_tokens,
|
||||
cu_num_tokens,
|
||||
pos,
|
||||
self.gpu_block_table_ptrs,
|
||||
self.block_table_strides,
|
||||
self.block_sizes_tensor,
|
||||
self.slot_mapping_ptrs,
|
||||
PAD_ID=PAD_SLOT_ID,
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
return tuple(x[:num_tokens] for x in self.slot_mappings)
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -403,134 +16,78 @@ class InputBatch:
|
||||
|
||||
# batch_idx -> req_id
|
||||
req_ids: list[str]
|
||||
|
||||
# req_id -> batch_idx
|
||||
req_id_to_batch_idx: dict[str, int]
|
||||
|
||||
# batch_idx -> req_state_idx
|
||||
idx_mapping: torch.Tensor
|
||||
idx_mapping_np: np.ndarray
|
||||
|
||||
# [num_kv_cache_groups, num_reqs, max_num_blocks_per_req]
|
||||
block_tables: tuple[torch.Tensor, ...]
|
||||
# [num_kv_cache_groups, num_tokens]
|
||||
slot_mappings: tuple[torch.Tensor, ...]
|
||||
# batch_idx -> num_scheduled_tokens
|
||||
num_scheduled_tokens: np.ndarray
|
||||
total_num_tokens: int
|
||||
max_num_tokens: int
|
||||
num_reqs: int
|
||||
|
||||
# [num_reqs] mostly
|
||||
sampling_metadata: SamplingMetadata
|
||||
attn_metadata: dict[str, Any]
|
||||
spec_decode_common_attn_metadata: Optional[Any]
|
||||
spec_decode_metadata: Optional[SpecDecodeMetadata]
|
||||
|
||||
logits_indices: torch.Tensor
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _make_sampling_metadata_kernel(
|
||||
batch_idx_to_req_idx, # [batch_size]
|
||||
src_temperature,
|
||||
dst_temperature,
|
||||
src_top_p,
|
||||
dst_top_p,
|
||||
src_top_k,
|
||||
dst_top_k,
|
||||
src_frequency_penalties,
|
||||
dst_frequency_penalties,
|
||||
src_presence_penalties,
|
||||
dst_presence_penalties,
|
||||
src_repetition_penalties,
|
||||
dst_repetition_penalties,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
|
||||
# NOTE: With the type annotations, this function is pre-compiled
|
||||
# before the first call.
|
||||
@numba.jit(
|
||||
[
|
||||
types.none(
|
||||
types.int32[:], # idx_mapping
|
||||
types.int32[:, :], # token_ids
|
||||
types.int32[:], # num_computed_tokens
|
||||
types.int32[:], # num_scheduled_tokens
|
||||
types.int32[:], # input_ids
|
||||
types.int32[:], # query_start_loc
|
||||
types.int32[:], # seq_lens
|
||||
types.int64[:], # positions
|
||||
)
|
||||
],
|
||||
nopython=True,
|
||||
cache=True,
|
||||
)
|
||||
def prepare_inputs(
|
||||
# Inputs
|
||||
idx_mapping: np.ndarray, # batch_idx -> req_idx
|
||||
token_ids: np.ndarray, # [N, max_model_len]
|
||||
num_computed_tokens: np.ndarray, # [N]
|
||||
num_scheduled_tokens: np.ndarray, # [B]
|
||||
# Outputs
|
||||
input_ids: np.ndarray, # [num_input_tokens]
|
||||
query_start_loc: np.ndarray, # [B + 1]
|
||||
seq_lens: np.ndarray, # [B]
|
||||
positions: np.ndarray, # [num_input_tokens]
|
||||
) -> None:
|
||||
num_reqs = num_scheduled_tokens.shape[0]
|
||||
query_start_loc[0] = 0
|
||||
|
||||
temperature = tl.load(src_temperature + req_idx)
|
||||
tl.store(dst_temperature + batch_idx, temperature)
|
||||
cu_num_tokens = 0
|
||||
for i in range(num_reqs):
|
||||
req_idx = idx_mapping[i]
|
||||
start = num_computed_tokens[req_idx]
|
||||
end = start + num_scheduled_tokens[i]
|
||||
seq_lens[i] = end
|
||||
|
||||
top_p = tl.load(src_top_p + req_idx)
|
||||
tl.store(dst_top_p + batch_idx, top_p)
|
||||
start_idx = cu_num_tokens
|
||||
end_idx = start_idx + num_scheduled_tokens[i]
|
||||
input_ids[start_idx:end_idx] = token_ids[req_idx, start:end]
|
||||
positions[start_idx:end_idx] = np.arange(start, end)
|
||||
|
||||
top_k = tl.load(src_top_k + req_idx)
|
||||
tl.store(dst_top_k + batch_idx, top_k)
|
||||
cu_num_tokens = end_idx
|
||||
query_start_loc[i + 1] = cu_num_tokens
|
||||
|
||||
frequency_penalties = tl.load(src_frequency_penalties + req_idx)
|
||||
tl.store(dst_frequency_penalties + batch_idx, frequency_penalties)
|
||||
|
||||
presence_penalties = tl.load(src_presence_penalties + req_idx)
|
||||
tl.store(dst_presence_penalties + batch_idx, presence_penalties)
|
||||
|
||||
repetition_penalties = tl.load(src_repetition_penalties + req_idx)
|
||||
tl.store(dst_repetition_penalties + batch_idx, repetition_penalties)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _make_block_tables_kernel(
|
||||
batch_idx_to_req_idx, # [batch_size]
|
||||
src_block_table_ptrs, # [num_block_tables]
|
||||
dst_block_table_ptrs, # [num_block_tables]
|
||||
block_table_strides, # [num_block_tables]
|
||||
src_num_blocks_ptrs, # [num_block_tables]
|
||||
dst_num_blocks_ptrs, # [num_block_tables]
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
# kv cache group id
|
||||
group_id = tl.program_id(1)
|
||||
req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
|
||||
|
||||
src_num_blocks_ptr = _load_ptr(src_num_blocks_ptrs, group_id, tl.int32)
|
||||
dst_num_blocks_ptr = _load_ptr(dst_num_blocks_ptrs, group_id, tl.int32)
|
||||
num_blocks = tl.load(src_num_blocks_ptr + req_idx)
|
||||
tl.store(dst_num_blocks_ptr + batch_idx, num_blocks)
|
||||
|
||||
stride = tl.load(block_table_strides + group_id)
|
||||
src_block_table_ptr = _load_ptr(src_block_table_ptrs, group_id, tl.int32)
|
||||
src_row_ptr = src_block_table_ptr + req_idx * stride
|
||||
dst_block_table_ptr = _load_ptr(dst_block_table_ptrs, group_id, tl.int32)
|
||||
dst_row_ptr = dst_block_table_ptr + batch_idx * stride
|
||||
|
||||
for i in tl.range(0, num_blocks, BLOCK_SIZE):
|
||||
offset = i + tl.arange(0, BLOCK_SIZE)
|
||||
block_ids = tl.load(src_row_ptr + offset, mask=offset < num_blocks)
|
||||
tl.store(dst_row_ptr + offset, block_ids, mask=offset < num_blocks)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _make_slot_mappings_kernel(
|
||||
num_tokens,
|
||||
max_num_tokens,
|
||||
cu_num_tokens, # [num_reqs + 1]
|
||||
pos, # [num_tokens]
|
||||
block_table_ptrs, # [num_block_tables]
|
||||
block_table_strides, # [num_block_tables]
|
||||
page_sizes, # [num_block_tables]
|
||||
slot_mapping_ptrs, # [num_block_tables]
|
||||
PAD_ID: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
num_reqs = tl.num_programs(0)
|
||||
# kv cache group id
|
||||
group_id = tl.program_id(1)
|
||||
slot_mapping_ptr = _load_ptr(slot_mapping_ptrs, group_id, tl.int64)
|
||||
|
||||
if req_idx == num_reqs - 1:
|
||||
# Pad remaining slots to -1. This is needed for CUDA graphs.
|
||||
for i in tl.range(num_tokens, max_num_tokens, BLOCK_SIZE):
|
||||
offset = num_tokens + i + tl.arange(0, BLOCK_SIZE)
|
||||
tl.store(slot_mapping_ptr + offset,
|
||||
PAD_ID,
|
||||
mask=offset < max_num_tokens)
|
||||
return
|
||||
|
||||
block_table_ptr = _load_ptr(block_table_ptrs, group_id, tl.int32)
|
||||
block_table_stride = tl.load(block_table_strides + group_id)
|
||||
page_size = tl.load(page_sizes + group_id)
|
||||
|
||||
start_idx = tl.load(cu_num_tokens + req_idx)
|
||||
end_idx = tl.load(cu_num_tokens + req_idx + 1)
|
||||
for i in tl.range(start_idx, end_idx, BLOCK_SIZE):
|
||||
offset = start_idx + i + tl.arange(0, BLOCK_SIZE)
|
||||
positions = tl.load(pos + offset, mask=offset < end_idx, other=0)
|
||||
block_indices = positions // page_size
|
||||
block_numbers = tl.load(block_table_ptr +
|
||||
req_idx * block_table_stride + block_indices)
|
||||
slot_ids = block_numbers * page_size + positions % page_size
|
||||
tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _load_ptr(base, offset, elem_dtype):
|
||||
ptr = tl.load(base + offset)
|
||||
return tl.cast(ptr, tl.pointer_type(elem_dtype))
|
||||
# Pad the inputs for CUDA graphs.
|
||||
# Note: pad query_start_loc to be non-decreasing, as kernels
|
||||
# like FlashAttention requires that
|
||||
query_start_loc[num_reqs + 1:].fill(cu_num_tokens)
|
||||
# Fill unused with 0 for full cuda graph mode.
|
||||
seq_lens[num_reqs:].fill(0)
|
||||
|
||||
@ -68,7 +68,7 @@ from vllm.v1.kv_cache_interface import (AttentionSpec,
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds,
|
||||
LogprobsTensors, ModelRunnerOutput)
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessors
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.rejection_sampler import RejectionSampler
|
||||
from vllm.v1.sample.sampler import Sampler
|
||||
@ -76,7 +76,9 @@ from vllm.v1.spec_decode.eagle import EagleProposer
|
||||
from vllm.v1.spec_decode.medusa import MedusaProposer
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.gpu_block_table import BlockTables
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch, prepare_inputs
|
||||
from vllm.v1.worker.gpu_worker_states import RequestState
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
||||
KVConnectorModelRunnerMixin, KVConnectorOutput)
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
@ -200,18 +202,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.rejection_sampler = RejectionSampler()
|
||||
|
||||
# Request states.
|
||||
self.requests: dict[str, CachedRequestState] = {}
|
||||
|
||||
# Input Batch
|
||||
# NOTE(Chen): Ideally, we should initialize the input batch inside
|
||||
# `initialize_kv_cache` based on the kv cache config. However, as in
|
||||
# https://github.com/vllm-project/vllm/pull/18298, due to some unknown
|
||||
# reasons, we have to initialize the input batch before `load_model`,
|
||||
# quantization + weight offloading will fail otherwise. As a temporary
|
||||
# solution, we initialize the input batch here, and re-initialize it
|
||||
# in `initialize_kv_cache` if the block_sizes here is different from
|
||||
# the block_sizes in the kv cache config.
|
||||
self.input_batch = InputBatch(
|
||||
self.requests = RequestState(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_model_len=self.max_model_len,
|
||||
max_num_batched_tokens=self.max_num_tokens,
|
||||
@ -220,12 +211,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
pin_memory=self.pin_memory,
|
||||
vocab_size=self.model_config.get_vocab_size(),
|
||||
block_sizes=[self.cache_config.block_size],
|
||||
is_spec_decode=bool(self.vllm_config.speculative_config),
|
||||
logitsprocs=build_logitsprocs(
|
||||
self.vllm_config, self.device, self.pin_memory,
|
||||
self.is_pooling_model,
|
||||
self.vllm_config.model_config.logits_processors),
|
||||
is_pooling_model=self.is_pooling_model,
|
||||
)
|
||||
|
||||
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
|
||||
@ -253,9 +238,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.seq_lens = torch.zeros(self.max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
self.slot_mapping = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
|
||||
# None in the first PP rank. The rest are set after load_model.
|
||||
self.intermediate_tensors: Optional[IntermediateTensors] = None
|
||||
@ -290,12 +272,23 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
dtype=self.dtype,
|
||||
device=self.device)
|
||||
|
||||
self.block_tables = BlockTables(
|
||||
block_sizes=[self.cache_config.block_size],
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_num_cached_reqs=2 * self.max_num_reqs,
|
||||
max_num_batched_tokens=self.max_num_tokens,
|
||||
max_model_len=self.max_model_len,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
|
||||
# OPTIMIZATION: Cache the tensors rather than creating them every step.
|
||||
# Keep in int64 to avoid overflow with long context
|
||||
self.arange_np = np.arange(max(self.max_num_reqs + 1,
|
||||
self.max_model_len,
|
||||
self.max_num_tokens),
|
||||
dtype=np.int64)
|
||||
|
||||
# NOTE(woosuk): These tensors are "stateless", i.e., they are literally
|
||||
# a faster version of creating a new tensor every time. Thus, we should
|
||||
# not make any assumptions about the values in these tensors.
|
||||
@ -303,6 +296,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory)
|
||||
self.input_ids_np = self.input_ids_cpu.numpy()
|
||||
self.positions_cpu = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int64,
|
||||
device="cpu",
|
||||
@ -319,6 +313,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
pin_memory=self.pin_memory)
|
||||
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
||||
|
||||
self.index_mapping_cpu = torch.zeros(self.max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory)
|
||||
self.index_mapping_np = self.index_mapping_cpu.numpy()
|
||||
self.index_mapping = self.index_mapping_cpu.to(self.device)
|
||||
|
||||
# Layer pairings for cross-layer KV sharing.
|
||||
# If an Attention layer `layer_name` is in the keys of this dict, it
|
||||
# means this layer will perform attention using the keys and values
|
||||
@ -410,10 +411,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
The SamplingMetadata is updated and copied to the GPU if there is a
|
||||
new/resumed/paused/finished request in the batch.
|
||||
"""
|
||||
# Remove finished requests from the cached states.
|
||||
for req_id in scheduler_output.finished_req_ids:
|
||||
self.requests.pop(req_id, None)
|
||||
self.encoder_cache.pop(req_id, None)
|
||||
# Remove the finished requests from the persistent batch.
|
||||
# NOTE(woosuk): There could be an edge case where finished_req_ids and
|
||||
# scheduled_req_ids overlap. This happens when a request is aborted and
|
||||
@ -421,7 +418,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# distinct requests - clearing the cached states for the first request
|
||||
# and handling the second as a new request.
|
||||
for req_id in scheduler_output.finished_req_ids:
|
||||
self.input_batch.remove_request(req_id)
|
||||
self.requests.remove_request(req_id)
|
||||
self.encoder_cache.pop(req_id, None)
|
||||
|
||||
# Free the cached encoder outputs.
|
||||
for req_id, input_id in scheduler_output.free_encoder_input_ids:
|
||||
@ -431,120 +429,82 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
if not encoder_outputs:
|
||||
self.encoder_cache.pop(req_id, None)
|
||||
|
||||
# Remove the unscheduled requests from the persistent batch.
|
||||
# NOTE(woosuk): The unscheduled requests are either preempted requests
|
||||
# or running requests that are not scheduled in this step. We remove
|
||||
# them from the persistent batch but keep their cached states since
|
||||
# they will be scheduled again sometime in the future.
|
||||
scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
|
||||
cached_req_ids = self.input_batch.req_id_to_index.keys()
|
||||
unscheduled_req_ids = cached_req_ids - scheduled_req_ids
|
||||
# NOTE(woosuk): The persistent batch optimization assumes that
|
||||
# consecutive batches contain mostly the same requests. If batches
|
||||
# have low request overlap (e.g., alternating between two distinct
|
||||
# sets of requests), this optimization becomes very inefficient.
|
||||
for req_id in unscheduled_req_ids:
|
||||
self.input_batch.remove_request(req_id)
|
||||
req_indices: list[int] = []
|
||||
cu_num_new_blocks: list[list[int]] = [
|
||||
[0] for _ in range(self.block_tables.num_kv_cache_groups)
|
||||
]
|
||||
new_block_ids: list[list[int]] = [
|
||||
[] for _ in range(self.block_tables.num_kv_cache_groups)
|
||||
]
|
||||
overwrite: list[bool] = []
|
||||
|
||||
# Add new requests to the cached states.
|
||||
for new_req_data in scheduler_output.scheduled_new_reqs:
|
||||
req_id = new_req_data.req_id
|
||||
sampling_params = new_req_data.sampling_params
|
||||
pooling_params = new_req_data.pooling_params
|
||||
|
||||
if pooling_params:
|
||||
task = pooling_params.task
|
||||
assert task is not None, "You did not set `task` in the API"
|
||||
|
||||
model = cast(VllmModelForPooling, self.get_model())
|
||||
to_update = model.pooler.get_pooling_updates(task)
|
||||
to_update.apply(pooling_params)
|
||||
|
||||
req_state = CachedRequestState(
|
||||
req_id=req_id,
|
||||
mm_kwargs=new_req_data.mm_kwargs,
|
||||
mm_positions=new_req_data.mm_positions,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=pooling_params,
|
||||
lora_request=new_req_data.lora_request,
|
||||
)
|
||||
self.requests[req_id] = req_state
|
||||
self.input_batch.add_request(
|
||||
self.requests.add_request(
|
||||
req_id=req_id,
|
||||
prompt_token_ids=new_req_data.prompt_token_ids,
|
||||
num_computed_tokens=new_req_data.num_computed_tokens,
|
||||
block_ids=new_req_data.block_ids,
|
||||
sampling_params=sampling_params,
|
||||
sampling_params=new_req_data.sampling_params,
|
||||
)
|
||||
|
||||
req_index = self.requests.req_id_to_index[req_id]
|
||||
req_indices.append(req_index)
|
||||
for i, block_ids in enumerate(new_req_data.block_ids):
|
||||
x = cu_num_new_blocks[i][-1]
|
||||
cu_num_new_blocks[i].append(x + len(block_ids))
|
||||
new_block_ids[i].extend(block_ids)
|
||||
overwrite.append(True)
|
||||
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
if self.uses_mrope:
|
||||
self._init_mrope_positions(req_state)
|
||||
self._init_mrope_positions(req_id)
|
||||
|
||||
# Update the states of the running/resumed requests.
|
||||
is_last_rank = get_pp_group().is_last_rank
|
||||
req_data = scheduler_output.scheduled_cached_reqs
|
||||
for i, req_id in enumerate(req_data.req_ids):
|
||||
req_index = self.input_batch.req_id_to_index[req_id]
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
for i, req_id in enumerate(cached_reqs.req_ids):
|
||||
req_index = self.requests.req_id_to_index[req_id]
|
||||
|
||||
# Update input batch.
|
||||
if not is_last_rank:
|
||||
# When using PP, the scheduler sends the sampled tokens back,
|
||||
# because there's no direct communication between the first-
|
||||
# stage worker and the last-stage worker.
|
||||
new_token_ids = req_data.new_token_ids[i]
|
||||
self.input_batch.append_token_ids(req_index, new_token_ids)
|
||||
new_token_ids = cached_reqs.new_token_ids[i]
|
||||
self.requests.append_token_ids(req_index, new_token_ids)
|
||||
|
||||
new_block_ids = req_data.new_block_ids[i]
|
||||
if new_block_ids is not None:
|
||||
if cached_reqs.new_block_ids[i] is not None:
|
||||
req_indices.append(req_index)
|
||||
for i, block_ids in enumerate(cached_reqs.new_block_ids[i]):
|
||||
x = cu_num_new_blocks[i][-1]
|
||||
cu_num_new_blocks[i].append(x + len(block_ids))
|
||||
new_block_ids[i].extend(block_ids)
|
||||
# If the request is resumed from preemption, we need to
|
||||
# overwrite the existing block IDs.
|
||||
self.input_batch.append_block_ids(
|
||||
req_index,
|
||||
new_block_ids,
|
||||
overwrite=req_data.resumed_from_preemption[i],
|
||||
)
|
||||
overwrite.append(cached_reqs.resumed_from_preemption[i])
|
||||
|
||||
self.input_batch.num_computed_tokens.np[req_index] = (
|
||||
req_data.num_computed_tokens[i])
|
||||
self.requests.num_computed_tokens.np[req_index] = (
|
||||
cached_reqs.num_computed_tokens[i])
|
||||
|
||||
def _init_mrope_states(self, req_state: CachedRequestState) -> None:
|
||||
image_grid_thw = []
|
||||
video_grid_thw = []
|
||||
second_per_grid_ts = []
|
||||
audio_feature_lengths = []
|
||||
use_audio_in_video = False
|
||||
for mm_item in req_state.mm_kwargs:
|
||||
mm_input = mm_item.get_data()
|
||||
if (t := mm_input.get("image_grid_thw")) is not None:
|
||||
image_grid_thw.append(t.tolist())
|
||||
if (t := mm_input.get("video_grid_thw")) is not None:
|
||||
video_grid_thw.append(t.tolist())
|
||||
if (t := mm_input.get("second_per_grid_ts")) is not None:
|
||||
second_per_grid_ts.append(t)
|
||||
if (t := mm_input.get("audio_feature_lengths")) is not None:
|
||||
audio_feature_lengths.append(t)
|
||||
if mm_input.get("use_audio_in_video") is True:
|
||||
use_audio_in_video = True
|
||||
|
||||
req_state.mrope_positions, req_state.mrope_position_delta = \
|
||||
MRotaryEmbedding.get_input_positions_tensor(
|
||||
req_state.prompt_token_ids,
|
||||
hf_config=self.model_config.hf_config,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
audio_feature_lengths=audio_feature_lengths,
|
||||
use_audio_in_video=use_audio_in_video,
|
||||
if req_indices:
|
||||
self.block_tables.append_block_ids(
|
||||
req_indices=req_indices,
|
||||
cu_num_new_blocks=cu_num_new_blocks,
|
||||
new_block_ids=new_block_ids,
|
||||
overwrite=overwrite,
|
||||
)
|
||||
|
||||
def _init_mrope_positions(self, req_state: CachedRequestState):
|
||||
def _init_mrope_positions(self, req_id: str) -> None:
|
||||
req_idx = self.requests.req_id_to_index[req_id]
|
||||
req_data = self.requests.req_data[req_idx]
|
||||
|
||||
image_grid_thw = []
|
||||
video_grid_thw = []
|
||||
second_per_grid_ts = []
|
||||
audio_feature_lengths = []
|
||||
use_audio_in_video = False
|
||||
for mm_item in req_state.mm_kwargs:
|
||||
for mm_item in req_data.mm_kwargs:
|
||||
mm_input = mm_item.get_data()
|
||||
if (t := mm_input.get("image_grid_thw")) is not None:
|
||||
image_grid_thw.append(t.tolist())
|
||||
@ -557,9 +517,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
if mm_input.get("use_audio_in_video") is True:
|
||||
use_audio_in_video = True
|
||||
|
||||
req_state.mrope_positions, req_state.mrope_position_delta = \
|
||||
req_data.mrope_positions, req_data.mrope_position_delta = \
|
||||
MRotaryEmbedding.get_input_positions_tensor(
|
||||
req_state.prompt_token_ids,
|
||||
req_data.prompt_token_ids,
|
||||
hf_config=self.model_config.hf_config,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
@ -622,91 +582,55 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
def _prepare_inputs(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> tuple[dict[str, Any], torch.Tensor, Optional[SpecDecodeMetadata],
|
||||
np.ndarray, Optional[CommonAttentionMetadata], int]:
|
||||
"""
|
||||
:return: tuple[
|
||||
attn_metadata: layer-to-attention_metadata mapping,
|
||||
logits_indices, spec_decode_metadata
|
||||
]
|
||||
"""
|
||||
) -> InputBatch:
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
assert total_num_scheduled_tokens > 0
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
assert num_reqs > 0
|
||||
num_reqs = len(scheduler_output.num_scheduled_tokens)
|
||||
|
||||
# FIXME
|
||||
# batch_idx -> req_id
|
||||
req_ids = list(scheduler_output.num_scheduled_tokens.keys())
|
||||
req_ids = sorted(scheduler_output.num_scheduled_tokens,
|
||||
key=scheduler_output.num_scheduled_tokens.get)
|
||||
|
||||
# req_id -> batch_idx
|
||||
req_id_to_batch_idx = {req_id: i for i, req_id in enumerate(req_ids)}
|
||||
|
||||
# batch_idx -> req_idx
|
||||
idx_mapping = [
|
||||
self.input_batch.req_id_to_index[req_id] for req_id in req_ids
|
||||
idx_mapping_list = [
|
||||
self.requests.req_id_to_index[req_id] for req_id in req_ids
|
||||
]
|
||||
# batch_idx -> req_idx
|
||||
idx_mapping_tensor = self.input_batch.get_index_mapping(idx_mapping)
|
||||
num_reqs = len(req_ids)
|
||||
self.index_mapping_np[:num_reqs] = idx_mapping_list
|
||||
index_mapping_np = self.index_mapping_np[:num_reqs]
|
||||
idx_mapping = self.index_mapping[:num_reqs].copy_(
|
||||
self.index_mapping_cpu[:num_reqs], non_blocking=True)
|
||||
|
||||
# OPTIMIZATION: Start copying the block table first.
|
||||
# This way, we can overlap the copy with the following CPU operations.
|
||||
block_tables = self.input_batch.make_block_tables(idx_mapping_tensor)
|
||||
block_tables = self.block_tables.compute_block_tables(idx_mapping)
|
||||
|
||||
# Get the number of scheduled tokens for each request.
|
||||
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
|
||||
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
|
||||
max_num_scheduled_tokens = max(tokens)
|
||||
|
||||
# Get request indices.
|
||||
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
|
||||
req_indices = np.repeat(self.arange_np[:num_reqs],
|
||||
num_scheduled_tokens)
|
||||
|
||||
# cu_num_tokens: [2, 5, 3] -> [2, 7, 10]
|
||||
# arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
||||
cu_num_tokens, arange = self._get_cumsum_and_arange(
|
||||
num_scheduled_tokens)
|
||||
|
||||
# Get positions.
|
||||
positions_np = self.positions_np[:total_num_scheduled_tokens]
|
||||
np.add(self.input_batch.num_computed_tokens.np[req_indices],
|
||||
arange,
|
||||
out=positions_np)
|
||||
|
||||
prepare_inputs(
|
||||
idx_mapping=index_mapping_np,
|
||||
token_ids=self.requests.token_ids.np,
|
||||
num_computed_tokens=self.requests.num_computed_tokens.np,
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
input_ids=self.input_ids_np,
|
||||
query_start_loc=self.query_start_loc_np,
|
||||
seq_lens=self.seq_lens_np,
|
||||
positions=self.positions_np,
|
||||
)
|
||||
# Calculate M-RoPE positions.
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
if self.uses_mrope:
|
||||
self._calc_mrope_positions(scheduler_output)
|
||||
|
||||
# Get token indices.
|
||||
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
||||
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
|
||||
# where M is the max_model_len.
|
||||
token_indices = (positions_np +
|
||||
req_indices * self.input_batch.token_ids.np.shape[1])
|
||||
|
||||
# NOTE(woosuk): We use torch.index_select instead of np.take here
|
||||
# because torch.index_select is much faster than np.take for large
|
||||
# tensors.
|
||||
torch.index_select(self.input_batch.token_ids.cpu.flatten(),
|
||||
0,
|
||||
torch.from_numpy(token_indices),
|
||||
out=self.input_ids_cpu[:total_num_scheduled_tokens])
|
||||
|
||||
# Prepare the attention metadata.
|
||||
self.query_start_loc_np[0] = 0
|
||||
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
|
||||
# Note: pad query_start_loc to be non-decreasing, as kernels
|
||||
# like FlashAttention requires that
|
||||
self.query_start_loc_np[num_reqs + 1:].fill(cu_num_tokens[-1])
|
||||
self.query_start_loc.copy_(self.query_start_loc_cpu, non_blocking=True)
|
||||
query_start_loc = self.query_start_loc[:num_reqs + 1]
|
||||
|
||||
self.seq_lens_np[:num_reqs] = (
|
||||
self.input_batch.num_computed_tokens.np[:num_reqs] +
|
||||
num_scheduled_tokens)
|
||||
# Fill unused with 0 for full cuda graph mode.
|
||||
self.seq_lens_np[num_reqs:].fill(0)
|
||||
self.seq_lens.copy_(self.seq_lens_cpu, non_blocking=True)
|
||||
seq_lens = self.seq_lens[:num_reqs]
|
||||
max_seq_len = self.seq_lens_np[:num_reqs].max().item()
|
||||
@ -714,16 +638,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Copy the tensors to the GPU.
|
||||
self.input_ids[:total_num_scheduled_tokens].copy_(
|
||||
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
|
||||
# Common case (1D positions)
|
||||
self.positions[:total_num_scheduled_tokens].copy_(
|
||||
self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True)
|
||||
if self.uses_mrope:
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
|
||||
self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
|
||||
non_blocking=True)
|
||||
else:
|
||||
# Common case (1D positions)
|
||||
self.positions[:total_num_scheduled_tokens].copy_(
|
||||
self.positions_cpu[:total_num_scheduled_tokens],
|
||||
non_blocking=True)
|
||||
|
||||
use_spec_decode = len(
|
||||
scheduler_output.scheduled_spec_decode_tokens) > 0
|
||||
@ -737,16 +659,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
spec_decode_metadata = None
|
||||
else:
|
||||
# Get the number of draft tokens for each request.
|
||||
# Iterate over the dictionary rather than all requests since not all
|
||||
# requests have draft tokens.
|
||||
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
|
||||
for req_id, draft_token_ids in (
|
||||
scheduler_output.scheduled_spec_decode_tokens.items()):
|
||||
batch_idx = req_id_to_batch_idx[req_id]
|
||||
num_draft_tokens[batch_idx] = len(draft_token_ids)
|
||||
|
||||
for i, req_id in enumerate(req_ids):
|
||||
draft_token_ids = (
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
|
||||
if draft_token_ids:
|
||||
num_draft_tokens[i] = len(draft_token_ids)
|
||||
spec_decode_metadata = self._calc_spec_decode_metadata(
|
||||
num_draft_tokens, cu_num_tokens)
|
||||
num_draft_tokens, self.query_start_loc_np[1:num_reqs + 1])
|
||||
logits_indices = spec_decode_metadata.logits_indices
|
||||
|
||||
logits_indices_padded = None
|
||||
@ -774,15 +694,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded]
|
||||
)
|
||||
|
||||
attn_metadata: dict[str, Any] = {}
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||
query_start_loc, self.positions[:total_num_scheduled_tokens])
|
||||
|
||||
# Used in the below loop.
|
||||
query_start_loc_cpu = self.query_start_loc_cpu[:num_reqs + 1]
|
||||
seq_lens_cpu = self.seq_lens_cpu[:num_reqs]
|
||||
num_computed_tokens_cpu = (
|
||||
self.input_batch.num_computed_tokens.cpu[:num_reqs])
|
||||
self.requests.num_computed_tokens.cpu[:num_reqs])
|
||||
spec_decode_common_attn_metadata = None
|
||||
|
||||
attn_metadata: dict[str, Any] = {}
|
||||
# Prepare the attention metadata for each KV cache group and make layers
|
||||
# in the same group share the same metadata.
|
||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||
@ -804,14 +726,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
non_blocking=True)
|
||||
num_common_prefix_blocks = 0
|
||||
else:
|
||||
blk_table = self.input_batch.block_table[kv_cache_group_id]
|
||||
blk_table_tensor = blk_table.get_device_tensor()[:num_reqs]
|
||||
slot_mapping = blk_table.slot_mapping[:
|
||||
total_num_scheduled_tokens]
|
||||
|
||||
# Fill unused with -1. Needed for reshape_and_cache in full cuda
|
||||
# graph mode.
|
||||
blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1)
|
||||
blk_table_tensor = block_tables[kv_cache_group_id]
|
||||
slot_mapping = slot_mappings[kv_cache_group_id]
|
||||
num_common_prefix_blocks = (
|
||||
scheduler_output.
|
||||
num_common_prefix_blocks[kv_cache_group_id])
|
||||
@ -876,13 +792,24 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
continue
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
|
||||
# Hot-Swap lora model
|
||||
if self.lora_config:
|
||||
self.set_active_loras(self.input_batch, num_scheduled_tokens)
|
||||
# # Hot-Swap lora model
|
||||
# if self.lora_config:
|
||||
# self.set_active_loras(input_batch, num_scheduled_tokens)
|
||||
|
||||
return (attn_metadata, logits_indices, spec_decode_metadata,
|
||||
num_scheduled_tokens, spec_decode_common_attn_metadata,
|
||||
max_num_scheduled_tokens)
|
||||
return InputBatch(
|
||||
req_ids=req_ids,
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
req_id_to_batch_idx=req_id_to_batch_idx,
|
||||
idx_mapping=idx_mapping,
|
||||
idx_mapping_np=index_mapping_np,
|
||||
num_reqs=num_reqs,
|
||||
total_num_tokens=total_num_scheduled_tokens,
|
||||
max_num_tokens=max_num_scheduled_tokens,
|
||||
attn_metadata=attn_metadata,
|
||||
spec_decode_metadata=spec_decode_metadata,
|
||||
spec_decode_common_attn_metadata=spec_decode_common_attn_metadata,
|
||||
logits_indices=logits_indices,
|
||||
)
|
||||
|
||||
def _compute_cascade_attn_prefix_len(
|
||||
self,
|
||||
@ -955,7 +882,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_reqs = len(num_scheduled_tokens)
|
||||
common_prefix_len = min(
|
||||
common_prefix_len,
|
||||
self.input_batch.num_computed_tokens.np[:num_reqs].min())
|
||||
self.requests.num_computed_tokens.np[:num_reqs].min())
|
||||
# common_prefix_len should be a multiple of the block size.
|
||||
common_prefix_len = (common_prefix_len // kv_cache_spec.block_size *
|
||||
kv_cache_spec.block_size)
|
||||
@ -979,16 +906,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
)
|
||||
return common_prefix_len if use_cascade else 0
|
||||
|
||||
def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
|
||||
def _calc_mrope_positions(self, input_batch: InputBatch):
|
||||
mrope_pos_ptr = 0
|
||||
for index, req_id in enumerate(self.input_batch.req_ids):
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
req = self.requests[req_id]
|
||||
assert req.mrope_positions is not None
|
||||
|
||||
num_computed_tokens = \
|
||||
self.input_batch.num_computed_tokens_cpu[index]
|
||||
self.requests.num_computed_tokens_cpu[i]
|
||||
num_scheduled_tokens = \
|
||||
scheduler_output.num_scheduled_tokens[req_id]
|
||||
input_batch.num_scheduled_tokens[i]
|
||||
num_prompt_tokens = len(req.prompt_token_ids)
|
||||
|
||||
if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
|
||||
@ -1159,17 +1086,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
def _gather_mm_embeddings(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
input_batch: InputBatch,
|
||||
shift_computed_tokens: int = 0,
|
||||
) -> list[torch.Tensor]:
|
||||
mm_embeds: list[torch.Tensor] = []
|
||||
for req_id in self.input_batch.req_ids:
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
|
||||
req_id]
|
||||
req_state = self.requests[req_id]
|
||||
num_computed_tokens = \
|
||||
req_state.num_computed_tokens + shift_computed_tokens
|
||||
mm_positions = req_state.mm_positions
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
num_scheduled_tokens = input_batch.num_scheduled_tokens[i]
|
||||
req_idx = self.requests.req_id_to_index[req_id]
|
||||
num_computed_tokens = (
|
||||
self.requests.num_computed_tokens.np[req_idx] +
|
||||
shift_computed_tokens)
|
||||
req_data = self.requests.req_data[req_idx]
|
||||
mm_positions = req_data.mm_positions
|
||||
for i, pos_info in enumerate(mm_positions):
|
||||
start_pos = pos_info.offset
|
||||
num_encoder_tokens = pos_info.length
|
||||
@ -1274,8 +1202,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# request in the batch, as the logit indices are offset by this amount.
|
||||
struct_out_req_batch_indices: dict[str, int] = {}
|
||||
cumulative_offset = 0
|
||||
seq = sorted(self.input_batch.req_id_to_index.items(),
|
||||
key=lambda x: x[1])
|
||||
seq = sorted(self.requests.req_id_to_index.items(), key=lambda x: x[1])
|
||||
for req_id, batch_index in seq:
|
||||
logit_index = batch_index + cumulative_offset
|
||||
cumulative_offset += len(
|
||||
@ -1431,7 +1358,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
return ModelRunnerOutput(
|
||||
req_ids=self.input_batch.req_ids,
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
req_id_to_index=self.input_batch.req_id_to_batch_idx,
|
||||
sampled_token_ids=[],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
@ -1455,21 +1382,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.vllm_config)
|
||||
|
||||
# Prepare the decoder inputs.
|
||||
(attn_metadata, logits_indices, spec_decode_metadata,
|
||||
num_scheduled_tokens_np, spec_decode_common_attn_metadata,
|
||||
max_query_len) = self._prepare_inputs(scheduler_output)
|
||||
|
||||
# FIXME
|
||||
# batch_idx -> req_id
|
||||
req_ids = list(scheduler_output.num_scheduled_tokens.keys())
|
||||
# req_id -> batch_idx
|
||||
req_id_to_batch_idx = {req_id: i for i, req_id in enumerate(req_ids)}
|
||||
# batch_idx -> req_idx
|
||||
idx_mapping = [
|
||||
self.input_batch.req_id_to_index[req_id] for req_id in req_ids
|
||||
]
|
||||
# batch_idx -> req_idx
|
||||
idx_mapping_tensor = self.input_batch.get_index_mapping(idx_mapping)
|
||||
input_batch = self._prepare_inputs(scheduler_output)
|
||||
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
@ -1540,8 +1453,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
|
||||
num_input_tokens, intermediate_tensors, True)
|
||||
|
||||
uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
|
||||
num_scheduled_tokens == self.input_batch.num_reqs * max_query_len)
|
||||
uniform_decode = (input_batch.max_num_tokens
|
||||
== self.uniform_decode_query_len
|
||||
and num_scheduled_tokens
|
||||
== input_batch.num_reqs * input_batch.max_num_tokens)
|
||||
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
|
||||
uniform_decode=uniform_decode)
|
||||
cudagraph_runtime_mode, batch_descriptor = \
|
||||
@ -1550,7 +1465,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Run the model.
|
||||
# Use persistent buffers for CUDA graphs.
|
||||
with set_forward_context(
|
||||
attn_metadata,
|
||||
input_batch.attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
@ -1590,11 +1505,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
all_gather_group=get_tp_group())
|
||||
logits = None
|
||||
else:
|
||||
if self.input_batch.pooling_params:
|
||||
return self._pool(hidden_states, num_scheduled_tokens,
|
||||
num_scheduled_tokens_np, kv_connector_output)
|
||||
|
||||
sample_hidden_states = hidden_states[logits_indices]
|
||||
sample_hidden_states = hidden_states[input_batch.logits_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
if broadcast_pp_output:
|
||||
model_output_broadcast_data = {
|
||||
@ -1610,9 +1521,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.apply_grammar_bitmask(scheduler_output, logits)
|
||||
|
||||
# Sample the next token and get logprobs if needed.
|
||||
sampling_metadata = self.input_batch.make_sampling_metadata(
|
||||
idx_mapping_tensor)
|
||||
if spec_decode_metadata is None:
|
||||
sampling_metadata = self.requests.make_sampling_metadata(
|
||||
input_batch.idx_mapping)
|
||||
if input_batch.spec_decode_metadata is None:
|
||||
sampler_output = self.sampler(
|
||||
logits=logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
@ -1623,7 +1534,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# logits tensor. This means any in-place operations on bonus_logits
|
||||
# won't affect the original logits tensor.
|
||||
assert logits is not None
|
||||
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
|
||||
bonus_logits = logits[
|
||||
input_batch.spec_decode_metadata.bonus_logits_indices]
|
||||
sampler_output = self.sampler(
|
||||
logits=bonus_logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
@ -1633,9 +1545,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Just like `bonus_logits`, `target_logits` is a new tensor with
|
||||
# separate storage from the original `logits` tensor. Therefore,
|
||||
# it is safe to update `target_logits` in place.
|
||||
target_logits = logits[spec_decode_metadata.target_logits_indices]
|
||||
target_logits = logits[
|
||||
input_batch.spec_decode_metadata.target_logits_indices]
|
||||
output_token_ids = self.rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
input_batch.spec_decode_metadata,
|
||||
None, # draft_probs
|
||||
target_logits,
|
||||
bonus_token_ids,
|
||||
@ -1643,6 +1556,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
)
|
||||
sampler_output.sampled_token_ids = output_token_ids
|
||||
|
||||
for i in range(input_batch.num_reqs):
|
||||
req_idx = input_batch.idx_mapping_np[i]
|
||||
num_tokens = input_batch.num_scheduled_tokens[i]
|
||||
self.requests.num_computed_tokens.np[req_idx] += num_tokens
|
||||
|
||||
num_nans_in_logits = {}
|
||||
if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
|
||||
num_nans_in_logits = self._get_nans_in_logits(logits)
|
||||
@ -1664,27 +1582,25 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
max_gen_len = sampled_token_ids.shape[-1]
|
||||
if max_gen_len == 1:
|
||||
# No spec decode tokens.
|
||||
valid_sampled_token_ids = sampled_token_ids.tolist()
|
||||
valid_sampled_token_ids_np = sampled_token_ids.cpu().numpy()
|
||||
valid_sampled_token_ids = valid_sampled_token_ids_np.tolist()
|
||||
else:
|
||||
# Includes spec decode tokens.
|
||||
valid_sampled_token_ids = self.rejection_sampler.parse_output(
|
||||
sampled_token_ids, self.input_batch.vocab_size)
|
||||
# # Mask out the sampled tokens that should not be sampled.
|
||||
# for i in discard_sampled_tokens_req_indices:
|
||||
# valid_sampled_token_ids[i].clear()
|
||||
sampled_token_ids, self.vocab_size)
|
||||
|
||||
# Cache the sampled tokens in the model runner, so that the scheduler
|
||||
# doesn't need to send them back.
|
||||
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
|
||||
# the sampled tokens back, because there's no direct communication
|
||||
# between the first-stage worker and the last-stage worker.
|
||||
for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
|
||||
if not sampled_ids:
|
||||
continue
|
||||
self.input_batch.append_token_ids(req_idx, sampled_ids)
|
||||
self.requests.append_sampled_token_ids(
|
||||
input_batch.idx_mapping_np,
|
||||
valid_sampled_token_ids,
|
||||
)
|
||||
|
||||
if self.speculative_config:
|
||||
assert spec_decode_common_attn_metadata is not None
|
||||
assert input_batch.spec_decode_common_attn_metadata is not None
|
||||
self._draft_token_ids = self.propose_draft_token_ids(
|
||||
scheduler_output,
|
||||
valid_sampled_token_ids,
|
||||
@ -1692,15 +1608,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
hidden_states,
|
||||
sample_hidden_states,
|
||||
aux_hidden_states,
|
||||
spec_decode_metadata,
|
||||
spec_decode_common_attn_metadata,
|
||||
input_batch.spec_decode_metadata,
|
||||
input_batch.spec_decode_common_attn_metadata,
|
||||
)
|
||||
|
||||
self.eplb_step()
|
||||
|
||||
return ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_id_to_batch_idx,
|
||||
req_ids=input_batch.req_ids,
|
||||
req_id_to_index=input_batch.req_id_to_batch_idx,
|
||||
sampled_token_ids=valid_sampled_token_ids,
|
||||
logprobs=logprobs_lists,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
@ -1712,7 +1628,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
|
||||
if self._draft_token_ids is None:
|
||||
return None
|
||||
req_ids = self.input_batch.req_ids
|
||||
req_ids = self.requests.req_ids
|
||||
if isinstance(self._draft_token_ids, torch.Tensor):
|
||||
draft_token_ids = self._draft_token_ids.tolist()
|
||||
else:
|
||||
@ -1722,16 +1638,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
def propose_draft_token_ids(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
input_batch: InputBatch,
|
||||
sampled_token_ids: list[list[int]],
|
||||
sampling_metadata: SamplingMetadata,
|
||||
hidden_states: torch.Tensor,
|
||||
sample_hidden_states: torch.Tensor,
|
||||
aux_hidden_states: Optional[torch.Tensor],
|
||||
spec_decode_metadata: Optional[SpecDecodeMetadata],
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
) -> Union[list[list[int]], torch.Tensor]:
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
num_scheduled_tokens = input_batch.total_num_tokens
|
||||
if self.speculative_config.method == "ngram":
|
||||
assert isinstance(self.drafter, NgramProposer)
|
||||
draft_token_ids = self.propose_ngram_draft_token_ids(
|
||||
@ -1745,7 +1659,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
indices = []
|
||||
offset = 0
|
||||
for num_draft, tokens in zip(
|
||||
spec_decode_metadata.num_draft_tokens,
|
||||
input_batch.spec_decode_metadata.num_draft_tokens,
|
||||
sampled_token_ids):
|
||||
indices.append(offset + len(tokens) - 1)
|
||||
offset += num_draft + 1
|
||||
@ -1759,7 +1673,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
elif self.speculative_config.use_eagle():
|
||||
assert isinstance(self.drafter, EagleProposer)
|
||||
# TODO(woosuk): Refactor the loop.
|
||||
req_ids = self.input_batch.req_ids
|
||||
req_ids = input_batch.req_ids
|
||||
next_token_ids: list[int] = []
|
||||
for i, token_ids in enumerate(sampled_token_ids):
|
||||
if token_ids:
|
||||
@ -1771,14 +1685,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
req_id = req_ids[i]
|
||||
req_state = self.requests[req_id]
|
||||
seq_len = (req_state.num_computed_tokens +
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
input_batch.num_scheduled_tokens[i])
|
||||
next_token_id = req_state.get_token_id(seq_len)
|
||||
next_token_ids.append(next_token_id)
|
||||
next_token_ids = torch.tensor(next_token_ids,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
if spec_decode_metadata is None:
|
||||
if input_batch.spec_decode_metadata is None:
|
||||
# input_ids can be None for multimodal models.
|
||||
target_token_ids = self.input_ids[:num_scheduled_tokens]
|
||||
# TODO(woosuk): Support M-RoPE.
|
||||
@ -1791,7 +1705,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
||||
else:
|
||||
# TODO(woosuk): Refactor this.
|
||||
num_draft_tokens = spec_decode_metadata.num_draft_tokens
|
||||
num_draft_tokens = input_batch.spec_decode_metadata.num_draft_tokens
|
||||
num_rejected_tokens = [
|
||||
n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
|
||||
for i, n in enumerate(num_draft_tokens)
|
||||
@ -1812,7 +1726,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
target_hidden_states = hidden_states[token_indices]
|
||||
mm_embeds = None
|
||||
if self.supports_mm_inputs:
|
||||
mm_embeds = self._gather_mm_embeddings(scheduler_output,
|
||||
mm_embeds = self._gather_mm_embeddings(input_batch,
|
||||
shift_computed_tokens=1)
|
||||
|
||||
draft_token_ids = self.drafter.propose(
|
||||
@ -1828,10 +1742,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
def propose_ngram_draft_token_ids(
|
||||
self,
|
||||
input_batch: InputBatch,
|
||||
sampled_token_ids: list[list[int]],
|
||||
) -> list[list[int]]:
|
||||
# TODO(woosuk): Optimize.
|
||||
req_ids = self.input_batch.req_ids
|
||||
draft_token_ids: list[list[int]] = []
|
||||
for i, sampled_ids in enumerate(sampled_token_ids):
|
||||
num_sampled_ids = len(sampled_ids)
|
||||
@ -1842,19 +1756,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
# Skip requests that require sampling parameters that are not
|
||||
# supported with speculative decoding.
|
||||
req_id = req_ids[i]
|
||||
if req_id in self.input_batch.spec_decode_unsupported_reqs:
|
||||
req_id = input_batch.req_ids[i]
|
||||
if req_id in self.requests.spec_decode_unsupported_reqs:
|
||||
draft_token_ids.append([])
|
||||
continue
|
||||
|
||||
num_tokens = self.input_batch.num_tokens_no_spec[i]
|
||||
num_tokens = self.requests.num_tokens_no_spec[i]
|
||||
if num_tokens >= self.max_model_len:
|
||||
# Skip requests that have already reached the max model length.
|
||||
draft_token_ids.append([])
|
||||
continue
|
||||
|
||||
drafter_output = self.drafter.propose(
|
||||
self.input_batch.token_ids_cpu[i, :num_tokens])
|
||||
self.requests.token_ids.np[i, :num_tokens])
|
||||
if drafter_output is None or len(drafter_output) == 0:
|
||||
draft_token_ids.append([])
|
||||
else:
|
||||
@ -1992,11 +1906,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
hidden_states: torch.Tensor,
|
||||
num_scheduled_tokens: dict[str, int],
|
||||
) -> dict[str, Optional[LogprobsTensors]]:
|
||||
num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
|
||||
num_prompt_logprobs_dict = self.requests.num_prompt_logprobs
|
||||
if not num_prompt_logprobs_dict:
|
||||
return {}
|
||||
|
||||
in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu
|
||||
in_progress_dict = self.requests.in_progress_prompt_logprobs_cpu
|
||||
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {}
|
||||
|
||||
# Since prompt logprobs are a rare feature, prioritize simple,
|
||||
@ -2045,7 +1959,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Get the logits corresponding to this req's prompt tokens.
|
||||
# If this is a partial request (i.e. chunked prefill),
|
||||
# then there is prompt logprob generated for each index.
|
||||
req_idx = self.input_batch.req_id_to_index[req_id]
|
||||
req_idx = 0
|
||||
offset = self.query_start_loc_np[req_idx].item()
|
||||
prompt_hidden_states = hidden_states[offset:offset + num_logits]
|
||||
logits = self.model.compute_logits(prompt_hidden_states, None)
|
||||
@ -2083,20 +1997,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
def _get_nans_in_logits(
|
||||
self,
|
||||
input_batch: InputBatch,
|
||||
logits: Optional[torch.Tensor],
|
||||
) -> dict[str, int]:
|
||||
try:
|
||||
if logits is None:
|
||||
return {req_id: 0 for req_id in self.input_batch.req_ids}
|
||||
return {req_id: 0 for req_id in input_batch.req_ids}
|
||||
|
||||
num_nans_in_logits = {}
|
||||
num_nans_for_index = logits.isnan().sum(dim=-1).cpu().numpy()
|
||||
for req_id in self.input_batch.req_ids:
|
||||
req_index = self.input_batch.req_id_to_index[req_id]
|
||||
num_nans_in_logits[req_id] = (
|
||||
int(num_nans_for_index[req_index])
|
||||
if num_nans_for_index is not None
|
||||
and req_index < logits.shape[0] else 0)
|
||||
for i, req_id in input_batch.req_ids:
|
||||
num_nans_in_logits[req_id] = (int(num_nans_for_index[i])
|
||||
if num_nans_for_index is not None
|
||||
and i < logits.shape[0] else 0)
|
||||
return num_nans_in_logits
|
||||
except IndexError:
|
||||
return {}
|
||||
@ -2248,17 +2161,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
1],
|
||||
seq_lens=self.seq_lens[:num_reqs],
|
||||
seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
|
||||
num_computed_tokens_cpu=self.input_batch.
|
||||
num_computed_tokens_cpu_tensor[:num_reqs],
|
||||
num_computed_tokens_cpu=self.requests.num_computed_tokens.
|
||||
cpu[:num_reqs],
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=max_query_len,
|
||||
max_seq_len=self.max_model_len,
|
||||
block_table_tensor=self.input_batch.block_table[
|
||||
kv_cache_group_id].get_device_tensor()[:num_reqs],
|
||||
slot_mapping=self.input_batch.
|
||||
block_table[kv_cache_group_id].slot_mapping[:num_tokens],
|
||||
causal=True)
|
||||
block_table_tensor=self.requests.
|
||||
block_tables[kv_cache_group_id].gpu[:num_reqs],
|
||||
slot_mapping=self.requests.slot_mappings[kv_cache_group_id]
|
||||
[:num_tokens],
|
||||
causal=True,
|
||||
)
|
||||
|
||||
for attn_group in self.attn_groups[kv_cache_group_id]:
|
||||
attn_metadata_i = attn_group.metadata_builder\
|
||||
|
||||
342
vllm/v1/worker/gpu_worker_states.py
Normal file
342
vllm/v1/worker/gpu_worker_states.py
Normal file
@ -0,0 +1,342 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Datastructures defining a GPU input batch
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.inputs import (MultiModalKwargsItem,
|
||||
MultiModalKwargsItems, PlaceholderRange)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessors
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestData:
|
||||
|
||||
mm_kwargs: list[MultiModalKwargsItem]
|
||||
mm_positions: list[PlaceholderRange]
|
||||
sampling_params: Optional[SamplingParams]
|
||||
pooling_params: Optional[PoolingParams]
|
||||
|
||||
# M-RoPE (only for Qwen2/2.5-VL)
|
||||
mrope_positions: Optional[torch.Tensor] = None
|
||||
mrope_position_delta: Optional[int] = None
|
||||
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
|
||||
# Temporary back-compatibility for plugins that define model runner
|
||||
@property
|
||||
@deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be "
|
||||
"removed in v0.13. Please use `mm_kwargs` instead.")
|
||||
def mm_inputs(self) -> list[MultiModalKwargsItems]:
|
||||
return [
|
||||
MultiModalKwargsItems.from_seq([item]) for item in self.mm_kwargs
|
||||
]
|
||||
|
||||
|
||||
class RequestAttribute:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_rows_cpu: int,
|
||||
num_cols: int,
|
||||
num_rows_gpu: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
is_scalar: bool = False,
|
||||
):
|
||||
self.cpu = torch.zeros(num_rows_cpu,
|
||||
num_cols,
|
||||
dtype=dtype,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.np = self.cpu.numpy()
|
||||
self.gpu = torch.zeros(num_rows_gpu,
|
||||
num_cols,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
if is_scalar:
|
||||
assert num_cols == 1
|
||||
self.cpu.squeeze_(1)
|
||||
self.np = self.cpu.numpy()
|
||||
self.gpu.squeeze_(1)
|
||||
|
||||
self.gpu_buffer = self.cpu.to(device)
|
||||
|
||||
def mirror_to_gpu(self) -> torch.Tensor:
|
||||
return self.gpu_buffer.copy_(self.cpu, non_blocking=True)
|
||||
|
||||
|
||||
class RequestState:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_model_len: int,
|
||||
max_num_batched_tokens: int,
|
||||
max_num_cached_reqs: int,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
vocab_size: int,
|
||||
block_sizes: list[int], # The block_size of each kv cache group
|
||||
logitsprocs: Optional[LogitsProcessors] = None,
|
||||
is_spec_decode: bool = False,
|
||||
is_pooling_model: bool = False,
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_model_len = max_model_len
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.max_num_cached_reqs = max_num_cached_reqs
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
self.vocab_size = vocab_size
|
||||
self.is_spec_decode = is_spec_decode
|
||||
self.pooling_params = None
|
||||
self.block_sizes = block_sizes
|
||||
self.num_prompt_logprobs = {}
|
||||
|
||||
self.req_id_to_index: dict[str, int] = {}
|
||||
self.index_to_req_id: dict[int, str] = {}
|
||||
self.free_indices = list(range(max_num_cached_reqs))
|
||||
# Used to construct the input batch.
|
||||
self._add_scalar_attr("idx_mapping", torch.int32)
|
||||
|
||||
# Request states.
|
||||
self.req_data: dict[int, RequestData] = {}
|
||||
# TODO(woosuk): Because the token_ids tensor can be very big, we only
|
||||
# initialize it on CPU memory.
|
||||
self._add_vector_attr("token_ids",
|
||||
self.max_model_len,
|
||||
torch.int32,
|
||||
cpu_only=True)
|
||||
self._add_scalar_attr("num_prompt_tokens", torch.int32)
|
||||
self._add_scalar_attr("num_tokens", torch.int32)
|
||||
self._add_scalar_attr("num_computed_tokens", torch.int32)
|
||||
|
||||
# Sampling-related.
|
||||
self._add_scalar_attr("temperature", torch.float32)
|
||||
self.greedy_reqs: set[str] = set()
|
||||
self.random_reqs: set[str] = set()
|
||||
self._add_scalar_attr("top_p", torch.float32)
|
||||
self.top_p_reqs: set[str] = set()
|
||||
self._add_scalar_attr("top_k", torch.int32)
|
||||
self.top_k_reqs: set[str] = set()
|
||||
self._add_scalar_attr("frequency_penalties", torch.float32)
|
||||
self.frequency_penalties_reqs: set[str] = set()
|
||||
self._add_scalar_attr("presence_penalties", torch.float32)
|
||||
self.presence_penalties_reqs: set[str] = set()
|
||||
self._add_scalar_attr("repetition_penalties", torch.float32)
|
||||
self.repetition_penalties_reqs: set[str] = set()
|
||||
|
||||
# req_idx -> generator
|
||||
self.generators: dict[int, torch.Generator] = {}
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
req_id: str,
|
||||
prompt_token_ids: list[int],
|
||||
num_computed_tokens: int,
|
||||
sampling_params: SamplingParams,
|
||||
) -> None:
|
||||
req_idx = self.free_indices.pop()
|
||||
self.req_id_to_index[req_id] = req_idx
|
||||
self.index_to_req_id[req_idx] = req_id
|
||||
|
||||
self.num_prompt_tokens.np[req_idx] = len(prompt_token_ids)
|
||||
self.num_computed_tokens.np[req_idx] = num_computed_tokens
|
||||
self.append_token_ids(req_idx, prompt_token_ids)
|
||||
|
||||
self.temperature.np[req_idx] = sampling_params.temperature
|
||||
if sampling_params.sampling_type == SamplingType.GREEDY:
|
||||
# NOTE: Be careful about division by zero.
|
||||
self.greedy_reqs.add(req_id)
|
||||
elif sampling_params.sampling_type == SamplingType.RANDOM:
|
||||
self.random_reqs.add(req_id)
|
||||
|
||||
self.top_p.np[req_idx] = sampling_params.top_p
|
||||
if sampling_params.top_p < 1.0:
|
||||
self.top_p_reqs.add(req_id)
|
||||
|
||||
top_k = sampling_params.top_k
|
||||
if 0 < top_k < self.vocab_size:
|
||||
self.top_k_reqs.add(req_id)
|
||||
else:
|
||||
top_k = self.vocab_size
|
||||
self.top_k.np[req_idx] = top_k
|
||||
|
||||
self.frequency_penalties.np[
|
||||
req_idx] = sampling_params.frequency_penalty
|
||||
if sampling_params.frequency_penalty != 0.0:
|
||||
self.frequency_penalties_reqs.add(req_id)
|
||||
self.presence_penalties.np[req_idx] = sampling_params.presence_penalty
|
||||
if sampling_params.presence_penalty != 0.0:
|
||||
self.presence_penalties_reqs.add(req_id)
|
||||
self.repetition_penalties.np[
|
||||
req_idx] = sampling_params.repetition_penalty
|
||||
if sampling_params.repetition_penalty != 1.0:
|
||||
self.repetition_penalties_reqs.add(req_id)
|
||||
|
||||
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
|
||||
generator = torch.Generator(device=self.device)
|
||||
generator.manual_seed(sampling_params.seed)
|
||||
self.generators[req_idx] = generator
|
||||
|
||||
def append_token_ids(
|
||||
self,
|
||||
req_idx: int,
|
||||
token_ids: Union[list[int], np.ndarray],
|
||||
) -> None:
|
||||
start_idx = self.num_tokens.np[req_idx]
|
||||
end_idx = start_idx + len(token_ids)
|
||||
self.token_ids.np[req_idx, start_idx:end_idx] = token_ids
|
||||
self.num_tokens.np[req_idx] = end_idx
|
||||
|
||||
def append_sampled_token_ids(
|
||||
self,
|
||||
idx_mapping: np.ndarray,
|
||||
sampled_token_ids: np.ndarray,
|
||||
) -> None:
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
for i in range(num_reqs):
|
||||
req_idx = idx_mapping[i]
|
||||
self.append_token_ids(req_idx, sampled_token_ids[i])
|
||||
|
||||
def remove_request(self, req_id: str) -> None:
|
||||
req_idx = self.req_id_to_index.pop(req_id, None)
|
||||
if req_idx is None:
|
||||
# Request not found.
|
||||
return
|
||||
self.index_to_req_id.pop(req_idx, None)
|
||||
self.free_indices.append(req_idx)
|
||||
|
||||
self.greedy_reqs.discard(req_id)
|
||||
self.random_reqs.discard(req_id)
|
||||
self.top_p_reqs.discard(req_id)
|
||||
self.top_k_reqs.discard(req_id)
|
||||
self.frequency_penalties_reqs.discard(req_id)
|
||||
self.presence_penalties_reqs.discard(req_id)
|
||||
self.repetition_penalties_reqs.discard(req_id)
|
||||
self.generators.pop(req_idx, None)
|
||||
|
||||
def make_sampling_metadata(
|
||||
self,
|
||||
batch_idx_to_req_idx: torch.Tensor,
|
||||
) -> SamplingMetadata:
|
||||
batch_size = batch_idx_to_req_idx.shape[0]
|
||||
_make_sampling_metadata_kernel[(batch_size, )](
|
||||
batch_idx_to_req_idx,
|
||||
self.temperature.mirror_to_gpu(),
|
||||
self.temperature.gpu,
|
||||
self.top_p.mirror_to_gpu(),
|
||||
self.top_p.gpu,
|
||||
self.top_k.mirror_to_gpu(),
|
||||
self.top_k.gpu,
|
||||
self.frequency_penalties.mirror_to_gpu(),
|
||||
self.frequency_penalties.gpu,
|
||||
self.presence_penalties.mirror_to_gpu(),
|
||||
self.presence_penalties.gpu,
|
||||
self.repetition_penalties.mirror_to_gpu(),
|
||||
self.repetition_penalties.gpu,
|
||||
num_warps=1,
|
||||
num_stages=1,
|
||||
)
|
||||
no_penalties = not (self.frequency_penalties_reqs
|
||||
or self.presence_penalties_reqs
|
||||
or self.repetition_penalties_reqs)
|
||||
return SamplingMetadata(
|
||||
temperature=self.temperature.gpu[:batch_size],
|
||||
all_greedy=not self.random_reqs,
|
||||
all_random=not self.greedy_reqs,
|
||||
top_p=self.top_p.gpu[:batch_size],
|
||||
top_k=self.top_k.gpu[:batch_size],
|
||||
frequency_penalties=self.frequency_penalties.gpu[:batch_size],
|
||||
presence_penalties=self.presence_penalties.gpu[:batch_size],
|
||||
repetition_penalties=self.repetition_penalties.gpu[:batch_size],
|
||||
no_penalties=no_penalties,
|
||||
# TODO
|
||||
generators={},
|
||||
token_ids=self.token_ids.cpu[:batch_size],
|
||||
max_num_logprobs=None,
|
||||
allowed_token_ids_mask=None,
|
||||
bad_words_token_ids={},
|
||||
logitsprocs=None,
|
||||
)
|
||||
|
||||
@property
|
||||
def num_cached_reqs(self) -> int:
|
||||
return len(self.req_id_to_index)
|
||||
|
||||
def _add_scalar_attr(self, name: str, dtype: torch.dtype):
|
||||
attr = RequestAttribute(self.max_num_cached_reqs,
|
||||
1,
|
||||
self.max_num_reqs,
|
||||
dtype,
|
||||
self.device,
|
||||
self.pin_memory,
|
||||
is_scalar=True)
|
||||
setattr(self, name, attr)
|
||||
|
||||
def _add_vector_attr(
|
||||
self,
|
||||
name: str,
|
||||
max_len: int,
|
||||
dtype: torch.dtype,
|
||||
cpu_only: bool = False,
|
||||
):
|
||||
if cpu_only:
|
||||
num_rows_gpu = 0
|
||||
else:
|
||||
num_rows_gpu = self.max_num_reqs
|
||||
attr = RequestAttribute(self.max_num_cached_reqs, max_len,
|
||||
num_rows_gpu, dtype, self.device,
|
||||
self.pin_memory)
|
||||
setattr(self, name, attr)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _make_sampling_metadata_kernel(
|
||||
batch_idx_to_req_idx, # [batch_size]
|
||||
src_temperature,
|
||||
dst_temperature,
|
||||
src_top_p,
|
||||
dst_top_p,
|
||||
src_top_k,
|
||||
dst_top_k,
|
||||
src_frequency_penalties,
|
||||
dst_frequency_penalties,
|
||||
src_presence_penalties,
|
||||
dst_presence_penalties,
|
||||
src_repetition_penalties,
|
||||
dst_repetition_penalties,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
|
||||
|
||||
temperature = tl.load(src_temperature + req_idx)
|
||||
tl.store(dst_temperature + batch_idx, temperature)
|
||||
|
||||
top_p = tl.load(src_top_p + req_idx)
|
||||
tl.store(dst_top_p + batch_idx, top_p)
|
||||
|
||||
top_k = tl.load(src_top_k + req_idx)
|
||||
tl.store(dst_top_k + batch_idx, top_k)
|
||||
|
||||
frequency_penalties = tl.load(src_frequency_penalties + req_idx)
|
||||
tl.store(dst_frequency_penalties + batch_idx, frequency_penalties)
|
||||
|
||||
presence_penalties = tl.load(src_presence_penalties + req_idx)
|
||||
tl.store(dst_presence_penalties + batch_idx, presence_penalties)
|
||||
|
||||
repetition_penalties = tl.load(src_repetition_penalties + req_idx)
|
||||
tl.store(dst_repetition_penalties + batch_idx, repetition_penalties)
|
||||
@ -12,7 +12,6 @@ from vllm.sampling_params import SamplingType
|
||||
from vllm.utils import swap_dict_values
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
from vllm.v1.worker.block_table import MultiGroupBlockTable
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
|
||||
@ -298,3 +298,32 @@ def bind_kv_cache(
|
||||
for layer_name, kv_cache in kv_caches.items():
|
||||
# NOTE: Use list because of v0 PP virtual engine.
|
||||
forward_context[layer_name].kv_cache = [kv_cache]
|
||||
|
||||
|
||||
class CpuGpuBuffer:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
):
|
||||
self.cpu = torch.zeros(*args,
|
||||
dtype=dtype,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.np = self.cpu.numpy()
|
||||
self.gpu = self.cpu.to(device)
|
||||
|
||||
def copy_to_gpu(self, n: Optional[int] = None) -> None:
|
||||
if n is None:
|
||||
return self.gpu.copy_(self.cpu, non_blocking=True)
|
||||
else:
|
||||
return self.gpu[:n].copy_(self.cpu[:n], non_blocking=True)
|
||||
|
||||
def copy_to_cpu(self, n: Optional[int] = None) -> None:
|
||||
if n is None:
|
||||
return self.cpu.copy_(self.gpu, non_blocking=True)
|
||||
else:
|
||||
return self.cpu[:n].copy_(self.gpu[:n], non_blocking=True)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user