From 33a3a26ca539ebb85cdf790f1e1113d09c2c1326 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 17 Aug 2025 14:38:24 -0700 Subject: [PATCH] wip Signed-off-by: Woosuk Kwon --- vllm/v1/sample/metadata.py | 9 +- vllm/v1/worker/gpu_input_batch.py | 948 ++++++++--------------------- vllm/v1/worker/gpu_model_runner.py | 65 +- 3 files changed, 269 insertions(+), 753 deletions(-) diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index 9d6a87cea3d07..b62b6e5c331ce 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -12,12 +12,12 @@ from vllm.v1.sample.logits_processor import LogitsProcessors @dataclass class SamplingMetadata: - temperature: Optional[torch.Tensor] + temperature: torch.Tensor all_greedy: bool all_random: bool - top_p: Optional[torch.Tensor] - top_k: Optional[torch.Tensor] + top_p: torch.Tensor + top_k: torch.Tensor generators: dict[int, torch.Generator] @@ -25,12 +25,11 @@ class SamplingMetadata: max_num_logprobs: Optional[int] no_penalties: bool - prompt_token_ids: Optional[torch.Tensor] frequency_penalties: torch.Tensor presence_penalties: torch.Tensor repetition_penalties: torch.Tensor - output_token_ids: list[list[int]] + token_ids: Optional[torch.Tensor] # `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size, # vocab size). diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 3d4cf27a6ccf3..53ad0cbaf2cc4 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -3,9 +3,8 @@ # Datastructures defining a GPU input batch from dataclasses import dataclass -from typing import Optional, cast +from typing import Optional -import numpy as np import torch from typing_extensions import deprecated @@ -14,45 +13,25 @@ from vllm.multimodal.inputs import (MultiModalKwargs, MultiModalKwargsItem, PlaceholderRange) from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams, SamplingType -from vllm.utils import swap_dict_values -from vllm.v1.outputs import LogprobsTensors -from vllm.v1.pool.metadata import PoolingMetadata -from vllm.v1.sample.logits_processor import (BatchUpdateBuilder, - LogitsProcessors, - MoveDirectionality) +from vllm.utils import cdiv, get_cuda_view_from_cpu_tensor, is_uva_available +from vllm.v1.sample.logits_processor import LogitsProcessors from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.spec_decode.utils import is_spec_decode_unsupported -from vllm.v1.utils import copy_slice -from vllm.v1.worker.block_table import MultiGroupBlockTable @dataclass class CachedRequestState: req_id: str - prompt_token_ids: list[int] mm_kwargs: list[MultiModalKwargsItem] mm_positions: list[PlaceholderRange] sampling_params: Optional[SamplingParams] pooling_params: Optional[PoolingParams] - generator: Optional[torch.Generator] - - block_ids: tuple[list[int], ...] - num_computed_tokens: int - output_token_ids: list[int] mrope_positions: Optional[torch.Tensor] = None mrope_position_delta: Optional[int] = None lora_request: Optional[LoRARequest] = None - def __post_init__(self): - self.num_prompt_tokens = len(self.prompt_token_ids) - - @property - def num_tokens(self) -> int: - return self.num_prompt_tokens + len(self.output_token_ids) - # Temporary back-compatibility for plugins that define model runner @property @deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be " @@ -60,11 +39,26 @@ class CachedRequestState: def mm_inputs(self) -> list[MultiModalKwargs]: return [MultiModalKwargs([item]) for item in self.mm_kwargs] - def get_token_id(self, idx: int) -> int: - if idx < self.num_prompt_tokens: - return self.prompt_token_ids[idx] - else: - return self.output_token_ids[idx - self.num_prompt_tokens] + +class PerRequestAttribute: + + def __init__( + self, + N: int, + M: int, + K: int, + dtype: torch.dtype, + device: torch.device, + ): + assert is_uva_available(), "UVA is not available." + self.cpu_tensor = torch.zeros(N, + M, + dtype=dtype, + device="cpu", + pin_memory=True) + self.np = self.cpu_tensor.numpy() + self.uva_tensor = get_cuda_view_from_cpu_tensor(self.cpu_tensor) + self.gpu_tensor = torch.zeros(K, M, dtype=dtype, device=device) class InputBatch: @@ -74,6 +68,7 @@ class InputBatch: max_num_reqs: int, max_model_len: int, max_num_batched_tokens: int, + max_num_cached_reqs: int, device: torch.device, pin_memory: bool, vocab_size: int, @@ -82,715 +77,300 @@ class InputBatch: is_spec_decode: bool = False, is_pooling_model: bool = False, ): - self.is_pooling_model = is_pooling_model - self.is_spec_decode = is_spec_decode self.max_num_reqs = max_num_reqs self.max_model_len = max_model_len self.max_num_batched_tokens = max_num_batched_tokens + self.max_num_cached_reqs = max_num_cached_reqs self.device = device self.pin_memory = pin_memory self.vocab_size = vocab_size + self.is_spec_decode = is_spec_decode - self._req_ids: list[Optional[str]] = [] self.req_id_to_index: dict[str, int] = {} + self.index_to_req_id: dict[int, str] = {} + self.free_indices = list(range(max_num_cached_reqs)) + # Request states. # TODO(woosuk): This buffer could be too large if max_model_len is big. - # Find a way to reduce the CPU memory usage. - # This buffer is not directly transferred to the GPU, so it does not - # need to be pinned. - self.token_ids_cpu_tensor = torch.zeros( - (max_num_reqs, max_model_len), - device="cpu", - dtype=torch.int32, - pin_memory=False, - ) - self.token_ids_cpu = self.token_ids_cpu_tensor.numpy() - self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32) - self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32) - self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) - self.num_computed_tokens_cpu_tensor = torch.zeros( - (max_num_reqs, ), - device="cpu", - dtype=torch.int32, - pin_memory=pin_memory, - ) - self.num_computed_tokens_cpu = \ - self.num_computed_tokens_cpu_tensor.numpy() - - # Block table. - self.block_table = MultiGroupBlockTable( - max_num_reqs=max_num_reqs, - max_model_len=max_model_len, - max_num_batched_tokens=max_num_batched_tokens, - pin_memory=pin_memory, - device=device, - block_sizes=block_sizes, - ) + # Find a way to reduce the memory usage. + self._add_vector_attr("token_ids", self.max_model_len, torch.int32) + self._add_scalar_attr("num_prompt_tokens", torch.int32) + self._add_scalar_attr("num_tokens", torch.int32) + self._add_scalar_attr("num_computed_tokens", torch.int32) # Sampling-related. - self.temperature = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) - self.temperature_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=pin_memory) - self.temperature_cpu = self.temperature_cpu_tensor.numpy() + self._add_scalar_attr("temperature", torch.float32) self.greedy_reqs: set[str] = set() self.random_reqs: set[str] = set() - - self.top_p = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) - self.top_p_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=pin_memory) - self.top_p_cpu = self.top_p_cpu_tensor.numpy() + self._add_scalar_attr("top_p", torch.float32) self.top_p_reqs: set[str] = set() - - self.top_k = torch.empty((max_num_reqs, ), - dtype=torch.int32, - device=device) - self.top_k_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.int32, - device="cpu", - pin_memory=pin_memory) - self.top_k_cpu = self.top_k_cpu_tensor.numpy() + self._add_scalar_attr("top_k", torch.int32) self.top_k_reqs: set[str] = set() - - # IDs of requests which do not support spec decoding - self.spec_decode_unsupported_reqs: set[str] = set() - - # Frequency penalty related data structures - self.frequency_penalties = torch.empty((max_num_reqs, ), - dtype=torch.float, - device=device) - self.frequency_penalties_cpu_tensor = torch.empty( - (max_num_reqs, ), - dtype=torch.float, - device="cpu", - pin_memory=pin_memory) - self.frequency_penalties_cpu = \ - self.frequency_penalties_cpu_tensor.numpy() + self._add_scalar_attr("frequency_penalties", torch.float32) self.frequency_penalties_reqs: set[str] = set() - - # Presence penalty related data structures - self.presence_penalties = torch.empty((max_num_reqs, ), - dtype=torch.float, - device=device) - self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float, - device="cpu", - pin_memory=pin_memory) - self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy( - ) + self._add_scalar_attr("presence_penalties", torch.float32) self.presence_penalties_reqs: set[str] = set() - - # Repetition penalty related data structures - self.repetition_penalties = torch.empty((max_num_reqs, ), - dtype=torch.float, - device=device) - self.repetition_penalties_cpu_tensor = torch.empty( - (max_num_reqs, ), - dtype=torch.float, - device="cpu", - pin_memory=pin_memory) - self.repetition_penalties_cpu = \ - self.repetition_penalties_cpu_tensor.numpy() + self._add_scalar_attr("repetition_penalties", torch.float32) self.repetition_penalties_reqs: set[str] = set() - # lora related - self.request_lora_mapping = np.zeros((self.max_num_reqs, ), - dtype=np.int32) - self.lora_id_to_request_ids: dict[int, set[str]] = {} - self.lora_id_to_lora_request: dict[int, LoRARequest] = {} - - # req_index -> generator - # NOTE(woosuk): The indices of the requests that do not have their own - # generator should not be included in the dictionary. + # req_idx -> generator self.generators: dict[int, torch.Generator] = {} - self.num_logprobs: dict[str, int] = {} - # NOTE(rob): num_prompt_logprobs only includes reqs - # that are currently in the prefill phase. - self.num_prompt_logprobs: dict[str, int] = {} - - # To accumulate prompt logprobs tensor chunks across prefill steps. - self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {} - - # Internal representation of per-step batch state changes, used for - # reordering persistent batch and generating logitsprocs batch state - # updates. Should reset each step. - self.batch_update_builder = BatchUpdateBuilder() - - # TODO convert this to LogitsProcessor - self.has_allowed_token_ids: set[str] = set() - # NOTE(lufang): In the mask tensor, if the corresponding token allowed, - # the value is False. Since we use masked_fill_ to set -inf. - self.allowed_token_ids_mask: Optional[torch.Tensor] = None - self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None - - # req_index -> bad_words_token_ids - self.bad_words_token_ids: dict[int, list[list[int]]] = {} - - self.logits_processing_needs_token_ids = np.zeros(max_num_reqs, - dtype=bool) - - self.req_output_token_ids: list[Optional[list[int]]] = [] - - # Store provided logitsprocs. If none are provided, initialize empty - # data structure - self.logitsprocs = logitsprocs or LogitsProcessors() - - # This is updated each time the batch constituents change. - self.sampling_metadata = self._make_sampling_metadata() - - self.pooling_params: dict[str, PoolingParams] = {} - - @property - def req_ids(self) -> list[str]: - # None elements should only be present transiently - # while performing state updates to the batch. - return cast(list[str], self._req_ids) - - def _register_add_request(self, request: "CachedRequestState") -> int: - """Track add-request operations for logits processors. - Not applicable to pooling models. - """ - - # Detailed added request metadata is only required for non-pooling - # models, to support logitsprocs - assert request.sampling_params - - # Fill the next empty index if there is one. - if (new_req_index := self.batch_update_builder.pop_removed()) is None: - # Append to end otherwise. - new_req_index = self.num_reqs - - assert new_req_index < self.max_num_reqs - self.batch_update_builder.added.append( - (new_req_index, request.sampling_params, request.prompt_token_ids, - request.output_token_ids)) - return new_req_index + # Block table(s). + self.block_tables = [] + self.num_blocks = [] + for block_size in block_sizes: + max_num_blocks = cdiv(max_model_len, block_size) + block_table = PerRequestAttribute(self.max_num_cached_reqs, + max_num_blocks, + self.max_num_reqs, torch.int32, + self.device) + self.block_tables.append(block_table) + num_blocks = PerRequestAttribute(self.max_num_cached_reqs, 1, + self.max_num_reqs, torch.int32, + self.device) + self.num_blocks.append(num_blocks) + self.num_block_tables = len(block_sizes) def add_request( self, - request: "CachedRequestState", - ) -> int: - if not self.is_pooling_model: - # New request index bookkeeping for autoregressive models. - req_index = self._register_add_request(request) + req_id: str, + prompt_token_ids: list[int], + num_computed_tokens: int, + block_ids: tuple[list[int], ...], + sampling_params: SamplingParams, + ) -> None: + req_idx = self.free_indices.pop() + self.req_id_to_index[req_id] = req_idx + self.index_to_req_id[req_idx] = req_id + + num_prompt_tokens = len(prompt_token_ids) + self.token_ids.np[req_idx, :num_prompt_tokens] = prompt_token_ids + self.num_prompt_tokens.np[req_idx] = num_prompt_tokens + self.num_tokens.np[req_idx] = num_prompt_tokens + self.num_computed_tokens.np[req_idx] = num_computed_tokens + + self.temperature.np[req_idx] = sampling_params.temperature + if sampling_params.sampling_type == SamplingType.GREEDY: + # NOTE: Be careful about division by zero. + self.greedy_reqs.add(req_id) + elif sampling_params.sampling_type == SamplingType.RANDOM: + self.random_reqs.add(req_id) + + self.top_p.np[req_idx] = sampling_params.top_p + if sampling_params.top_p < 1.0: + self.top_p_reqs.add(req_id) + + top_k = sampling_params.top_k + if 0 < top_k < self.vocab_size: + self.top_k_reqs.add(req_id) else: - req_index = self.num_reqs + top_k = self.vocab_size + self.top_k.np[req_idx] = top_k - req_id = request.req_id - if req_index == len(self._req_ids): - self._req_ids.append(req_id) - self.req_output_token_ids.append(request.output_token_ids) - else: - self._req_ids[req_index] = req_id - self.req_output_token_ids[req_index] = request.output_token_ids + self.frequency_penalties.np[ + req_idx] = sampling_params.frequency_penalties + if sampling_params.frequency_penalties != 0.0: + self.frequency_penalties_reqs.add(req_id) + self.presence_penalties.np[ + req_idx] = sampling_params.presence_penalties + if sampling_params.presence_penalties != 0.0: + self.presence_penalties_reqs.add(req_id) + self.repetition_penalties.np[ + req_idx] = sampling_params.repetition_penalties + if sampling_params.repetition_penalties != 1.0: + self.repetition_penalties_reqs.add(req_id) - self.req_id_to_index[req_id] = req_index + if sampling_params.seed is not None: + generator = torch.Generator(device=self.device) + generator.manual_seed(sampling_params.seed) + self.generators[req_idx] = generator - # Copy the prompt token ids and output token ids. - num_prompt_tokens = len(request.prompt_token_ids) - self.num_prompt_tokens[req_index] = num_prompt_tokens - self.token_ids_cpu[ - req_index, :num_prompt_tokens] = request.prompt_token_ids - start_idx = num_prompt_tokens - end_idx = start_idx + len(request.output_token_ids) - self.token_ids_cpu[req_index, - start_idx:end_idx] = request.output_token_ids - # Number of token ids in token_ids_cpu. - # NOTE(woosuk): This may include spec decode tokens. - self.num_tokens[req_index] = request.num_tokens - # Number of tokens without spec decode tokens. - self.num_tokens_no_spec[req_index] = request.num_tokens + for i in range(self.num_block_tables): + self.block_tables[i].np[req_idx, :len(block_ids[i])] = block_ids[i] + self.num_blocks[i].np[req_idx] = len(block_ids[i]) - self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens - self.block_table.add_row(request.block_ids, req_index) + def append_token_ids(self, req_id: str, token_ids: list[int]) -> None: + req_idx = self.req_id_to_index.get(req_id) + assert req_idx is not None + start_idx = self.num_tokens.np[req_idx] + end_idx = start_idx + len(token_ids) + self.token_ids.np[req_idx, start_idx:end_idx] = token_ids + self.num_tokens.np[req_idx] = end_idx - if sampling_params := request.sampling_params: - if (self.is_spec_decode - and is_spec_decode_unsupported(sampling_params)): - self.spec_decode_unsupported_reqs.add(req_id) - if sampling_params.sampling_type == SamplingType.GREEDY: - # Avoid later division by zero. - self.temperature_cpu[req_index] = -1.0 - self.greedy_reqs.add(req_id) + def append_block_ids( + self, + req_id: str, + new_block_ids: tuple[list[int], ...], + overwrite: bool, + ) -> None: + req_idx = self.req_id_to_index.get(req_id) + assert req_idx is not None + for i in range(self.num_block_tables): + block_table = self.block_tables[i] + num_blocks = self.num_blocks[i] + if overwrite: + # Replace the existing block IDs with the new ones. + # This happens when the request is resumed from preemption. + block_table.np[ + req_idx, :len(new_block_ids[i])] = new_block_ids[i] + num_blocks.np[req_idx] = len(new_block_ids[i]) else: - self.temperature_cpu[req_index] = sampling_params.temperature - self.random_reqs.add(req_id) + # Append the new block IDs to the existing ones (common case). + start_idx = num_blocks.np[req_idx] + end_idx = start_idx + len(new_block_ids[i]) + block_table.np[req_idx, start_idx:end_idx] = new_block_ids[i] + num_blocks.np[req_idx] = end_idx - self.top_p_cpu[req_index] = sampling_params.top_p - if sampling_params.top_p < 1: - self.top_p_reqs.add(req_id) - top_k = sampling_params.top_k - if 0 < top_k < self.vocab_size: - self.top_k_reqs.add(req_id) - else: - top_k = self.vocab_size - self.top_k_cpu[req_index] = top_k - self.frequency_penalties_cpu[ - req_index] = sampling_params.frequency_penalty - if sampling_params.frequency_penalty != 0.0: - self.frequency_penalties_reqs.add(req_id) - self.presence_penalties_cpu[ - req_index] = sampling_params.presence_penalty - if sampling_params.presence_penalty != 0.0: - self.presence_penalties_reqs.add(req_id) - self.repetition_penalties_cpu[ - req_index] = sampling_params.repetition_penalty - if sampling_params.repetition_penalty != 1.0: - self.repetition_penalties_reqs.add(req_id) - - # NOTE(woosuk): self.generators should not include the requests that - # do not have their own generator. - if request.generator is not None: - self.generators[req_index] = request.generator - - if sampling_params.logprobs is not None: - self.num_logprobs[req_id] = (self.vocab_size - if sampling_params.logprobs == -1 - else sampling_params.logprobs) - if sampling_params.prompt_logprobs is not None: - self.num_prompt_logprobs[ - req_id] = sampling_params.prompt_logprobs - - if sampling_params.allowed_token_ids: - self.has_allowed_token_ids.add(req_id) - if self.allowed_token_ids_mask_cpu_tensor is None: - # Lazy allocation for this tensor, which can be large. - # False means we don't fill with -inf. - self.allowed_token_ids_mask = torch.zeros( - self.max_num_reqs, - self.vocab_size, - dtype=torch.bool, - device=self.device) - self.allowed_token_ids_mask_cpu_tensor = torch.zeros( - self.max_num_reqs, - self.vocab_size, - dtype=torch.bool, - device="cpu") - self.allowed_token_ids_mask_cpu_tensor[req_index] = True - # False means we don't fill with -inf. - self.allowed_token_ids_mask_cpu_tensor[req_index][ - sampling_params.allowed_token_ids] = False - - if sampling_params.bad_words_token_ids: - self.bad_words_token_ids[ - req_index] = sampling_params.bad_words_token_ids - elif pooling_params := request.pooling_params: - self.pooling_params[req_id] = pooling_params - self.logits_processing_needs_token_ids[req_index] = ( - pooling_params.requires_token_ids) - else: - raise NotImplementedError(request) - - # Add request lora ID - if request.lora_request: - lora_id = request.lora_request.lora_int_id - if lora_id not in self.lora_id_to_request_ids: - self.lora_id_to_request_ids[lora_id] = set() - - self.request_lora_mapping[req_index] = lora_id - self.lora_id_to_request_ids[lora_id].add(request.req_id) - self.lora_id_to_lora_request[lora_id] = request.lora_request - else: - # No LoRA - self.request_lora_mapping[req_index] = 0 - - return req_index - - def remove_request(self, req_id: str) -> Optional[int]: - """This method must always be followed by a call to condense(). - - Args: - req_id: request to remove - - Returns: - Removed request index, or `None` if `req_id` not recognized - """ - - req_index = self.req_id_to_index.pop(req_id, None) - if req_index is None: - return None - if not self.is_pooling_model: - # Autoregressive models require bookkeeping of removed requests to - # support logitsprocs. - self.batch_update_builder.removed_append(req_index) - self._req_ids[req_index] = None - self.req_output_token_ids[req_index] = None + def remove_request(self, req_id: str) -> None: + req_idx = self.req_id_to_index.pop(req_id, None) + if req_idx is None: + # Request not found. + return + self.index_to_req_id.pop(req_idx, None) + self.free_indices.append(req_idx) self.greedy_reqs.discard(req_id) self.random_reqs.discard(req_id) self.top_p_reqs.discard(req_id) self.top_k_reqs.discard(req_id) - self.spec_decode_unsupported_reqs.discard(req_id) self.frequency_penalties_reqs.discard(req_id) self.presence_penalties_reqs.discard(req_id) self.repetition_penalties_reqs.discard(req_id) - self.generators.pop(req_index, None) - self.num_logprobs.pop(req_id, None) - self.num_prompt_logprobs.pop(req_id, None) - self.in_progress_prompt_logprobs_cpu.pop(req_id, None) + self.generators.pop(req_idx, None) - # LoRA - lora_id = self.request_lora_mapping[req_index] - if lora_id != 0: - self.lora_id_to_request_ids[lora_id].discard(req_id) - if len(self.lora_id_to_request_ids[lora_id]) == 0: - self.lora_id_to_request_ids.pop(lora_id) - self.lora_id_to_lora_request.pop(lora_id) - self.request_lora_mapping[req_index] = 0 - - self.has_allowed_token_ids.discard(req_id) - if self.allowed_token_ids_mask_cpu_tensor is not None: - # False means we don't fill with -inf. - self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False) - self.bad_words_token_ids.pop(req_index, None) - self.pooling_params.pop(req_id, None) - return req_index - - def swap_states(self, i1: int, i2: int) -> None: - # For autoregressive models, track detailed request reordering info - # to support logitsprocs - self.batch_update_builder.moved.append( - (i1, i2, MoveDirectionality.SWAP)) - old_id_i1 = self._req_ids[i1] - old_id_i2 = self._req_ids[i2] - self._req_ids[i1], self._req_ids[i2] =\ - self._req_ids[i2], self._req_ids[i1] # noqa - self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\ - self.req_output_token_ids[i2], self.req_output_token_ids[i1] - assert old_id_i1 is not None and old_id_i2 is not None - self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\ - self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1] - self.num_tokens[i1], self.num_tokens[i2] =\ - self.num_tokens[i2], self.num_tokens[i1] - self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\ - self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1] - self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\ - self.num_prompt_tokens[i2], self.num_prompt_tokens[i1] - self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\ - self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1] - self.temperature_cpu[i1], self.temperature_cpu[i2] =\ - self.temperature_cpu[i2], self.temperature_cpu[i1] - self.top_p_cpu[i1], self.top_p_cpu[i2] =\ - self.top_p_cpu[i2], self.top_p_cpu[i1] - self.top_k_cpu[i1], self.top_k_cpu[i2] =\ - self.top_k_cpu[i2], self.top_k_cpu[i1] - self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\ - self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1] - self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\ - self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1] - self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\ - self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1] - - # NOTE: the following is unsafe - # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\ - # self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...] - # instead, we need to temporiarily copy the data for one of the indices - # TODO(lucas): optimize this by only copying valid indices - tmp = self.token_ids_cpu[i1, ...].copy() - self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...] - self.token_ids_cpu[i2, ...] = tmp - - swap_dict_values(self.generators, i1, i2) - swap_dict_values(self.bad_words_token_ids, i1, i2) - - self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\ - self.request_lora_mapping[i2], self.request_lora_mapping[i1] - - if self.allowed_token_ids_mask_cpu_tensor is not None: - self.allowed_token_ids_mask_cpu_tensor[i1], \ - self.allowed_token_ids_mask_cpu_tensor[i2] =\ - self.allowed_token_ids_mask_cpu_tensor[i2], \ - self.allowed_token_ids_mask_cpu_tensor[i1] - self.block_table.swap_row(i1, i2) - - def condense(self) -> None: - """Slide non-empty requests down into lower, empty indices. - - Any consecutive empty indices at the very end of the list are not - filled. - - Args: - empty_req_indices: empty indices which may be filled. - - Returns: - swaps: list of (from,to) swap tuples for moved requests - empty_req_indices: indices not filled by condensation - """ - num_reqs = self.num_reqs - - if self.is_pooling_model: - # Will be contiguous in pooling case, just trim the lists. - del self._req_ids[num_reqs:] - del self.req_output_token_ids[num_reqs:] - return - - if not (empty_req_indices := self.batch_update_builder.removed): - # All removed requests were replaced by added requests, or else no - # requests were removed at all. No condense() needed - return - if num_reqs == 0: - # The batched states are empty. - self._req_ids.clear() - self.req_output_token_ids.clear() - return - - # NOTE(woosuk): This function assumes that the empty_req_indices - # is sorted in descending order. - last_req_index = num_reqs + len(empty_req_indices) - 1 - while empty_req_indices: - # Find the largest non-empty index. - while last_req_index in empty_req_indices: - last_req_index -= 1 - - # Find the smallest empty index. - empty_index = self.batch_update_builder.peek_removed() - assert empty_index is not None - if empty_index >= last_req_index: - break - - # Move active request down into empty request - # index. - self.batch_update_builder.pop_removed() - # Autoregressive models require detailed tracking of condense - # operations to support logitsprocs - self.batch_update_builder.moved.append( - (last_req_index, empty_index, - MoveDirectionality.UNIDIRECTIONAL)) - req_id = self._req_ids[last_req_index] - output_token_ids = self.req_output_token_ids[last_req_index] - assert req_id is not None - self._req_ids[empty_index] = req_id - self._req_ids[last_req_index] = None - self.req_output_token_ids[empty_index] = output_token_ids - self.req_output_token_ids[last_req_index] = None - self.req_id_to_index[req_id] = empty_index - - num_tokens = self.num_tokens[last_req_index] - self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[ - last_req_index, :num_tokens] - self.num_tokens[empty_index] = num_tokens - self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[ - last_req_index] - self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[ - last_req_index] - self.num_computed_tokens_cpu[ - empty_index] = self.num_computed_tokens_cpu[last_req_index] - self.block_table.move_row(last_req_index, empty_index) - self.temperature_cpu[empty_index] = self.temperature_cpu[ - last_req_index] - self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] - self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] - self.frequency_penalties_cpu[ - empty_index] = self.frequency_penalties_cpu[last_req_index] - self.presence_penalties_cpu[ - empty_index] = self.presence_penalties_cpu[last_req_index] - self.repetition_penalties_cpu[ - empty_index] = self.repetition_penalties_cpu[last_req_index] - generator = self.generators.pop(last_req_index, None) - if generator is not None: - self.generators[empty_index] = generator - - self.request_lora_mapping[empty_index] = self.request_lora_mapping[ - last_req_index] - - # TODO convert these to LogitsProcessors - if self.allowed_token_ids_mask_cpu_tensor is not None: - self.allowed_token_ids_mask_cpu_tensor[ - empty_index] = self.allowed_token_ids_mask_cpu_tensor[ - last_req_index] - - bad_words_token_ids = self.bad_words_token_ids.pop( - last_req_index, None) - if bad_words_token_ids is not None: - self.bad_words_token_ids[empty_index] = bad_words_token_ids - - # Decrement last_req_index since it is now empty. - last_req_index -= 1 - - # Trim lists to the batch size. - del self._req_ids[num_reqs:] - del self.req_output_token_ids[num_reqs:] - - def refresh_metadata(self): - """Apply any batch updates to sampling metadata.""" - - if self.is_pooling_model: - # Batch changes every step for pooling models. - self.sampling_metadata = self._make_sampling_metadata() - return - - # For non-pooling models - generate and apply logitsprocs update; - # reset batch update tracking. - # Update sampling metadata if batch state is changed. - batch_update = self.batch_update_builder.get_and_reset(self.num_reqs) - for logit_proc in self.logitsprocs.all: - logit_proc.update_state(batch_update) - if batch_update: - self.sampling_metadata = self._make_sampling_metadata() - - def _make_sampling_metadata(self) -> SamplingMetadata: - num_reqs = self.num_reqs - if not self.all_greedy: - temperature = copy_slice(self.temperature_cpu_tensor, - self.temperature, num_reqs) - else: - temperature = None - if not self.no_top_p: - copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs) - if not self.no_top_k: - copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs) - - if not self.no_penalties: - # Since syncing these tensors is expensive only copy them - # if necessary i.e. if there are requests which require - # penalties to be applied during sampling. - copy_slice(self.frequency_penalties_cpu_tensor, - self.frequency_penalties, num_reqs) - copy_slice(self.presence_penalties_cpu_tensor, - self.presence_penalties, num_reqs) - copy_slice(self.repetition_penalties_cpu_tensor, - self.repetition_penalties, num_reqs) - - needs_prompt_token_ids = ( - not self.no_penalties - or self.logits_processing_needs_token_ids[:num_reqs].any()) - if needs_prompt_token_ids: - # The prompt tokens are used only for applying penalties or - # step pooling during the sampling/pooling process. - # Hence copy these tensors only when there are requests which - # need penalties/step_pooler to be applied. - prompt_token_ids = self._make_prompt_token_ids_tensor() - else: - prompt_token_ids = None - - allowed_token_ids_mask: Optional[torch.Tensor] = None - if not self.no_allowed_token_ids: - assert self.allowed_token_ids_mask is not None - copy_slice(self.allowed_token_ids_mask_cpu_tensor, - self.allowed_token_ids_mask, num_reqs) - allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs] + def make_block_table(self, req_idx: int) -> tuple[torch.Tensor, ...]: + pass + def make_sampling_metadata(self, + req_indices: list[int]) -> SamplingMetadata: + batch_size = len(req_indices) + _make_sampling_metadata_kernel[(batch_size, )]( + req_indices, + self.temperature.uva_tensor, + self.temperature.gpu_tensor, + self.top_p.uva_tensor, + self.top_p.gpu_tensor, + self.top_k.uva_tensor, + self.top_k.gpu_tensor, + self.frequency_penalties.uva_tensor, + self.frequency_penalties.gpu_tensor, + self.presence_penalties.uva_tensor, + self.presence_penalties.gpu_tensor, + self.repetition_penalties.uva_tensor, + self.repetition_penalties.gpu_tensor, + num_warps=1, + num_stages=1, + ) + generators = {} + if self.generators: + for i, req_idx in enumerate(req_indices): + generator = self.generators.get(req_idx) + if generator is not None: + generators[i] = generator + no_penalties = not (self.frequency_penalties_reqs + or self.presence_penalties_reqs + or self.repetition_penalties_reqs) return SamplingMetadata( - temperature=temperature, - all_greedy=self.all_greedy, - all_random=self.all_random, - top_p=None if self.no_top_p else self.top_p[:num_reqs], - top_k=None if self.no_top_k else self.top_k[:num_reqs], - generators=self.generators, - max_num_logprobs=self.max_num_logprobs, - prompt_token_ids=prompt_token_ids, - frequency_penalties=self.frequency_penalties[:num_reqs], - presence_penalties=self.presence_penalties[:num_reqs], - repetition_penalties=self.repetition_penalties[:num_reqs], - output_token_ids=cast(list[list[int]], self.req_output_token_ids), - no_penalties=self.no_penalties, - allowed_token_ids_mask=allowed_token_ids_mask, - bad_words_token_ids=self.bad_words_token_ids, - logitsprocs=self.logitsprocs, + temperature=self.temperature.gpu_tensor[:batch_size], + all_greedy=not self.random_reqs, + all_random=not self.greedy_reqs, + top_p=self.top_p.gpu_tensor[:batch_size], + top_k=self.top_k.gpu_tensor[:batch_size], + frequency_penalties=self.frequency_penalties. + gpu_tensor[:batch_size], + presence_penalties=self.presence_penalties.gpu_tensor[:batch_size], + repetition_penalties=self.repetition_penalties. + gpu_tensor[:batch_size], + no_penalties=no_penalties, + generators=generators, + token_ids=self.token_ids.gpu_tensor[:batch_size], + max_num_logprobs=None, + allowed_token_ids_mask=None, + bad_words_token_ids={}, + logitsprocs=None, ) - @property - def pooling_metadata(self) -> PoolingMetadata: - if len(self.pooling_params) == 0: - pooling_params = [] - else: - # Note, for now this assumes that all request in the batch - # are either sampling or pooling requests - assert len(self.req_ids) == len(self.pooling_params) - pooling_params = [ - self.pooling_params[req_id] for req_id in self.req_ids - ] - - return PoolingMetadata( - prompt_lens=torch.from_numpy( - self.num_prompt_tokens[:self.num_reqs]).to(self.device), - prompt_token_ids=self.sampling_metadata.prompt_token_ids, - pooling_params=pooling_params, - ) - - def _make_prompt_token_ids_tensor(self) -> torch.Tensor: - max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max() - prompt_token_ids_cpu_tensor = torch.empty( - (self.num_reqs, max_prompt_len), - device="cpu", - dtype=torch.int64, - pin_memory=self.pin_memory, - ) - prompt_token_ids = prompt_token_ids_cpu_tensor.numpy() - prompt_token_ids[:] = self.token_ids_cpu[:self. - num_reqs, :max_prompt_len] - # Use the value of vocab_size as a pad since we don't have a - # token_id of this value. - for i in range(self.num_reqs): - prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size - return prompt_token_ids_cpu_tensor.to(device=self.device, - non_blocking=True) - - def make_lora_inputs( - self, num_scheduled_tokens: np.ndarray - ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]: - """ - Given the num_scheduled_tokens for each request in the batch, return - datastructures used to activate the current LoRAs. - Returns: - 1. prompt_lora_mapping: A tuple of size self.num_reqs where, - prompt_lora_mapping[i] is the LoRA id to use for the ith prompt. - 2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens) - where, token_lora_mapping[i] is the LoRA id to use for ith token. - 3. lora_requests: Set of relevant LoRA requests. - """ - - req_lora_mapping = self.request_lora_mapping[:self.num_reqs] - prompt_lora_mapping = tuple(req_lora_mapping) - token_lora_mapping = tuple( - req_lora_mapping.repeat(num_scheduled_tokens)) - active_lora_requests: set[LoRARequest] = set( - self.lora_id_to_lora_request.values()) - - return prompt_lora_mapping, token_lora_mapping, active_lora_requests - @property def num_reqs(self) -> int: return len(self.req_id_to_index) - @property - def all_greedy(self) -> bool: - return len(self.random_reqs) == 0 + def _add_vector_attr(self, name: str, max_len: int, dtype: torch.dtype): + attr = PerRequestAttribute(self.max_num_cached_reqs, max_len, + self.max_num_reqs, dtype, self.device) + setattr(self, name, attr) - @property - def all_random(self) -> bool: - return len(self.greedy_reqs) == 0 + def _add_scalar_attr(self, name: str, dtype: torch.dtype): + self._add_vector_attr(name, max_len=1, dtype=dtype) - @property - def no_top_p(self) -> bool: - return len(self.top_p_reqs) == 0 - @property - def no_top_k(self) -> bool: - return len(self.top_k_reqs) == 0 +import triton +import triton.language as tl - @property - def no_penalties(self) -> bool: - return (len(self.presence_penalties_reqs) == 0 - and len(self.frequency_penalties_reqs) == 0 - and len(self.repetition_penalties_reqs) == 0) - @property - def max_num_logprobs(self) -> Optional[int]: - return max(self.num_logprobs.values()) if self.num_logprobs else None +@triton.jit +def _make_sampling_metadata_kernel( + req_indices, # [batch_size] + src_temperature, + dst_temperature, + src_top_p, + dst_top_p, + src_top_k, + dst_top_k, + src_frequency_penalties, + dst_frequency_penalties, + src_presence_penalties, + dst_presence_penalties, + src_repetition_penalties, + dst_repetition_penalties, +): + batch_idx = tl.program_id(0) + req_index = tl.load(req_indices + batch_idx) - @property - def no_prompt_logprob(self) -> bool: - return not self.num_prompt_logprobs + temperature = tl.load(src_temperature + req_index) + tl.store(dst_temperature + req_index, temperature) - @property - def no_allowed_token_ids(self) -> bool: - return len(self.has_allowed_token_ids) == 0 + top_p = tl.load(src_top_p + req_index) + tl.store(dst_top_p + req_index, top_p) + + top_k = tl.load(src_top_k + req_index) + tl.store(dst_top_k + req_index, top_k) + + frequency_penalties = tl.load(src_frequency_penalties + req_index) + tl.store(dst_frequency_penalties + req_index, frequency_penalties) + + presence_penalties = tl.load(src_presence_penalties + req_index) + tl.store(dst_presence_penalties + req_index, presence_penalties) + + repetition_penalties = tl.load(src_repetition_penalties + req_index) + tl.store(dst_repetition_penalties + req_index, repetition_penalties) + + +@triton.jit +def _make_block_table_kernel( + req_indices, # [batch_size] + src_block_table_ptrs, + dst_block_table_ptrs, + src_num_blocks_ptrs, + dst_num_blocks_ptrs, + num_block_tables: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + batch_idx = tl.program_id(0) + req_index = tl.load(req_indices + batch_idx) + + for i in tl.range(num_block_tables): + src_num_blocks_ptr = tl.load(src_num_blocks_ptrs + i) + dst_num_blocks_ptr = tl.load(dst_num_blocks_ptrs + i) + num_blocks = tl.load(src_num_blocks_ptr + req_index) + tl.store(dst_num_blocks_ptr + req_index, num_blocks) + + src_block_table_ptr = tl.load(src_block_table_ptrs + i) + dst_block_table_ptr = tl.load(dst_block_table_ptrs + i) + for j in tl.range(num_blocks, BLOCK_SIZE): + offset = tl.arange(0, BLOCK_SIZE) + block_ids = tl.load(src_block_table_ptr + j * BLOCK_SIZE + offset, + mask=offset < num_blocks) + tl.store(dst_block_table_ptr + j * BLOCK_SIZE + offset, + block_ids, + mask=offset < num_blocks) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4219d9147ada2..e0c85f025f63a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -57,8 +57,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, - make_kv_sharing_fast_prefill_attention_metadata, - reorder_batch_to_split_decodes_and_prefills) + make_kv_sharing_fast_prefill_attention_metadata) from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.kv_cache_interface import (AttentionSpec, ChunkedLocalAttentionSpec, @@ -288,35 +287,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): dtype=self.dtype, device=self.device) - # OPTIMIZATION: Cache the tensors rather than creating them every step. - # Keep in int64 to avoid overflow with long context - self.arange_np = np.arange(max(self.max_num_reqs + 1, - self.max_model_len, - self.max_num_tokens), - dtype=np.int64) - # NOTE(woosuk): These tensors are "stateless", i.e., they are literally - # a faster version of creating a new tensor every time. Thus, we should - # not make any assumptions about the values in these tensors. - self.input_ids_cpu = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) - self.positions_cpu = torch.zeros(self.max_num_tokens, - dtype=torch.int64, - device="cpu", - pin_memory=self.pin_memory) - self.positions_np = self.positions_cpu.numpy() - self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) - self.query_start_loc_np = self.query_start_loc_cpu.numpy() - self.seq_lens_cpu = torch.zeros(self.max_num_reqs, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) - self.seq_lens_np = self.seq_lens_cpu.numpy() - # Layer pairings for cross-layer KV sharing. # If an Attention layer `layer_name` is in the keys of this dict, it # means this layer will perform attention using the keys and values @@ -344,8 +314,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) if self.supports_mm_inputs \ else None) - self.reorder_batch_threshold: Optional[int] = None - def _init_model_kwargs(self, num_tokens: int): model_kwargs = dict[str, Any]() num_reqs = self.input_batch.num_reqs @@ -381,30 +349,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): device=self.device) return model_kwargs - def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: - """ - Update the order of requests in the batch based on the attention - backend's needs. For example, some attention backends (namely MLA) may - want to separate requests based on if the attention computation will be - compute-bound or memory-bound. - - Args: - scheduler_output: The scheduler output. - """ - # Attention free models have zero kv_cache_goups, however models - # like Mamba are also attention free but use the kv_cache for - # keeping its internal state. This is why we check the number - # of kv_cache groups instead of solely checking - # for self.model_config.is_attention_free. - if len(self.kv_cache_config.kv_cache_groups) == 0: - return - - if self.reorder_batch_threshold is not None: - reorder_batch_to_split_decodes_and_prefills( - self.input_batch, - scheduler_output, - decode_threshold=self.reorder_batch_threshold) - # Note: used for model runner override. def _init_device_properties(self) -> None: """Initialize attributes from torch.cuda.get_device_properties @@ -621,13 +565,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_state = self.requests[req_id] self.input_batch.add_request(req_state) - # Condense the batched states if there are gaps left by removed requests - self.input_batch.condense() - # Allow attention backend to reorder the batch, potentially - self._may_reorder_batch(scheduler_output) - # Refresh batch metadata with any pending updates. - self.input_batch.refresh_metadata() - def _extract_mm_kwargs( self, scheduler_output: "SchedulerOutput",