[Misc][V1] Fix type in v1 prefix caching (#11151)

This commit is contained in:
Cody Yu 2024-12-12 16:57:40 -08:00 committed by GitHub
parent db6c264a1e
commit 78ed8f57d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 27 additions and 15 deletions

View File

@ -49,7 +49,7 @@ def test_prefill():
block_hash = hash_block_tokens(parent_block_hash, block_tokens)
assert manager.block_pool[block_id].block_hash == block_hash
assert manager.block_pool[block_id].ref_cnt == 1
parent_block_hash = block_hash
parent_block_hash = block_hash.hash_value
# Check partial/preallocated block metadata
for block_id in (3, 4):
@ -360,11 +360,15 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
assert not computed_blocks
# Just ask for 1 block.
blocks = manager.allocate_slots(req, block_size, computed_blocks)
req.num_computed_tokens = block_size
assert len(blocks) == 1 + num_preallocated_blocks
# Append slots to the block.
req.num_computed_tokens = block_size * len(blocks) # Assume all used.
blocks = manager.append_slots(req, block_size) # Append 1 block.
# Assume all computed.
manager.append_slots(req, block_size * (len(blocks) - 1))
req.num_computed_tokens = block_size * len(blocks)
# Append 1 block.
blocks = manager.append_slots(req, block_size)
assert len(blocks) == 1 + num_preallocated_blocks

View File

@ -375,8 +375,8 @@ class KVCacheManager:
prev_block: The previous block in the chain.
"""
# Update the new blocks with the block hashes through the chain.
prev_block_hash = (prev_block.block_hash
if prev_block is not None else None)
prev_block_hash_value = (prev_block.block_hash.hash_value
if prev_block is not None else None)
for i, blk in enumerate(full_blocks):
blk_idx = blk_start_idx + i
@ -390,10 +390,10 @@ class KVCacheManager:
f"{request.request_id}({request})")
# Compute the hash of the current block.
block_hash = hash_block_tokens(prev_block_hash,
block_hash = hash_block_tokens(prev_block_hash_value,
tuple(block_tokens))
# Update and added the full block to the cache.
blk.block_hash = block_hash
self.cached_block_hash_to_block[block_hash][blk.block_id] = blk
prev_block_hash = block_hash
prev_block_hash_value = block_hash.hash_value

View File

@ -1,12 +1,19 @@
"""KV-Cache Utilities."""
from dataclasses import dataclass
from typing import List, Optional, Tuple
from typing import List, NamedTuple, Optional, Tuple
from vllm.logger import init_logger
logger = init_logger(__name__)
BlockHashType = Tuple[int, Tuple[int]]
class BlockHashType(NamedTuple):
"""Hash value of a block and the token IDs in the block.
The reason we keep a tuple of token IDs is to make sure no hash
collision happens when the hash value is the same.
"""
hash_value: int
token_ids: Tuple[int]
@dataclass
@ -171,8 +178,8 @@ def hash_block_tokens(parent_block_hash: Optional[int],
The hash value of the block and the token ids in the block.
The entire tuple is used as the hash key of the block.
"""
return (hash(
(parent_block_hash, *curr_block_token_ids)), curr_block_token_ids)
return BlockHashType(hash((parent_block_hash, *curr_block_token_ids)),
curr_block_token_ids)
def hash_request_tokens(block_size: int,
@ -188,14 +195,15 @@ def hash_request_tokens(block_size: int,
The list of computed hash values.
"""
ret = []
parent_block_hash = None
parent_block_hash_value = None
for start in range(0, len(token_ids), block_size):
end = start + block_size
block_token_ids = tuple(token_ids[start:end])
# Do not hash the block if it is not full.
if len(block_token_ids) < block_size:
break
block_hash = hash_block_tokens(parent_block_hash, block_token_ids)
block_hash = hash_block_tokens(parent_block_hash_value,
block_token_ids)
ret.append(block_hash)
parent_block_hash = block_hash
parent_block_hash_value = block_hash.hash_value
return ret