From 6283995a6cd91fb5d8720fa31e09024c270b0008 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 6 Sep 2025 21:18:16 -0700 Subject: [PATCH] minor Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu_input_batch.py | 57 ++++++++++ vllm/v1/worker/gpu_model_runner.py | 4 +- vllm/v1/worker/gpu_worker_states.py | 168 ++++++++++++---------------- 3 files changed, 128 insertions(+), 101 deletions(-) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index c91236f63e9e0..250b390422c64 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -3,8 +3,10 @@ from dataclasses import dataclass from typing import Any, Optional +import numba import numpy as np import torch +from numba import types from vllm.v1.spec_decode.metadata import SpecDecodeMetadata @@ -33,3 +35,58 @@ class InputBatch: spec_decode_metadata: Optional[SpecDecodeMetadata] logits_indices: torch.Tensor + + +# NOTE: With the type annotations, this function is pre-compiled +# before the first call. +@numba.jit( + [ + types.none( + types.int32[:], # idx_mapping + types.int32[:, :], # token_ids + types.int32[:], # num_computed_tokens + types.int32[:], # num_scheduled_tokens + types.int32[:], # input_ids + types.int32[:], # query_start_loc + types.int32[:], # seq_lens + types.int64[:], # positions + ) + ], + nopython=True, + cache=True, +) +def prepare_inputs( + idx_mapping: np.ndarray, # batch_idx -> req_idx + token_ids: np.ndarray, # [N, max_model_len] + num_computed_tokens: np.ndarray, # [N] + num_scheduled_tokens: np.ndarray, # [B] + input_ids: np.ndarray, # [num_input_tokens] + query_start_loc: np.ndarray, # [B + 1] + seq_lens: np.ndarray, # [B] + positions: np.ndarray, # [num_input_tokens] +) -> None: + num_reqs = num_scheduled_tokens.shape[0] + query_start_loc[0] = 0 + + cu_num_tokens = 0 + for i in range(num_reqs): + req_idx = idx_mapping[i] + query_len = num_scheduled_tokens[i] + start = num_computed_tokens[req_idx] + end = start + query_len + seq_lens[i] = end + + start_idx = cu_num_tokens + end_idx = start_idx + query_len + input_ids[start_idx:end_idx] = token_ids[req_idx, start:end] + positions[start_idx:end_idx] = np.arange(start, end, dtype=np.int64) + + cu_num_tokens = end_idx + query_start_loc[i + 1] = cu_num_tokens + + # Pad the inputs for CUDA graphs. + # Note: pad query_start_loc to be non-decreasing, as kernels + # like FlashAttention requires that + query_start_loc[num_reqs + 1:].fill(cu_num_tokens) + # Fill unused with 0 for full cuda graph mode. + seq_lens[num_reqs:].fill(0) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9b101c770083f..baa51eabd57d3 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -80,8 +80,8 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext from vllm.v1.worker.gpu_block_table import BlockTables -from vllm.v1.worker.gpu_input_batch import InputBatch -from vllm.v1.worker.gpu_worker_states import RequestState, prepare_inputs +from vllm.v1.worker.gpu_input_batch import InputBatch, prepare_inputs +from vllm.v1.worker.gpu_worker_states import RequestState from vllm.v1.worker.kv_connector_model_runner_mixin import ( KVConnectorModelRunnerMixin, KVConnectorOutput) from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin diff --git a/vllm/v1/worker/gpu_worker_states.py b/vllm/v1/worker/gpu_worker_states.py index 38e0cc3921bc0..1612755f669e4 100644 --- a/vllm/v1/worker/gpu_worker_states.py +++ b/vllm/v1/worker/gpu_worker_states.py @@ -4,12 +4,10 @@ from dataclasses import dataclass from typing import Optional, Union -import numba import numpy as np import torch import triton import triton.language as tl -from numba import types from typing_extensions import deprecated from vllm.lora.request import LoRARequest @@ -49,38 +47,76 @@ class RequestData: ] -class Param: +class SamplingStates: def __init__( self, - num_rows_cpu: int, - num_cols: int, - num_rows_gpu: int, - dtype: torch.dtype, + max_num_reqs: int, + max_model_len: int, + max_num_cached_reqs: int, + vocab_size: int, device: torch.device, - pin_memory: bool, - is_scalar: bool = False, ): - self.cpu = torch.zeros(num_rows_cpu, - num_cols, - dtype=dtype, - device="cpu", - pin_memory=pin_memory) - self.np = self.cpu.numpy() - self.gpu = torch.zeros(num_rows_gpu, - num_cols, - dtype=dtype, - device=device) - if is_scalar: - self.cpu.squeeze_(1) - self.np = self.cpu.numpy() - self.gpu.squeeze_(1) + self.max_num_reqs = max_num_reqs + self.max_model_len = max_model_len + self.max_num_cached_reqs = max_num_cached_reqs + self.vocab_size = vocab_size + self.device = device - # TODO(woosuk): Optimize this. - self.gpu_buffer = self.cpu.to(device) + self.temperature = self._make_param(torch.float32) + self.greedy_req_indices: set[int] = set() + self.top_p = self._make_param(torch.float32) + self.top_p_req_indices: set[int] = set() + self.top_k = self._make_param(torch.int32) + self.top_k_req_indices: set[int] = set() - def mirror_to_gpu(self) -> torch.Tensor: - return self.gpu_buffer.copy_(self.cpu, non_blocking=True) + self.frequency_penalties = self._make_param(torch.float32) + self.presence_penalties = self._make_param(torch.float32) + self.repetition_penalties = self._make_param(torch.float32) + self.penalty_req_indices: set[int] = set() + + self.generators: dict[int, torch.Generator] = {} + + def _make_param(self, dtype: torch.dtype) -> torch.Tensor: + return torch.zeros(self.max_num_reqs, dtype=dtype, device=self.device) + + def add_requests( + self, + req_indices: list[int], + sampling_params: list[SamplingParams], + ) -> None: + num_reqs = len(req_indices) + for i in range(num_reqs): + req_idx = req_indices[i] + sampling_param = sampling_params[i] + + temp = sampling_param.temperature + if temp == 0.0: + self.greedy_req_indices.add(req_idx) + + top_p = sampling_param.top_p + if top_p < 1.0: + self.top_p_req_indices.add(req_idx) + top_k = sampling_param.top_k + if 0 < top_k < self.vocab_size: + self.top_k_req_indices.add(req_idx) + else: + top_k = self.vocab_size + + if sampling_param.frequency_penalty != 0.0 or sampling_param.presence_penalty != 0.0 or sampling_param.repetition_penalty != 1.0: + self.penalty_req_indices.add(req_idx) + + if sampling_param.sampling_type == SamplingType.RANDOM_SEED: + generator = torch.Generator(device=self.device) + generator.manual_seed(sampling_param.seed) + self.generators[req_idx] = generator + + def remove_request(self, req_idx: int) -> None: + self.greedy_req_indices.discard(req_idx) + self.top_p_req_indices.discard(req_idx) + self.top_k_req_indices.discard(req_idx) + self.penalty_req_indices.discard(req_idx) + self.generators.pop(req_idx, None) class RequestState: @@ -128,23 +164,12 @@ class RequestState: self.num_tokens = self._make_param(torch.int32) self.num_computed_tokens = self._make_param(torch.int32) - # Sampling-related. - self.temperature = self._make_param(torch.float32) - self.greedy_reqs: set[str] = set() - self.random_reqs: set[str] = set() - self.top_p = self._make_param(torch.float32) - self.top_p_reqs: set[str] = set() - self.top_k = self._make_param(torch.int32) - self.top_k_reqs: set[str] = set() - self.frequency_penalties = self._make_param(torch.float32) - self.frequency_penalties_reqs: set[str] = set() - self.presence_penalties = self._make_param(torch.float32) - self.presence_penalties_reqs: set[str] = set() - self.repetition_penalties = self._make_param(torch.float32) - self.repetition_penalties_reqs: set[str] = set() - - # req_idx -> generator - self.generators: dict[int, torch.Generator] = {} + self.sampling_states = SamplingStates( + max_num_reqs=max_num_reqs, + max_model_len=max_model_len, + max_num_cached_reqs=max_num_cached_reqs, + device=device, + ) def _make_param( self, @@ -413,58 +438,3 @@ def _prepare_spec_decode_kernel( sample_start_idx + offset, mask=offset < draft_len) tl.store(bonus_logits_indices + batch_idx, sample_end_idx - 1) - - -# NOTE: With the type annotations, this function is pre-compiled -# before the first call. -@numba.jit( - [ - types.none( - types.int32[:], # idx_mapping - types.int32[:, :], # token_ids - types.int32[:], # num_computed_tokens - types.int32[:], # num_scheduled_tokens - types.int32[:], # input_ids - types.int32[:], # query_start_loc - types.int32[:], # seq_lens - types.int64[:], # positions - ) - ], - nopython=True, - cache=True, -) -def prepare_inputs( - idx_mapping: np.ndarray, # batch_idx -> req_idx - token_ids: np.ndarray, # [N, max_model_len] - num_computed_tokens: np.ndarray, # [N] - num_scheduled_tokens: np.ndarray, # [B] - input_ids: np.ndarray, # [num_input_tokens] - query_start_loc: np.ndarray, # [B + 1] - seq_lens: np.ndarray, # [B] - positions: np.ndarray, # [num_input_tokens] -) -> None: - num_reqs = num_scheduled_tokens.shape[0] - query_start_loc[0] = 0 - - cu_num_tokens = 0 - for i in range(num_reqs): - req_idx = idx_mapping[i] - query_len = num_scheduled_tokens[i] - start = num_computed_tokens[req_idx] - end = start + query_len - seq_lens[i] = end - - start_idx = cu_num_tokens - end_idx = start_idx + query_len - input_ids[start_idx:end_idx] = token_ids[req_idx, start:end] - positions[start_idx:end_idx] = np.arange(start, end, dtype=np.int64) - - cu_num_tokens = end_idx - query_start_loc[i + 1] = cu_num_tokens - - # Pad the inputs for CUDA graphs. - # Note: pad query_start_loc to be non-decreasing, as kernels - # like FlashAttention requires that - query_start_loc[num_reqs + 1:].fill(cu_num_tokens) - # Fill unused with 0 for full cuda graph mode. - seq_lens[num_reqs:].fill(0)