Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-08-24 18:36:18 -07:00
parent 48bca9a109
commit a1e3745150
7 changed files with 964 additions and 811 deletions

View File

@ -90,9 +90,9 @@ class Sampler(nn.Module):
# Apply bad words exclusion.
logits = self.apply_bad_words(logits, sampling_metadata)
# Apply logits processors which can impact greedy sampling
for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
logits = processor.apply(logits)
# # Apply logits processors which can impact greedy sampling
# for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
# logits = processor.apply(logits)
# Apply penalties (e.g., min_tokens, freq_penalties).
logits = self.apply_penalties(logits, sampling_metadata)
@ -167,10 +167,10 @@ class Sampler(nn.Module):
# Apply temperature.
logits = self.apply_temperature(logits, sampling_metadata.temperature)
# Apply logits processors that only apply to random sampling
# (argmax invariant)
for processor in sampling_metadata.logitsprocs.argmax_invariant:
logits = processor.apply(logits)
# # Apply logits processors that only apply to random sampling
# # (argmax invariant)
# for processor in sampling_metadata.logitsprocs.argmax_invariant:
# logits = processor.apply(logits)
# Apply top_k and/or top_p.
random_sampled, processed_logprobs = self.topk_topp_sampler(

View File

@ -0,0 +1,312 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
import torch
import triton
import triton.language as tl
from vllm.utils import cdiv
from vllm.v1.worker.utils import CpuGpuBuffer
PAD_SLOT_ID = -1
class BlockTables:
def __init__(
self,
block_sizes: list[int],
max_num_reqs: int,
max_num_cached_reqs: int,
max_num_batched_tokens: int,
max_model_len: int,
device: torch.device,
pin_memory: bool,
):
self.block_sizes = block_sizes
self.max_num_reqs = max_num_reqs
self.max_num_cached_reqs = max_num_cached_reqs
self.max_num_batched_tokens = max_num_batched_tokens
self.max_model_len = max_model_len
self.device = device
self.pin_memory = pin_memory
self.num_kv_cache_groups = len(self.block_sizes)
# [num_kv_cache_groups, max_num_reqs, max_num_blocks]
self.block_tables: list[torch.Tensor] = []
# [num_kv_cache_groups, max_num_cached_reqs, max_num_blocks]
self.block_table_buffers: list[torch.Tensor] = []
# [num_kv_cache_groups, max_num_reqs]
self.num_blocks: list[torch.Tensor] = []
# [num_kv_cache_groups, max_num_tokens]
self.slot_mappings: list[torch.Tensor] = []
for i in range(self.num_kv_cache_groups):
block_size = self.block_sizes[i]
max_num_blocks = cdiv(self.max_model_len, block_size)
block_table = torch.zeros(
self.max_num_reqs,
max_num_blocks,
dtype=torch.int32,
device=self.device,
)
self.block_tables.append(block_table)
block_table_buffer = torch.zeros(
self.max_num_cached_reqs,
max_num_blocks,
dtype=torch.int32,
device=self.device,
)
self.block_table_buffers.append(block_table_buffer)
num_blocks = torch.zeros(self.max_num_reqs,
dtype=torch.int32,
device=self.device)
self.num_blocks.append(num_blocks)
slot_mapping = torch.zeros(self.max_num_batched_tokens,
dtype=torch.int64,
device=self.device)
self.slot_mappings.append(slot_mapping)
self.block_table_ptrs = self._make_ptr_tensor(self.block_tables)
self.buffer_ptrs = self._make_ptr_tensor(self.block_table_buffers)
self.block_table_strides = torch.tensor(
[b.stride(0) for b in self.block_tables],
dtype=torch.int64,
device=self.device)
self.num_blocks_ptrs = self._make_ptr_tensor(self.num_blocks)
self.block_sizes_tensor = torch.tensor(self.block_sizes,
dtype=torch.int32,
device=self.device)
self.slot_mapping_ptrs = self._make_ptr_tensor(self.slot_mappings)
# Misc buffers.
self.req_indices = self._make_buffer(self.max_num_reqs, torch.int32)
self.overwrite = self._make_buffer(self.max_num_reqs, torch.bool)
self.cu_num_new_blocks: list[CpuGpuBuffer] = []
self.new_block_ids: list[CpuGpuBuffer] = []
for i in range(self.num_kv_cache_groups):
self.cu_num_new_blocks.append(
self._make_buffer(self.max_num_reqs + 1, torch.int32))
# NOTE(woosuk): Here, we assume that total number of new blocks
# is ALWAYS less than max_num_batched_tokens.
# TODO(woosuk): Rigorously verify that this assumption is correct.
self.new_block_ids.append(
self._make_buffer(self.max_num_batched_tokens, torch.int32))
def _make_buffer(self, n: int, dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(n,
dtype=dtype,
pin_memory=self.pin_memory,
device=self.device)
def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor:
ptrs_tensor_cpu = torch.tensor([t.data_ptr() for t in x],
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory)
return ptrs_tensor_cpu.to(self.device, non_blocking=True)
def append_block_ids(
self,
# [num_reqs]
req_indices: list[int],
# [num_kv_cache_groups, num_reqs + 1]
cu_num_new_blocks: list[list[int]],
# [num_kv_cache_groups, num_new_blocks]
new_block_ids: list[list[int]],
# [num_reqs]
overwrite: list[bool],
) -> None:
# TODO(woosuk): Optimize & simplify this.
num_reqs = len(req_indices)
self.req_indices.np[:num_reqs] = req_indices
self.overwrite.np[:num_reqs] = overwrite
for i in range(self.num_kv_cache_groups):
self.cu_num_new_blocks[i].np[:num_reqs + 1] = cu_num_new_blocks[i]
n = len(new_block_ids[i])
self.new_block_ids[i].np[:n] = new_block_ids[i]
cu_num_new_blocks_ptrs = self._make_ptr_tensor(
[x.copy_to_gpu(num_reqs + 1) for x in self.cu_num_new_blocks])
new_block_ids_ptrs = self._make_ptr_tensor([
x.copy_to_gpu(len(new_block_ids[i]))
for i, x in enumerate(self.new_block_ids)
])
_append_block_ids_kernel[(num_reqs, self.num_kv_cache_groups)](
self.req_indices.copy_to_gpu(num_reqs),
cu_num_new_blocks_ptrs,
new_block_ids_ptrs,
self.overwrite.copy_to_gpu(num_reqs),
self.block_table_strides,
self.buffer_ptrs,
self.num_blocks_ptrs,
BLOCK_SIZE=1024,
)
def compute_block_tables(
self,
idx_mapping: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
batch_size = idx_mapping.shape[0]
_compute_block_tables_kernel[(batch_size, self.num_kv_cache_groups)](
idx_mapping,
self.buffer_ptrs,
self.block_table_ptrs,
self.block_table_strides,
self.num_blocks_ptrs,
BLOCK_SIZE=1024,
)
return tuple(b[:batch_size] for b in self.block_tables)
def compute_slot_mappings(
self,
cu_num_tokens: torch.Tensor,
pos: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
num_tokens = pos.shape[0]
num_reqs = cu_num_tokens.shape[0] - 1
num_groups = self.num_kv_cache_groups
_compute_slot_mappings_kernel[(num_reqs + 1, num_groups)](
num_tokens,
self.max_num_batched_tokens,
cu_num_tokens,
pos,
self.block_table_ptrs,
self.block_table_strides,
self.block_sizes_tensor,
self.slot_mapping_ptrs,
PAD_ID=PAD_SLOT_ID,
BLOCK_SIZE=1024,
)
return tuple(x[:num_tokens] for x in self.slot_mappings)
@triton.jit
def _append_block_ids_kernel(
# Inputs
req_indices, # [num_reqs]
cu_num_new_block_ptrs, # [num_kv_cache_groups, num_reqs + 1]
new_block_id_ptrs, # [num_kv_cache_groups, num_new_blocks]
overwrite, # [num_reqs]
block_table_strides, # [num_kv_cache_groups]
# Outputs
block_table_buffer_ptrs, # [num_kv_cache_groups]
num_block_ptrs, # [num_kv_cache_groups]
# Constants
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
group_id = tl.program_id(1)
req_idx = tl.load(req_indices + batch_idx)
do_overwrite = tl.load(overwrite + batch_idx)
cu_num_new_blocks_ptr = _load_ptr(cu_num_new_block_ptrs + group_id,
tl.int32)
start_idx = tl.load(cu_num_new_blocks_ptr + batch_idx)
end_idx = tl.load(cu_num_new_blocks_ptr + batch_idx + 1)
num_new_blocks = end_idx - start_idx
num_blocks_ptr = _load_ptr(num_block_ptrs + group_id, tl.int32)
if do_overwrite:
dst_start_idx = 0
else:
dst_start_idx = tl.load(num_blocks_ptr + req_idx)
dst_end_idx = dst_start_idx + num_new_blocks
tl.store(num_blocks_ptr + req_idx, dst_end_idx)
# Destination
block_table_buffer_ptr = _load_ptr(block_table_buffer_ptrs + group_id,
tl.int32)
block_table_stride = tl.load(block_table_strides + group_id)
buffer_row_ptr = block_table_buffer_ptr + req_idx * block_table_stride
new_block_ids_ptr = _load_ptr(new_block_id_ptrs + group_id, tl.int32)
for i in tl.range(0, num_new_blocks, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
block_ids = tl.load(new_block_ids_ptr + start_idx + offset,
mask=offset < num_new_blocks)
tl.store(buffer_row_ptr + dst_start_idx + offset,
block_ids,
mask=offset < num_new_blocks)
@triton.jit
def _compute_block_tables_kernel(
batch_idx_to_req_idx, # [batch_size]
src_block_table_ptrs, # [num_kv_cache_groups]
dst_block_table_ptrs, # [num_kv_cache_groups]
block_table_strides, # [num_kv_cache_groups]
num_blocks_ptrs, # [num_kv_cache_groups]
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
# kv cache group id
group_id = tl.program_id(1)
req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
num_blocks_ptr = _load_ptr(num_blocks_ptrs + group_id, tl.int32)
num_blocks = tl.load(num_blocks_ptr + req_idx)
stride = tl.load(block_table_strides + group_id)
src_block_table_ptr = _load_ptr(src_block_table_ptrs + group_id, tl.int32)
src_row_ptr = src_block_table_ptr + req_idx * stride
dst_block_table_ptr = _load_ptr(dst_block_table_ptrs + group_id, tl.int32)
dst_row_ptr = dst_block_table_ptr + batch_idx * stride
for i in tl.range(0, num_blocks, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
block_ids = tl.load(src_row_ptr + offset, mask=offset < num_blocks)
tl.store(dst_row_ptr + offset, block_ids, mask=offset < num_blocks)
@triton.jit
def _compute_slot_mappings_kernel(
num_tokens,
max_num_tokens,
cu_num_tokens, # [num_reqs + 1]
pos, # [num_tokens]
block_table_ptrs, # [num_kv_cache_groups]
block_table_strides, # [num_kv_cache_groups]
page_sizes, # [num_kv_cache_groups]
slot_mapping_ptrs, # [num_kv_cache_groups]
PAD_ID: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
# kv cache group id
group_id = tl.program_id(1)
slot_mapping_ptr = _load_ptr(slot_mapping_ptrs + group_id, tl.int64)
if req_idx == tl.num_programs(0) - 1:
# Pad remaining slots to -1. This is needed for CUDA graphs.
for i in tl.range(num_tokens, max_num_tokens, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
tl.store(slot_mapping_ptr + offset,
PAD_ID,
mask=offset < max_num_tokens)
return
block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
block_table_stride = tl.load(block_table_strides + group_id)
page_size = tl.load(page_sizes + group_id)
start_idx = tl.load(cu_num_tokens + req_idx)
end_idx = tl.load(cu_num_tokens + req_idx + 1)
for i in tl.range(start_idx, end_idx, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
positions = tl.load(pos + offset, mask=offset < end_idx, other=0)
block_indices = positions // page_size
block_numbers = tl.load(block_table_ptr +
req_idx * block_table_stride + block_indices)
slot_ids = block_numbers * page_size + positions % page_size
tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx)
@triton.jit
def _load_ptr(ptr_to_ptr, elem_dtype):
ptr = tl.load(ptr_to_ptr)
return tl.cast(ptr, tl.pointer_type(elem_dtype))

View File

@ -1,401 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Datastructures defining a GPU input batch
from dataclasses import dataclass
from typing import Optional
from typing import Any, Optional
import numba
import numpy as np
import torch
import triton
import triton.language as tl
from typing_extensions import deprecated
from numba import types
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import (MultiModalKwargsItem,
MultiModalKwargsItems, PlaceholderRange)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import cdiv, get_cuda_view_from_cpu_tensor, is_uva_available
from vllm.v1.sample.logits_processor import LogitsProcessors
from vllm.v1.sample.metadata import SamplingMetadata
PAD_SLOT_ID = -1
@dataclass
class RequestData:
mm_kwargs: list[MultiModalKwargsItem]
mm_positions: list[PlaceholderRange]
sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams]
# M-RoPE (only for Qwen2/2.5-VL)
mrope_positions: Optional[torch.Tensor] = None
mrope_position_delta: Optional[int] = None
lora_request: Optional[LoRARequest] = None
# Temporary back-compatibility for plugins that define model runner
@property
@deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be "
"removed in v0.13. Please use `mm_kwargs` instead.")
def mm_inputs(self) -> list[MultiModalKwargsItems]:
return [
MultiModalKwargsItems.from_seq([item]) for item in self.mm_kwargs
]
class PerRequestAttribute:
def __init__(
self,
num_rows_cpu: int,
num_cols: int,
num_rows_gpu: int,
dtype: torch.dtype,
device: torch.device,
is_scalar: bool = False,
):
assert is_uva_available(), "UVA is not available."
self.cpu = torch.zeros(num_rows_cpu,
num_cols,
dtype=dtype,
device="cpu",
pin_memory=True)
self.np = self.cpu.numpy()
self.uva = get_cuda_view_from_cpu_tensor(self.cpu)
self.gpu = torch.zeros(num_rows_gpu,
num_cols,
dtype=dtype,
device=device)
if is_scalar:
assert num_cols == 1
self.cpu.squeeze_(1)
self.np = self.cpu.numpy()
self.uva.squeeze_(1)
self.gpu.squeeze_(1)
class RequestState:
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
max_num_cached_reqs: int,
device: torch.device,
pin_memory: bool,
vocab_size: int,
block_sizes: list[int], # The block_size of each kv cache group
logitsprocs: Optional[LogitsProcessors] = None,
is_spec_decode: bool = False,
is_pooling_model: bool = False,
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_batched_tokens = max_num_batched_tokens
self.max_num_cached_reqs = max_num_cached_reqs
self.device = device
self.pin_memory = pin_memory
self.vocab_size = vocab_size
self.is_spec_decode = is_spec_decode
self.pooling_params = None
self.block_sizes = block_sizes
self.num_prompt_logprobs = {}
self.req_id_to_index: dict[str, int] = {}
self.index_to_req_id: dict[int, str] = {}
self.free_indices = list(range(max_num_cached_reqs))
# Used to construct the input batch.
self._add_scalar_attr("idx_mapping", torch.int32)
# Request states.
self.req_data: dict[int, RequestData] = {}
# TODO(woosuk): Because the token_ids tensor can be very big, we only
# initialize it on CPU memory.
self._add_vector_attr_cpu("token_ids", self.max_model_len, torch.int32)
self._add_scalar_attr("num_prompt_tokens", torch.int32)
self._add_scalar_attr("num_tokens", torch.int32)
self._add_scalar_attr("num_computed_tokens", torch.int32)
# Sampling-related.
self._add_scalar_attr("temperature", torch.float32)
self.greedy_reqs: set[str] = set()
self.random_reqs: set[str] = set()
self._add_scalar_attr("top_p", torch.float32)
self.top_p_reqs: set[str] = set()
self._add_scalar_attr("top_k", torch.int32)
self.top_k_reqs: set[str] = set()
self._add_scalar_attr("frequency_penalties", torch.float32)
self.frequency_penalties_reqs: set[str] = set()
self._add_scalar_attr("presence_penalties", torch.float32)
self.presence_penalties_reqs: set[str] = set()
self._add_scalar_attr("repetition_penalties", torch.float32)
self.repetition_penalties_reqs: set[str] = set()
# req_idx -> generator
self.generators: dict[int, torch.Generator] = {}
# Block table(s).
self._init_block_tables()
def add_request(
self,
req_id: str,
prompt_token_ids: list[int],
num_computed_tokens: int,
block_ids: tuple[list[int], ...],
sampling_params: SamplingParams,
) -> None:
req_idx = self.free_indices.pop()
self.req_id_to_index[req_id] = req_idx
self.index_to_req_id[req_idx] = req_id
self.num_prompt_tokens.np[req_idx] = len(prompt_token_ids)
self.num_computed_tokens.np[req_idx] = num_computed_tokens
self.append_token_ids(req_idx, prompt_token_ids)
self.append_block_ids(req_idx, block_ids, overwrite=True)
self.temperature.np[req_idx] = sampling_params.temperature
if sampling_params.sampling_type == SamplingType.GREEDY:
# NOTE: Be careful about division by zero.
self.greedy_reqs.add(req_id)
elif sampling_params.sampling_type == SamplingType.RANDOM:
self.random_reqs.add(req_id)
self.top_p.np[req_idx] = sampling_params.top_p
if sampling_params.top_p < 1.0:
self.top_p_reqs.add(req_id)
top_k = sampling_params.top_k
if 0 < top_k < self.vocab_size:
self.top_k_reqs.add(req_id)
else:
top_k = self.vocab_size
self.top_k.np[req_idx] = top_k
self.frequency_penalties.np[
req_idx] = sampling_params.frequency_penalty
if sampling_params.frequency_penalty != 0.0:
self.frequency_penalties_reqs.add(req_id)
self.presence_penalties.np[req_idx] = sampling_params.presence_penalty
if sampling_params.presence_penalty != 0.0:
self.presence_penalties_reqs.add(req_id)
self.repetition_penalties.np[
req_idx] = sampling_params.repetition_penalty
if sampling_params.repetition_penalty != 1.0:
self.repetition_penalties_reqs.add(req_id)
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
generator = torch.Generator(device=self.device)
generator.manual_seed(sampling_params.seed)
self.generators[req_idx] = generator
def append_token_ids(self, req_idx: int, token_ids: list[int]) -> None:
start_idx = self.num_tokens.np[req_idx]
end_idx = start_idx + len(token_ids)
self.token_ids.np[req_idx, start_idx:end_idx] = token_ids
self.num_tokens.np[req_idx] = end_idx
# TODO(woosuk): Further vectorize this to minimize overheads.
def append_block_ids(
self,
req_idx: int,
new_block_ids: tuple[list[int], ...],
overwrite: bool,
) -> None:
for i in range(self.num_block_tables):
block_table = self.block_tables[i]
num_blocks = self.num_blocks[i]
num_new_blocks = len(new_block_ids[i])
if overwrite:
# Replace the existing block IDs with the new ones.
# This happens when the request is resumed from preemption.
block_table.np[req_idx, :num_new_blocks] = new_block_ids[i]
num_blocks.np[req_idx] = num_new_blocks
else:
# Append the new block IDs to the existing ones (common case).
start_idx = num_blocks.np[req_idx]
end_idx = start_idx + num_new_blocks
block_table.np[req_idx, start_idx:end_idx] = new_block_ids[i]
num_blocks.np[req_idx] = end_idx
def remove_request(self, req_id: str) -> None:
req_idx = self.req_id_to_index.pop(req_id, None)
if req_idx is None:
# Request not found.
return
self.index_to_req_id.pop(req_idx, None)
self.free_indices.append(req_idx)
self.greedy_reqs.discard(req_id)
self.random_reqs.discard(req_id)
self.top_p_reqs.discard(req_id)
self.top_k_reqs.discard(req_id)
self.frequency_penalties_reqs.discard(req_id)
self.presence_penalties_reqs.discard(req_id)
self.repetition_penalties_reqs.discard(req_id)
self.generators.pop(req_idx, None)
def get_index_mapping(self, idx_mapping: list[int]) -> torch.Tensor:
num_reqs = len(idx_mapping)
self.idx_mapping.np[:num_reqs] = idx_mapping
return self.idx_mapping.gpu[:num_reqs].copy_(
self.idx_mapping.uva[:num_reqs], non_blocking=True)
def make_sampling_metadata(
self,
batch_idx_to_req_idx: torch.Tensor,
) -> SamplingMetadata:
batch_size = batch_idx_to_req_idx.shape[0]
_make_sampling_metadata_kernel[(batch_size, )](
batch_idx_to_req_idx,
self.temperature.uva,
self.temperature.gpu,
self.top_p.uva,
self.top_p.gpu,
self.top_k.uva,
self.top_k.gpu,
self.frequency_penalties.uva,
self.frequency_penalties.gpu,
self.presence_penalties.uva,
self.presence_penalties.gpu,
self.repetition_penalties.uva,
self.repetition_penalties.gpu,
num_warps=1,
num_stages=1,
)
no_penalties = not (self.frequency_penalties_reqs
or self.presence_penalties_reqs
or self.repetition_penalties_reqs)
return SamplingMetadata(
temperature=self.temperature.gpu[:batch_size],
all_greedy=not self.random_reqs,
all_random=not self.greedy_reqs,
top_p=self.top_p.gpu[:batch_size],
top_k=self.top_k.gpu[:batch_size],
frequency_penalties=self.frequency_penalties.gpu[:batch_size],
presence_penalties=self.presence_penalties.gpu[:batch_size],
repetition_penalties=self.repetition_penalties.gpu[:batch_size],
no_penalties=no_penalties,
# TODO
generators={},
token_ids=self.token_ids.gpu[:batch_size],
max_num_logprobs=None,
allowed_token_ids_mask=None,
bad_words_token_ids={},
logitsprocs=None,
)
@property
def num_reqs(self) -> int:
return len(self.req_id_to_index)
def _add_scalar_attr(self, name: str, dtype: torch.dtype):
attr = PerRequestAttribute(self.max_num_cached_reqs,
1,
self.max_num_reqs,
dtype,
self.device,
is_scalar=True)
setattr(self, name, attr)
def _add_vector_attr(self, name: str, max_len: int, dtype: torch.dtype):
attr = PerRequestAttribute(self.max_num_cached_reqs, max_len,
self.max_num_reqs, dtype, self.device)
setattr(self, name, attr)
def _add_vector_attr_cpu(self, name: str, max_len: int,
dtype: torch.dtype):
attr = PerRequestAttribute(self.max_num_cached_reqs, max_len, 0, dtype,
self.device)
setattr(self, name, attr)
def _init_block_tables(self):
self.num_block_tables = len(self.block_sizes)
self.block_tables = []
self.num_blocks = []
self.slot_mappings: list[torch.Tensor] = []
for i in range(self.num_block_tables):
max_num_blocks = cdiv(self.max_model_len, self.block_sizes[i])
block_table = PerRequestAttribute(self.max_num_cached_reqs,
max_num_blocks,
self.max_num_reqs, torch.int32,
self.device)
self.block_tables.append(block_table)
num_blocks = PerRequestAttribute(self.max_num_cached_reqs,
1,
self.max_num_reqs,
torch.int32,
self.device,
is_scalar=True)
self.num_blocks.append(num_blocks)
slot_mapping = torch.zeros(self.max_num_batched_tokens,
dtype=torch.int64,
device=self.device)
self.slot_mappings.append(slot_mapping)
def make_ptr_tensor(x: list[torch.Tensor]) -> torch.Tensor:
return torch.tensor([t.data_ptr() for t in x],
dtype=torch.int64,
device=self.device)
self.uva_block_table_ptrs = make_ptr_tensor(
[b.uva for b in self.block_tables])
self.gpu_block_table_ptrs = make_ptr_tensor(
[b.gpu for b in self.block_tables])
self.uva_num_blocks_ptrs = make_ptr_tensor(
[n.uva for n in self.num_blocks])
self.gpu_num_blocks_ptrs = make_ptr_tensor(
[n.gpu for n in self.num_blocks])
self.block_table_strides = torch.tensor(
[b.gpu.shape[1] for b in self.block_tables],
dtype=torch.int64,
device=self.device)
self.block_sizes_tensor = torch.tensor(self.block_sizes,
dtype=torch.int32,
device=self.device)
self.slot_mapping_ptrs = make_ptr_tensor(self.slot_mappings)
def make_block_tables(
self,
idx_mapping: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
batch_size = idx_mapping.shape[0]
_make_block_tables_kernel[(batch_size, self.num_block_tables)](
idx_mapping,
self.uva_block_table_ptrs,
self.gpu_block_table_ptrs,
self.block_table_strides,
self.uva_num_blocks_ptrs,
self.gpu_num_blocks_ptrs,
BLOCK_SIZE=1024,
)
return tuple(b.gpu[:batch_size] for b in self.block_tables)
def make_slot_mappings(
self,
cu_num_tokens: torch.Tensor,
pos: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
num_tokens = pos.shape[0]
num_reqs = cu_num_tokens.shape[0] - 1
_make_slot_mappings_kernel[(num_reqs + 1, self.num_block_tables)](
num_tokens,
self.max_num_batched_tokens,
cu_num_tokens,
pos,
self.gpu_block_table_ptrs,
self.block_table_strides,
self.block_sizes_tensor,
self.slot_mapping_ptrs,
PAD_ID=PAD_SLOT_ID,
BLOCK_SIZE=1024,
)
return tuple(x[:num_tokens] for x in self.slot_mappings)
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
@dataclass
@ -403,134 +16,78 @@ class InputBatch:
# batch_idx -> req_id
req_ids: list[str]
# req_id -> batch_idx
req_id_to_batch_idx: dict[str, int]
# batch_idx -> req_state_idx
idx_mapping: torch.Tensor
idx_mapping_np: np.ndarray
# [num_kv_cache_groups, num_reqs, max_num_blocks_per_req]
block_tables: tuple[torch.Tensor, ...]
# [num_kv_cache_groups, num_tokens]
slot_mappings: tuple[torch.Tensor, ...]
# batch_idx -> num_scheduled_tokens
num_scheduled_tokens: np.ndarray
total_num_tokens: int
max_num_tokens: int
num_reqs: int
# [num_reqs] mostly
sampling_metadata: SamplingMetadata
attn_metadata: dict[str, Any]
spec_decode_common_attn_metadata: Optional[Any]
spec_decode_metadata: Optional[SpecDecodeMetadata]
logits_indices: torch.Tensor
@triton.jit
def _make_sampling_metadata_kernel(
batch_idx_to_req_idx, # [batch_size]
src_temperature,
dst_temperature,
src_top_p,
dst_top_p,
src_top_k,
dst_top_k,
src_frequency_penalties,
dst_frequency_penalties,
src_presence_penalties,
dst_presence_penalties,
src_repetition_penalties,
dst_repetition_penalties,
):
batch_idx = tl.program_id(0)
req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
# NOTE: With the type annotations, this function is pre-compiled
# before the first call.
@numba.jit(
[
types.none(
types.int32[:], # idx_mapping
types.int32[:, :], # token_ids
types.int32[:], # num_computed_tokens
types.int32[:], # num_scheduled_tokens
types.int32[:], # input_ids
types.int32[:], # query_start_loc
types.int32[:], # seq_lens
types.int64[:], # positions
)
],
nopython=True,
cache=True,
)
def prepare_inputs(
# Inputs
idx_mapping: np.ndarray, # batch_idx -> req_idx
token_ids: np.ndarray, # [N, max_model_len]
num_computed_tokens: np.ndarray, # [N]
num_scheduled_tokens: np.ndarray, # [B]
# Outputs
input_ids: np.ndarray, # [num_input_tokens]
query_start_loc: np.ndarray, # [B + 1]
seq_lens: np.ndarray, # [B]
positions: np.ndarray, # [num_input_tokens]
) -> None:
num_reqs = num_scheduled_tokens.shape[0]
query_start_loc[0] = 0
temperature = tl.load(src_temperature + req_idx)
tl.store(dst_temperature + batch_idx, temperature)
cu_num_tokens = 0
for i in range(num_reqs):
req_idx = idx_mapping[i]
start = num_computed_tokens[req_idx]
end = start + num_scheduled_tokens[i]
seq_lens[i] = end
top_p = tl.load(src_top_p + req_idx)
tl.store(dst_top_p + batch_idx, top_p)
start_idx = cu_num_tokens
end_idx = start_idx + num_scheduled_tokens[i]
input_ids[start_idx:end_idx] = token_ids[req_idx, start:end]
positions[start_idx:end_idx] = np.arange(start, end)
top_k = tl.load(src_top_k + req_idx)
tl.store(dst_top_k + batch_idx, top_k)
cu_num_tokens = end_idx
query_start_loc[i + 1] = cu_num_tokens
frequency_penalties = tl.load(src_frequency_penalties + req_idx)
tl.store(dst_frequency_penalties + batch_idx, frequency_penalties)
presence_penalties = tl.load(src_presence_penalties + req_idx)
tl.store(dst_presence_penalties + batch_idx, presence_penalties)
repetition_penalties = tl.load(src_repetition_penalties + req_idx)
tl.store(dst_repetition_penalties + batch_idx, repetition_penalties)
@triton.jit
def _make_block_tables_kernel(
batch_idx_to_req_idx, # [batch_size]
src_block_table_ptrs, # [num_block_tables]
dst_block_table_ptrs, # [num_block_tables]
block_table_strides, # [num_block_tables]
src_num_blocks_ptrs, # [num_block_tables]
dst_num_blocks_ptrs, # [num_block_tables]
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
# kv cache group id
group_id = tl.program_id(1)
req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
src_num_blocks_ptr = _load_ptr(src_num_blocks_ptrs, group_id, tl.int32)
dst_num_blocks_ptr = _load_ptr(dst_num_blocks_ptrs, group_id, tl.int32)
num_blocks = tl.load(src_num_blocks_ptr + req_idx)
tl.store(dst_num_blocks_ptr + batch_idx, num_blocks)
stride = tl.load(block_table_strides + group_id)
src_block_table_ptr = _load_ptr(src_block_table_ptrs, group_id, tl.int32)
src_row_ptr = src_block_table_ptr + req_idx * stride
dst_block_table_ptr = _load_ptr(dst_block_table_ptrs, group_id, tl.int32)
dst_row_ptr = dst_block_table_ptr + batch_idx * stride
for i in tl.range(0, num_blocks, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
block_ids = tl.load(src_row_ptr + offset, mask=offset < num_blocks)
tl.store(dst_row_ptr + offset, block_ids, mask=offset < num_blocks)
@triton.jit
def _make_slot_mappings_kernel(
num_tokens,
max_num_tokens,
cu_num_tokens, # [num_reqs + 1]
pos, # [num_tokens]
block_table_ptrs, # [num_block_tables]
block_table_strides, # [num_block_tables]
page_sizes, # [num_block_tables]
slot_mapping_ptrs, # [num_block_tables]
PAD_ID: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
num_reqs = tl.num_programs(0)
# kv cache group id
group_id = tl.program_id(1)
slot_mapping_ptr = _load_ptr(slot_mapping_ptrs, group_id, tl.int64)
if req_idx == num_reqs - 1:
# Pad remaining slots to -1. This is needed for CUDA graphs.
for i in tl.range(num_tokens, max_num_tokens, BLOCK_SIZE):
offset = num_tokens + i + tl.arange(0, BLOCK_SIZE)
tl.store(slot_mapping_ptr + offset,
PAD_ID,
mask=offset < max_num_tokens)
return
block_table_ptr = _load_ptr(block_table_ptrs, group_id, tl.int32)
block_table_stride = tl.load(block_table_strides + group_id)
page_size = tl.load(page_sizes + group_id)
start_idx = tl.load(cu_num_tokens + req_idx)
end_idx = tl.load(cu_num_tokens + req_idx + 1)
for i in tl.range(start_idx, end_idx, BLOCK_SIZE):
offset = start_idx + i + tl.arange(0, BLOCK_SIZE)
positions = tl.load(pos + offset, mask=offset < end_idx, other=0)
block_indices = positions // page_size
block_numbers = tl.load(block_table_ptr +
req_idx * block_table_stride + block_indices)
slot_ids = block_numbers * page_size + positions % page_size
tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx)
@triton.jit
def _load_ptr(base, offset, elem_dtype):
ptr = tl.load(base + offset)
return tl.cast(ptr, tl.pointer_type(elem_dtype))
# Pad the inputs for CUDA graphs.
# Note: pad query_start_loc to be non-decreasing, as kernels
# like FlashAttention requires that
query_start_loc[num_reqs + 1:].fill(cu_num_tokens)
# Fill unused with 0 for full cuda graph mode.
seq_lens[num_reqs:].fill(0)

View File

@ -68,7 +68,7 @@ from vllm.v1.kv_cache_interface import (AttentionSpec,
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds,
LogprobsTensors, ModelRunnerOutput)
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
from vllm.v1.sample.logits_processor import LogitsProcessors
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.sample.sampler import Sampler
@ -76,7 +76,9 @@ from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.medusa import MedusaProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.gpu_block_table import BlockTables
from vllm.v1.worker.gpu_input_batch import InputBatch, prepare_inputs
from vllm.v1.worker.gpu_worker_states import RequestState
from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorModelRunnerMixin, KVConnectorOutput)
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
@ -200,18 +202,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.rejection_sampler = RejectionSampler()
# Request states.
self.requests: dict[str, CachedRequestState] = {}
# Input Batch
# NOTE(Chen): Ideally, we should initialize the input batch inside
# `initialize_kv_cache` based on the kv cache config. However, as in
# https://github.com/vllm-project/vllm/pull/18298, due to some unknown
# reasons, we have to initialize the input batch before `load_model`,
# quantization + weight offloading will fail otherwise. As a temporary
# solution, we initialize the input batch here, and re-initialize it
# in `initialize_kv_cache` if the block_sizes here is different from
# the block_sizes in the kv cache config.
self.input_batch = InputBatch(
self.requests = RequestState(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_batched_tokens=self.max_num_tokens,
@ -220,12 +211,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
block_sizes=[self.cache_config.block_size],
is_spec_decode=bool(self.vllm_config.speculative_config),
logitsprocs=build_logitsprocs(
self.vllm_config, self.device, self.pin_memory,
self.is_pooling_model,
self.vllm_config.model_config.logits_processors),
is_pooling_model=self.is_pooling_model,
)
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
@ -253,9 +238,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.seq_lens = torch.zeros(self.max_num_reqs,
dtype=torch.int32,
device=self.device)
self.slot_mapping = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device=self.device)
# None in the first PP rank. The rest are set after load_model.
self.intermediate_tensors: Optional[IntermediateTensors] = None
@ -290,12 +272,23 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dtype=self.dtype,
device=self.device)
self.block_tables = BlockTables(
block_sizes=[self.cache_config.block_size],
max_num_reqs=self.max_num_reqs,
max_num_cached_reqs=2 * self.max_num_reqs,
max_num_batched_tokens=self.max_num_tokens,
max_model_len=self.max_model_len,
device=self.device,
pin_memory=self.pin_memory,
)
# OPTIMIZATION: Cache the tensors rather than creating them every step.
# Keep in int64 to avoid overflow with long context
self.arange_np = np.arange(max(self.max_num_reqs + 1,
self.max_model_len,
self.max_num_tokens),
dtype=np.int64)
# NOTE(woosuk): These tensors are "stateless", i.e., they are literally
# a faster version of creating a new tensor every time. Thus, we should
# not make any assumptions about the values in these tensors.
@ -303,6 +296,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
self.input_ids_np = self.input_ids_cpu.numpy()
self.positions_cpu = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device="cpu",
@ -319,6 +313,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
pin_memory=self.pin_memory)
self.seq_lens_np = self.seq_lens_cpu.numpy()
self.index_mapping_cpu = torch.zeros(self.max_num_reqs,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
self.index_mapping_np = self.index_mapping_cpu.numpy()
self.index_mapping = self.index_mapping_cpu.to(self.device)
# Layer pairings for cross-layer KV sharing.
# If an Attention layer `layer_name` is in the keys of this dict, it
# means this layer will perform attention using the keys and values
@ -410,10 +411,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
The SamplingMetadata is updated and copied to the GPU if there is a
new/resumed/paused/finished request in the batch.
"""
# Remove finished requests from the cached states.
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
self.encoder_cache.pop(req_id, None)
# Remove the finished requests from the persistent batch.
# NOTE(woosuk): There could be an edge case where finished_req_ids and
# scheduled_req_ids overlap. This happens when a request is aborted and
@ -421,7 +418,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# distinct requests - clearing the cached states for the first request
# and handling the second as a new request.
for req_id in scheduler_output.finished_req_ids:
self.input_batch.remove_request(req_id)
self.requests.remove_request(req_id)
self.encoder_cache.pop(req_id, None)
# Free the cached encoder outputs.
for req_id, input_id in scheduler_output.free_encoder_input_ids:
@ -431,120 +429,82 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if not encoder_outputs:
self.encoder_cache.pop(req_id, None)
# Remove the unscheduled requests from the persistent batch.
# NOTE(woosuk): The unscheduled requests are either preempted requests
# or running requests that are not scheduled in this step. We remove
# them from the persistent batch but keep their cached states since
# they will be scheduled again sometime in the future.
scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
cached_req_ids = self.input_batch.req_id_to_index.keys()
unscheduled_req_ids = cached_req_ids - scheduled_req_ids
# NOTE(woosuk): The persistent batch optimization assumes that
# consecutive batches contain mostly the same requests. If batches
# have low request overlap (e.g., alternating between two distinct
# sets of requests), this optimization becomes very inefficient.
for req_id in unscheduled_req_ids:
self.input_batch.remove_request(req_id)
req_indices: list[int] = []
cu_num_new_blocks: list[list[int]] = [
[0] for _ in range(self.block_tables.num_kv_cache_groups)
]
new_block_ids: list[list[int]] = [
[] for _ in range(self.block_tables.num_kv_cache_groups)
]
overwrite: list[bool] = []
# Add new requests to the cached states.
for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id
sampling_params = new_req_data.sampling_params
pooling_params = new_req_data.pooling_params
if pooling_params:
task = pooling_params.task
assert task is not None, "You did not set `task` in the API"
model = cast(VllmModelForPooling, self.get_model())
to_update = model.pooler.get_pooling_updates(task)
to_update.apply(pooling_params)
req_state = CachedRequestState(
req_id=req_id,
mm_kwargs=new_req_data.mm_kwargs,
mm_positions=new_req_data.mm_positions,
sampling_params=sampling_params,
pooling_params=pooling_params,
lora_request=new_req_data.lora_request,
)
self.requests[req_id] = req_state
self.input_batch.add_request(
self.requests.add_request(
req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
block_ids=new_req_data.block_ids,
sampling_params=sampling_params,
sampling_params=new_req_data.sampling_params,
)
req_index = self.requests.req_id_to_index[req_id]
req_indices.append(req_index)
for i, block_ids in enumerate(new_req_data.block_ids):
x = cu_num_new_blocks[i][-1]
cu_num_new_blocks[i].append(x + len(block_ids))
new_block_ids[i].extend(block_ids)
overwrite.append(True)
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
self._init_mrope_positions(req_state)
self._init_mrope_positions(req_id)
# Update the states of the running/resumed requests.
is_last_rank = get_pp_group().is_last_rank
req_data = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(req_data.req_ids):
req_index = self.input_batch.req_id_to_index[req_id]
cached_reqs = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(cached_reqs.req_ids):
req_index = self.requests.req_id_to_index[req_id]
# Update input batch.
if not is_last_rank:
# When using PP, the scheduler sends the sampled tokens back,
# because there's no direct communication between the first-
# stage worker and the last-stage worker.
new_token_ids = req_data.new_token_ids[i]
self.input_batch.append_token_ids(req_index, new_token_ids)
new_token_ids = cached_reqs.new_token_ids[i]
self.requests.append_token_ids(req_index, new_token_ids)
new_block_ids = req_data.new_block_ids[i]
if new_block_ids is not None:
if cached_reqs.new_block_ids[i] is not None:
req_indices.append(req_index)
for i, block_ids in enumerate(cached_reqs.new_block_ids[i]):
x = cu_num_new_blocks[i][-1]
cu_num_new_blocks[i].append(x + len(block_ids))
new_block_ids[i].extend(block_ids)
# If the request is resumed from preemption, we need to
# overwrite the existing block IDs.
self.input_batch.append_block_ids(
req_index,
new_block_ids,
overwrite=req_data.resumed_from_preemption[i],
)
overwrite.append(cached_reqs.resumed_from_preemption[i])
self.input_batch.num_computed_tokens.np[req_index] = (
req_data.num_computed_tokens[i])
self.requests.num_computed_tokens.np[req_index] = (
cached_reqs.num_computed_tokens[i])
def _init_mrope_states(self, req_state: CachedRequestState) -> None:
image_grid_thw = []
video_grid_thw = []
second_per_grid_ts = []
audio_feature_lengths = []
use_audio_in_video = False
for mm_item in req_state.mm_kwargs:
mm_input = mm_item.get_data()
if (t := mm_input.get("image_grid_thw")) is not None:
image_grid_thw.append(t.tolist())
if (t := mm_input.get("video_grid_thw")) is not None:
video_grid_thw.append(t.tolist())
if (t := mm_input.get("second_per_grid_ts")) is not None:
second_per_grid_ts.append(t)
if (t := mm_input.get("audio_feature_lengths")) is not None:
audio_feature_lengths.append(t)
if mm_input.get("use_audio_in_video") is True:
use_audio_in_video = True
req_state.mrope_positions, req_state.mrope_position_delta = \
MRotaryEmbedding.get_input_positions_tensor(
req_state.prompt_token_ids,
hf_config=self.model_config.hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
audio_feature_lengths=audio_feature_lengths,
use_audio_in_video=use_audio_in_video,
if req_indices:
self.block_tables.append_block_ids(
req_indices=req_indices,
cu_num_new_blocks=cu_num_new_blocks,
new_block_ids=new_block_ids,
overwrite=overwrite,
)
def _init_mrope_positions(self, req_state: CachedRequestState):
def _init_mrope_positions(self, req_id: str) -> None:
req_idx = self.requests.req_id_to_index[req_id]
req_data = self.requests.req_data[req_idx]
image_grid_thw = []
video_grid_thw = []
second_per_grid_ts = []
audio_feature_lengths = []
use_audio_in_video = False
for mm_item in req_state.mm_kwargs:
for mm_item in req_data.mm_kwargs:
mm_input = mm_item.get_data()
if (t := mm_input.get("image_grid_thw")) is not None:
image_grid_thw.append(t.tolist())
@ -557,9 +517,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if mm_input.get("use_audio_in_video") is True:
use_audio_in_video = True
req_state.mrope_positions, req_state.mrope_position_delta = \
req_data.mrope_positions, req_data.mrope_position_delta = \
MRotaryEmbedding.get_input_positions_tensor(
req_state.prompt_token_ids,
req_data.prompt_token_ids,
hf_config=self.model_config.hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
@ -622,91 +582,55 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _prepare_inputs(
self,
scheduler_output: "SchedulerOutput",
) -> tuple[dict[str, Any], torch.Tensor, Optional[SpecDecodeMetadata],
np.ndarray, Optional[CommonAttentionMetadata], int]:
"""
:return: tuple[
attn_metadata: layer-to-attention_metadata mapping,
logits_indices, spec_decode_metadata
]
"""
) -> InputBatch:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs
assert num_reqs > 0
num_reqs = len(scheduler_output.num_scheduled_tokens)
# FIXME
# batch_idx -> req_id
req_ids = list(scheduler_output.num_scheduled_tokens.keys())
req_ids = sorted(scheduler_output.num_scheduled_tokens,
key=scheduler_output.num_scheduled_tokens.get)
# req_id -> batch_idx
req_id_to_batch_idx = {req_id: i for i, req_id in enumerate(req_ids)}
# batch_idx -> req_idx
idx_mapping = [
self.input_batch.req_id_to_index[req_id] for req_id in req_ids
idx_mapping_list = [
self.requests.req_id_to_index[req_id] for req_id in req_ids
]
# batch_idx -> req_idx
idx_mapping_tensor = self.input_batch.get_index_mapping(idx_mapping)
num_reqs = len(req_ids)
self.index_mapping_np[:num_reqs] = idx_mapping_list
index_mapping_np = self.index_mapping_np[:num_reqs]
idx_mapping = self.index_mapping[:num_reqs].copy_(
self.index_mapping_cpu[:num_reqs], non_blocking=True)
# OPTIMIZATION: Start copying the block table first.
# This way, we can overlap the copy with the following CPU operations.
block_tables = self.input_batch.make_block_tables(idx_mapping_tensor)
block_tables = self.block_tables.compute_block_tables(idx_mapping)
# Get the number of scheduled tokens for each request.
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
max_num_scheduled_tokens = max(tokens)
# Get request indices.
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
req_indices = np.repeat(self.arange_np[:num_reqs],
num_scheduled_tokens)
# cu_num_tokens: [2, 5, 3] -> [2, 7, 10]
# arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
cu_num_tokens, arange = self._get_cumsum_and_arange(
num_scheduled_tokens)
# Get positions.
positions_np = self.positions_np[:total_num_scheduled_tokens]
np.add(self.input_batch.num_computed_tokens.np[req_indices],
arange,
out=positions_np)
prepare_inputs(
idx_mapping=index_mapping_np,
token_ids=self.requests.token_ids.np,
num_computed_tokens=self.requests.num_computed_tokens.np,
num_scheduled_tokens=num_scheduled_tokens,
input_ids=self.input_ids_np,
query_start_loc=self.query_start_loc_np,
seq_lens=self.seq_lens_np,
positions=self.positions_np,
)
# Calculate M-RoPE positions.
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
self._calc_mrope_positions(scheduler_output)
# Get token indices.
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
# where M is the max_model_len.
token_indices = (positions_np +
req_indices * self.input_batch.token_ids.np.shape[1])
# NOTE(woosuk): We use torch.index_select instead of np.take here
# because torch.index_select is much faster than np.take for large
# tensors.
torch.index_select(self.input_batch.token_ids.cpu.flatten(),
0,
torch.from_numpy(token_indices),
out=self.input_ids_cpu[:total_num_scheduled_tokens])
# Prepare the attention metadata.
self.query_start_loc_np[0] = 0
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
# Note: pad query_start_loc to be non-decreasing, as kernels
# like FlashAttention requires that
self.query_start_loc_np[num_reqs + 1:].fill(cu_num_tokens[-1])
self.query_start_loc.copy_(self.query_start_loc_cpu, non_blocking=True)
query_start_loc = self.query_start_loc[:num_reqs + 1]
self.seq_lens_np[:num_reqs] = (
self.input_batch.num_computed_tokens.np[:num_reqs] +
num_scheduled_tokens)
# Fill unused with 0 for full cuda graph mode.
self.seq_lens_np[num_reqs:].fill(0)
self.seq_lens.copy_(self.seq_lens_cpu, non_blocking=True)
seq_lens = self.seq_lens[:num_reqs]
max_seq_len = self.seq_lens_np[:num_reqs].max().item()
@ -714,16 +638,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Copy the tensors to the GPU.
self.input_ids[:total_num_scheduled_tokens].copy_(
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
# Common case (1D positions)
self.positions[:total_num_scheduled_tokens].copy_(
self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True)
if self.uses_mrope:
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
non_blocking=True)
else:
# Common case (1D positions)
self.positions[:total_num_scheduled_tokens].copy_(
self.positions_cpu[:total_num_scheduled_tokens],
non_blocking=True)
use_spec_decode = len(
scheduler_output.scheduled_spec_decode_tokens) > 0
@ -737,16 +659,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
spec_decode_metadata = None
else:
# Get the number of draft tokens for each request.
# Iterate over the dictionary rather than all requests since not all
# requests have draft tokens.
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
for req_id, draft_token_ids in (
scheduler_output.scheduled_spec_decode_tokens.items()):
batch_idx = req_id_to_batch_idx[req_id]
num_draft_tokens[batch_idx] = len(draft_token_ids)
for i, req_id in enumerate(req_ids):
draft_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
if draft_token_ids:
num_draft_tokens[i] = len(draft_token_ids)
spec_decode_metadata = self._calc_spec_decode_metadata(
num_draft_tokens, cu_num_tokens)
num_draft_tokens, self.query_start_loc_np[1:num_reqs + 1])
logits_indices = spec_decode_metadata.logits_indices
logits_indices_padded = None
@ -774,15 +694,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded]
)
attn_metadata: dict[str, Any] = {}
slot_mappings = self.block_tables.compute_slot_mappings(
query_start_loc, self.positions[:total_num_scheduled_tokens])
# Used in the below loop.
query_start_loc_cpu = self.query_start_loc_cpu[:num_reqs + 1]
seq_lens_cpu = self.seq_lens_cpu[:num_reqs]
num_computed_tokens_cpu = (
self.input_batch.num_computed_tokens.cpu[:num_reqs])
self.requests.num_computed_tokens.cpu[:num_reqs])
spec_decode_common_attn_metadata = None
attn_metadata: dict[str, Any] = {}
# Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata.
for kv_cache_group_id, kv_cache_group_spec in enumerate(
@ -804,14 +726,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
non_blocking=True)
num_common_prefix_blocks = 0
else:
blk_table = self.input_batch.block_table[kv_cache_group_id]
blk_table_tensor = blk_table.get_device_tensor()[:num_reqs]
slot_mapping = blk_table.slot_mapping[:
total_num_scheduled_tokens]
# Fill unused with -1. Needed for reshape_and_cache in full cuda
# graph mode.
blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1)
blk_table_tensor = block_tables[kv_cache_group_id]
slot_mapping = slot_mappings[kv_cache_group_id]
num_common_prefix_blocks = (
scheduler_output.
num_common_prefix_blocks[kv_cache_group_id])
@ -876,13 +792,24 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
continue
attn_metadata[layer_name] = attn_metadata_i
# Hot-Swap lora model
if self.lora_config:
self.set_active_loras(self.input_batch, num_scheduled_tokens)
# # Hot-Swap lora model
# if self.lora_config:
# self.set_active_loras(input_batch, num_scheduled_tokens)
return (attn_metadata, logits_indices, spec_decode_metadata,
num_scheduled_tokens, spec_decode_common_attn_metadata,
max_num_scheduled_tokens)
return InputBatch(
req_ids=req_ids,
num_scheduled_tokens=num_scheduled_tokens,
req_id_to_batch_idx=req_id_to_batch_idx,
idx_mapping=idx_mapping,
idx_mapping_np=index_mapping_np,
num_reqs=num_reqs,
total_num_tokens=total_num_scheduled_tokens,
max_num_tokens=max_num_scheduled_tokens,
attn_metadata=attn_metadata,
spec_decode_metadata=spec_decode_metadata,
spec_decode_common_attn_metadata=spec_decode_common_attn_metadata,
logits_indices=logits_indices,
)
def _compute_cascade_attn_prefix_len(
self,
@ -955,7 +882,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_reqs = len(num_scheduled_tokens)
common_prefix_len = min(
common_prefix_len,
self.input_batch.num_computed_tokens.np[:num_reqs].min())
self.requests.num_computed_tokens.np[:num_reqs].min())
# common_prefix_len should be a multiple of the block size.
common_prefix_len = (common_prefix_len // kv_cache_spec.block_size *
kv_cache_spec.block_size)
@ -979,16 +906,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
return common_prefix_len if use_cascade else 0
def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
def _calc_mrope_positions(self, input_batch: InputBatch):
mrope_pos_ptr = 0
for index, req_id in enumerate(self.input_batch.req_ids):
for i, req_id in enumerate(input_batch.req_ids):
req = self.requests[req_id]
assert req.mrope_positions is not None
num_computed_tokens = \
self.input_batch.num_computed_tokens_cpu[index]
self.requests.num_computed_tokens_cpu[i]
num_scheduled_tokens = \
scheduler_output.num_scheduled_tokens[req_id]
input_batch.num_scheduled_tokens[i]
num_prompt_tokens = len(req.prompt_token_ids)
if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
@ -1159,17 +1086,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _gather_mm_embeddings(
self,
scheduler_output: "SchedulerOutput",
input_batch: InputBatch,
shift_computed_tokens: int = 0,
) -> list[torch.Tensor]:
mm_embeds: list[torch.Tensor] = []
for req_id in self.input_batch.req_ids:
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id]
req_state = self.requests[req_id]
num_computed_tokens = \
req_state.num_computed_tokens + shift_computed_tokens
mm_positions = req_state.mm_positions
for i, req_id in enumerate(input_batch.req_ids):
num_scheduled_tokens = input_batch.num_scheduled_tokens[i]
req_idx = self.requests.req_id_to_index[req_id]
num_computed_tokens = (
self.requests.num_computed_tokens.np[req_idx] +
shift_computed_tokens)
req_data = self.requests.req_data[req_idx]
mm_positions = req_data.mm_positions
for i, pos_info in enumerate(mm_positions):
start_pos = pos_info.offset
num_encoder_tokens = pos_info.length
@ -1274,8 +1202,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# request in the batch, as the logit indices are offset by this amount.
struct_out_req_batch_indices: dict[str, int] = {}
cumulative_offset = 0
seq = sorted(self.input_batch.req_id_to_index.items(),
key=lambda x: x[1])
seq = sorted(self.requests.req_id_to_index.items(), key=lambda x: x[1])
for req_id, batch_index in seq:
logit_index = batch_index + cumulative_offset
cumulative_offset += len(
@ -1431,7 +1358,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return ModelRunnerOutput(
req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
req_id_to_index=self.input_batch.req_id_to_batch_idx,
sampled_token_ids=[],
logprobs=None,
prompt_logprobs_dict={},
@ -1455,21 +1382,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.vllm_config)
# Prepare the decoder inputs.
(attn_metadata, logits_indices, spec_decode_metadata,
num_scheduled_tokens_np, spec_decode_common_attn_metadata,
max_query_len) = self._prepare_inputs(scheduler_output)
# FIXME
# batch_idx -> req_id
req_ids = list(scheduler_output.num_scheduled_tokens.keys())
# req_id -> batch_idx
req_id_to_batch_idx = {req_id: i for i, req_id in enumerate(req_ids)}
# batch_idx -> req_idx
idx_mapping = [
self.input_batch.req_id_to_index[req_id] for req_id in req_ids
]
# batch_idx -> req_idx
idx_mapping_tensor = self.input_batch.get_index_mapping(idx_mapping)
input_batch = self._prepare_inputs(scheduler_output)
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
@ -1540,8 +1453,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
num_input_tokens, intermediate_tensors, True)
uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
num_scheduled_tokens == self.input_batch.num_reqs * max_query_len)
uniform_decode = (input_batch.max_num_tokens
== self.uniform_decode_query_len
and num_scheduled_tokens
== input_batch.num_reqs * input_batch.max_num_tokens)
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
uniform_decode=uniform_decode)
cudagraph_runtime_mode, batch_descriptor = \
@ -1550,7 +1465,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Run the model.
# Use persistent buffers for CUDA graphs.
with set_forward_context(
attn_metadata,
input_batch.attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
@ -1590,11 +1505,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
all_gather_group=get_tp_group())
logits = None
else:
if self.input_batch.pooling_params:
return self._pool(hidden_states, num_scheduled_tokens,
num_scheduled_tokens_np, kv_connector_output)
sample_hidden_states = hidden_states[logits_indices]
sample_hidden_states = hidden_states[input_batch.logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
if broadcast_pp_output:
model_output_broadcast_data = {
@ -1610,9 +1521,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.apply_grammar_bitmask(scheduler_output, logits)
# Sample the next token and get logprobs if needed.
sampling_metadata = self.input_batch.make_sampling_metadata(
idx_mapping_tensor)
if spec_decode_metadata is None:
sampling_metadata = self.requests.make_sampling_metadata(
input_batch.idx_mapping)
if input_batch.spec_decode_metadata is None:
sampler_output = self.sampler(
logits=logits,
sampling_metadata=sampling_metadata,
@ -1623,7 +1534,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# logits tensor. This means any in-place operations on bonus_logits
# won't affect the original logits tensor.
assert logits is not None
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
bonus_logits = logits[
input_batch.spec_decode_metadata.bonus_logits_indices]
sampler_output = self.sampler(
logits=bonus_logits,
sampling_metadata=sampling_metadata,
@ -1633,9 +1545,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Just like `bonus_logits`, `target_logits` is a new tensor with
# separate storage from the original `logits` tensor. Therefore,
# it is safe to update `target_logits` in place.
target_logits = logits[spec_decode_metadata.target_logits_indices]
target_logits = logits[
input_batch.spec_decode_metadata.target_logits_indices]
output_token_ids = self.rejection_sampler(
spec_decode_metadata,
input_batch.spec_decode_metadata,
None, # draft_probs
target_logits,
bonus_token_ids,
@ -1643,6 +1556,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
sampler_output.sampled_token_ids = output_token_ids
for i in range(input_batch.num_reqs):
req_idx = input_batch.idx_mapping_np[i]
num_tokens = input_batch.num_scheduled_tokens[i]
self.requests.num_computed_tokens.np[req_idx] += num_tokens
num_nans_in_logits = {}
if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
num_nans_in_logits = self._get_nans_in_logits(logits)
@ -1664,27 +1582,25 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
# No spec decode tokens.
valid_sampled_token_ids = sampled_token_ids.tolist()
valid_sampled_token_ids_np = sampled_token_ids.cpu().numpy()
valid_sampled_token_ids = valid_sampled_token_ids_np.tolist()
else:
# Includes spec decode tokens.
valid_sampled_token_ids = self.rejection_sampler.parse_output(
sampled_token_ids, self.input_batch.vocab_size)
# # Mask out the sampled tokens that should not be sampled.
# for i in discard_sampled_tokens_req_indices:
# valid_sampled_token_ids[i].clear()
sampled_token_ids, self.vocab_size)
# Cache the sampled tokens in the model runner, so that the scheduler
# doesn't need to send them back.
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
# the sampled tokens back, because there's no direct communication
# between the first-stage worker and the last-stage worker.
for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
if not sampled_ids:
continue
self.input_batch.append_token_ids(req_idx, sampled_ids)
self.requests.append_sampled_token_ids(
input_batch.idx_mapping_np,
valid_sampled_token_ids,
)
if self.speculative_config:
assert spec_decode_common_attn_metadata is not None
assert input_batch.spec_decode_common_attn_metadata is not None
self._draft_token_ids = self.propose_draft_token_ids(
scheduler_output,
valid_sampled_token_ids,
@ -1692,15 +1608,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
hidden_states,
sample_hidden_states,
aux_hidden_states,
spec_decode_metadata,
spec_decode_common_attn_metadata,
input_batch.spec_decode_metadata,
input_batch.spec_decode_common_attn_metadata,
)
self.eplb_step()
return ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_id_to_batch_idx,
req_ids=input_batch.req_ids,
req_id_to_index=input_batch.req_id_to_batch_idx,
sampled_token_ids=valid_sampled_token_ids,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
@ -1712,7 +1628,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
if self._draft_token_ids is None:
return None
req_ids = self.input_batch.req_ids
req_ids = self.requests.req_ids
if isinstance(self._draft_token_ids, torch.Tensor):
draft_token_ids = self._draft_token_ids.tolist()
else:
@ -1722,16 +1638,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def propose_draft_token_ids(
self,
scheduler_output: "SchedulerOutput",
input_batch: InputBatch,
sampled_token_ids: list[list[int]],
sampling_metadata: SamplingMetadata,
hidden_states: torch.Tensor,
sample_hidden_states: torch.Tensor,
aux_hidden_states: Optional[torch.Tensor],
spec_decode_metadata: Optional[SpecDecodeMetadata],
common_attn_metadata: CommonAttentionMetadata,
) -> Union[list[list[int]], torch.Tensor]:
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
num_scheduled_tokens = input_batch.total_num_tokens
if self.speculative_config.method == "ngram":
assert isinstance(self.drafter, NgramProposer)
draft_token_ids = self.propose_ngram_draft_token_ids(
@ -1745,7 +1659,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
indices = []
offset = 0
for num_draft, tokens in zip(
spec_decode_metadata.num_draft_tokens,
input_batch.spec_decode_metadata.num_draft_tokens,
sampled_token_ids):
indices.append(offset + len(tokens) - 1)
offset += num_draft + 1
@ -1759,7 +1673,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
elif self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer)
# TODO(woosuk): Refactor the loop.
req_ids = self.input_batch.req_ids
req_ids = input_batch.req_ids
next_token_ids: list[int] = []
for i, token_ids in enumerate(sampled_token_ids):
if token_ids:
@ -1771,14 +1685,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_id = req_ids[i]
req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
input_batch.num_scheduled_tokens[i])
next_token_id = req_state.get_token_id(seq_len)
next_token_ids.append(next_token_id)
next_token_ids = torch.tensor(next_token_ids,
dtype=torch.int32,
device=self.device)
if spec_decode_metadata is None:
if input_batch.spec_decode_metadata is None:
# input_ids can be None for multimodal models.
target_token_ids = self.input_ids[:num_scheduled_tokens]
# TODO(woosuk): Support M-RoPE.
@ -1791,7 +1705,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
target_hidden_states = hidden_states[:num_scheduled_tokens]
else:
# TODO(woosuk): Refactor this.
num_draft_tokens = spec_decode_metadata.num_draft_tokens
num_draft_tokens = input_batch.spec_decode_metadata.num_draft_tokens
num_rejected_tokens = [
n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
for i, n in enumerate(num_draft_tokens)
@ -1812,7 +1726,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
target_hidden_states = hidden_states[token_indices]
mm_embeds = None
if self.supports_mm_inputs:
mm_embeds = self._gather_mm_embeddings(scheduler_output,
mm_embeds = self._gather_mm_embeddings(input_batch,
shift_computed_tokens=1)
draft_token_ids = self.drafter.propose(
@ -1828,10 +1742,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def propose_ngram_draft_token_ids(
self,
input_batch: InputBatch,
sampled_token_ids: list[list[int]],
) -> list[list[int]]:
# TODO(woosuk): Optimize.
req_ids = self.input_batch.req_ids
draft_token_ids: list[list[int]] = []
for i, sampled_ids in enumerate(sampled_token_ids):
num_sampled_ids = len(sampled_ids)
@ -1842,19 +1756,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Skip requests that require sampling parameters that are not
# supported with speculative decoding.
req_id = req_ids[i]
if req_id in self.input_batch.spec_decode_unsupported_reqs:
req_id = input_batch.req_ids[i]
if req_id in self.requests.spec_decode_unsupported_reqs:
draft_token_ids.append([])
continue
num_tokens = self.input_batch.num_tokens_no_spec[i]
num_tokens = self.requests.num_tokens_no_spec[i]
if num_tokens >= self.max_model_len:
# Skip requests that have already reached the max model length.
draft_token_ids.append([])
continue
drafter_output = self.drafter.propose(
self.input_batch.token_ids_cpu[i, :num_tokens])
self.requests.token_ids.np[i, :num_tokens])
if drafter_output is None or len(drafter_output) == 0:
draft_token_ids.append([])
else:
@ -1992,11 +1906,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
hidden_states: torch.Tensor,
num_scheduled_tokens: dict[str, int],
) -> dict[str, Optional[LogprobsTensors]]:
num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
num_prompt_logprobs_dict = self.requests.num_prompt_logprobs
if not num_prompt_logprobs_dict:
return {}
in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu
in_progress_dict = self.requests.in_progress_prompt_logprobs_cpu
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {}
# Since prompt logprobs are a rare feature, prioritize simple,
@ -2045,7 +1959,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Get the logits corresponding to this req's prompt tokens.
# If this is a partial request (i.e. chunked prefill),
# then there is prompt logprob generated for each index.
req_idx = self.input_batch.req_id_to_index[req_id]
req_idx = 0
offset = self.query_start_loc_np[req_idx].item()
prompt_hidden_states = hidden_states[offset:offset + num_logits]
logits = self.model.compute_logits(prompt_hidden_states, None)
@ -2083,20 +1997,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _get_nans_in_logits(
self,
input_batch: InputBatch,
logits: Optional[torch.Tensor],
) -> dict[str, int]:
try:
if logits is None:
return {req_id: 0 for req_id in self.input_batch.req_ids}
return {req_id: 0 for req_id in input_batch.req_ids}
num_nans_in_logits = {}
num_nans_for_index = logits.isnan().sum(dim=-1).cpu().numpy()
for req_id in self.input_batch.req_ids:
req_index = self.input_batch.req_id_to_index[req_id]
num_nans_in_logits[req_id] = (
int(num_nans_for_index[req_index])
if num_nans_for_index is not None
and req_index < logits.shape[0] else 0)
for i, req_id in input_batch.req_ids:
num_nans_in_logits[req_id] = (int(num_nans_for_index[i])
if num_nans_for_index is not None
and i < logits.shape[0] else 0)
return num_nans_in_logits
except IndexError:
return {}
@ -2248,17 +2161,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
1],
seq_lens=self.seq_lens[:num_reqs],
seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
num_computed_tokens_cpu=self.input_batch.
num_computed_tokens_cpu_tensor[:num_reqs],
num_computed_tokens_cpu=self.requests.num_computed_tokens.
cpu[:num_reqs],
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
max_seq_len=self.max_model_len,
block_table_tensor=self.input_batch.block_table[
kv_cache_group_id].get_device_tensor()[:num_reqs],
slot_mapping=self.input_batch.
block_table[kv_cache_group_id].slot_mapping[:num_tokens],
causal=True)
block_table_tensor=self.requests.
block_tables[kv_cache_group_id].gpu[:num_reqs],
slot_mapping=self.requests.slot_mappings[kv_cache_group_id]
[:num_tokens],
causal=True,
)
for attn_group in self.attn_groups[kv_cache_group_id]:
attn_metadata_i = attn_group.metadata_builder\

View File

@ -0,0 +1,342 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Datastructures defining a GPU input batch
from dataclasses import dataclass
from typing import Optional, Union
import numpy as np
import torch
import triton
import triton.language as tl
from typing_extensions import deprecated
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import (MultiModalKwargsItem,
MultiModalKwargsItems, PlaceholderRange)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.v1.sample.logits_processor import LogitsProcessors
from vllm.v1.sample.metadata import SamplingMetadata
@dataclass
class RequestData:
mm_kwargs: list[MultiModalKwargsItem]
mm_positions: list[PlaceholderRange]
sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams]
# M-RoPE (only for Qwen2/2.5-VL)
mrope_positions: Optional[torch.Tensor] = None
mrope_position_delta: Optional[int] = None
lora_request: Optional[LoRARequest] = None
# Temporary back-compatibility for plugins that define model runner
@property
@deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be "
"removed in v0.13. Please use `mm_kwargs` instead.")
def mm_inputs(self) -> list[MultiModalKwargsItems]:
return [
MultiModalKwargsItems.from_seq([item]) for item in self.mm_kwargs
]
class RequestAttribute:
def __init__(
self,
num_rows_cpu: int,
num_cols: int,
num_rows_gpu: int,
dtype: torch.dtype,
device: torch.device,
pin_memory: bool,
is_scalar: bool = False,
):
self.cpu = torch.zeros(num_rows_cpu,
num_cols,
dtype=dtype,
device="cpu",
pin_memory=pin_memory)
self.np = self.cpu.numpy()
self.gpu = torch.zeros(num_rows_gpu,
num_cols,
dtype=dtype,
device=device)
if is_scalar:
assert num_cols == 1
self.cpu.squeeze_(1)
self.np = self.cpu.numpy()
self.gpu.squeeze_(1)
self.gpu_buffer = self.cpu.to(device)
def mirror_to_gpu(self) -> torch.Tensor:
return self.gpu_buffer.copy_(self.cpu, non_blocking=True)
class RequestState:
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
max_num_cached_reqs: int,
device: torch.device,
pin_memory: bool,
vocab_size: int,
block_sizes: list[int], # The block_size of each kv cache group
logitsprocs: Optional[LogitsProcessors] = None,
is_spec_decode: bool = False,
is_pooling_model: bool = False,
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_batched_tokens = max_num_batched_tokens
self.max_num_cached_reqs = max_num_cached_reqs
self.device = device
self.pin_memory = pin_memory
self.vocab_size = vocab_size
self.is_spec_decode = is_spec_decode
self.pooling_params = None
self.block_sizes = block_sizes
self.num_prompt_logprobs = {}
self.req_id_to_index: dict[str, int] = {}
self.index_to_req_id: dict[int, str] = {}
self.free_indices = list(range(max_num_cached_reqs))
# Used to construct the input batch.
self._add_scalar_attr("idx_mapping", torch.int32)
# Request states.
self.req_data: dict[int, RequestData] = {}
# TODO(woosuk): Because the token_ids tensor can be very big, we only
# initialize it on CPU memory.
self._add_vector_attr("token_ids",
self.max_model_len,
torch.int32,
cpu_only=True)
self._add_scalar_attr("num_prompt_tokens", torch.int32)
self._add_scalar_attr("num_tokens", torch.int32)
self._add_scalar_attr("num_computed_tokens", torch.int32)
# Sampling-related.
self._add_scalar_attr("temperature", torch.float32)
self.greedy_reqs: set[str] = set()
self.random_reqs: set[str] = set()
self._add_scalar_attr("top_p", torch.float32)
self.top_p_reqs: set[str] = set()
self._add_scalar_attr("top_k", torch.int32)
self.top_k_reqs: set[str] = set()
self._add_scalar_attr("frequency_penalties", torch.float32)
self.frequency_penalties_reqs: set[str] = set()
self._add_scalar_attr("presence_penalties", torch.float32)
self.presence_penalties_reqs: set[str] = set()
self._add_scalar_attr("repetition_penalties", torch.float32)
self.repetition_penalties_reqs: set[str] = set()
# req_idx -> generator
self.generators: dict[int, torch.Generator] = {}
def add_request(
self,
req_id: str,
prompt_token_ids: list[int],
num_computed_tokens: int,
sampling_params: SamplingParams,
) -> None:
req_idx = self.free_indices.pop()
self.req_id_to_index[req_id] = req_idx
self.index_to_req_id[req_idx] = req_id
self.num_prompt_tokens.np[req_idx] = len(prompt_token_ids)
self.num_computed_tokens.np[req_idx] = num_computed_tokens
self.append_token_ids(req_idx, prompt_token_ids)
self.temperature.np[req_idx] = sampling_params.temperature
if sampling_params.sampling_type == SamplingType.GREEDY:
# NOTE: Be careful about division by zero.
self.greedy_reqs.add(req_id)
elif sampling_params.sampling_type == SamplingType.RANDOM:
self.random_reqs.add(req_id)
self.top_p.np[req_idx] = sampling_params.top_p
if sampling_params.top_p < 1.0:
self.top_p_reqs.add(req_id)
top_k = sampling_params.top_k
if 0 < top_k < self.vocab_size:
self.top_k_reqs.add(req_id)
else:
top_k = self.vocab_size
self.top_k.np[req_idx] = top_k
self.frequency_penalties.np[
req_idx] = sampling_params.frequency_penalty
if sampling_params.frequency_penalty != 0.0:
self.frequency_penalties_reqs.add(req_id)
self.presence_penalties.np[req_idx] = sampling_params.presence_penalty
if sampling_params.presence_penalty != 0.0:
self.presence_penalties_reqs.add(req_id)
self.repetition_penalties.np[
req_idx] = sampling_params.repetition_penalty
if sampling_params.repetition_penalty != 1.0:
self.repetition_penalties_reqs.add(req_id)
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
generator = torch.Generator(device=self.device)
generator.manual_seed(sampling_params.seed)
self.generators[req_idx] = generator
def append_token_ids(
self,
req_idx: int,
token_ids: Union[list[int], np.ndarray],
) -> None:
start_idx = self.num_tokens.np[req_idx]
end_idx = start_idx + len(token_ids)
self.token_ids.np[req_idx, start_idx:end_idx] = token_ids
self.num_tokens.np[req_idx] = end_idx
def append_sampled_token_ids(
self,
idx_mapping: np.ndarray,
sampled_token_ids: np.ndarray,
) -> None:
num_reqs = idx_mapping.shape[0]
for i in range(num_reqs):
req_idx = idx_mapping[i]
self.append_token_ids(req_idx, sampled_token_ids[i])
def remove_request(self, req_id: str) -> None:
req_idx = self.req_id_to_index.pop(req_id, None)
if req_idx is None:
# Request not found.
return
self.index_to_req_id.pop(req_idx, None)
self.free_indices.append(req_idx)
self.greedy_reqs.discard(req_id)
self.random_reqs.discard(req_id)
self.top_p_reqs.discard(req_id)
self.top_k_reqs.discard(req_id)
self.frequency_penalties_reqs.discard(req_id)
self.presence_penalties_reqs.discard(req_id)
self.repetition_penalties_reqs.discard(req_id)
self.generators.pop(req_idx, None)
def make_sampling_metadata(
self,
batch_idx_to_req_idx: torch.Tensor,
) -> SamplingMetadata:
batch_size = batch_idx_to_req_idx.shape[0]
_make_sampling_metadata_kernel[(batch_size, )](
batch_idx_to_req_idx,
self.temperature.mirror_to_gpu(),
self.temperature.gpu,
self.top_p.mirror_to_gpu(),
self.top_p.gpu,
self.top_k.mirror_to_gpu(),
self.top_k.gpu,
self.frequency_penalties.mirror_to_gpu(),
self.frequency_penalties.gpu,
self.presence_penalties.mirror_to_gpu(),
self.presence_penalties.gpu,
self.repetition_penalties.mirror_to_gpu(),
self.repetition_penalties.gpu,
num_warps=1,
num_stages=1,
)
no_penalties = not (self.frequency_penalties_reqs
or self.presence_penalties_reqs
or self.repetition_penalties_reqs)
return SamplingMetadata(
temperature=self.temperature.gpu[:batch_size],
all_greedy=not self.random_reqs,
all_random=not self.greedy_reqs,
top_p=self.top_p.gpu[:batch_size],
top_k=self.top_k.gpu[:batch_size],
frequency_penalties=self.frequency_penalties.gpu[:batch_size],
presence_penalties=self.presence_penalties.gpu[:batch_size],
repetition_penalties=self.repetition_penalties.gpu[:batch_size],
no_penalties=no_penalties,
# TODO
generators={},
token_ids=self.token_ids.cpu[:batch_size],
max_num_logprobs=None,
allowed_token_ids_mask=None,
bad_words_token_ids={},
logitsprocs=None,
)
@property
def num_cached_reqs(self) -> int:
return len(self.req_id_to_index)
def _add_scalar_attr(self, name: str, dtype: torch.dtype):
attr = RequestAttribute(self.max_num_cached_reqs,
1,
self.max_num_reqs,
dtype,
self.device,
self.pin_memory,
is_scalar=True)
setattr(self, name, attr)
def _add_vector_attr(
self,
name: str,
max_len: int,
dtype: torch.dtype,
cpu_only: bool = False,
):
if cpu_only:
num_rows_gpu = 0
else:
num_rows_gpu = self.max_num_reqs
attr = RequestAttribute(self.max_num_cached_reqs, max_len,
num_rows_gpu, dtype, self.device,
self.pin_memory)
setattr(self, name, attr)
@triton.jit
def _make_sampling_metadata_kernel(
batch_idx_to_req_idx, # [batch_size]
src_temperature,
dst_temperature,
src_top_p,
dst_top_p,
src_top_k,
dst_top_k,
src_frequency_penalties,
dst_frequency_penalties,
src_presence_penalties,
dst_presence_penalties,
src_repetition_penalties,
dst_repetition_penalties,
):
batch_idx = tl.program_id(0)
req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
temperature = tl.load(src_temperature + req_idx)
tl.store(dst_temperature + batch_idx, temperature)
top_p = tl.load(src_top_p + req_idx)
tl.store(dst_top_p + batch_idx, top_p)
top_k = tl.load(src_top_k + req_idx)
tl.store(dst_top_k + batch_idx, top_k)
frequency_penalties = tl.load(src_frequency_penalties + req_idx)
tl.store(dst_frequency_penalties + batch_idx, frequency_penalties)
presence_penalties = tl.load(src_presence_penalties + req_idx)
tl.store(dst_presence_penalties + batch_idx, presence_penalties)
repetition_penalties = tl.load(src_repetition_penalties + req_idx)
tl.store(dst_repetition_penalties + batch_idx, repetition_penalties)

View File

@ -12,7 +12,6 @@ from vllm.sampling_params import SamplingType
from vllm.utils import swap_dict_values
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.worker.block_table import MultiGroupBlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState
_SAMPLING_EPS = 1e-5

View File

@ -298,3 +298,32 @@ def bind_kv_cache(
for layer_name, kv_cache in kv_caches.items():
# NOTE: Use list because of v0 PP virtual engine.
forward_context[layer_name].kv_cache = [kv_cache]
class CpuGpuBuffer:
def __init__(
self,
*args,
dtype: torch.dtype,
device: torch.device,
pin_memory: bool,
):
self.cpu = torch.zeros(*args,
dtype=dtype,
device="cpu",
pin_memory=pin_memory)
self.np = self.cpu.numpy()
self.gpu = self.cpu.to(device)
def copy_to_gpu(self, n: Optional[int] = None) -> None:
if n is None:
return self.gpu.copy_(self.cpu, non_blocking=True)
else:
return self.gpu[:n].copy_(self.cpu[:n], non_blocking=True)
def copy_to_cpu(self, n: Optional[int] = None) -> None:
if n is None:
return self.cpu.copy_(self.gpu, non_blocking=True)
else:
return self.cpu[:n].copy_(self.gpu[:n], non_blocking=True)