From 79e5eb36434c402cd0a08d3bcd6a06f80f53753a Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 22 Aug 2025 01:37:43 -0700 Subject: [PATCH] wip Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu_input_batch.py | 394 +++++++++++++++++++---------- vllm/v1/worker/gpu_model_runner.py | 316 ++++++++++------------- 2 files changed, 396 insertions(+), 314 deletions(-) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index dd67bec3df4d2..62ec917276eb8 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -6,6 +6,8 @@ from dataclasses import dataclass from typing import Optional import torch +import triton +import triton.language as tl from typing_extensions import deprecated from vllm.lora.request import LoRARequest @@ -17,6 +19,8 @@ from vllm.utils import cdiv, get_cuda_view_from_cpu_tensor, is_uva_available from vllm.v1.sample.logits_processor import LogitsProcessors from vllm.v1.sample.metadata import SamplingMetadata +PAD_SLOT_ID = -1 + @dataclass class CachedRequestState: @@ -27,6 +31,7 @@ class CachedRequestState: sampling_params: Optional[SamplingParams] pooling_params: Optional[PoolingParams] + # M-RoPE (only for Qwen2/2.5-VL) mrope_positions: Optional[torch.Tensor] = None mrope_position_delta: Optional[int] = None @@ -46,21 +51,31 @@ class PerRequestAttribute: def __init__( self, - N: int, - M: int, - K: int, + num_rows_cpu: int, + num_cols: int, + num_rows_gpu: int, dtype: torch.dtype, device: torch.device, + is_scalar: bool = False, ): assert is_uva_available(), "UVA is not available." - self.cpu_tensor = torch.zeros(N, - M, - dtype=dtype, - device="cpu", - pin_memory=True) - self.np = self.cpu_tensor.numpy() - self.uva_tensor = get_cuda_view_from_cpu_tensor(self.cpu_tensor) - self.gpu_tensor = torch.zeros(K, M, dtype=dtype, device=device) + self.cpu = torch.zeros(num_rows_cpu, + num_cols, + dtype=dtype, + device="cpu", + pin_memory=True) + self.np = self.cpu.numpy() + self.uva = get_cuda_view_from_cpu_tensor(self.cpu) + self.gpu = torch.zeros(num_rows_gpu, + num_cols, + dtype=dtype, + device=device) + if is_scalar: + assert num_cols == 1 + self.cpu.squeeze_(1) + self.np = self.cpu.numpy() + self.uva.squeeze_(1) + self.gpu.squeeze_(1) class InputBatch: @@ -87,15 +102,19 @@ class InputBatch: self.pin_memory = pin_memory self.vocab_size = vocab_size self.is_spec_decode = is_spec_decode + self.pooling_params = None + self.block_sizes = block_sizes + self.num_prompt_logprobs = {} self.req_id_to_index: dict[str, int] = {} self.index_to_req_id: dict[int, str] = {} self.free_indices = list(range(max_num_cached_reqs)) + self._add_scalar_attr("idx_mapping", torch.int32) # Request states. - # TODO(woosuk): This buffer could be too large if max_model_len is big. - # Find a way to reduce the memory usage. - self._add_vector_attr("token_ids", self.max_model_len, torch.int32) + # TODO(woosuk): Because the token_ids tensor can be very big, we only + # initialize it on CPU memory. + self._add_vector_attr_cpu("token_ids", self.max_model_len, torch.int32) self._add_scalar_attr("num_prompt_tokens", torch.int32) self._add_scalar_attr("num_tokens", torch.int32) self._add_scalar_attr("num_computed_tokens", torch.int32) @@ -119,20 +138,7 @@ class InputBatch: self.generators: dict[int, torch.Generator] = {} # Block table(s). - self.block_tables = [] - self.num_blocks = [] - for block_size in block_sizes: - max_num_blocks = cdiv(max_model_len, block_size) - block_table = PerRequestAttribute(self.max_num_cached_reqs, - max_num_blocks, - self.max_num_reqs, torch.int32, - self.device) - self.block_tables.append(block_table) - num_blocks = PerRequestAttribute(self.max_num_cached_reqs, 1, - self.max_num_reqs, torch.int32, - self.device) - self.num_blocks.append(num_blocks) - self.num_block_tables = len(block_sizes) + self._init_block_tables() def add_request( self, @@ -146,11 +152,10 @@ class InputBatch: self.req_id_to_index[req_id] = req_idx self.index_to_req_id[req_idx] = req_id - num_prompt_tokens = len(prompt_token_ids) - self.token_ids.np[req_idx, :num_prompt_tokens] = prompt_token_ids - self.num_prompt_tokens.np[req_idx] = num_prompt_tokens - self.num_tokens.np[req_idx] = num_prompt_tokens + self.num_prompt_tokens.np[req_idx] = len(prompt_token_ids) self.num_computed_tokens.np[req_idx] = num_computed_tokens + self.append_token_ids(req_idx, prompt_token_ids) + self.append_block_ids(req_idx, block_ids, overwrite=True) self.temperature.np[req_idx] = sampling_params.temperature if sampling_params.sampling_type == SamplingType.GREEDY: @@ -171,56 +176,48 @@ class InputBatch: self.top_k.np[req_idx] = top_k self.frequency_penalties.np[ - req_idx] = sampling_params.frequency_penalties - if sampling_params.frequency_penalties != 0.0: + req_idx] = sampling_params.frequency_penalty + if sampling_params.frequency_penalty != 0.0: self.frequency_penalties_reqs.add(req_id) - self.presence_penalties.np[ - req_idx] = sampling_params.presence_penalties - if sampling_params.presence_penalties != 0.0: + self.presence_penalties.np[req_idx] = sampling_params.presence_penalty + if sampling_params.presence_penalty != 0.0: self.presence_penalties_reqs.add(req_id) self.repetition_penalties.np[ - req_idx] = sampling_params.repetition_penalties - if sampling_params.repetition_penalties != 1.0: + req_idx] = sampling_params.repetition_penalty + if sampling_params.repetition_penalty != 1.0: self.repetition_penalties_reqs.add(req_id) - if sampling_params.seed is not None: + if sampling_params.sampling_type == SamplingType.RANDOM_SEED: generator = torch.Generator(device=self.device) generator.manual_seed(sampling_params.seed) self.generators[req_idx] = generator - for i in range(self.num_block_tables): - self.block_tables[i].np[req_idx, :len(block_ids[i])] = block_ids[i] - self.num_blocks[i].np[req_idx] = len(block_ids[i]) - - def append_token_ids(self, req_id: str, token_ids: list[int]) -> None: - req_idx = self.req_id_to_index.get(req_id) - assert req_idx is not None + def append_token_ids(self, req_idx: int, token_ids: list[int]) -> None: start_idx = self.num_tokens.np[req_idx] end_idx = start_idx + len(token_ids) self.token_ids.np[req_idx, start_idx:end_idx] = token_ids self.num_tokens.np[req_idx] = end_idx + # TODO(woosuk): Further vectorize this to minimize overheads. def append_block_ids( self, - req_id: str, + req_idx: int, new_block_ids: tuple[list[int], ...], overwrite: bool, ) -> None: - req_idx = self.req_id_to_index.get(req_id) - assert req_idx is not None for i in range(self.num_block_tables): block_table = self.block_tables[i] num_blocks = self.num_blocks[i] + num_new_blocks = len(new_block_ids[i]) if overwrite: # Replace the existing block IDs with the new ones. # This happens when the request is resumed from preemption. - block_table.np[ - req_idx, :len(new_block_ids[i])] = new_block_ids[i] - num_blocks.np[req_idx] = len(new_block_ids[i]) + block_table.np[req_idx, :num_new_blocks] = new_block_ids[i] + num_blocks.np[req_idx] = num_new_blocks else: # Append the new block IDs to the existing ones (common case). start_idx = num_blocks.np[req_idx] - end_idx = start_idx + len(new_block_ids[i]) + end_idx = start_idx + num_new_blocks block_table.np[req_idx, start_idx:end_idx] = new_block_ids[i] num_blocks.np[req_idx] = end_idx @@ -241,52 +238,50 @@ class InputBatch: self.repetition_penalties_reqs.discard(req_id) self.generators.pop(req_idx, None) - def make_block_table(self, req_idx: int) -> tuple[torch.Tensor, ...]: - pass + def get_index_mapping(self, idx_mapping: list[int]) -> torch.Tensor: + num_reqs = len(idx_mapping) + self.idx_mapping.np[:num_reqs] = idx_mapping + return self.idx_mapping.gpu[:num_reqs].copy_( + self.idx_mapping.uva[:num_reqs], non_blocking=True) - def make_sampling_metadata(self, - req_indices: list[int]) -> SamplingMetadata: - batch_size = len(req_indices) + def make_sampling_metadata( + self, + batch_idx_to_req_idx: torch.Tensor, + ) -> SamplingMetadata: + batch_size = batch_idx_to_req_idx.shape[0] _make_sampling_metadata_kernel[(batch_size, )]( - req_indices, - self.temperature.uva_tensor, - self.temperature.gpu_tensor, - self.top_p.uva_tensor, - self.top_p.gpu_tensor, - self.top_k.uva_tensor, - self.top_k.gpu_tensor, - self.frequency_penalties.uva_tensor, - self.frequency_penalties.gpu_tensor, - self.presence_penalties.uva_tensor, - self.presence_penalties.gpu_tensor, - self.repetition_penalties.uva_tensor, - self.repetition_penalties.gpu_tensor, + batch_idx_to_req_idx, + self.temperature.uva, + self.temperature.gpu, + self.top_p.uva, + self.top_p.gpu, + self.top_k.uva, + self.top_k.gpu, + self.frequency_penalties.uva, + self.frequency_penalties.gpu, + self.presence_penalties.uva, + self.presence_penalties.gpu, + self.repetition_penalties.uva, + self.repetition_penalties.gpu, num_warps=1, num_stages=1, ) - generators = {} - if self.generators: - for i, req_idx in enumerate(req_indices): - generator = self.generators.get(req_idx) - if generator is not None: - generators[i] = generator no_penalties = not (self.frequency_penalties_reqs or self.presence_penalties_reqs or self.repetition_penalties_reqs) return SamplingMetadata( - temperature=self.temperature.gpu_tensor[:batch_size], + temperature=self.temperature.gpu[:batch_size], all_greedy=not self.random_reqs, all_random=not self.greedy_reqs, - top_p=self.top_p.gpu_tensor[:batch_size], - top_k=self.top_k.gpu_tensor[:batch_size], - frequency_penalties=self.frequency_penalties. - gpu_tensor[:batch_size], - presence_penalties=self.presence_penalties.gpu_tensor[:batch_size], - repetition_penalties=self.repetition_penalties. - gpu_tensor[:batch_size], + top_p=self.top_p.gpu[:batch_size], + top_k=self.top_k.gpu[:batch_size], + frequency_penalties=self.frequency_penalties.gpu[:batch_size], + presence_penalties=self.presence_penalties.gpu[:batch_size], + repetition_penalties=self.repetition_penalties.gpu[:batch_size], no_penalties=no_penalties, - generators=generators, - token_ids=self.token_ids.gpu_tensor[:batch_size], + # TODO + generators={}, + token_ids=self.token_ids.gpu[:batch_size], max_num_logprobs=None, allowed_token_ids_mask=None, bad_words_token_ids={}, @@ -297,22 +292,113 @@ class InputBatch: def num_reqs(self) -> int: return len(self.req_id_to_index) + def _add_scalar_attr(self, name: str, dtype: torch.dtype): + attr = PerRequestAttribute(self.max_num_cached_reqs, + 1, + self.max_num_reqs, + dtype, + self.device, + is_scalar=True) + setattr(self, name, attr) + def _add_vector_attr(self, name: str, max_len: int, dtype: torch.dtype): attr = PerRequestAttribute(self.max_num_cached_reqs, max_len, self.max_num_reqs, dtype, self.device) setattr(self, name, attr) - def _add_scalar_attr(self, name: str, dtype: torch.dtype): - self._add_vector_attr(name, max_len=1, dtype=dtype) + def _add_vector_attr_cpu(self, name: str, max_len: int, + dtype: torch.dtype): + attr = PerRequestAttribute(self.max_num_cached_reqs, max_len, 0, dtype, + self.device) + setattr(self, name, attr) + def _init_block_tables(self): + self.num_block_tables = len(self.block_sizes) + self.block_tables = [] + self.num_blocks = [] + self.slot_mappings: list[torch.Tensor] = [] + for i in range(self.num_block_tables): + max_num_blocks = cdiv(self.max_model_len, self.block_sizes[i]) + block_table = PerRequestAttribute(self.max_num_cached_reqs, + max_num_blocks, + self.max_num_reqs, torch.int32, + self.device) + self.block_tables.append(block_table) + num_blocks = PerRequestAttribute(self.max_num_cached_reqs, + 1, + self.max_num_reqs, + torch.int32, + self.device, + is_scalar=True) + self.num_blocks.append(num_blocks) + slot_mapping = torch.zeros(self.max_num_batched_tokens, + dtype=torch.int64, + device=self.device) + self.slot_mappings.append(slot_mapping) -import triton -import triton.language as tl + def make_ptr_tensor(x: list[torch.Tensor]) -> torch.Tensor: + return torch.tensor([t.data_ptr() for t in x], + dtype=torch.int64, + device=self.device) + + self.uva_block_table_ptrs = make_ptr_tensor( + [b.uva for b in self.block_tables]) + self.gpu_block_table_ptrs = make_ptr_tensor( + [b.gpu for b in self.block_tables]) + self.uva_num_blocks_ptrs = make_ptr_tensor( + [n.uva for n in self.num_blocks]) + self.gpu_num_blocks_ptrs = make_ptr_tensor( + [n.gpu for n in self.num_blocks]) + self.block_table_strides = torch.tensor( + [b.gpu.shape[1] for b in self.block_tables], + dtype=torch.int64, + device=self.device) + self.block_sizes_tensor = torch.tensor(self.block_sizes, + dtype=torch.int32, + device=self.device) + self.slot_mapping_ptrs = make_ptr_tensor(self.slot_mappings) + + def make_block_tables( + self, + idx_mapping: torch.Tensor, + ) -> tuple[torch.Tensor, ...]: + batch_size = idx_mapping.shape[0] + _make_block_tables_kernel[(batch_size, self.num_block_tables)]( + idx_mapping, + self.uva_block_table_ptrs, + self.gpu_block_table_ptrs, + self.block_table_strides, + self.uva_num_blocks_ptrs, + self.gpu_num_blocks_ptrs, + BLOCK_SIZE=1024, + ) + return tuple(b.gpu[:batch_size] for b in self.block_tables) + + def make_slot_mappings( + self, + cu_num_tokens: torch.Tensor, + pos: torch.Tensor, + ) -> tuple[torch.Tensor, ...]: + num_tokens = pos.shape[0] + num_reqs = cu_num_tokens.shape[0] - 1 + _make_slot_mappings_kernel[(num_reqs + 1, self.num_block_tables)]( + num_tokens, + self.max_num_batched_tokens, + cu_num_tokens, + pos, + self.gpu_block_table_ptrs, + self.block_table_strides, + self.block_sizes_tensor, + self.slot_mapping_ptrs, + PAD_ID=PAD_SLOT_ID, + BLOCK_SIZE=1024, + ) + return tuple(x[:num_tokens] for x in self.slot_mappings) @triton.jit def _make_sampling_metadata_kernel( - req_indices, # [batch_size] + batch_idx_to_req_idx, # [batch_size] src_temperature, dst_temperature, src_top_p, @@ -327,52 +413,104 @@ def _make_sampling_metadata_kernel( dst_repetition_penalties, ): batch_idx = tl.program_id(0) - req_index = tl.load(req_indices + batch_idx) + req_idx = tl.load(batch_idx_to_req_idx + batch_idx) - temperature = tl.load(src_temperature + req_index) - tl.store(dst_temperature + req_index, temperature) + temperature = tl.load(src_temperature + req_idx) + tl.store(dst_temperature + batch_idx, temperature) - top_p = tl.load(src_top_p + req_index) - tl.store(dst_top_p + req_index, top_p) + top_p = tl.load(src_top_p + req_idx) + tl.store(dst_top_p + batch_idx, top_p) - top_k = tl.load(src_top_k + req_index) - tl.store(dst_top_k + req_index, top_k) + top_k = tl.load(src_top_k + req_idx) + tl.store(dst_top_k + batch_idx, top_k) - frequency_penalties = tl.load(src_frequency_penalties + req_index) - tl.store(dst_frequency_penalties + req_index, frequency_penalties) + frequency_penalties = tl.load(src_frequency_penalties + req_idx) + tl.store(dst_frequency_penalties + batch_idx, frequency_penalties) - presence_penalties = tl.load(src_presence_penalties + req_index) - tl.store(dst_presence_penalties + req_index, presence_penalties) + presence_penalties = tl.load(src_presence_penalties + req_idx) + tl.store(dst_presence_penalties + batch_idx, presence_penalties) - repetition_penalties = tl.load(src_repetition_penalties + req_index) - tl.store(dst_repetition_penalties + req_index, repetition_penalties) + repetition_penalties = tl.load(src_repetition_penalties + req_idx) + tl.store(dst_repetition_penalties + batch_idx, repetition_penalties) @triton.jit -def _make_block_table_kernel( - req_indices, # [batch_size] - src_block_table_ptrs, - dst_block_table_ptrs, - src_num_blocks_ptrs, - dst_num_blocks_ptrs, - num_block_tables: tl.constexpr, +def _make_block_tables_kernel( + batch_idx_to_req_idx, # [batch_size] + src_block_table_ptrs, # [num_block_tables] + dst_block_table_ptrs, # [num_block_tables] + block_table_strides, # [num_block_tables] + src_num_blocks_ptrs, # [num_block_tables] + dst_num_blocks_ptrs, # [num_block_tables] BLOCK_SIZE: tl.constexpr, ): batch_idx = tl.program_id(0) - req_index = tl.load(req_indices + batch_idx) + # kv cache group id + group_id = tl.program_id(1) + req_idx = tl.load(batch_idx_to_req_idx + batch_idx) - for i in tl.range(num_block_tables): - src_num_blocks_ptr = tl.load(src_num_blocks_ptrs + i) - dst_num_blocks_ptr = tl.load(dst_num_blocks_ptrs + i) - num_blocks = tl.load(src_num_blocks_ptr + req_index) - tl.store(dst_num_blocks_ptr + req_index, num_blocks) + src_num_blocks_ptr = _load_ptr(src_num_blocks_ptrs, group_id, tl.int32) + dst_num_blocks_ptr = _load_ptr(dst_num_blocks_ptrs, group_id, tl.int32) + num_blocks = tl.load(src_num_blocks_ptr + req_idx) + tl.store(dst_num_blocks_ptr + batch_idx, num_blocks) - src_block_table_ptr = tl.load(src_block_table_ptrs + i) - dst_block_table_ptr = tl.load(dst_block_table_ptrs + i) - for j in tl.range(num_blocks, BLOCK_SIZE): - offset = tl.arange(0, BLOCK_SIZE) - block_ids = tl.load(src_block_table_ptr + j * BLOCK_SIZE + offset, - mask=offset < num_blocks) - tl.store(dst_block_table_ptr + j * BLOCK_SIZE + offset, - block_ids, - mask=offset < num_blocks) + stride = tl.load(block_table_strides + group_id) + src_block_table_ptr = _load_ptr(src_block_table_ptrs, group_id, tl.int32) + src_row_ptr = src_block_table_ptr + req_idx * stride + dst_block_table_ptr = _load_ptr(dst_block_table_ptrs, group_id, tl.int32) + dst_row_ptr = dst_block_table_ptr + batch_idx * stride + + for i in tl.range(0, num_blocks, BLOCK_SIZE): + offset = i + tl.arange(0, BLOCK_SIZE) + block_ids = tl.load(src_row_ptr + offset, mask=offset < num_blocks) + tl.store(dst_row_ptr + offset, block_ids, mask=offset < num_blocks) + + +@triton.jit +def _make_slot_mappings_kernel( + num_tokens, + max_num_tokens, + cu_num_tokens, # [num_reqs + 1] + pos, # [num_tokens] + block_table_ptrs, # [num_block_tables] + block_table_strides, # [num_block_tables] + page_sizes, # [num_block_tables] + slot_mapping_ptrs, # [num_block_tables] + PAD_ID: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + req_idx = tl.program_id(0) + num_reqs = tl.num_programs(0) + # kv cache group id + group_id = tl.program_id(1) + slot_mapping_ptr = _load_ptr(slot_mapping_ptrs, group_id, tl.int64) + + if req_idx == num_reqs - 1: + # Pad remaining slots to -1. This is needed for CUDA graphs. + for i in tl.range(num_tokens, max_num_tokens, BLOCK_SIZE): + offset = num_tokens + i + tl.arange(0, BLOCK_SIZE) + tl.store(slot_mapping_ptr + offset, + PAD_ID, + mask=offset < max_num_tokens) + return + + block_table_ptr = _load_ptr(block_table_ptrs, group_id, tl.int32) + block_table_stride = tl.load(block_table_strides + group_id) + page_size = tl.load(page_sizes + group_id) + + start_idx = tl.load(cu_num_tokens + req_idx) + end_idx = tl.load(cu_num_tokens + req_idx + 1) + for i in tl.range(start_idx, end_idx, BLOCK_SIZE): + offset = start_idx + i + tl.arange(0, BLOCK_SIZE) + positions = tl.load(pos + offset, mask=offset < end_idx, other=0) + block_indices = positions // page_size + block_numbers = tl.load(block_table_ptr + + req_idx * block_table_stride + block_indices) + slot_ids = block_numbers * page_size + positions % page_size + tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx) + + +@triton.jit +def _load_ptr(base, offset, elem_dtype): + ptr = tl.load(base + offset) + return tl.cast(ptr, tl.pointer_type(elem_dtype)) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9c62b6746666f..815103f71ac85 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -47,7 +47,6 @@ from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargsItem, PlaceholderRange) from vllm.multimodal.utils import group_mm_kwargs_by_modality from vllm.pooling_params import PoolingParams -from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.tasks import GenerationTask, PoolingTask, SupportedTask from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, @@ -215,6 +214,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): max_num_reqs=self.max_num_reqs, max_model_len=self.max_model_len, max_num_batched_tokens=self.max_num_tokens, + max_num_cached_reqs=2 * self.max_num_reqs, device=self.device, pin_memory=self.pin_memory, vocab_size=self.model_config.get_vocab_size(), @@ -289,6 +289,35 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): dtype=self.dtype, device=self.device) + # OPTIMIZATION: Cache the tensors rather than creating them every step. + # Keep in int64 to avoid overflow with long context + self.arange_np = np.arange(max(self.max_num_reqs + 1, + self.max_model_len, + self.max_num_tokens), + dtype=np.int64) + # NOTE(woosuk): These tensors are "stateless", i.e., they are literally + # a faster version of creating a new tensor every time. Thus, we should + # not make any assumptions about the values in these tensors. + self.input_ids_cpu = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + self.positions_cpu = torch.zeros(self.max_num_tokens, + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory) + self.positions_np = self.positions_cpu.numpy() + self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + self.query_start_loc_np = self.query_start_loc_cpu.numpy() + self.seq_lens_cpu = torch.zeros(self.max_num_reqs, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + self.seq_lens_np = self.seq_lens_cpu.numpy() + # Layer pairings for cross-layer KV sharing. # If an Attention layer `layer_name` is in the keys of this dict, it # means this layer will perform attention using the keys and values @@ -318,6 +347,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): torch.Tensor]] = None def _init_model_kwargs(self, num_tokens: int): + return {} model_kwargs = dict[str, Any]() num_reqs = self.input_batch.num_reqs @@ -410,20 +440,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): for req_id in unscheduled_req_ids: self.input_batch.remove_request(req_id) - reqs_to_add: list[CachedRequestState] = [] # Add new requests to the cached states. for new_req_data in scheduler_output.scheduled_new_reqs: req_id = new_req_data.req_id sampling_params = new_req_data.sampling_params pooling_params = new_req_data.pooling_params - if sampling_params and \ - sampling_params.sampling_type == SamplingType.RANDOM_SEED: - generator = torch.Generator(device=self.device) - generator.manual_seed(sampling_params.seed) - else: - generator = None - if pooling_params: task = pooling_params.task assert task is not None, "You did not set `task` in the API" @@ -434,141 +456,81 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_state = CachedRequestState( req_id=req_id, - prompt_token_ids=new_req_data.prompt_token_ids, mm_kwargs=new_req_data.mm_kwargs, mm_positions=new_req_data.mm_positions, sampling_params=sampling_params, pooling_params=pooling_params, - generator=generator, - block_ids=new_req_data.block_ids, - num_computed_tokens=new_req_data.num_computed_tokens, - output_token_ids=[], lora_request=new_req_data.lora_request, ) - self.requests[req_id] = req_state + self.input_batch.add_request( + req_id=req_id, + prompt_token_ids=new_req_data.prompt_token_ids, + num_computed_tokens=new_req_data.num_computed_tokens, + block_ids=new_req_data.block_ids, + sampling_params=sampling_params, + ) # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: - image_grid_thw = [] - video_grid_thw = [] - second_per_grid_ts = [] - audio_feature_lengths = [] - use_audio_in_video = False - for mm_item in req_state.mm_kwargs: - mm_input = mm_item.get_data() - if (t := mm_input.get("image_grid_thw")) is not None: - image_grid_thw.append(t.tolist()) - if (t := mm_input.get("video_grid_thw")) is not None: - video_grid_thw.append(t.tolist()) - if (t := mm_input.get("second_per_grid_ts")) is not None: - second_per_grid_ts.append(t) - if (t := - mm_input.get("audio_feature_lengths")) is not None: - audio_feature_lengths.append(t) - if mm_input.get("use_audio_in_video") is True: - use_audio_in_video = True - - hf_config = self.model_config.hf_config - - req_state.mrope_positions, req_state.mrope_position_delta = \ - MRotaryEmbedding.get_input_positions_tensor( - req_state.prompt_token_ids, - hf_config=hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - audio_feature_lengths=audio_feature_lengths, - use_audio_in_video=use_audio_in_video, - ) - - reqs_to_add.append(req_state) + self._init_mrope_states(req_state) # Update the states of the running/resumed requests. is_last_rank = get_pp_group().is_last_rank req_data = scheduler_output.scheduled_cached_reqs for i, req_id in enumerate(req_data.req_ids): - req_state = self.requests[req_id] - num_computed_tokens = req_data.num_computed_tokens[i] - new_block_ids = req_data.new_block_ids[i] - resumed_from_preemption = req_data.resumed_from_preemption[i] - - # Update the cached states. - req_state.num_computed_tokens = num_computed_tokens + req_index = self.input_batch.req_id_to_index[req_id] + # Update input batch. if not is_last_rank: # When using PP, the scheduler sends the sampled tokens back, # because there's no direct communication between the first- # stage worker and the last-stage worker. new_token_ids = req_data.new_token_ids[i] - # Add the sampled token(s) from the previous step (if any). - # This doesn't include "unverified" tokens like spec tokens. - num_new_tokens = (num_computed_tokens + len(new_token_ids) - - req_state.num_tokens) - if num_new_tokens == 1: - # Avoid slicing list in most common case. - req_state.output_token_ids.append(new_token_ids[-1]) - elif num_new_tokens > 0: - req_state.output_token_ids.extend( - new_token_ids[-num_new_tokens:]) + self.input_batch.append_token_ids(req_index, new_token_ids) - # Update the block IDs. - if not resumed_from_preemption: - if new_block_ids is not None: - # Append the new blocks to the existing block IDs. - for block_ids, new_ids in zip(req_state.block_ids, - new_block_ids): - block_ids.extend(new_ids) - else: - assert new_block_ids is not None - # The request is resumed from preemption. - # Replace the existing block IDs with the new ones. - req_state.block_ids = new_block_ids - - req_index = self.input_batch.req_id_to_index.get(req_id) - if req_index is None: - # The request is not in the persistent batch. - # The request was either preempted and resumed later, or was not - # scheduled in the previous step and needs to be added again. - reqs_to_add.append(req_state) - continue - - # Update the persistent batch. - self.input_batch.num_computed_tokens_cpu[req_index] = ( - num_computed_tokens) + new_block_ids = req_data.new_block_ids[i] if new_block_ids is not None: - self.input_batch.block_table.append_row( - new_block_ids, req_index) - - # For the last rank, we don't need to update the token_ids_cpu - # because the sampled tokens are already cached. - if not is_last_rank: - # Add new_token_ids to token_ids_cpu. - start_token_index = num_computed_tokens - end_token_index = num_computed_tokens + len(new_token_ids) - self.input_batch.token_ids_cpu[ + # If the request is resumed from preemption, we need to + # overwrite the existing block IDs. + self.input_batch.append_block_ids( req_index, - start_token_index:end_token_index] = new_token_ids - self.input_batch.num_tokens_no_spec[ - req_index] = end_token_index - self.input_batch.num_tokens[req_index] = end_token_index + new_block_ids, + overwrite=req_data.resumed_from_preemption[i], + ) - # Add spec_token_ids to token_ids_cpu. - spec_token_ids = ( - scheduler_output.scheduled_spec_decode_tokens.get(req_id, ())) - if spec_token_ids: - num_spec_tokens = len(spec_token_ids) - start_index = self.input_batch.num_tokens_no_spec[req_index] - end_token_index = start_index + num_spec_tokens - self.input_batch.token_ids_cpu[ - req_index, start_index:end_token_index] = spec_token_ids - # NOTE(woosuk): `num_tokens` here may include spec tokens. - self.input_batch.num_tokens[req_index] += num_spec_tokens + self.input_batch.num_computed_tokens.np[req_index] = ( + req_data.num_computed_tokens[i]) - # Add the new or resumed requests to the persistent batch. - # The smaller empty indices are filled first. - for request in reqs_to_add: - self.input_batch.add_request(request) + def _init_mrope_states(self, req_state: CachedRequestState) -> None: + image_grid_thw = [] + video_grid_thw = [] + second_per_grid_ts = [] + audio_feature_lengths = [] + use_audio_in_video = False + for mm_item in req_state.mm_kwargs: + mm_input = mm_item.get_data() + if (t := mm_input.get("image_grid_thw")) is not None: + image_grid_thw.append(t.tolist()) + if (t := mm_input.get("video_grid_thw")) is not None: + video_grid_thw.append(t.tolist()) + if (t := mm_input.get("second_per_grid_ts")) is not None: + second_per_grid_ts.append(t) + if (t := mm_input.get("audio_feature_lengths")) is not None: + audio_feature_lengths.append(t) + if mm_input.get("use_audio_in_video") is True: + use_audio_in_video = True + + req_state.mrope_positions, req_state.mrope_position_delta = \ + MRotaryEmbedding.get_input_positions_tensor( + req_state.prompt_token_ids, + hf_config=self.model_config.hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + audio_feature_lengths=audio_feature_lengths, + use_audio_in_video=use_audio_in_video, + ) def _extract_mm_kwargs( self, @@ -637,12 +599,24 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_reqs = self.input_batch.num_reqs assert num_reqs > 0 + # FIXME + # batch_idx -> req_id + req_ids = list(scheduler_output.num_scheduled_tokens.keys()) + # req_id -> batch_idx + req_id_to_batch_idx = {req_id: i for i, req_id in enumerate(req_ids)} + # batch_idx -> req_idx + idx_mapping = [ + self.input_batch.req_id_to_index[req_id] for req_id in req_ids + ] + # batch_idx -> req_idx + idx_mapping_tensor = self.input_batch.get_index_mapping(idx_mapping) + num_reqs = len(req_ids) + # OPTIMIZATION: Start copying the block table first. # This way, we can overlap the copy with the following CPU operations. - self.input_batch.block_table.commit_block_table(num_reqs) + block_tables = self.input_batch.make_block_tables(idx_mapping_tensor) # Get the number of scheduled tokens for each request. - req_ids = self.input_batch.req_ids tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] num_scheduled_tokens = np.array(tokens, dtype=np.int32) max_num_scheduled_tokens = max(tokens) @@ -659,7 +633,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Get positions. positions_np = self.positions_np[:total_num_scheduled_tokens] - np.add(self.input_batch.num_computed_tokens_cpu[req_indices], + np.add(self.input_batch.num_computed_tokens.np[req_indices], arange, out=positions_np) @@ -673,21 +647,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] # where M is the max_model_len. token_indices = (positions_np + - req_indices * self.input_batch.token_ids_cpu.shape[1]) + req_indices * self.input_batch.token_ids.np.shape[1]) # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large # tensors. - torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), + torch.index_select(self.input_batch.token_ids.cpu.flatten(), 0, torch.from_numpy(token_indices), out=self.input_ids_cpu[:total_num_scheduled_tokens]) - self.input_batch.block_table.compute_slot_mapping( - req_indices, positions_np) - self.input_batch.block_table.commit_slot_mapping( - total_num_scheduled_tokens) - # Prepare the attention metadata. self.query_start_loc_np[0] = 0 self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens @@ -698,7 +667,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): query_start_loc = self.query_start_loc[:num_reqs + 1] self.seq_lens_np[:num_reqs] = ( - self.input_batch.num_computed_tokens_cpu[:num_reqs] + + self.input_batch.num_computed_tokens.np[:num_reqs] + num_scheduled_tokens) # Fill unused with 0 for full cuda graph mode. self.seq_lens_np[num_reqs:].fill(0) @@ -737,8 +706,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_draft_tokens = np.zeros(num_reqs, dtype=np.int32) for req_id, draft_token_ids in ( scheduler_output.scheduled_spec_decode_tokens.items()): - req_idx = self.input_batch.req_id_to_index[req_id] - num_draft_tokens[req_idx] = len(draft_token_ids) + batch_idx = req_id_to_batch_idx[req_id] + num_draft_tokens[batch_idx] = len(draft_token_ids) spec_decode_metadata = self._calc_spec_decode_metadata( num_draft_tokens, cu_num_tokens) @@ -788,11 +757,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): per_layer_metadata[layer_name] attn_metadata[layer_name] = encoder_attn_metadata + slot_mappings = self.input_batch.make_slot_mappings( + query_start_loc, + self.positions[:total_num_scheduled_tokens], + ) # Used in the below loop. query_start_loc_cpu = self.query_start_loc_cpu[:num_reqs + 1] seq_lens_cpu = self.seq_lens_cpu[:num_reqs] num_computed_tokens_cpu = ( - self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]) + self.input_batch.num_computed_tokens.cpu[:num_reqs]) spec_decode_common_attn_metadata = None # Prepare the attention metadata for each KV cache group and make layers @@ -800,14 +773,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): - blk_table = self.input_batch.block_table[kv_cache_group_id] - blk_table_tensor = blk_table.get_device_tensor()[:num_reqs] - slot_mapping = blk_table.slot_mapping[:total_num_scheduled_tokens] - - # Fill unused with -1. Needed for reshape_and_cache in full cuda - # graph mode. - blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1) - + blk_table_tensor = block_tables[kv_cache_group_id] + slot_mapping = slot_mappings[kv_cache_group_id] common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, query_start_loc_cpu=query_start_loc_cpu, @@ -948,7 +915,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_reqs = len(num_scheduled_tokens) common_prefix_len = min( common_prefix_len, - self.input_batch.num_computed_tokens_cpu[:num_reqs].min()) + self.input_batch.num_computed_tokens.np[:num_reqs].min()) # common_prefix_len should be a multiple of the block size. common_prefix_len = (common_prefix_len // kv_cache_spec.block_size * kv_cache_spec.block_size) @@ -1454,6 +1421,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_scheduled_tokens_np, spec_decode_common_attn_metadata, max_query_len) = (self._prepare_inputs(scheduler_output)) + # FIXME + # batch_idx -> req_id + req_ids = list(scheduler_output.num_scheduled_tokens.keys()) + # req_id -> batch_idx + req_id_to_batch_idx = {req_id: i for i, req_id in enumerate(req_ids)} + # batch_idx -> req_idx + idx_mapping = [ + self.input_batch.req_id_to_index[req_id] for req_id in req_ids + ] + # batch_idx -> req_idx + idx_mapping_tensor = self.input_batch.get_index_mapping(idx_mapping) + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): @@ -1593,7 +1572,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.apply_grammar_bitmask(scheduler_output, logits) # Sample the next token and get logprobs if needed. - sampling_metadata = self.input_batch.sampling_metadata + sampling_metadata = self.input_batch.make_sampling_metadata( + idx_mapping_tensor) if spec_decode_metadata is None: sampler_output = self.sampler( logits=logits, @@ -1629,24 +1609,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if envs.VLLM_COMPUTE_NANS_IN_LOGITS: num_nans_in_logits = self._get_nans_in_logits(logits) - # TODO(woosuk): The following loop can be slow since it iterates over - # the requests one by one. Optimize. - discard_sampled_tokens_req_indices = [] - for i, req_id in enumerate(self.input_batch.req_ids): - req_state = self.requests[req_id] - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) - if seq_len < req_state.num_tokens: - # Ignore the sampled token for partial prefills. - # Rewind the generator state as if the token was not sampled. - # This relies on cuda-specific torch-internal impl details - generator = self.input_batch.generators.get(i) - if generator is not None: - generator.set_offset(generator.get_offset() - 4) - # Record the index of the request that should not be sampled, - # so that we could clear the sampled tokens before returning. - discard_sampled_tokens_req_indices.append(i) - # NOTE: GPU -> CPU Sync happens here. # Move as many CPU operations as possible before this sync point. logprobs_tensors = sampler_output.logprobs_tensors @@ -1668,37 +1630,20 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): else: # Includes spec decode tokens. valid_sampled_token_ids = self.rejection_sampler.parse_output( - sampled_token_ids, - self.input_batch.vocab_size, - ) - # Mask out the sampled tokens that should not be sampled. - for i in discard_sampled_tokens_req_indices: - valid_sampled_token_ids[i].clear() + sampled_token_ids, self.input_batch.vocab_size) + # # Mask out the sampled tokens that should not be sampled. + # for i in discard_sampled_tokens_req_indices: + # valid_sampled_token_ids[i].clear() # Cache the sampled tokens in the model runner, so that the scheduler # doesn't need to send them back. # NOTE(woosuk): As an exception, when using PP, the scheduler sends # the sampled tokens back, because there's no direct communication # between the first-stage worker and the last-stage worker. - req_ids = self.input_batch.req_ids for req_idx, sampled_ids in enumerate(valid_sampled_token_ids): if not sampled_ids: continue - - start_idx = self.input_batch.num_tokens_no_spec[req_idx] - end_idx = start_idx + len(sampled_ids) - assert end_idx <= self.max_model_len, ( - "Sampled token IDs exceed the max model length. " - f"Total number of tokens: {end_idx} > max_model_len: " - f"{self.max_model_len}") - - self.input_batch.token_ids_cpu[req_idx, - start_idx:end_idx] = sampled_ids - self.input_batch.num_tokens_no_spec[req_idx] = end_idx - self.input_batch.num_tokens[req_idx] = end_idx - req_id = req_ids[req_idx] - req_state = self.requests[req_id] - req_state.output_token_ids.extend(sampled_ids) + self.input_batch.append_token_ids(req_idx, sampled_ids) if self.speculative_config: assert spec_decode_common_attn_metadata is not None @@ -1716,8 +1661,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.eplb_step() return ModelRunnerOutput( - req_ids=self.input_batch.req_ids, - req_id_to_index=self.input_batch.req_id_to_index, + req_ids=req_ids, + req_id_to_index=req_id_to_batch_idx, sampled_token_ids=valid_sampled_token_ids, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, @@ -2389,14 +2334,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): generators={}, max_num_logprobs=None, no_penalties=True, - prompt_token_ids=None, frequency_penalties=dummy_tensors(0.1), presence_penalties=dummy_tensors(0.1), repetition_penalties=dummy_tensors(0.1), - output_token_ids=[[] for _ in range(num_reqs)], allowed_token_ids_mask=None, bad_words_token_ids={}, logitsprocs=LogitsProcessors(), + token_ids=None, ) try: sampler_output = self.sampler(logits=logits,