[Misc][V1] Fix type in v1 prefix caching (#11151)

This commit is contained in:
Cody Yu 2024-12-12 16:57:40 -08:00 committed by GitHub
parent db6c264a1e
commit 78ed8f57d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 27 additions and 15 deletions

View File

@ -49,7 +49,7 @@ def test_prefill():
block_hash = hash_block_tokens(parent_block_hash, block_tokens) block_hash = hash_block_tokens(parent_block_hash, block_tokens)
assert manager.block_pool[block_id].block_hash == block_hash assert manager.block_pool[block_id].block_hash == block_hash
assert manager.block_pool[block_id].ref_cnt == 1 assert manager.block_pool[block_id].ref_cnt == 1
parent_block_hash = block_hash parent_block_hash = block_hash.hash_value
# Check partial/preallocated block metadata # Check partial/preallocated block metadata
for block_id in (3, 4): for block_id in (3, 4):
@ -360,11 +360,15 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
assert not computed_blocks assert not computed_blocks
# Just ask for 1 block. # Just ask for 1 block.
blocks = manager.allocate_slots(req, block_size, computed_blocks) blocks = manager.allocate_slots(req, block_size, computed_blocks)
req.num_computed_tokens = block_size
assert len(blocks) == 1 + num_preallocated_blocks assert len(blocks) == 1 + num_preallocated_blocks
# Append slots to the block. # Assume all computed.
req.num_computed_tokens = block_size * len(blocks) # Assume all used. manager.append_slots(req, block_size * (len(blocks) - 1))
blocks = manager.append_slots(req, block_size) # Append 1 block. req.num_computed_tokens = block_size * len(blocks)
# Append 1 block.
blocks = manager.append_slots(req, block_size)
assert len(blocks) == 1 + num_preallocated_blocks assert len(blocks) == 1 + num_preallocated_blocks

View File

@ -375,7 +375,7 @@ class KVCacheManager:
prev_block: The previous block in the chain. prev_block: The previous block in the chain.
""" """
# Update the new blocks with the block hashes through the chain. # Update the new blocks with the block hashes through the chain.
prev_block_hash = (prev_block.block_hash prev_block_hash_value = (prev_block.block_hash.hash_value
if prev_block is not None else None) if prev_block is not None else None)
for i, blk in enumerate(full_blocks): for i, blk in enumerate(full_blocks):
blk_idx = blk_start_idx + i blk_idx = blk_start_idx + i
@ -390,10 +390,10 @@ class KVCacheManager:
f"{request.request_id}({request})") f"{request.request_id}({request})")
# Compute the hash of the current block. # Compute the hash of the current block.
block_hash = hash_block_tokens(prev_block_hash, block_hash = hash_block_tokens(prev_block_hash_value,
tuple(block_tokens)) tuple(block_tokens))
# Update and added the full block to the cache. # Update and added the full block to the cache.
blk.block_hash = block_hash blk.block_hash = block_hash
self.cached_block_hash_to_block[block_hash][blk.block_id] = blk self.cached_block_hash_to_block[block_hash][blk.block_id] = blk
prev_block_hash = block_hash prev_block_hash_value = block_hash.hash_value

View File

@ -1,12 +1,19 @@
"""KV-Cache Utilities.""" """KV-Cache Utilities."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple from typing import List, NamedTuple, Optional, Tuple
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
BlockHashType = Tuple[int, Tuple[int]]
class BlockHashType(NamedTuple):
"""Hash value of a block and the token IDs in the block.
The reason we keep a tuple of token IDs is to make sure no hash
collision happens when the hash value is the same.
"""
hash_value: int
token_ids: Tuple[int]
@dataclass @dataclass
@ -171,8 +178,8 @@ def hash_block_tokens(parent_block_hash: Optional[int],
The hash value of the block and the token ids in the block. The hash value of the block and the token ids in the block.
The entire tuple is used as the hash key of the block. The entire tuple is used as the hash key of the block.
""" """
return (hash( return BlockHashType(hash((parent_block_hash, *curr_block_token_ids)),
(parent_block_hash, *curr_block_token_ids)), curr_block_token_ids) curr_block_token_ids)
def hash_request_tokens(block_size: int, def hash_request_tokens(block_size: int,
@ -188,14 +195,15 @@ def hash_request_tokens(block_size: int,
The list of computed hash values. The list of computed hash values.
""" """
ret = [] ret = []
parent_block_hash = None parent_block_hash_value = None
for start in range(0, len(token_ids), block_size): for start in range(0, len(token_ids), block_size):
end = start + block_size end = start + block_size
block_token_ids = tuple(token_ids[start:end]) block_token_ids = tuple(token_ids[start:end])
# Do not hash the block if it is not full. # Do not hash the block if it is not full.
if len(block_token_ids) < block_size: if len(block_token_ids) < block_size:
break break
block_hash = hash_block_tokens(parent_block_hash, block_token_ids) block_hash = hash_block_tokens(parent_block_hash_value,
block_token_ids)
ret.append(block_hash) ret.append(block_hash)
parent_block_hash = block_hash parent_block_hash_value = block_hash.hash_value
return ret return ret