Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-08-22 01:37:43 -07:00
parent c472982746
commit 79e5eb3643
2 changed files with 396 additions and 314 deletions

View File

@ -6,6 +6,8 @@ from dataclasses import dataclass
from typing import Optional
import torch
import triton
import triton.language as tl
from typing_extensions import deprecated
from vllm.lora.request import LoRARequest
@ -17,6 +19,8 @@ from vllm.utils import cdiv, get_cuda_view_from_cpu_tensor, is_uva_available
from vllm.v1.sample.logits_processor import LogitsProcessors
from vllm.v1.sample.metadata import SamplingMetadata
PAD_SLOT_ID = -1
@dataclass
class CachedRequestState:
@ -27,6 +31,7 @@ class CachedRequestState:
sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams]
# M-RoPE (only for Qwen2/2.5-VL)
mrope_positions: Optional[torch.Tensor] = None
mrope_position_delta: Optional[int] = None
@ -46,21 +51,31 @@ class PerRequestAttribute:
def __init__(
self,
N: int,
M: int,
K: int,
num_rows_cpu: int,
num_cols: int,
num_rows_gpu: int,
dtype: torch.dtype,
device: torch.device,
is_scalar: bool = False,
):
assert is_uva_available(), "UVA is not available."
self.cpu_tensor = torch.zeros(N,
M,
dtype=dtype,
device="cpu",
pin_memory=True)
self.np = self.cpu_tensor.numpy()
self.uva_tensor = get_cuda_view_from_cpu_tensor(self.cpu_tensor)
self.gpu_tensor = torch.zeros(K, M, dtype=dtype, device=device)
self.cpu = torch.zeros(num_rows_cpu,
num_cols,
dtype=dtype,
device="cpu",
pin_memory=True)
self.np = self.cpu.numpy()
self.uva = get_cuda_view_from_cpu_tensor(self.cpu)
self.gpu = torch.zeros(num_rows_gpu,
num_cols,
dtype=dtype,
device=device)
if is_scalar:
assert num_cols == 1
self.cpu.squeeze_(1)
self.np = self.cpu.numpy()
self.uva.squeeze_(1)
self.gpu.squeeze_(1)
class InputBatch:
@ -87,15 +102,19 @@ class InputBatch:
self.pin_memory = pin_memory
self.vocab_size = vocab_size
self.is_spec_decode = is_spec_decode
self.pooling_params = None
self.block_sizes = block_sizes
self.num_prompt_logprobs = {}
self.req_id_to_index: dict[str, int] = {}
self.index_to_req_id: dict[int, str] = {}
self.free_indices = list(range(max_num_cached_reqs))
self._add_scalar_attr("idx_mapping", torch.int32)
# Request states.
# TODO(woosuk): This buffer could be too large if max_model_len is big.
# Find a way to reduce the memory usage.
self._add_vector_attr("token_ids", self.max_model_len, torch.int32)
# TODO(woosuk): Because the token_ids tensor can be very big, we only
# initialize it on CPU memory.
self._add_vector_attr_cpu("token_ids", self.max_model_len, torch.int32)
self._add_scalar_attr("num_prompt_tokens", torch.int32)
self._add_scalar_attr("num_tokens", torch.int32)
self._add_scalar_attr("num_computed_tokens", torch.int32)
@ -119,20 +138,7 @@ class InputBatch:
self.generators: dict[int, torch.Generator] = {}
# Block table(s).
self.block_tables = []
self.num_blocks = []
for block_size in block_sizes:
max_num_blocks = cdiv(max_model_len, block_size)
block_table = PerRequestAttribute(self.max_num_cached_reqs,
max_num_blocks,
self.max_num_reqs, torch.int32,
self.device)
self.block_tables.append(block_table)
num_blocks = PerRequestAttribute(self.max_num_cached_reqs, 1,
self.max_num_reqs, torch.int32,
self.device)
self.num_blocks.append(num_blocks)
self.num_block_tables = len(block_sizes)
self._init_block_tables()
def add_request(
self,
@ -146,11 +152,10 @@ class InputBatch:
self.req_id_to_index[req_id] = req_idx
self.index_to_req_id[req_idx] = req_id
num_prompt_tokens = len(prompt_token_ids)
self.token_ids.np[req_idx, :num_prompt_tokens] = prompt_token_ids
self.num_prompt_tokens.np[req_idx] = num_prompt_tokens
self.num_tokens.np[req_idx] = num_prompt_tokens
self.num_prompt_tokens.np[req_idx] = len(prompt_token_ids)
self.num_computed_tokens.np[req_idx] = num_computed_tokens
self.append_token_ids(req_idx, prompt_token_ids)
self.append_block_ids(req_idx, block_ids, overwrite=True)
self.temperature.np[req_idx] = sampling_params.temperature
if sampling_params.sampling_type == SamplingType.GREEDY:
@ -171,56 +176,48 @@ class InputBatch:
self.top_k.np[req_idx] = top_k
self.frequency_penalties.np[
req_idx] = sampling_params.frequency_penalties
if sampling_params.frequency_penalties != 0.0:
req_idx] = sampling_params.frequency_penalty
if sampling_params.frequency_penalty != 0.0:
self.frequency_penalties_reqs.add(req_id)
self.presence_penalties.np[
req_idx] = sampling_params.presence_penalties
if sampling_params.presence_penalties != 0.0:
self.presence_penalties.np[req_idx] = sampling_params.presence_penalty
if sampling_params.presence_penalty != 0.0:
self.presence_penalties_reqs.add(req_id)
self.repetition_penalties.np[
req_idx] = sampling_params.repetition_penalties
if sampling_params.repetition_penalties != 1.0:
req_idx] = sampling_params.repetition_penalty
if sampling_params.repetition_penalty != 1.0:
self.repetition_penalties_reqs.add(req_id)
if sampling_params.seed is not None:
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
generator = torch.Generator(device=self.device)
generator.manual_seed(sampling_params.seed)
self.generators[req_idx] = generator
for i in range(self.num_block_tables):
self.block_tables[i].np[req_idx, :len(block_ids[i])] = block_ids[i]
self.num_blocks[i].np[req_idx] = len(block_ids[i])
def append_token_ids(self, req_id: str, token_ids: list[int]) -> None:
req_idx = self.req_id_to_index.get(req_id)
assert req_idx is not None
def append_token_ids(self, req_idx: int, token_ids: list[int]) -> None:
start_idx = self.num_tokens.np[req_idx]
end_idx = start_idx + len(token_ids)
self.token_ids.np[req_idx, start_idx:end_idx] = token_ids
self.num_tokens.np[req_idx] = end_idx
# TODO(woosuk): Further vectorize this to minimize overheads.
def append_block_ids(
self,
req_id: str,
req_idx: int,
new_block_ids: tuple[list[int], ...],
overwrite: bool,
) -> None:
req_idx = self.req_id_to_index.get(req_id)
assert req_idx is not None
for i in range(self.num_block_tables):
block_table = self.block_tables[i]
num_blocks = self.num_blocks[i]
num_new_blocks = len(new_block_ids[i])
if overwrite:
# Replace the existing block IDs with the new ones.
# This happens when the request is resumed from preemption.
block_table.np[
req_idx, :len(new_block_ids[i])] = new_block_ids[i]
num_blocks.np[req_idx] = len(new_block_ids[i])
block_table.np[req_idx, :num_new_blocks] = new_block_ids[i]
num_blocks.np[req_idx] = num_new_blocks
else:
# Append the new block IDs to the existing ones (common case).
start_idx = num_blocks.np[req_idx]
end_idx = start_idx + len(new_block_ids[i])
end_idx = start_idx + num_new_blocks
block_table.np[req_idx, start_idx:end_idx] = new_block_ids[i]
num_blocks.np[req_idx] = end_idx
@ -241,52 +238,50 @@ class InputBatch:
self.repetition_penalties_reqs.discard(req_id)
self.generators.pop(req_idx, None)
def make_block_table(self, req_idx: int) -> tuple[torch.Tensor, ...]:
pass
def get_index_mapping(self, idx_mapping: list[int]) -> torch.Tensor:
num_reqs = len(idx_mapping)
self.idx_mapping.np[:num_reqs] = idx_mapping
return self.idx_mapping.gpu[:num_reqs].copy_(
self.idx_mapping.uva[:num_reqs], non_blocking=True)
def make_sampling_metadata(self,
req_indices: list[int]) -> SamplingMetadata:
batch_size = len(req_indices)
def make_sampling_metadata(
self,
batch_idx_to_req_idx: torch.Tensor,
) -> SamplingMetadata:
batch_size = batch_idx_to_req_idx.shape[0]
_make_sampling_metadata_kernel[(batch_size, )](
req_indices,
self.temperature.uva_tensor,
self.temperature.gpu_tensor,
self.top_p.uva_tensor,
self.top_p.gpu_tensor,
self.top_k.uva_tensor,
self.top_k.gpu_tensor,
self.frequency_penalties.uva_tensor,
self.frequency_penalties.gpu_tensor,
self.presence_penalties.uva_tensor,
self.presence_penalties.gpu_tensor,
self.repetition_penalties.uva_tensor,
self.repetition_penalties.gpu_tensor,
batch_idx_to_req_idx,
self.temperature.uva,
self.temperature.gpu,
self.top_p.uva,
self.top_p.gpu,
self.top_k.uva,
self.top_k.gpu,
self.frequency_penalties.uva,
self.frequency_penalties.gpu,
self.presence_penalties.uva,
self.presence_penalties.gpu,
self.repetition_penalties.uva,
self.repetition_penalties.gpu,
num_warps=1,
num_stages=1,
)
generators = {}
if self.generators:
for i, req_idx in enumerate(req_indices):
generator = self.generators.get(req_idx)
if generator is not None:
generators[i] = generator
no_penalties = not (self.frequency_penalties_reqs
or self.presence_penalties_reqs
or self.repetition_penalties_reqs)
return SamplingMetadata(
temperature=self.temperature.gpu_tensor[:batch_size],
temperature=self.temperature.gpu[:batch_size],
all_greedy=not self.random_reqs,
all_random=not self.greedy_reqs,
top_p=self.top_p.gpu_tensor[:batch_size],
top_k=self.top_k.gpu_tensor[:batch_size],
frequency_penalties=self.frequency_penalties.
gpu_tensor[:batch_size],
presence_penalties=self.presence_penalties.gpu_tensor[:batch_size],
repetition_penalties=self.repetition_penalties.
gpu_tensor[:batch_size],
top_p=self.top_p.gpu[:batch_size],
top_k=self.top_k.gpu[:batch_size],
frequency_penalties=self.frequency_penalties.gpu[:batch_size],
presence_penalties=self.presence_penalties.gpu[:batch_size],
repetition_penalties=self.repetition_penalties.gpu[:batch_size],
no_penalties=no_penalties,
generators=generators,
token_ids=self.token_ids.gpu_tensor[:batch_size],
# TODO
generators={},
token_ids=self.token_ids.gpu[:batch_size],
max_num_logprobs=None,
allowed_token_ids_mask=None,
bad_words_token_ids={},
@ -297,22 +292,113 @@ class InputBatch:
def num_reqs(self) -> int:
return len(self.req_id_to_index)
def _add_scalar_attr(self, name: str, dtype: torch.dtype):
attr = PerRequestAttribute(self.max_num_cached_reqs,
1,
self.max_num_reqs,
dtype,
self.device,
is_scalar=True)
setattr(self, name, attr)
def _add_vector_attr(self, name: str, max_len: int, dtype: torch.dtype):
attr = PerRequestAttribute(self.max_num_cached_reqs, max_len,
self.max_num_reqs, dtype, self.device)
setattr(self, name, attr)
def _add_scalar_attr(self, name: str, dtype: torch.dtype):
self._add_vector_attr(name, max_len=1, dtype=dtype)
def _add_vector_attr_cpu(self, name: str, max_len: int,
dtype: torch.dtype):
attr = PerRequestAttribute(self.max_num_cached_reqs, max_len, 0, dtype,
self.device)
setattr(self, name, attr)
def _init_block_tables(self):
self.num_block_tables = len(self.block_sizes)
self.block_tables = []
self.num_blocks = []
self.slot_mappings: list[torch.Tensor] = []
for i in range(self.num_block_tables):
max_num_blocks = cdiv(self.max_model_len, self.block_sizes[i])
block_table = PerRequestAttribute(self.max_num_cached_reqs,
max_num_blocks,
self.max_num_reqs, torch.int32,
self.device)
self.block_tables.append(block_table)
num_blocks = PerRequestAttribute(self.max_num_cached_reqs,
1,
self.max_num_reqs,
torch.int32,
self.device,
is_scalar=True)
self.num_blocks.append(num_blocks)
slot_mapping = torch.zeros(self.max_num_batched_tokens,
dtype=torch.int64,
device=self.device)
self.slot_mappings.append(slot_mapping)
import triton
import triton.language as tl
def make_ptr_tensor(x: list[torch.Tensor]) -> torch.Tensor:
return torch.tensor([t.data_ptr() for t in x],
dtype=torch.int64,
device=self.device)
self.uva_block_table_ptrs = make_ptr_tensor(
[b.uva for b in self.block_tables])
self.gpu_block_table_ptrs = make_ptr_tensor(
[b.gpu for b in self.block_tables])
self.uva_num_blocks_ptrs = make_ptr_tensor(
[n.uva for n in self.num_blocks])
self.gpu_num_blocks_ptrs = make_ptr_tensor(
[n.gpu for n in self.num_blocks])
self.block_table_strides = torch.tensor(
[b.gpu.shape[1] for b in self.block_tables],
dtype=torch.int64,
device=self.device)
self.block_sizes_tensor = torch.tensor(self.block_sizes,
dtype=torch.int32,
device=self.device)
self.slot_mapping_ptrs = make_ptr_tensor(self.slot_mappings)
def make_block_tables(
self,
idx_mapping: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
batch_size = idx_mapping.shape[0]
_make_block_tables_kernel[(batch_size, self.num_block_tables)](
idx_mapping,
self.uva_block_table_ptrs,
self.gpu_block_table_ptrs,
self.block_table_strides,
self.uva_num_blocks_ptrs,
self.gpu_num_blocks_ptrs,
BLOCK_SIZE=1024,
)
return tuple(b.gpu[:batch_size] for b in self.block_tables)
def make_slot_mappings(
self,
cu_num_tokens: torch.Tensor,
pos: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
num_tokens = pos.shape[0]
num_reqs = cu_num_tokens.shape[0] - 1
_make_slot_mappings_kernel[(num_reqs + 1, self.num_block_tables)](
num_tokens,
self.max_num_batched_tokens,
cu_num_tokens,
pos,
self.gpu_block_table_ptrs,
self.block_table_strides,
self.block_sizes_tensor,
self.slot_mapping_ptrs,
PAD_ID=PAD_SLOT_ID,
BLOCK_SIZE=1024,
)
return tuple(x[:num_tokens] for x in self.slot_mappings)
@triton.jit
def _make_sampling_metadata_kernel(
req_indices, # [batch_size]
batch_idx_to_req_idx, # [batch_size]
src_temperature,
dst_temperature,
src_top_p,
@ -327,52 +413,104 @@ def _make_sampling_metadata_kernel(
dst_repetition_penalties,
):
batch_idx = tl.program_id(0)
req_index = tl.load(req_indices + batch_idx)
req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
temperature = tl.load(src_temperature + req_index)
tl.store(dst_temperature + req_index, temperature)
temperature = tl.load(src_temperature + req_idx)
tl.store(dst_temperature + batch_idx, temperature)
top_p = tl.load(src_top_p + req_index)
tl.store(dst_top_p + req_index, top_p)
top_p = tl.load(src_top_p + req_idx)
tl.store(dst_top_p + batch_idx, top_p)
top_k = tl.load(src_top_k + req_index)
tl.store(dst_top_k + req_index, top_k)
top_k = tl.load(src_top_k + req_idx)
tl.store(dst_top_k + batch_idx, top_k)
frequency_penalties = tl.load(src_frequency_penalties + req_index)
tl.store(dst_frequency_penalties + req_index, frequency_penalties)
frequency_penalties = tl.load(src_frequency_penalties + req_idx)
tl.store(dst_frequency_penalties + batch_idx, frequency_penalties)
presence_penalties = tl.load(src_presence_penalties + req_index)
tl.store(dst_presence_penalties + req_index, presence_penalties)
presence_penalties = tl.load(src_presence_penalties + req_idx)
tl.store(dst_presence_penalties + batch_idx, presence_penalties)
repetition_penalties = tl.load(src_repetition_penalties + req_index)
tl.store(dst_repetition_penalties + req_index, repetition_penalties)
repetition_penalties = tl.load(src_repetition_penalties + req_idx)
tl.store(dst_repetition_penalties + batch_idx, repetition_penalties)
@triton.jit
def _make_block_table_kernel(
req_indices, # [batch_size]
src_block_table_ptrs,
dst_block_table_ptrs,
src_num_blocks_ptrs,
dst_num_blocks_ptrs,
num_block_tables: tl.constexpr,
def _make_block_tables_kernel(
batch_idx_to_req_idx, # [batch_size]
src_block_table_ptrs, # [num_block_tables]
dst_block_table_ptrs, # [num_block_tables]
block_table_strides, # [num_block_tables]
src_num_blocks_ptrs, # [num_block_tables]
dst_num_blocks_ptrs, # [num_block_tables]
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
req_index = tl.load(req_indices + batch_idx)
# kv cache group id
group_id = tl.program_id(1)
req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
for i in tl.range(num_block_tables):
src_num_blocks_ptr = tl.load(src_num_blocks_ptrs + i)
dst_num_blocks_ptr = tl.load(dst_num_blocks_ptrs + i)
num_blocks = tl.load(src_num_blocks_ptr + req_index)
tl.store(dst_num_blocks_ptr + req_index, num_blocks)
src_num_blocks_ptr = _load_ptr(src_num_blocks_ptrs, group_id, tl.int32)
dst_num_blocks_ptr = _load_ptr(dst_num_blocks_ptrs, group_id, tl.int32)
num_blocks = tl.load(src_num_blocks_ptr + req_idx)
tl.store(dst_num_blocks_ptr + batch_idx, num_blocks)
src_block_table_ptr = tl.load(src_block_table_ptrs + i)
dst_block_table_ptr = tl.load(dst_block_table_ptrs + i)
for j in tl.range(num_blocks, BLOCK_SIZE):
offset = tl.arange(0, BLOCK_SIZE)
block_ids = tl.load(src_block_table_ptr + j * BLOCK_SIZE + offset,
mask=offset < num_blocks)
tl.store(dst_block_table_ptr + j * BLOCK_SIZE + offset,
block_ids,
mask=offset < num_blocks)
stride = tl.load(block_table_strides + group_id)
src_block_table_ptr = _load_ptr(src_block_table_ptrs, group_id, tl.int32)
src_row_ptr = src_block_table_ptr + req_idx * stride
dst_block_table_ptr = _load_ptr(dst_block_table_ptrs, group_id, tl.int32)
dst_row_ptr = dst_block_table_ptr + batch_idx * stride
for i in tl.range(0, num_blocks, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
block_ids = tl.load(src_row_ptr + offset, mask=offset < num_blocks)
tl.store(dst_row_ptr + offset, block_ids, mask=offset < num_blocks)
@triton.jit
def _make_slot_mappings_kernel(
num_tokens,
max_num_tokens,
cu_num_tokens, # [num_reqs + 1]
pos, # [num_tokens]
block_table_ptrs, # [num_block_tables]
block_table_strides, # [num_block_tables]
page_sizes, # [num_block_tables]
slot_mapping_ptrs, # [num_block_tables]
PAD_ID: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
num_reqs = tl.num_programs(0)
# kv cache group id
group_id = tl.program_id(1)
slot_mapping_ptr = _load_ptr(slot_mapping_ptrs, group_id, tl.int64)
if req_idx == num_reqs - 1:
# Pad remaining slots to -1. This is needed for CUDA graphs.
for i in tl.range(num_tokens, max_num_tokens, BLOCK_SIZE):
offset = num_tokens + i + tl.arange(0, BLOCK_SIZE)
tl.store(slot_mapping_ptr + offset,
PAD_ID,
mask=offset < max_num_tokens)
return
block_table_ptr = _load_ptr(block_table_ptrs, group_id, tl.int32)
block_table_stride = tl.load(block_table_strides + group_id)
page_size = tl.load(page_sizes + group_id)
start_idx = tl.load(cu_num_tokens + req_idx)
end_idx = tl.load(cu_num_tokens + req_idx + 1)
for i in tl.range(start_idx, end_idx, BLOCK_SIZE):
offset = start_idx + i + tl.arange(0, BLOCK_SIZE)
positions = tl.load(pos + offset, mask=offset < end_idx, other=0)
block_indices = positions // page_size
block_numbers = tl.load(block_table_ptr +
req_idx * block_table_stride + block_indices)
slot_ids = block_numbers * page_size + positions % page_size
tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx)
@triton.jit
def _load_ptr(base, offset, elem_dtype):
ptr = tl.load(base + offset)
return tl.cast(ptr, tl.pointer_type(elem_dtype))

View File

@ -47,7 +47,6 @@ from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargsItem,
PlaceholderRange)
from vllm.multimodal.utils import group_mm_kwargs_by_modality
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
@ -215,6 +214,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_batched_tokens=self.max_num_tokens,
max_num_cached_reqs=2 * self.max_num_reqs,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
@ -289,6 +289,35 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dtype=self.dtype,
device=self.device)
# OPTIMIZATION: Cache the tensors rather than creating them every step.
# Keep in int64 to avoid overflow with long context
self.arange_np = np.arange(max(self.max_num_reqs + 1,
self.max_model_len,
self.max_num_tokens),
dtype=np.int64)
# NOTE(woosuk): These tensors are "stateless", i.e., they are literally
# a faster version of creating a new tensor every time. Thus, we should
# not make any assumptions about the values in these tensors.
self.input_ids_cpu = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
self.positions_cpu = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory)
self.positions_np = self.positions_cpu.numpy()
self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
self.query_start_loc_np = self.query_start_loc_cpu.numpy()
self.seq_lens_cpu = torch.zeros(self.max_num_reqs,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
self.seq_lens_np = self.seq_lens_cpu.numpy()
# Layer pairings for cross-layer KV sharing.
# If an Attention layer `layer_name` is in the keys of this dict, it
# means this layer will perform attention using the keys and values
@ -318,6 +347,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
torch.Tensor]] = None
def _init_model_kwargs(self, num_tokens: int):
return {}
model_kwargs = dict[str, Any]()
num_reqs = self.input_batch.num_reqs
@ -410,20 +440,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for req_id in unscheduled_req_ids:
self.input_batch.remove_request(req_id)
reqs_to_add: list[CachedRequestState] = []
# Add new requests to the cached states.
for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id
sampling_params = new_req_data.sampling_params
pooling_params = new_req_data.pooling_params
if sampling_params and \
sampling_params.sampling_type == SamplingType.RANDOM_SEED:
generator = torch.Generator(device=self.device)
generator.manual_seed(sampling_params.seed)
else:
generator = None
if pooling_params:
task = pooling_params.task
assert task is not None, "You did not set `task` in the API"
@ -434,141 +456,81 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_state = CachedRequestState(
req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids,
mm_kwargs=new_req_data.mm_kwargs,
mm_positions=new_req_data.mm_positions,
sampling_params=sampling_params,
pooling_params=pooling_params,
generator=generator,
block_ids=new_req_data.block_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
output_token_ids=[],
lora_request=new_req_data.lora_request,
)
self.requests[req_id] = req_state
self.input_batch.add_request(
req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
block_ids=new_req_data.block_ids,
sampling_params=sampling_params,
)
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
image_grid_thw = []
video_grid_thw = []
second_per_grid_ts = []
audio_feature_lengths = []
use_audio_in_video = False
for mm_item in req_state.mm_kwargs:
mm_input = mm_item.get_data()
if (t := mm_input.get("image_grid_thw")) is not None:
image_grid_thw.append(t.tolist())
if (t := mm_input.get("video_grid_thw")) is not None:
video_grid_thw.append(t.tolist())
if (t := mm_input.get("second_per_grid_ts")) is not None:
second_per_grid_ts.append(t)
if (t :=
mm_input.get("audio_feature_lengths")) is not None:
audio_feature_lengths.append(t)
if mm_input.get("use_audio_in_video") is True:
use_audio_in_video = True
hf_config = self.model_config.hf_config
req_state.mrope_positions, req_state.mrope_position_delta = \
MRotaryEmbedding.get_input_positions_tensor(
req_state.prompt_token_ids,
hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
audio_feature_lengths=audio_feature_lengths,
use_audio_in_video=use_audio_in_video,
)
reqs_to_add.append(req_state)
self._init_mrope_states(req_state)
# Update the states of the running/resumed requests.
is_last_rank = get_pp_group().is_last_rank
req_data = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(req_data.req_ids):
req_state = self.requests[req_id]
num_computed_tokens = req_data.num_computed_tokens[i]
new_block_ids = req_data.new_block_ids[i]
resumed_from_preemption = req_data.resumed_from_preemption[i]
# Update the cached states.
req_state.num_computed_tokens = num_computed_tokens
req_index = self.input_batch.req_id_to_index[req_id]
# Update input batch.
if not is_last_rank:
# When using PP, the scheduler sends the sampled tokens back,
# because there's no direct communication between the first-
# stage worker and the last-stage worker.
new_token_ids = req_data.new_token_ids[i]
# Add the sampled token(s) from the previous step (if any).
# This doesn't include "unverified" tokens like spec tokens.
num_new_tokens = (num_computed_tokens + len(new_token_ids) -
req_state.num_tokens)
if num_new_tokens == 1:
# Avoid slicing list in most common case.
req_state.output_token_ids.append(new_token_ids[-1])
elif num_new_tokens > 0:
req_state.output_token_ids.extend(
new_token_ids[-num_new_tokens:])
self.input_batch.append_token_ids(req_index, new_token_ids)
# Update the block IDs.
if not resumed_from_preemption:
if new_block_ids is not None:
# Append the new blocks to the existing block IDs.
for block_ids, new_ids in zip(req_state.block_ids,
new_block_ids):
block_ids.extend(new_ids)
else:
assert new_block_ids is not None
# The request is resumed from preemption.
# Replace the existing block IDs with the new ones.
req_state.block_ids = new_block_ids
req_index = self.input_batch.req_id_to_index.get(req_id)
if req_index is None:
# The request is not in the persistent batch.
# The request was either preempted and resumed later, or was not
# scheduled in the previous step and needs to be added again.
reqs_to_add.append(req_state)
continue
# Update the persistent batch.
self.input_batch.num_computed_tokens_cpu[req_index] = (
num_computed_tokens)
new_block_ids = req_data.new_block_ids[i]
if new_block_ids is not None:
self.input_batch.block_table.append_row(
new_block_ids, req_index)
# For the last rank, we don't need to update the token_ids_cpu
# because the sampled tokens are already cached.
if not is_last_rank:
# Add new_token_ids to token_ids_cpu.
start_token_index = num_computed_tokens
end_token_index = num_computed_tokens + len(new_token_ids)
self.input_batch.token_ids_cpu[
# If the request is resumed from preemption, we need to
# overwrite the existing block IDs.
self.input_batch.append_block_ids(
req_index,
start_token_index:end_token_index] = new_token_ids
self.input_batch.num_tokens_no_spec[
req_index] = end_token_index
self.input_batch.num_tokens[req_index] = end_token_index
new_block_ids,
overwrite=req_data.resumed_from_preemption[i],
)
# Add spec_token_ids to token_ids_cpu.
spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id, ()))
if spec_token_ids:
num_spec_tokens = len(spec_token_ids)
start_index = self.input_batch.num_tokens_no_spec[req_index]
end_token_index = start_index + num_spec_tokens
self.input_batch.token_ids_cpu[
req_index, start_index:end_token_index] = spec_token_ids
# NOTE(woosuk): `num_tokens` here may include spec tokens.
self.input_batch.num_tokens[req_index] += num_spec_tokens
self.input_batch.num_computed_tokens.np[req_index] = (
req_data.num_computed_tokens[i])
# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
for request in reqs_to_add:
self.input_batch.add_request(request)
def _init_mrope_states(self, req_state: CachedRequestState) -> None:
image_grid_thw = []
video_grid_thw = []
second_per_grid_ts = []
audio_feature_lengths = []
use_audio_in_video = False
for mm_item in req_state.mm_kwargs:
mm_input = mm_item.get_data()
if (t := mm_input.get("image_grid_thw")) is not None:
image_grid_thw.append(t.tolist())
if (t := mm_input.get("video_grid_thw")) is not None:
video_grid_thw.append(t.tolist())
if (t := mm_input.get("second_per_grid_ts")) is not None:
second_per_grid_ts.append(t)
if (t := mm_input.get("audio_feature_lengths")) is not None:
audio_feature_lengths.append(t)
if mm_input.get("use_audio_in_video") is True:
use_audio_in_video = True
req_state.mrope_positions, req_state.mrope_position_delta = \
MRotaryEmbedding.get_input_positions_tensor(
req_state.prompt_token_ids,
hf_config=self.model_config.hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
audio_feature_lengths=audio_feature_lengths,
use_audio_in_video=use_audio_in_video,
)
def _extract_mm_kwargs(
self,
@ -637,12 +599,24 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_reqs = self.input_batch.num_reqs
assert num_reqs > 0
# FIXME
# batch_idx -> req_id
req_ids = list(scheduler_output.num_scheduled_tokens.keys())
# req_id -> batch_idx
req_id_to_batch_idx = {req_id: i for i, req_id in enumerate(req_ids)}
# batch_idx -> req_idx
idx_mapping = [
self.input_batch.req_id_to_index[req_id] for req_id in req_ids
]
# batch_idx -> req_idx
idx_mapping_tensor = self.input_batch.get_index_mapping(idx_mapping)
num_reqs = len(req_ids)
# OPTIMIZATION: Start copying the block table first.
# This way, we can overlap the copy with the following CPU operations.
self.input_batch.block_table.commit_block_table(num_reqs)
block_tables = self.input_batch.make_block_tables(idx_mapping_tensor)
# Get the number of scheduled tokens for each request.
req_ids = self.input_batch.req_ids
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
max_num_scheduled_tokens = max(tokens)
@ -659,7 +633,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Get positions.
positions_np = self.positions_np[:total_num_scheduled_tokens]
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
np.add(self.input_batch.num_computed_tokens.np[req_indices],
arange,
out=positions_np)
@ -673,21 +647,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
# where M is the max_model_len.
token_indices = (positions_np +
req_indices * self.input_batch.token_ids_cpu.shape[1])
req_indices * self.input_batch.token_ids.np.shape[1])
# NOTE(woosuk): We use torch.index_select instead of np.take here
# because torch.index_select is much faster than np.take for large
# tensors.
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
torch.index_select(self.input_batch.token_ids.cpu.flatten(),
0,
torch.from_numpy(token_indices),
out=self.input_ids_cpu[:total_num_scheduled_tokens])
self.input_batch.block_table.compute_slot_mapping(
req_indices, positions_np)
self.input_batch.block_table.commit_slot_mapping(
total_num_scheduled_tokens)
# Prepare the attention metadata.
self.query_start_loc_np[0] = 0
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
@ -698,7 +667,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
query_start_loc = self.query_start_loc[:num_reqs + 1]
self.seq_lens_np[:num_reqs] = (
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
self.input_batch.num_computed_tokens.np[:num_reqs] +
num_scheduled_tokens)
# Fill unused with 0 for full cuda graph mode.
self.seq_lens_np[num_reqs:].fill(0)
@ -737,8 +706,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
for req_id, draft_token_ids in (
scheduler_output.scheduled_spec_decode_tokens.items()):
req_idx = self.input_batch.req_id_to_index[req_id]
num_draft_tokens[req_idx] = len(draft_token_ids)
batch_idx = req_id_to_batch_idx[req_id]
num_draft_tokens[batch_idx] = len(draft_token_ids)
spec_decode_metadata = self._calc_spec_decode_metadata(
num_draft_tokens, cu_num_tokens)
@ -788,11 +757,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
per_layer_metadata[layer_name]
attn_metadata[layer_name] = encoder_attn_metadata
slot_mappings = self.input_batch.make_slot_mappings(
query_start_loc,
self.positions[:total_num_scheduled_tokens],
)
# Used in the below loop.
query_start_loc_cpu = self.query_start_loc_cpu[:num_reqs + 1]
seq_lens_cpu = self.seq_lens_cpu[:num_reqs]
num_computed_tokens_cpu = (
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
self.input_batch.num_computed_tokens.cpu[:num_reqs])
spec_decode_common_attn_metadata = None
# Prepare the attention metadata for each KV cache group and make layers
@ -800,14 +773,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups):
blk_table = self.input_batch.block_table[kv_cache_group_id]
blk_table_tensor = blk_table.get_device_tensor()[:num_reqs]
slot_mapping = blk_table.slot_mapping[:total_num_scheduled_tokens]
# Fill unused with -1. Needed for reshape_and_cache in full cuda
# graph mode.
blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1)
blk_table_tensor = block_tables[kv_cache_group_id]
slot_mapping = slot_mappings[kv_cache_group_id]
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
@ -948,7 +915,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_reqs = len(num_scheduled_tokens)
common_prefix_len = min(
common_prefix_len,
self.input_batch.num_computed_tokens_cpu[:num_reqs].min())
self.input_batch.num_computed_tokens.np[:num_reqs].min())
# common_prefix_len should be a multiple of the block size.
common_prefix_len = (common_prefix_len // kv_cache_spec.block_size *
kv_cache_spec.block_size)
@ -1454,6 +1421,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_scheduled_tokens_np, spec_decode_common_attn_metadata,
max_query_len) = (self._prepare_inputs(scheduler_output))
# FIXME
# batch_idx -> req_id
req_ids = list(scheduler_output.num_scheduled_tokens.keys())
# req_id -> batch_idx
req_id_to_batch_idx = {req_id: i for i, req_id in enumerate(req_ids)}
# batch_idx -> req_idx
idx_mapping = [
self.input_batch.req_id_to_index[req_id] for req_id in req_ids
]
# batch_idx -> req_idx
idx_mapping_tensor = self.input_batch.get_index_mapping(idx_mapping)
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
@ -1593,7 +1572,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.apply_grammar_bitmask(scheduler_output, logits)
# Sample the next token and get logprobs if needed.
sampling_metadata = self.input_batch.sampling_metadata
sampling_metadata = self.input_batch.make_sampling_metadata(
idx_mapping_tensor)
if spec_decode_metadata is None:
sampler_output = self.sampler(
logits=logits,
@ -1629,24 +1609,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
num_nans_in_logits = self._get_nans_in_logits(logits)
# TODO(woosuk): The following loop can be slow since it iterates over
# the requests one by one. Optimize.
discard_sampled_tokens_req_indices = []
for i, req_id in enumerate(self.input_batch.req_ids):
req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
if seq_len < req_state.num_tokens:
# Ignore the sampled token for partial prefills.
# Rewind the generator state as if the token was not sampled.
# This relies on cuda-specific torch-internal impl details
generator = self.input_batch.generators.get(i)
if generator is not None:
generator.set_offset(generator.get_offset() - 4)
# Record the index of the request that should not be sampled,
# so that we could clear the sampled tokens before returning.
discard_sampled_tokens_req_indices.append(i)
# NOTE: GPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point.
logprobs_tensors = sampler_output.logprobs_tensors
@ -1668,37 +1630,20 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
else:
# Includes spec decode tokens.
valid_sampled_token_ids = self.rejection_sampler.parse_output(
sampled_token_ids,
self.input_batch.vocab_size,
)
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
sampled_token_ids, self.input_batch.vocab_size)
# # Mask out the sampled tokens that should not be sampled.
# for i in discard_sampled_tokens_req_indices:
# valid_sampled_token_ids[i].clear()
# Cache the sampled tokens in the model runner, so that the scheduler
# doesn't need to send them back.
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
# the sampled tokens back, because there's no direct communication
# between the first-stage worker and the last-stage worker.
req_ids = self.input_batch.req_ids
for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
if not sampled_ids:
continue
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
end_idx = start_idx + len(sampled_ids)
assert end_idx <= self.max_model_len, (
"Sampled token IDs exceed the max model length. "
f"Total number of tokens: {end_idx} > max_model_len: "
f"{self.max_model_len}")
self.input_batch.token_ids_cpu[req_idx,
start_idx:end_idx] = sampled_ids
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
self.input_batch.num_tokens[req_idx] = end_idx
req_id = req_ids[req_idx]
req_state = self.requests[req_id]
req_state.output_token_ids.extend(sampled_ids)
self.input_batch.append_token_ids(req_idx, sampled_ids)
if self.speculative_config:
assert spec_decode_common_attn_metadata is not None
@ -1716,8 +1661,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.eplb_step()
return ModelRunnerOutput(
req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
req_ids=req_ids,
req_id_to_index=req_id_to_batch_idx,
sampled_token_ids=valid_sampled_token_ids,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
@ -2389,14 +2334,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
generators={},
max_num_logprobs=None,
no_penalties=True,
prompt_token_ids=None,
frequency_penalties=dummy_tensors(0.1),
presence_penalties=dummy_tensors(0.1),
repetition_penalties=dummy_tensors(0.1),
output_token_ids=[[] for _ in range(num_reqs)],
allowed_token_ids_mask=None,
bad_words_token_ids={},
logitsprocs=LogitsProcessors(),
token_ids=None,
)
try:
sampler_output = self.sampler(logits=logits,