[mypy][6/N] Fix all the core subdirectory typing (#4450)

Co-authored-by: Cade Daniel <edacih@gmail.com>
This commit is contained in:
SangBin Cho 2024-05-02 12:01:00 +09:00 committed by GitHub
parent 5e401bce17
commit cf8cac8c70
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 275 additions and 83 deletions

View File

@ -33,6 +33,7 @@ jobs:
- name: Mypy - name: Mypy
run: | run: |
mypy vllm/attention --config-file pyproject.toml mypy vllm/attention --config-file pyproject.toml
mypy vllm/core --config-file pyproject.toml
mypy vllm/distributed --config-file pyproject.toml mypy vllm/distributed --config-file pyproject.toml
mypy vllm/entrypoints --config-file pyproject.toml mypy vllm/entrypoints --config-file pyproject.toml
mypy vllm/executor --config-file pyproject.toml mypy vllm/executor --config-file pyproject.toml
@ -42,9 +43,6 @@ jobs:
mypy vllm/engine --config-file pyproject.toml mypy vllm/engine --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml mypy vllm/worker --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml
mypy vllm/model_executor --config-file pyproject.toml mypy vllm/model_executor --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml
# TODO(sang): Fix nested dir
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml

View File

@ -95,7 +95,7 @@ echo 'vLLM yapf: Done'
# Run mypy # Run mypy
echo 'vLLM mypy:' echo 'vLLM mypy:'
mypy vllm/attention --config-file pyproject.toml mypy vllm/attention --config-file pyproject.toml
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/core --config-file pyproject.toml
mypy vllm/distributed --config-file pyproject.toml mypy vllm/distributed --config-file pyproject.toml
mypy vllm/entrypoints --config-file pyproject.toml mypy vllm/entrypoints --config-file pyproject.toml
mypy vllm/executor --config-file pyproject.toml mypy vllm/executor --config-file pyproject.toml

View File

@ -40,7 +40,9 @@ class BlockTable:
): ):
self._block_size = block_size self._block_size = block_size
self._allocator = block_allocator self._allocator = block_allocator
self._blocks: Optional[List[Block]] = _blocks if _blocks is None:
_blocks = []
self._blocks: List[Block] = _blocks
# Use helper method instead of directly calculating, as blocks # Use helper method instead of directly calculating, as blocks
# may not be allocated. # may not be allocated.
@ -104,7 +106,7 @@ class BlockTable:
token_ids (List[int]): The sequence of token IDs to be appended. token_ids (List[int]): The sequence of token IDs to be appended.
""" """
assert self._is_allocated assert self._is_allocated
assert self._blocks is not None assert len(self._blocks) > 0
self.ensure_num_empty_slots(num_empty_slots=len(token_ids) + self.ensure_num_empty_slots(num_empty_slots=len(token_ids) +
num_lookahead_slots) num_lookahead_slots)
@ -141,6 +143,7 @@ class BlockTable:
blocks_to_allocate = cdiv(slots_to_allocate, self._block_size) blocks_to_allocate = cdiv(slots_to_allocate, self._block_size)
for _ in range(blocks_to_allocate): for _ in range(blocks_to_allocate):
assert len(self._blocks) > 0
self._blocks.append( self._blocks.append(
self._allocator.allocate_mutable(prev_block=self._blocks[-1], self._allocator.allocate_mutable(prev_block=self._blocks[-1],
device=device)) device=device))
@ -159,6 +162,7 @@ class BlockTable:
the current instance. the current instance.
""" """
assert self._is_allocated assert self._is_allocated
assert len(self._blocks) > 0
forked_blocks = self._allocator.fork(self._blocks[-1]) forked_blocks = self._allocator.fork(self._blocks[-1])
return BlockTable( return BlockTable(
block_size=self._block_size, block_size=self._block_size,
@ -177,10 +181,10 @@ class BlockTable:
assert self._is_allocated assert self._is_allocated
for block in self._blocks: for block in self._blocks:
self._allocator.free(block) self._allocator.free(block)
self._blocks = None self._blocks = []
@property @property
def physical_block_ids(self) -> List[int]: def physical_block_ids(self) -> List[Optional[int]]:
"""Returns a list of physical block indices for the blocks in the """Returns a list of physical block indices for the blocks in the
BlockTable. BlockTable.
@ -235,7 +239,7 @@ class BlockTable:
def _get_all_token_ids(self) -> List[int]: def _get_all_token_ids(self) -> List[int]:
# NOTE: This function is O(seq_len); use sparingly. # NOTE: This function is O(seq_len); use sparingly.
token_ids = [] token_ids: List[int] = []
if not self._is_allocated: if not self._is_allocated:
return token_ids return token_ids
@ -247,7 +251,7 @@ class BlockTable:
@property @property
def _is_allocated(self) -> bool: def _is_allocated(self) -> bool:
return self._blocks is not None return len(self._blocks) > 0
@property @property
def _num_empty_slots(self) -> int: def _num_empty_slots(self) -> int:

View File

@ -1,5 +1,5 @@
from collections import defaultdict from collections import defaultdict
from typing import Dict, Iterable, List, Optional from typing import Dict, Iterable, List, Optional, Protocol
from vllm.core.block.interfaces import Block, BlockAllocator from vllm.core.block.interfaces import Block, BlockAllocator
@ -7,7 +7,19 @@ BlockId = int
RefCount = int RefCount = int
class RefCounter: class RefCounterProtocol(Protocol):
def incr(self, block_id: BlockId) -> RefCount:
raise NotImplementedError
def decr(self, block_id: BlockId) -> RefCount:
raise NotImplementedError
def get(self, block_id: BlockId) -> RefCount:
raise NotImplementedError
class RefCounter(RefCounterProtocol):
"""A class for managing reference counts for a set of block indices. """A class for managing reference counts for a set of block indices.
The RefCounter class maintains a dictionary that maps block indices to their The RefCounter class maintains a dictionary that maps block indices to their
@ -54,7 +66,7 @@ class RefCounter:
return ReadOnlyRefCounter(self) return ReadOnlyRefCounter(self)
class ReadOnlyRefCounter: class ReadOnlyRefCounter(RefCounterProtocol):
"""A read-only view of the RefCounter class. """A read-only view of the RefCounter class.
The ReadOnlyRefCounter class provides a read-only interface to access the The ReadOnlyRefCounter class provides a read-only interface to access the
@ -96,7 +108,7 @@ class CopyOnWriteTracker:
def __init__( def __init__(
self, self,
refcounter: RefCounter, refcounter: RefCounterProtocol,
allocator: BlockAllocator, allocator: BlockAllocator,
): ):
self._copy_on_writes: Dict[BlockId, List[BlockId]] = defaultdict(list) self._copy_on_writes: Dict[BlockId, List[BlockId]] = defaultdict(list)

View File

@ -1,6 +1,6 @@
from typing import Dict, List, Optional from typing import Dict, FrozenSet, List, Optional
from vllm.core.block.interfaces import (Block, BlockAllocator, from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId,
DeviceAwareBlockAllocator) DeviceAwareBlockAllocator)
from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator
from vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator from vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator
@ -57,15 +57,15 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
cpu_block_ids = block_ids[num_gpu_blocks:] cpu_block_ids = block_ids[num_gpu_blocks:]
if allocator_type == "naive": if allocator_type == "naive":
gpu_allocator = NaiveBlockAllocator( gpu_allocator: BlockAllocator = NaiveBlockAllocator(
create_block=NaiveBlock, create_block=NaiveBlock, # type: ignore
num_blocks=num_gpu_blocks, num_blocks=num_gpu_blocks,
block_size=block_size, block_size=block_size,
block_ids=gpu_block_ids, block_ids=gpu_block_ids,
) )
cpu_allocator = NaiveBlockAllocator( cpu_allocator: BlockAllocator = NaiveBlockAllocator(
create_block=NaiveBlock, create_block=NaiveBlock, # type: ignore
num_blocks=num_cpu_blocks, num_blocks=num_cpu_blocks,
block_size=block_size, block_size=block_size,
block_ids=cpu_block_ids, block_ids=cpu_block_ids,
@ -105,13 +105,14 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Device.GPU: gpu_block_allocator, Device.GPU: gpu_block_allocator,
} }
self._block_ids_to_allocator = {} self._block_ids_to_allocator: Dict[int, BlockAllocator] = {}
for _, allocator in self._allocators.items(): for _, allocator in self._allocators.items():
for block_id in allocator.all_block_ids: for block_id in allocator.all_block_ids:
self._block_ids_to_allocator[block_id] = allocator self._block_ids_to_allocator[block_id] = allocator
def allocate_mutable(self, prev_block: Optional[Block], def allocate_mutable(self,
device: Device) -> Block: prev_block: Optional[Block],
device: Optional[Device] = None) -> Block:
"""Allocates a new mutable block on the specified device. """Allocates a new mutable block on the specified device.
Args: Args:
@ -122,10 +123,13 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Returns: Returns:
Block: The newly allocated mutable block. Block: The newly allocated mutable block.
""" """
assert device is not None
return self._allocators[device].allocate_mutable(prev_block) return self._allocators[device].allocate_mutable(prev_block)
def allocate_immutable(self, prev_block: Optional[Block], def allocate_immutable(self,
token_ids: List[int], device: Device) -> Block: prev_block: Optional[Block],
token_ids: List[int],
device: Optional[Device] = None) -> Block:
"""Allocates a new immutable block with the provided token IDs on the """Allocates a new immutable block with the provided token IDs on the
specified device. specified device.
@ -140,6 +144,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Block: The newly allocated immutable block containing the provided Block: The newly allocated immutable block containing the provided
token IDs. token IDs.
""" """
assert device is not None
return self._allocators[device].allocate_immutable( return self._allocators[device].allocate_immutable(
prev_block, token_ids) prev_block, token_ids)
@ -149,7 +154,9 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Args: Args:
block (Block): The block to be freed. block (Block): The block to be freed.
""" """
allocator = self._block_ids_to_allocator[block.block_id] block_id = block.block_id
assert block_id is not None
allocator = self._block_ids_to_allocator[block_id]
return allocator.free(block) return allocator.free(block)
def fork(self, last_block: Block) -> List[Block]: def fork(self, last_block: Block) -> List[Block]:
@ -163,19 +170,22 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
List[Block]: A new list of blocks that shares the same memory as the List[Block]: A new list of blocks that shares the same memory as the
original sequence. original sequence.
""" """
allocator = self._block_ids_to_allocator[last_block.block_id] block_id = last_block.block_id
assert block_id is not None
allocator = self._block_ids_to_allocator[block_id]
return allocator.fork(last_block) return allocator.fork(last_block)
def get_num_free_blocks(self, device: Device) -> int: def get_num_free_blocks(self, device: Optional[Device] = None) -> int:
"""Returns the number of free blocks available on the specified device. """Returns the number of free blocks available on the specified device.
Args: Args:
device (Device): The device for which to query the number of free device (Device): The device for which to query the number of free
blocks. blocks. AssertionError is raised if None is passed.
Returns: Returns:
int: The number of free blocks available on the specified device. int: The number of free blocks available on the specified device.
""" """
assert device is not None
return self._allocators[device].get_num_free_blocks() return self._allocators[device].get_num_free_blocks()
def clear_copy_on_writes(self) -> Dict[int, List[int]]: def clear_copy_on_writes(self) -> Dict[int, List[int]]:
@ -210,5 +220,12 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
return self._allocators[device].get_common_computed_block_ids( return self._allocators[device].get_common_computed_block_ids(
seq_block_ids) seq_block_ids)
def all_block_ids(self) -> frozenset[int]: @property
def all_block_ids(self) -> FrozenSet[int]:
return frozenset(self._block_ids_to_allocator.keys()) return frozenset(self._block_ids_to_allocator.keys())
def promote_to_immutable_block(self, block: Block) -> BlockId:
raise NotImplementedError
def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
raise NotImplementedError

View File

@ -3,6 +3,8 @@ from typing import Dict, FrozenSet, List, Optional, Protocol
from vllm.utils import Device from vllm.utils import Device
BlockId = int
class Block(ABC): class Block(ABC):
@ -15,6 +17,12 @@ class Block(ABC):
def block_id(self) -> Optional[int]: def block_id(self) -> Optional[int]:
pass pass
@block_id.setter
@abstractmethod
def block_id(self, value: Optional[int]) -> None:
"""NOTE: Do not use this API outside Block."""
self._block_id = value
@property @property
@abstractmethod @abstractmethod
def token_ids(self) -> List[int]: def token_ids(self) -> List[int]:
@ -35,6 +43,27 @@ class Block(ABC):
def prev_block(self) -> Optional["Block"]: def prev_block(self) -> Optional["Block"]:
pass pass
@property
@abstractmethod
def computed(self) -> bool:
raise NotImplementedError
@computed.setter
@abstractmethod
def computed(self, value) -> bool:
"""Should be only used by PrefixCacingAllocator"""
raise NotImplementedError
@property
@abstractmethod
def last_accessed(self) -> float:
raise NotImplementedError
@last_accessed.setter
@abstractmethod
def last_accessed(self, last_accessed_ts: float):
raise NotImplementedError
class Factory(Protocol): class Factory(Protocol):
@abstractmethod @abstractmethod
@ -48,6 +77,17 @@ class Block(ABC):
) -> "Block": ) -> "Block":
pass pass
@property
@abstractmethod
def content_hash(self) -> Optional[int]:
"""Return the content-based hash of the current block, or None if it is
not yet defined or not supported.
For the content-based hash to be defined, the current block must be
full.
"""
return None
class BlockAllocator(ABC): class BlockAllocator(ABC):
@ -57,7 +97,7 @@ class BlockAllocator(ABC):
@abstractmethod @abstractmethod
def allocate_immutable(self, prev_block: Optional[Block], def allocate_immutable(self, prev_block: Optional[Block],
token_ids: List[int], device: Device) -> Block: token_ids: List[int]) -> Block:
pass pass
@abstractmethod @abstractmethod
@ -69,7 +109,7 @@ class BlockAllocator(ABC):
pass pass
@abstractmethod @abstractmethod
def get_num_free_blocks(self, device: Device) -> int: def get_num_free_blocks(self) -> int:
pass pass
@property @property
@ -82,11 +122,12 @@ class BlockAllocator(ABC):
pass pass
@abstractmethod @abstractmethod
def mark_blocks_as_accessed(self) -> None: def mark_blocks_as_accessed(self, block_ids: List[int],
now: float) -> None:
pass pass
@abstractmethod @abstractmethod
def mark_blocks_as_computed(self) -> None: def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
pass pass
@abstractmethod @abstractmethod
@ -94,21 +135,66 @@ class BlockAllocator(ABC):
self, seq_block_ids: List[List[int]]) -> List[int]: self, seq_block_ids: List[List[int]]) -> List[int]:
pass pass
@abstractmethod
def cow_block_if_not_appendable(self, block: Block) -> Optional["BlockId"]:
"""NOTE: This should not be used besides Block"""
pass
@abstractmethod
def promote_to_immutable_block(self, block: Block) -> BlockId:
"""NOTE: This should not be used besides Block"""
pass
class NoFreeBlocksError(ValueError): class NoFreeBlocksError(ValueError):
pass pass
class DeviceAwareBlockAllocator(BlockAllocator): class DeviceAwareBlockAllocator(ABC):
@abstractmethod @abstractmethod
def allocate_mutable(self, prev_block: Optional[Block]) -> Block: def allocate_mutable(self,
prev_block: Optional[Block],
device: Optional[Device] = None) -> Block:
pass pass
@abstractmethod @abstractmethod
def allocate_immutable(self, prev_block: Optional[Block], def allocate_immutable(self,
token_ids: List[int], device: Device) -> Block: prev_block: Optional[Block],
token_ids: List[int],
device: Optional[Device] = None) -> Block:
pass pass
@abstractmethod @abstractmethod
def get_num_free_blocks(self, device: Device) -> int: def get_num_free_blocks(self, device: Optional[Device] = None) -> int:
pass
@abstractmethod
def free(self, block: Block) -> None:
pass
@abstractmethod
def fork(self, last_block: Block) -> List[Block]:
pass
@property
@abstractmethod
def all_block_ids(self) -> FrozenSet[int]:
pass
@abstractmethod
def clear_copy_on_writes(self) -> Dict[int, List[int]]:
pass
@abstractmethod
def mark_blocks_as_accessed(self, block_ids: List[int],
now: float) -> None:
pass
@abstractmethod
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
pass
@abstractmethod
def get_common_computed_block_ids(
self, seq_block_ids: List[List[int]]) -> List[int]:
pass pass

View File

@ -1,10 +1,9 @@
from typing import Dict, Iterable, List, Optional, Set from typing import Dict, FrozenSet, Iterable, List, Optional, Set
from vllm.core.block.common import (CopyOnWriteTracker, RefCounter, from vllm.core.block.common import (CopyOnWriteTracker, RefCounter,
get_all_blocks_recursively) get_all_blocks_recursively)
from vllm.core.block.interfaces import Block, BlockAllocator from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
BlockId = int
Refcount = int Refcount = int
@ -49,8 +48,10 @@ class NaiveBlockAllocator(BlockAllocator):
allocator=self, allocator=self,
) )
def allocate_immutable(self, prev_block: Optional[Block], def allocate_immutable(self,
token_ids: List[int]) -> Block: prev_block: Optional[Block],
token_ids: List[int],
device: Optional[Device] = None) -> Block:
"""Allocates a new immutable block with the given token IDs, linked to """Allocates a new immutable block with the given token IDs, linked to
the previous block. the previous block.
@ -63,11 +64,14 @@ class NaiveBlockAllocator(BlockAllocator):
Returns: Returns:
Block: The newly allocated immutable block. Block: The newly allocated immutable block.
""" """
assert device is None
block = self.allocate_mutable(prev_block=prev_block) block = self.allocate_mutable(prev_block=prev_block)
block.append_token_ids(token_ids) block.append_token_ids(token_ids)
return block return block
def allocate_mutable(self, prev_block: Optional[Block]) -> Block: def allocate_mutable(self,
prev_block: Optional[Block],
device: Optional[Device] = None) -> Block:
"""Allocates a new mutable block, linked to the previous block. """Allocates a new mutable block, linked to the previous block.
Args: Args:
@ -78,6 +82,7 @@ class NaiveBlockAllocator(BlockAllocator):
Returns: Returns:
Block: The newly allocated mutable block. Block: The newly allocated mutable block.
""" """
assert device is None
block_id = self._allocate_new_block_id() block_id = self._allocate_new_block_id()
return self._create_block( return self._create_block(
prev_block=prev_block, prev_block=prev_block,
@ -88,6 +93,7 @@ class NaiveBlockAllocator(BlockAllocator):
) )
def free(self, block: Block) -> None: def free(self, block: Block) -> None:
assert block.block_id is not None
self._free_block_id(block.block_id) self._free_block_id(block.block_id)
# Mark the block as having no allocation. # Mark the block as having no allocation.
@ -111,6 +117,7 @@ class NaiveBlockAllocator(BlockAllocator):
for block in source_blocks: for block in source_blocks:
# Increment refcount for each block. # Increment refcount for each block.
assert block.block_id is not None
refcount = self._refcounter.incr(block.block_id) refcount = self._refcounter.incr(block.block_id)
assert refcount != 1, "can't fork free'd block" assert refcount != 1, "can't fork free'd block"
@ -126,7 +133,8 @@ class NaiveBlockAllocator(BlockAllocator):
return forked_blocks return forked_blocks
def get_num_free_blocks(self) -> int: def get_num_free_blocks(self, device: Optional[Device] = None) -> int:
assert device is None
return len(self._free_block_indices) return len(self._free_block_indices)
def _allocate_new_block_id(self) -> BlockId: def _allocate_new_block_id(self) -> BlockId:
@ -148,7 +156,7 @@ class NaiveBlockAllocator(BlockAllocator):
return self._refcounter return self._refcounter
@property @property
def all_block_ids(self): def all_block_ids(self) -> FrozenSet[int]:
return self._all_block_indices return self._all_block_indices
def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]: def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
@ -200,6 +208,9 @@ class NaiveBlockAllocator(BlockAllocator):
""" """
return [] return []
def promote_to_immutable_block(self, block: Block) -> BlockId:
raise NotImplementedError
class NaiveBlock(Block): class NaiveBlock(Block):
"""An implementation of the Block class that does not support prefix """An implementation of the Block class that does not support prefix
@ -224,13 +235,13 @@ class NaiveBlock(Block):
""" """
def __init__(self, def __init__(self,
prev_block: Block, prev_block: Optional[Block],
token_ids: List[int], token_ids: List[int],
block_size: int, block_size: int,
allocator: BlockAllocator, allocator: BlockAllocator,
block_id: Optional[int] = None, block_id: Optional[int] = None,
_cow_target: Optional[Block] = None): _cow_target: Optional[Block] = None):
self._token_ids = [] self._token_ids: List[int] = []
self._block_size = block_size self._block_size = block_size
self._prev_block = prev_block self._prev_block = prev_block
self._block_id = block_id self._block_id = block_id
@ -256,6 +267,22 @@ class NaiveBlock(Block):
assert self.num_empty_slots >= len(token_ids) assert self.num_empty_slots >= len(token_ids)
self._token_ids.extend(token_ids) self._token_ids.extend(token_ids)
@property
def computed(self) -> bool:
raise NotImplementedError
@computed.setter
def computed(self, value) -> None:
raise NotImplementedError
@property
def last_accessed(self) -> float:
raise NotImplementedError
@last_accessed.setter
def last_accessed(self, last_accessed_ts: float):
raise NotImplementedError
@property @property
def block_id(self) -> Optional[int]: def block_id(self) -> Optional[int]:
return self._block_id return self._block_id
@ -276,9 +303,14 @@ class NaiveBlock(Block):
def token_ids(self) -> List[int]: def token_ids(self) -> List[int]:
return self._token_ids return self._token_ids
@property
def block_size(self) -> int: def block_size(self) -> int:
return self._block_size return self._block_size
@property @property
def prev_block(self) -> Optional["Block"]: def prev_block(self) -> Optional["Block"]:
return self._prev_block return self._prev_block
@property
def content_hash(self) -> Optional[int]:
return None

View File

@ -1,16 +1,15 @@
"""Token blocks.""" """Token blocks."""
from itertools import takewhile from itertools import takewhile
from os.path import commonprefix from os.path import commonprefix
from typing import Dict, Iterable, List, Optional from typing import Dict, FrozenSet, Iterable, List, Optional
from vllm.core.block.common import (CopyOnWriteTracker, from vllm.core.block.common import (CopyOnWriteTracker,
get_all_blocks_recursively) get_all_blocks_recursively)
from vllm.core.block.interfaces import Block, BlockAllocator from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator
from vllm.core.evictor_v2 import EvictionPolicy, Evictor, make_evictor from vllm.core.evictor_v2 import EvictionPolicy, Evictor, make_evictor
PrefixHash = int PrefixHash = int
BlockId = int
# By default, we init our block access time as _DEFAULT_LAST_ACCESSED_TIME # By default, we init our block access time as _DEFAULT_LAST_ACCESSED_TIME
# so that if we find one block is still hold _DEFAULT_LAST_ACCESSED_TIME, # so that if we find one block is still hold _DEFAULT_LAST_ACCESSED_TIME,
@ -38,7 +37,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
num_blocks: int, num_blocks: int,
block_size: int, block_size: int,
block_ids: Optional[Iterable[int]] = None, block_ids: Optional[Iterable[int]] = None,
eviction_policy: Optional[EvictionPolicy] = EvictionPolicy.LRU, eviction_policy: EvictionPolicy = EvictionPolicy.LRU,
): ):
# A mapping of prefix hash to block index. All blocks which have a # A mapping of prefix hash to block index. All blocks which have a
# prefix hash will be in this dict, even if they have refcount 0. # prefix hash will be in this dict, even if they have refcount 0.
@ -49,7 +48,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
# An allocator for blocks that do not have prefix hashes. # An allocator for blocks that do not have prefix hashes.
self._hashless_allocator = NaiveBlockAllocator( self._hashless_allocator = NaiveBlockAllocator(
create_block=self._create_block, create_block=self._create_block, # type: ignore
num_blocks=num_blocks, num_blocks=num_blocks,
block_size=block_size, block_size=block_size,
block_ids=block_ids, block_ids=block_ids,
@ -79,7 +78,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
block_size: int, block_size: int,
allocator: BlockAllocator, allocator: BlockAllocator,
block_id: Optional[int] = None, block_id: Optional[int] = None,
computed: Optional[bool] = False, computed: bool = False,
) -> Block: ) -> Block:
# Bind block to self. # Bind block to self.
allocator = self allocator = self
@ -93,8 +92,10 @@ class PrefixCachingBlockAllocator(BlockAllocator):
computed=computed, computed=computed,
) )
def allocate_immutable(self, prev_block: Optional[Block], def allocate_immutable(self,
token_ids: List[int]) -> Block: prev_block: Optional[Block],
token_ids: List[int],
device: Optional[Device] = None) -> Block:
"""Allocates an immutable block with the given token IDs, reusing cached """Allocates an immutable block with the given token IDs, reusing cached
blocks if possible. blocks if possible.
@ -105,6 +106,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
Returns: Returns:
Block: The allocated immutable block. Block: The allocated immutable block.
""" """
assert device is None
assert_prefix_caching_block_or_none(prev_block) assert_prefix_caching_block_or_none(prev_block)
block = self._create_block( block = self._create_block(
@ -127,16 +129,20 @@ class PrefixCachingBlockAllocator(BlockAllocator):
return block return block
def allocate_mutable(self, prev_block: Block) -> Block: def allocate_mutable(self,
prev_block: Optional[Block],
device: Optional[Device] = None) -> Block:
"""Allocates a mutable block. If there are no free blocks, this will """Allocates a mutable block. If there are no free blocks, this will
evict unused cached blocks. evict unused cached blocks.
Args: Args:
prev_block (Block): The previous block in the sequence. prev_block (Block): The previous block in the sequence.
None is not allowed unlike it is super class.
Returns: Returns:
Block: The allocated mutable block. Block: The allocated mutable block.
""" """
assert device is None
assert_prefix_caching_block_or_none(prev_block) assert_prefix_caching_block_or_none(prev_block)
try: try:
@ -144,6 +150,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
prev_block=prev_block) prev_block=prev_block)
assert block.block_id not in self._blocks assert block.block_id not in self._blocks
assert block.block_id is not None
self._blocks[block.block_id] = block self._blocks[block.block_id] = block
return block return block
except BlockAllocator.NoFreeBlocksError: except BlockAllocator.NoFreeBlocksError:
@ -183,6 +190,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
assert block.content_hash is None assert block.content_hash is None
assert block.block_id not in self._blocks assert block.block_id not in self._blocks
assert block.block_id is not None
self._blocks[block.block_id] = block self._blocks[block.block_id] = block
return block return block
@ -225,6 +233,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
# We have fork case where block would get more than one ref, # We have fork case where block would get more than one ref,
# so we cannot free it from tracking if ref cnt large than 1 # so we cannot free it from tracking if ref cnt large than 1
if refcount <= 1: if refcount <= 1:
assert block.block_id is not None
del self._blocks[block.block_id] del self._blocks[block.block_id]
return self._hashless_allocator.free(block) return self._hashless_allocator.free(block)
@ -233,6 +242,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
# If no longer used, add the block to the evictor. # If no longer used, add the block to the evictor.
if refcount == 0: if refcount == 0:
assert block.content_hash in self._cached_blocks assert block.content_hash in self._cached_blocks
assert block.block_id is not None
del self._blocks[block.block_id] del self._blocks[block.block_id]
self.evictor.add(block.block_id, block.content_hash, self.evictor.add(block.block_id, block.content_hash,
block.num_tokens_total, block.last_accessed) block.num_tokens_total, block.last_accessed)
@ -268,18 +278,18 @@ class PrefixCachingBlockAllocator(BlockAllocator):
return forked_blocks return forked_blocks
def get_num_free_blocks(self) -> int: def get_num_free_blocks(self, device: Optional[Device] = None) -> int:
assert device is None
# The number of free blocks is the number of hashless free blocks # The number of free blocks is the number of hashless free blocks
# plus the number of blocks evictor could free from its list. # plus the number of blocks evictor could free from its list.
return self._hashless_allocator.get_num_free_blocks( return self._hashless_allocator.get_num_free_blocks(
) + self.evictor.num_blocks ) + self.evictor.num_blocks
@property @property
def all_block_ids(self) -> frozenset[int]: def all_block_ids(self) -> FrozenSet[int]:
return self._hashless_allocator.all_block_ids return self._hashless_allocator.all_block_ids
def promote_to_immutable_block(self, def promote_to_immutable_block(self, block: Block) -> BlockId:
block: "PrefixCachingBlock") -> BlockId:
"""Once a mutable block is full, it can be promoted to an immutable """Once a mutable block is full, it can be promoted to an immutable
block. This means that its content can be referenced by future blocks block. This means that its content can be referenced by future blocks
having the same prefix. having the same prefix.
@ -289,7 +299,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
block. block.
Args: Args:
block (PrefixCachingBlock): The mutable block to be promoted. block: The mutable block to be promoted.
Returns: Returns:
BlockId: Either the original block index, or the block index of BlockId: Either the original block index, or the block index of
@ -385,8 +395,11 @@ class PrefixCachingBlockAllocator(BlockAllocator):
takewhile(lambda block_id: self.block_is_computed(block_id), takewhile(lambda block_id: self.block_is_computed(block_id),
seq[:-1])) for seq in seq_block_ids seq[:-1])) for seq in seq_block_ids
] ]
res = commonprefix([ids for ids in ids_list if ids != []]) # It returns a list of int although type annotation says list of string.
return res return commonprefix([
ids for ids in ids_list # type: ignore
if ids != []
])
class PrefixCachingBlock(Block): class PrefixCachingBlock(Block):
@ -403,7 +416,7 @@ class PrefixCachingBlock(Block):
token_ids (List[int]): The initial token IDs to be stored in the block. token_ids (List[int]): The initial token IDs to be stored in the block.
block_size (int): The maximum number of token IDs that can be stored in block_size (int): The maximum number of token IDs that can be stored in
the block. the block.
prefix_caching_allocator (PrefixCachingBlockAllocator): The prefix prefix_caching_allocator (BlockAllocator): The prefix
caching block allocator associated with this block. caching block allocator associated with this block.
block_id (Optional[int], optional): The physical block index block_id (Optional[int], optional): The physical block index
of this block. Defaults to None. of this block. Defaults to None.
@ -411,21 +424,25 @@ class PrefixCachingBlock(Block):
def __init__( def __init__(
self, self,
prev_block: Optional["PrefixCachingBlock"], prev_block: Optional[Block],
token_ids: List[int], token_ids: List[int],
block_size: int, block_size: int,
prefix_caching_allocator: PrefixCachingBlockAllocator, prefix_caching_allocator: BlockAllocator,
block_id: Optional[int] = None, block_id: Optional[int] = None,
computed: Optional[bool] = False, computed: bool = False,
): ):
assert isinstance(prefix_caching_allocator,
PrefixCachingBlockAllocator), (
"Currently this class is only tested with "
"PrefixCachingBlockAllocator.")
assert_prefix_caching_block_or_none(prev_block) assert_prefix_caching_block_or_none(prev_block)
self._prev_block = prev_block self._prev_block = prev_block
self._cached_content_hash: Optional[int] = None self._cached_content_hash: Optional[int] = None
self._cached_num_tokens_total: Optional[int] = None self._cached_num_tokens_total: Optional[int] = None
self._prefix_caching_allocator = prefix_caching_allocator self._prefix_caching_allocator = prefix_caching_allocator
self.last_accessed = _DEFAULT_LAST_ACCESSED_TIME self._last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME
self.computed = computed self._computed = computed
self._block = NaiveBlock( self._block = NaiveBlock(
prev_block=prev_block, prev_block=prev_block,
@ -436,6 +453,22 @@ class PrefixCachingBlock(Block):
_cow_target=self, _cow_target=self,
) )
@property
def computed(self) -> bool:
return self._computed
@computed.setter
def computed(self, value) -> None:
self._computed = value
@property
def last_accessed(self) -> float:
return self._last_accessed
@last_accessed.setter
def last_accessed(self, last_accessed_ts: float):
self._last_accessed = last_accessed_ts
def append_token_ids(self, token_ids: List[int]) -> None: def append_token_ids(self, token_ids: List[int]) -> None:
"""Appends the given token IDs to the block and registers the block as """Appends the given token IDs to the block and registers the block as
immutable if the block becomes full. immutable if the block becomes full.
@ -483,7 +516,7 @@ class PrefixCachingBlock(Block):
if self._cached_num_tokens_total is not None: if self._cached_num_tokens_total is not None:
return self._cached_num_tokens_total return self._cached_num_tokens_total
_block = self _block: Optional[Block] = self
self._cached_num_tokens_total = 0 self._cached_num_tokens_total = 0
# TODO: current implement here take O(N^2), we expect future # TODO: current implement here take O(N^2), we expect future
@ -524,8 +557,10 @@ class PrefixCachingBlock(Block):
return None return None
is_first_block = self._prev_block is None is_first_block = self._prev_block is None
prev_block_hash = (None if is_first_block else prev_block_hash = (
self._prev_block.content_hash) None if is_first_block else
self._prev_block.content_hash # type: ignore
)
# Previous block exists but does not yet have a hash. # Previous block exists but does not yet have a hash.
# Return no hash in this case. # Return no hash in this case.

View File

@ -190,7 +190,7 @@ class BlockSpaceManagerV2(BlockSpaceManager):
assert seq.seq_id in self.block_tables assert seq.seq_id in self.block_tables
block_ids = self.block_tables[seq.seq_id].physical_block_ids block_ids = self.block_tables[seq.seq_id].physical_block_ids
assert all(b is not None for b in block_ids) assert all(b is not None for b in block_ids)
return block_ids return block_ids # type: ignore
def access_all_blocks_in_seq(self, seq: Sequence, now: float): def access_all_blocks_in_seq(self, seq: Sequence, now: float):
# Update the last accessed time of all the blocks accessed # Update the last accessed time of all the blocks accessed
@ -204,7 +204,9 @@ class BlockSpaceManagerV2(BlockSpaceManager):
block_ids = [] block_ids = []
for block_id in block_table.physical_block_ids: for block_id in block_table.physical_block_ids:
block_ids.append(block_id) block_ids.append(block_id)
self.block_allocator.mark_blocks_as_accessed(block_ids, now) self.block_allocator.mark_blocks_as_accessed(
block_ids, # type: ignore
now)
def mark_blocks_as_computed(self, seq_group: SequenceGroup): def mark_blocks_as_computed(self, seq_group: SequenceGroup):
# The only need for mark block as computed is for prefix caching, # The only need for mark block as computed is for prefix caching,
@ -227,8 +229,9 @@ class BlockSpaceManagerV2(BlockSpaceManager):
seq_block_ids = [ seq_block_ids = [
self.block_tables[seq.seq_id].physical_block_ids for seq in seqs self.block_tables[seq.seq_id].physical_block_ids for seq in seqs
] ]
# NOTE(sang): This assumes seq_block_ids doesn't contain any None.
return self.block_allocator.get_common_computed_block_ids( return self.block_allocator.get_common_computed_block_ids(
seq_block_ids) seq_block_ids) # type: ignore
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
src_block_table = self.block_tables[parent_seq.seq_id] src_block_table = self.block_tables[parent_seq.seq_id]

View File

@ -32,15 +32,20 @@ class Evictor(ABC):
@abstractmethod @abstractmethod
def add(self, block_id: int, content_hash: int, num_hashed_tokens: int, def add(self, block_id: int, content_hash: int, num_hashed_tokens: int,
last_accessed: int): last_accessed: float):
"""Adds block to the evictor, making it a candidate for eviction""" """Adds block to the evictor, making it a candidate for eviction"""
pass pass
@abstractmethod @abstractmethod
def update(self, block_id: int, last_accessed: int): def update(self, block_id: int, last_accessed: float):
"""Update corresponding block's access time in metadata""" """Update corresponding block's access time in metadata"""
pass pass
@abstractmethod
def remove(self, block_id: int):
"""Remove a given block id from the cache."""
pass
@abstractproperty @abstractproperty
def num_blocks(self) -> int: def num_blocks(self) -> int:
pass pass
@ -55,7 +60,7 @@ class BlockMetaData():
""" """
def __init__(self, content_hash: int, num_hashed_tokens: int, def __init__(self, content_hash: int, num_hashed_tokens: int,
last_accessed: int): last_accessed: float):
self.content_hash = content_hash self.content_hash = content_hash
self.num_hashed_tokens = num_hashed_tokens self.num_hashed_tokens = num_hashed_tokens
self.last_accessed = last_accessed self.last_accessed = last_accessed
@ -96,12 +101,12 @@ class LRUEvictor(Evictor):
return evicted_block_id, evicted_block.content_hash return evicted_block_id, evicted_block.content_hash
def add(self, block_id: int, content_hash: int, num_hashed_tokens: int, def add(self, block_id: int, content_hash: int, num_hashed_tokens: int,
last_accessed: int): last_accessed: float):
self.free_table[block_id] = BlockMetaData(content_hash, self.free_table[block_id] = BlockMetaData(content_hash,
num_hashed_tokens, num_hashed_tokens,
last_accessed) last_accessed)
def update(self, block_id: int, last_accessed: int): def update(self, block_id: int, last_accessed: float):
self.free_table[block_id].last_accessed = last_accessed self.free_table[block_id].last_accessed = last_accessed
def remove(self, block_id: int): def remove(self, block_id: int):