[v1] Refactor KVCacheManager for more hash input than token ids (#10507)

Signed-off-by: rickyx <rickyx@anyscale.com>
Signed-off-by: Cody Yu <hao.yu.cody@gmail.com>
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
Ricky Xu 2024-11-22 15:27:25 -08:00 committed by GitHub
parent eebad39f26
commit 97814fbf0f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 360 additions and 181 deletions

View File

@ -1,8 +1,11 @@
"""Compare the with and without prefix caching.""" """Compare the with and without prefix caching."""
import pytest
from vllm.inputs import token_inputs from vllm.inputs import token_inputs
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import cdiv
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
from vllm.v1.core.kv_cache_utils import hash_block_tokens from vllm.v1.core.kv_cache_utils import KVCacheBlock, hash_block_tokens
def make_request(request_id, prompt_token_ids): def make_request(request_id, prompt_token_ids):
@ -31,7 +34,8 @@ def test_prefill():
# Fully cache miss # Fully cache miss
# Incomplete 1 block (7 tokens) # Incomplete 1 block (7 tokens)
unique_token_ids = [3] * 7 unique_token_ids = [3] * 7
req0 = make_request("0", common_token_ids + unique_token_ids) all_token_ids = common_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids)
computed_blocks = manager.get_computed_blocks(req0) computed_blocks = manager.get_computed_blocks(req0)
assert not computed_blocks assert not computed_blocks
blocks = manager.allocate_slots(req0, 55, computed_blocks) blocks = manager.allocate_slots(req0, 55, computed_blocks)
@ -40,24 +44,16 @@ def test_prefill():
# Check full block metadata # Check full block metadata
parent_block_hash = None parent_block_hash = None
for block_id in (0, 1, 2): for block_id in (0, 1, 2):
block_hash = hash_block_tokens(parent_block_hash, block_tokens = tuple(all_token_ids[block_id * 16:(block_id + 1) * 16])
manager.block_pool[block_id].token_ids) block_hash = hash_block_tokens(parent_block_hash, block_tokens)
assert manager.block_pool[block_id].block_hash == block_hash assert manager.block_pool[block_id].block_hash == block_hash
assert manager.block_pool[block_id].ref_cnt == 1 assert manager.block_pool[block_id].ref_cnt == 1
assert manager.block_pool[block_id].num_hashed_tokens == 16 * (
block_id + 1)
assert manager.block_pool[block_id].token_ids == tuple([block_id] * 16)
parent_block_hash = block_hash parent_block_hash = block_hash
# Check partial/preallocated block metadata # Check partial/preallocated block metadata
for block_id in (3, 4): for block_id in (3, 4):
assert manager.block_pool[block_id].block_hash is None assert manager.block_pool[block_id].block_hash is None
assert manager.block_pool[block_id].ref_cnt == 1 assert manager.block_pool[block_id].ref_cnt == 1
assert manager.block_pool[block_id].num_hashed_tokens == 0
if block_id == 3:
assert manager.block_pool[block_id].token_ids == [3] * 7
else:
assert not manager.block_pool[block_id].token_ids
# Cache hit in the common prefix when the original block is still in use. # Cache hit in the common prefix when the original block is still in use.
# Incomplete 1 block (5 tokens) # Incomplete 1 block (5 tokens)
@ -113,7 +109,7 @@ def test_prefill():
req3 = make_request("3", [99] * (16 * 9)) req3 = make_request("3", [99] * (16 * 9))
computed_blocks = manager.get_computed_blocks(req3) computed_blocks = manager.get_computed_blocks(req3)
assert not computed_blocks assert not computed_blocks
blocks = manager.allocate_slots(req2, 16 * 9, computed_blocks) blocks = manager.allocate_slots(req3, 16 * 9, computed_blocks)
# This block ID order also checks the eviction order. # This block ID order also checks the eviction order.
assert [b.block_id for b in blocks] == [9, 4, 3, 6, 5, 8, 7, 2, 1, 0] assert [b.block_id for b in blocks] == [9, 4, 3, 6, 5, 8, 7, 2, 1, 0]
assert manager.free_block_queue.num_free_blocks == 0 assert manager.free_block_queue.num_free_blocks == 0
@ -148,7 +144,7 @@ def test_decode():
req0.append_output_token_ids(8) req0.append_output_token_ids(8)
new_blocks = manager.append_slots(req0, 4) new_blocks = manager.append_slots(req0, 4)
assert new_blocks is not None and len(new_blocks) == 0 assert new_blocks is not None and len(new_blocks) == 0
assert len(manager.block_pool[3].token_ids) == 11 assert manager.req_to_blocks[req0.request_id][-2].block_hash is None
# Append slots without allocating a new block, but start using the # Append slots without allocating a new block, but start using the
# preallocated block. # preallocated block.
@ -159,8 +155,7 @@ def test_decode():
req0.append_output_token_ids(7) req0.append_output_token_ids(7)
new_blocks = manager.append_slots(req0, 15) new_blocks = manager.append_slots(req0, 15)
assert new_blocks is not None and len(new_blocks) == 0 assert new_blocks is not None and len(new_blocks) == 0
assert len(manager.block_pool[3].token_ids) == 16 assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None
assert len(manager.block_pool[4].token_ids) == 10
# Append slots with allocating a new block. # Append slots with allocating a new block.
req0.num_computed_tokens = 74 req0.num_computed_tokens = 74
@ -171,9 +166,6 @@ def test_decode():
new_blocks = manager.append_slots(req0, 17) new_blocks = manager.append_slots(req0, 17)
# Plus one preallocated block. # Plus one preallocated block.
assert new_blocks is not None and len(new_blocks) == 2 assert new_blocks is not None and len(new_blocks) == 2
assert len(manager.block_pool[4].token_ids) == 16
assert len(manager.block_pool[5].token_ids) == 11
assert len(manager.block_pool[6].token_ids) == 0
def test_evict(): def test_evict():
@ -217,3 +209,198 @@ def test_evict():
blocks = manager.allocate_slots(req2, 3, computed_blocks) blocks = manager.allocate_slots(req2, 3, computed_blocks)
assert [b.block_id for b in blocks] == [6, 5] assert [b.block_id for b in blocks] == [6, 5]
assert manager.free_block_queue.num_free_blocks == 6 assert manager.free_block_queue.num_free_blocks == 6
def test_hash_block_correct_reuse():
"""
This tests when a previously cached block is reused as a new block,
its hash metadata should be correctly reset.
"""
block_size = 16
manager = KVCacheManager(
block_size=block_size,
num_gpu_blocks=1,
sliding_window=False,
enable_caching=True,
num_preallocate_tokens=0,
)
# Allocate 1 block and cache it.
num_tokens = block_size * 1
req = make_request("0", list(range(num_tokens)))
computed_blocks = manager.get_computed_blocks(req)
assert not computed_blocks
blocks = manager.allocate_slots(req, num_tokens, computed_blocks)
assert len(blocks) == 1
# Deallocate the block.
manager.free(req)
# Allocate a new block that's not full, make sure hash info on the
# block is cleared.
req = make_request("1", list(range(num_tokens - 1)))
computed_blocks = manager.get_computed_blocks(req)
assert not computed_blocks
blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks)
assert len(blocks) == 1
assert manager.block_pool[blocks[0].block_id].block_hash is None
def test_computed_blocks_not_evicted():
"""
Test that the computed blocks are not evicted when getting new blocks
for a request if there are any other free blocks.
"""
block_size = 16
manager = KVCacheManager(
block_size=block_size,
num_gpu_blocks=2,
sliding_window=False,
enable_caching=True,
num_preallocate_tokens=0,
)
# Allocate a block and cache it.
num_tokens = block_size * 1
req0 = make_request("0", list(range(num_tokens)))
computed_blocks = manager.get_computed_blocks(req0)
assert not computed_blocks
blocks = manager.allocate_slots(req0, num_tokens, computed_blocks)
assert len(blocks) == 1
assert blocks[0].block_id == 0
# Allocate another block.
req1 = make_request("1", list(range(num_tokens, num_tokens * 2)))
computed_blocks = manager.get_computed_blocks(req1)
assert not computed_blocks
blocks = manager.allocate_slots(req1, num_tokens, computed_blocks)
assert len(blocks) == 1
assert blocks[0].block_id == 1
# Free the blocks.
manager.free(req0)
manager.free(req1)
# Now if we have a cache hit on the first block, we should evict the second
# cached block rather than the first one.
req2 = make_request("2", list(range(num_tokens * 2)))
computed_blocks = manager.get_computed_blocks(req2)
assert len(computed_blocks) == 1
assert computed_blocks[0].block_id == 0
blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens,
computed_blocks)
assert len(blocks) == 1
assert blocks[0].block_id == 1
def test_basic_prefix_caching_disabled():
"""
This tests that the prefix caching is disabled.
"""
block_size = 4
manager = KVCacheManager(
block_size=block_size,
num_gpu_blocks=4,
sliding_window=False,
enable_caching=False,
num_preallocate_tokens=0,
)
req1 = make_request("1", list(range(10))) # 2 blocks and some more
computed_blocks = manager.get_computed_blocks(req1)
assert not computed_blocks
blocks = manager.allocate_slots(req1, 10, computed_blocks)
assert len(blocks) == 3
# Free the blocks.
manager.free(req1)
# No caching.
req2 = make_request("2", list(range(16))) # shared prefix
computed_blocks = manager.get_computed_blocks(req2)
assert not computed_blocks
blocks = manager.allocate_slots(req2, 16, computed_blocks)
assert len(blocks) == 4
# New requests should not have any blocks.
req3 = make_request("3", list(range(4)))
computed_blocks = manager.get_computed_blocks(req3)
assert not computed_blocks
blocks = manager.allocate_slots(req3, 4, computed_blocks)
assert not blocks
@pytest.mark.parametrize("num_preallocate_tokens", list(range(0, 8)))
@pytest.mark.parametrize("block_size", [4])
def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
"""
This tests that the preallocated blocks are correctly added.
"""
manager = KVCacheManager(
block_size=block_size,
num_gpu_blocks=10,
sliding_window=False,
enable_caching=True,
num_preallocate_tokens=num_preallocate_tokens,
)
num_preallocated_blocks = cdiv(num_preallocate_tokens, block_size)
req = make_request("0", list(range(block_size * 30)))
computed_blocks = manager.get_computed_blocks(req)
assert not computed_blocks
# Just ask for 1 block.
blocks = manager.allocate_slots(req, block_size, computed_blocks)
assert len(blocks) == 1 + num_preallocated_blocks
# Append slots to the block.
req.num_computed_tokens = block_size * len(blocks) # Assume all used.
blocks = manager.append_slots(req, block_size) # Append 1 block.
assert len(blocks) == 1 + num_preallocated_blocks
def test_cache_blocks():
"""
This is a unit test that tests the correctness of the _cache_full_blocks
function of KVCacheManager.
"""
block_size = 4
manager = KVCacheManager(
block_size=block_size,
num_gpu_blocks=5,
sliding_window=False,
enable_caching=True,
num_preallocate_tokens=0,
)
# Req:
# Block 0: [0, 1, 2, 3]
# Block 1: [4, 5, 6, 7]
# Block 2: [8, 9, 10, 11]
# Block 3: [12, 13]
req = make_request("0", list(range(14)))
# Test that blocks are cached correctly for 2 full blocks from the start.
blocks = [KVCacheBlock(block_id=i) for i in range(2)]
manager._cache_full_blocks(
request=req,
blk_start_idx=0,
full_blocks=blocks,
prev_block=None,
)
assert len(manager.cached_block_hash_to_block) == 2
assert all([block.block_hash is not None for block in blocks])
# Test that blocks that don't start from the beginning are cached correctly.
blocks = [KVCacheBlock(block_id=2)]
manager._cache_full_blocks(
request=req,
blk_start_idx=2,
full_blocks=blocks,
prev_block=None,
)
assert len(manager.cached_block_hash_to_block) == 3
assert blocks[0].block_hash is not None

View File

@ -79,6 +79,9 @@ class KVCacheManager:
return [] return []
computed_blocks = [] computed_blocks = []
# TODO(rickyx): potentially we could cache this so we don't have to
# recompute it every time.
block_hashes = hash_request_tokens(self.block_size, block_hashes = hash_request_tokens(self.block_size,
request.all_token_ids) request.all_token_ids)
@ -120,47 +123,45 @@ class KVCacheManager:
# slots, but we cannot allocate new blocks due to the limit. # slots, but we cannot allocate new blocks due to the limit.
return None return None
# When caching is enabled, assign token IDs to already allocated blocks. if num_new_blocks <= 0:
new_token_ids = None # No new block is needed.
parent_block = None new_blocks = []
if self.enable_caching: else:
# Figure out the token IDs to add to the blocks. # Get new blocks from the free block pool considering
new_token_ids = request.all_token_ids[ # preallocated blocks.
request.num_computed_tokens:request.num_computed_tokens + num_new_blocks = min(
num_tokens] num_new_blocks + self.num_preallocate_blocks,
self.free_block_queue.num_free_blocks,
)
# Find the last full block index. new_blocks = self._get_new_blocks(num_new_blocks)
# TODO: This may be optimized by calculating the computed tokens. req_blocks.extend(new_blocks)
last_full_block_idx = len(req_blocks) - 1
while (last_full_block_idx >= 0
and req_blocks[last_full_block_idx].block_hash is None):
last_full_block_idx -= 1
parent_block = (req_blocks[last_full_block_idx] if not self.enable_caching:
if last_full_block_idx >= 0 else None) return new_blocks
token_id_idx = self._add_token_ids_to_blocks(
blocks=req_blocks[last_full_block_idx + 1:],
token_ids=new_token_ids,
parent_block=parent_block)
new_token_ids = new_token_ids[token_id_idx:] num_computed_full_blocks = (request.num_computed_tokens //
parent_block = req_blocks[-1] self.block_size)
# No new block is needed. When caching is enabled, we make sure # NOTE(rickyx): We are assuming the `num_tokens` are actual
# token_id_idx is equal to len(new_token_ids), meaning that all tokens # tokens rather than lookahead slots (e.g. for speculative decoding).
# are added to allocated blocks. # TODO(rickyx): When supporting speculative decoding, we will need to
if num_required_blocks <= len(req_blocks): # differentiate between them so that we can know how many blocks are
assert not self.enable_caching or token_id_idx == num_tokens, \ # full after appending the actual tokens.
f"{token_id_idx=} != {num_tokens=}" num_full_blocks_after_append = (request.num_computed_tokens +
return [] num_tokens) // self.block_size
assert num_full_blocks_after_append <= len(req_blocks)
new_full_blocks = req_blocks[
num_computed_full_blocks:num_full_blocks_after_append]
self._cache_full_blocks(
request=request,
blk_start_idx=num_computed_full_blocks,
full_blocks=new_full_blocks,
prev_block=req_blocks[num_computed_full_blocks - 1]
if num_computed_full_blocks >= 1 else None,
)
# Allocate new blocks considering preallocated blocks, and
# add token IDs to them if caching is enabled.
num_new_blocks = min(num_new_blocks + self.num_preallocate_blocks,
self.free_block_queue.num_free_blocks)
new_blocks = self._get_new_blocks(num_new_blocks, new_token_ids,
parent_block)
req_blocks.extend(new_blocks)
return new_blocks return new_blocks
def allocate_slots( def allocate_slots(
@ -184,11 +185,20 @@ class KVCacheManager:
raise ValueError( raise ValueError(
f"num_tokens must be greater than 0, got {num_tokens}") f"num_tokens must be greater than 0, got {num_tokens}")
# If a computed block of a request is an eviction candidate (in the # Touch the computed blocks to make sure they won't be evicted.
# free queue and ref_cnt == 0), it cannot be counted as a free block num_evictable_computed_blocks = 0
# when allocating this request. if self.enable_caching:
num_evictable_computed_blocks = len( self._touch(computed_blocks)
[blk for blk in computed_blocks if blk.ref_cnt == 0])
# If a computed block of a request is an eviction candidate (in the
# free queue and ref_cnt == 0), it cannot be counted as a free block
# when allocating this request.
num_evictable_computed_blocks = len(
[blk for blk in computed_blocks if blk.ref_cnt == 0])
else:
assert not computed_blocks, (
"Computed blocks should be empty when "
"prefix caching is disabled")
num_required_blocks = cdiv(num_tokens, self.block_size) num_required_blocks = cdiv(num_tokens, self.block_size)
if (num_required_blocks > self.free_block_queue.num_free_blocks - if (num_required_blocks > self.free_block_queue.num_free_blocks -
@ -201,35 +211,28 @@ class KVCacheManager:
num_new_blocks = min( num_new_blocks = min(
num_required_blocks + self.num_preallocate_blocks, num_required_blocks + self.num_preallocate_blocks,
self.free_block_queue.num_free_blocks - self.free_block_queue.num_free_blocks -
num_evictable_computed_blocks) num_evictable_computed_blocks,
)
num_computed_tokens = len(computed_blocks) * self.block_size
# When caching is enabled, get the new token IDs and the parent block
# ID to generate cache keys.
new_token_ids = None
parent_block = None
if self.enable_caching:
# Touch the computed blocks to make sure they won't be evicted.
self._touch(computed_blocks)
# Get the token IDs for the blocks being allocated for hashing.
new_token_ids = request.all_token_ids[
num_computed_tokens:num_computed_tokens + num_tokens]
if not new_token_ids:
raise RuntimeError(
"Failed to infer the token IDs for allocation. "
f"#all_tokens={len(request.all_token_ids)} < "
f"#computed_tokens={num_computed_tokens}")
# Get the parent block ID to construct the block chain.
parent_block = computed_blocks[-1] if computed_blocks else None
new_blocks = self._get_new_blocks(num_new_blocks, new_token_ids,
parent_block)
# Concatenate the computed block IDs and the new block IDs. # Concatenate the computed block IDs and the new block IDs.
new_blocks = self._get_new_blocks(num_new_blocks)
self.req_to_blocks[request.request_id] = computed_blocks + new_blocks self.req_to_blocks[request.request_id] = computed_blocks + new_blocks
if not self.enable_caching:
return new_blocks
num_computed_tokens = len(computed_blocks) * self.block_size
num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size
self._cache_full_blocks(
request=request,
blk_start_idx=len(computed_blocks),
# The new full blocks are the full blocks that are not computed.
full_blocks=self.req_to_blocks[request.request_id]
[len(computed_blocks):num_full_blocks],
prev_block=computed_blocks[-1] if computed_blocks else None,
)
return new_blocks return new_blocks
def free(self, request: Request) -> None: def free(self, request: Request) -> None:
@ -248,24 +251,17 @@ class KVCacheManager:
blocks = reversed(blocks) blocks = reversed(blocks)
for block in blocks: for block in blocks:
block.ref_cnt -= 1 block.decr_ref()
if block.ref_cnt == 0: if block.ref_cnt == 0:
self.free_block_queue.append(block) self.free_block_queue.append(block)
def _get_new_blocks( def _get_new_blocks(self, num_blocks: int) -> List[KVCacheBlock]:
self, """Get new blocks from the free block pool.
num_blocks: int,
token_ids: Optional[List[int]] = None,
parent_block: Optional[int] = None) -> List[KVCacheBlock]:
"""Get new blocks from the free block pool, and add token IDs to
allocated blocks if caching is enabled.
Note that we do not check block cache in this function. Note that we do not check block cache in this function.
Args: Args:
num_blocks: The number of blocks to allocate. num_blocks: The number of blocks to allocate.
token_ids: The token IDs in the blocks. None if caching is disabled.
parent_block: The parent block. Used to include block chain
in the block hash.
Returns: Returns:
A list of new block. A list of new block.
@ -274,56 +270,38 @@ class KVCacheManager:
raise ValueError( raise ValueError(
f"Cannot get {num_blocks} free blocks from the pool") f"Cannot get {num_blocks} free blocks from the pool")
# First allocate blocks.
ret: List[KVCacheBlock] = [] ret: List[KVCacheBlock] = []
idx = 0 idx = 0
while idx < num_blocks: while idx < num_blocks:
# First allocate blocks.
curr_block = self.free_block_queue.popleft() curr_block = self.free_block_queue.popleft()
assert curr_block.ref_cnt == 0 assert curr_block.ref_cnt == 0
# Evict blocks from the cache. # If the block is cached, evict it.
if self.enable_caching: if self.enable_caching:
block_hash = curr_block.block_hash self._evict_cached_block(curr_block)
if (block_hash is not None
and block_hash in self.cached_block_hash_to_block):
if len(self.cached_block_hash_to_block[block_hash]) == 1:
del self.cached_block_hash_to_block[block_hash]
else:
del self.cached_block_hash_to_block[block_hash][
curr_block.block_id]
curr_block.reset()
curr_block.ref_cnt = 1 curr_block.incr_ref()
ret.append(curr_block) ret.append(curr_block)
idx += 1 idx += 1
# Then assign token IDs to the allocated blocks.
if self.enable_caching:
assert token_ids is not None
token_id_idx = self._add_token_ids_to_blocks(
blocks=ret, token_ids=token_ids, parent_block=parent_block)
assert token_id_idx == len(token_ids)
return ret return ret
def _cache_full_block(self, def _evict_cached_block(self, block: KVCacheBlock) -> None:
block: KVCacheBlock, """
parent_block: Optional[KVCacheBlock] = None) -> None: If a block is cached in `cached_block_hash_to_block`, we reset its hash
"""Cache a full block for prefix caching. metadata and evict it from the cache.
Args: Args:
block: The block to cache. block: The block to evict.
parent_block: The parent block. None if this is the first block.
""" """
parent_block_hash = (parent_block.block_hash block_hash = block.block_hash
if parent_block is not None else None) if block_hash and block_hash in self.cached_block_hash_to_block:
assert len(block.token_ids) == self.block_size block.reset_hash()
block.token_ids = tuple(block.token_ids) del self.cached_block_hash_to_block[block_hash][block.block_id]
block_hash = hash_block_tokens(parent_block_hash, block.token_ids)
block.block_hash = block_hash if len(self.cached_block_hash_to_block[block_hash]) == 0:
block.num_hashed_tokens = self.block_size + ( del self.cached_block_hash_to_block[block_hash]
parent_block.num_hashed_tokens if parent_block is not None else 0)
self.cached_block_hash_to_block[block_hash][block.block_id] = block
def _get_cached_block(self, def _get_cached_block(self,
block_hash: BlockHashType) -> Optional[KVCacheBlock]: block_hash: BlockHashType) -> Optional[KVCacheBlock]:
@ -355,43 +333,50 @@ class KVCacheManager:
# candidate), so remove it. # candidate), so remove it.
if block.ref_cnt == 0: if block.ref_cnt == 0:
self.free_block_queue.remove(block) self.free_block_queue.remove(block)
block.ref_cnt += 1 block.incr_ref()
def _add_token_ids_to_blocks( def _cache_full_blocks(
self, self,
blocks: List[KVCacheBlock], request: Request,
token_ids: List[int], blk_start_idx: int,
parent_block: Optional[KVCacheBlock] = None) -> int: full_blocks: List[KVCacheBlock],
"""Add token IDs to a list of allocated blocks. prev_block: Optional[KVCacheBlock],
If a block becomes full after adding token IDs, cache it. ) -> None:
Return the token ID index that has not been added to the blocks """Cache a list of full blocks for prefix caching.
if the blocks are not enough to hold all the token IDs.
This function takes a list of blocks that will have their block hash
metadata to be updated and cached. Given a request, it computes the
block hashes for the blocks starting from `blk_start_idx` to the end
of the request's full blocks, updating the metadata for each block
and caching them in the `cached_block_hash_to_block`.
Args: Args:
blocks: A list of blocks to add token IDs. request: The request to cache the blocks.
token_ids: A list of token IDs to add. blk_start_idx: The index of the first block in the request's blocks
parent_block: The parent block. None if this is the to cache.
first block. full_blocks: The list of blocks to update hash metadata.
prev_block: The previous block in the chain.
Returns:
The starting token ID index that has not been added to the blocks
due to insufficient given blocks.
""" """
token_id_start = 0 # Update the new blocks with the block hashes through the chain.
for curr_block in blocks: prev_block_hash = (prev_block.block_hash
# If all token IDs are added, then the rest of the blocks are if prev_block is not None else None)
# preallocated blocks, so we only need to update the for i, blk in enumerate(full_blocks):
# parent_block_id. FIXME blk_idx = blk_start_idx + i
if token_id_start == len(token_ids):
continue
# Add token IDs to the empty slots in the block. block_tokens = request.all_token_ids[blk_idx *
empty_slots = self.block_size - len(curr_block.token_ids) self.block_size:(blk_idx +
token_id_end = min(token_id_start + empty_slots, len(token_ids)) 1) *
curr_block.token_ids.extend(token_ids[token_id_start:token_id_end]) self.block_size]
# Cache the block if it becomes full. assert len(block_tokens) == self.block_size, (
if len(curr_block.token_ids) == self.block_size: f"Expected {self.block_size} tokens, got {len(block_tokens)} "
self._cache_full_block(curr_block, parent_block) f"at {blk_idx}th block for request "
parent_block = curr_block f"{request.request_id}({request})")
token_id_start = token_id_end
return token_id_start # Compute the hash of the current block.
block_hash = hash_block_tokens(prev_block_hash,
tuple(block_tokens))
# Update and added the full block to the cache.
blk.block_hash = block_hash
self.cached_block_hash_to_block[block_hash][blk.block_id] = blk
prev_block_hash = block_hash

View File

@ -1,6 +1,6 @@
"""KV-Cache Utilities.""" """KV-Cache Utilities."""
from dataclasses import dataclass, field from dataclasses import dataclass
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple
from vllm.logger import init_logger from vllm.logger import init_logger
@ -16,27 +16,34 @@ class KVCacheBlock:
block_id: int block_id: int
# Reference count. # Reference count.
ref_cnt: int = 0 ref_cnt: int = 0
# Token IDs in the block. When the block is full, the type of token_ids
# should be Tuple[int] for fast matching.
token_ids: Union[List[int], Tuple[int]] = field(default_factory=list)
# The hash of the block composed of (block hash, tuple of token IDs). # The hash of the block composed of (block hash, tuple of token IDs).
# It is only available when the block is full. # It is only available when the block is full.
block_hash: Optional[BlockHashType] = None _block_hash: Optional[BlockHashType] = None
# The number of hashed tokens. More hashed tokens means the block
# is closer to the end of a prompt and more likely to be evicted.
num_hashed_tokens: int = 0
# Used to construct a doubly linked list for free blocks. # Used to construct a doubly linked list for free blocks.
# These two attributes should only be manipulated by FreeKVCacheBlockQueue. # These two attributes should only be manipulated by FreeKVCacheBlockQueue.
prev_free_block: Optional["KVCacheBlock"] = None prev_free_block: Optional["KVCacheBlock"] = None
next_free_block: Optional["KVCacheBlock"] = None next_free_block: Optional["KVCacheBlock"] = None
def reset(self): def incr_ref(self):
"""Reset the block metadata.""" self.ref_cnt += 1
self.ref_cnt = 0
self.token_ids = [] def decr_ref(self):
self.block_hash = None self.ref_cnt -= 1
self.num_hashed_tokens = 0
@property
def block_hash(self) -> Optional[BlockHashType]:
return self._block_hash
@block_hash.setter
def block_hash(self, block_hash: BlockHashType):
assert self.block_hash is None, (
"The block already has a hash. This should not happen.")
self._block_hash = block_hash
def reset_hash(self):
"""Reset the block hash when the block is evicted."""
self._block_hash = None
class FreeKVCacheBlockQueue: class FreeKVCacheBlockQueue: