vllm/vllm/block.py
Sage Moore ce4f5a29fb
Add Automatic Prefix Caching (#2762)
Co-authored-by: ElizaWszola <eliza@neuralmagic.com>
Co-authored-by: Michael Goin <michael@neuralmagic.com>
2024-03-02 00:50:01 -08:00

85 lines
2.3 KiB
Python

"""Token blocks."""
from typing import List
from vllm.utils import Device
_BLANK_TOKEN_ID = -1
DEFAULT_LAST_ACCESSED_TIME = -1
class LogicalTokenBlock:
"""A block that stores a contiguous chunk of tokens from left to right.
Logical blocks are used to represent the states of the corresponding
physical blocks in the KV cache.
"""
def __init__(
self,
block_number: int,
block_size: int,
) -> None:
self.block_number = block_number
self.block_size = block_size
self.token_ids = [_BLANK_TOKEN_ID] * block_size
self.num_tokens = 0
def is_empty(self) -> bool:
return self.num_tokens == 0
def get_num_empty_slots(self) -> int:
return self.block_size - self.num_tokens
def is_full(self) -> bool:
return self.num_tokens == self.block_size
def append_tokens(self, token_ids: List[int]) -> None:
assert len(token_ids) <= self.get_num_empty_slots()
curr_idx = self.num_tokens
self.token_ids[curr_idx:curr_idx + len(token_ids)] = token_ids
self.num_tokens += len(token_ids)
def get_token_ids(self) -> List[int]:
return self.token_ids[:self.num_tokens]
def get_last_token_id(self) -> int:
assert self.num_tokens > 0
return self.token_ids[self.num_tokens - 1]
class PhysicalTokenBlock:
"""Represents the state of a block in the KV cache."""
def __init__(
self,
device: Device,
block_number: int,
block_size: int,
block_hash: int,
num_hashed_tokens: int,
) -> None:
self.device = device
self.block_number = block_number
self.block_size = block_size
self.block_hash = block_hash
self.num_hashed_tokens = num_hashed_tokens
self.ref_count = 0
self.last_accessed = DEFAULT_LAST_ACCESSED_TIME
self.computed = False
def __repr__(self) -> str:
return (f'PhysicalTokenBlock(device={self.device}, '
f'block_number={self.block_number}, '
f'num_hashed_tokens={self.num_hashed_tokens}, '
f'ref_count={self.ref_count}, '
f'last_accessed={self.last_accessed}, '
f'computed={self.computed})')
# Mapping: logical block number -> physical block.
BlockTable = List[PhysicalTokenBlock]