mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-12 16:26:07 +08:00
[v1][KVCacheManager] Avoid full cache hit by controlling max_length (#17999)
Signed-off-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
e57e4d6e9e
commit
f0d610a8ae
@ -17,8 +17,9 @@ def get_sliding_window_manager(sliding_window_spec, block_pool):
|
|||||||
|
|
||||||
|
|
||||||
def test_sliding_window_possible_cached_prefix():
|
def test_sliding_window_possible_cached_prefix():
|
||||||
|
block_size = 2
|
||||||
sliding_window_spec = SlidingWindowSpec(
|
sliding_window_spec = SlidingWindowSpec(
|
||||||
block_size=2,
|
block_size=block_size,
|
||||||
num_kv_heads=1,
|
num_kv_heads=1,
|
||||||
head_size=1,
|
head_size=1,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
@ -44,7 +45,9 @@ def test_sliding_window_possible_cached_prefix():
|
|||||||
i: block_pool.blocks[i + 10]
|
i: block_pool.blocks[i + 10]
|
||||||
}
|
}
|
||||||
|
|
||||||
computed_blocks = manager.find_longest_cache_hit(block_hash_list)
|
computed_blocks = manager.find_longest_cache_hit(
|
||||||
|
block_hash_list,
|
||||||
|
len(block_hash_list) * block_size)
|
||||||
assert len(computed_blocks) == expect_length
|
assert len(computed_blocks) == expect_length
|
||||||
|
|
||||||
assert all(block == block_pool.null_block
|
assert all(block == block_pool.null_block
|
||||||
|
|||||||
@ -146,21 +146,16 @@ class KVCacheManager:
|
|||||||
assert self.prefix_cache_stats is not None
|
assert self.prefix_cache_stats is not None
|
||||||
self.prefix_cache_stats.requests += 1
|
self.prefix_cache_stats.requests += 1
|
||||||
|
|
||||||
if len(block_hashes) * self.block_size == request.num_tokens:
|
# NOTE: When all tokens hit the cache, we must recompute the last token
|
||||||
# When prompt length is divisible by the block size and all
|
# to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1.
|
||||||
# blocks are cached, we need to recompute the last token. This
|
# This can trigger recomputation of an entire block, rather than just
|
||||||
# have to be achieved by re-computing an entire block because
|
# the single last token, because allocate_slots() requires
|
||||||
# allocate_slots() assumes num_computed_tokens is always a
|
# num_computed_tokens to be block-size aligned. Removing this limitation
|
||||||
# multiple of the block size. To achieve this, remove the last
|
# could slightly improve performance in the future.
|
||||||
# block hash from the block_hashes for find_longest_cache_hit
|
max_cache_hit_length = request.num_tokens - 1
|
||||||
# This limitation can potentially be removed in the future to
|
|
||||||
# slightly improve the performance.
|
|
||||||
last_block_hash = block_hashes.pop()
|
|
||||||
else:
|
|
||||||
last_block_hash = None
|
|
||||||
|
|
||||||
computed_blocks = (
|
computed_blocks = self.single_type_manager.find_longest_cache_hit(
|
||||||
self.single_type_manager.find_longest_cache_hit(block_hashes))
|
block_hashes, max_cache_hit_length)
|
||||||
# NOTE(woosuk): Since incomplete blocks are not eligible for
|
# NOTE(woosuk): Since incomplete blocks are not eligible for
|
||||||
# sharing, `num_computed_tokens` is always a multiple of
|
# sharing, `num_computed_tokens` is always a multiple of
|
||||||
# `block_size`.
|
# `block_size`.
|
||||||
@ -171,12 +166,6 @@ class KVCacheManager:
|
|||||||
self.prefix_cache_stats.queries += request.num_tokens
|
self.prefix_cache_stats.queries += request.num_tokens
|
||||||
self.prefix_cache_stats.hits += num_computed_tokens
|
self.prefix_cache_stats.hits += num_computed_tokens
|
||||||
|
|
||||||
if last_block_hash is not None:
|
|
||||||
# Add back the last block hash if it was removed.
|
|
||||||
# NOTE: Because block_hashes is cached in req_to_block_hashes,
|
|
||||||
# we shouldn't modify it directly.
|
|
||||||
block_hashes.append(last_block_hash)
|
|
||||||
|
|
||||||
return KVCacheBlocks(computed_blocks), num_computed_tokens
|
return KVCacheBlocks(computed_blocks), num_computed_tokens
|
||||||
|
|
||||||
def allocate_slots(
|
def allocate_slots(
|
||||||
|
|||||||
@ -187,17 +187,19 @@ class SingleTypeKVCacheManager(ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def find_longest_cache_hit(
|
def find_longest_cache_hit(self, block_hashes: list[BlockHashType],
|
||||||
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
|
max_length: int) -> list[KVCacheBlock]:
|
||||||
"""
|
"""
|
||||||
Get the longest cache hit prefix of the blocks. If no cache hit is
|
Get the longest cache hit prefix of the blocks that is not longer than
|
||||||
found, return an empty list. if eagle is enabled, drop the last matched
|
`max_length`. If no cache hit is found, return an empty list.
|
||||||
block to force recompute the last block to get the required hidden
|
If eagle is enabled, drop the last matched block to force recompute the
|
||||||
states for eagle drafting head. Need to be customized for each attention
|
last block to get the required hidden states for eagle drafting head.
|
||||||
type.
|
Need to be customized for each attention type.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
block_hashes: The block hashes of the request.
|
block_hashes: The block hashes of the request.
|
||||||
|
max_length: The maximum length of the cache hit prefix.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of cached blocks with skipped blocks replaced by null block.
|
A list of cached blocks with skipped blocks replaced by null block.
|
||||||
For example, sliding window manager should return a list like
|
For example, sliding window manager should return a list like
|
||||||
@ -226,10 +228,12 @@ class SingleTypeKVCacheManager(ABC):
|
|||||||
|
|
||||||
class FullAttentionManager(SingleTypeKVCacheManager):
|
class FullAttentionManager(SingleTypeKVCacheManager):
|
||||||
|
|
||||||
def find_longest_cache_hit(
|
def find_longest_cache_hit(self, block_hashes: list[BlockHashType],
|
||||||
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
|
max_length: int) -> list[KVCacheBlock]:
|
||||||
computed_blocks: list[KVCacheBlock] = []
|
computed_blocks: list[KVCacheBlock] = []
|
||||||
for block_hash in block_hashes:
|
max_num_blocks = max_length // self.block_size
|
||||||
|
for i in range(max_num_blocks):
|
||||||
|
block_hash = block_hashes[i]
|
||||||
# block_hashes is a chain of block hashes. If a block hash is not
|
# block_hashes is a chain of block hashes. If a block hash is not
|
||||||
# in the cached_block_hash_to_id, the following block hashes are
|
# in the cached_block_hash_to_id, the following block hashes are
|
||||||
# not computed yet for sure.
|
# not computed yet for sure.
|
||||||
@ -276,19 +280,20 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
|
|||||||
self.sliding_window_contiguous_blocks += 1
|
self.sliding_window_contiguous_blocks += 1
|
||||||
self._null_block = block_pool.null_block
|
self._null_block = block_pool.null_block
|
||||||
|
|
||||||
def find_longest_cache_hit(
|
def find_longest_cache_hit(self, block_hashes: list[BlockHashType],
|
||||||
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
|
max_length: int) -> list[KVCacheBlock]:
|
||||||
# TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to
|
# TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to
|
||||||
# optimize the time complexity from O(len(block_hashes)) to
|
# optimize the time complexity from O(max_num_blocks) to
|
||||||
# O(len(block_hashes) / sliding_window_contiguous_blocks +
|
# O(max_num_blocks / sliding_window_contiguous_blocks +
|
||||||
# sliding_window_contiguous_blocks),
|
# sliding_window_contiguous_blocks),
|
||||||
# which is good for low cache hit rate scenarios.
|
# which is good for low cache hit rate scenarios.
|
||||||
computed_blocks = [self._null_block] * len(block_hashes)
|
max_num_blocks = max_length // self.block_size
|
||||||
|
computed_blocks = [self._null_block] * max_num_blocks
|
||||||
num_contiguous_blocks = 0
|
num_contiguous_blocks = 0
|
||||||
|
|
||||||
match_found = False
|
match_found = False
|
||||||
# Search from right to left and early stop when a match is found.
|
# Search from right to left and early stop when a match is found.
|
||||||
for i in range(len(block_hashes) - 1, -1, -1):
|
for i in range(max_num_blocks - 1, -1, -1):
|
||||||
if cached_block := self.block_pool.get_cached_block(
|
if cached_block := self.block_pool.get_cached_block(
|
||||||
block_hashes[i]):
|
block_hashes[i]):
|
||||||
computed_blocks[i] = cached_block
|
computed_blocks[i] = cached_block
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user