From 8b3c13c48551036380ceaacb3d05cf9066e58816 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 15 Sep 2025 11:17:54 -0700 Subject: [PATCH] wip Signed-off-by: Woosuk Kwon --- vllm/v1/core/sched/output.py | 47 +- vllm/v1/kv_cache_interface.py | 8 + vllm/v1/sample/metadata.py | 25 - vllm/v1/worker/gpu/__init__.py | 0 vllm/v1/worker/gpu/attn_utils.py | 80 + .../block_table.py} | 75 +- vllm/v1/worker/gpu/init_utils.py | 26 + .../input_batch.py} | 54 +- vllm/v1/worker/gpu/model_runner.py | 290 +++ vllm/v1/worker/gpu/sampler.py | 238 +++ vllm/v1/worker/gpu/states.py | 143 ++ vllm/v1/worker/gpu_model_runner.py | 1609 +---------------- vllm/v1/worker/gpu_worker_states.py | 291 --- 13 files changed, 886 insertions(+), 2000 deletions(-) create mode 100644 vllm/v1/worker/gpu/__init__.py create mode 100644 vllm/v1/worker/gpu/attn_utils.py rename vllm/v1/worker/{gpu_block_table.py => gpu/block_table.py} (84%) create mode 100644 vllm/v1/worker/gpu/init_utils.py rename vllm/v1/worker/{gpu_input_batch.py => gpu/input_batch.py} (64%) create mode 100644 vllm/v1/worker/gpu/model_runner.py create mode 100644 vllm/v1/worker/gpu/sampler.py create mode 100644 vllm/v1/worker/gpu/states.py delete mode 100644 vllm/v1/worker/gpu_worker_states.py diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index 56ab396d6d937..c7f0af29372fb 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -30,7 +30,6 @@ class NewRequestData: mm_features: list[MultiModalFeatureSpec] sampling_params: Optional[SamplingParams] pooling_params: Optional[PoolingParams] - block_ids: tuple[list[int], ...] num_computed_tokens: int lora_request: Optional[LoRARequest] @@ -46,7 +45,6 @@ class NewRequestData: mm_features=request.mm_features, sampling_params=request.sampling_params, pooling_params=request.pooling_params, - block_ids=block_ids, num_computed_tokens=request.num_computed_tokens, lora_request=request.lora_request, ) @@ -57,7 +55,6 @@ class NewRequestData: f"prompt_token_ids={self.prompt_token_ids}," f"mm_features={self.mm_features}," f"sampling_params={self.sampling_params}," - f"block_ids={self.block_ids}," f"num_computed_tokens={self.num_computed_tokens}," f"lora_request={self.lora_request}" ")") @@ -69,7 +66,6 @@ class NewRequestData: f"prompt_token_ids_len={len(self.prompt_token_ids)}," f"mm_features={self.mm_features}," f"sampling_params={self.sampling_params}," - f"block_ids={self.block_ids}," f"num_computed_tokens={self.num_computed_tokens}," f"lora_request={self.lora_request}" ")") @@ -77,52 +73,17 @@ class NewRequestData: @bc_linter_include @dataclass -class CachedRequestData: +class SchedulerOutput: req_ids: list[str] - # If resumed_from_preemption is False, new_block_ids will be appended to - # the request's block IDs. If True, new_block_ids will be used as the - # request's block IDs instead of appending to the existing block IDs. - resumed_from_preemption: list[bool] - # NOTE(woosuk): new_token_ids is only used for pipeline parallelism. - # When PP is not used, new_token_ids will be empty. - new_token_ids: list[list[int]] - new_block_ids: list[Optional[tuple[list[int], ...]]] - num_computed_tokens: list[int] - - @property - def num_reqs(self) -> int: - return len(self.req_ids) - - @classmethod - def make_empty(cls) -> CachedRequestData: - return cls( - req_ids=[], - resumed_from_preemption=[], - new_token_ids=[], - new_block_ids=[], - num_computed_tokens=[], - ) - - -@bc_linter_include -@dataclass -class SchedulerOutput: + cu_new_block_ids: tuple[np.ndarray, ...] # list of the requests that are scheduled for the first time. # We cache the request's data in each worker process, so that we don't # need to re-send it every scheduling step. scheduled_new_reqs: list[NewRequestData] - # list of the requests that have been scheduled before. - # Since the request's data is already cached in the worker processes, - # we only send the diff to minimize the communication cost. - scheduled_cached_reqs: CachedRequestData - # req_id -> num_scheduled_tokens - # Number of tokens scheduled for each request. num_scheduled_tokens: dict[str, int] - # Total number of tokens scheduled for all requests. - # Equal to sum(num_scheduled_tokens.values()) total_num_scheduled_tokens: int # req_id -> spec_token_ids # If a request does not have any spec decode tokens, it will not be @@ -136,13 +97,11 @@ class SchedulerOutput: # This can be used for cascade attention. num_common_prefix_blocks: list[int] + preempted_req_ids: set[str] # Request IDs that are finished in between the previous and the current # steps. This is used to notify the workers about the finished requests # so that they can free the cached states for those requests. finished_req_ids: set[str] - # list of mm_hash strings associated with the encoder outputs to be - # freed from the encoder cache. - free_encoder_mm_hashes: list[str] # Dict of request ids to their index within the batch # for filling the next token bitmask diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 6e8f569fff0e3..ec3d5e8f29df1 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -3,6 +3,7 @@ import copy from dataclasses import dataclass, fields +from functools import cached_property from math import prod from typing import Optional @@ -273,3 +274,10 @@ class KVCacheConfig: see `_get_kv_cache_config_uniform_page_size` for more details. """ kv_cache_groups: list[KVCacheGroupSpec] + + @cached_property + def block_sizes(self) -> list[int]: + return [ + kv_cache_group.kv_cache_spec.block_size + for kv_cache_group in self.kv_cache_groups + ] diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index 4749542cb6306..af8179baca0ed 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -6,39 +6,14 @@ from typing import Optional import torch -from vllm.v1.sample.logits_processor import LogitsProcessors - @dataclass class SamplingMetadata: temperature: torch.Tensor - all_greedy: bool - all_random: bool top_p: Optional[torch.Tensor] top_k: Optional[torch.Tensor] - generators: dict[int, torch.Generator] - # None means no logprobs, 0 means sampled token logprobs only max_num_logprobs: Optional[int] - - no_penalties: bool - frequency_penalties: torch.Tensor - presence_penalties: torch.Tensor - repetition_penalties: torch.Tensor - - token_ids: Optional[torch.Tensor] - num_tokens: Optional[torch.Tensor] - num_prompt_tokens: Optional[torch.Tensor] - - # `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size, - # vocab size). - allowed_token_ids_mask: Optional[torch.Tensor] - - # req_index -> bad_words_token_ids - bad_words_token_ids: dict[int, list[list[int]]] - - # Loaded logits processors - logitsprocs: LogitsProcessors diff --git a/vllm/v1/worker/gpu/__init__.py b/vllm/v1/worker/gpu/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/v1/worker/gpu/attn_utils.py b/vllm/v1/worker/gpu/attn_utils.py new file mode 100644 index 0000000000000..c36422c3ce8a2 --- /dev/null +++ b/vllm/v1/worker/gpu/attn_utils.py @@ -0,0 +1,80 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + +from vllm.attention.backends.abstract import AttentionType +from vllm.attention.layer import Attention +from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec, + SlidingWindowSpec) + + +def get_kv_cache_spec( + vllm_config: VllmConfig, + kv_cache_dtype: torch.dtype, +) -> dict[str, KVCacheSpec]: + block_size = vllm_config.cache_config.block_size + use_mla = vllm_config.model_config.use_mla + + kv_cache_spec: dict[str, KVCacheSpec] = {} + attn_layers = get_layers_from_vllm_config(vllm_config, Attention) + for layer_name, attn_module in attn_layers.items(): + assert attn_module.attn_type == AttentionType.DECODER + if attn_module.sliding_window is not None: + kv_cache_spec[layer_name] = SlidingWindowSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=kv_cache_dtype, + sliding_window=attn_module.sliding_window, + use_mla=use_mla, + ) + else: + kv_cache_spec[layer_name] = FullAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=kv_cache_dtype, + use_mla=use_mla, + ) + return kv_cache_spec + + +def init_attn_backend(vllm_config: VllmConfig): + attn_layers = get_layers_from_vllm_config(vllm_config, Attention) + + +def _allocate_kv_cache( + kv_cache_config: KVCacheConfig, + device: torch.device, +): + kv_cache_raw_tensors: dict[str, torch.Tensor] = {} + for kv_cache_tensor in kv_cache_config.kv_cache_tensors: + tensor = torch.zeros(kv_cache_tensor.size, + dtype=torch.int8, + device=device) + for layer_name in kv_cache_tensor.shared_by: + kv_cache_raw_tensors[layer_name] = tensor + + layer_names = set() + for group in kv_cache_config.kv_cache_groups: + for layer_name in group.layer_names: + layer_names.add(layer_name) + assert layer_names == set(kv_cache_raw_tensors.keys() + ), "Some layers are not correctly initialized" + return kv_cache_raw_tensors + + +def _reshape_kv_cache( + kv_cache_config: KVCacheConfig, + kv_cache_raw_tensors: dict[str, torch.Tensor], +): + pass + + +def init_kv_cache( + kv_cache_config: KVCacheConfig, + device: torch.device, +): + kv_cache_raw_tensors = _allocate_kv_cache(kv_cache_config, device) + kv_caches = _reshape_kv_cache(kv_cache_config, kv_cache_raw_tensors) diff --git a/vllm/v1/worker/gpu_block_table.py b/vllm/v1/worker/gpu/block_table.py similarity index 84% rename from vllm/v1/worker/gpu_block_table.py rename to vllm/v1/worker/gpu/block_table.py index e0584d94f8be1..d50a852867adc 100644 --- a/vllm/v1/worker/gpu_block_table.py +++ b/vllm/v1/worker/gpu/block_table.py @@ -18,7 +18,6 @@ class BlockTables: self, block_sizes: list[int], max_num_reqs: int, - max_num_cached_reqs: int, max_num_batched_tokens: int, max_model_len: int, device: torch.device, @@ -26,21 +25,17 @@ class BlockTables: ): self.block_sizes = block_sizes self.max_num_reqs = max_num_reqs - self.max_num_cached_reqs = max_num_cached_reqs self.max_num_batched_tokens = max_num_batched_tokens self.max_model_len = max_model_len self.device = device self.pin_memory = pin_memory self.num_kv_cache_groups = len(self.block_sizes) - # [num_kv_cache_groups, max_num_reqs, max_num_blocks] + # num_kv_cache_groups x [max_num_reqs, max_num_blocks] self.block_tables: list[torch.Tensor] = [] - # [num_kv_cache_groups, max_num_cached_reqs, max_num_blocks] - self.block_table_buffers: list[torch.Tensor] = [] for i in range(self.num_kv_cache_groups): block_size = self.block_sizes[i] max_num_blocks = cdiv(self.max_model_len, block_size) - block_table = torch.zeros( self.max_num_reqs, max_num_blocks, @@ -48,17 +43,16 @@ class BlockTables: device=self.device, ) self.block_tables.append(block_table) - - block_table_buffer = torch.zeros( - self.max_num_cached_reqs, - max_num_blocks, - dtype=torch.int32, - device=self.device, - ) - self.block_table_buffers.append(block_table_buffer) - self.block_table_ptrs = self._make_ptr_tensor(self.block_tables) - self.buffer_ptrs = self._make_ptr_tensor(self.block_table_buffers) + + # Block tables used for model's forward pass. + # num_kv_cache_groups x [max_num_reqs, max_num_blocks] + self.input_block_tables: list[torch.Tensor] = [ + torch.zeros_like(block_table) for block_table in self.block_tables + ] + self.input_block_table_ptrs = self._make_ptr_tensor( + self.input_block_tables) + self.block_table_strides = torch.tensor( [b.stride(0) for b in self.block_tables], dtype=torch.int64, @@ -67,7 +61,7 @@ class BlockTables: dtype=torch.int32, device=self.device) self.num_blocks = torch.zeros(self.num_kv_cache_groups, - self.max_num_cached_reqs, + self.max_num_reqs, dtype=torch.int32, device=self.device) self.slot_mappings = torch.zeros(self.num_kv_cache_groups, @@ -107,7 +101,6 @@ class BlockTables: # [num_reqs] overwrite: list[bool], ) -> None: - # TODO(woosuk): Optimize & simplify this. num_reqs = len(req_indices) self.req_indices.np[:num_reqs] = req_indices self.overwrite.np[:num_reqs] = overwrite @@ -115,19 +108,21 @@ class BlockTables: self.cu_num_new_blocks.np[i, :num_reqs + 1] = cu_num_new_blocks[i] # NOTE(woosuk): Here, we cannot use a fixed-size buffer because there's - # no clear upper bound on the number of new blocks. - new_block_ids_cpu = torch.empty( + # no clear upper bound to the number of new blocks in a single step. + # NOTE(woosuk): The buffer has to be cached, because otherwise we cannot + # guarantee that the buffer is not freed before the copy is completed. + self.new_block_ids_cpu = torch.empty( self.num_kv_cache_groups, max(len(x) for x in new_block_ids), dtype=torch.int32, device="cpu", pin_memory=self.pin_memory, ) - new_block_ids_np = new_block_ids_cpu.numpy() + new_block_ids_np = self.new_block_ids_cpu.numpy() for i in range(self.num_kv_cache_groups): new_block_ids_np[i, :len(new_block_ids[i])] = new_block_ids[i] - new_block_ids_gpu = new_block_ids_cpu.to(self.device, - non_blocking=True) + new_block_ids_gpu = self.new_block_ids_cpu.to(self.device, + non_blocking=True) _append_block_ids_kernel[(self.num_kv_cache_groups, num_reqs)]( self.req_indices.copy_to_gpu(num_reqs), @@ -137,33 +132,34 @@ class BlockTables: new_block_ids_gpu.stride(0), self.overwrite.copy_to_gpu(num_reqs), self.block_table_strides, - self.buffer_ptrs, + self.block_table_ptrs, self.num_blocks, self.num_blocks.stride(0), BLOCK_SIZE=1024, ) - def compute_block_tables( + def gather_block_tables( self, idx_mapping: torch.Tensor, ) -> tuple[torch.Tensor, ...]: - batch_size = idx_mapping.shape[0] - _compute_block_tables_kernel[(self.num_kv_cache_groups, batch_size)]( + num_reqs = idx_mapping.shape[0] + _gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs)]( idx_mapping, - self.buffer_ptrs, self.block_table_ptrs, + self.input_block_table_ptrs, self.block_table_strides, self.num_blocks, self.num_blocks.stride(0), BLOCK_SIZE=1024, ) - return tuple(b[:batch_size] for b in self.block_tables) + return tuple(block_table[:num_reqs] + for block_table in self.input_block_tables) def compute_slot_mappings( self, query_start_loc: torch.Tensor, positions: torch.Tensor, - ) -> tuple[torch.Tensor, ...]: + ) -> torch.Tensor: num_reqs = query_start_loc.shape[0] - 1 num_tokens = positions.shape[0] num_groups = self.num_kv_cache_groups @@ -172,7 +168,7 @@ class BlockTables: self.max_num_batched_tokens, query_start_loc, positions, - self.block_table_ptrs, + self.input_block_table_ptrs, self.block_table_strides, self.block_sizes_tensor, self.slot_mappings, @@ -180,7 +176,7 @@ class BlockTables: PAD_ID=PAD_SLOT_ID, BLOCK_SIZE=1024, ) - return tuple(x[:num_tokens] for x in self.slot_mappings) + return self.slot_mappings[:, :num_tokens] @triton.jit @@ -194,8 +190,8 @@ def _append_block_ids_kernel( overwrite, # [num_reqs] block_table_strides, # [num_kv_cache_groups] # Outputs - block_table_buffer_ptrs, # [num_kv_cache_groups] - num_blocks_ptr, # [num_kv_cache_groups, max_num_cached_reqs] + block_table_ptrs, # [num_kv_cache_groups] + num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs] num_blocks_stride, # Constants BLOCK_SIZE: tl.constexpr, @@ -220,10 +216,9 @@ def _append_block_ids_kernel( tl.store(group_num_blocks_ptr + req_idx, dst_end_idx) # Destination - block_table_buffer_ptr = _load_ptr(block_table_buffer_ptrs + group_id, - tl.int32) + block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32) block_table_stride = tl.load(block_table_strides + group_id) - buffer_row_ptr = block_table_buffer_ptr + req_idx * block_table_stride + row_ptr = block_table_ptr + req_idx * block_table_stride group_new_block_ids_ptr = (new_block_ids_ptr + group_id * new_block_ids_stride) @@ -231,18 +226,18 @@ def _append_block_ids_kernel( offset = i + tl.arange(0, BLOCK_SIZE) block_ids = tl.load(group_new_block_ids_ptr + start_idx + offset, mask=offset < num_new_blocks) - tl.store(buffer_row_ptr + dst_start_idx + offset, + tl.store(row_ptr + dst_start_idx + offset, block_ids, mask=offset < num_new_blocks) @triton.jit -def _compute_block_tables_kernel( +def _gather_block_tables_kernel( batch_idx_to_req_idx, # [batch_size] src_block_table_ptrs, # [num_kv_cache_groups] dst_block_table_ptrs, # [num_kv_cache_groups] block_table_strides, # [num_kv_cache_groups] - num_blocks_ptr, # [num_kv_cache_groups, max_num_cached_reqs] + num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs] num_blocks_stride, BLOCK_SIZE: tl.constexpr, ): diff --git a/vllm/v1/worker/gpu/init_utils.py b/vllm/v1/worker/gpu/init_utils.py new file mode 100644 index 0000000000000..22b9f3660eaa0 --- /dev/null +++ b/vllm/v1/worker/gpu/init_utils.py @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import time + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.model_loader import get_model_loader +from vllm.utils import DeviceMemoryProfiler, GiB_bytes + +logger = init_logger(__name__) + + +def load_model(vllm_config: VllmConfig): + time_before_load = time.perf_counter() + + with DeviceMemoryProfiler() as m: + model_loader = get_model_loader(vllm_config.load_config) + logger.info("Loading model from scratch...") + model = model_loader.load_model(vllm_config=vllm_config, + model_config=vllm_config.model_config) + + time_after_load = time.perf_counter() + logger.info("Model loading took %.4f GiB and %.6f seconds", + m.consumed_memory / GiB_bytes, + time_after_load - time_before_load) + return model diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu/input_batch.py similarity index 64% rename from vllm/v1/worker/gpu_input_batch.py rename to vllm/v1/worker/gpu/input_batch.py index 250b390422c64..63fdb02c1155b 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu/input_batch.py @@ -1,14 +1,42 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Any, Optional +from typing import Any import numba import numpy as np import torch from numba import types -from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.utils import CpuGpuBuffer + + +class InputBuffers: + + def __init__( + self, + max_num_reqs: int, + max_num_tokens: int, + device: torch.device, + pin_memory: bool, + ): + self.max_num_reqs = max_num_reqs + self.max_num_tokens = max_num_tokens + self.device = device + self.pin_memory = pin_memory + + self.idx_mapping = self._make_buffer(max_num_reqs, dtype=torch.int32) + self.input_ids = self._make_buffer(max_num_tokens, dtype=torch.int32) + self.positions = self._make_buffer(max_num_tokens, dtype=torch.int64) + self.query_start_loc = self._make_buffer(max_num_reqs + 1, + dtype=torch.int32) + self.seq_lens = self._make_buffer(max_num_reqs, dtype=torch.int32) + + def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer: + return CpuGpuBuffer(*args, + dtype=dtype, + pin_memory=self.pin_memory, + device=self.device) @dataclass @@ -16,9 +44,7 @@ class InputBatch: # batch_idx -> req_id req_ids: list[str] - - # req_id -> batch_idx - req_id_to_batch_idx: dict[str, int] + num_reqs: int # batch_idx -> req_state_idx idx_mapping: torch.Tensor @@ -26,14 +52,18 @@ class InputBatch: # batch_idx -> num_scheduled_tokens num_scheduled_tokens: np.ndarray - total_num_tokens: int - max_query_len: int - num_reqs: int + # sum(num_scheduled_tokens) + num_tokens: int + # [max_num_batched_tokens] + input_ids: torch.Tensor + # [max_num_batched_tokens] + positions: torch.Tensor + + # layer_name -> Metadata attn_metadata: dict[str, Any] - spec_decode_common_attn_metadata: Optional[Any] - spec_decode_metadata: Optional[SpecDecodeMetadata] + # [num_reqs] logits_indices: torch.Tensor @@ -47,9 +77,9 @@ class InputBatch: types.int32[:], # num_computed_tokens types.int32[:], # num_scheduled_tokens types.int32[:], # input_ids + types.int64[:], # positions types.int32[:], # query_start_loc types.int32[:], # seq_lens - types.int64[:], # positions ) ], nopython=True, @@ -61,9 +91,9 @@ def prepare_inputs( num_computed_tokens: np.ndarray, # [N] num_scheduled_tokens: np.ndarray, # [B] input_ids: np.ndarray, # [num_input_tokens] + positions: np.ndarray, # [num_input_tokens] query_start_loc: np.ndarray, # [B + 1] seq_lens: np.ndarray, # [B] - positions: np.ndarray, # [num_input_tokens] ) -> None: num_reqs = num_scheduled_tokens.shape[0] query_start_loc[0] = 0 diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py new file mode 100644 index 0000000000000..7c29a4b392398 --- /dev/null +++ b/vllm/v1/worker/gpu/model_runner.py @@ -0,0 +1,290 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from copy import deepcopy +from typing import Any + +import numpy as np +import torch + +from vllm.config import VllmConfig +from vllm.forward_context import set_forward_context +from vllm.logger import init_logger +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, is_pin_memory_available +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.sample.sampler import SamplerOutput +from vllm.v1.worker.gpu.attn_utils import get_kv_cache_spec, init_attn_backend +from vllm.v1.worker.gpu.block_table import BlockTables +from vllm.v1.worker.gpu.init_utils import load_model +from vllm.v1.worker.gpu.input_batch import (InputBatch, InputBuffers, + prepare_inputs) +from vllm.v1.worker.gpu.sampler import Sampler +from vllm.v1.worker.gpu.states import RequestState + +logger = init_logger(__name__) + + +class GPUModelRunner: + + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + ): + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.compilation_config = vllm_config.compilation_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.speculative_config = vllm_config.speculative_config + self.observability_config = vllm_config.observability_config + + self.device = device + self.pin_memory = is_pin_memory_available() + self.dtype = self.model_config.dtype + if self.cache_config.cache_dtype == "auto": + self.kv_cache_dtype = self.dtype + else: + self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ + self.cache_config.cache_dtype] + + self.vocab_size = self.model_config.get_vocab_size() + self.max_model_len = self.model_config.max_model_len + self.max_num_tokens = self.scheduler_config.max_num_batched_tokens + self.max_num_reqs = self.scheduler_config.max_num_seqs + + self.req_states = RequestState( + max_num_reqs=self.max_num_reqs, + max_model_len=self.max_model_len, + max_num_batched_tokens=self.max_num_tokens, + vocab_size=self.vocab_size, + device=self.device, + pin_memory=self.pin_memory, + ) + self.input_buffers = InputBuffers( + max_num_reqs=self.max_num_reqs, + max_num_tokens=self.max_num_tokens, + device=self.device, + pin_memory=self.pin_memory, + ) + self.sampler = Sampler() + + def load_model(self) -> None: + self.model = load_model(self.vllm_config) + + def get_kv_cache_spec(self): + return get_kv_cache_spec(self.vllm_config, self.kv_cache_dtype) + + def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: + kv_cache_config = deepcopy(kv_cache_config) + self.kv_cache_config = kv_cache_config + block_sizes = kv_cache_config.block_sizes + + self.block_tables = BlockTables( + block_sizes=block_sizes, + max_num_reqs=self.max_num_reqs, + max_num_batched_tokens=self.max_num_tokens, + max_model_len=self.max_model_len, + device=self.device, + pin_memory=self.pin_memory, + ) + self.attn_metadata_builders = init_attn_backend(self.vllm_config) + + def update_states(self, scheduler_output: SchedulerOutput) -> None: + for req_id in scheduler_output.preempted_req_ids: + self.req_states.remove_request(req_id) + for req_id in scheduler_output.finished_req_ids: + self.req_states.remove_request(req_id) + + # TODO(woosuk): Change SchedulerOutput. + req_indices: list[int] = [] + cu_num_new_blocks = tuple( + [0] for _ in range(self.block_tables.num_kv_cache_groups)) + new_block_ids = tuple( + [] for _ in range(self.block_tables.num_kv_cache_groups)) + overwrite: list[bool] = [] + + # Add new requests to the cached states. + for new_req_data in scheduler_output.scheduled_new_reqs: + req_id = new_req_data.req_id + self.req_states.add_request( + req_id=req_id, + prompt_token_ids=new_req_data.prompt_token_ids, + num_computed_tokens=new_req_data.num_computed_tokens, + sampling_params=new_req_data.sampling_params, + ) + + req_index = self.req_states.req_id_to_index[req_id] + req_indices.append(req_index) + for i, block_ids in enumerate(new_req_data.block_ids): + x = cu_num_new_blocks[i][-1] + cu_num_new_blocks[i].append(x + len(block_ids)) + new_block_ids[i].extend(block_ids) + overwrite.append(True) + + # Update the states of the running/resumed requests. + cached_reqs = scheduler_output.scheduled_cached_reqs + for i, req_id in enumerate(cached_reqs.req_ids): + req_index = self.req_states.req_id_to_index[req_id] + + req_new_block_ids = cached_reqs.new_block_ids[i] + if req_new_block_ids is not None: + req_indices.append(req_index) + for group_id, block_ids in enumerate(req_new_block_ids): + x = cu_num_new_blocks[group_id][-1] + cu_num_new_blocks[group_id].append(x + len(block_ids)) + new_block_ids[group_id].extend(block_ids) + overwrite.append(False) + + self.req_states.num_computed_tokens[req_index] = ( + cached_reqs.num_computed_tokens[i]) + + if req_indices: + self.block_tables.append_block_ids( + req_indices=req_indices, + cu_num_new_blocks=cu_num_new_blocks, + new_block_ids=new_block_ids, + overwrite=overwrite, + ) + + def prepare_inputs(self, scheduler_output: SchedulerOutput) -> InputBatch: + num_tokens = scheduler_output.total_num_scheduled_tokens + assert num_tokens > 0 + num_reqs = len(scheduler_output.num_scheduled_tokens) + + # Decode first, then prefill. + # batch_idx -> req_id + req_ids = sorted(scheduler_output.num_scheduled_tokens, + key=scheduler_output.num_scheduled_tokens.get) + num_scheduled_tokens = np.array( + [scheduler_output.num_scheduled_tokens[i] for i in req_ids], + dtype=np.int32) + + idx_mapping_list = [ + self.req_states.req_id_to_index[req_id] for req_id in req_ids + ] + self.input_buffers.idx_mapping.np[:num_reqs] = idx_mapping_list + idx_mapping_np = self.input_buffers.idx_mapping.np[:num_reqs] + idx_mapping = self.input_buffers.idx_mapping.copy_to_gpu(num_reqs) + + # Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks] + block_tables = self.block_tables.gather_block_tables(idx_mapping) + + input_ids = self.input_buffers.input_ids + positions = self.input_buffers.positions + query_start_loc = self.input_buffers.query_start_loc + seq_lens = self.input_buffers.seq_lens + + prepare_inputs( + idx_mapping_np, + self.req_states.token_ids, + self.req_states.num_computed_tokens, + num_scheduled_tokens, + input_ids.np, + positions.np, + query_start_loc.np, + seq_lens.np, + ) + input_ids.copy_to_gpu(num_tokens) + positions.copy_to_gpu(num_tokens) + + # NOTE(woosuk): We should copy the whole query_start_loc and seq_lens + # tensors from CPU to GPU, because they may include paddings needed + # for full CUDA graph mode. + query_start_loc.copy_to_gpu() + query_start_loc = query_start_loc.gpu[:num_reqs + 1] + max_query_len = int(num_scheduled_tokens.max()) + + seq_lens.copy_to_gpu() + seq_lens_np = seq_lens.np[:num_reqs] + max_seq_len = int(seq_lens_np.max()) + seq_lens = seq_lens.gpu[:num_reqs] + + # Slot mappings: [num_kv_cache_groups, num_tokens] + slot_mappings = self.block_tables.compute_slot_mappings( + query_start_loc, positions.gpu[:num_tokens]) + + logits_indices = query_start_loc[1:] - 1 + + attn_metadata: dict[str, Any] = {} + for i, kv_cache_spec in enumerate( + self.kv_cache_config.kv_cache_groups): + block_table = block_tables[i] + slot_mapping = slot_mappings[i] + + return InputBatch( + req_ids=req_ids, + num_reqs=num_reqs, + idx_mapping=idx_mapping, + idx_mapping_np=idx_mapping_np, + num_scheduled_tokens=num_scheduled_tokens, + num_tokens=num_tokens, + input_ids=input_ids, + positions=positions, + attn_metadata=attn_metadata, + logits_indices=logits_indices, + ) + + def sample( + self, + input_batch: InputBatch, + logits: torch.Tensor, + ) -> SamplerOutput: + sampling_metadata = self.req_states.make_sampling_metadata( + input_batch.idx_mapping_np) + sampler_output = self.sampler( + logits=logits, + sampling_metadata=sampling_metadata, + ) + + return sampler_output + + def postprocess( + self, + input_batch: InputBatch, + sampler_output: SamplerOutput, + ) -> np.ndarray: + # Get the number of sampled tokens. + # Handle requests that are chunked-prefilling. + idx_mapping_np = input_batch.idx_mapping_np + num_computed_tokens = self.req_states.num_computed_tokens[ + idx_mapping_np] + post_num_computed_tokens = (num_computed_tokens + + input_batch.num_scheduled_tokens) + num_tokens = self.req_states.num_tokens[idx_mapping_np] + + is_chunked_prefilling = post_num_computed_tokens < num_tokens + # 0 if chunked-prefilling, 1 if not. + num_sampled_tokens = (~is_chunked_prefilling).astype(np.int32) + # Increment the number of tokens. + self.req_states.num_tokens[idx_mapping_np] += num_sampled_tokens + return num_sampled_tokens + + def execute_model( + self, + scheduler_output: SchedulerOutput, + ): + self.update_states(scheduler_output) + if scheduler_output.total_num_scheduled_tokens == 0: + return + + input_batch = self.prepare_inputs(scheduler_output) + + with set_forward_context( + input_batch.attn_metadata, + self.vllm_config, + ): + hidden_states = self.model( + input_ids=input_batch.input_ids, + positions=input_batch.positions, + ) + + sampling_hidden_states = hidden_states[input_batch.logits_indices] + logits = self.model.compute_logits(sampling_hidden_states, None) + sampler_output = self.sample(input_batch, logits) + + num_sampled_tokens = self.postprocess(input_batch, sampler_output) + return output diff --git a/vllm/v1/worker/gpu/sampler.py b/vllm/v1/worker/gpu/sampler.py new file mode 100644 index 0000000000000..0302302d81f8e --- /dev/null +++ b/vllm/v1/worker/gpu/sampler.py @@ -0,0 +1,238 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch +import torch.nn as nn +import triton +import triton.language as tl + +from vllm.config import LogprobsMode +from vllm.v1.outputs import LogprobsTensors, SamplerOutput +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p + +_SAMPLING_EPS = 1e-5 + + +class Sampler(nn.Module): + + def __init__( + self, + logprobs_mode: LogprobsMode = LogprobsMode.PROCESSED_LOGPROBS, + ): + super().__init__() + assert logprobs_mode == LogprobsMode.PROCESSED_LOGPROBS + self.logprobs_mode = logprobs_mode + + def forward( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> SamplerOutput: + # Divide logits by temperature, in FP32. + logits = apply_temperature(logits, sampling_metadata.temperature) + + # Apply top_k and/or top_p. + logits = apply_top_k_top_p( + logits, + sampling_metadata.top_k, + sampling_metadata.top_p, + ) + + # Compute the probabilities. + probs = torch.softmax(logits, dim=-1, dtype=torch.float32) + # Sample the next token (int64). + sampled = gumbel_sample( + probs, + sampling_metadata.temperature, + None, # seeds + None, # pos + ) + + logprobs_tensors = None + num_logprobs = sampling_metadata.max_num_logprobs + if num_logprobs is not None: + assert num_logprobs >= 0 + # Compute the logprobs. + logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float32) + # Gather the logprobs of the topk and sampled token. + logprobs_tensors = self.gather_logprobs( + logprobs, + num_logprobs, + sampled, + ) + + # These are GPU tensors. + sampler_output = SamplerOutput( + # The sampled tokens are expanded to 2D tensor with shape + # [num_requests, 1], where each row represents one generated + # token per request. + sampled_token_ids=sampled.unsqueeze(-1), + logprobs_tensors=logprobs_tensors, + ) + return sampler_output + + def gather_logprobs( + self, + logprobs: torch.Tensor, + num_logprobs: int, + sampled: torch.Tensor, + ) -> LogprobsTensors: + sampled = sampled.unsqueeze(-1) + sampled_logprobs = logprobs.gather(-1, sampled) + sampled_ranks = (logprobs > sampled_logprobs).sum(-1) + if num_logprobs == 0: + # Return the logprobs of the sampled token. + logprobs_tensors = LogprobsTensors( + sampled, + sampled_logprobs, + sampled_ranks, + ) + else: + # Return (num_logprobs + 1) logprobs. + topk_logprobs, topk_indices = torch.topk( + logprobs, + num_logprobs, + dim=-1, + ) + logprobs_tensors = LogprobsTensors( + torch.cat((sampled, topk_indices), dim=1), + torch.cat((sampled_logprobs, topk_logprobs), dim=1), + sampled_ranks, + ) + return logprobs_tensors + + +@triton.jit +def _apply_temp_kernel( + logits, # bf16[batch_size, vocab_size] + logits_stride, + output, # fp32[batch_size, vocab_size] + output_stride, + temperature, + vocab_size, + BLOCK_SIZE: tl.constexpr, + EPSILON: tl.constexpr, +): + batch_idx = tl.program_id(0) + block_idx = tl.program_id(1) + + temp = tl.load(temperature + batch_idx) + if temp < EPSILON: + # Greedy sampling. Don't apply temperature. + # NOTE(woosuk): In this case, we assume that its logprobs are not used. + temp = tl.ones([1], dtype=tl.float32) + + offset = tl.arange(0, BLOCK_SIZE) + block = block_idx * BLOCK_SIZE + offset + + # Load the logits. + x = tl.load(logits + batch_idx * logits_stride + block, + mask=block < vocab_size) + x = x.to(tl.float32) + x = x / temp + tl.store(output + batch_idx * output_stride + block, + x, + mask=block < vocab_size) + + +def apply_temperature( + logits: torch.Tensor, + temperature: torch.Tensor, +) -> torch.Tensor: + batch_size, vocab_size = logits.shape + output = torch.empty_like(logits, dtype=torch.float32) + BLOCK_SIZE = 8192 + _apply_temp_kernel[(batch_size, triton.cdiv(vocab_size, BLOCK_SIZE))]( + logits, + logits.stride(0), + output, + output.stride(0), + temperature, + vocab_size, + BLOCK_SIZE=BLOCK_SIZE, + EPSILON=_SAMPLING_EPS, + ) + return output + + +@triton.jit +def _apply_gumbel_kernel( + probs_ptr, + probs_stride, + seeds_ptr, + temp_ptr, + vocab_size, + BLOCK_SIZE: tl.constexpr, + EPSILON: tl.constexpr, +): + req_idx = tl.program_id(0) + seed = tl.load(seeds_ptr + req_idx) + temp = tl.load(temp_ptr + req_idx) + + if temp < EPSILON: + # Greedy sampling. Don't apply gumbel noise. + return + + block_id = tl.program_id(1) + r_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + q = tl.rand(seed, r_offset) + + # NOTE(woosuk): This logic makes sure q is not 0. + RMAX = 0.9999999403953552 + RMAX_LOG = -5.960464477539063e-08 + q = tl.where(q >= RMAX, RMAX_LOG, tl.math.log(q)) + q = -1.0 * q + + p = tl.load(probs_ptr + req_idx * probs_stride + r_offset, + mask=r_offset < vocab_size) + p = p / q + + tl.store(probs_ptr + req_idx * probs_stride + r_offset, + p, + mask=r_offset < vocab_size) + + +def gumbel_sample( + # fp32[num_reqs, vocab_size] + probs: torch.Tensor, + # fp32[num_reqs] + temperature: torch.Tensor, + # int64[num_reqs] + seeds: Optional[torch.Tensor], + # int64[num_reqs] + pos: Optional[torch.Tensor], +) -> torch.Tensor: + num_reqs = probs.shape[0] + vocab_size = probs.shape[1] + if seeds is not None: + # Per-request seed. + assert pos is not None + gumbel_seeds = seeds + pos + else: + # Global seed. + assert pos is None + seed_dtype = torch.int64 + gumbel_seeds = torch.randint( + torch.iinfo(seed_dtype).min, + torch.iinfo(seed_dtype).max, + (num_reqs, ), + dtype=seed_dtype, + device=probs.device, + ) + + # Update the probs in-place. + BLOCK_SIZE = 8192 + _apply_gumbel_kernel[(num_reqs, triton.cdiv(vocab_size, BLOCK_SIZE))]( + probs, + probs.stride(0), + gumbel_seeds, + temperature, + vocab_size, + BLOCK_SIZE, + EPSILON=_SAMPLING_EPS, + ) + # Sample the next token. + return probs.argmax(dim=-1).view(-1) diff --git a/vllm/v1/worker/gpu/states.py b/vllm/v1/worker/gpu/states.py new file mode 100644 index 0000000000000..18ce104008ec8 --- /dev/null +++ b/vllm/v1/worker/gpu/states.py @@ -0,0 +1,143 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Union + +import numpy as np +import torch + +from vllm.sampling_params import SamplingParams +from vllm.v1.sample.metadata import SamplingMetadata + + +class RequestState: + + def __init__( + self, + max_num_reqs: int, + max_model_len: int, + max_num_batched_tokens: int, + vocab_size: int, + device: torch.device, + pin_memory: bool, + ): + self.max_num_reqs = max_num_reqs + self.max_model_len = max_model_len + self.max_num_batched_tokens = max_num_batched_tokens + self.vocab_size = vocab_size + self.device = device + self.pin_memory = pin_memory + + self.req_id_to_index: dict[str, int] = {} + self.index_to_req_id: dict[int, str] = {} + self.free_indices = list(range(max_num_reqs)) + + # TODO(woosuk): Because the token_ids tensor can be very big, we only + # initialize it on CPU memory. + self.token_ids = np.zeros( + (self.max_num_reqs, self.max_model_len), + dtype=np.int32, + ) + self.num_tokens = np.zeros(self.max_num_reqs, dtype=np.int32) + self.num_computed_tokens = np.zeros(self.max_num_reqs, dtype=np.int32) + + # Last sampled token ids. + self.last_token = torch.zeros( + self.max_num_reqs, + dtype=torch.int32, + device=self.device, + ) + + # Sampling parameters. + self.temperature = np.zeros(self.max_num_reqs, dtype=np.float32) + self.top_p = np.zeros(self.max_num_reqs, dtype=np.float32) + self.top_k = np.zeros(self.max_num_reqs, dtype=np.int32) + self.num_logprobs = np.empty(self.max_num_reqs, dtype=np.int32) + # -1 means no logprobs are requested. + self.num_logprobs.fill(-1) + + @property + def num_reqs(self) -> int: + return len(self.req_id_to_index) + + def add_request( + self, + req_id: str, + prompt_token_ids: list[int], + num_computed_tokens: int, + sampling_params: SamplingParams, + ) -> None: + assert len(self.free_indices) > 0 + req_idx = self.free_indices.pop() + self.req_id_to_index[req_id] = req_idx + self.index_to_req_id[req_idx] = req_id + + prompt_len = len(prompt_token_ids) + self.num_tokens[req_idx] = prompt_len + self.token_ids[req_idx, :prompt_len] = prompt_token_ids + self.num_computed_tokens[req_idx] = num_computed_tokens + + self.temperature[req_idx] = sampling_params.temperature + self.top_p[req_idx] = sampling_params.top_p + if 0 < sampling_params.top_k < self.vocab_size: + top_k = sampling_params.top_k + else: + top_k = self.vocab_size + self.top_k[req_idx] = top_k + + if sampling_params.num_logprobs is not None: + num_logprobs = sampling_params.num_logprobs + else: + num_logprobs = -1 + self.num_logprobs[req_idx] = num_logprobs + + def append_token_ids( + self, + req_idx: int, + token_ids: Union[list[int], np.ndarray], + ) -> None: + start_idx = self.num_tokens[req_idx] + end_idx = start_idx + len(token_ids) + self.token_ids[req_idx, start_idx:end_idx] = token_ids + self.num_tokens[req_idx] = end_idx + + def remove_request(self, req_id: str) -> None: + req_idx = self.req_id_to_index.pop(req_id, None) + if req_idx is None: + # Request not found. + return + self.index_to_req_id.pop(req_idx, None) + self.free_indices.append(req_idx) + + def make_sampling_metadata( + self, + idx_mapping: np.ndarray, + ) -> SamplingMetadata: + temperature = self.temperature[idx_mapping] + temperature = self._copy_np_to_gpu(temperature) + + top_p = self.top_p[idx_mapping] + no_top_p = np.all(top_p == 1.0) + top_p = self._copy_np_to_gpu(top_p) if not no_top_p else None + + top_k = self.top_k[idx_mapping] + no_top_k = np.all(top_k == self.vocab_size) + top_k = self._copy_np_to_gpu(top_k) if not no_top_k else None + + num_logprobs = self.num_logprobs[idx_mapping] + max_num_logprobs = np.max(num_logprobs) + if max_num_logprobs == -1: + max_num_logprobs = None + + return SamplingMetadata( + temperature=temperature, + top_p=top_p, + top_k=top_k, + max_num_logprobs=max_num_logprobs, + ) + + def _copy_np_to_gpu(self, src: np.ndarray) -> torch.Tensor: + cpu_tensor = torch.from_numpy(src) + if self.pin_memory: + cpu_tensor = cpu_tensor.pin_memory() + return cpu_tensor.to(device=self.device, non_blocking=True) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 910367252402d..9c629c784bfff 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -8,73 +8,45 @@ from collections import defaultdict from collections.abc import Iterator from contextlib import contextmanager from copy import deepcopy -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Optional, Union import numpy as np import torch import torch.distributed -import torch.nn as nn from tqdm import tqdm import vllm.envs as envs -from vllm.attention import Attention, AttentionType from vllm.attention.backends.abstract import AttentionBackend -from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention from vllm.compilation.counter import compilation_counter -from vllm.compilation.cuda_graph import CUDAGraphWrapper from vllm.compilation.monitor import set_cudagraph_capturing_enabled -from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig, - get_layers_from_vllm_config, update_config) +from vllm.config import CUDAGraphMode, VllmConfig, get_layers_from_vllm_config from vllm.distributed.eplb.eplb_state import EplbState -from vllm.distributed.kv_transfer import (get_kv_transfer_group, - has_kv_transfer_group) -from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks -from vllm.distributed.parallel_state import ( - get_pp_group, get_tp_group, graph_capture, is_global_first_rank, - prepare_communication_buffer_for_model) -from vllm.forward_context import (BatchDescriptor, DPMetadata, - set_forward_context) +from vllm.distributed.kv_transfer import has_kv_transfer_group +from vllm.distributed.parallel_state import (get_pp_group, get_tp_group, + graph_capture, + is_global_first_rank) +from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase -from vllm.model_executor.layers.mamba.abstract import MambaBase -from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding -from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader -from vllm.model_executor.models.interfaces import (is_mixture_of_experts, - supports_eagle3, - supports_transcription) -from vllm.model_executor.models.interfaces_base import ( - VllmModelForPooling, is_pooling_model, is_text_generation_model) from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargsItem, - PlaceholderRange) -from vllm.multimodal.utils import group_mm_kwargs_by_modality -from vllm.pooling_params import PoolingParams -from vllm.sampling_params import SamplingType -from vllm.sequence import IntermediateTensors, PoolerOutput -from vllm.tasks import GenerationTask, PoolingTask, SupportedTask -from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size, - is_pin_memory_available, round_up, supports_dynamo) +from vllm.sequence import IntermediateTensors +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LazyLoader, check_use_alibi, + is_pin_memory_available, round_up) from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.utils import ( - AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, + AttentionCGSupport, CommonAttentionMetadata, create_fast_prefill_custom_backend) from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher # yapf conflicts with isort for this block # yapf: disable from vllm.v1.kv_cache_interface import (AttentionSpec, - ChunkedLocalAttentionSpec, - CrossAttentionSpec, EncoderOnlyAttentionSpec, - FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec, KVCacheSpec, - MambaSpec, SlidingWindowSpec) + KVCacheConfig, KVCacheSpec) # yapf: enable from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, - DraftTokenIds, LogprobsLists, LogprobsTensors, - ModelRunnerOutput, SamplerOutput) -from vllm.v1.pool.metadata import PoolingMetadata -from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs + LogprobsLists, LogprobsTensors, ModelRunnerOutput, + SamplerOutput) +from vllm.v1.sample.logits_processor import LogitsProcessors from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.sample.sampler import Sampler @@ -82,23 +54,18 @@ from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer -from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext -from vllm.v1.worker.gpu_worker_states import RequestState -from vllm.v1.worker.gpu_block_table import BlockTables +from vllm.v1.utils import record_function_or_nullcontext from vllm.v1.worker.gpu_input_batch import InputBatch, prepare_inputs +from vllm.v1.worker.gpu_worker_states import RequestState from vllm.v1.worker.kv_connector_model_runner_mixin import ( KVConnectorModelRunnerMixin) from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin -from .utils import (AttentionGroup, MultiModalBudget, - add_kv_sharing_layers_to_kv_cache_groups, bind_kv_cache, - gather_mm_placeholders, sanity_check_mm_encoder_outputs, - scatter_mm_placeholders) +from .utils import AttentionGroup, MultiModalBudget, bind_kv_cache if TYPE_CHECKING: import xgrammar as xgr - from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.v1.core.sched.output import SchedulerOutput else: xgr = LazyLoader("xgr", globals(), "xgrammar") @@ -387,225 +354,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): device="cpu", pin_memory=self.pin_memory) - def _make_buffer(self, - *size: Union[int, torch.SymInt], - dtype: torch.dtype, - numpy: bool = True) -> CpuGpuBuffer: - # Bfloat16 torch tensors cannot be directly cast to a numpy array, so - # if a bfloat16 buffer is needed without a corresponding numpy array, - # don't bother instantiating the numpy array. - return CpuGpuBuffer(*size, - dtype=dtype, - device=self.device, - pin_memory=self.pin_memory, - with_numpy=numpy) - - def _init_model_kwargs(self, num_tokens: int): - model_kwargs = dict[str, Any]() - - if not self.is_pooling_model: - return model_kwargs - - num_reqs = self.input_batch.num_reqs - pooling_params = self.input_batch.get_pooling_params() - - token_type_id_requests = dict[int, Any]() - for i, param in enumerate(pooling_params): - if param.extra_kwargs is not None and \ - (token_types := param.extra_kwargs.get( - "compressed_token_type_ids")) is not None: - token_type_id_requests[i] = token_types - - if len(token_type_id_requests) == 0: - return model_kwargs - - seq_lens = self.seq_lens.gpu[:num_reqs] - token_type_ids = [] - - for i in range(num_reqs): - pos = token_type_id_requests.get(i, seq_lens[i]) - ids = (torch.arange(seq_lens[i]) >= pos).int() - token_type_ids.append(ids) - - model_kwargs["token_type_ids"] = torch.concat(token_type_ids).to( - device=self.device) - return model_kwargs - - # Note: used for model runner override. - def _init_device_properties(self) -> None: - """Initialize attributes from torch.cuda.get_device_properties - """ - self.device_properties = torch.cuda.get_device_properties(self.device) - self.num_sms = self.device_properties.multi_processor_count - - # Note: used for model runner override. - def _sync_device(self) -> None: - torch.cuda.synchronize() - - def _update_states(self, scheduler_output: "SchedulerOutput") -> None: - """Update the cached states and the persistent batch with the scheduler - output. - - The updated states are used by the `_prepare_inputs` function to create - the input GPU tensors for the model. - - The SamplingMetadata is updated and copied to the GPU if there is a - new/resumed/paused/finished request in the batch. - """ - # Remove finished requests from the cached states. - # NOTE(woosuk): There could be an edge case where finished_req_ids and - # scheduled_req_ids overlap. This happens when a request is aborted and - # then resubmitted with the same ID. In this case, we treat them as two - # distinct requests - clearing the cached states for the first request - # and handling the second as a new request. - for req_id in scheduler_output.finished_req_ids: - self.req_states.remove_request(req_id) - self.encoder_cache.pop(req_id, None) - - # Free the cached encoder outputs. - for mm_hash in scheduler_output.free_encoder_mm_hashes: - self.encoder_cache.pop(mm_hash, None) - - req_indices: list[int] = [] - cu_num_new_blocks = tuple( - [0] for _ in range(self.block_tables.num_kv_cache_groups) - ) - new_block_ids = tuple( - [] for _ in range(self.block_tables.num_kv_cache_groups) - ) - overwrite: list[bool] = [] - - # Add new requests to the cached states. - for new_req_data in scheduler_output.scheduled_new_reqs: - req_id = new_req_data.req_id - self.req_states.add_request( - req_id=req_id, - prompt_token_ids=new_req_data.prompt_token_ids, - num_computed_tokens=new_req_data.num_computed_tokens, - sampling_params=new_req_data.sampling_params, - ) - - req_index = self.req_states.req_id_to_index[req_id] - req_indices.append(req_index) - for i, block_ids in enumerate(new_req_data.block_ids): - x = cu_num_new_blocks[i][-1] - cu_num_new_blocks[i].append(x + len(block_ids)) - new_block_ids[i].extend(block_ids) - overwrite.append(True) - - # Only relevant for models using M-RoPE (e.g, Qwen2-VL) - if self.uses_mrope: - self._init_mrope_positions(req_id) - - # Update the states of the running/resumed requests. - is_last_rank = get_pp_group().is_last_rank - cached_reqs = scheduler_output.scheduled_cached_reqs - for i, req_id in enumerate(cached_reqs.req_ids): - req_index = self.req_states.req_id_to_index[req_id] - - # Update input batch. - if not is_last_rank: - # When using PP, the scheduler sends the sampled tokens back, - # because there's no direct communication between the first- - # stage worker and the last-stage worker. - new_token_ids = cached_reqs.new_token_ids[i] - self.req_states.append_token_ids(req_index, new_token_ids) - - req_new_block_ids = cached_reqs.new_block_ids[i] - if req_new_block_ids is not None: - req_indices.append(req_index) - for group_id, block_ids in enumerate(req_new_block_ids): - x = cu_num_new_blocks[group_id][-1] - cu_num_new_blocks[group_id].append(x + len(block_ids)) - new_block_ids[group_id].extend(block_ids) - # If the request is resumed from preemption, we need to - # overwrite the existing block IDs. - overwrite.append(cached_reqs.resumed_from_preemption[i]) - - self.req_states.num_computed_tokens.np[req_index] = ( - cached_reqs.num_computed_tokens[i]) - - if req_indices: - self.block_tables.append_block_ids( - req_indices=req_indices, - cu_num_new_blocks=cu_num_new_blocks, - new_block_ids=new_block_ids, - overwrite=overwrite, - ) - - def _init_mrope_positions(self, req_id: str) -> None: - req_idx = self.req_states.req_id_to_index[req_id] - req_data = self.req_states.req_data[req_idx] - prompt_len = self.req_states.num_prompt_tokens.np[req_idx] - prompt_token_ids = self.req_states.token_ids.np[req_idx, :prompt_len] - prompt_token_ids = prompt_token_ids.tolist() - - image_grid_thw = [] - video_grid_thw = [] - second_per_grid_ts = [] - audio_feature_lengths = [] - use_audio_in_video = False - for mm_feature in req_state.mm_features: - mm_item = mm_feature.data - if mm_item is None: - continue - mm_input = mm_item.get_data() - if (t := mm_input.get("image_grid_thw")) is not None: - image_grid_thw.append(t.tolist()) - if (t := mm_input.get("video_grid_thw")) is not None: - video_grid_thw.append(t.tolist()) - if (t := mm_input.get("second_per_grid_ts")) is not None: - second_per_grid_ts.append(t) - if (t := mm_input.get("audio_feature_lengths")) is not None: - audio_feature_lengths.append(t) - if mm_input.get("use_audio_in_video") is True: - use_audio_in_video = True - - req_data.mrope_positions, req_data.mrope_position_delta = \ - MRotaryEmbedding.get_input_positions_tensor( - prompt_token_ids, - hf_config=self.model_config.hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - audio_feature_lengths=audio_feature_lengths, - use_audio_in_video=use_audio_in_video, - ) - - def _extract_mm_kwargs( - self, - scheduler_output: "SchedulerOutput", - ) -> BatchedTensorInputs: - if not scheduler_output or not self.is_multimodal_raw_input_only_model: - return {} - - mm_kwargs = list[MultiModalKwargsItem]() - for req in scheduler_output.scheduled_new_reqs: - for feature in req.mm_features: - if feature.data is not None: - mm_kwargs.append(feature.data) - - # Input all modalities at once - mm_kwargs_combined: BatchedTensorInputs = {} - for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( - mm_kwargs, - device=self.device, - pin_memory=self.pin_memory, - ): - mm_kwargs_combined.update(mm_kwargs_group) - - return mm_kwargs_combined - - def _dummy_mm_kwargs(self, num_seqs: int) -> BatchedTensorInputs: - if not self.is_multimodal_raw_input_only_model: - return {} - - mm_budget = self.mm_budget - assert mm_budget is not None - - dummy_modality = mm_budget.get_modality_with_max_tokens() - return self._get_mm_dummy_batch(dummy_modality, num_seqs) - def _prepare_inputs( self, scheduler_output: "SchedulerOutput", @@ -815,589 +563,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): logits_indices=logits_indices, ) - def _compute_cascade_attn_prefix_len( - self, - num_scheduled_tokens: np.ndarray, - num_common_prefix_blocks: int, - kv_cache_spec: KVCacheSpec, - attn_metadata_builder: AttentionMetadataBuilder, - ) -> int: - """Compute the length of the common prefix for cascade attention. - - NOTE(woosuk): The common prefix length returned by this function - represents the length used specifically for cascade attention, not the - actual number of tokens shared between requests. When cascade attention - is disabled (use_cascade=False), this function returns 0 even if - requests share common tokens. Additionally, the common prefix length is - truncated to a multiple of the block size and may be further truncated - due to implementation details explained below. - - Args: - num_scheduled_tokens: Number of tokens scheduled per request. - num_common_prefix_blocks: Number of shared KV cache blocks. - - Returns: - int: Length of common prefix in tokens. - """ - common_prefix_len = num_common_prefix_blocks * kv_cache_spec.block_size - if common_prefix_len == 0: - # Common case. - return 0 - - # NOTE(woosuk): Cascade attention uses two attention kernels: one - # for the common prefix and the other for the rest. For the first - # kernel, we concatenate all the query tokens (possibly from - # different requests) and treat them as if they are from the same - # request. Then, we use bi-directional attention to process the - # common prefix in the KV cache. Importantly, this means that the - # first kernel does not do any masking. - - # Consider the following example: - # Request 1's input query: [D, E, X] - # Request 1's kv cache: [A, B, C, D, E, X] - # Request 1's num_computed_tokens: 3 (i.e., [A, B, C]) - # Request 2's input query: [E, Y] - # Request 2's kv cache: [A, B, C, D, E, Y] - # Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D]) - - # If we use [A, B, C, D, E] as the common prefix, then the - # first kernel will compute the bi-directional attention between - # input query [D, E, X, E, Y] and common prefix [A, B, C, D, E]. - # However, this is wrong because D in Request 1 should not attend to - # E in the common prefix (i.e., we need masking). - # To avoid this, [A, B, C, D] should be the common prefix. - # That is, the common prefix should be capped by the minimum - # num_computed_tokens among the requests, and plus one to include - # the first token of the query. - - # In practice, we use [A, B, C] as the common prefix, instead of - # [A, B, C, D] (i.e., the common prefix is capped by the minimum - # num_computed_tokens, without plus one). - # This is because of an implementation detail: We want to always - # use two kernels for cascade attention. Let's imagine: - # Request 3's input query: [D] - # Request 3's kv cache: [A, B, C, D] - # Request 3's num_computed_tokens: 3 (i.e., [A, B, C]) - # If we use [A, B, C, D] as the common prefix for Request 1-3, - # then Request 3 will be processed only by the first kernel, - # and the second kernel will get an empty input. While this is not - # a fundamental problem, our current implementation does not support - # this case. - num_reqs = len(num_scheduled_tokens) - common_prefix_len = min( - common_prefix_len, - self.input_batch.num_computed_tokens_cpu[:num_reqs].min()) - # common_prefix_len should be a multiple of the block size. - common_prefix_len = (common_prefix_len // kv_cache_spec.block_size * - kv_cache_spec.block_size) - use_sliding_window = (isinstance(kv_cache_spec, SlidingWindowSpec) or - (isinstance(kv_cache_spec, FullAttentionSpec) - and kv_cache_spec.sliding_window is not None)) - use_local_attention = ( - isinstance(kv_cache_spec, ChunkedLocalAttentionSpec) - or (isinstance(kv_cache_spec, FullAttentionSpec) - and kv_cache_spec.attention_chunk_size is not None)) - assert isinstance(kv_cache_spec, AttentionSpec) - use_cascade = attn_metadata_builder.use_cascade_attention( - common_prefix_len=common_prefix_len, - query_lens=num_scheduled_tokens, - num_query_heads=self.num_query_heads, - num_kv_heads=kv_cache_spec.num_kv_heads, - use_alibi=self.use_alibi, - use_sliding_window=use_sliding_window, - use_local_attention=use_local_attention, - num_sms=self.num_sms, - ) - return common_prefix_len if use_cascade else 0 - - def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): - mrope_pos_ptr = 0 - for index, req_id in enumerate(self.input_batch.req_ids): - req = self.requests[req_id] - assert req.mrope_positions is not None - - num_computed_tokens = \ - self.input_batch.num_computed_tokens_cpu[index] - num_scheduled_tokens = \ - scheduler_output.num_scheduled_tokens[req_id] - num_prompt_tokens = len(req.prompt_token_ids) - - if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens: - prompt_part_len = max(0, - num_prompt_tokens - num_computed_tokens) - completion_part_len = max( - 0, num_scheduled_tokens - prompt_part_len) - else: - prompt_part_len = num_scheduled_tokens - completion_part_len = 0 - - assert num_scheduled_tokens == prompt_part_len + completion_part_len - - if prompt_part_len > 0: - # prompt's mrope_positions are pre-computed - dst_start = mrope_pos_ptr - dst_end = mrope_pos_ptr + prompt_part_len - src_start = num_computed_tokens - src_end = num_computed_tokens + prompt_part_len - - self.mrope_positions.cpu[:, dst_start:dst_end] = ( - req.mrope_positions[:, src_start:src_end]) - mrope_pos_ptr += prompt_part_len - - if completion_part_len > 0: - # compute completion's mrope_positions on-the-fly - dst_start = mrope_pos_ptr - dst_end = mrope_pos_ptr + completion_part_len - - MRotaryEmbedding.get_next_input_positions_tensor( - out=self.mrope_positions.np, - out_offset=dst_start, - mrope_position_delta=req.mrope_position_delta, - context_len=num_computed_tokens + prompt_part_len, - num_new_tokens=completion_part_len, - ) - - mrope_pos_ptr += completion_part_len - - def _prepare_spec_decode_metadata( - self, - req_ids: list[str], - req_id_to_draft_token_ids: dict[str, list[int]], - query_start_loc: torch.Tensor, - ) -> SpecDecodeMetadata: - # Get the number of draft tokens for each request. - num_reqs = len(req_ids) - num_draft_tokens = np.zeros(num_reqs, dtype=np.int32) - for i, req_id in enumerate(req_ids): - draft_token_ids = req_id_to_draft_token_ids.get(req_id) - if draft_token_ids: - num_draft_tokens[i] = len(draft_token_ids) - np.cumsum(num_draft_tokens, - dtype=np.int32, - out=self.cu_num_draft_tokens.np[:num_reqs]) - cu_num_draft_tokens = self.cu_num_draft_tokens.copy_to_gpu(num_reqs) - return self.req_states.make_spec_decode_metadata( - query_start_loc, - cu_num_draft_tokens, - cu_num_draft_tokens.np[:num_reqs], - self.input_ids.gpu, - ) - - def _prepare_kv_sharing_fast_prefill( - self, - logits_indices: torch.Tensor, - ) -> torch.Tensor: - assert self.kv_sharing_fast_prefill_logits_indices is not None - num_logits = logits_indices.shape[0] - assert num_logits > 0 - self.kv_sharing_fast_prefill_logits_indices[:num_logits].copy_( - logits_indices) - # There might have leftover indices in logits_indices[num_logits:] - # from previous iterations, whose values may be greater than the - # batch size in the current iteration. To ensure indices are always - # valid, we fill the padded indices with the last index. - self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_( - logits_indices[-1].item()) - if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - and num_logits <= self.cudagraph_batch_sizes[-1]): - # Use piecewise CUDA graphs. - # Add padding to the batch size. - num_logits_padded = self.vllm_config.pad_for_cudagraph(num_logits) - else: - num_logits_padded = num_logits - logits_indices_padded = ( - self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded]) - return logits_indices_padded - - def _batch_mm_kwargs_from_scheduler( - self, - scheduler_output: "SchedulerOutput", - ) -> tuple[list[MultiModalKwargsItem], list[tuple[str, PlaceholderRange]]]: - """Batch multimodal kwargs from scheduled encoder inputs. - - Args: - scheduler_output: The scheduler output containing scheduled encoder - inputs. - - Returns: - A tuple of (mm_kwargs, req_ids_pos) where: - - mm_kwargs: List of multimodal kwargs items to be batched - - mm_hashes_pos: List of (mm_hash, position_info) tuples - """ - scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs - if not scheduled_encoder_inputs: - return [], [] - # Batch the multi-modal inputs. - mm_kwargs = list[MultiModalKwargsItem]() - # list of tuple (mm_hash, position_info) - mm_hashes_pos = list[tuple[str, PlaceholderRange]]() - for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): - req_state = self.requests[req_id] - - for mm_input_id in encoder_input_ids: - mm_feature = req_state.mm_features[mm_input_id] - mm_hash = mm_feature.identifier - mm_kwargs.append(mm_feature.data) - mm_hashes_pos.append((mm_hash, mm_feature.mm_position)) - - return mm_kwargs, mm_hashes_pos - - def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): - # Batch the multi-modal inputs using the helper method. - mm_kwargs, mm_hashes_pos = self._batch_mm_kwargs_from_scheduler( - scheduler_output) - - if not mm_kwargs: - return - - # Batch mm inputs as much as we can: if a request in the batch has - # multiple modalities or a different modality than the previous one, - # we process it separately to preserve item order. - # FIXME(ywang96): This is a hacky way to deal with multiple modalities - # in the same batch while still being able to benefit from batching - # multimodal inputs. The proper solution should be reordering the - # encoder outputs. - encoder_outputs = [] - for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality( - mm_kwargs, - device=self.device, - pin_memory=self.pin_memory, - ): - # Run the encoder. - # `curr_group_outputs` is either of the following: - # 1. A tensor of shape (num_items, feature_size, hidden_size) - # in case feature_size is fixed across all multimodal items. - # 2. A list or tuple (length: num_items) of tensors, each of shape - # (feature_size, hidden_size) in case the feature size is dynamic - # depending on the input multimodal items. - curr_group_outputs = self.model.get_multimodal_embeddings( - **mm_kwargs_group) - - sanity_check_mm_encoder_outputs( - curr_group_outputs, - expected_num_items=num_items, - ) - - for output in curr_group_outputs: - encoder_outputs.append(output) - - # Cache the encoder outputs by mm_hash - for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs): - self.encoder_cache[mm_hash] = scatter_mm_placeholders( - output, - is_embed=pos_info.is_embed, - ) - - def _gather_mm_embeddings( - self, - scheduler_output: "SchedulerOutput", - shift_computed_tokens: int = 0, - ) -> list[torch.Tensor]: - mm_embeds: list[torch.Tensor] = [] - for req_id in self.input_batch.req_ids: - num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ - req_id] - req_state = self.requests[req_id] - num_computed_tokens = \ - req_state.num_computed_tokens + shift_computed_tokens - for mm_feature in req_state.mm_features: - pos_info = mm_feature.mm_position - start_pos = pos_info.offset - num_encoder_tokens = pos_info.length - - # The encoder output is needed if the two ranges overlap: - # [num_computed_tokens, - # num_computed_tokens + num_scheduled_tokens) and - # [start_pos, start_pos + num_encoder_tokens) - if start_pos >= num_computed_tokens + num_scheduled_tokens: - # The encoder output is not needed in this step. - break - if start_pos + num_encoder_tokens <= num_computed_tokens: - # The encoder output is already processed and stored - # in the decoder's KV cache. - continue - - start_idx = max(num_computed_tokens - start_pos, 0) - end_idx = min( - num_computed_tokens - start_pos + num_scheduled_tokens, - num_encoder_tokens, - ) - assert start_idx < end_idx - - mm_hash = mm_feature.identifier - encoder_output = self.encoder_cache.get(mm_hash, None) - assert encoder_output is not None,\ - f"Encoder cache miss for {mm_hash}." - - if (is_embed := pos_info.is_embed) is not None: - is_embed = is_embed[start_idx:end_idx] - - mm_embeds_item = gather_mm_placeholders( - encoder_output[start_idx:end_idx], - is_embed=is_embed, - ) - mm_embeds.append(mm_embeds_item) - return mm_embeds - - def _extract_encoder_inputs( - self, - scheduler_output: "SchedulerOutput", - ) -> dict[str, torch.Tensor]: - """Extract encoder inputs for encoder-decoder models. - - This method extracts multimodal input features from scheduled encoder - inputs and formats them for the encoder-decoder model forward pass. - """ - # Batch the multi-modal inputs using the helper method. - mm_kwargs, _ = self._batch_mm_kwargs_from_scheduler(scheduler_output) - - if not mm_kwargs: - return {} - - # Group MM kwargs by modality and extract features - encoder_features = {} - for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( - mm_kwargs, - device=self.device, - pin_memory=self.pin_memory, - ): - # Add the grouped features to encoder_features dict - # This allows the model to receive them as kwargs (e.g., - # input_features=...) - encoder_features.update(mm_kwargs_group) - - return encoder_features - - def get_model(self) -> nn.Module: - # get raw model out of the cudagraph wrapper. - if isinstance(self.model, CUDAGraphWrapper): - return self.model.unwrap() - return self.model - - def get_supported_generation_tasks(self) -> list[GenerationTask]: - model = self.get_model() - supported_tasks = list[GenerationTask]() - - if is_text_generation_model(model): - supported_tasks.append("generate") - - if supports_transcription(model): - if model.supports_transcription_only: - return ["transcription"] - - supported_tasks.append("transcription") - - return supported_tasks - - def get_supported_pooling_tasks(self) -> list[PoolingTask]: - model = self.get_model() - if not is_pooling_model(model): - return [] - - supported_tasks = list(model.pooler.get_supported_tasks()) - - if (self.scheduler_config.chunked_prefill_enabled - and "encode" in supported_tasks): - supported_tasks.remove("encode") - - logger.debug_once("Chunked prefill is not supported with " - "encode task which using ALL pooling. " - "Please turn off chunked prefill by " - "`--no-enable-chunked-prefill` before using it.") - - if "score" in supported_tasks: - num_labels = getattr(self.model_config.hf_config, "num_labels", 0) - if num_labels != 1: - supported_tasks.remove("score") - logger.debug_once( - "Score API is only enabled for num_labels == 1.") - - return supported_tasks - - def get_supported_tasks(self) -> tuple[SupportedTask, ...]: - tasks = list[SupportedTask]() - - if self.model_config.runner_type == "generate": - tasks.extend(self.get_supported_generation_tasks()) - if self.model_config.runner_type == "pooling": - tasks.extend(self.get_supported_pooling_tasks()) - - return tuple(tasks) - - def apply_grammar_bitmask( - self, - scheduler_output: "SchedulerOutput", - logits: torch.Tensor, - ): - grammar_bitmask = scheduler_output.grammar_bitmask - if grammar_bitmask is None: - return - - # We receive the structured output bitmask from the scheduler, - # compacted to contain bitmasks only for structured output requests. - # The order of the requests in the bitmask is not guaranteed to be the - # same as the order of the requests in the gpu runner's batch. We need - # to sort the bitmask to match the order of the requests used here. - - # Get the batch indices of the structured output requests. - # Keep track of the number of speculative tokens scheduled for every - # request in the batch, as the logit indices are offset by this amount. - struct_out_req_batch_indices: dict[str, int] = {} - cumulative_offset = 0 - seq = sorted(self.input_batch.req_id_to_index.items(), - key=lambda x: x[1]) - for req_id, batch_index in seq: - logit_index = batch_index + cumulative_offset - cumulative_offset += len( - scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) - if req_id in scheduler_output.structured_output_request_ids: - struct_out_req_batch_indices[req_id] = logit_index - - out_indices = [] - - # Reorder the bitmask to match the order of the requests in the batch. - sorted_bitmask = np.full(shape=(logits.shape[0], - grammar_bitmask.shape[1]), - fill_value=-1, - dtype=grammar_bitmask.dtype) - cumulative_index = 0 - seq = sorted(scheduler_output.structured_output_request_ids.items(), - key=lambda x: x[1]) - for req_id, _ in seq: - logit_index = struct_out_req_batch_indices[req_id] - num_spec_tokens = len( - scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) - for i in range(1 + num_spec_tokens): - sorted_bitmask[logit_index + i] = \ - grammar_bitmask[cumulative_index + i] - out_indices.append(logit_index + i) - cumulative_index += 1 + num_spec_tokens - grammar_bitmask = sorted_bitmask - - # If the length of out indices and the logits have the same shape - # we don't need to pass indices to the kernel, - # since the bitmask is already aligned with the logits. - skip_out_indices = len(out_indices) == logits.shape[0] - - # Serialization of np.ndarray is much more efficient than a tensor, - # so we receive it in that format. - grammar_bitmask = torch.from_numpy(grammar_bitmask).contiguous() - - xgr.apply_token_bitmask_inplace( - logits, - grammar_bitmask.to(self.device, non_blocking=True), - indices=out_indices if not skip_out_indices else None, - ) - - def sync_and_slice_intermediate_tensors( - self, num_tokens: int, intermediate_tensors: IntermediateTensors, - sync_self: bool) -> IntermediateTensors: - - assert self.intermediate_tensors is not None - - tp = self.vllm_config.parallel_config.tensor_parallel_size - enabled_sp = self.compilation_config.pass_config. \ - enable_sequence_parallelism - if enabled_sp: - # When sequence parallelism is enabled, we always pad num_tokens - # to be a multiple of tensor_parallel_size (tp) earlier - assert num_tokens % tp == 0 - is_residual_scattered = tp > 1 and enabled_sp \ - and num_tokens % tp == 0 - - # When sequence parallelism is enabled, the "residual" tensor is sharded - # across tensor parallel ranks, so each rank only needs its own slice. - if sync_self: - assert intermediate_tensors is not None - for k, v in intermediate_tensors.items(): - is_scattered = k == "residual" and is_residual_scattered - copy_len = num_tokens // tp if is_scattered else \ - num_tokens - self.intermediate_tensors[k][:copy_len].copy_( - v[:copy_len], non_blocking=True) - - return IntermediateTensors({ - k: - v[:num_tokens // tp] - if k == "residual" and is_residual_scattered else v[:num_tokens] - for k, v in self.intermediate_tensors.items() - }) - - def eplb_step(self, - is_dummy: bool = False, - is_profile: bool = False) -> None: - """ - Step for the EPLB (Expert Parallelism Load Balancing) state. - """ - if not self.parallel_config.enable_eplb: - return - - assert self.eplb_state is not None - model = self.get_model() - assert is_mixture_of_experts(model) - self.eplb_state.step( - model, - is_dummy, - is_profile, - log_stats=self.parallel_config.eplb_config.log_balancedness, - ) - - def get_dp_padding(self, - num_tokens: int) -> tuple[int, Optional[torch.Tensor]]: - dp_size = self.vllm_config.parallel_config.data_parallel_size - dp_rank = self.vllm_config.parallel_config.data_parallel_rank - - # For DP: Don't pad when setting enforce_eager. - # This lets us set enforce_eager on the prefiller in a P/D setup and - # still use CUDA graphs (enabled by this padding) on the decoder. - # - # TODO(tms) : There are many cases where padding is enabled for - # prefills, causing unnecessary and excessive padding of activations. - - if dp_size == 1 or self.vllm_config.model_config.enforce_eager: - # Early exit. - return 0, None - - num_tokens_across_dp = DPMetadata.num_tokens_across_dp( - num_tokens, dp_size, dp_rank) - max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item() - num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] * - dp_size, - device="cpu", - dtype=torch.int32) - return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding - - def _pool( - self, - hidden_states: torch.Tensor, - input_batch: InputBatch, - ) -> ModelRunnerOutput: - hidden_states = hidden_states[:num_scheduled_tokens] - pooling_metadata = self.req_states.get_pooling_metadata() - pooling_metadata.build_pooling_cursor(num_scheduled_tokens_np.tolist(), - device=hidden_states.device) - seq_lens_cpu = self.seq_lens.cpu[:self.input_batch.num_reqs] - - # Pooling models D2H & synchronize occurs in pooler.py:build_output - raw_pooler_output = self.model.pooler( - hidden_states=hidden_states, pooling_metadata=pooling_metadata) - - pooler_output: list[Optional[torch.Tensor]] = [] - for raw_output, seq_len, prompt_len in zip( - raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens): - - output = raw_output.data if seq_len == prompt_len else None - pooler_output.append(output) - - return ModelRunnerOutput( - req_ids=self.input_batch.req_ids, - req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids=[], - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=pooler_output, - ) - def _preprocess( self, scheduler_output: "SchedulerOutput", @@ -1709,7 +874,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): positions, intermediate_tensors, model_kwargs, - ) = self._preprocess(scheduler_output, input_batch, intermediate_tensors) + ) = self._preprocess(scheduler_output, input_batch, + intermediate_tensors) uniform_decode = (max_query_len == self.uniform_decode_query_len) and ( @@ -1847,284 +1013,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): async_output_copy_stream=self.async_output_copy_stream, ) - def take_draft_token_ids(self) -> Optional[DraftTokenIds]: - if self._draft_token_ids is None: - return None - req_ids = self.input_batch.req_ids - if isinstance(self._draft_token_ids, torch.Tensor): - draft_token_ids = self._draft_token_ids.tolist() - else: - draft_token_ids = self._draft_token_ids - self._draft_token_ids = None - return DraftTokenIds(req_ids, draft_token_ids) - - def propose_draft_token_ids( - self, - scheduler_output: "SchedulerOutput", - sampled_token_ids: list[list[int]], - sampling_metadata: SamplingMetadata, - hidden_states: torch.Tensor, - sample_hidden_states: torch.Tensor, - aux_hidden_states: Optional[torch.Tensor], - spec_decode_metadata: Optional[SpecDecodeMetadata], - common_attn_metadata: CommonAttentionMetadata, - ) -> Union[list[list[int]], torch.Tensor]: - num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - if self.speculative_config.method == "ngram": - assert isinstance(self.drafter, NgramProposer) - draft_token_ids = self.propose_ngram_draft_token_ids( - sampled_token_ids) - elif self.speculative_config.method == "medusa": - assert isinstance(self.drafter, MedusaProposer) - if sample_hidden_states.shape[0] == len(sampled_token_ids): - # The input to the target model does not include draft tokens. - hidden_states = sample_hidden_states - else: - indices = [] - offset = 0 - for num_draft, tokens in zip( - spec_decode_metadata.num_draft_tokens, - sampled_token_ids): - indices.append(offset + len(tokens) - 1) - offset += num_draft + 1 - indices = torch.tensor(indices, device=self.device) - hidden_states = sample_hidden_states[indices] - - draft_token_ids = self.drafter.propose( - target_hidden_states=hidden_states, - sampling_metadata=sampling_metadata, - ) - elif self.speculative_config.use_eagle(): - assert isinstance(self.drafter, EagleProposer) - # TODO(woosuk): Refactor the loop. - req_ids = self.input_batch.req_ids - next_token_ids: list[int] = [] - for i, token_ids in enumerate(sampled_token_ids): - if token_ids: - # Common case. - next_token_id = token_ids[-1] - else: - # Partial prefill (rare case). - # Get the next token id from the request state. - req_id = req_ids[i] - req_state = self.requests[req_id] - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) - next_token_id = req_state.get_token_id(seq_len) - next_token_ids.append(next_token_id) - next_token_ids = torch.tensor(next_token_ids, - dtype=torch.int32, - device=self.device) - - if spec_decode_metadata is None: - # input_ids can be None for multimodal models. - target_token_ids = self.input_ids.gpu[:num_scheduled_tokens] - # TODO(woosuk): Support M-RoPE. - target_positions = self.positions.gpu[:num_scheduled_tokens] - if self.use_aux_hidden_state_outputs: - target_hidden_states = torch.cat( - [h[:num_scheduled_tokens] for h in aux_hidden_states], - dim=-1) - else: - target_hidden_states = hidden_states[:num_scheduled_tokens] - else: - # TODO(woosuk): Refactor this. - num_draft_tokens = spec_decode_metadata.num_draft_tokens - num_rejected_tokens = [ - n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 - for i, n in enumerate(num_draft_tokens) - ] - num_rejected_tokens_cpu = torch.tensor(num_rejected_tokens, - dtype=torch.int32) - common_attn_metadata, token_indices =\ - self.drafter.prepare_inputs( - common_attn_metadata, num_rejected_tokens_cpu) - - target_token_ids = self.input_ids.gpu[token_indices] - # TODO(woosuk): Support M-RoPE. - target_positions = self.positions.gpu[token_indices] - if self.use_aux_hidden_state_outputs: - target_hidden_states = torch.cat( - [h[token_indices] for h in aux_hidden_states], dim=-1) - else: - target_hidden_states = hidden_states[token_indices] - mm_embeds = None - if self.supports_mm_inputs: - mm_embeds = self._gather_mm_embeddings(scheduler_output, - shift_computed_tokens=1) - - draft_token_ids = self.drafter.propose( - target_token_ids=target_token_ids, - target_positions=target_positions, - target_hidden_states=target_hidden_states, - next_token_ids=next_token_ids, - sampling_metadata=sampling_metadata, - common_attn_metadata=common_attn_metadata, - mm_embeds=mm_embeds, - ) - return draft_token_ids - - def propose_ngram_draft_token_ids( - self, - sampled_token_ids: list[list[int]], - ) -> list[list[int]]: - # TODO(woosuk): Optimize. - req_ids = self.input_batch.req_ids - draft_token_ids: list[list[int]] = [] - for i, sampled_ids in enumerate(sampled_token_ids): - num_sampled_ids = len(sampled_ids) - if not num_sampled_ids: - # Skip speculative decoding. - draft_token_ids.append([]) - continue - - # Skip requests that require sampling parameters that are not - # supported with speculative decoding. - req_id = req_ids[i] - if req_id in self.input_batch.spec_decode_unsupported_reqs: - draft_token_ids.append([]) - continue - - num_tokens = self.input_batch.num_tokens_no_spec[i] - if num_tokens >= self.max_model_len: - # Skip requests that have already reached the max model length. - draft_token_ids.append([]) - continue - - drafter_output = self.drafter.propose( - self.input_batch.token_ids_cpu[i, :num_tokens]) - if drafter_output is None or len(drafter_output) == 0: - draft_token_ids.append([]) - else: - draft_token_ids.append(drafter_output.tolist()) - return draft_token_ids - - def update_config(self, overrides: dict[str, Any]) -> None: - allowed_config_names = {"load_config", "model_config"} - for config_name, config_overrides in overrides.items(): - assert config_name in allowed_config_names, \ - f"Config `{config_name}` not supported. " \ - f"Allowed configs: {allowed_config_names}" - config = getattr(self, config_name) - new_config = update_config(config, config_overrides) - setattr(self, config_name, new_config) - - def load_model(self, eep_scale_up: bool = False) -> None: - """ - Args: - eep_scale_up: the model loading is for elastic EP scale up. - """ - logger.info("Starting to load model %s...", self.model_config.model) - if eep_scale_up: - from vllm.distributed.parallel_state import get_ep_group - num_local_physical_experts = torch.empty(1, - dtype=torch.int32, - device="cpu") - torch.distributed.broadcast(num_local_physical_experts, - group=get_ep_group().cpu_group, - group_src=0) - num_local_physical_experts = int(num_local_physical_experts.item()) - new_ep_size = get_ep_group().world_size - global_expert_load, old_global_expert_indices = ( - EplbState.recv_state()) - num_logical_experts = global_expert_load.shape[1] - self.parallel_config.eplb_config.num_redundant_experts = ( - num_local_physical_experts * new_ep_size - num_logical_experts) - assert old_global_expert_indices.shape[ - 1] % num_local_physical_experts == 0 - old_ep_size = old_global_expert_indices.shape[ - 1] // num_local_physical_experts - rank_mapping = { - old_ep_rank: old_ep_rank - for old_ep_rank in range(old_ep_size) - } - else: - global_expert_load = None - old_global_expert_indices = None - rank_mapping = None - - with DeviceMemoryProfiler() as m: - time_before_load = time.perf_counter() - model_loader = get_model_loader(self.load_config) - logger.info("Loading model from scratch...") - self.model = model_loader.load_model( - vllm_config=self.vllm_config, model_config=self.model_config) - if self.lora_config: - self.model = self.load_lora_model(self.model, - self.model_config, - self.scheduler_config, - self.lora_config, - self.device) - if hasattr(self, "drafter"): - logger.info("Loading drafter model...") - self.drafter.load_model(self.model) - if self.use_aux_hidden_state_outputs: - if supports_eagle3(self.model): - self.model.set_aux_hidden_state_layers( - self.model.get_eagle3_aux_hidden_state_layers()) - else: - raise RuntimeError( - "Model does not support EAGLE3 interface but " - "aux_hidden_state_outputs was requested") - time_after_load = time.perf_counter() - self.model_memory_usage = m.consumed_memory - logger.info("Model loading took %.4f GiB and %.6f seconds", - self.model_memory_usage / GiB_bytes, - time_after_load - time_before_load) - prepare_communication_buffer_for_model(self.model) - - if is_mixture_of_experts( - self.model) and self.parallel_config.enable_eplb: - logger.info("EPLB is enabled for model %s.", - self.model_config.model) - self.eplb_state = EplbState.build( - self.model, - self.device, - self.parallel_config, - global_expert_load, - old_global_expert_indices, - rank_mapping, - ) - - if ( - self.vllm_config.compilation_config.level == \ - CompilationLevel.DYNAMO_AS_IS and supports_dynamo() - ): - backend = self.vllm_config.compilation_config.init_backend( - self.vllm_config) - compilation_counter.dynamo_as_is_count += 1 - self.model.compile( - fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, - backend=backend) - return - # for other compilation levels, cudagraph behavior is controlled by - # CudagraphWraper and CudagraphDispatcher of vllm. - - # wrap the model with full cudagraph wrapper if needed. - if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): - self.model = CUDAGraphWrapper(self.model, - self.vllm_config, - runtime_mode=CUDAGraphMode.FULL) - - def reload_weights(self) -> None: - assert getattr(self, "model", None) is not None, \ - "Cannot reload weights before model is loaded." - model_loader = get_model_loader(self.load_config) - logger.info("Reloading weights inplace...") - model = self.get_model() - model_loader.load_weights(model, model_config=self.model_config) - - def save_tensorized_model( - self, - tensorizer_config: "TensorizerConfig", - ) -> None: - model = self.get_model() - TensorizerLoader.save_model( - model, - tensorizer_config=tensorizer_config, - model_config=self.model_config, - ) - def _get_prompt_logprobs_dict( self, hidden_states: torch.Tensor, @@ -2219,82 +1107,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return prompt_logprobs_dict - def _get_nans_in_logits( - self, - logits: Optional[torch.Tensor], - ) -> dict[str, int]: - try: - if logits is None: - return {req_id: 0 for req_id in self.input_batch.req_ids} - - num_nans_in_logits = {} - num_nans_for_index = logits.isnan().sum(dim=-1).cpu().numpy() - for req_id in self.input_batch.req_ids: - req_index = self.input_batch.req_id_to_index[req_id] - num_nans_in_logits[req_id] = ( - int(num_nans_for_index[req_index]) - if num_nans_for_index is not None - and req_index < logits.shape[0] else 0) - return num_nans_in_logits - except IndexError: - return {} - - @contextmanager - def maybe_randomize_inputs(self, input_ids: torch.Tensor): - """ - Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set. - This is to help balance expert-selection - - during profile_run - - during DP rank dummy run - """ - dp_size = self.vllm_config.parallel_config.data_parallel_size - randomize_inputs = envs.VLLM_RANDOMIZE_DP_DUMMY_INPUTS and dp_size > 1 - if not randomize_inputs: - yield - else: - import functools - - @functools.cache - def rand_input_ids() -> torch.Tensor: - return torch.randint_like( - self.input_ids.gpu, - low=0, - high=self.model_config.get_vocab_size(), - dtype=input_ids.dtype) - - logger.debug_once("Randomizing dummy data for DP Rank") - input_ids.copy_(rand_input_ids()[:input_ids.size(0)], - non_blocking=True) - yield - input_ids.fill_(0) - - def _get_mm_dummy_batch( - self, - modality: str, - max_items_per_batch: int, - ) -> BatchedTensorInputs: - """Dummy data for profiling and precompiling multimodal models.""" - assert self.mm_budget is not None - - dummy_decoder_data = self.mm_registry.get_decoder_dummy_data( - model_config=self.model_config, - seq_len=self.max_num_tokens, - mm_counts={modality: 1}, - cache=self.mm_budget.cache, - ) - dummy_mm_data = dummy_decoder_data.multi_modal_data - - # Result in the maximum GPU consumption of the model - dummy_mm_item = dummy_mm_data[modality][0] - dummy_mm_items = [dummy_mm_item] * max_items_per_batch - - return next(mm_kwargs_group - for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( - dummy_mm_items, - device=self.device, - pin_memory=self.pin_memory, - )) - @torch.inference_mode() def _dummy_run( self, @@ -2590,134 +1402,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) return sampler_output - def _dummy_pooler_run_task( - self, - hidden_states: torch.Tensor, - task: PoolingTask, - ) -> PoolerOutput: - num_tokens = hidden_states.shape[0] - max_num_reqs = self.scheduler_config.max_num_seqs - num_reqs = min(num_tokens, max_num_reqs) - min_tokens_per_req = num_tokens // num_reqs - num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs - num_scheduled_tokens_list[-1] += num_tokens % num_reqs - assert sum(num_scheduled_tokens_list) == num_tokens - assert len(num_scheduled_tokens_list) == num_reqs - - req_num_tokens = num_tokens // num_reqs - - dummy_prompt_lens = torch.tensor( - num_scheduled_tokens_list, - device="cpu", - ) - dummy_token_ids = torch.zeros((num_reqs, req_num_tokens), - dtype=torch.int32, - device=self.device) - - model = cast(VllmModelForPooling, self.get_model()) - dummy_pooling_params = PoolingParams(task=task) - to_update = model.pooler.get_pooling_updates(task) - to_update.apply(dummy_pooling_params) - - dummy_metadata = PoolingMetadata( - prompt_lens=dummy_prompt_lens, - prompt_token_ids=dummy_token_ids, - pooling_params=[dummy_pooling_params] * num_reqs, - ) - - dummy_metadata.build_pooling_cursor(num_scheduled_tokens_list, - device=hidden_states.device) - - try: - return model.pooler(hidden_states=hidden_states, - pooling_metadata=dummy_metadata) - except RuntimeError as e: - if 'out of memory' in str(e): - raise RuntimeError( - "CUDA out of memory occurred when warming up pooler " - f"({task=}) with {num_reqs} dummy requests. Please try " - "lowering `max_num_seqs` or `gpu_memory_utilization` when " - "initializing the engine.") from e - else: - raise e - - @torch.inference_mode() - def _dummy_pooler_run( - self, - hidden_states: torch.Tensor, - ) -> PoolerOutput: - # Find the task that has the largest output for subsequent steps - output_size = dict[PoolingTask, float]() - for task in self.get_supported_pooling_tasks(): - # Run a full batch with each task to ensure none of them OOMs - output = self._dummy_pooler_run_task(hidden_states, task) - output_size[task] = output.get_data_nbytes() - del output # Allow GC - - max_task = max(output_size.items(), key=lambda x: x[1])[0] - return self._dummy_pooler_run_task(hidden_states, max_task) - def profile_run(self) -> None: - # Profile with multimodal encoder & encoder cache. - if self.supports_mm_inputs: - if self.model_config.multimodal_config.skip_mm_profiling: - logger.info( - "Skipping memory profiling for multimodal encoder and " - "encoder cache.") - else: - mm_budget = self.mm_budget - assert mm_budget is not None - - if (encoder_budget := mm_budget.get_encoder_budget()) > 0: - # NOTE: Currently model is profiled with a single non-text - # modality with the max possible input tokens even when - # it supports multiple. - dummy_modality = mm_budget.get_modality_with_max_tokens() - max_mm_items_per_batch = mm_budget \ - .max_items_per_batch_by_modality[dummy_modality] - - logger.info( - "Encoder cache will be initialized with a budget of " - "%s tokens, and profiled with %s %s items of the " - "maximum feature size.", - encoder_budget, - max_mm_items_per_batch, - dummy_modality, - ) - - # Create dummy batch of multimodal inputs. - batched_dummy_mm_inputs = self._get_mm_dummy_batch( - dummy_modality, - max_mm_items_per_batch, - ) - - # Run multimodal encoder. - dummy_encoder_outputs = \ - self.model.get_multimodal_embeddings( - **batched_dummy_mm_inputs) - - sanity_check_mm_encoder_outputs( - dummy_encoder_outputs, - expected_num_items=max_mm_items_per_batch, - ) - - # Cache the dummy encoder outputs. - self.encoder_cache["tmp"] = dict( - enumerate(dummy_encoder_outputs)) - # Add `is_profile` here to pre-allocate communication buffers hidden_states, last_hidden_states \ = self._dummy_run(self.max_num_tokens, is_profile=True) - if get_pp_group().is_last_rank: - if self.is_pooling_model: - output = self._dummy_pooler_run(hidden_states) - else: - output = self._dummy_sampler_run(last_hidden_states) - else: - output = None - self._sync_device() + output = self._dummy_sampler_run(last_hidden_states) del hidden_states, output - self.encoder_cache.clear() gc.collect() def capture_model(self) -> int: @@ -3065,62 +1755,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): kv_caches[layer_name] = kv_cache_raw_tensors[ layer_name].view(dtype).view(kv_cache_shape).permute( *inv_order) - elif isinstance(kv_cache_spec, MambaSpec): - has_mamba = True - raw_tensor = kv_cache_raw_tensors[layer_name] - state_tensors = [] - storage_offset_bytes = 0 - for (shape, dtype) in zip(kv_cache_spec.shapes, - kv_cache_spec.dtypes): - dtype_size = get_dtype_size(dtype) - num_element_per_page = ( - kv_cache_spec.page_size_bytes // dtype_size) - target_shape = (num_blocks, *shape) - stride = torch.empty(target_shape).stride() - target_stride = (num_element_per_page, *stride[1:]) - assert storage_offset_bytes % dtype_size == 0 - tensor = torch.as_strided( - raw_tensor.view(dtype), - size=target_shape, - stride=target_stride, - storage_offset=storage_offset_bytes // dtype_size, - ) - state_tensors.append(tensor) - storage_offset_bytes += stride[0] * dtype_size - - kv_caches[layer_name] = state_tensors - else: - raise NotImplementedError - - if has_attn and has_mamba: - self._update_hybrid_attention_mamba_layout(kv_caches) return kv_caches - def _update_hybrid_attention_mamba_layout( - self, kv_caches: dict[str, torch.Tensor]) -> None: - """ - Update the layout of attention layers from (2, num_blocks, ...) to - (num_blocks, 2, ...). - - Args: - kv_caches: The KV cache buffer of each layer. - """ - - for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator(): - for layer_name in group.layer_names: - kv_cache = kv_caches[layer_name] - if (isinstance(kv_cache_spec, AttentionSpec) - and kv_cache.shape[0] == 2): - assert kv_cache.shape[1] != 2, \ - "Fail to determine whether the layout is " \ - "(2, num_blocks, ...) or (num_blocks, 2, ...) for " \ - f"a tensor of shape {kv_cache.shape}" - hidden_size = kv_cache.shape[2:].numel() - kv_cache.as_strided_(size=kv_cache.shape, - stride=(hidden_size, 2 * hidden_size, - *kv_cache.stride()[2:])) - def initialize_kv_cache_tensors( self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: """ @@ -3138,62 +1775,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): kv_caches = self._reshape_kv_cache_tensors(kv_cache_config, kv_cache_raw_tensors) - # Set up cross-layer KV cache sharing - for layer_name, target_layer_name in self.shared_kv_cache_layers.items( - ): - logger.debug("%s reuses KV cache of %s", layer_name, - target_layer_name) - kv_caches[layer_name] = kv_caches[target_layer_name] - bind_kv_cache(kv_caches, self.compilation_config.static_forward_context, self.kv_caches) return kv_caches - def init_block_tables(self, kv_cache_config: KVCacheConfig) -> None: - block_sizes = [ - kv_cache_group.kv_cache_spec.block_size - for kv_cache_group in kv_cache_config.kv_cache_groups - ] - self.block_tables = BlockTables( - block_sizes=block_sizes, - max_num_reqs=self.max_num_reqs, - max_num_cached_reqs=self.max_num_cached_reqs, - max_num_batched_tokens=self.max_num_tokens, - max_model_len=self.max_model_len, - device=self.device, - pin_memory=self.pin_memory, - ) - - def maybe_add_kv_sharing_layers_to_kv_cache_groups( - self, kv_cache_config: KVCacheConfig) -> None: - """ - Add layers that re-use KV cache to KV cache group of its target layer. - Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()` - """ - if not self.shared_kv_cache_layers: - # No cross-layer KV sharing, return - return - - add_kv_sharing_layers_to_kv_cache_groups( - self.shared_kv_cache_layers, - kv_cache_config.kv_cache_groups, - self.runner_only_attn_layers, - ) - - if self.cache_config.kv_sharing_fast_prefill: - # In You Only Cache Once (https://arxiv.org/abs/2405.05254) or other - # similar KV sharing setups, only the layers that generate KV caches - # are involved in the prefill phase, enabling prefill to early exit. - attn_layers = get_layers_from_vllm_config(self.vllm_config, - Attention) - for layer_name in reversed(attn_layers): - if layer_name in self.shared_kv_cache_layers: - self.kv_sharing_fast_prefill_eligible_layers.add( - layer_name) - else: - break - def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. @@ -3203,162 +1789,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): """ kv_cache_config = deepcopy(kv_cache_config) self.kv_cache_config = kv_cache_config - self.init_block_tables(kv_cache_config) - self.may_add_encoder_only_layers_to_kv_cache_config() - self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config) self.initialize_attn_backend(kv_cache_config) kv_caches = self.initialize_kv_cache_tensors(kv_cache_config) - if self.speculative_config and self.speculative_config.use_eagle(): - assert isinstance(self.drafter, EagleProposer) - # validate all draft model layers belong to the same kv cache - # group - self.drafter.validate_same_kv_cache_group(kv_cache_config) - - if has_kv_transfer_group(): - get_kv_transfer_group().register_kv_caches(kv_caches) - if self.device.type == 'xpu': - get_kv_transfer_group().set_host_xfer_buffer_ops( - copy_kv_blocks) - - if self.dcp_world_size > 1: - layer_names = self.attn_groups[0][0].layer_names - layers = get_layers_from_vllm_config(self.vllm_config, - AttentionLayerBase, - layer_names) - for layer in layers.values(): - assert layer.impl.need_to_return_lse_for_decode, ( - "DCP requires attention impls to return" - " the softmax lse for decode, but the impl " - f"{layer.impl.__class__.__name__} " - "does not return the softmax lse for decode.") - - def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: - """ - Add encoder-only layers to the KV cache config. - """ - block_size = self.vllm_config.cache_config.block_size - use_mla = self.vllm_config.model_config.use_mla - encoder_only_attn_specs: dict[AttentionSpec, - list[str]] = defaultdict(list) - attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) - for layer_name, attn_module in attn_layers.items(): - if attn_module.attn_type == AttentionType.ENCODER_ONLY: - attn_spec: AttentionSpec = EncoderOnlyAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - use_mla=use_mla) - encoder_only_attn_specs[attn_spec].append(layer_name) - self.runner_only_attn_layers.add(layer_name) - if len(encoder_only_attn_specs) > 0: - assert len( - encoder_only_attn_specs - ) == 1, "Only support one encoder-only attention spec now" - spec, layer_names = encoder_only_attn_specs.popitem() - self.kv_cache_config.kv_cache_groups.append( - KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec)) - - def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: - """ - Generates the KVCacheSpec by parsing the kv cache format from each - Attention module in the static forward context. - Returns: - KVCacheSpec: A dictionary mapping layer names to their KV cache - format. Layers that do not need KV cache are not included. - """ - - block_size = self.vllm_config.cache_config.block_size - use_mla = self.vllm_config.model_config.use_mla - kv_cache_spec: dict[str, KVCacheSpec] = {} - attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) - for layer_name, attn_module in attn_layers.items(): - if (kv_tgt_layer := - attn_module.kv_sharing_target_layer_name) is not None: - # The layer doesn't need its own KV cache and will use that of - # the target layer. We skip creating a KVCacheSpec for it, so - # that KV cache management logic will act as this layer does - # not exist, and doesn't allocate KV cache for the layer. This - # enables the memory saving of cross-layer kv sharing, allowing - # a given amount of memory to accommodate longer context lengths - # or enable more requests to be processed simultaneously. - self.shared_kv_cache_layers[layer_name] = kv_tgt_layer - continue - - # TODO(lucas): move the attention specs into the model layers like - # the attention backends - if attn_module.attn_type == AttentionType.DECODER: - if attn_module.sliding_window is not None: - kv_cache_spec[layer_name] = SlidingWindowSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - sliding_window=attn_module.sliding_window, - use_mla=use_mla) - elif self.attention_chunk_size is not None \ - and isinstance(attn_module, ChunkedLocalAttention): - kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - attention_chunk_size=self.attention_chunk_size, - use_mla=use_mla) - else: - kv_cache_spec[layer_name] = FullAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - use_mla=use_mla) - elif attn_module.attn_type == AttentionType.ENCODER_DECODER: - kv_cache_spec[layer_name] = CrossAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - use_mla=use_mla) - elif attn_module.attn_type in (AttentionType.ENCODER, - AttentionType.ENCODER_ONLY): - # encoder-only attention does not need KV cache. - continue - else: - raise ValueError( - f"Unknown attention type: {attn_module.attn_type}") - - mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase) - if len(mamba_layers) > 0: - if (self.vllm_config.speculative_config is not None - and self.vllm_config.model_config.hf_config.model_type - not in ["qwen3_next"]): - raise NotImplementedError( - "Mamba with speculative decoding is not supported yet.") - if self.vllm_config.cache_config.enable_prefix_caching: - raise NotImplementedError( - "Prefix caching is not supported for Mamba yet.") - max_model_len = self.vllm_config.model_config.max_model_len - - page_size_padded = ( - self.vllm_config.cache_config.mamba_page_size_padded) - - # Set block_size to max_model_len, so that mamba model will always - # have only one block in the KV cache. - for layer_name, mamba_module in mamba_layers.items(): - kv_cache_spec[layer_name] = MambaSpec( - shapes=mamba_module.get_state_shape(), - dtypes=mamba_module.get_state_dtype(), - block_size=max_model_len, - page_size_padded=page_size_padded, - mamba_type=mamba_module.mamba_type, - num_speculative_blocks=( - self.speculative_config.num_speculative_tokens - if self.speculative_config else 0), - ) - - return kv_cache_spec - def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]: # This is a short term mitigation for issue mentioned in # https://github.com/vllm-project/vllm/issues/22754. diff --git a/vllm/v1/worker/gpu_worker_states.py b/vllm/v1/worker/gpu_worker_states.py deleted file mode 100644 index 8417ea6f0faf2..0000000000000 --- a/vllm/v1/worker/gpu_worker_states.py +++ /dev/null @@ -1,291 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from dataclasses import dataclass -from typing import Optional, Union - -import numpy as np -import torch -import triton -import triton.language as tl - -from vllm.lora.request import LoRARequest -from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange -from vllm.pooling_params import PoolingParams -from vllm.sampling_params import SamplingParams, SamplingType -from vllm.v1.sample.logits_processor import LogitsProcessors -from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.spec_decode.metadata import SpecDecodeMetadata - -_MAX_SPEC_LEN = 32 - - -@dataclass -class RequestData: - - mm_kwargs: list[MultiModalKwargsItem] - mm_positions: list[PlaceholderRange] - sampling_params: Optional[SamplingParams] - pooling_params: Optional[PoolingParams] - - mm_hashes: list[str] - # M-RoPE (only for Qwen2/2.5-VL) - mrope_positions: Optional[torch.Tensor] = None - mrope_position_delta: Optional[int] = None - - lora_request: Optional[LoRARequest] = None - - -class RequestState: - - def __init__( - self, - max_num_reqs: int, - max_model_len: int, - max_num_batched_tokens: int, - max_num_cached_reqs: int, - device: torch.device, - pin_memory: bool, - vocab_size: int, - logitsprocs: Optional[LogitsProcessors] = None, - is_spec_decode: bool = False, - is_pooling_model: bool = False, - ): - self.max_num_reqs = max_num_reqs - self.max_model_len = max_model_len - self.max_num_batched_tokens = max_num_batched_tokens - self.max_num_cached_reqs = max_num_cached_reqs - self.device = device - self.pin_memory = pin_memory - self.vocab_size = vocab_size - self.is_spec_decode = is_spec_decode - self.pooling_params = None - self.num_prompt_logprobs: dict[int, int] = {} - - self.req_id_to_index: dict[str, int] = {} - self.index_to_req_id: dict[int, str] = {} - self.free_indices = list(range(max_num_cached_reqs)) - - # Request states. - self.req_data: dict[int, RequestData] = {} - # TODO(woosuk): Because the token_ids tensor can be very big, we only - # initialize it on CPU memory. - self.token_ids = np.zeros( - (self.max_num_cached_reqs, self.max_model_len), - dtype=np.int32, - ) - self.num_prompt_tokens = np.zeros(self.max_num_cached_reqs, - dtype=np.int32) - self.num_tokens = np.zeros(self.max_num_cached_reqs, dtype=np.int32) - self.num_computed_tokens = np.zeros(self.max_num_cached_reqs, - dtype=np.int32) - - # Last sampled token ids. - self.last_sampled_token = torch.zeros( - self.max_num_cached_reqs, - dtype=torch.int32, - device=self.device, - ) - - self.temperature = np.zeros(self.max_num_cached_reqs, dtype=np.float32) - self.top_p = np.zeros(self.max_num_cached_reqs, dtype=np.float32) - self.top_k = np.zeros(self.max_num_cached_reqs, dtype=np.int32) - - self.frequency_penalties = np.zeros(self.max_num_cached_reqs, - dtype=np.float32) - self.presence_penalties = np.zeros(self.max_num_cached_reqs, - dtype=np.float32) - self.repetition_penalties = np.zeros(self.max_num_cached_reqs, - dtype=np.float32) - - self.generators: dict[int, torch.Generator] = {} - - def add_request( - self, - req_id: str, - prompt_token_ids: list[int], - num_computed_tokens: int, - sampling_params: SamplingParams, - ) -> None: - assert len(self.free_indices) > 0, "No free space in GPU worker states" - req_idx = self.free_indices.pop() - self.req_id_to_index[req_id] = req_idx - self.index_to_req_id[req_idx] = req_id - - prompt_len = len(prompt_token_ids) - self.num_prompt_tokens[req_idx] = prompt_len - self.num_tokens[req_idx] = prompt_len - self.token_ids[req_idx, :prompt_len] = prompt_token_ids - self.num_computed_tokens[req_idx] = num_computed_tokens - - self.temperature[req_idx] = sampling_params.temperature - self.top_p[req_idx] = sampling_params.top_p - if 0 < sampling_params.top_k < self.vocab_size: - top_k = sampling_params.top_k - else: - top_k = self.vocab_size - self.top_k[req_idx] = top_k - self.frequency_penalties[req_idx] = sampling_params.frequency_penalty - self.presence_penalties[req_idx] = sampling_params.presence_penalty - self.repetition_penalties[req_idx] = sampling_params.repetition_penalty - - if sampling_params.sampling_type == SamplingType.RANDOM_SEED: - generator = torch.Generator(device=self.device) - generator.manual_seed(sampling_params.seed) - self.generators[req_idx] = generator - - @property - def num_cached_reqs(self) -> int: - return len(self.req_id_to_index) - - def append_token_ids( - self, - req_idx: int, - token_ids: Union[list[int], np.ndarray], - ) -> None: - start_idx = self.num_tokens.np[req_idx] - end_idx = start_idx + len(token_ids) - self.token_ids.np[req_idx, start_idx:end_idx] = token_ids - self.num_tokens.np[req_idx] = end_idx - - def remove_request(self, req_id: str) -> None: - req_idx = self.req_id_to_index.pop(req_id, None) - if req_idx is None: - # Request not found. - return - self.index_to_req_id.pop(req_idx, None) - self.free_indices.append(req_idx) - - def make_sampling_metadata( - self, - idx_mapping: np.ndarray, - ) -> SamplingMetadata: - temperature = self.temperature[idx_mapping] - all_greedy = np.all(temperature == 0.0) - all_random = np.all(temperature != 0.0) - temperature = self._copy_np_to_gpu(temperature) - - top_p = self.top_p[idx_mapping] - no_top_p = np.all(top_p == 1.0) - top_p = self._copy_np_to_gpu(top_p) if not no_top_p else None - top_k = self.top_k[idx_mapping] - no_top_k = np.all(top_k == self.vocab_size) - top_k = self._copy_np_to_gpu(top_k) if not no_top_k else None - - frequency_penalties = self.frequency_penalties[idx_mapping] - presence_penalties = self.presence_penalties[idx_mapping] - repetition_penalties = self.repetition_penalties[idx_mapping] - no_penalties = (np.all(frequency_penalties == 0.0) - and np.all(presence_penalties == 0.0) - and np.all(repetition_penalties == 1.0)) - if no_penalties: - frequency_penalties = None - presence_penalties = None - repetition_penalties = None - else: - frequency_penalties = self._copy_np_to_gpu(frequency_penalties) - presence_penalties = self._copy_np_to_gpu(presence_penalties) - repetition_penalties = self._copy_np_to_gpu(repetition_penalties) - - if self.generators: - generators = { - req_idx: self.generators[req_idx] - for req_idx in idx_mapping if req_idx in self.generators - } - else: - generators = {} - - return SamplingMetadata( - temperature=temperature, - all_greedy=all_greedy, - all_random=all_random, - top_p=top_p, - top_k=top_k, - frequency_penalties=frequency_penalties, - presence_penalties=presence_penalties, - repetition_penalties=repetition_penalties, - no_penalties=no_penalties, - generators=generators, - token_ids=None, - num_tokens=None, - num_prompt_tokens=None, - max_num_logprobs=None, - allowed_token_ids_mask=None, - bad_words_token_ids={}, - logitsprocs=None, - ) - - def _copy_np_to_gpu(self, src: np.ndarray) -> torch.Tensor: - cpu_tensor = torch.from_numpy(src) - return cpu_tensor.to(device=self.device, non_blocking=True) - - def make_spec_decode_metadata( - self, - query_start_loc: torch.Tensor, - cu_num_draft_tokens: torch.Tensor, - cu_num_draft_tokens_np: np.ndarray, - input_ids: torch.Tensor, - ) -> SpecDecodeMetadata: - batch_size = query_start_loc.shape[0] - 1 - total_num_draft_tokens = cu_num_draft_tokens_np[batch_size - 1] - logits_indices = torch.empty(total_num_draft_tokens + batch_size, - dtype=torch.int32, - device=self.device) - target_logits_indices = torch.empty(total_num_draft_tokens, - dtype=torch.int32, - device=self.device) - bonus_logits_indices = torch.empty(batch_size, - dtype=torch.int32, - device=self.device) - _prepare_spec_decode_kernel[(batch_size, )]( - query_start_loc, - cu_num_draft_tokens, - logits_indices, - target_logits_indices, - bonus_logits_indices, - BLOCK_SIZE=triton.next_power_of_2(_MAX_SPEC_LEN + 1), - ) - - draft_token_ids = input_ids[logits_indices] - draft_token_ids = draft_token_ids[target_logits_indices + 1] - return SpecDecodeMetadata( - draft_token_ids=draft_token_ids, - num_draft_tokens=cu_num_draft_tokens_np.tolist(), - cu_num_draft_tokens=cu_num_draft_tokens, - target_logits_indices=target_logits_indices, - bonus_logits_indices=bonus_logits_indices, - logits_indices=logits_indices, - ) - - -@triton.jit -def _prepare_spec_decode_kernel( - query_start_loc, # [B + 1] - cu_num_draft_tokens, # [B] - logits_indices, # [N + B] - target_logits_indices, # [N] - bonus_logits_indices, # [B] - BLOCK_SIZE: tl.constexpr, -): - batch_idx = tl.program_id(0) - - if batch_idx == 0: - draft_start_idx = 0 - else: - draft_start_idx = tl.load(cu_num_draft_tokens + batch_idx - 1) - draft_end_idx = tl.load(cu_num_draft_tokens + batch_idx) - draft_len = draft_end_idx - draft_start_idx - sample_len = draft_len + 1 - - q_end_idx = tl.load(query_start_loc + batch_idx + 1) - - sample_start_idx = draft_start_idx + batch_idx - sample_end_idx = sample_start_idx + sample_len - offset = tl.arange(0, BLOCK_SIZE) - tl.store(logits_indices + sample_start_idx + offset, - q_end_idx - sample_len + offset, - mask=offset < sample_len) - tl.store(target_logits_indices + draft_start_idx + offset, - sample_start_idx + offset, - mask=offset < draft_len) - tl.store(bonus_logits_indices + batch_idx, sample_end_idx - 1)