From 78ed8f57d8815cdd5567533f7d3e25b959d861ab Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Thu, 12 Dec 2024 16:57:40 -0800 Subject: [PATCH] [Misc][V1] Fix type in v1 prefix caching (#11151) --- tests/v1/core/test_prefix_caching.py | 12 ++++++++---- vllm/v1/core/kv_cache_manager.py | 8 ++++---- vllm/v1/core/kv_cache_utils.py | 22 +++++++++++++++------- 3 files changed, 27 insertions(+), 15 deletions(-) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index b44d3e5cb067..00f7b0fcfe1d 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -49,7 +49,7 @@ def test_prefill(): block_hash = hash_block_tokens(parent_block_hash, block_tokens) assert manager.block_pool[block_id].block_hash == block_hash assert manager.block_pool[block_id].ref_cnt == 1 - parent_block_hash = block_hash + parent_block_hash = block_hash.hash_value # Check partial/preallocated block metadata for block_id in (3, 4): @@ -360,11 +360,15 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int): assert not computed_blocks # Just ask for 1 block. blocks = manager.allocate_slots(req, block_size, computed_blocks) + req.num_computed_tokens = block_size assert len(blocks) == 1 + num_preallocated_blocks - # Append slots to the block. - req.num_computed_tokens = block_size * len(blocks) # Assume all used. - blocks = manager.append_slots(req, block_size) # Append 1 block. + # Assume all computed. + manager.append_slots(req, block_size * (len(blocks) - 1)) + req.num_computed_tokens = block_size * len(blocks) + + # Append 1 block. + blocks = manager.append_slots(req, block_size) assert len(blocks) == 1 + num_preallocated_blocks diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index b492a755e6dd..03cbb958237d 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -375,8 +375,8 @@ class KVCacheManager: prev_block: The previous block in the chain. """ # Update the new blocks with the block hashes through the chain. - prev_block_hash = (prev_block.block_hash - if prev_block is not None else None) + prev_block_hash_value = (prev_block.block_hash.hash_value + if prev_block is not None else None) for i, blk in enumerate(full_blocks): blk_idx = blk_start_idx + i @@ -390,10 +390,10 @@ class KVCacheManager: f"{request.request_id}({request})") # Compute the hash of the current block. - block_hash = hash_block_tokens(prev_block_hash, + block_hash = hash_block_tokens(prev_block_hash_value, tuple(block_tokens)) # Update and added the full block to the cache. blk.block_hash = block_hash self.cached_block_hash_to_block[block_hash][blk.block_id] = blk - prev_block_hash = block_hash + prev_block_hash_value = block_hash.hash_value diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index fb666c364bfb..814e462a91fe 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1,12 +1,19 @@ """KV-Cache Utilities.""" from dataclasses import dataclass -from typing import List, Optional, Tuple +from typing import List, NamedTuple, Optional, Tuple from vllm.logger import init_logger logger = init_logger(__name__) -BlockHashType = Tuple[int, Tuple[int]] + +class BlockHashType(NamedTuple): + """Hash value of a block and the token IDs in the block. + The reason we keep a tuple of token IDs is to make sure no hash + collision happens when the hash value is the same. + """ + hash_value: int + token_ids: Tuple[int] @dataclass @@ -171,8 +178,8 @@ def hash_block_tokens(parent_block_hash: Optional[int], The hash value of the block and the token ids in the block. The entire tuple is used as the hash key of the block. """ - return (hash( - (parent_block_hash, *curr_block_token_ids)), curr_block_token_ids) + return BlockHashType(hash((parent_block_hash, *curr_block_token_ids)), + curr_block_token_ids) def hash_request_tokens(block_size: int, @@ -188,14 +195,15 @@ def hash_request_tokens(block_size: int, The list of computed hash values. """ ret = [] - parent_block_hash = None + parent_block_hash_value = None for start in range(0, len(token_ids), block_size): end = start + block_size block_token_ids = tuple(token_ids[start:end]) # Do not hash the block if it is not full. if len(block_token_ids) < block_size: break - block_hash = hash_block_tokens(parent_block_hash, block_token_ids) + block_hash = hash_block_tokens(parent_block_hash_value, + block_token_ids) ret.append(block_hash) - parent_block_hash = block_hash + parent_block_hash_value = block_hash.hash_value return ret