[v1][KVCacheManager] pass num_new_computed_tokens to kv cache manager (#18001)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang 2025-05-14 10:09:39 +08:00 committed by GitHub
parent 40de1ef455
commit f2ae883b67
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 119 additions and 53 deletions

View File

@ -81,7 +81,9 @@ def test_prefill(hash_algo):
assert len(manager.req_to_block_hashes[req0.request_id]) == 3
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks)
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [1, 2, 3, 4]
# Check full block metadata
@ -108,7 +110,9 @@ def test_prefill(hash_algo):
assert computed_blocks.get_block_ids() == [1, 2, 3]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
blocks = manager.allocate_slots(req1, num_new_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [5]
for block in computed_blocks.blocks:
assert block.ref_cnt == 2
@ -140,7 +144,9 @@ def test_prefill(hash_algo):
assert computed_blocks.get_block_ids() == [1, 2, 3]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks)
blocks = manager.allocate_slots(req2, num_new_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [6]
# Although we only have 6 free blocks, we have 8 blocks in
@ -161,7 +167,9 @@ def test_prefill(hash_algo):
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req3, 16 * 10, computed_blocks)
blocks = manager.allocate_slots(req3, 16 * 10,
len(computed_blocks.blocks) * 16,
computed_blocks)
# This block ID order also checks the eviction order.
assert blocks.get_block_ids() == [7, 8, 9, 10, 4, 5, 6, 3, 2, 1]
assert manager.block_pool.free_block_queue.num_free_blocks == 0
@ -197,7 +205,9 @@ def test_prefill_plp():
assert len(manager.req_to_block_hashes[req0.request_id]) == 0
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks)
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [1, 2, 3, 4]
req0_block_hashes = [b.block_hash for b in blocks.blocks]
@ -226,7 +236,9 @@ def test_prefill_plp():
assert computed_blocks.get_block_ids() == [1, 2, 3]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
blocks = manager.allocate_slots(req1, num_new_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [5]
for block in computed_blocks.blocks:
assert block.ref_cnt == 2
@ -259,7 +271,9 @@ def test_prefill_plp():
assert len(manager.req_to_block_hashes[req2.request_id]) == 0
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req2, 55, computed_blocks)
blocks = manager.allocate_slots(req2, 55,
len(computed_blocks.blocks) * 16,
computed_blocks)
block_ids = blocks.get_block_ids()
# Duplicate cached blocks have different ids but same hashes vs request #0
assert [b.block_hash for b in blocks.blocks] == req0_block_hashes
@ -290,14 +304,18 @@ def test_decode():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks)
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [1, 2, 3, 4]
# Append slots without allocating a new block.
req0.num_computed_tokens = 55
for _ in range(4):
req0.append_output_token_ids(8)
new_blocks = manager.allocate_slots(req0, 4)
new_blocks = manager.allocate_slots(req0, 4,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert new_blocks is not None and len(new_blocks.blocks) == 0
assert manager.single_type_manager.req_to_blocks[
req0.request_id][-1].block_hash is None
@ -308,7 +326,9 @@ def test_decode():
# the preallocated block.
for _ in range(9 + 10):
req0.append_output_token_ids(7)
new_blocks = manager.allocate_slots(req0, 19)
new_blocks = manager.allocate_slots(req0, 19,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert new_blocks is not None and len(new_blocks.blocks) == 1
assert manager.single_type_manager.req_to_blocks[
req0.request_id][-2].block_hash is not None
@ -328,7 +348,9 @@ def test_evict():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks)
blocks = manager.allocate_slots(req0, 5 * 16 + 7,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks.blocks) == 6 # 5 full + 1 partial
# 3 blocks.
@ -337,7 +359,9 @@ def test_evict():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req1, 3 * 16, computed_blocks)
blocks = manager.allocate_slots(req1, 3 * 16,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks.blocks) == 3 # 3 full blocks
last_token_id += 3 * 16
@ -357,7 +381,9 @@ def test_evict():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert computed_blocks.get_block_ids() == [1, 2]
assert num_computed_tokens == 2 * 16
blocks = manager.allocate_slots(req2, 3, computed_blocks)
blocks = manager.allocate_slots(req2, 3,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [10]
assert manager.block_pool.free_block_queue.num_free_blocks == 7
@ -380,7 +406,9 @@ def test_hash_block_correct_reuse():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req, num_tokens, computed_blocks)
blocks = manager.allocate_slots(req, num_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks.blocks) == 1
# Deallocate the block.
@ -392,7 +420,9 @@ def test_hash_block_correct_reuse():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks)
blocks = manager.allocate_slots(req, num_tokens - 1,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks.blocks) == 1
assert manager.block_pool.blocks[
@ -417,7 +447,9 @@ def test_computed_blocks_not_evicted():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, num_tokens, computed_blocks)
blocks = manager.allocate_slots(req0, num_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks.blocks) == 1
assert blocks.blocks[0].block_id == 1
@ -426,7 +458,9 @@ def test_computed_blocks_not_evicted():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req1, num_tokens, computed_blocks)
blocks = manager.allocate_slots(req1, num_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks.blocks) == 1
assert blocks.blocks[0].block_id == 2
@ -443,6 +477,7 @@ def test_computed_blocks_not_evicted():
assert num_computed_tokens == block_size
blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks.blocks) == 1
assert blocks.blocks[0].block_id == 2
@ -464,7 +499,9 @@ def test_basic_prefix_caching_disabled():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req1, 10, computed_blocks)
blocks = manager.allocate_slots(req1, 10,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks.blocks) == 3
# Free the blocks.
@ -475,7 +512,9 @@ def test_basic_prefix_caching_disabled():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req2, 16, computed_blocks)
blocks = manager.allocate_slots(req2, 16,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks.blocks) == 4
# New requests should not have any blocks.
@ -483,7 +522,9 @@ def test_basic_prefix_caching_disabled():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req3, 4, computed_blocks)
blocks = manager.allocate_slots(req3, 4,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert not blocks
@ -581,14 +622,18 @@ def test_mm_prefix_caching():
assert block_hashes[1].extra_keys == ("aaa", "bbb")
assert block_hashes[2].extra_keys == ("bbb", )
blocks = manager.allocate_slots(req0, 59, computed_blocks)
blocks = manager.allocate_slots(req0, 59,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [1, 2, 3, 4]
req0.num_computed_tokens = 59
# Append slots without allocating a new block.
for _ in range(5):
req0.append_output_token_ids(8)
new_blocks = manager.allocate_slots(req0, 5)
new_blocks = manager.allocate_slots(req0, 5,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert new_blocks is not None and len(new_blocks.blocks) == 0
# The just completed block should have hashes with extra keys.
@ -638,14 +683,18 @@ def test_cache_key_salting():
assert block_hashes[1].extra_keys is None
assert block_hashes[2].extra_keys is None
blocks = manager.allocate_slots(req0, 59, computed_blocks)
blocks = manager.allocate_slots(req0, 59,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [1, 2, 3, 4]
req0.num_computed_tokens = 59
# Append slots without allocating a new block.
for _ in range(5):
req0.append_output_token_ids(8)
new_blocks = manager.allocate_slots(req0, 5)
new_blocks = manager.allocate_slots(req0, 5,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert new_blocks is not None and len(new_blocks.blocks) == 0
# Now one more block that should not have extra keys.
@ -691,7 +740,8 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
manager.allocate_slots(req0, 48, computed_blocks)
manager.allocate_slots(req0, 48,
len(computed_blocks.blocks) * 16, computed_blocks)
block_part0 = manager.single_type_manager.req_to_blocks[req0.request_id]
# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
@ -699,7 +749,8 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert computed_blocks.blocks == block_part0
assert num_computed_tokens == 3 * 16
manager.allocate_slots(req1, 48, computed_blocks)
manager.allocate_slots(req1, 48,
len(computed_blocks.blocks) * 16, computed_blocks)
block_part1 = manager.single_type_manager.req_to_blocks[req1.request_id]
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
# | Req1-5(F)| ... |
@ -713,7 +764,8 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
manager.allocate_slots(req2, block_size * 2, computed_blocks)
manager.allocate_slots(req2, block_size * 2,
len(computed_blocks.blocks) * 16, computed_blocks)
# Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed,
# but it cannot be allocated due to insufficient free blocks (2).
@ -724,7 +776,9 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
assert computed_blocks.blocks == block_part1
assert num_computed_tokens == 6 * 16
# Req3 cannot be allocated.
assert manager.allocate_slots(req3, 48, computed_blocks) is None
assert manager.allocate_slots(req3, 48,
len(computed_blocks.blocks) * 16,
computed_blocks) is None
# Block 0-2 are used by Req 1.
assert {block.ref_cnt for block in block_part1[:3]} == {1}
# Block 3-5 are free.
@ -751,7 +805,9 @@ def test_reset_prefix_cache():
computed_blocks, _ = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert len(computed_blocks.blocks) == 3
blocks = manager.allocate_slots(req1, 7, computed_blocks)
blocks = manager.allocate_slots(req1, 7,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [5]
# Failed to reset prefix cache because some blocks are not freed yet.
@ -782,7 +838,8 @@ def test_prefix_cache_stats_disabled():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
manager.allocate_slots(req, 16, computed_blocks)
manager.allocate_slots(req, 16,
len(computed_blocks.blocks) * 16, computed_blocks)
manager.reset_prefix_cache()
# Ensure prefix_cache_stats remains None
@ -860,7 +917,8 @@ def test_eagle_enabled_removes_last_block():
# Prime the cache
computed_blocks, _ = manager.get_computed_blocks(req)
manager.allocate_slots(req, len(token_ids), computed_blocks)
manager.allocate_slots(req, len(token_ids),
len(computed_blocks.blocks) * 16, computed_blocks)
manager.free(req)
# New request with same tokens + Eagle enabled
@ -889,7 +947,8 @@ def test_eagle_with_partial_blocks():
# Prime the cache
computed_blocks, _ = manager.get_computed_blocks(req)
manager.allocate_slots(req, len(token_ids), computed_blocks)
manager.allocate_slots(req, len(token_ids),
len(computed_blocks.blocks) * 16, computed_blocks)
manager.free(req)
# New request with Eagle enabled
@ -928,7 +987,8 @@ def test_eagle_with_sliding_window():
# Prime the cache
computed_blocks, _ = manager.get_computed_blocks(req)
manager.allocate_slots(req, len(token_ids), computed_blocks)
manager.allocate_slots(req, len(token_ids),
len(computed_blocks.blocks) * 16, computed_blocks)
# record the block hash of the first block in the request for later use
block_hash_first_block = manager.req_to_block_hashes[req.request_id][0]
assert block_hash_first_block is not None

View File

@ -121,13 +121,6 @@ class KVCacheManager:
- A list of blocks that are computed for the request.
- The number of computed tokens.
"""
# Request already has blocks from async load via KVConnector.
num_existing_blocks = len(
self.single_type_manager.req_to_blocks[request.request_id])
if num_existing_blocks > 0:
return KVCacheBlocks.create_empty(), request.num_computed_tokens
# Prefix caching is disabled or
# When the request requires prompt logprobs, we skip prefix caching.
if (not self.enable_caching
@ -172,6 +165,7 @@ class KVCacheManager:
self,
request: Request,
num_new_tokens: int,
num_new_computed_tokens: int = 0,
new_computed_blocks: Optional[KVCacheBlocks] = None,
num_lookahead_tokens: int = 0,
delay_cache_blocks: bool = False,
@ -183,8 +177,10 @@ class KVCacheManager:
num_new_tokens: The number of tokens to allocate, including external
tokens. Note that this does not include tokens that have
already been computed locally (i.e. new_computed_blocks).
new_computed_blocks: The new computed blocks just hitting the
prefix caching.
num_new_computed_tokens: The number of new computed tokens just
hitting the prefix caching, excluding external tokens.
new_computed_blocks: The cached blocks for the above new computed
tokens.
num_lookahead_tokens: The number of speculative tokens to allocate.
This is used by spec decode proposers with kv-cache such
as eagle.
@ -229,7 +225,7 @@ class KVCacheManager:
# The number of computed tokens is the number of computed tokens plus
# the new prefix caching hits
num_computed_tokens = (request.num_computed_tokens +
len(new_computed_block_list) * self.block_size)
num_new_computed_tokens)
num_tokens_need_slot = min(
num_computed_tokens + num_new_tokens + num_lookahead_tokens,
self.max_model_len)

View File

@ -18,7 +18,7 @@ from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
compute_encoder_budget)
from vllm.v1.core.kv_cache_manager import KVCacheManager
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
from vllm.v1.core.sched.interface import SchedulerInterface
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
SchedulerOutput)
@ -311,12 +311,14 @@ class Scheduler(SchedulerInterface):
break
request = self.waiting[0]
num_prealloc_computed_tokens = 0
# P/D: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
is_ready = self._update_waiting_for_remote_kv(request)
if is_ready:
request.status = RequestStatus.WAITING
num_prealloc_computed_tokens = (
request.num_computed_tokens)
else:
self.waiting.popleft()
skipped_waiting_requests.appendleft(request)
@ -345,18 +347,25 @@ class Scheduler(SchedulerInterface):
continue
# Get already-cached tokens.
new_computed_blocks, num_computed_tokens = \
self.kv_cache_manager.get_computed_blocks(
request)
if num_prealloc_computed_tokens == 0:
new_computed_blocks, num_native_computed_tokens = \
self.kv_cache_manager.get_computed_blocks(
request)
else:
# P/D: skip checking prefix cache if loaded from remote kvs.
new_computed_blocks = KVCacheBlocks.create_empty()
num_native_computed_tokens = 0
# Get externally-cached tokens if using a KVConnector.
num_external_tokens, load_kv_async = (
num_external_computed_tokens, load_kv_async = (
(0, False) if self.connector is None else
self.connector.get_num_new_matched_tokens(
request, num_computed_tokens))
request, num_native_computed_tokens))
# Total computed tokens (local + external).
num_computed_tokens += num_external_tokens
num_computed_tokens = (num_native_computed_tokens +
num_external_computed_tokens +
num_prealloc_computed_tokens)
encoder_inputs_to_schedule = None
new_encoder_budget = encoder_budget
@ -390,7 +399,8 @@ class Scheduler(SchedulerInterface):
new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens + num_external_tokens,
num_new_tokens + num_external_computed_tokens,
num_native_computed_tokens,
new_computed_blocks,
num_lookahead_tokens=self.num_lookahead_tokens,
delay_cache_blocks=load_kv_async,
@ -406,7 +416,7 @@ class Scheduler(SchedulerInterface):
self.connector.update_state_after_alloc(
request,
new_computed_blocks + new_blocks,
num_external_tokens,
num_external_computed_tokens,
)
self.waiting.popleft()