diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 546531a91610f..4fb6654ef00cb 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -90,9 +90,9 @@ class Sampler(nn.Module): # Apply bad words exclusion. logits = self.apply_bad_words(logits, sampling_metadata) - # Apply logits processors which can impact greedy sampling - for processor in sampling_metadata.logitsprocs.non_argmax_invariant: - logits = processor.apply(logits) + # # Apply logits processors which can impact greedy sampling + # for processor in sampling_metadata.logitsprocs.non_argmax_invariant: + # logits = processor.apply(logits) # Apply penalties (e.g., min_tokens, freq_penalties). logits = self.apply_penalties(logits, sampling_metadata) @@ -167,10 +167,10 @@ class Sampler(nn.Module): # Apply temperature. logits = self.apply_temperature(logits, sampling_metadata.temperature) - # Apply logits processors that only apply to random sampling - # (argmax invariant) - for processor in sampling_metadata.logitsprocs.argmax_invariant: - logits = processor.apply(logits) + # # Apply logits processors that only apply to random sampling + # # (argmax invariant) + # for processor in sampling_metadata.logitsprocs.argmax_invariant: + # logits = processor.apply(logits) # Apply top_k and/or top_p. random_sampled, processed_logprobs = self.topk_topp_sampler( diff --git a/vllm/v1/worker/gpu_block_table.py b/vllm/v1/worker/gpu_block_table.py new file mode 100644 index 0000000000000..c8c6f9615a464 --- /dev/null +++ b/vllm/v1/worker/gpu_block_table.py @@ -0,0 +1,312 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable + +import torch +import triton +import triton.language as tl + +from vllm.utils import cdiv +from vllm.v1.worker.utils import CpuGpuBuffer + +PAD_SLOT_ID = -1 + + +class BlockTables: + + def __init__( + self, + block_sizes: list[int], + max_num_reqs: int, + max_num_cached_reqs: int, + max_num_batched_tokens: int, + max_model_len: int, + device: torch.device, + pin_memory: bool, + ): + self.block_sizes = block_sizes + self.max_num_reqs = max_num_reqs + self.max_num_cached_reqs = max_num_cached_reqs + self.max_num_batched_tokens = max_num_batched_tokens + self.max_model_len = max_model_len + self.device = device + self.pin_memory = pin_memory + + self.num_kv_cache_groups = len(self.block_sizes) + # [num_kv_cache_groups, max_num_reqs, max_num_blocks] + self.block_tables: list[torch.Tensor] = [] + # [num_kv_cache_groups, max_num_cached_reqs, max_num_blocks] + self.block_table_buffers: list[torch.Tensor] = [] + # [num_kv_cache_groups, max_num_reqs] + self.num_blocks: list[torch.Tensor] = [] + # [num_kv_cache_groups, max_num_tokens] + self.slot_mappings: list[torch.Tensor] = [] + for i in range(self.num_kv_cache_groups): + block_size = self.block_sizes[i] + max_num_blocks = cdiv(self.max_model_len, block_size) + + block_table = torch.zeros( + self.max_num_reqs, + max_num_blocks, + dtype=torch.int32, + device=self.device, + ) + self.block_tables.append(block_table) + + block_table_buffer = torch.zeros( + self.max_num_cached_reqs, + max_num_blocks, + dtype=torch.int32, + device=self.device, + ) + self.block_table_buffers.append(block_table_buffer) + + num_blocks = torch.zeros(self.max_num_reqs, + dtype=torch.int32, + device=self.device) + self.num_blocks.append(num_blocks) + + slot_mapping = torch.zeros(self.max_num_batched_tokens, + dtype=torch.int64, + device=self.device) + self.slot_mappings.append(slot_mapping) + + self.block_table_ptrs = self._make_ptr_tensor(self.block_tables) + self.buffer_ptrs = self._make_ptr_tensor(self.block_table_buffers) + self.block_table_strides = torch.tensor( + [b.stride(0) for b in self.block_tables], + dtype=torch.int64, + device=self.device) + self.num_blocks_ptrs = self._make_ptr_tensor(self.num_blocks) + self.block_sizes_tensor = torch.tensor(self.block_sizes, + dtype=torch.int32, + device=self.device) + self.slot_mapping_ptrs = self._make_ptr_tensor(self.slot_mappings) + + # Misc buffers. + self.req_indices = self._make_buffer(self.max_num_reqs, torch.int32) + self.overwrite = self._make_buffer(self.max_num_reqs, torch.bool) + self.cu_num_new_blocks: list[CpuGpuBuffer] = [] + self.new_block_ids: list[CpuGpuBuffer] = [] + for i in range(self.num_kv_cache_groups): + self.cu_num_new_blocks.append( + self._make_buffer(self.max_num_reqs + 1, torch.int32)) + # NOTE(woosuk): Here, we assume that total number of new blocks + # is ALWAYS less than max_num_batched_tokens. + # TODO(woosuk): Rigorously verify that this assumption is correct. + self.new_block_ids.append( + self._make_buffer(self.max_num_batched_tokens, torch.int32)) + + def _make_buffer(self, n: int, dtype: torch.dtype) -> CpuGpuBuffer: + return CpuGpuBuffer(n, + dtype=dtype, + pin_memory=self.pin_memory, + device=self.device) + + def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor: + ptrs_tensor_cpu = torch.tensor([t.data_ptr() for t in x], + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory) + return ptrs_tensor_cpu.to(self.device, non_blocking=True) + + def append_block_ids( + self, + # [num_reqs] + req_indices: list[int], + # [num_kv_cache_groups, num_reqs + 1] + cu_num_new_blocks: list[list[int]], + # [num_kv_cache_groups, num_new_blocks] + new_block_ids: list[list[int]], + # [num_reqs] + overwrite: list[bool], + ) -> None: + # TODO(woosuk): Optimize & simplify this. + num_reqs = len(req_indices) + self.req_indices.np[:num_reqs] = req_indices + self.overwrite.np[:num_reqs] = overwrite + for i in range(self.num_kv_cache_groups): + self.cu_num_new_blocks[i].np[:num_reqs + 1] = cu_num_new_blocks[i] + n = len(new_block_ids[i]) + self.new_block_ids[i].np[:n] = new_block_ids[i] + + cu_num_new_blocks_ptrs = self._make_ptr_tensor( + [x.copy_to_gpu(num_reqs + 1) for x in self.cu_num_new_blocks]) + new_block_ids_ptrs = self._make_ptr_tensor([ + x.copy_to_gpu(len(new_block_ids[i])) + for i, x in enumerate(self.new_block_ids) + ]) + _append_block_ids_kernel[(num_reqs, self.num_kv_cache_groups)]( + self.req_indices.copy_to_gpu(num_reqs), + cu_num_new_blocks_ptrs, + new_block_ids_ptrs, + self.overwrite.copy_to_gpu(num_reqs), + self.block_table_strides, + self.buffer_ptrs, + self.num_blocks_ptrs, + BLOCK_SIZE=1024, + ) + + def compute_block_tables( + self, + idx_mapping: torch.Tensor, + ) -> tuple[torch.Tensor, ...]: + batch_size = idx_mapping.shape[0] + _compute_block_tables_kernel[(batch_size, self.num_kv_cache_groups)]( + idx_mapping, + self.buffer_ptrs, + self.block_table_ptrs, + self.block_table_strides, + self.num_blocks_ptrs, + BLOCK_SIZE=1024, + ) + return tuple(b[:batch_size] for b in self.block_tables) + + def compute_slot_mappings( + self, + cu_num_tokens: torch.Tensor, + pos: torch.Tensor, + ) -> tuple[torch.Tensor, ...]: + num_tokens = pos.shape[0] + num_reqs = cu_num_tokens.shape[0] - 1 + num_groups = self.num_kv_cache_groups + _compute_slot_mappings_kernel[(num_reqs + 1, num_groups)]( + num_tokens, + self.max_num_batched_tokens, + cu_num_tokens, + pos, + self.block_table_ptrs, + self.block_table_strides, + self.block_sizes_tensor, + self.slot_mapping_ptrs, + PAD_ID=PAD_SLOT_ID, + BLOCK_SIZE=1024, + ) + return tuple(x[:num_tokens] for x in self.slot_mappings) + + +@triton.jit +def _append_block_ids_kernel( + # Inputs + req_indices, # [num_reqs] + cu_num_new_block_ptrs, # [num_kv_cache_groups, num_reqs + 1] + new_block_id_ptrs, # [num_kv_cache_groups, num_new_blocks] + overwrite, # [num_reqs] + block_table_strides, # [num_kv_cache_groups] + # Outputs + block_table_buffer_ptrs, # [num_kv_cache_groups] + num_block_ptrs, # [num_kv_cache_groups] + # Constants + BLOCK_SIZE: tl.constexpr, +): + batch_idx = tl.program_id(0) + group_id = tl.program_id(1) + req_idx = tl.load(req_indices + batch_idx) + do_overwrite = tl.load(overwrite + batch_idx) + + cu_num_new_blocks_ptr = _load_ptr(cu_num_new_block_ptrs + group_id, + tl.int32) + start_idx = tl.load(cu_num_new_blocks_ptr + batch_idx) + end_idx = tl.load(cu_num_new_blocks_ptr + batch_idx + 1) + num_new_blocks = end_idx - start_idx + + num_blocks_ptr = _load_ptr(num_block_ptrs + group_id, tl.int32) + if do_overwrite: + dst_start_idx = 0 + else: + dst_start_idx = tl.load(num_blocks_ptr + req_idx) + dst_end_idx = dst_start_idx + num_new_blocks + tl.store(num_blocks_ptr + req_idx, dst_end_idx) + + # Destination + block_table_buffer_ptr = _load_ptr(block_table_buffer_ptrs + group_id, + tl.int32) + block_table_stride = tl.load(block_table_strides + group_id) + buffer_row_ptr = block_table_buffer_ptr + req_idx * block_table_stride + + new_block_ids_ptr = _load_ptr(new_block_id_ptrs + group_id, tl.int32) + for i in tl.range(0, num_new_blocks, BLOCK_SIZE): + offset = i + tl.arange(0, BLOCK_SIZE) + block_ids = tl.load(new_block_ids_ptr + start_idx + offset, + mask=offset < num_new_blocks) + tl.store(buffer_row_ptr + dst_start_idx + offset, + block_ids, + mask=offset < num_new_blocks) + + +@triton.jit +def _compute_block_tables_kernel( + batch_idx_to_req_idx, # [batch_size] + src_block_table_ptrs, # [num_kv_cache_groups] + dst_block_table_ptrs, # [num_kv_cache_groups] + block_table_strides, # [num_kv_cache_groups] + num_blocks_ptrs, # [num_kv_cache_groups] + BLOCK_SIZE: tl.constexpr, +): + batch_idx = tl.program_id(0) + # kv cache group id + group_id = tl.program_id(1) + req_idx = tl.load(batch_idx_to_req_idx + batch_idx) + + num_blocks_ptr = _load_ptr(num_blocks_ptrs + group_id, tl.int32) + num_blocks = tl.load(num_blocks_ptr + req_idx) + + stride = tl.load(block_table_strides + group_id) + src_block_table_ptr = _load_ptr(src_block_table_ptrs + group_id, tl.int32) + src_row_ptr = src_block_table_ptr + req_idx * stride + dst_block_table_ptr = _load_ptr(dst_block_table_ptrs + group_id, tl.int32) + dst_row_ptr = dst_block_table_ptr + batch_idx * stride + + for i in tl.range(0, num_blocks, BLOCK_SIZE): + offset = i + tl.arange(0, BLOCK_SIZE) + block_ids = tl.load(src_row_ptr + offset, mask=offset < num_blocks) + tl.store(dst_row_ptr + offset, block_ids, mask=offset < num_blocks) + + +@triton.jit +def _compute_slot_mappings_kernel( + num_tokens, + max_num_tokens, + cu_num_tokens, # [num_reqs + 1] + pos, # [num_tokens] + block_table_ptrs, # [num_kv_cache_groups] + block_table_strides, # [num_kv_cache_groups] + page_sizes, # [num_kv_cache_groups] + slot_mapping_ptrs, # [num_kv_cache_groups] + PAD_ID: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + req_idx = tl.program_id(0) + # kv cache group id + group_id = tl.program_id(1) + slot_mapping_ptr = _load_ptr(slot_mapping_ptrs + group_id, tl.int64) + + if req_idx == tl.num_programs(0) - 1: + # Pad remaining slots to -1. This is needed for CUDA graphs. + for i in tl.range(num_tokens, max_num_tokens, BLOCK_SIZE): + offset = i + tl.arange(0, BLOCK_SIZE) + tl.store(slot_mapping_ptr + offset, + PAD_ID, + mask=offset < max_num_tokens) + return + + block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32) + block_table_stride = tl.load(block_table_strides + group_id) + page_size = tl.load(page_sizes + group_id) + + start_idx = tl.load(cu_num_tokens + req_idx) + end_idx = tl.load(cu_num_tokens + req_idx + 1) + for i in tl.range(start_idx, end_idx, BLOCK_SIZE): + offset = i + tl.arange(0, BLOCK_SIZE) + positions = tl.load(pos + offset, mask=offset < end_idx, other=0) + block_indices = positions // page_size + block_numbers = tl.load(block_table_ptr + + req_idx * block_table_stride + block_indices) + slot_ids = block_numbers * page_size + positions % page_size + tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx) + + +@triton.jit +def _load_ptr(ptr_to_ptr, elem_dtype): + ptr = tl.load(ptr_to_ptr) + return tl.cast(ptr, tl.pointer_type(elem_dtype)) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 1cd3bb3f59a85..699b273642b07 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -1,401 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# Datastructures defining a GPU input batch - from dataclasses import dataclass -from typing import Optional +from typing import Any, Optional +import numba import numpy as np import torch -import triton -import triton.language as tl -from typing_extensions import deprecated +from numba import types -from vllm.lora.request import LoRARequest -from vllm.multimodal.inputs import (MultiModalKwargsItem, - MultiModalKwargsItems, PlaceholderRange) -from vllm.pooling_params import PoolingParams -from vllm.sampling_params import SamplingParams, SamplingType -from vllm.utils import cdiv, get_cuda_view_from_cpu_tensor, is_uva_available -from vllm.v1.sample.logits_processor import LogitsProcessors -from vllm.v1.sample.metadata import SamplingMetadata - -PAD_SLOT_ID = -1 - - -@dataclass -class RequestData: - - mm_kwargs: list[MultiModalKwargsItem] - mm_positions: list[PlaceholderRange] - sampling_params: Optional[SamplingParams] - pooling_params: Optional[PoolingParams] - - # M-RoPE (only for Qwen2/2.5-VL) - mrope_positions: Optional[torch.Tensor] = None - mrope_position_delta: Optional[int] = None - - lora_request: Optional[LoRARequest] = None - - # Temporary back-compatibility for plugins that define model runner - @property - @deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be " - "removed in v0.13. Please use `mm_kwargs` instead.") - def mm_inputs(self) -> list[MultiModalKwargsItems]: - return [ - MultiModalKwargsItems.from_seq([item]) for item in self.mm_kwargs - ] - - -class PerRequestAttribute: - - def __init__( - self, - num_rows_cpu: int, - num_cols: int, - num_rows_gpu: int, - dtype: torch.dtype, - device: torch.device, - is_scalar: bool = False, - ): - assert is_uva_available(), "UVA is not available." - self.cpu = torch.zeros(num_rows_cpu, - num_cols, - dtype=dtype, - device="cpu", - pin_memory=True) - self.np = self.cpu.numpy() - self.uva = get_cuda_view_from_cpu_tensor(self.cpu) - self.gpu = torch.zeros(num_rows_gpu, - num_cols, - dtype=dtype, - device=device) - if is_scalar: - assert num_cols == 1 - self.cpu.squeeze_(1) - self.np = self.cpu.numpy() - self.uva.squeeze_(1) - self.gpu.squeeze_(1) - - -class RequestState: - - def __init__( - self, - max_num_reqs: int, - max_model_len: int, - max_num_batched_tokens: int, - max_num_cached_reqs: int, - device: torch.device, - pin_memory: bool, - vocab_size: int, - block_sizes: list[int], # The block_size of each kv cache group - logitsprocs: Optional[LogitsProcessors] = None, - is_spec_decode: bool = False, - is_pooling_model: bool = False, - ): - self.max_num_reqs = max_num_reqs - self.max_model_len = max_model_len - self.max_num_batched_tokens = max_num_batched_tokens - self.max_num_cached_reqs = max_num_cached_reqs - self.device = device - self.pin_memory = pin_memory - self.vocab_size = vocab_size - self.is_spec_decode = is_spec_decode - self.pooling_params = None - self.block_sizes = block_sizes - self.num_prompt_logprobs = {} - - self.req_id_to_index: dict[str, int] = {} - self.index_to_req_id: dict[int, str] = {} - self.free_indices = list(range(max_num_cached_reqs)) - # Used to construct the input batch. - self._add_scalar_attr("idx_mapping", torch.int32) - - # Request states. - self.req_data: dict[int, RequestData] = {} - # TODO(woosuk): Because the token_ids tensor can be very big, we only - # initialize it on CPU memory. - self._add_vector_attr_cpu("token_ids", self.max_model_len, torch.int32) - self._add_scalar_attr("num_prompt_tokens", torch.int32) - self._add_scalar_attr("num_tokens", torch.int32) - self._add_scalar_attr("num_computed_tokens", torch.int32) - - # Sampling-related. - self._add_scalar_attr("temperature", torch.float32) - self.greedy_reqs: set[str] = set() - self.random_reqs: set[str] = set() - self._add_scalar_attr("top_p", torch.float32) - self.top_p_reqs: set[str] = set() - self._add_scalar_attr("top_k", torch.int32) - self.top_k_reqs: set[str] = set() - self._add_scalar_attr("frequency_penalties", torch.float32) - self.frequency_penalties_reqs: set[str] = set() - self._add_scalar_attr("presence_penalties", torch.float32) - self.presence_penalties_reqs: set[str] = set() - self._add_scalar_attr("repetition_penalties", torch.float32) - self.repetition_penalties_reqs: set[str] = set() - - # req_idx -> generator - self.generators: dict[int, torch.Generator] = {} - - # Block table(s). - self._init_block_tables() - - def add_request( - self, - req_id: str, - prompt_token_ids: list[int], - num_computed_tokens: int, - block_ids: tuple[list[int], ...], - sampling_params: SamplingParams, - ) -> None: - req_idx = self.free_indices.pop() - self.req_id_to_index[req_id] = req_idx - self.index_to_req_id[req_idx] = req_id - - self.num_prompt_tokens.np[req_idx] = len(prompt_token_ids) - self.num_computed_tokens.np[req_idx] = num_computed_tokens - self.append_token_ids(req_idx, prompt_token_ids) - self.append_block_ids(req_idx, block_ids, overwrite=True) - - self.temperature.np[req_idx] = sampling_params.temperature - if sampling_params.sampling_type == SamplingType.GREEDY: - # NOTE: Be careful about division by zero. - self.greedy_reqs.add(req_id) - elif sampling_params.sampling_type == SamplingType.RANDOM: - self.random_reqs.add(req_id) - - self.top_p.np[req_idx] = sampling_params.top_p - if sampling_params.top_p < 1.0: - self.top_p_reqs.add(req_id) - - top_k = sampling_params.top_k - if 0 < top_k < self.vocab_size: - self.top_k_reqs.add(req_id) - else: - top_k = self.vocab_size - self.top_k.np[req_idx] = top_k - - self.frequency_penalties.np[ - req_idx] = sampling_params.frequency_penalty - if sampling_params.frequency_penalty != 0.0: - self.frequency_penalties_reqs.add(req_id) - self.presence_penalties.np[req_idx] = sampling_params.presence_penalty - if sampling_params.presence_penalty != 0.0: - self.presence_penalties_reqs.add(req_id) - self.repetition_penalties.np[ - req_idx] = sampling_params.repetition_penalty - if sampling_params.repetition_penalty != 1.0: - self.repetition_penalties_reqs.add(req_id) - - if sampling_params.sampling_type == SamplingType.RANDOM_SEED: - generator = torch.Generator(device=self.device) - generator.manual_seed(sampling_params.seed) - self.generators[req_idx] = generator - - def append_token_ids(self, req_idx: int, token_ids: list[int]) -> None: - start_idx = self.num_tokens.np[req_idx] - end_idx = start_idx + len(token_ids) - self.token_ids.np[req_idx, start_idx:end_idx] = token_ids - self.num_tokens.np[req_idx] = end_idx - - # TODO(woosuk): Further vectorize this to minimize overheads. - def append_block_ids( - self, - req_idx: int, - new_block_ids: tuple[list[int], ...], - overwrite: bool, - ) -> None: - for i in range(self.num_block_tables): - block_table = self.block_tables[i] - num_blocks = self.num_blocks[i] - num_new_blocks = len(new_block_ids[i]) - if overwrite: - # Replace the existing block IDs with the new ones. - # This happens when the request is resumed from preemption. - block_table.np[req_idx, :num_new_blocks] = new_block_ids[i] - num_blocks.np[req_idx] = num_new_blocks - else: - # Append the new block IDs to the existing ones (common case). - start_idx = num_blocks.np[req_idx] - end_idx = start_idx + num_new_blocks - block_table.np[req_idx, start_idx:end_idx] = new_block_ids[i] - num_blocks.np[req_idx] = end_idx - - def remove_request(self, req_id: str) -> None: - req_idx = self.req_id_to_index.pop(req_id, None) - if req_idx is None: - # Request not found. - return - self.index_to_req_id.pop(req_idx, None) - self.free_indices.append(req_idx) - - self.greedy_reqs.discard(req_id) - self.random_reqs.discard(req_id) - self.top_p_reqs.discard(req_id) - self.top_k_reqs.discard(req_id) - self.frequency_penalties_reqs.discard(req_id) - self.presence_penalties_reqs.discard(req_id) - self.repetition_penalties_reqs.discard(req_id) - self.generators.pop(req_idx, None) - - def get_index_mapping(self, idx_mapping: list[int]) -> torch.Tensor: - num_reqs = len(idx_mapping) - self.idx_mapping.np[:num_reqs] = idx_mapping - return self.idx_mapping.gpu[:num_reqs].copy_( - self.idx_mapping.uva[:num_reqs], non_blocking=True) - - def make_sampling_metadata( - self, - batch_idx_to_req_idx: torch.Tensor, - ) -> SamplingMetadata: - batch_size = batch_idx_to_req_idx.shape[0] - _make_sampling_metadata_kernel[(batch_size, )]( - batch_idx_to_req_idx, - self.temperature.uva, - self.temperature.gpu, - self.top_p.uva, - self.top_p.gpu, - self.top_k.uva, - self.top_k.gpu, - self.frequency_penalties.uva, - self.frequency_penalties.gpu, - self.presence_penalties.uva, - self.presence_penalties.gpu, - self.repetition_penalties.uva, - self.repetition_penalties.gpu, - num_warps=1, - num_stages=1, - ) - no_penalties = not (self.frequency_penalties_reqs - or self.presence_penalties_reqs - or self.repetition_penalties_reqs) - return SamplingMetadata( - temperature=self.temperature.gpu[:batch_size], - all_greedy=not self.random_reqs, - all_random=not self.greedy_reqs, - top_p=self.top_p.gpu[:batch_size], - top_k=self.top_k.gpu[:batch_size], - frequency_penalties=self.frequency_penalties.gpu[:batch_size], - presence_penalties=self.presence_penalties.gpu[:batch_size], - repetition_penalties=self.repetition_penalties.gpu[:batch_size], - no_penalties=no_penalties, - # TODO - generators={}, - token_ids=self.token_ids.gpu[:batch_size], - max_num_logprobs=None, - allowed_token_ids_mask=None, - bad_words_token_ids={}, - logitsprocs=None, - ) - - @property - def num_reqs(self) -> int: - return len(self.req_id_to_index) - - def _add_scalar_attr(self, name: str, dtype: torch.dtype): - attr = PerRequestAttribute(self.max_num_cached_reqs, - 1, - self.max_num_reqs, - dtype, - self.device, - is_scalar=True) - setattr(self, name, attr) - - def _add_vector_attr(self, name: str, max_len: int, dtype: torch.dtype): - attr = PerRequestAttribute(self.max_num_cached_reqs, max_len, - self.max_num_reqs, dtype, self.device) - setattr(self, name, attr) - - def _add_vector_attr_cpu(self, name: str, max_len: int, - dtype: torch.dtype): - attr = PerRequestAttribute(self.max_num_cached_reqs, max_len, 0, dtype, - self.device) - setattr(self, name, attr) - - def _init_block_tables(self): - self.num_block_tables = len(self.block_sizes) - self.block_tables = [] - self.num_blocks = [] - self.slot_mappings: list[torch.Tensor] = [] - for i in range(self.num_block_tables): - max_num_blocks = cdiv(self.max_model_len, self.block_sizes[i]) - block_table = PerRequestAttribute(self.max_num_cached_reqs, - max_num_blocks, - self.max_num_reqs, torch.int32, - self.device) - self.block_tables.append(block_table) - num_blocks = PerRequestAttribute(self.max_num_cached_reqs, - 1, - self.max_num_reqs, - torch.int32, - self.device, - is_scalar=True) - self.num_blocks.append(num_blocks) - slot_mapping = torch.zeros(self.max_num_batched_tokens, - dtype=torch.int64, - device=self.device) - self.slot_mappings.append(slot_mapping) - - def make_ptr_tensor(x: list[torch.Tensor]) -> torch.Tensor: - return torch.tensor([t.data_ptr() for t in x], - dtype=torch.int64, - device=self.device) - - self.uva_block_table_ptrs = make_ptr_tensor( - [b.uva for b in self.block_tables]) - self.gpu_block_table_ptrs = make_ptr_tensor( - [b.gpu for b in self.block_tables]) - self.uva_num_blocks_ptrs = make_ptr_tensor( - [n.uva for n in self.num_blocks]) - self.gpu_num_blocks_ptrs = make_ptr_tensor( - [n.gpu for n in self.num_blocks]) - self.block_table_strides = torch.tensor( - [b.gpu.shape[1] for b in self.block_tables], - dtype=torch.int64, - device=self.device) - self.block_sizes_tensor = torch.tensor(self.block_sizes, - dtype=torch.int32, - device=self.device) - self.slot_mapping_ptrs = make_ptr_tensor(self.slot_mappings) - - def make_block_tables( - self, - idx_mapping: torch.Tensor, - ) -> tuple[torch.Tensor, ...]: - batch_size = idx_mapping.shape[0] - _make_block_tables_kernel[(batch_size, self.num_block_tables)]( - idx_mapping, - self.uva_block_table_ptrs, - self.gpu_block_table_ptrs, - self.block_table_strides, - self.uva_num_blocks_ptrs, - self.gpu_num_blocks_ptrs, - BLOCK_SIZE=1024, - ) - return tuple(b.gpu[:batch_size] for b in self.block_tables) - - def make_slot_mappings( - self, - cu_num_tokens: torch.Tensor, - pos: torch.Tensor, - ) -> tuple[torch.Tensor, ...]: - num_tokens = pos.shape[0] - num_reqs = cu_num_tokens.shape[0] - 1 - _make_slot_mappings_kernel[(num_reqs + 1, self.num_block_tables)]( - num_tokens, - self.max_num_batched_tokens, - cu_num_tokens, - pos, - self.gpu_block_table_ptrs, - self.block_table_strides, - self.block_sizes_tensor, - self.slot_mapping_ptrs, - PAD_ID=PAD_SLOT_ID, - BLOCK_SIZE=1024, - ) - return tuple(x[:num_tokens] for x in self.slot_mappings) +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata @dataclass @@ -403,134 +16,78 @@ class InputBatch: # batch_idx -> req_id req_ids: list[str] + + # req_id -> batch_idx + req_id_to_batch_idx: dict[str, int] + # batch_idx -> req_state_idx idx_mapping: torch.Tensor idx_mapping_np: np.ndarray - # [num_kv_cache_groups, num_reqs, max_num_blocks_per_req] - block_tables: tuple[torch.Tensor, ...] - # [num_kv_cache_groups, num_tokens] - slot_mappings: tuple[torch.Tensor, ...] + # batch_idx -> num_scheduled_tokens + num_scheduled_tokens: np.ndarray + total_num_tokens: int + max_num_tokens: int + num_reqs: int - # [num_reqs] mostly - sampling_metadata: SamplingMetadata + attn_metadata: dict[str, Any] + spec_decode_common_attn_metadata: Optional[Any] + spec_decode_metadata: Optional[SpecDecodeMetadata] + + logits_indices: torch.Tensor -@triton.jit -def _make_sampling_metadata_kernel( - batch_idx_to_req_idx, # [batch_size] - src_temperature, - dst_temperature, - src_top_p, - dst_top_p, - src_top_k, - dst_top_k, - src_frequency_penalties, - dst_frequency_penalties, - src_presence_penalties, - dst_presence_penalties, - src_repetition_penalties, - dst_repetition_penalties, -): - batch_idx = tl.program_id(0) - req_idx = tl.load(batch_idx_to_req_idx + batch_idx) +# NOTE: With the type annotations, this function is pre-compiled +# before the first call. +@numba.jit( + [ + types.none( + types.int32[:], # idx_mapping + types.int32[:, :], # token_ids + types.int32[:], # num_computed_tokens + types.int32[:], # num_scheduled_tokens + types.int32[:], # input_ids + types.int32[:], # query_start_loc + types.int32[:], # seq_lens + types.int64[:], # positions + ) + ], + nopython=True, + cache=True, +) +def prepare_inputs( + # Inputs + idx_mapping: np.ndarray, # batch_idx -> req_idx + token_ids: np.ndarray, # [N, max_model_len] + num_computed_tokens: np.ndarray, # [N] + num_scheduled_tokens: np.ndarray, # [B] + # Outputs + input_ids: np.ndarray, # [num_input_tokens] + query_start_loc: np.ndarray, # [B + 1] + seq_lens: np.ndarray, # [B] + positions: np.ndarray, # [num_input_tokens] +) -> None: + num_reqs = num_scheduled_tokens.shape[0] + query_start_loc[0] = 0 - temperature = tl.load(src_temperature + req_idx) - tl.store(dst_temperature + batch_idx, temperature) + cu_num_tokens = 0 + for i in range(num_reqs): + req_idx = idx_mapping[i] + start = num_computed_tokens[req_idx] + end = start + num_scheduled_tokens[i] + seq_lens[i] = end - top_p = tl.load(src_top_p + req_idx) - tl.store(dst_top_p + batch_idx, top_p) + start_idx = cu_num_tokens + end_idx = start_idx + num_scheduled_tokens[i] + input_ids[start_idx:end_idx] = token_ids[req_idx, start:end] + positions[start_idx:end_idx] = np.arange(start, end) - top_k = tl.load(src_top_k + req_idx) - tl.store(dst_top_k + batch_idx, top_k) + cu_num_tokens = end_idx + query_start_loc[i + 1] = cu_num_tokens - frequency_penalties = tl.load(src_frequency_penalties + req_idx) - tl.store(dst_frequency_penalties + batch_idx, frequency_penalties) - - presence_penalties = tl.load(src_presence_penalties + req_idx) - tl.store(dst_presence_penalties + batch_idx, presence_penalties) - - repetition_penalties = tl.load(src_repetition_penalties + req_idx) - tl.store(dst_repetition_penalties + batch_idx, repetition_penalties) - - -@triton.jit -def _make_block_tables_kernel( - batch_idx_to_req_idx, # [batch_size] - src_block_table_ptrs, # [num_block_tables] - dst_block_table_ptrs, # [num_block_tables] - block_table_strides, # [num_block_tables] - src_num_blocks_ptrs, # [num_block_tables] - dst_num_blocks_ptrs, # [num_block_tables] - BLOCK_SIZE: tl.constexpr, -): - batch_idx = tl.program_id(0) - # kv cache group id - group_id = tl.program_id(1) - req_idx = tl.load(batch_idx_to_req_idx + batch_idx) - - src_num_blocks_ptr = _load_ptr(src_num_blocks_ptrs, group_id, tl.int32) - dst_num_blocks_ptr = _load_ptr(dst_num_blocks_ptrs, group_id, tl.int32) - num_blocks = tl.load(src_num_blocks_ptr + req_idx) - tl.store(dst_num_blocks_ptr + batch_idx, num_blocks) - - stride = tl.load(block_table_strides + group_id) - src_block_table_ptr = _load_ptr(src_block_table_ptrs, group_id, tl.int32) - src_row_ptr = src_block_table_ptr + req_idx * stride - dst_block_table_ptr = _load_ptr(dst_block_table_ptrs, group_id, tl.int32) - dst_row_ptr = dst_block_table_ptr + batch_idx * stride - - for i in tl.range(0, num_blocks, BLOCK_SIZE): - offset = i + tl.arange(0, BLOCK_SIZE) - block_ids = tl.load(src_row_ptr + offset, mask=offset < num_blocks) - tl.store(dst_row_ptr + offset, block_ids, mask=offset < num_blocks) - - -@triton.jit -def _make_slot_mappings_kernel( - num_tokens, - max_num_tokens, - cu_num_tokens, # [num_reqs + 1] - pos, # [num_tokens] - block_table_ptrs, # [num_block_tables] - block_table_strides, # [num_block_tables] - page_sizes, # [num_block_tables] - slot_mapping_ptrs, # [num_block_tables] - PAD_ID: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - req_idx = tl.program_id(0) - num_reqs = tl.num_programs(0) - # kv cache group id - group_id = tl.program_id(1) - slot_mapping_ptr = _load_ptr(slot_mapping_ptrs, group_id, tl.int64) - - if req_idx == num_reqs - 1: - # Pad remaining slots to -1. This is needed for CUDA graphs. - for i in tl.range(num_tokens, max_num_tokens, BLOCK_SIZE): - offset = num_tokens + i + tl.arange(0, BLOCK_SIZE) - tl.store(slot_mapping_ptr + offset, - PAD_ID, - mask=offset < max_num_tokens) - return - - block_table_ptr = _load_ptr(block_table_ptrs, group_id, tl.int32) - block_table_stride = tl.load(block_table_strides + group_id) - page_size = tl.load(page_sizes + group_id) - - start_idx = tl.load(cu_num_tokens + req_idx) - end_idx = tl.load(cu_num_tokens + req_idx + 1) - for i in tl.range(start_idx, end_idx, BLOCK_SIZE): - offset = start_idx + i + tl.arange(0, BLOCK_SIZE) - positions = tl.load(pos + offset, mask=offset < end_idx, other=0) - block_indices = positions // page_size - block_numbers = tl.load(block_table_ptr + - req_idx * block_table_stride + block_indices) - slot_ids = block_numbers * page_size + positions % page_size - tl.store(slot_mapping_ptr + offset, slot_ids, mask=offset < end_idx) - - -@triton.jit -def _load_ptr(base, offset, elem_dtype): - ptr = tl.load(base + offset) - return tl.cast(ptr, tl.pointer_type(elem_dtype)) + # Pad the inputs for CUDA graphs. + # Note: pad query_start_loc to be non-decreasing, as kernels + # like FlashAttention requires that + query_start_loc[num_reqs + 1:].fill(cu_num_tokens) + # Fill unused with 0 for full cuda graph mode. + seq_lens[num_reqs:].fill(0) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f19f04766e6df..eb23aeb70dfd7 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -68,7 +68,7 @@ from vllm.v1.kv_cache_interface import (AttentionSpec, from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds, LogprobsTensors, ModelRunnerOutput) from vllm.v1.pool.metadata import PoolingMetadata -from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs +from vllm.v1.sample.logits_processor import LogitsProcessors from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.sample.sampler import Sampler @@ -76,7 +76,9 @@ from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer -from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch +from vllm.v1.worker.gpu_block_table import BlockTables +from vllm.v1.worker.gpu_input_batch import InputBatch, prepare_inputs +from vllm.v1.worker.gpu_worker_states import RequestState from vllm.v1.worker.kv_connector_model_runner_mixin import ( KVConnectorModelRunnerMixin, KVConnectorOutput) from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin @@ -200,18 +202,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.rejection_sampler = RejectionSampler() # Request states. - self.requests: dict[str, CachedRequestState] = {} - - # Input Batch - # NOTE(Chen): Ideally, we should initialize the input batch inside - # `initialize_kv_cache` based on the kv cache config. However, as in - # https://github.com/vllm-project/vllm/pull/18298, due to some unknown - # reasons, we have to initialize the input batch before `load_model`, - # quantization + weight offloading will fail otherwise. As a temporary - # solution, we initialize the input batch here, and re-initialize it - # in `initialize_kv_cache` if the block_sizes here is different from - # the block_sizes in the kv cache config. - self.input_batch = InputBatch( + self.requests = RequestState( max_num_reqs=self.max_num_reqs, max_model_len=self.max_model_len, max_num_batched_tokens=self.max_num_tokens, @@ -220,12 +211,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): pin_memory=self.pin_memory, vocab_size=self.model_config.get_vocab_size(), block_sizes=[self.cache_config.block_size], - is_spec_decode=bool(self.vllm_config.speculative_config), - logitsprocs=build_logitsprocs( - self.vllm_config, self.device, self.pin_memory, - self.is_pooling_model, - self.vllm_config.model_config.logits_processors), - is_pooling_model=self.is_pooling_model, ) # TODO(woosuk): Provide an option to tune the max cudagraph batch size. @@ -253,9 +238,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.seq_lens = torch.zeros(self.max_num_reqs, dtype=torch.int32, device=self.device) - self.slot_mapping = torch.zeros(self.max_num_tokens, - dtype=torch.int64, - device=self.device) # None in the first PP rank. The rest are set after load_model. self.intermediate_tensors: Optional[IntermediateTensors] = None @@ -290,12 +272,23 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): dtype=self.dtype, device=self.device) + self.block_tables = BlockTables( + block_sizes=[self.cache_config.block_size], + max_num_reqs=self.max_num_reqs, + max_num_cached_reqs=2 * self.max_num_reqs, + max_num_batched_tokens=self.max_num_tokens, + max_model_len=self.max_model_len, + device=self.device, + pin_memory=self.pin_memory, + ) + # OPTIMIZATION: Cache the tensors rather than creating them every step. # Keep in int64 to avoid overflow with long context self.arange_np = np.arange(max(self.max_num_reqs + 1, self.max_model_len, self.max_num_tokens), dtype=np.int64) + # NOTE(woosuk): These tensors are "stateless", i.e., they are literally # a faster version of creating a new tensor every time. Thus, we should # not make any assumptions about the values in these tensors. @@ -303,6 +296,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): dtype=torch.int32, device="cpu", pin_memory=self.pin_memory) + self.input_ids_np = self.input_ids_cpu.numpy() self.positions_cpu = torch.zeros(self.max_num_tokens, dtype=torch.int64, device="cpu", @@ -319,6 +313,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): pin_memory=self.pin_memory) self.seq_lens_np = self.seq_lens_cpu.numpy() + self.index_mapping_cpu = torch.zeros(self.max_num_reqs, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + self.index_mapping_np = self.index_mapping_cpu.numpy() + self.index_mapping = self.index_mapping_cpu.to(self.device) + # Layer pairings for cross-layer KV sharing. # If an Attention layer `layer_name` is in the keys of this dict, it # means this layer will perform attention using the keys and values @@ -410,10 +411,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): The SamplingMetadata is updated and copied to the GPU if there is a new/resumed/paused/finished request in the batch. """ - # Remove finished requests from the cached states. - for req_id in scheduler_output.finished_req_ids: - self.requests.pop(req_id, None) - self.encoder_cache.pop(req_id, None) # Remove the finished requests from the persistent batch. # NOTE(woosuk): There could be an edge case where finished_req_ids and # scheduled_req_ids overlap. This happens when a request is aborted and @@ -421,7 +418,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # distinct requests - clearing the cached states for the first request # and handling the second as a new request. for req_id in scheduler_output.finished_req_ids: - self.input_batch.remove_request(req_id) + self.requests.remove_request(req_id) + self.encoder_cache.pop(req_id, None) # Free the cached encoder outputs. for req_id, input_id in scheduler_output.free_encoder_input_ids: @@ -431,120 +429,82 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if not encoder_outputs: self.encoder_cache.pop(req_id, None) - # Remove the unscheduled requests from the persistent batch. - # NOTE(woosuk): The unscheduled requests are either preempted requests - # or running requests that are not scheduled in this step. We remove - # them from the persistent batch but keep their cached states since - # they will be scheduled again sometime in the future. - scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys() - cached_req_ids = self.input_batch.req_id_to_index.keys() - unscheduled_req_ids = cached_req_ids - scheduled_req_ids - # NOTE(woosuk): The persistent batch optimization assumes that - # consecutive batches contain mostly the same requests. If batches - # have low request overlap (e.g., alternating between two distinct - # sets of requests), this optimization becomes very inefficient. - for req_id in unscheduled_req_ids: - self.input_batch.remove_request(req_id) + req_indices: list[int] = [] + cu_num_new_blocks: list[list[int]] = [ + [0] for _ in range(self.block_tables.num_kv_cache_groups) + ] + new_block_ids: list[list[int]] = [ + [] for _ in range(self.block_tables.num_kv_cache_groups) + ] + overwrite: list[bool] = [] # Add new requests to the cached states. for new_req_data in scheduler_output.scheduled_new_reqs: req_id = new_req_data.req_id - sampling_params = new_req_data.sampling_params - pooling_params = new_req_data.pooling_params - - if pooling_params: - task = pooling_params.task - assert task is not None, "You did not set `task` in the API" - - model = cast(VllmModelForPooling, self.get_model()) - to_update = model.pooler.get_pooling_updates(task) - to_update.apply(pooling_params) - - req_state = CachedRequestState( - req_id=req_id, - mm_kwargs=new_req_data.mm_kwargs, - mm_positions=new_req_data.mm_positions, - sampling_params=sampling_params, - pooling_params=pooling_params, - lora_request=new_req_data.lora_request, - ) - self.requests[req_id] = req_state - self.input_batch.add_request( + self.requests.add_request( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, num_computed_tokens=new_req_data.num_computed_tokens, - block_ids=new_req_data.block_ids, - sampling_params=sampling_params, + sampling_params=new_req_data.sampling_params, ) + req_index = self.requests.req_id_to_index[req_id] + req_indices.append(req_index) + for i, block_ids in enumerate(new_req_data.block_ids): + x = cu_num_new_blocks[i][-1] + cu_num_new_blocks[i].append(x + len(block_ids)) + new_block_ids[i].extend(block_ids) + overwrite.append(True) + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: - self._init_mrope_positions(req_state) + self._init_mrope_positions(req_id) # Update the states of the running/resumed requests. is_last_rank = get_pp_group().is_last_rank - req_data = scheduler_output.scheduled_cached_reqs - for i, req_id in enumerate(req_data.req_ids): - req_index = self.input_batch.req_id_to_index[req_id] + cached_reqs = scheduler_output.scheduled_cached_reqs + for i, req_id in enumerate(cached_reqs.req_ids): + req_index = self.requests.req_id_to_index[req_id] # Update input batch. if not is_last_rank: # When using PP, the scheduler sends the sampled tokens back, # because there's no direct communication between the first- # stage worker and the last-stage worker. - new_token_ids = req_data.new_token_ids[i] - self.input_batch.append_token_ids(req_index, new_token_ids) + new_token_ids = cached_reqs.new_token_ids[i] + self.requests.append_token_ids(req_index, new_token_ids) - new_block_ids = req_data.new_block_ids[i] - if new_block_ids is not None: + if cached_reqs.new_block_ids[i] is not None: + req_indices.append(req_index) + for i, block_ids in enumerate(cached_reqs.new_block_ids[i]): + x = cu_num_new_blocks[i][-1] + cu_num_new_blocks[i].append(x + len(block_ids)) + new_block_ids[i].extend(block_ids) # If the request is resumed from preemption, we need to # overwrite the existing block IDs. - self.input_batch.append_block_ids( - req_index, - new_block_ids, - overwrite=req_data.resumed_from_preemption[i], - ) + overwrite.append(cached_reqs.resumed_from_preemption[i]) - self.input_batch.num_computed_tokens.np[req_index] = ( - req_data.num_computed_tokens[i]) + self.requests.num_computed_tokens.np[req_index] = ( + cached_reqs.num_computed_tokens[i]) - def _init_mrope_states(self, req_state: CachedRequestState) -> None: - image_grid_thw = [] - video_grid_thw = [] - second_per_grid_ts = [] - audio_feature_lengths = [] - use_audio_in_video = False - for mm_item in req_state.mm_kwargs: - mm_input = mm_item.get_data() - if (t := mm_input.get("image_grid_thw")) is not None: - image_grid_thw.append(t.tolist()) - if (t := mm_input.get("video_grid_thw")) is not None: - video_grid_thw.append(t.tolist()) - if (t := mm_input.get("second_per_grid_ts")) is not None: - second_per_grid_ts.append(t) - if (t := mm_input.get("audio_feature_lengths")) is not None: - audio_feature_lengths.append(t) - if mm_input.get("use_audio_in_video") is True: - use_audio_in_video = True - - req_state.mrope_positions, req_state.mrope_position_delta = \ - MRotaryEmbedding.get_input_positions_tensor( - req_state.prompt_token_ids, - hf_config=self.model_config.hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - audio_feature_lengths=audio_feature_lengths, - use_audio_in_video=use_audio_in_video, + if req_indices: + self.block_tables.append_block_ids( + req_indices=req_indices, + cu_num_new_blocks=cu_num_new_blocks, + new_block_ids=new_block_ids, + overwrite=overwrite, ) - def _init_mrope_positions(self, req_state: CachedRequestState): + def _init_mrope_positions(self, req_id: str) -> None: + req_idx = self.requests.req_id_to_index[req_id] + req_data = self.requests.req_data[req_idx] + image_grid_thw = [] video_grid_thw = [] second_per_grid_ts = [] audio_feature_lengths = [] use_audio_in_video = False - for mm_item in req_state.mm_kwargs: + for mm_item in req_data.mm_kwargs: mm_input = mm_item.get_data() if (t := mm_input.get("image_grid_thw")) is not None: image_grid_thw.append(t.tolist()) @@ -557,9 +517,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if mm_input.get("use_audio_in_video") is True: use_audio_in_video = True - req_state.mrope_positions, req_state.mrope_position_delta = \ + req_data.mrope_positions, req_data.mrope_position_delta = \ MRotaryEmbedding.get_input_positions_tensor( - req_state.prompt_token_ids, + req_data.prompt_token_ids, hf_config=self.model_config.hf_config, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, @@ -622,91 +582,55 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def _prepare_inputs( self, scheduler_output: "SchedulerOutput", - ) -> tuple[dict[str, Any], torch.Tensor, Optional[SpecDecodeMetadata], - np.ndarray, Optional[CommonAttentionMetadata], int]: - """ - :return: tuple[ - attn_metadata: layer-to-attention_metadata mapping, - logits_indices, spec_decode_metadata - ] - """ + ) -> InputBatch: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 - num_reqs = self.input_batch.num_reqs - assert num_reqs > 0 + num_reqs = len(scheduler_output.num_scheduled_tokens) - # FIXME # batch_idx -> req_id - req_ids = list(scheduler_output.num_scheduled_tokens.keys()) + req_ids = sorted(scheduler_output.num_scheduled_tokens, + key=scheduler_output.num_scheduled_tokens.get) + # req_id -> batch_idx req_id_to_batch_idx = {req_id: i for i, req_id in enumerate(req_ids)} + # batch_idx -> req_idx - idx_mapping = [ - self.input_batch.req_id_to_index[req_id] for req_id in req_ids + idx_mapping_list = [ + self.requests.req_id_to_index[req_id] for req_id in req_ids ] - # batch_idx -> req_idx - idx_mapping_tensor = self.input_batch.get_index_mapping(idx_mapping) - num_reqs = len(req_ids) + self.index_mapping_np[:num_reqs] = idx_mapping_list + index_mapping_np = self.index_mapping_np[:num_reqs] + idx_mapping = self.index_mapping[:num_reqs].copy_( + self.index_mapping_cpu[:num_reqs], non_blocking=True) # OPTIMIZATION: Start copying the block table first. # This way, we can overlap the copy with the following CPU operations. - block_tables = self.input_batch.make_block_tables(idx_mapping_tensor) + block_tables = self.block_tables.compute_block_tables(idx_mapping) # Get the number of scheduled tokens for each request. tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] num_scheduled_tokens = np.array(tokens, dtype=np.int32) max_num_scheduled_tokens = max(tokens) - # Get request indices. - # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] - req_indices = np.repeat(self.arange_np[:num_reqs], - num_scheduled_tokens) - - # cu_num_tokens: [2, 5, 3] -> [2, 7, 10] - # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - cu_num_tokens, arange = self._get_cumsum_and_arange( - num_scheduled_tokens) - - # Get positions. - positions_np = self.positions_np[:total_num_scheduled_tokens] - np.add(self.input_batch.num_computed_tokens.np[req_indices], - arange, - out=positions_np) - + prepare_inputs( + idx_mapping=index_mapping_np, + token_ids=self.requests.token_ids.np, + num_computed_tokens=self.requests.num_computed_tokens.np, + num_scheduled_tokens=num_scheduled_tokens, + input_ids=self.input_ids_np, + query_start_loc=self.query_start_loc_np, + seq_lens=self.seq_lens_np, + positions=self.positions_np, + ) # Calculate M-RoPE positions. # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: self._calc_mrope_positions(scheduler_output) - # Get token indices. - # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] - # where M is the max_model_len. - token_indices = (positions_np + - req_indices * self.input_batch.token_ids.np.shape[1]) - - # NOTE(woosuk): We use torch.index_select instead of np.take here - # because torch.index_select is much faster than np.take for large - # tensors. - torch.index_select(self.input_batch.token_ids.cpu.flatten(), - 0, - torch.from_numpy(token_indices), - out=self.input_ids_cpu[:total_num_scheduled_tokens]) - # Prepare the attention metadata. - self.query_start_loc_np[0] = 0 - self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens - # Note: pad query_start_loc to be non-decreasing, as kernels - # like FlashAttention requires that - self.query_start_loc_np[num_reqs + 1:].fill(cu_num_tokens[-1]) self.query_start_loc.copy_(self.query_start_loc_cpu, non_blocking=True) query_start_loc = self.query_start_loc[:num_reqs + 1] - self.seq_lens_np[:num_reqs] = ( - self.input_batch.num_computed_tokens.np[:num_reqs] + - num_scheduled_tokens) - # Fill unused with 0 for full cuda graph mode. - self.seq_lens_np[num_reqs:].fill(0) self.seq_lens.copy_(self.seq_lens_cpu, non_blocking=True) seq_lens = self.seq_lens[:num_reqs] max_seq_len = self.seq_lens_np[:num_reqs].max().item() @@ -714,16 +638,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Copy the tensors to the GPU. self.input_ids[:total_num_scheduled_tokens].copy_( self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) + # Common case (1D positions) + self.positions[:total_num_scheduled_tokens].copy_( + self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True) if self.uses_mrope: # Only relevant for models using M-RoPE (e.g, Qwen2-VL) self.mrope_positions[:, :total_num_scheduled_tokens].copy_( self.mrope_positions_cpu[:, :total_num_scheduled_tokens], non_blocking=True) - else: - # Common case (1D positions) - self.positions[:total_num_scheduled_tokens].copy_( - self.positions_cpu[:total_num_scheduled_tokens], - non_blocking=True) use_spec_decode = len( scheduler_output.scheduled_spec_decode_tokens) > 0 @@ -737,16 +659,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): spec_decode_metadata = None else: # Get the number of draft tokens for each request. - # Iterate over the dictionary rather than all requests since not all - # requests have draft tokens. num_draft_tokens = np.zeros(num_reqs, dtype=np.int32) - for req_id, draft_token_ids in ( - scheduler_output.scheduled_spec_decode_tokens.items()): - batch_idx = req_id_to_batch_idx[req_id] - num_draft_tokens[batch_idx] = len(draft_token_ids) - + for i, req_id in enumerate(req_ids): + draft_token_ids = ( + scheduler_output.scheduled_spec_decode_tokens.get(req_id)) + if draft_token_ids: + num_draft_tokens[i] = len(draft_token_ids) spec_decode_metadata = self._calc_spec_decode_metadata( - num_draft_tokens, cu_num_tokens) + num_draft_tokens, self.query_start_loc_np[1:num_reqs + 1]) logits_indices = spec_decode_metadata.logits_indices logits_indices_padded = None @@ -774,15 +694,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded] ) - attn_metadata: dict[str, Any] = {} + slot_mappings = self.block_tables.compute_slot_mappings( + query_start_loc, self.positions[:total_num_scheduled_tokens]) # Used in the below loop. query_start_loc_cpu = self.query_start_loc_cpu[:num_reqs + 1] seq_lens_cpu = self.seq_lens_cpu[:num_reqs] num_computed_tokens_cpu = ( - self.input_batch.num_computed_tokens.cpu[:num_reqs]) + self.requests.num_computed_tokens.cpu[:num_reqs]) spec_decode_common_attn_metadata = None + attn_metadata: dict[str, Any] = {} # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( @@ -804,14 +726,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): non_blocking=True) num_common_prefix_blocks = 0 else: - blk_table = self.input_batch.block_table[kv_cache_group_id] - blk_table_tensor = blk_table.get_device_tensor()[:num_reqs] - slot_mapping = blk_table.slot_mapping[: - total_num_scheduled_tokens] - - # Fill unused with -1. Needed for reshape_and_cache in full cuda - # graph mode. - blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1) + blk_table_tensor = block_tables[kv_cache_group_id] + slot_mapping = slot_mappings[kv_cache_group_id] num_common_prefix_blocks = ( scheduler_output. num_common_prefix_blocks[kv_cache_group_id]) @@ -876,13 +792,24 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): continue attn_metadata[layer_name] = attn_metadata_i - # Hot-Swap lora model - if self.lora_config: - self.set_active_loras(self.input_batch, num_scheduled_tokens) + # # Hot-Swap lora model + # if self.lora_config: + # self.set_active_loras(input_batch, num_scheduled_tokens) - return (attn_metadata, logits_indices, spec_decode_metadata, - num_scheduled_tokens, spec_decode_common_attn_metadata, - max_num_scheduled_tokens) + return InputBatch( + req_ids=req_ids, + num_scheduled_tokens=num_scheduled_tokens, + req_id_to_batch_idx=req_id_to_batch_idx, + idx_mapping=idx_mapping, + idx_mapping_np=index_mapping_np, + num_reqs=num_reqs, + total_num_tokens=total_num_scheduled_tokens, + max_num_tokens=max_num_scheduled_tokens, + attn_metadata=attn_metadata, + spec_decode_metadata=spec_decode_metadata, + spec_decode_common_attn_metadata=spec_decode_common_attn_metadata, + logits_indices=logits_indices, + ) def _compute_cascade_attn_prefix_len( self, @@ -955,7 +882,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_reqs = len(num_scheduled_tokens) common_prefix_len = min( common_prefix_len, - self.input_batch.num_computed_tokens.np[:num_reqs].min()) + self.requests.num_computed_tokens.np[:num_reqs].min()) # common_prefix_len should be a multiple of the block size. common_prefix_len = (common_prefix_len // kv_cache_spec.block_size * kv_cache_spec.block_size) @@ -979,16 +906,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) return common_prefix_len if use_cascade else 0 - def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): + def _calc_mrope_positions(self, input_batch: InputBatch): mrope_pos_ptr = 0 - for index, req_id in enumerate(self.input_batch.req_ids): + for i, req_id in enumerate(input_batch.req_ids): req = self.requests[req_id] assert req.mrope_positions is not None num_computed_tokens = \ - self.input_batch.num_computed_tokens_cpu[index] + self.requests.num_computed_tokens_cpu[i] num_scheduled_tokens = \ - scheduler_output.num_scheduled_tokens[req_id] + input_batch.num_scheduled_tokens[i] num_prompt_tokens = len(req.prompt_token_ids) if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens: @@ -1159,17 +1086,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def _gather_mm_embeddings( self, - scheduler_output: "SchedulerOutput", + input_batch: InputBatch, shift_computed_tokens: int = 0, ) -> list[torch.Tensor]: mm_embeds: list[torch.Tensor] = [] - for req_id in self.input_batch.req_ids: - num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ - req_id] - req_state = self.requests[req_id] - num_computed_tokens = \ - req_state.num_computed_tokens + shift_computed_tokens - mm_positions = req_state.mm_positions + for i, req_id in enumerate(input_batch.req_ids): + num_scheduled_tokens = input_batch.num_scheduled_tokens[i] + req_idx = self.requests.req_id_to_index[req_id] + num_computed_tokens = ( + self.requests.num_computed_tokens.np[req_idx] + + shift_computed_tokens) + req_data = self.requests.req_data[req_idx] + mm_positions = req_data.mm_positions for i, pos_info in enumerate(mm_positions): start_pos = pos_info.offset num_encoder_tokens = pos_info.length @@ -1274,8 +1202,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # request in the batch, as the logit indices are offset by this amount. struct_out_req_batch_indices: dict[str, int] = {} cumulative_offset = 0 - seq = sorted(self.input_batch.req_id_to_index.items(), - key=lambda x: x[1]) + seq = sorted(self.requests.req_id_to_index.items(), key=lambda x: x[1]) for req_id, batch_index in seq: logit_index = batch_index + cumulative_offset cumulative_offset += len( @@ -1431,7 +1358,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return ModelRunnerOutput( req_ids=self.input_batch.req_ids, - req_id_to_index=self.input_batch.req_id_to_index, + req_id_to_index=self.input_batch.req_id_to_batch_idx, sampled_token_ids=[], logprobs=None, prompt_logprobs_dict={}, @@ -1455,21 +1382,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.vllm_config) # Prepare the decoder inputs. - (attn_metadata, logits_indices, spec_decode_metadata, - num_scheduled_tokens_np, spec_decode_common_attn_metadata, - max_query_len) = self._prepare_inputs(scheduler_output) - - # FIXME - # batch_idx -> req_id - req_ids = list(scheduler_output.num_scheduled_tokens.keys()) - # req_id -> batch_idx - req_id_to_batch_idx = {req_id: i for i, req_id in enumerate(req_ids)} - # batch_idx -> req_idx - idx_mapping = [ - self.input_batch.req_id_to_index[req_id] for req_id in req_ids - ] - # batch_idx -> req_idx - idx_mapping_tensor = self.input_batch.get_index_mapping(idx_mapping) + input_batch = self._prepare_inputs(scheduler_output) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE @@ -1540,8 +1453,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): intermediate_tensors = self.sync_and_slice_intermediate_tensors( num_input_tokens, intermediate_tensors, True) - uniform_decode = (max_query_len == self.uniform_decode_query_len) and ( - num_scheduled_tokens == self.input_batch.num_reqs * max_query_len) + uniform_decode = (input_batch.max_num_tokens + == self.uniform_decode_query_len + and num_scheduled_tokens + == input_batch.num_reqs * input_batch.max_num_tokens) batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, uniform_decode=uniform_decode) cudagraph_runtime_mode, batch_descriptor = \ @@ -1550,7 +1465,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Run the model. # Use persistent buffers for CUDA graphs. with set_forward_context( - attn_metadata, + input_batch.attn_metadata, self.vllm_config, num_tokens=num_input_tokens, num_tokens_across_dp=num_tokens_across_dp, @@ -1590,11 +1505,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): all_gather_group=get_tp_group()) logits = None else: - if self.input_batch.pooling_params: - return self._pool(hidden_states, num_scheduled_tokens, - num_scheduled_tokens_np, kv_connector_output) - - sample_hidden_states = hidden_states[logits_indices] + sample_hidden_states = hidden_states[input_batch.logits_indices] logits = self.model.compute_logits(sample_hidden_states, None) if broadcast_pp_output: model_output_broadcast_data = { @@ -1610,9 +1521,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.apply_grammar_bitmask(scheduler_output, logits) # Sample the next token and get logprobs if needed. - sampling_metadata = self.input_batch.make_sampling_metadata( - idx_mapping_tensor) - if spec_decode_metadata is None: + sampling_metadata = self.requests.make_sampling_metadata( + input_batch.idx_mapping) + if input_batch.spec_decode_metadata is None: sampler_output = self.sampler( logits=logits, sampling_metadata=sampling_metadata, @@ -1623,7 +1534,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # logits tensor. This means any in-place operations on bonus_logits # won't affect the original logits tensor. assert logits is not None - bonus_logits = logits[spec_decode_metadata.bonus_logits_indices] + bonus_logits = logits[ + input_batch.spec_decode_metadata.bonus_logits_indices] sampler_output = self.sampler( logits=bonus_logits, sampling_metadata=sampling_metadata, @@ -1633,9 +1545,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Just like `bonus_logits`, `target_logits` is a new tensor with # separate storage from the original `logits` tensor. Therefore, # it is safe to update `target_logits` in place. - target_logits = logits[spec_decode_metadata.target_logits_indices] + target_logits = logits[ + input_batch.spec_decode_metadata.target_logits_indices] output_token_ids = self.rejection_sampler( - spec_decode_metadata, + input_batch.spec_decode_metadata, None, # draft_probs target_logits, bonus_token_ids, @@ -1643,6 +1556,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) sampler_output.sampled_token_ids = output_token_ids + for i in range(input_batch.num_reqs): + req_idx = input_batch.idx_mapping_np[i] + num_tokens = input_batch.num_scheduled_tokens[i] + self.requests.num_computed_tokens.np[req_idx] += num_tokens + num_nans_in_logits = {} if envs.VLLM_COMPUTE_NANS_IN_LOGITS: num_nans_in_logits = self._get_nans_in_logits(logits) @@ -1664,27 +1582,25 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): max_gen_len = sampled_token_ids.shape[-1] if max_gen_len == 1: # No spec decode tokens. - valid_sampled_token_ids = sampled_token_ids.tolist() + valid_sampled_token_ids_np = sampled_token_ids.cpu().numpy() + valid_sampled_token_ids = valid_sampled_token_ids_np.tolist() else: # Includes spec decode tokens. valid_sampled_token_ids = self.rejection_sampler.parse_output( - sampled_token_ids, self.input_batch.vocab_size) - # # Mask out the sampled tokens that should not be sampled. - # for i in discard_sampled_tokens_req_indices: - # valid_sampled_token_ids[i].clear() + sampled_token_ids, self.vocab_size) # Cache the sampled tokens in the model runner, so that the scheduler # doesn't need to send them back. # NOTE(woosuk): As an exception, when using PP, the scheduler sends # the sampled tokens back, because there's no direct communication # between the first-stage worker and the last-stage worker. - for req_idx, sampled_ids in enumerate(valid_sampled_token_ids): - if not sampled_ids: - continue - self.input_batch.append_token_ids(req_idx, sampled_ids) + self.requests.append_sampled_token_ids( + input_batch.idx_mapping_np, + valid_sampled_token_ids, + ) if self.speculative_config: - assert spec_decode_common_attn_metadata is not None + assert input_batch.spec_decode_common_attn_metadata is not None self._draft_token_ids = self.propose_draft_token_ids( scheduler_output, valid_sampled_token_ids, @@ -1692,15 +1608,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): hidden_states, sample_hidden_states, aux_hidden_states, - spec_decode_metadata, - spec_decode_common_attn_metadata, + input_batch.spec_decode_metadata, + input_batch.spec_decode_common_attn_metadata, ) self.eplb_step() return ModelRunnerOutput( - req_ids=req_ids, - req_id_to_index=req_id_to_batch_idx, + req_ids=input_batch.req_ids, + req_id_to_index=input_batch.req_id_to_batch_idx, sampled_token_ids=valid_sampled_token_ids, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, @@ -1712,7 +1628,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def take_draft_token_ids(self) -> Optional[DraftTokenIds]: if self._draft_token_ids is None: return None - req_ids = self.input_batch.req_ids + req_ids = self.requests.req_ids if isinstance(self._draft_token_ids, torch.Tensor): draft_token_ids = self._draft_token_ids.tolist() else: @@ -1722,16 +1638,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def propose_draft_token_ids( self, - scheduler_output: "SchedulerOutput", + input_batch: InputBatch, sampled_token_ids: list[list[int]], sampling_metadata: SamplingMetadata, hidden_states: torch.Tensor, sample_hidden_states: torch.Tensor, aux_hidden_states: Optional[torch.Tensor], - spec_decode_metadata: Optional[SpecDecodeMetadata], - common_attn_metadata: CommonAttentionMetadata, ) -> Union[list[list[int]], torch.Tensor]: - num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + num_scheduled_tokens = input_batch.total_num_tokens if self.speculative_config.method == "ngram": assert isinstance(self.drafter, NgramProposer) draft_token_ids = self.propose_ngram_draft_token_ids( @@ -1745,7 +1659,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): indices = [] offset = 0 for num_draft, tokens in zip( - spec_decode_metadata.num_draft_tokens, + input_batch.spec_decode_metadata.num_draft_tokens, sampled_token_ids): indices.append(offset + len(tokens) - 1) offset += num_draft + 1 @@ -1759,7 +1673,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): elif self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) # TODO(woosuk): Refactor the loop. - req_ids = self.input_batch.req_ids + req_ids = input_batch.req_ids next_token_ids: list[int] = [] for i, token_ids in enumerate(sampled_token_ids): if token_ids: @@ -1771,14 +1685,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_id = req_ids[i] req_state = self.requests[req_id] seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) + input_batch.num_scheduled_tokens[i]) next_token_id = req_state.get_token_id(seq_len) next_token_ids.append(next_token_id) next_token_ids = torch.tensor(next_token_ids, dtype=torch.int32, device=self.device) - if spec_decode_metadata is None: + if input_batch.spec_decode_metadata is None: # input_ids can be None for multimodal models. target_token_ids = self.input_ids[:num_scheduled_tokens] # TODO(woosuk): Support M-RoPE. @@ -1791,7 +1705,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): target_hidden_states = hidden_states[:num_scheduled_tokens] else: # TODO(woosuk): Refactor this. - num_draft_tokens = spec_decode_metadata.num_draft_tokens + num_draft_tokens = input_batch.spec_decode_metadata.num_draft_tokens num_rejected_tokens = [ n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 for i, n in enumerate(num_draft_tokens) @@ -1812,7 +1726,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): target_hidden_states = hidden_states[token_indices] mm_embeds = None if self.supports_mm_inputs: - mm_embeds = self._gather_mm_embeddings(scheduler_output, + mm_embeds = self._gather_mm_embeddings(input_batch, shift_computed_tokens=1) draft_token_ids = self.drafter.propose( @@ -1828,10 +1742,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def propose_ngram_draft_token_ids( self, + input_batch: InputBatch, sampled_token_ids: list[list[int]], ) -> list[list[int]]: # TODO(woosuk): Optimize. - req_ids = self.input_batch.req_ids draft_token_ids: list[list[int]] = [] for i, sampled_ids in enumerate(sampled_token_ids): num_sampled_ids = len(sampled_ids) @@ -1842,19 +1756,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Skip requests that require sampling parameters that are not # supported with speculative decoding. - req_id = req_ids[i] - if req_id in self.input_batch.spec_decode_unsupported_reqs: + req_id = input_batch.req_ids[i] + if req_id in self.requests.spec_decode_unsupported_reqs: draft_token_ids.append([]) continue - num_tokens = self.input_batch.num_tokens_no_spec[i] + num_tokens = self.requests.num_tokens_no_spec[i] if num_tokens >= self.max_model_len: # Skip requests that have already reached the max model length. draft_token_ids.append([]) continue drafter_output = self.drafter.propose( - self.input_batch.token_ids_cpu[i, :num_tokens]) + self.requests.token_ids.np[i, :num_tokens]) if drafter_output is None or len(drafter_output) == 0: draft_token_ids.append([]) else: @@ -1992,11 +1906,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): hidden_states: torch.Tensor, num_scheduled_tokens: dict[str, int], ) -> dict[str, Optional[LogprobsTensors]]: - num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs + num_prompt_logprobs_dict = self.requests.num_prompt_logprobs if not num_prompt_logprobs_dict: return {} - in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu + in_progress_dict = self.requests.in_progress_prompt_logprobs_cpu prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {} # Since prompt logprobs are a rare feature, prioritize simple, @@ -2045,7 +1959,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Get the logits corresponding to this req's prompt tokens. # If this is a partial request (i.e. chunked prefill), # then there is prompt logprob generated for each index. - req_idx = self.input_batch.req_id_to_index[req_id] + req_idx = 0 offset = self.query_start_loc_np[req_idx].item() prompt_hidden_states = hidden_states[offset:offset + num_logits] logits = self.model.compute_logits(prompt_hidden_states, None) @@ -2083,20 +1997,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def _get_nans_in_logits( self, + input_batch: InputBatch, logits: Optional[torch.Tensor], ) -> dict[str, int]: try: if logits is None: - return {req_id: 0 for req_id in self.input_batch.req_ids} + return {req_id: 0 for req_id in input_batch.req_ids} num_nans_in_logits = {} num_nans_for_index = logits.isnan().sum(dim=-1).cpu().numpy() - for req_id in self.input_batch.req_ids: - req_index = self.input_batch.req_id_to_index[req_id] - num_nans_in_logits[req_id] = ( - int(num_nans_for_index[req_index]) - if num_nans_for_index is not None - and req_index < logits.shape[0] else 0) + for i, req_id in input_batch.req_ids: + num_nans_in_logits[req_id] = (int(num_nans_for_index[i]) + if num_nans_for_index is not None + and i < logits.shape[0] else 0) return num_nans_in_logits except IndexError: return {} @@ -2248,17 +2161,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): 1], seq_lens=self.seq_lens[:num_reqs], seq_lens_cpu=self.seq_lens_cpu[:num_reqs], - num_computed_tokens_cpu=self.input_batch. - num_computed_tokens_cpu_tensor[:num_reqs], + num_computed_tokens_cpu=self.requests.num_computed_tokens. + cpu[:num_reqs], num_reqs=num_reqs, num_actual_tokens=num_tokens, max_query_len=max_query_len, max_seq_len=self.max_model_len, - block_table_tensor=self.input_batch.block_table[ - kv_cache_group_id].get_device_tensor()[:num_reqs], - slot_mapping=self.input_batch. - block_table[kv_cache_group_id].slot_mapping[:num_tokens], - causal=True) + block_table_tensor=self.requests. + block_tables[kv_cache_group_id].gpu[:num_reqs], + slot_mapping=self.requests.slot_mappings[kv_cache_group_id] + [:num_tokens], + causal=True, + ) for attn_group in self.attn_groups[kv_cache_group_id]: attn_metadata_i = attn_group.metadata_builder\ diff --git a/vllm/v1/worker/gpu_worker_states.py b/vllm/v1/worker/gpu_worker_states.py new file mode 100644 index 0000000000000..6b149701632a2 --- /dev/null +++ b/vllm/v1/worker/gpu_worker_states.py @@ -0,0 +1,342 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Datastructures defining a GPU input batch + +from dataclasses import dataclass +from typing import Optional, Union + +import numpy as np +import torch +import triton +import triton.language as tl +from typing_extensions import deprecated + +from vllm.lora.request import LoRARequest +from vllm.multimodal.inputs import (MultiModalKwargsItem, + MultiModalKwargsItems, PlaceholderRange) +from vllm.pooling_params import PoolingParams +from vllm.sampling_params import SamplingParams, SamplingType +from vllm.v1.sample.logits_processor import LogitsProcessors +from vllm.v1.sample.metadata import SamplingMetadata + + +@dataclass +class RequestData: + + mm_kwargs: list[MultiModalKwargsItem] + mm_positions: list[PlaceholderRange] + sampling_params: Optional[SamplingParams] + pooling_params: Optional[PoolingParams] + + # M-RoPE (only for Qwen2/2.5-VL) + mrope_positions: Optional[torch.Tensor] = None + mrope_position_delta: Optional[int] = None + + lora_request: Optional[LoRARequest] = None + + # Temporary back-compatibility for plugins that define model runner + @property + @deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be " + "removed in v0.13. Please use `mm_kwargs` instead.") + def mm_inputs(self) -> list[MultiModalKwargsItems]: + return [ + MultiModalKwargsItems.from_seq([item]) for item in self.mm_kwargs + ] + + +class RequestAttribute: + + def __init__( + self, + num_rows_cpu: int, + num_cols: int, + num_rows_gpu: int, + dtype: torch.dtype, + device: torch.device, + pin_memory: bool, + is_scalar: bool = False, + ): + self.cpu = torch.zeros(num_rows_cpu, + num_cols, + dtype=dtype, + device="cpu", + pin_memory=pin_memory) + self.np = self.cpu.numpy() + self.gpu = torch.zeros(num_rows_gpu, + num_cols, + dtype=dtype, + device=device) + if is_scalar: + assert num_cols == 1 + self.cpu.squeeze_(1) + self.np = self.cpu.numpy() + self.gpu.squeeze_(1) + + self.gpu_buffer = self.cpu.to(device) + + def mirror_to_gpu(self) -> torch.Tensor: + return self.gpu_buffer.copy_(self.cpu, non_blocking=True) + + +class RequestState: + + def __init__( + self, + max_num_reqs: int, + max_model_len: int, + max_num_batched_tokens: int, + max_num_cached_reqs: int, + device: torch.device, + pin_memory: bool, + vocab_size: int, + block_sizes: list[int], # The block_size of each kv cache group + logitsprocs: Optional[LogitsProcessors] = None, + is_spec_decode: bool = False, + is_pooling_model: bool = False, + ): + self.max_num_reqs = max_num_reqs + self.max_model_len = max_model_len + self.max_num_batched_tokens = max_num_batched_tokens + self.max_num_cached_reqs = max_num_cached_reqs + self.device = device + self.pin_memory = pin_memory + self.vocab_size = vocab_size + self.is_spec_decode = is_spec_decode + self.pooling_params = None + self.block_sizes = block_sizes + self.num_prompt_logprobs = {} + + self.req_id_to_index: dict[str, int] = {} + self.index_to_req_id: dict[int, str] = {} + self.free_indices = list(range(max_num_cached_reqs)) + # Used to construct the input batch. + self._add_scalar_attr("idx_mapping", torch.int32) + + # Request states. + self.req_data: dict[int, RequestData] = {} + # TODO(woosuk): Because the token_ids tensor can be very big, we only + # initialize it on CPU memory. + self._add_vector_attr("token_ids", + self.max_model_len, + torch.int32, + cpu_only=True) + self._add_scalar_attr("num_prompt_tokens", torch.int32) + self._add_scalar_attr("num_tokens", torch.int32) + self._add_scalar_attr("num_computed_tokens", torch.int32) + + # Sampling-related. + self._add_scalar_attr("temperature", torch.float32) + self.greedy_reqs: set[str] = set() + self.random_reqs: set[str] = set() + self._add_scalar_attr("top_p", torch.float32) + self.top_p_reqs: set[str] = set() + self._add_scalar_attr("top_k", torch.int32) + self.top_k_reqs: set[str] = set() + self._add_scalar_attr("frequency_penalties", torch.float32) + self.frequency_penalties_reqs: set[str] = set() + self._add_scalar_attr("presence_penalties", torch.float32) + self.presence_penalties_reqs: set[str] = set() + self._add_scalar_attr("repetition_penalties", torch.float32) + self.repetition_penalties_reqs: set[str] = set() + + # req_idx -> generator + self.generators: dict[int, torch.Generator] = {} + + def add_request( + self, + req_id: str, + prompt_token_ids: list[int], + num_computed_tokens: int, + sampling_params: SamplingParams, + ) -> None: + req_idx = self.free_indices.pop() + self.req_id_to_index[req_id] = req_idx + self.index_to_req_id[req_idx] = req_id + + self.num_prompt_tokens.np[req_idx] = len(prompt_token_ids) + self.num_computed_tokens.np[req_idx] = num_computed_tokens + self.append_token_ids(req_idx, prompt_token_ids) + + self.temperature.np[req_idx] = sampling_params.temperature + if sampling_params.sampling_type == SamplingType.GREEDY: + # NOTE: Be careful about division by zero. + self.greedy_reqs.add(req_id) + elif sampling_params.sampling_type == SamplingType.RANDOM: + self.random_reqs.add(req_id) + + self.top_p.np[req_idx] = sampling_params.top_p + if sampling_params.top_p < 1.0: + self.top_p_reqs.add(req_id) + + top_k = sampling_params.top_k + if 0 < top_k < self.vocab_size: + self.top_k_reqs.add(req_id) + else: + top_k = self.vocab_size + self.top_k.np[req_idx] = top_k + + self.frequency_penalties.np[ + req_idx] = sampling_params.frequency_penalty + if sampling_params.frequency_penalty != 0.0: + self.frequency_penalties_reqs.add(req_id) + self.presence_penalties.np[req_idx] = sampling_params.presence_penalty + if sampling_params.presence_penalty != 0.0: + self.presence_penalties_reqs.add(req_id) + self.repetition_penalties.np[ + req_idx] = sampling_params.repetition_penalty + if sampling_params.repetition_penalty != 1.0: + self.repetition_penalties_reqs.add(req_id) + + if sampling_params.sampling_type == SamplingType.RANDOM_SEED: + generator = torch.Generator(device=self.device) + generator.manual_seed(sampling_params.seed) + self.generators[req_idx] = generator + + def append_token_ids( + self, + req_idx: int, + token_ids: Union[list[int], np.ndarray], + ) -> None: + start_idx = self.num_tokens.np[req_idx] + end_idx = start_idx + len(token_ids) + self.token_ids.np[req_idx, start_idx:end_idx] = token_ids + self.num_tokens.np[req_idx] = end_idx + + def append_sampled_token_ids( + self, + idx_mapping: np.ndarray, + sampled_token_ids: np.ndarray, + ) -> None: + num_reqs = idx_mapping.shape[0] + for i in range(num_reqs): + req_idx = idx_mapping[i] + self.append_token_ids(req_idx, sampled_token_ids[i]) + + def remove_request(self, req_id: str) -> None: + req_idx = self.req_id_to_index.pop(req_id, None) + if req_idx is None: + # Request not found. + return + self.index_to_req_id.pop(req_idx, None) + self.free_indices.append(req_idx) + + self.greedy_reqs.discard(req_id) + self.random_reqs.discard(req_id) + self.top_p_reqs.discard(req_id) + self.top_k_reqs.discard(req_id) + self.frequency_penalties_reqs.discard(req_id) + self.presence_penalties_reqs.discard(req_id) + self.repetition_penalties_reqs.discard(req_id) + self.generators.pop(req_idx, None) + + def make_sampling_metadata( + self, + batch_idx_to_req_idx: torch.Tensor, + ) -> SamplingMetadata: + batch_size = batch_idx_to_req_idx.shape[0] + _make_sampling_metadata_kernel[(batch_size, )]( + batch_idx_to_req_idx, + self.temperature.mirror_to_gpu(), + self.temperature.gpu, + self.top_p.mirror_to_gpu(), + self.top_p.gpu, + self.top_k.mirror_to_gpu(), + self.top_k.gpu, + self.frequency_penalties.mirror_to_gpu(), + self.frequency_penalties.gpu, + self.presence_penalties.mirror_to_gpu(), + self.presence_penalties.gpu, + self.repetition_penalties.mirror_to_gpu(), + self.repetition_penalties.gpu, + num_warps=1, + num_stages=1, + ) + no_penalties = not (self.frequency_penalties_reqs + or self.presence_penalties_reqs + or self.repetition_penalties_reqs) + return SamplingMetadata( + temperature=self.temperature.gpu[:batch_size], + all_greedy=not self.random_reqs, + all_random=not self.greedy_reqs, + top_p=self.top_p.gpu[:batch_size], + top_k=self.top_k.gpu[:batch_size], + frequency_penalties=self.frequency_penalties.gpu[:batch_size], + presence_penalties=self.presence_penalties.gpu[:batch_size], + repetition_penalties=self.repetition_penalties.gpu[:batch_size], + no_penalties=no_penalties, + # TODO + generators={}, + token_ids=self.token_ids.cpu[:batch_size], + max_num_logprobs=None, + allowed_token_ids_mask=None, + bad_words_token_ids={}, + logitsprocs=None, + ) + + @property + def num_cached_reqs(self) -> int: + return len(self.req_id_to_index) + + def _add_scalar_attr(self, name: str, dtype: torch.dtype): + attr = RequestAttribute(self.max_num_cached_reqs, + 1, + self.max_num_reqs, + dtype, + self.device, + self.pin_memory, + is_scalar=True) + setattr(self, name, attr) + + def _add_vector_attr( + self, + name: str, + max_len: int, + dtype: torch.dtype, + cpu_only: bool = False, + ): + if cpu_only: + num_rows_gpu = 0 + else: + num_rows_gpu = self.max_num_reqs + attr = RequestAttribute(self.max_num_cached_reqs, max_len, + num_rows_gpu, dtype, self.device, + self.pin_memory) + setattr(self, name, attr) + + +@triton.jit +def _make_sampling_metadata_kernel( + batch_idx_to_req_idx, # [batch_size] + src_temperature, + dst_temperature, + src_top_p, + dst_top_p, + src_top_k, + dst_top_k, + src_frequency_penalties, + dst_frequency_penalties, + src_presence_penalties, + dst_presence_penalties, + src_repetition_penalties, + dst_repetition_penalties, +): + batch_idx = tl.program_id(0) + req_idx = tl.load(batch_idx_to_req_idx + batch_idx) + + temperature = tl.load(src_temperature + req_idx) + tl.store(dst_temperature + batch_idx, temperature) + + top_p = tl.load(src_top_p + req_idx) + tl.store(dst_top_p + batch_idx, top_p) + + top_k = tl.load(src_top_k + req_idx) + tl.store(dst_top_k + batch_idx, top_k) + + frequency_penalties = tl.load(src_frequency_penalties + req_idx) + tl.store(dst_frequency_penalties + batch_idx, frequency_penalties) + + presence_penalties = tl.load(src_presence_penalties + req_idx) + tl.store(dst_presence_penalties + batch_idx, presence_penalties) + + repetition_penalties = tl.load(src_repetition_penalties + req_idx) + tl.store(dst_repetition_penalties + batch_idx, repetition_penalties) diff --git a/vllm/v1/worker/tpu_input_batch.py b/vllm/v1/worker/tpu_input_batch.py index 81c798685cb3a..489edf772e1c9 100644 --- a/vllm/v1/worker/tpu_input_batch.py +++ b/vllm/v1/worker/tpu_input_batch.py @@ -12,7 +12,6 @@ from vllm.sampling_params import SamplingType from vllm.utils import swap_dict_values from vllm.v1.outputs import LogprobsTensors from vllm.v1.worker.block_table import MultiGroupBlockTable -from vllm.v1.worker.gpu_input_batch import CachedRequestState _SAMPLING_EPS = 1e-5 diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index ffc1a11bc3ba1..2ce2bf4531b57 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -298,3 +298,32 @@ def bind_kv_cache( for layer_name, kv_cache in kv_caches.items(): # NOTE: Use list because of v0 PP virtual engine. forward_context[layer_name].kv_cache = [kv_cache] + + +class CpuGpuBuffer: + + def __init__( + self, + *args, + dtype: torch.dtype, + device: torch.device, + pin_memory: bool, + ): + self.cpu = torch.zeros(*args, + dtype=dtype, + device="cpu", + pin_memory=pin_memory) + self.np = self.cpu.numpy() + self.gpu = self.cpu.to(device) + + def copy_to_gpu(self, n: Optional[int] = None) -> None: + if n is None: + return self.gpu.copy_(self.cpu, non_blocking=True) + else: + return self.gpu[:n].copy_(self.cpu[:n], non_blocking=True) + + def copy_to_cpu(self, n: Optional[int] = None) -> None: + if n is None: + return self.cpu.copy_(self.gpu, non_blocking=True) + else: + return self.cpu[:n].copy_(self.gpu[:n], non_blocking=True)