mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 05:45:01 +08:00
[Misc][V1] Fix type in v1 prefix caching (#11151)
This commit is contained in:
parent
db6c264a1e
commit
78ed8f57d8
@ -49,7 +49,7 @@ def test_prefill():
|
|||||||
block_hash = hash_block_tokens(parent_block_hash, block_tokens)
|
block_hash = hash_block_tokens(parent_block_hash, block_tokens)
|
||||||
assert manager.block_pool[block_id].block_hash == block_hash
|
assert manager.block_pool[block_id].block_hash == block_hash
|
||||||
assert manager.block_pool[block_id].ref_cnt == 1
|
assert manager.block_pool[block_id].ref_cnt == 1
|
||||||
parent_block_hash = block_hash
|
parent_block_hash = block_hash.hash_value
|
||||||
|
|
||||||
# Check partial/preallocated block metadata
|
# Check partial/preallocated block metadata
|
||||||
for block_id in (3, 4):
|
for block_id in (3, 4):
|
||||||
@ -360,11 +360,15 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
|
|||||||
assert not computed_blocks
|
assert not computed_blocks
|
||||||
# Just ask for 1 block.
|
# Just ask for 1 block.
|
||||||
blocks = manager.allocate_slots(req, block_size, computed_blocks)
|
blocks = manager.allocate_slots(req, block_size, computed_blocks)
|
||||||
|
req.num_computed_tokens = block_size
|
||||||
assert len(blocks) == 1 + num_preallocated_blocks
|
assert len(blocks) == 1 + num_preallocated_blocks
|
||||||
|
|
||||||
# Append slots to the block.
|
# Assume all computed.
|
||||||
req.num_computed_tokens = block_size * len(blocks) # Assume all used.
|
manager.append_slots(req, block_size * (len(blocks) - 1))
|
||||||
blocks = manager.append_slots(req, block_size) # Append 1 block.
|
req.num_computed_tokens = block_size * len(blocks)
|
||||||
|
|
||||||
|
# Append 1 block.
|
||||||
|
blocks = manager.append_slots(req, block_size)
|
||||||
assert len(blocks) == 1 + num_preallocated_blocks
|
assert len(blocks) == 1 + num_preallocated_blocks
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -375,7 +375,7 @@ class KVCacheManager:
|
|||||||
prev_block: The previous block in the chain.
|
prev_block: The previous block in the chain.
|
||||||
"""
|
"""
|
||||||
# Update the new blocks with the block hashes through the chain.
|
# Update the new blocks with the block hashes through the chain.
|
||||||
prev_block_hash = (prev_block.block_hash
|
prev_block_hash_value = (prev_block.block_hash.hash_value
|
||||||
if prev_block is not None else None)
|
if prev_block is not None else None)
|
||||||
for i, blk in enumerate(full_blocks):
|
for i, blk in enumerate(full_blocks):
|
||||||
blk_idx = blk_start_idx + i
|
blk_idx = blk_start_idx + i
|
||||||
@ -390,10 +390,10 @@ class KVCacheManager:
|
|||||||
f"{request.request_id}({request})")
|
f"{request.request_id}({request})")
|
||||||
|
|
||||||
# Compute the hash of the current block.
|
# Compute the hash of the current block.
|
||||||
block_hash = hash_block_tokens(prev_block_hash,
|
block_hash = hash_block_tokens(prev_block_hash_value,
|
||||||
tuple(block_tokens))
|
tuple(block_tokens))
|
||||||
|
|
||||||
# Update and added the full block to the cache.
|
# Update and added the full block to the cache.
|
||||||
blk.block_hash = block_hash
|
blk.block_hash = block_hash
|
||||||
self.cached_block_hash_to_block[block_hash][blk.block_id] = blk
|
self.cached_block_hash_to_block[block_hash][blk.block_id] = blk
|
||||||
prev_block_hash = block_hash
|
prev_block_hash_value = block_hash.hash_value
|
||||||
|
|||||||
@ -1,12 +1,19 @@
|
|||||||
"""KV-Cache Utilities."""
|
"""KV-Cache Utilities."""
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, NamedTuple, Optional, Tuple
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
BlockHashType = Tuple[int, Tuple[int]]
|
|
||||||
|
class BlockHashType(NamedTuple):
|
||||||
|
"""Hash value of a block and the token IDs in the block.
|
||||||
|
The reason we keep a tuple of token IDs is to make sure no hash
|
||||||
|
collision happens when the hash value is the same.
|
||||||
|
"""
|
||||||
|
hash_value: int
|
||||||
|
token_ids: Tuple[int]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -171,8 +178,8 @@ def hash_block_tokens(parent_block_hash: Optional[int],
|
|||||||
The hash value of the block and the token ids in the block.
|
The hash value of the block and the token ids in the block.
|
||||||
The entire tuple is used as the hash key of the block.
|
The entire tuple is used as the hash key of the block.
|
||||||
"""
|
"""
|
||||||
return (hash(
|
return BlockHashType(hash((parent_block_hash, *curr_block_token_ids)),
|
||||||
(parent_block_hash, *curr_block_token_ids)), curr_block_token_ids)
|
curr_block_token_ids)
|
||||||
|
|
||||||
|
|
||||||
def hash_request_tokens(block_size: int,
|
def hash_request_tokens(block_size: int,
|
||||||
@ -188,14 +195,15 @@ def hash_request_tokens(block_size: int,
|
|||||||
The list of computed hash values.
|
The list of computed hash values.
|
||||||
"""
|
"""
|
||||||
ret = []
|
ret = []
|
||||||
parent_block_hash = None
|
parent_block_hash_value = None
|
||||||
for start in range(0, len(token_ids), block_size):
|
for start in range(0, len(token_ids), block_size):
|
||||||
end = start + block_size
|
end = start + block_size
|
||||||
block_token_ids = tuple(token_ids[start:end])
|
block_token_ids = tuple(token_ids[start:end])
|
||||||
# Do not hash the block if it is not full.
|
# Do not hash the block if it is not full.
|
||||||
if len(block_token_ids) < block_size:
|
if len(block_token_ids) < block_size:
|
||||||
break
|
break
|
||||||
block_hash = hash_block_tokens(parent_block_hash, block_token_ids)
|
block_hash = hash_block_tokens(parent_block_hash_value,
|
||||||
|
block_token_ids)
|
||||||
ret.append(block_hash)
|
ret.append(block_hash)
|
||||||
parent_block_hash = block_hash
|
parent_block_hash_value = block_hash.hash_value
|
||||||
return ret
|
return ret
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user