[v1][KVCacheManager] Avoid full cache hit by controlling max_length (#17999)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Chen Zhang 2025-05-13 14:50:38 +08:00 committed by GitHub
parent e57e4d6e9e
commit f0d610a8ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 35 additions and 38 deletions

View File

@ -17,8 +17,9 @@ def get_sliding_window_manager(sliding_window_spec, block_pool):
def test_sliding_window_possible_cached_prefix(): def test_sliding_window_possible_cached_prefix():
block_size = 2
sliding_window_spec = SlidingWindowSpec( sliding_window_spec = SlidingWindowSpec(
block_size=2, block_size=block_size,
num_kv_heads=1, num_kv_heads=1,
head_size=1, head_size=1,
dtype=torch.float32, dtype=torch.float32,
@ -44,7 +45,9 @@ def test_sliding_window_possible_cached_prefix():
i: block_pool.blocks[i + 10] i: block_pool.blocks[i + 10]
} }
computed_blocks = manager.find_longest_cache_hit(block_hash_list) computed_blocks = manager.find_longest_cache_hit(
block_hash_list,
len(block_hash_list) * block_size)
assert len(computed_blocks) == expect_length assert len(computed_blocks) == expect_length
assert all(block == block_pool.null_block assert all(block == block_pool.null_block

View File

@ -146,21 +146,16 @@ class KVCacheManager:
assert self.prefix_cache_stats is not None assert self.prefix_cache_stats is not None
self.prefix_cache_stats.requests += 1 self.prefix_cache_stats.requests += 1
if len(block_hashes) * self.block_size == request.num_tokens: # NOTE: When all tokens hit the cache, we must recompute the last token
# When prompt length is divisible by the block size and all # to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1.
# blocks are cached, we need to recompute the last token. This # This can trigger recomputation of an entire block, rather than just
# have to be achieved by re-computing an entire block because # the single last token, because allocate_slots() requires
# allocate_slots() assumes num_computed_tokens is always a # num_computed_tokens to be block-size aligned. Removing this limitation
# multiple of the block size. To achieve this, remove the last # could slightly improve performance in the future.
# block hash from the block_hashes for find_longest_cache_hit max_cache_hit_length = request.num_tokens - 1
# This limitation can potentially be removed in the future to
# slightly improve the performance.
last_block_hash = block_hashes.pop()
else:
last_block_hash = None
computed_blocks = ( computed_blocks = self.single_type_manager.find_longest_cache_hit(
self.single_type_manager.find_longest_cache_hit(block_hashes)) block_hashes, max_cache_hit_length)
# NOTE(woosuk): Since incomplete blocks are not eligible for # NOTE(woosuk): Since incomplete blocks are not eligible for
# sharing, `num_computed_tokens` is always a multiple of # sharing, `num_computed_tokens` is always a multiple of
# `block_size`. # `block_size`.
@ -171,12 +166,6 @@ class KVCacheManager:
self.prefix_cache_stats.queries += request.num_tokens self.prefix_cache_stats.queries += request.num_tokens
self.prefix_cache_stats.hits += num_computed_tokens self.prefix_cache_stats.hits += num_computed_tokens
if last_block_hash is not None:
# Add back the last block hash if it was removed.
# NOTE: Because block_hashes is cached in req_to_block_hashes,
# we shouldn't modify it directly.
block_hashes.append(last_block_hash)
return KVCacheBlocks(computed_blocks), num_computed_tokens return KVCacheBlocks(computed_blocks), num_computed_tokens
def allocate_slots( def allocate_slots(

View File

@ -187,17 +187,19 @@ class SingleTypeKVCacheManager(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def find_longest_cache_hit( def find_longest_cache_hit(self, block_hashes: list[BlockHashType],
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]: max_length: int) -> list[KVCacheBlock]:
""" """
Get the longest cache hit prefix of the blocks. If no cache hit is Get the longest cache hit prefix of the blocks that is not longer than
found, return an empty list. if eagle is enabled, drop the last matched `max_length`. If no cache hit is found, return an empty list.
block to force recompute the last block to get the required hidden If eagle is enabled, drop the last matched block to force recompute the
states for eagle drafting head. Need to be customized for each attention last block to get the required hidden states for eagle drafting head.
type. Need to be customized for each attention type.
Args: Args:
block_hashes: The block hashes of the request. block_hashes: The block hashes of the request.
max_length: The maximum length of the cache hit prefix.
Returns: Returns:
A list of cached blocks with skipped blocks replaced by null block. A list of cached blocks with skipped blocks replaced by null block.
For example, sliding window manager should return a list like For example, sliding window manager should return a list like
@ -226,10 +228,12 @@ class SingleTypeKVCacheManager(ABC):
class FullAttentionManager(SingleTypeKVCacheManager): class FullAttentionManager(SingleTypeKVCacheManager):
def find_longest_cache_hit( def find_longest_cache_hit(self, block_hashes: list[BlockHashType],
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]: max_length: int) -> list[KVCacheBlock]:
computed_blocks: list[KVCacheBlock] = [] computed_blocks: list[KVCacheBlock] = []
for block_hash in block_hashes: max_num_blocks = max_length // self.block_size
for i in range(max_num_blocks):
block_hash = block_hashes[i]
# block_hashes is a chain of block hashes. If a block hash is not # block_hashes is a chain of block hashes. If a block hash is not
# in the cached_block_hash_to_id, the following block hashes are # in the cached_block_hash_to_id, the following block hashes are
# not computed yet for sure. # not computed yet for sure.
@ -276,19 +280,20 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
self.sliding_window_contiguous_blocks += 1 self.sliding_window_contiguous_blocks += 1
self._null_block = block_pool.null_block self._null_block = block_pool.null_block
def find_longest_cache_hit( def find_longest_cache_hit(self, block_hashes: list[BlockHashType],
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]: max_length: int) -> list[KVCacheBlock]:
# TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to # TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to
# optimize the time complexity from O(len(block_hashes)) to # optimize the time complexity from O(max_num_blocks) to
# O(len(block_hashes) / sliding_window_contiguous_blocks + # O(max_num_blocks / sliding_window_contiguous_blocks +
# sliding_window_contiguous_blocks), # sliding_window_contiguous_blocks),
# which is good for low cache hit rate scenarios. # which is good for low cache hit rate scenarios.
computed_blocks = [self._null_block] * len(block_hashes) max_num_blocks = max_length // self.block_size
computed_blocks = [self._null_block] * max_num_blocks
num_contiguous_blocks = 0 num_contiguous_blocks = 0
match_found = False match_found = False
# Search from right to left and early stop when a match is found. # Search from right to left and early stop when a match is found.
for i in range(len(block_hashes) - 1, -1, -1): for i in range(max_num_blocks - 1, -1, -1):
if cached_block := self.block_pool.get_cached_block( if cached_block := self.block_pool.get_cached_block(
block_hashes[i]): block_hashes[i]):
computed_blocks[i] = cached_block computed_blocks[i] = cached_block