diff --git a/docs/design/v1/prefix_caching.md b/docs/design/v1/prefix_caching.md index ad041b0059f5..bbdfb255214d 100644 --- a/docs/design/v1/prefix_caching.md +++ b/docs/design/v1/prefix_caching.md @@ -104,7 +104,7 @@ class KVCacheBlock: block_id: int # The block hash (will be assigned when the block is full, # and will be reset when the block is evicted). - block_hash: BlockHashType + block_hash: BlockHash # The number of requests using this block now. ref_cnt: int diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index d3d62cf09232..61aee8752988 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -100,8 +100,8 @@ def test_kv_cache_block(): assert block.ref_cnt == 0 # Test block hash setting and resetting - block_hash = vllm.v1.core.kv_cache_utils.BlockHashType(hash_value=123, - token_ids=(1, 2, 3)) + block_hash = vllm.v1.core.kv_cache_utils.BlockHash(hash_value=123, + token_ids=(1, 2, 3)) block.block_hash = block_hash assert block.block_hash == block_hash @@ -282,7 +282,7 @@ def test_hash_block_tokens(hash_fn): block_hash = hash_block_tokens(hash_fn, parent_block_hash, curr_block_token_ids, extra_keys) - assert isinstance(block_hash, vllm.v1.core.kv_cache_utils.BlockHashType) + assert isinstance(block_hash, vllm.v1.core.kv_cache_utils.BlockHash) assert block_hash.hash_value == hash_fn( (parent_block_hash, curr_block_token_ids, extra_keys)) assert block_hash.token_ids == curr_block_token_ids @@ -306,10 +306,8 @@ def test_hash_request_tokens(hash_fn): block_hashes = hash_request_tokens(hash_fn, block_size, request) assert len(block_hashes) == 2 - assert isinstance(block_hashes[0], - vllm.v1.core.kv_cache_utils.BlockHashType) - assert isinstance(block_hashes[1], - vllm.v1.core.kv_cache_utils.BlockHashType) + assert isinstance(block_hashes[0], vllm.v1.core.kv_cache_utils.BlockHash) + assert isinstance(block_hashes[1], vllm.v1.core.kv_cache_utils.BlockHash) # Check the first block assert block_hashes[0].token_ids == (0, 1, 2) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index ba3c0b3cf316..1a7a31d98506 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -12,7 +12,7 @@ from vllm.sampling_params import SamplingParams from vllm.utils import sha256 from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_manager import KVCacheManager, Request -from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, +from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock, hash_block_tokens) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, SlidingWindowSpec) @@ -547,7 +547,7 @@ def test_cache_blocks(hash_fn): # Test that blocks are cached correctly for 2 full blocks from the start. blocks = [KVCacheBlock(block_id=i) for i in range(2)] - block_hashes: list[BlockHashType] = [] + block_hashes: list[BlockHash] = [] block_pool.cache_full_blocks( request=req, diff --git a/tests/v1/core/test_specialized_manager.py b/tests/v1/core/test_specialized_manager.py index 101a2379be37..4217dc37e2df 100644 --- a/tests/v1/core/test_specialized_manager.py +++ b/tests/v1/core/test_specialized_manager.py @@ -3,7 +3,7 @@ import torch from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock +from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock from vllm.v1.core.single_type_kv_cache_manager import SlidingWindowManager from vllm.v1.kv_cache_interface import SlidingWindowSpec @@ -32,7 +32,7 @@ def test_sliding_window_possible_cached_prefix(): def run_one_case(block_is_cached, expect_length): block_hash_list = [ - BlockHashType(i, ()) for i in range(len(block_is_cached)) + BlockHash(i, ()) for i in range(len(block_is_cached)) ] block_pool.cached_block_hash_to_block.clear() diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index f2ed183b68fc..a0a065df9b1c 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -6,7 +6,7 @@ from typing import Callable, Optional from vllm.distributed.kv_events import (AllBlocksCleared, BlockRemoved, BlockStored, KVCacheEvent) from vllm.logger import init_logger -from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, +from vllm.v1.core.kv_cache_utils import (BlockHash, FreeKVCacheBlockQueue, KVCacheBlock, generate_block_hash_extra_keys, hash_block_tokens) @@ -55,7 +55,7 @@ class BlockPool: # if there is already an identical block in the cache. This is because # we want to make sure the allocated block IDs won't change so that # block tables are append-only. - self.cached_block_hash_to_block: dict[BlockHashType, dict[ + self.cached_block_hash_to_block: dict[BlockHash, dict[ int, KVCacheBlock]] = defaultdict(dict) # To represent a placeholder block with block_id=0. @@ -67,7 +67,7 @@ class BlockPool: self.kv_event_queue: list[KVCacheEvent] = [] def get_cached_block(self, - block_hash: BlockHashType) -> Optional[KVCacheBlock]: + block_hash: BlockHash) -> Optional[KVCacheBlock]: """Get a cached block by the block hash, or None if cache miss. If there are duplicated blocks, we return the first block in the cache. @@ -87,7 +87,7 @@ class BlockPool: self, request: Request, blocks: list[KVCacheBlock], - block_hashes: list[BlockHashType], + block_hashes: list[BlockHash], num_cached_blocks: int, num_full_blocks: int, block_size: int, diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 0f6098d2b400..59e07382b652 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -8,7 +8,7 @@ from vllm.distributed.kv_events import KVCacheEvent from vllm.logger import init_logger from vllm.utils import sha256 from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, +from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock, hash_request_tokens) from vllm.v1.core.single_type_kv_cache_manager import ( get_manager_for_kv_cache_spec) @@ -92,7 +92,7 @@ class KVCacheManager: # This is to avoid recomputing the block hashes for each call of # `get_computed_blocks` or `allocate_slots`. self.req_to_block_hashes: defaultdict[ - str, list[BlockHashType]] = defaultdict(list) + str, list[BlockHash]] = defaultdict(list) @property def usage(self) -> float: diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index a41fe4881870..3ccad97e9919 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -18,7 +18,7 @@ from vllm.v1.request import Request logger = init_logger(__name__) -class BlockHashType(NamedTuple): +class BlockHash(NamedTuple): """Hash value of a block (int), the token IDs in the block, and extra keys. We keep a tuple of token IDs and extra keys to reduce the likelihood of hash collisions when the hash value is the same. By using SHA256 however, @@ -117,7 +117,7 @@ class KVCacheBlock: ref_cnt: int = 0 # The hash of the block composed of (block hash, tuple of token IDs). # It is only available when the block is full. - _block_hash: Optional[BlockHashType] = None + _block_hash: Optional[BlockHash] = None # Used to construct a doubly linked list for free blocks. # These two attributes should only be manipulated by FreeKVCacheBlockQueue. @@ -131,11 +131,11 @@ class KVCacheBlock: self.ref_cnt -= 1 @property - def block_hash(self) -> Optional[BlockHashType]: + def block_hash(self) -> Optional[BlockHash]: return self._block_hash @block_hash.setter - def block_hash(self, block_hash: BlockHashType): + def block_hash(self, block_hash: BlockHash): assert self.block_hash is None, ( "The block already has a hash. This should not happen.") self._block_hash = block_hash @@ -398,7 +398,7 @@ def hash_block_tokens( hash_function: Callable, parent_block_hash: Optional[int], curr_block_token_ids: Sequence[int], - extra_keys: Optional[tuple[Any, ...]] = None) -> BlockHashType: + extra_keys: Optional[tuple[Any, ...]] = None) -> BlockHash: """Computes a hash value corresponding to the contents of a block and the contents of the preceding block(s). The hash value is used for prefix caching. We use LRU cache for this function to avoid recomputing @@ -419,14 +419,14 @@ def hash_block_tokens( parent_block_hash = NONE_HASH curr_block_token_ids_tuple = tuple(curr_block_token_ids) - return BlockHashType( + return BlockHash( hash_function( (parent_block_hash, curr_block_token_ids_tuple, extra_keys)), curr_block_token_ids_tuple, extra_keys) def hash_request_tokens(hash_function: Any, block_size: int, - request: Request) -> list[BlockHashType]: + request: Request) -> list[BlockHash]: """Computes hash values of a chain of blocks given a sequence of token IDs. The hash value is used for prefix caching. diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 0223c9ceec8d..e69e9ac9f6a3 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -5,7 +5,7 @@ from typing import Callable from vllm.utils import cdiv from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock +from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec, SlidingWindowSpec) from vllm.v1.request import Request @@ -133,7 +133,7 @@ class SingleTypeKVCacheManager(ABC): req_blocks.extend(new_blocks) return new_blocks - def cache_blocks(self, request: Request, block_hashes: list[BlockHashType], + def cache_blocks(self, request: Request, block_hashes: list[BlockHash], num_tokens: int) -> None: """ Cache the blocks for the request. @@ -187,7 +187,7 @@ class SingleTypeKVCacheManager(ABC): raise NotImplementedError @abstractmethod - def find_longest_cache_hit(self, block_hashes: list[BlockHashType], + def find_longest_cache_hit(self, block_hashes: list[BlockHash], max_length: int) -> list[KVCacheBlock]: """ Get the longest cache hit prefix of the blocks that is not longer than @@ -228,7 +228,7 @@ class SingleTypeKVCacheManager(ABC): class FullAttentionManager(SingleTypeKVCacheManager): - def find_longest_cache_hit(self, block_hashes: list[BlockHashType], + def find_longest_cache_hit(self, block_hashes: list[BlockHash], max_length: int) -> list[KVCacheBlock]: computed_blocks: list[KVCacheBlock] = [] max_num_blocks = max_length // self.block_size @@ -280,7 +280,7 @@ class SlidingWindowManager(SingleTypeKVCacheManager): self.sliding_window_contiguous_blocks += 1 self._null_block = block_pool.null_block - def find_longest_cache_hit(self, block_hashes: list[BlockHashType], + def find_longest_cache_hit(self, block_hashes: list[BlockHash], max_length: int) -> list[KVCacheBlock]: # TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to # optimize the time complexity from O(max_num_blocks) to