From 64c8cced182f7d482d521212f165cc96655aeebb Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 22 Aug 2025 01:48:35 -0700 Subject: [PATCH] rename Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu_input_batch.py | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 62ec917276eb8..1cd3bb3f59a85 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -5,6 +5,7 @@ from dataclasses import dataclass from typing import Optional +import numpy as np import torch import triton import triton.language as tl @@ -23,9 +24,8 @@ PAD_SLOT_ID = -1 @dataclass -class CachedRequestState: +class RequestData: - req_id: str mm_kwargs: list[MultiModalKwargsItem] mm_positions: list[PlaceholderRange] sampling_params: Optional[SamplingParams] @@ -78,7 +78,7 @@ class PerRequestAttribute: self.gpu.squeeze_(1) -class InputBatch: +class RequestState: def __init__( self, @@ -109,9 +109,11 @@ class InputBatch: self.req_id_to_index: dict[str, int] = {} self.index_to_req_id: dict[int, str] = {} self.free_indices = list(range(max_num_cached_reqs)) + # Used to construct the input batch. self._add_scalar_attr("idx_mapping", torch.int32) # Request states. + self.req_data: dict[int, RequestData] = {} # TODO(woosuk): Because the token_ids tensor can be very big, we only # initialize it on CPU memory. self._add_vector_attr_cpu("token_ids", self.max_model_len, torch.int32) @@ -396,6 +398,24 @@ class InputBatch: return tuple(x[:num_tokens] for x in self.slot_mappings) +@dataclass +class InputBatch: + + # batch_idx -> req_id + req_ids: list[str] + # batch_idx -> req_state_idx + idx_mapping: torch.Tensor + idx_mapping_np: np.ndarray + + # [num_kv_cache_groups, num_reqs, max_num_blocks_per_req] + block_tables: tuple[torch.Tensor, ...] + # [num_kv_cache_groups, num_tokens] + slot_mappings: tuple[torch.Tensor, ...] + + # [num_reqs] mostly + sampling_metadata: SamplingMetadata + + @triton.jit def _make_sampling_metadata_kernel( batch_idx_to_req_idx, # [batch_size]