mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-20 05:07:03 +08:00
rename
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
79e5eb3643
commit
64c8cced18
@ -5,6 +5,7 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
@ -23,9 +24,8 @@ PAD_SLOT_ID = -1
|
||||
|
||||
|
||||
@dataclass
|
||||
class CachedRequestState:
|
||||
class RequestData:
|
||||
|
||||
req_id: str
|
||||
mm_kwargs: list[MultiModalKwargsItem]
|
||||
mm_positions: list[PlaceholderRange]
|
||||
sampling_params: Optional[SamplingParams]
|
||||
@ -78,7 +78,7 @@ class PerRequestAttribute:
|
||||
self.gpu.squeeze_(1)
|
||||
|
||||
|
||||
class InputBatch:
|
||||
class RequestState:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -109,9 +109,11 @@ class InputBatch:
|
||||
self.req_id_to_index: dict[str, int] = {}
|
||||
self.index_to_req_id: dict[int, str] = {}
|
||||
self.free_indices = list(range(max_num_cached_reqs))
|
||||
# Used to construct the input batch.
|
||||
self._add_scalar_attr("idx_mapping", torch.int32)
|
||||
|
||||
# Request states.
|
||||
self.req_data: dict[int, RequestData] = {}
|
||||
# TODO(woosuk): Because the token_ids tensor can be very big, we only
|
||||
# initialize it on CPU memory.
|
||||
self._add_vector_attr_cpu("token_ids", self.max_model_len, torch.int32)
|
||||
@ -396,6 +398,24 @@ class InputBatch:
|
||||
return tuple(x[:num_tokens] for x in self.slot_mappings)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputBatch:
|
||||
|
||||
# batch_idx -> req_id
|
||||
req_ids: list[str]
|
||||
# batch_idx -> req_state_idx
|
||||
idx_mapping: torch.Tensor
|
||||
idx_mapping_np: np.ndarray
|
||||
|
||||
# [num_kv_cache_groups, num_reqs, max_num_blocks_per_req]
|
||||
block_tables: tuple[torch.Tensor, ...]
|
||||
# [num_kv_cache_groups, num_tokens]
|
||||
slot_mappings: tuple[torch.Tensor, ...]
|
||||
|
||||
# [num_reqs] mostly
|
||||
sampling_metadata: SamplingMetadata
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _make_sampling_metadata_kernel(
|
||||
batch_idx_to_req_idx, # [batch_size]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user