mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-28 18:18:43 +08:00
[V1] Add uncache_blocks (#12333)
This commit is contained in:
parent
7551a34032
commit
f0ef37233e
@ -626,3 +626,33 @@ def test_reset_prefix_cache():
|
||||
assert manager.reset_prefix_cache()
|
||||
assert not manager.cached_block_hash_to_block
|
||||
assert all([blk.block_hash is None for blk in manager.block_pool])
|
||||
|
||||
|
||||
def test_uncache_blocks():
|
||||
manager = KVCacheManager(
|
||||
block_size=16,
|
||||
num_gpu_blocks=10,
|
||||
max_model_len=8192,
|
||||
sliding_window=None,
|
||||
enable_caching=True,
|
||||
num_preallocate_tokens=0,
|
||||
)
|
||||
|
||||
req0 = make_request("0", list(range(30)))
|
||||
blocks = manager.allocate_slots(req0, 30, [])
|
||||
assert [b.block_id for b in blocks] == [0, 1]
|
||||
assert len(manager.cached_block_hash_to_block) == 1
|
||||
|
||||
req0.num_computed_tokens = 30
|
||||
|
||||
# Simulate speculative tokens.
|
||||
for _ in range(5):
|
||||
req0.append_output_token_ids(8)
|
||||
manager.append_slots(req0, 5)
|
||||
assert len(manager.cached_block_hash_to_block) == 2
|
||||
|
||||
# After sampling, assuming only 1 token is accepted.
|
||||
req0.num_computed_tokens = 31
|
||||
num_uncached_blocks = manager.uncache_blocks(req0)
|
||||
assert num_uncached_blocks == 1
|
||||
assert len(manager.cached_block_hash_to_block) == 1
|
||||
|
||||
@ -285,6 +285,29 @@ class KVCacheManager:
|
||||
if block.ref_cnt == 0:
|
||||
self.free_block_queue.append(block)
|
||||
|
||||
def uncache_blocks(self, request: Request) -> int:
|
||||
"""Uncache the blocks that are no longer full based on the
|
||||
num_computed_tokens in the given request. This happens when
|
||||
the blocks were full and cached due to speculative tokens, but the
|
||||
speculative tokens are not accepted.
|
||||
|
||||
Args:
|
||||
request: The request.
|
||||
|
||||
Returns:
|
||||
The number of uncached blocks.
|
||||
"""
|
||||
blocks = self.req_to_blocks[request.request_id]
|
||||
num_computed_tokens = request.num_computed_tokens
|
||||
num_full_blocks = num_computed_tokens // self.block_size
|
||||
num_uncached_blocks = 0
|
||||
for block in blocks[num_full_blocks:]:
|
||||
# If the block is not cached, the following blocks are not cached.
|
||||
if not self._maybe_evict_cached_block(block):
|
||||
break
|
||||
num_uncached_blocks += 1
|
||||
return num_uncached_blocks
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
"""Reset prefix cache. This function may be used in RLHF
|
||||
flows to invalid prefix caching after the weights are updated,
|
||||
@ -386,7 +409,7 @@ class KVCacheManager:
|
||||
|
||||
# If the block is cached, evict it.
|
||||
if self.enable_caching:
|
||||
self._evict_cached_block(curr_block)
|
||||
self._maybe_evict_cached_block(curr_block)
|
||||
|
||||
curr_block.incr_ref()
|
||||
ret.append(curr_block)
|
||||
@ -394,13 +417,16 @@ class KVCacheManager:
|
||||
|
||||
return ret
|
||||
|
||||
def _evict_cached_block(self, block: KVCacheBlock) -> None:
|
||||
def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool:
|
||||
"""
|
||||
If a block is cached in `cached_block_hash_to_block`, we reset its hash
|
||||
metadata and evict it from the cache.
|
||||
|
||||
Args:
|
||||
block: The block to evict.
|
||||
|
||||
Returns:
|
||||
True if the block is evicted, False otherwise.
|
||||
"""
|
||||
block_hash = block.block_hash
|
||||
if block_hash and block_hash in self.cached_block_hash_to_block:
|
||||
@ -410,6 +436,9 @@ class KVCacheManager:
|
||||
if len(self.cached_block_hash_to_block[block_hash]) == 0:
|
||||
del self.cached_block_hash_to_block[block_hash]
|
||||
|
||||
return True
|
||||
return False
|
||||
|
||||
def _get_cached_block(self,
|
||||
block_hash: BlockHashType) -> Optional[KVCacheBlock]:
|
||||
"""Get a cached block by the block hash, or None if cache miss.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user