mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 01:35:48 +08:00
[V1] Simplify prefix caching logic by removing num_evictable_computed_blocks (#11310)
This commit is contained in:
parent
a30482f054
commit
c6b0a7d3ba
@ -201,23 +201,15 @@ class KVCacheManager:
|
|||||||
f"num_tokens must be greater than 0, got {num_tokens}")
|
f"num_tokens must be greater than 0, got {num_tokens}")
|
||||||
|
|
||||||
# Touch the computed blocks to make sure they won't be evicted.
|
# Touch the computed blocks to make sure they won't be evicted.
|
||||||
num_evictable_computed_blocks = 0
|
|
||||||
if self.enable_caching:
|
if self.enable_caching:
|
||||||
self._touch(computed_blocks)
|
self._touch(computed_blocks)
|
||||||
|
|
||||||
# If a computed block of a request is an eviction candidate (in the
|
|
||||||
# free queue and ref_cnt == 0), it cannot be counted as a free block
|
|
||||||
# when allocating this request.
|
|
||||||
num_evictable_computed_blocks = len(
|
|
||||||
[blk for blk in computed_blocks if blk.ref_cnt == 0])
|
|
||||||
else:
|
else:
|
||||||
assert not computed_blocks, (
|
assert not computed_blocks, (
|
||||||
"Computed blocks should be empty when "
|
"Computed blocks should be empty when "
|
||||||
"prefix caching is disabled")
|
"prefix caching is disabled")
|
||||||
|
|
||||||
num_required_blocks = cdiv(num_tokens, self.block_size)
|
num_required_blocks = cdiv(num_tokens, self.block_size)
|
||||||
if (num_required_blocks > self.free_block_queue.num_free_blocks -
|
if (num_required_blocks > self.free_block_queue.num_free_blocks):
|
||||||
num_evictable_computed_blocks):
|
|
||||||
# Cannot allocate new blocks.
|
# Cannot allocate new blocks.
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -225,8 +217,7 @@ class KVCacheManager:
|
|||||||
# preallocated blocks.
|
# preallocated blocks.
|
||||||
num_new_blocks = min(
|
num_new_blocks = min(
|
||||||
num_required_blocks + self.num_preallocate_blocks,
|
num_required_blocks + self.num_preallocate_blocks,
|
||||||
self.free_block_queue.num_free_blocks -
|
self.free_block_queue.num_free_blocks,
|
||||||
num_evictable_computed_blocks,
|
|
||||||
# Should not exceed the maximum number of blocks per request.
|
# Should not exceed the maximum number of blocks per request.
|
||||||
# This is especially because the block table has the shape
|
# This is especially because the block table has the shape
|
||||||
# [..., max_num_blocks_per_req].
|
# [..., max_num_blocks_per_req].
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user