[V1] Revert uncache_blocks and support recaching full blocks (#12415)

Signed-off-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
Cody Yu 2025-02-03 15:04:53 -08:00 committed by GitHub
parent cf58b9c4ca
commit 5095e96606
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 16 additions and 55 deletions

View File

@ -629,33 +629,3 @@ def test_reset_prefix_cache():
assert manager.reset_prefix_cache()
assert not manager.cached_block_hash_to_block
assert all([blk.block_hash is None for blk in manager.block_pool])
def test_uncache_blocks():
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
max_model_len=8192,
sliding_window=None,
enable_caching=True,
num_preallocate_tokens=0,
)
req0 = make_request("0", list(range(30)))
blocks = manager.allocate_slots(req0, 30)
assert [b.block_id for b in blocks] == [0, 1]
assert len(manager.cached_block_hash_to_block) == 1
req0.num_computed_tokens = 30
# Simulate speculative tokens.
for _ in range(5):
req0.append_output_token_ids(8)
manager.allocate_slots(req0, 5)
assert len(manager.cached_block_hash_to_block) == 2
# After sampling, assuming only 1 token is accepted.
req0.num_computed_tokens = 31
num_uncached_blocks = manager.uncache_blocks(req0)
assert num_uncached_blocks == 1
assert len(manager.cached_block_hash_to_block) == 1

View File

@ -252,29 +252,6 @@ class KVCacheManager:
if block.ref_cnt == 0:
self.free_block_queue.append(block)
def uncache_blocks(self, request: Request) -> int:
"""Uncache the blocks that are no longer full based on the
num_computed_tokens in the given request. This happens when
the blocks were full and cached due to speculative tokens, but the
speculative tokens are not accepted.
Args:
request: The request.
Returns:
The number of uncached blocks.
"""
blocks = self.req_to_blocks[request.request_id]
num_computed_tokens = request.num_computed_tokens
num_full_blocks = num_computed_tokens // self.block_size
num_uncached_blocks = 0
for block in blocks[num_full_blocks:]:
# If the block is not cached, the following blocks are not cached.
if not self._maybe_evict_cached_block(block):
break
num_uncached_blocks += 1
return num_uncached_blocks
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache. This function may be used in RLHF
flows to invalid prefix caching after the weights are updated,
@ -470,8 +447,22 @@ class KVCacheManager:
assert prev_block.block_hash is not None
prev_block_hash_value = prev_block.block_hash.hash_value
for i, blk in enumerate(full_blocks):
blk_idx = blk_start_idx + i
# Find the first uncached block. This case should only happen when
# speculative decoding is used.
offset = 0
for blk in full_blocks:
if blk.block_hash is None:
break
else:
prev_block_hash_value = blk.block_hash.hash_value
offset += 1
else:
# All blocks are cached.
return
for i, blk in enumerate(full_blocks[offset:]):
blk_idx = blk_start_idx + offset + i
assert blk.block_hash is None
if blk_idx < num_cached_block_hashes:
# The block hash may already be computed in