[v1][Bugfix] Only cache blocks that are not in the prefix cache (#14073)

This commit is contained in:
Chen Zhang 2025-03-01 16:25:54 +08:00 committed by GitHub
parent b28246f6ff
commit b9f1d4294e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 9 additions and 22 deletions

View File

@ -107,34 +107,20 @@ class BlockPool:
assert prev_block.block_hash is not None
prev_block_hash_value = prev_block.block_hash.hash_value
# Find the first uncached block.
# FIXME: num_cached_blocks should be corrected by the caller
# so this should never happen.
offset = 0
for blk in new_full_blocks:
if blk.block_hash is None:
break
else:
prev_block_hash_value = blk.block_hash.hash_value
offset += 1
else:
# All blocks are cached.
return
for i, blk in enumerate(new_full_blocks[offset:]):
blk_idx = num_cached_blocks + offset + i
for i, blk in enumerate(new_full_blocks):
assert blk.block_hash is None
if i + offset < len(new_block_hashes):
if i < len(new_block_hashes):
# The block hash may already be computed in
# "get_computed_blocks" if the tokens are not generated by
# this request (either the prompt tokens or the previously
# generated tokens with preemption). In this case we simply
# reuse the block hash.
block_hash = new_block_hashes[i + offset]
block_hash = new_block_hashes[i]
else:
# Otherwise compute the block hash and cache it in the request
# in case it will be preempted in the future.
blk_idx = num_cached_blocks + i
start_token_idx = blk_idx * block_size
end_token_idx = (blk_idx + 1) * block_size
block_tokens = request.all_token_ids[

View File

@ -65,7 +65,7 @@ class KVCacheManager:
# This is used to track the number of cached blocks for each request.
# This is only used to track the RUNNING requests, we do not track the
# data for reempted ones.
self.num_cached_block: Dict[str, int] = defaultdict(int)
self.num_cached_block: Dict[str, int] = {}
self.prefix_cache_stats = PrefixCacheStats()
@property
@ -224,9 +224,10 @@ class KVCacheManager:
if not self.enable_caching:
return new_blocks
# FIXME: `num_cached_blocks` is not correct when the prefix cache
# of a new request is hit.
num_cached_blocks = self.num_cached_block[request.request_id]
# Use `new_computed_blocks` for a new request, and `num_cached_block`
# for a running request.
num_cached_blocks = self.num_cached_block.get(request.request_id,
len(new_computed_blocks))
# Speculated tokens might be rejected in the future, so we does
# not cache any speculated tokens. We only cache blocks with
# generated (accepted) tokens.