Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-08-22 01:48:35 -07:00
parent 79e5eb3643
commit 64c8cced18

View File

@ -5,6 +5,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional
import numpy as np
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
@ -23,9 +24,8 @@ PAD_SLOT_ID = -1
@dataclass @dataclass
class CachedRequestState: class RequestData:
req_id: str
mm_kwargs: list[MultiModalKwargsItem] mm_kwargs: list[MultiModalKwargsItem]
mm_positions: list[PlaceholderRange] mm_positions: list[PlaceholderRange]
sampling_params: Optional[SamplingParams] sampling_params: Optional[SamplingParams]
@ -78,7 +78,7 @@ class PerRequestAttribute:
self.gpu.squeeze_(1) self.gpu.squeeze_(1)
class InputBatch: class RequestState:
def __init__( def __init__(
self, self,
@ -109,9 +109,11 @@ class InputBatch:
self.req_id_to_index: dict[str, int] = {} self.req_id_to_index: dict[str, int] = {}
self.index_to_req_id: dict[int, str] = {} self.index_to_req_id: dict[int, str] = {}
self.free_indices = list(range(max_num_cached_reqs)) self.free_indices = list(range(max_num_cached_reqs))
# Used to construct the input batch.
self._add_scalar_attr("idx_mapping", torch.int32) self._add_scalar_attr("idx_mapping", torch.int32)
# Request states. # Request states.
self.req_data: dict[int, RequestData] = {}
# TODO(woosuk): Because the token_ids tensor can be very big, we only # TODO(woosuk): Because the token_ids tensor can be very big, we only
# initialize it on CPU memory. # initialize it on CPU memory.
self._add_vector_attr_cpu("token_ids", self.max_model_len, torch.int32) self._add_vector_attr_cpu("token_ids", self.max_model_len, torch.int32)
@ -396,6 +398,24 @@ class InputBatch:
return tuple(x[:num_tokens] for x in self.slot_mappings) return tuple(x[:num_tokens] for x in self.slot_mappings)
@dataclass
class InputBatch:
# batch_idx -> req_id
req_ids: list[str]
# batch_idx -> req_state_idx
idx_mapping: torch.Tensor
idx_mapping_np: np.ndarray
# [num_kv_cache_groups, num_reqs, max_num_blocks_per_req]
block_tables: tuple[torch.Tensor, ...]
# [num_kv_cache_groups, num_tokens]
slot_mappings: tuple[torch.Tensor, ...]
# [num_reqs] mostly
sampling_metadata: SamplingMetadata
@triton.jit @triton.jit
def _make_sampling_metadata_kernel( def _make_sampling_metadata_kernel(
batch_idx_to_req_idx, # [batch_size] batch_idx_to_req_idx, # [batch_size]