[v1][KVCacheManager] pass num_new_computed_tokens to kv cache manager (#18001)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang 2025-05-14 10:09:39 +08:00 committed by GitHub
parent 40de1ef455
commit f2ae883b67
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 119 additions and 53 deletions

View File

@ -81,7 +81,9 @@ def test_prefill(hash_algo):
assert len(manager.req_to_block_hashes[req0.request_id]) == 3 assert len(manager.req_to_block_hashes[req0.request_id]) == 3
assert not computed_blocks.blocks assert not computed_blocks.blocks
assert num_computed_tokens == 0 assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks) blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [1, 2, 3, 4] assert blocks.get_block_ids() == [1, 2, 3, 4]
# Check full block metadata # Check full block metadata
@ -108,7 +110,9 @@ def test_prefill(hash_algo):
assert computed_blocks.get_block_ids() == [1, 2, 3] assert computed_blocks.get_block_ids() == [1, 2, 3]
assert num_computed_tokens == 3 * 16 assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16 num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks) blocks = manager.allocate_slots(req1, num_new_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [5] assert blocks.get_block_ids() == [5]
for block in computed_blocks.blocks: for block in computed_blocks.blocks:
assert block.ref_cnt == 2 assert block.ref_cnt == 2
@ -140,7 +144,9 @@ def test_prefill(hash_algo):
assert computed_blocks.get_block_ids() == [1, 2, 3] assert computed_blocks.get_block_ids() == [1, 2, 3]
assert num_computed_tokens == 3 * 16 assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16 num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks) blocks = manager.allocate_slots(req2, num_new_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [6] assert blocks.get_block_ids() == [6]
# Although we only have 6 free blocks, we have 8 blocks in # Although we only have 6 free blocks, we have 8 blocks in
@ -161,7 +167,9 @@ def test_prefill(hash_algo):
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert not computed_blocks.blocks assert not computed_blocks.blocks
assert num_computed_tokens == 0 assert num_computed_tokens == 0
blocks = manager.allocate_slots(req3, 16 * 10, computed_blocks) blocks = manager.allocate_slots(req3, 16 * 10,
len(computed_blocks.blocks) * 16,
computed_blocks)
# This block ID order also checks the eviction order. # This block ID order also checks the eviction order.
assert blocks.get_block_ids() == [7, 8, 9, 10, 4, 5, 6, 3, 2, 1] assert blocks.get_block_ids() == [7, 8, 9, 10, 4, 5, 6, 3, 2, 1]
assert manager.block_pool.free_block_queue.num_free_blocks == 0 assert manager.block_pool.free_block_queue.num_free_blocks == 0
@ -197,7 +205,9 @@ def test_prefill_plp():
assert len(manager.req_to_block_hashes[req0.request_id]) == 0 assert len(manager.req_to_block_hashes[req0.request_id]) == 0
assert not computed_blocks.blocks assert not computed_blocks.blocks
assert num_computed_tokens == 0 assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks) blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [1, 2, 3, 4] assert blocks.get_block_ids() == [1, 2, 3, 4]
req0_block_hashes = [b.block_hash for b in blocks.blocks] req0_block_hashes = [b.block_hash for b in blocks.blocks]
@ -226,7 +236,9 @@ def test_prefill_plp():
assert computed_blocks.get_block_ids() == [1, 2, 3] assert computed_blocks.get_block_ids() == [1, 2, 3]
assert num_computed_tokens == 3 * 16 assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16 num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks) blocks = manager.allocate_slots(req1, num_new_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [5] assert blocks.get_block_ids() == [5]
for block in computed_blocks.blocks: for block in computed_blocks.blocks:
assert block.ref_cnt == 2 assert block.ref_cnt == 2
@ -259,7 +271,9 @@ def test_prefill_plp():
assert len(manager.req_to_block_hashes[req2.request_id]) == 0 assert len(manager.req_to_block_hashes[req2.request_id]) == 0
assert not computed_blocks.blocks assert not computed_blocks.blocks
assert num_computed_tokens == 0 assert num_computed_tokens == 0
blocks = manager.allocate_slots(req2, 55, computed_blocks) blocks = manager.allocate_slots(req2, 55,
len(computed_blocks.blocks) * 16,
computed_blocks)
block_ids = blocks.get_block_ids() block_ids = blocks.get_block_ids()
# Duplicate cached blocks have different ids but same hashes vs request #0 # Duplicate cached blocks have different ids but same hashes vs request #0
assert [b.block_hash for b in blocks.blocks] == req0_block_hashes assert [b.block_hash for b in blocks.blocks] == req0_block_hashes
@ -290,14 +304,18 @@ def test_decode():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks assert not computed_blocks.blocks
assert num_computed_tokens == 0 assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks) blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [1, 2, 3, 4] assert blocks.get_block_ids() == [1, 2, 3, 4]
# Append slots without allocating a new block. # Append slots without allocating a new block.
req0.num_computed_tokens = 55 req0.num_computed_tokens = 55
for _ in range(4): for _ in range(4):
req0.append_output_token_ids(8) req0.append_output_token_ids(8)
new_blocks = manager.allocate_slots(req0, 4) new_blocks = manager.allocate_slots(req0, 4,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert new_blocks is not None and len(new_blocks.blocks) == 0 assert new_blocks is not None and len(new_blocks.blocks) == 0
assert manager.single_type_manager.req_to_blocks[ assert manager.single_type_manager.req_to_blocks[
req0.request_id][-1].block_hash is None req0.request_id][-1].block_hash is None
@ -308,7 +326,9 @@ def test_decode():
# the preallocated block. # the preallocated block.
for _ in range(9 + 10): for _ in range(9 + 10):
req0.append_output_token_ids(7) req0.append_output_token_ids(7)
new_blocks = manager.allocate_slots(req0, 19) new_blocks = manager.allocate_slots(req0, 19,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert new_blocks is not None and len(new_blocks.blocks) == 1 assert new_blocks is not None and len(new_blocks.blocks) == 1
assert manager.single_type_manager.req_to_blocks[ assert manager.single_type_manager.req_to_blocks[
req0.request_id][-2].block_hash is not None req0.request_id][-2].block_hash is not None
@ -328,7 +348,9 @@ def test_evict():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks assert not computed_blocks.blocks
assert num_computed_tokens == 0 assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks) blocks = manager.allocate_slots(req0, 5 * 16 + 7,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks.blocks) == 6 # 5 full + 1 partial assert len(blocks.blocks) == 6 # 5 full + 1 partial
# 3 blocks. # 3 blocks.
@ -337,7 +359,9 @@ def test_evict():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks.blocks assert not computed_blocks.blocks
assert num_computed_tokens == 0 assert num_computed_tokens == 0
blocks = manager.allocate_slots(req1, 3 * 16, computed_blocks) blocks = manager.allocate_slots(req1, 3 * 16,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks.blocks) == 3 # 3 full blocks assert len(blocks.blocks) == 3 # 3 full blocks
last_token_id += 3 * 16 last_token_id += 3 * 16
@ -357,7 +381,9 @@ def test_evict():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert computed_blocks.get_block_ids() == [1, 2] assert computed_blocks.get_block_ids() == [1, 2]
assert num_computed_tokens == 2 * 16 assert num_computed_tokens == 2 * 16
blocks = manager.allocate_slots(req2, 3, computed_blocks) blocks = manager.allocate_slots(req2, 3,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [10] assert blocks.get_block_ids() == [10]
assert manager.block_pool.free_block_queue.num_free_blocks == 7 assert manager.block_pool.free_block_queue.num_free_blocks == 7
@ -380,7 +406,9 @@ def test_hash_block_correct_reuse():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks.blocks assert not computed_blocks.blocks
assert num_computed_tokens == 0 assert num_computed_tokens == 0
blocks = manager.allocate_slots(req, num_tokens, computed_blocks) blocks = manager.allocate_slots(req, num_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks.blocks) == 1 assert len(blocks.blocks) == 1
# Deallocate the block. # Deallocate the block.
@ -392,7 +420,9 @@ def test_hash_block_correct_reuse():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks.blocks assert not computed_blocks.blocks
assert num_computed_tokens == 0 assert num_computed_tokens == 0
blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks) blocks = manager.allocate_slots(req, num_tokens - 1,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks.blocks) == 1 assert len(blocks.blocks) == 1
assert manager.block_pool.blocks[ assert manager.block_pool.blocks[
@ -417,7 +447,9 @@ def test_computed_blocks_not_evicted():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks assert not computed_blocks.blocks
assert num_computed_tokens == 0 assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, num_tokens, computed_blocks) blocks = manager.allocate_slots(req0, num_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks.blocks) == 1 assert len(blocks.blocks) == 1
assert blocks.blocks[0].block_id == 1 assert blocks.blocks[0].block_id == 1
@ -426,7 +458,9 @@ def test_computed_blocks_not_evicted():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks.blocks assert not computed_blocks.blocks
assert num_computed_tokens == 0 assert num_computed_tokens == 0
blocks = manager.allocate_slots(req1, num_tokens, computed_blocks) blocks = manager.allocate_slots(req1, num_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks.blocks) == 1 assert len(blocks.blocks) == 1
assert blocks.blocks[0].block_id == 2 assert blocks.blocks[0].block_id == 2
@ -443,6 +477,7 @@ def test_computed_blocks_not_evicted():
assert num_computed_tokens == block_size assert num_computed_tokens == block_size
blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens, blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks) computed_blocks)
assert len(blocks.blocks) == 1 assert len(blocks.blocks) == 1
assert blocks.blocks[0].block_id == 2 assert blocks.blocks[0].block_id == 2
@ -464,7 +499,9 @@ def test_basic_prefix_caching_disabled():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks.blocks assert not computed_blocks.blocks
assert num_computed_tokens == 0 assert num_computed_tokens == 0
blocks = manager.allocate_slots(req1, 10, computed_blocks) blocks = manager.allocate_slots(req1, 10,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks.blocks) == 3 assert len(blocks.blocks) == 3
# Free the blocks. # Free the blocks.
@ -475,7 +512,9 @@ def test_basic_prefix_caching_disabled():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert not computed_blocks.blocks assert not computed_blocks.blocks
assert num_computed_tokens == 0 assert num_computed_tokens == 0
blocks = manager.allocate_slots(req2, 16, computed_blocks) blocks = manager.allocate_slots(req2, 16,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks.blocks) == 4 assert len(blocks.blocks) == 4
# New requests should not have any blocks. # New requests should not have any blocks.
@ -483,7 +522,9 @@ def test_basic_prefix_caching_disabled():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert not computed_blocks.blocks assert not computed_blocks.blocks
assert num_computed_tokens == 0 assert num_computed_tokens == 0
blocks = manager.allocate_slots(req3, 4, computed_blocks) blocks = manager.allocate_slots(req3, 4,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert not blocks assert not blocks
@ -581,14 +622,18 @@ def test_mm_prefix_caching():
assert block_hashes[1].extra_keys == ("aaa", "bbb") assert block_hashes[1].extra_keys == ("aaa", "bbb")
assert block_hashes[2].extra_keys == ("bbb", ) assert block_hashes[2].extra_keys == ("bbb", )
blocks = manager.allocate_slots(req0, 59, computed_blocks) blocks = manager.allocate_slots(req0, 59,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [1, 2, 3, 4] assert blocks.get_block_ids() == [1, 2, 3, 4]
req0.num_computed_tokens = 59 req0.num_computed_tokens = 59
# Append slots without allocating a new block. # Append slots without allocating a new block.
for _ in range(5): for _ in range(5):
req0.append_output_token_ids(8) req0.append_output_token_ids(8)
new_blocks = manager.allocate_slots(req0, 5) new_blocks = manager.allocate_slots(req0, 5,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert new_blocks is not None and len(new_blocks.blocks) == 0 assert new_blocks is not None and len(new_blocks.blocks) == 0
# The just completed block should have hashes with extra keys. # The just completed block should have hashes with extra keys.
@ -638,14 +683,18 @@ def test_cache_key_salting():
assert block_hashes[1].extra_keys is None assert block_hashes[1].extra_keys is None
assert block_hashes[2].extra_keys is None assert block_hashes[2].extra_keys is None
blocks = manager.allocate_slots(req0, 59, computed_blocks) blocks = manager.allocate_slots(req0, 59,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [1, 2, 3, 4] assert blocks.get_block_ids() == [1, 2, 3, 4]
req0.num_computed_tokens = 59 req0.num_computed_tokens = 59
# Append slots without allocating a new block. # Append slots without allocating a new block.
for _ in range(5): for _ in range(5):
req0.append_output_token_ids(8) req0.append_output_token_ids(8)
new_blocks = manager.allocate_slots(req0, 5) new_blocks = manager.allocate_slots(req0, 5,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert new_blocks is not None and len(new_blocks.blocks) == 0 assert new_blocks is not None and len(new_blocks.blocks) == 0
# Now one more block that should not have extra keys. # Now one more block that should not have extra keys.
@ -691,7 +740,8 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks assert not computed_blocks.blocks
assert num_computed_tokens == 0 assert num_computed_tokens == 0
manager.allocate_slots(req0, 48, computed_blocks) manager.allocate_slots(req0, 48,
len(computed_blocks.blocks) * 16, computed_blocks)
block_part0 = manager.single_type_manager.req_to_blocks[req0.request_id] block_part0 = manager.single_type_manager.req_to_blocks[req0.request_id]
# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... | # | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
@ -699,7 +749,8 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert computed_blocks.blocks == block_part0 assert computed_blocks.blocks == block_part0
assert num_computed_tokens == 3 * 16 assert num_computed_tokens == 3 * 16
manager.allocate_slots(req1, 48, computed_blocks) manager.allocate_slots(req1, 48,
len(computed_blocks.blocks) * 16, computed_blocks)
block_part1 = manager.single_type_manager.req_to_blocks[req1.request_id] block_part1 = manager.single_type_manager.req_to_blocks[req1.request_id]
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) | # | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
# | Req1-5(F)| ... | # | Req1-5(F)| ... |
@ -713,7 +764,8 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert not computed_blocks.blocks assert not computed_blocks.blocks
assert num_computed_tokens == 0 assert num_computed_tokens == 0
manager.allocate_slots(req2, block_size * 2, computed_blocks) manager.allocate_slots(req2, block_size * 2,
len(computed_blocks.blocks) * 16, computed_blocks)
# Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed, # Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed,
# but it cannot be allocated due to insufficient free blocks (2). # but it cannot be allocated due to insufficient free blocks (2).
@ -724,7 +776,9 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
assert computed_blocks.blocks == block_part1 assert computed_blocks.blocks == block_part1
assert num_computed_tokens == 6 * 16 assert num_computed_tokens == 6 * 16
# Req3 cannot be allocated. # Req3 cannot be allocated.
assert manager.allocate_slots(req3, 48, computed_blocks) is None assert manager.allocate_slots(req3, 48,
len(computed_blocks.blocks) * 16,
computed_blocks) is None
# Block 0-2 are used by Req 1. # Block 0-2 are used by Req 1.
assert {block.ref_cnt for block in block_part1[:3]} == {1} assert {block.ref_cnt for block in block_part1[:3]} == {1}
# Block 3-5 are free. # Block 3-5 are free.
@ -751,7 +805,9 @@ def test_reset_prefix_cache():
computed_blocks, _ = manager.get_computed_blocks(req1) computed_blocks, _ = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3 assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert len(computed_blocks.blocks) == 3 assert len(computed_blocks.blocks) == 3
blocks = manager.allocate_slots(req1, 7, computed_blocks) blocks = manager.allocate_slots(req1, 7,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [5] assert blocks.get_block_ids() == [5]
# Failed to reset prefix cache because some blocks are not freed yet. # Failed to reset prefix cache because some blocks are not freed yet.
@ -782,7 +838,8 @@ def test_prefix_cache_stats_disabled():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks.blocks assert not computed_blocks.blocks
assert num_computed_tokens == 0 assert num_computed_tokens == 0
manager.allocate_slots(req, 16, computed_blocks) manager.allocate_slots(req, 16,
len(computed_blocks.blocks) * 16, computed_blocks)
manager.reset_prefix_cache() manager.reset_prefix_cache()
# Ensure prefix_cache_stats remains None # Ensure prefix_cache_stats remains None
@ -860,7 +917,8 @@ def test_eagle_enabled_removes_last_block():
# Prime the cache # Prime the cache
computed_blocks, _ = manager.get_computed_blocks(req) computed_blocks, _ = manager.get_computed_blocks(req)
manager.allocate_slots(req, len(token_ids), computed_blocks) manager.allocate_slots(req, len(token_ids),
len(computed_blocks.blocks) * 16, computed_blocks)
manager.free(req) manager.free(req)
# New request with same tokens + Eagle enabled # New request with same tokens + Eagle enabled
@ -889,7 +947,8 @@ def test_eagle_with_partial_blocks():
# Prime the cache # Prime the cache
computed_blocks, _ = manager.get_computed_blocks(req) computed_blocks, _ = manager.get_computed_blocks(req)
manager.allocate_slots(req, len(token_ids), computed_blocks) manager.allocate_slots(req, len(token_ids),
len(computed_blocks.blocks) * 16, computed_blocks)
manager.free(req) manager.free(req)
# New request with Eagle enabled # New request with Eagle enabled
@ -928,7 +987,8 @@ def test_eagle_with_sliding_window():
# Prime the cache # Prime the cache
computed_blocks, _ = manager.get_computed_blocks(req) computed_blocks, _ = manager.get_computed_blocks(req)
manager.allocate_slots(req, len(token_ids), computed_blocks) manager.allocate_slots(req, len(token_ids),
len(computed_blocks.blocks) * 16, computed_blocks)
# record the block hash of the first block in the request for later use # record the block hash of the first block in the request for later use
block_hash_first_block = manager.req_to_block_hashes[req.request_id][0] block_hash_first_block = manager.req_to_block_hashes[req.request_id][0]
assert block_hash_first_block is not None assert block_hash_first_block is not None

View File

@ -121,13 +121,6 @@ class KVCacheManager:
- A list of blocks that are computed for the request. - A list of blocks that are computed for the request.
- The number of computed tokens. - The number of computed tokens.
""" """
# Request already has blocks from async load via KVConnector.
num_existing_blocks = len(
self.single_type_manager.req_to_blocks[request.request_id])
if num_existing_blocks > 0:
return KVCacheBlocks.create_empty(), request.num_computed_tokens
# Prefix caching is disabled or # Prefix caching is disabled or
# When the request requires prompt logprobs, we skip prefix caching. # When the request requires prompt logprobs, we skip prefix caching.
if (not self.enable_caching if (not self.enable_caching
@ -172,6 +165,7 @@ class KVCacheManager:
self, self,
request: Request, request: Request,
num_new_tokens: int, num_new_tokens: int,
num_new_computed_tokens: int = 0,
new_computed_blocks: Optional[KVCacheBlocks] = None, new_computed_blocks: Optional[KVCacheBlocks] = None,
num_lookahead_tokens: int = 0, num_lookahead_tokens: int = 0,
delay_cache_blocks: bool = False, delay_cache_blocks: bool = False,
@ -183,8 +177,10 @@ class KVCacheManager:
num_new_tokens: The number of tokens to allocate, including external num_new_tokens: The number of tokens to allocate, including external
tokens. Note that this does not include tokens that have tokens. Note that this does not include tokens that have
already been computed locally (i.e. new_computed_blocks). already been computed locally (i.e. new_computed_blocks).
new_computed_blocks: The new computed blocks just hitting the num_new_computed_tokens: The number of new computed tokens just
prefix caching. hitting the prefix caching, excluding external tokens.
new_computed_blocks: The cached blocks for the above new computed
tokens.
num_lookahead_tokens: The number of speculative tokens to allocate. num_lookahead_tokens: The number of speculative tokens to allocate.
This is used by spec decode proposers with kv-cache such This is used by spec decode proposers with kv-cache such
as eagle. as eagle.
@ -229,7 +225,7 @@ class KVCacheManager:
# The number of computed tokens is the number of computed tokens plus # The number of computed tokens is the number of computed tokens plus
# the new prefix caching hits # the new prefix caching hits
num_computed_tokens = (request.num_computed_tokens + num_computed_tokens = (request.num_computed_tokens +
len(new_computed_block_list) * self.block_size) num_new_computed_tokens)
num_tokens_need_slot = min( num_tokens_need_slot = min(
num_computed_tokens + num_new_tokens + num_lookahead_tokens, num_computed_tokens + num_new_tokens + num_lookahead_tokens,
self.max_model_len) self.max_model_len)

View File

@ -18,7 +18,7 @@ from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
compute_encoder_budget) compute_encoder_budget)
from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.interface import SchedulerInterface
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
SchedulerOutput) SchedulerOutput)
@ -311,12 +311,14 @@ class Scheduler(SchedulerInterface):
break break
request = self.waiting[0] request = self.waiting[0]
num_prealloc_computed_tokens = 0
# P/D: skip request if still waiting for remote kvs. # P/D: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
is_ready = self._update_waiting_for_remote_kv(request) is_ready = self._update_waiting_for_remote_kv(request)
if is_ready: if is_ready:
request.status = RequestStatus.WAITING request.status = RequestStatus.WAITING
num_prealloc_computed_tokens = (
request.num_computed_tokens)
else: else:
self.waiting.popleft() self.waiting.popleft()
skipped_waiting_requests.appendleft(request) skipped_waiting_requests.appendleft(request)
@ -345,18 +347,25 @@ class Scheduler(SchedulerInterface):
continue continue
# Get already-cached tokens. # Get already-cached tokens.
new_computed_blocks, num_computed_tokens = \ if num_prealloc_computed_tokens == 0:
self.kv_cache_manager.get_computed_blocks( new_computed_blocks, num_native_computed_tokens = \
request) self.kv_cache_manager.get_computed_blocks(
request)
else:
# P/D: skip checking prefix cache if loaded from remote kvs.
new_computed_blocks = KVCacheBlocks.create_empty()
num_native_computed_tokens = 0
# Get externally-cached tokens if using a KVConnector. # Get externally-cached tokens if using a KVConnector.
num_external_tokens, load_kv_async = ( num_external_computed_tokens, load_kv_async = (
(0, False) if self.connector is None else (0, False) if self.connector is None else
self.connector.get_num_new_matched_tokens( self.connector.get_num_new_matched_tokens(
request, num_computed_tokens)) request, num_native_computed_tokens))
# Total computed tokens (local + external). # Total computed tokens (local + external).
num_computed_tokens += num_external_tokens num_computed_tokens = (num_native_computed_tokens +
num_external_computed_tokens +
num_prealloc_computed_tokens)
encoder_inputs_to_schedule = None encoder_inputs_to_schedule = None
new_encoder_budget = encoder_budget new_encoder_budget = encoder_budget
@ -390,7 +399,8 @@ class Scheduler(SchedulerInterface):
new_blocks = self.kv_cache_manager.allocate_slots( new_blocks = self.kv_cache_manager.allocate_slots(
request, request,
num_new_tokens + num_external_tokens, num_new_tokens + num_external_computed_tokens,
num_native_computed_tokens,
new_computed_blocks, new_computed_blocks,
num_lookahead_tokens=self.num_lookahead_tokens, num_lookahead_tokens=self.num_lookahead_tokens,
delay_cache_blocks=load_kv_async, delay_cache_blocks=load_kv_async,
@ -406,7 +416,7 @@ class Scheduler(SchedulerInterface):
self.connector.update_state_after_alloc( self.connector.update_state_after_alloc(
request, request,
new_computed_blocks + new_blocks, new_computed_blocks + new_blocks,
num_external_tokens, num_external_computed_tokens,
) )
self.waiting.popleft() self.waiting.popleft()