[v1] Refactor KVCacheManager for more hash input than token ids (#10507)

Signed-off-by: rickyx <rickyx@anyscale.com>
Signed-off-by: Cody Yu <hao.yu.cody@gmail.com>
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
Ricky Xu 2024-11-22 15:27:25 -08:00 committed by GitHub
parent eebad39f26
commit 97814fbf0f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 360 additions and 181 deletions

View File

@ -1,8 +1,11 @@
"""Compare the with and without prefix caching."""
import pytest
from vllm.inputs import token_inputs
from vllm.sampling_params import SamplingParams
from vllm.utils import cdiv
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
from vllm.v1.core.kv_cache_utils import hash_block_tokens
from vllm.v1.core.kv_cache_utils import KVCacheBlock, hash_block_tokens
def make_request(request_id, prompt_token_ids):
@ -31,7 +34,8 @@ def test_prefill():
# Fully cache miss
# Incomplete 1 block (7 tokens)
unique_token_ids = [3] * 7
req0 = make_request("0", common_token_ids + unique_token_ids)
all_token_ids = common_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids)
computed_blocks = manager.get_computed_blocks(req0)
assert not computed_blocks
blocks = manager.allocate_slots(req0, 55, computed_blocks)
@ -40,24 +44,16 @@ def test_prefill():
# Check full block metadata
parent_block_hash = None
for block_id in (0, 1, 2):
block_hash = hash_block_tokens(parent_block_hash,
manager.block_pool[block_id].token_ids)
block_tokens = tuple(all_token_ids[block_id * 16:(block_id + 1) * 16])
block_hash = hash_block_tokens(parent_block_hash, block_tokens)
assert manager.block_pool[block_id].block_hash == block_hash
assert manager.block_pool[block_id].ref_cnt == 1
assert manager.block_pool[block_id].num_hashed_tokens == 16 * (
block_id + 1)
assert manager.block_pool[block_id].token_ids == tuple([block_id] * 16)
parent_block_hash = block_hash
# Check partial/preallocated block metadata
for block_id in (3, 4):
assert manager.block_pool[block_id].block_hash is None
assert manager.block_pool[block_id].ref_cnt == 1
assert manager.block_pool[block_id].num_hashed_tokens == 0
if block_id == 3:
assert manager.block_pool[block_id].token_ids == [3] * 7
else:
assert not manager.block_pool[block_id].token_ids
# Cache hit in the common prefix when the original block is still in use.
# Incomplete 1 block (5 tokens)
@ -113,7 +109,7 @@ def test_prefill():
req3 = make_request("3", [99] * (16 * 9))
computed_blocks = manager.get_computed_blocks(req3)
assert not computed_blocks
blocks = manager.allocate_slots(req2, 16 * 9, computed_blocks)
blocks = manager.allocate_slots(req3, 16 * 9, computed_blocks)
# This block ID order also checks the eviction order.
assert [b.block_id for b in blocks] == [9, 4, 3, 6, 5, 8, 7, 2, 1, 0]
assert manager.free_block_queue.num_free_blocks == 0
@ -148,7 +144,7 @@ def test_decode():
req0.append_output_token_ids(8)
new_blocks = manager.append_slots(req0, 4)
assert new_blocks is not None and len(new_blocks) == 0
assert len(manager.block_pool[3].token_ids) == 11
assert manager.req_to_blocks[req0.request_id][-2].block_hash is None
# Append slots without allocating a new block, but start using the
# preallocated block.
@ -159,8 +155,7 @@ def test_decode():
req0.append_output_token_ids(7)
new_blocks = manager.append_slots(req0, 15)
assert new_blocks is not None and len(new_blocks) == 0
assert len(manager.block_pool[3].token_ids) == 16
assert len(manager.block_pool[4].token_ids) == 10
assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None
# Append slots with allocating a new block.
req0.num_computed_tokens = 74
@ -171,9 +166,6 @@ def test_decode():
new_blocks = manager.append_slots(req0, 17)
# Plus one preallocated block.
assert new_blocks is not None and len(new_blocks) == 2
assert len(manager.block_pool[4].token_ids) == 16
assert len(manager.block_pool[5].token_ids) == 11
assert len(manager.block_pool[6].token_ids) == 0
def test_evict():
@ -217,3 +209,198 @@ def test_evict():
blocks = manager.allocate_slots(req2, 3, computed_blocks)
assert [b.block_id for b in blocks] == [6, 5]
assert manager.free_block_queue.num_free_blocks == 6
def test_hash_block_correct_reuse():
"""
This tests when a previously cached block is reused as a new block,
its hash metadata should be correctly reset.
"""
block_size = 16
manager = KVCacheManager(
block_size=block_size,
num_gpu_blocks=1,
sliding_window=False,
enable_caching=True,
num_preallocate_tokens=0,
)
# Allocate 1 block and cache it.
num_tokens = block_size * 1
req = make_request("0", list(range(num_tokens)))
computed_blocks = manager.get_computed_blocks(req)
assert not computed_blocks
blocks = manager.allocate_slots(req, num_tokens, computed_blocks)
assert len(blocks) == 1
# Deallocate the block.
manager.free(req)
# Allocate a new block that's not full, make sure hash info on the
# block is cleared.
req = make_request("1", list(range(num_tokens - 1)))
computed_blocks = manager.get_computed_blocks(req)
assert not computed_blocks
blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks)
assert len(blocks) == 1
assert manager.block_pool[blocks[0].block_id].block_hash is None
def test_computed_blocks_not_evicted():
"""
Test that the computed blocks are not evicted when getting new blocks
for a request if there are any other free blocks.
"""
block_size = 16
manager = KVCacheManager(
block_size=block_size,
num_gpu_blocks=2,
sliding_window=False,
enable_caching=True,
num_preallocate_tokens=0,
)
# Allocate a block and cache it.
num_tokens = block_size * 1
req0 = make_request("0", list(range(num_tokens)))
computed_blocks = manager.get_computed_blocks(req0)
assert not computed_blocks
blocks = manager.allocate_slots(req0, num_tokens, computed_blocks)
assert len(blocks) == 1
assert blocks[0].block_id == 0
# Allocate another block.
req1 = make_request("1", list(range(num_tokens, num_tokens * 2)))
computed_blocks = manager.get_computed_blocks(req1)
assert not computed_blocks
blocks = manager.allocate_slots(req1, num_tokens, computed_blocks)
assert len(blocks) == 1
assert blocks[0].block_id == 1
# Free the blocks.
manager.free(req0)
manager.free(req1)
# Now if we have a cache hit on the first block, we should evict the second
# cached block rather than the first one.
req2 = make_request("2", list(range(num_tokens * 2)))
computed_blocks = manager.get_computed_blocks(req2)
assert len(computed_blocks) == 1
assert computed_blocks[0].block_id == 0
blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens,
computed_blocks)
assert len(blocks) == 1
assert blocks[0].block_id == 1
def test_basic_prefix_caching_disabled():
"""
This tests that the prefix caching is disabled.
"""
block_size = 4
manager = KVCacheManager(
block_size=block_size,
num_gpu_blocks=4,
sliding_window=False,
enable_caching=False,
num_preallocate_tokens=0,
)
req1 = make_request("1", list(range(10))) # 2 blocks and some more
computed_blocks = manager.get_computed_blocks(req1)
assert not computed_blocks
blocks = manager.allocate_slots(req1, 10, computed_blocks)
assert len(blocks) == 3
# Free the blocks.
manager.free(req1)
# No caching.
req2 = make_request("2", list(range(16))) # shared prefix
computed_blocks = manager.get_computed_blocks(req2)
assert not computed_blocks
blocks = manager.allocate_slots(req2, 16, computed_blocks)
assert len(blocks) == 4
# New requests should not have any blocks.
req3 = make_request("3", list(range(4)))
computed_blocks = manager.get_computed_blocks(req3)
assert not computed_blocks
blocks = manager.allocate_slots(req3, 4, computed_blocks)
assert not blocks
@pytest.mark.parametrize("num_preallocate_tokens", list(range(0, 8)))
@pytest.mark.parametrize("block_size", [4])
def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
"""
This tests that the preallocated blocks are correctly added.
"""
manager = KVCacheManager(
block_size=block_size,
num_gpu_blocks=10,
sliding_window=False,
enable_caching=True,
num_preallocate_tokens=num_preallocate_tokens,
)
num_preallocated_blocks = cdiv(num_preallocate_tokens, block_size)
req = make_request("0", list(range(block_size * 30)))
computed_blocks = manager.get_computed_blocks(req)
assert not computed_blocks
# Just ask for 1 block.
blocks = manager.allocate_slots(req, block_size, computed_blocks)
assert len(blocks) == 1 + num_preallocated_blocks
# Append slots to the block.
req.num_computed_tokens = block_size * len(blocks) # Assume all used.
blocks = manager.append_slots(req, block_size) # Append 1 block.
assert len(blocks) == 1 + num_preallocated_blocks
def test_cache_blocks():
"""
This is a unit test that tests the correctness of the _cache_full_blocks
function of KVCacheManager.
"""
block_size = 4
manager = KVCacheManager(
block_size=block_size,
num_gpu_blocks=5,
sliding_window=False,
enable_caching=True,
num_preallocate_tokens=0,
)
# Req:
# Block 0: [0, 1, 2, 3]
# Block 1: [4, 5, 6, 7]
# Block 2: [8, 9, 10, 11]
# Block 3: [12, 13]
req = make_request("0", list(range(14)))
# Test that blocks are cached correctly for 2 full blocks from the start.
blocks = [KVCacheBlock(block_id=i) for i in range(2)]
manager._cache_full_blocks(
request=req,
blk_start_idx=0,
full_blocks=blocks,
prev_block=None,
)
assert len(manager.cached_block_hash_to_block) == 2
assert all([block.block_hash is not None for block in blocks])
# Test that blocks that don't start from the beginning are cached correctly.
blocks = [KVCacheBlock(block_id=2)]
manager._cache_full_blocks(
request=req,
blk_start_idx=2,
full_blocks=blocks,
prev_block=None,
)
assert len(manager.cached_block_hash_to_block) == 3
assert blocks[0].block_hash is not None

View File

@ -79,6 +79,9 @@ class KVCacheManager:
return []
computed_blocks = []
# TODO(rickyx): potentially we could cache this so we don't have to
# recompute it every time.
block_hashes = hash_request_tokens(self.block_size,
request.all_token_ids)
@ -120,47 +123,45 @@ class KVCacheManager:
# slots, but we cannot allocate new blocks due to the limit.
return None
# When caching is enabled, assign token IDs to already allocated blocks.
new_token_ids = None
parent_block = None
if self.enable_caching:
# Figure out the token IDs to add to the blocks.
new_token_ids = request.all_token_ids[
request.num_computed_tokens:request.num_computed_tokens +
num_tokens]
if num_new_blocks <= 0:
# No new block is needed.
new_blocks = []
else:
# Get new blocks from the free block pool considering
# preallocated blocks.
num_new_blocks = min(
num_new_blocks + self.num_preallocate_blocks,
self.free_block_queue.num_free_blocks,
)
# Find the last full block index.
# TODO: This may be optimized by calculating the computed tokens.
last_full_block_idx = len(req_blocks) - 1
while (last_full_block_idx >= 0
and req_blocks[last_full_block_idx].block_hash is None):
last_full_block_idx -= 1
new_blocks = self._get_new_blocks(num_new_blocks)
req_blocks.extend(new_blocks)
parent_block = (req_blocks[last_full_block_idx]
if last_full_block_idx >= 0 else None)
token_id_idx = self._add_token_ids_to_blocks(
blocks=req_blocks[last_full_block_idx + 1:],
token_ids=new_token_ids,
parent_block=parent_block)
if not self.enable_caching:
return new_blocks
new_token_ids = new_token_ids[token_id_idx:]
parent_block = req_blocks[-1]
num_computed_full_blocks = (request.num_computed_tokens //
self.block_size)
# No new block is needed. When caching is enabled, we make sure
# token_id_idx is equal to len(new_token_ids), meaning that all tokens
# are added to allocated blocks.
if num_required_blocks <= len(req_blocks):
assert not self.enable_caching or token_id_idx == num_tokens, \
f"{token_id_idx=} != {num_tokens=}"
return []
# NOTE(rickyx): We are assuming the `num_tokens` are actual
# tokens rather than lookahead slots (e.g. for speculative decoding).
# TODO(rickyx): When supporting speculative decoding, we will need to
# differentiate between them so that we can know how many blocks are
# full after appending the actual tokens.
num_full_blocks_after_append = (request.num_computed_tokens +
num_tokens) // self.block_size
assert num_full_blocks_after_append <= len(req_blocks)
new_full_blocks = req_blocks[
num_computed_full_blocks:num_full_blocks_after_append]
self._cache_full_blocks(
request=request,
blk_start_idx=num_computed_full_blocks,
full_blocks=new_full_blocks,
prev_block=req_blocks[num_computed_full_blocks - 1]
if num_computed_full_blocks >= 1 else None,
)
# Allocate new blocks considering preallocated blocks, and
# add token IDs to them if caching is enabled.
num_new_blocks = min(num_new_blocks + self.num_preallocate_blocks,
self.free_block_queue.num_free_blocks)
new_blocks = self._get_new_blocks(num_new_blocks, new_token_ids,
parent_block)
req_blocks.extend(new_blocks)
return new_blocks
def allocate_slots(
@ -184,11 +185,20 @@ class KVCacheManager:
raise ValueError(
f"num_tokens must be greater than 0, got {num_tokens}")
# If a computed block of a request is an eviction candidate (in the
# free queue and ref_cnt == 0), it cannot be counted as a free block
# when allocating this request.
num_evictable_computed_blocks = len(
[blk for blk in computed_blocks if blk.ref_cnt == 0])
# Touch the computed blocks to make sure they won't be evicted.
num_evictable_computed_blocks = 0
if self.enable_caching:
self._touch(computed_blocks)
# If a computed block of a request is an eviction candidate (in the
# free queue and ref_cnt == 0), it cannot be counted as a free block
# when allocating this request.
num_evictable_computed_blocks = len(
[blk for blk in computed_blocks if blk.ref_cnt == 0])
else:
assert not computed_blocks, (
"Computed blocks should be empty when "
"prefix caching is disabled")
num_required_blocks = cdiv(num_tokens, self.block_size)
if (num_required_blocks > self.free_block_queue.num_free_blocks -
@ -201,35 +211,28 @@ class KVCacheManager:
num_new_blocks = min(
num_required_blocks + self.num_preallocate_blocks,
self.free_block_queue.num_free_blocks -
num_evictable_computed_blocks)
num_computed_tokens = len(computed_blocks) * self.block_size
# When caching is enabled, get the new token IDs and the parent block
# ID to generate cache keys.
new_token_ids = None
parent_block = None
if self.enable_caching:
# Touch the computed blocks to make sure they won't be evicted.
self._touch(computed_blocks)
# Get the token IDs for the blocks being allocated for hashing.
new_token_ids = request.all_token_ids[
num_computed_tokens:num_computed_tokens + num_tokens]
if not new_token_ids:
raise RuntimeError(
"Failed to infer the token IDs for allocation. "
f"#all_tokens={len(request.all_token_ids)} < "
f"#computed_tokens={num_computed_tokens}")
# Get the parent block ID to construct the block chain.
parent_block = computed_blocks[-1] if computed_blocks else None
new_blocks = self._get_new_blocks(num_new_blocks, new_token_ids,
parent_block)
num_evictable_computed_blocks,
)
# Concatenate the computed block IDs and the new block IDs.
new_blocks = self._get_new_blocks(num_new_blocks)
self.req_to_blocks[request.request_id] = computed_blocks + new_blocks
if not self.enable_caching:
return new_blocks
num_computed_tokens = len(computed_blocks) * self.block_size
num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size
self._cache_full_blocks(
request=request,
blk_start_idx=len(computed_blocks),
# The new full blocks are the full blocks that are not computed.
full_blocks=self.req_to_blocks[request.request_id]
[len(computed_blocks):num_full_blocks],
prev_block=computed_blocks[-1] if computed_blocks else None,
)
return new_blocks
def free(self, request: Request) -> None:
@ -248,24 +251,17 @@ class KVCacheManager:
blocks = reversed(blocks)
for block in blocks:
block.ref_cnt -= 1
block.decr_ref()
if block.ref_cnt == 0:
self.free_block_queue.append(block)
def _get_new_blocks(
self,
num_blocks: int,
token_ids: Optional[List[int]] = None,
parent_block: Optional[int] = None) -> List[KVCacheBlock]:
"""Get new blocks from the free block pool, and add token IDs to
allocated blocks if caching is enabled.
def _get_new_blocks(self, num_blocks: int) -> List[KVCacheBlock]:
"""Get new blocks from the free block pool.
Note that we do not check block cache in this function.
Args:
num_blocks: The number of blocks to allocate.
token_ids: The token IDs in the blocks. None if caching is disabled.
parent_block: The parent block. Used to include block chain
in the block hash.
Returns:
A list of new block.
@ -274,56 +270,38 @@ class KVCacheManager:
raise ValueError(
f"Cannot get {num_blocks} free blocks from the pool")
# First allocate blocks.
ret: List[KVCacheBlock] = []
idx = 0
while idx < num_blocks:
# First allocate blocks.
curr_block = self.free_block_queue.popleft()
assert curr_block.ref_cnt == 0
# Evict blocks from the cache.
# If the block is cached, evict it.
if self.enable_caching:
block_hash = curr_block.block_hash
if (block_hash is not None
and block_hash in self.cached_block_hash_to_block):
if len(self.cached_block_hash_to_block[block_hash]) == 1:
del self.cached_block_hash_to_block[block_hash]
else:
del self.cached_block_hash_to_block[block_hash][
curr_block.block_id]
curr_block.reset()
self._evict_cached_block(curr_block)
curr_block.ref_cnt = 1
curr_block.incr_ref()
ret.append(curr_block)
idx += 1
# Then assign token IDs to the allocated blocks.
if self.enable_caching:
assert token_ids is not None
token_id_idx = self._add_token_ids_to_blocks(
blocks=ret, token_ids=token_ids, parent_block=parent_block)
assert token_id_idx == len(token_ids)
return ret
def _cache_full_block(self,
block: KVCacheBlock,
parent_block: Optional[KVCacheBlock] = None) -> None:
"""Cache a full block for prefix caching.
def _evict_cached_block(self, block: KVCacheBlock) -> None:
"""
If a block is cached in `cached_block_hash_to_block`, we reset its hash
metadata and evict it from the cache.
Args:
block: The block to cache.
parent_block: The parent block. None if this is the first block.
block: The block to evict.
"""
parent_block_hash = (parent_block.block_hash
if parent_block is not None else None)
assert len(block.token_ids) == self.block_size
block.token_ids = tuple(block.token_ids)
block_hash = hash_block_tokens(parent_block_hash, block.token_ids)
block.block_hash = block_hash
block.num_hashed_tokens = self.block_size + (
parent_block.num_hashed_tokens if parent_block is not None else 0)
self.cached_block_hash_to_block[block_hash][block.block_id] = block
block_hash = block.block_hash
if block_hash and block_hash in self.cached_block_hash_to_block:
block.reset_hash()
del self.cached_block_hash_to_block[block_hash][block.block_id]
if len(self.cached_block_hash_to_block[block_hash]) == 0:
del self.cached_block_hash_to_block[block_hash]
def _get_cached_block(self,
block_hash: BlockHashType) -> Optional[KVCacheBlock]:
@ -355,43 +333,50 @@ class KVCacheManager:
# candidate), so remove it.
if block.ref_cnt == 0:
self.free_block_queue.remove(block)
block.ref_cnt += 1
block.incr_ref()
def _add_token_ids_to_blocks(
self,
blocks: List[KVCacheBlock],
token_ids: List[int],
parent_block: Optional[KVCacheBlock] = None) -> int:
"""Add token IDs to a list of allocated blocks.
If a block becomes full after adding token IDs, cache it.
Return the token ID index that has not been added to the blocks
if the blocks are not enough to hold all the token IDs.
def _cache_full_blocks(
self,
request: Request,
blk_start_idx: int,
full_blocks: List[KVCacheBlock],
prev_block: Optional[KVCacheBlock],
) -> None:
"""Cache a list of full blocks for prefix caching.
This function takes a list of blocks that will have their block hash
metadata to be updated and cached. Given a request, it computes the
block hashes for the blocks starting from `blk_start_idx` to the end
of the request's full blocks, updating the metadata for each block
and caching them in the `cached_block_hash_to_block`.
Args:
blocks: A list of blocks to add token IDs.
token_ids: A list of token IDs to add.
parent_block: The parent block. None if this is the
first block.
Returns:
The starting token ID index that has not been added to the blocks
due to insufficient given blocks.
request: The request to cache the blocks.
blk_start_idx: The index of the first block in the request's blocks
to cache.
full_blocks: The list of blocks to update hash metadata.
prev_block: The previous block in the chain.
"""
token_id_start = 0
for curr_block in blocks:
# If all token IDs are added, then the rest of the blocks are
# preallocated blocks, so we only need to update the
# parent_block_id. FIXME
if token_id_start == len(token_ids):
continue
# Update the new blocks with the block hashes through the chain.
prev_block_hash = (prev_block.block_hash
if prev_block is not None else None)
for i, blk in enumerate(full_blocks):
blk_idx = blk_start_idx + i
# Add token IDs to the empty slots in the block.
empty_slots = self.block_size - len(curr_block.token_ids)
token_id_end = min(token_id_start + empty_slots, len(token_ids))
curr_block.token_ids.extend(token_ids[token_id_start:token_id_end])
# Cache the block if it becomes full.
if len(curr_block.token_ids) == self.block_size:
self._cache_full_block(curr_block, parent_block)
parent_block = curr_block
token_id_start = token_id_end
return token_id_start
block_tokens = request.all_token_ids[blk_idx *
self.block_size:(blk_idx +
1) *
self.block_size]
assert len(block_tokens) == self.block_size, (
f"Expected {self.block_size} tokens, got {len(block_tokens)} "
f"at {blk_idx}th block for request "
f"{request.request_id}({request})")
# Compute the hash of the current block.
block_hash = hash_block_tokens(prev_block_hash,
tuple(block_tokens))
# Update and added the full block to the cache.
blk.block_hash = block_hash
self.cached_block_hash_to_block[block_hash][blk.block_id] = blk
prev_block_hash = block_hash

View File

@ -1,6 +1,6 @@
"""KV-Cache Utilities."""
from dataclasses import dataclass, field
from typing import List, Optional, Tuple, Union
from dataclasses import dataclass
from typing import List, Optional, Tuple
from vllm.logger import init_logger
@ -16,27 +16,34 @@ class KVCacheBlock:
block_id: int
# Reference count.
ref_cnt: int = 0
# Token IDs in the block. When the block is full, the type of token_ids
# should be Tuple[int] for fast matching.
token_ids: Union[List[int], Tuple[int]] = field(default_factory=list)
# The hash of the block composed of (block hash, tuple of token IDs).
# It is only available when the block is full.
block_hash: Optional[BlockHashType] = None
# The number of hashed tokens. More hashed tokens means the block
# is closer to the end of a prompt and more likely to be evicted.
num_hashed_tokens: int = 0
_block_hash: Optional[BlockHashType] = None
# Used to construct a doubly linked list for free blocks.
# These two attributes should only be manipulated by FreeKVCacheBlockQueue.
prev_free_block: Optional["KVCacheBlock"] = None
next_free_block: Optional["KVCacheBlock"] = None
def reset(self):
"""Reset the block metadata."""
self.ref_cnt = 0
self.token_ids = []
self.block_hash = None
self.num_hashed_tokens = 0
def incr_ref(self):
self.ref_cnt += 1
def decr_ref(self):
self.ref_cnt -= 1
@property
def block_hash(self) -> Optional[BlockHashType]:
return self._block_hash
@block_hash.setter
def block_hash(self, block_hash: BlockHashType):
assert self.block_hash is None, (
"The block already has a hash. This should not happen.")
self._block_hash = block_hash
def reset_hash(self):
"""Reset the block hash when the block is evicted."""
self._block_hash = None
class FreeKVCacheBlockQueue: