[V1] Add uncache_blocks (#12333)

This commit is contained in:
Cody Yu 2025-01-22 20:19:21 -08:00 committed by GitHub
parent 7551a34032
commit f0ef37233e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 61 additions and 2 deletions

View File

@ -626,3 +626,33 @@ def test_reset_prefix_cache():
assert manager.reset_prefix_cache() assert manager.reset_prefix_cache()
assert not manager.cached_block_hash_to_block assert not manager.cached_block_hash_to_block
assert all([blk.block_hash is None for blk in manager.block_pool]) assert all([blk.block_hash is None for blk in manager.block_pool])
def test_uncache_blocks():
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=0,
)
req0 = make_request("0", list(range(30)))
blocks = manager.allocate_slots(req0, 30, [])
assert [b.block_id for b in blocks] == [0, 1]
assert len(manager.cached_block_hash_to_block) == 1
req0.num_computed_tokens = 30
# Simulate speculative tokens.
for _ in range(5):
req0.append_output_token_ids(8)
manager.append_slots(req0, 5)
assert len(manager.cached_block_hash_to_block) == 2
# After sampling, assuming only 1 token is accepted.
req0.num_computed_tokens = 31
num_uncached_blocks = manager.uncache_blocks(req0)
assert num_uncached_blocks == 1
assert len(manager.cached_block_hash_to_block) == 1

View File

@ -285,6 +285,29 @@ class KVCacheManager:
if block.ref_cnt == 0: if block.ref_cnt == 0:
self.free_block_queue.append(block) self.free_block_queue.append(block)
def uncache_blocks(self, request: Request) -> int:
"""Uncache the blocks that are no longer full based on the
num_computed_tokens in the given request. This happens when
the blocks were full and cached due to speculative tokens, but the
speculative tokens are not accepted.
Args:
request: The request.
Returns:
The number of uncached blocks.
"""
blocks = self.req_to_blocks[request.request_id]
num_computed_tokens = request.num_computed_tokens
num_full_blocks = num_computed_tokens // self.block_size
num_uncached_blocks = 0
for block in blocks[num_full_blocks:]:
# If the block is not cached, the following blocks are not cached.
if not self._maybe_evict_cached_block(block):
break
num_uncached_blocks += 1
return num_uncached_blocks
def reset_prefix_cache(self) -> bool: def reset_prefix_cache(self) -> bool:
"""Reset prefix cache. This function may be used in RLHF """Reset prefix cache. This function may be used in RLHF
flows to invalid prefix caching after the weights are updated, flows to invalid prefix caching after the weights are updated,
@ -386,7 +409,7 @@ class KVCacheManager:
# If the block is cached, evict it. # If the block is cached, evict it.
if self.enable_caching: if self.enable_caching:
self._evict_cached_block(curr_block) self._maybe_evict_cached_block(curr_block)
curr_block.incr_ref() curr_block.incr_ref()
ret.append(curr_block) ret.append(curr_block)
@ -394,13 +417,16 @@ class KVCacheManager:
return ret return ret
def _evict_cached_block(self, block: KVCacheBlock) -> None: def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool:
""" """
If a block is cached in `cached_block_hash_to_block`, we reset its hash If a block is cached in `cached_block_hash_to_block`, we reset its hash
metadata and evict it from the cache. metadata and evict it from the cache.
Args: Args:
block: The block to evict. block: The block to evict.
Returns:
True if the block is evicted, False otherwise.
""" """
block_hash = block.block_hash block_hash = block.block_hash
if block_hash and block_hash in self.cached_block_hash_to_block: if block_hash and block_hash in self.cached_block_hash_to_block:
@ -410,6 +436,9 @@ class KVCacheManager:
if len(self.cached_block_hash_to_block[block_hash]) == 0: if len(self.cached_block_hash_to_block[block_hash]) == 0:
del self.cached_block_hash_to_block[block_hash] del self.cached_block_hash_to_block[block_hash]
return True
return False
def _get_cached_block(self, def _get_cached_block(self,
block_hash: BlockHashType) -> Optional[KVCacheBlock]: block_hash: BlockHashType) -> Optional[KVCacheBlock]:
"""Get a cached block by the block hash, or None if cache miss. """Get a cached block by the block hash, or None if cache miss.