mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 22:05:44 +08:00
[Misc][V1] Fix type in v1 prefix caching (#11151)
This commit is contained in:
parent
db6c264a1e
commit
78ed8f57d8
@ -49,7 +49,7 @@ def test_prefill():
|
||||
block_hash = hash_block_tokens(parent_block_hash, block_tokens)
|
||||
assert manager.block_pool[block_id].block_hash == block_hash
|
||||
assert manager.block_pool[block_id].ref_cnt == 1
|
||||
parent_block_hash = block_hash
|
||||
parent_block_hash = block_hash.hash_value
|
||||
|
||||
# Check partial/preallocated block metadata
|
||||
for block_id in (3, 4):
|
||||
@ -360,11 +360,15 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
|
||||
assert not computed_blocks
|
||||
# Just ask for 1 block.
|
||||
blocks = manager.allocate_slots(req, block_size, computed_blocks)
|
||||
req.num_computed_tokens = block_size
|
||||
assert len(blocks) == 1 + num_preallocated_blocks
|
||||
|
||||
# Append slots to the block.
|
||||
req.num_computed_tokens = block_size * len(blocks) # Assume all used.
|
||||
blocks = manager.append_slots(req, block_size) # Append 1 block.
|
||||
# Assume all computed.
|
||||
manager.append_slots(req, block_size * (len(blocks) - 1))
|
||||
req.num_computed_tokens = block_size * len(blocks)
|
||||
|
||||
# Append 1 block.
|
||||
blocks = manager.append_slots(req, block_size)
|
||||
assert len(blocks) == 1 + num_preallocated_blocks
|
||||
|
||||
|
||||
|
||||
@ -375,8 +375,8 @@ class KVCacheManager:
|
||||
prev_block: The previous block in the chain.
|
||||
"""
|
||||
# Update the new blocks with the block hashes through the chain.
|
||||
prev_block_hash = (prev_block.block_hash
|
||||
if prev_block is not None else None)
|
||||
prev_block_hash_value = (prev_block.block_hash.hash_value
|
||||
if prev_block is not None else None)
|
||||
for i, blk in enumerate(full_blocks):
|
||||
blk_idx = blk_start_idx + i
|
||||
|
||||
@ -390,10 +390,10 @@ class KVCacheManager:
|
||||
f"{request.request_id}({request})")
|
||||
|
||||
# Compute the hash of the current block.
|
||||
block_hash = hash_block_tokens(prev_block_hash,
|
||||
block_hash = hash_block_tokens(prev_block_hash_value,
|
||||
tuple(block_tokens))
|
||||
|
||||
# Update and added the full block to the cache.
|
||||
blk.block_hash = block_hash
|
||||
self.cached_block_hash_to_block[block_hash][blk.block_id] = blk
|
||||
prev_block_hash = block_hash
|
||||
prev_block_hash_value = block_hash.hash_value
|
||||
|
||||
@ -1,12 +1,19 @@
|
||||
"""KV-Cache Utilities."""
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, NamedTuple, Optional, Tuple
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
BlockHashType = Tuple[int, Tuple[int]]
|
||||
|
||||
class BlockHashType(NamedTuple):
|
||||
"""Hash value of a block and the token IDs in the block.
|
||||
The reason we keep a tuple of token IDs is to make sure no hash
|
||||
collision happens when the hash value is the same.
|
||||
"""
|
||||
hash_value: int
|
||||
token_ids: Tuple[int]
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -171,8 +178,8 @@ def hash_block_tokens(parent_block_hash: Optional[int],
|
||||
The hash value of the block and the token ids in the block.
|
||||
The entire tuple is used as the hash key of the block.
|
||||
"""
|
||||
return (hash(
|
||||
(parent_block_hash, *curr_block_token_ids)), curr_block_token_ids)
|
||||
return BlockHashType(hash((parent_block_hash, *curr_block_token_ids)),
|
||||
curr_block_token_ids)
|
||||
|
||||
|
||||
def hash_request_tokens(block_size: int,
|
||||
@ -188,14 +195,15 @@ def hash_request_tokens(block_size: int,
|
||||
The list of computed hash values.
|
||||
"""
|
||||
ret = []
|
||||
parent_block_hash = None
|
||||
parent_block_hash_value = None
|
||||
for start in range(0, len(token_ids), block_size):
|
||||
end = start + block_size
|
||||
block_token_ids = tuple(token_ids[start:end])
|
||||
# Do not hash the block if it is not full.
|
||||
if len(block_token_ids) < block_size:
|
||||
break
|
||||
block_hash = hash_block_tokens(parent_block_hash, block_token_ids)
|
||||
block_hash = hash_block_tokens(parent_block_hash_value,
|
||||
block_token_ids)
|
||||
ret.append(block_hash)
|
||||
parent_block_hash = block_hash
|
||||
parent_block_hash_value = block_hash.hash_value
|
||||
return ret
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user