Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-08-22 01:48:35 -07:00
parent 79e5eb3643
commit 64c8cced18

View File

@ -5,6 +5,7 @@
from dataclasses import dataclass
from typing import Optional
import numpy as np
import torch
import triton
import triton.language as tl
@ -23,9 +24,8 @@ PAD_SLOT_ID = -1
@dataclass
class CachedRequestState:
class RequestData:
req_id: str
mm_kwargs: list[MultiModalKwargsItem]
mm_positions: list[PlaceholderRange]
sampling_params: Optional[SamplingParams]
@ -78,7 +78,7 @@ class PerRequestAttribute:
self.gpu.squeeze_(1)
class InputBatch:
class RequestState:
def __init__(
self,
@ -109,9 +109,11 @@ class InputBatch:
self.req_id_to_index: dict[str, int] = {}
self.index_to_req_id: dict[int, str] = {}
self.free_indices = list(range(max_num_cached_reqs))
# Used to construct the input batch.
self._add_scalar_attr("idx_mapping", torch.int32)
# Request states.
self.req_data: dict[int, RequestData] = {}
# TODO(woosuk): Because the token_ids tensor can be very big, we only
# initialize it on CPU memory.
self._add_vector_attr_cpu("token_ids", self.max_model_len, torch.int32)
@ -396,6 +398,24 @@ class InputBatch:
return tuple(x[:num_tokens] for x in self.slot_mappings)
@dataclass
class InputBatch:
# batch_idx -> req_id
req_ids: list[str]
# batch_idx -> req_state_idx
idx_mapping: torch.Tensor
idx_mapping_np: np.ndarray
# [num_kv_cache_groups, num_reqs, max_num_blocks_per_req]
block_tables: tuple[torch.Tensor, ...]
# [num_kv_cache_groups, num_tokens]
slot_mappings: tuple[torch.Tensor, ...]
# [num_reqs] mostly
sampling_metadata: SamplingMetadata
@triton.jit
def _make_sampling_metadata_kernel(
batch_idx_to_req_idx, # [batch_size]