[v1] Introduce KVCacheBlocks as interface between Scheduler and KVCacheManager (#17479)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang 2025-05-06 23:50:34 +08:00 committed by GitHub
parent 0d115460a7
commit aabcd2cae3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 121 additions and 99 deletions

View File

@ -542,7 +542,7 @@ def test_allocate_with_lookahead():
num_tokens=3, num_tokens=3,
num_lookahead_tokens=2, # Total required: 3+2=5 tokens num_lookahead_tokens=2, # Total required: 3+2=5 tokens
) )
assert len(blocks) == 2 # ceil(5/4)=2 blocks assert len(blocks.blocks) == 2 # ceil(5/4)=2 blocks
# Test case 2: With precomputed blocks # Test case 2: With precomputed blocks
kv_cache_manager = KVCacheManager(kv_cache_config=config, kv_cache_manager = KVCacheManager(kv_cache_config=config,
@ -553,7 +553,7 @@ def test_allocate_with_lookahead():
num_tokens=3, num_tokens=3,
num_lookahead_tokens=2, num_lookahead_tokens=2,
) )
assert len(blocks) == 2 assert len(blocks.blocks) == 2
# Test case 3: With precomputed blocks # Test case 3: With precomputed blocks
# required_blocks = ceil((3 + 4) / 4) = 2 # required_blocks = ceil((3 + 4) / 4) = 2
@ -564,4 +564,4 @@ def test_allocate_with_lookahead():
num_tokens=3, num_tokens=3,
num_lookahead_tokens=4, num_lookahead_tokens=4,
) )
assert len(blocks) == 2 assert len(blocks.blocks) == 2

View File

@ -79,10 +79,10 @@ def test_prefill(hash_algo):
req0 = make_request("0", all_token_ids) req0 = make_request("0", all_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert len(manager.req_to_block_hashes[req0.request_id]) == 3 assert len(manager.req_to_block_hashes[req0.request_id]) == 3
assert not computed_blocks assert not computed_blocks.blocks
assert num_computed_tokens == 0 assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks) blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [1, 2, 3, 4] assert blocks.get_block_ids() == [1, 2, 3, 4]
# Check full block metadata # Check full block metadata
parent_block_hash = None parent_block_hash = None
@ -105,12 +105,12 @@ def test_prefill(hash_algo):
req1 = make_request("1", common_token_ids + unique_token_ids) req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3 assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert [b.block_id for b in computed_blocks] == [1, 2, 3] assert computed_blocks.get_block_ids() == [1, 2, 3]
assert num_computed_tokens == 3 * 16 assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16 num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks) blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [5] assert blocks.get_block_ids() == [5]
for block in computed_blocks: for block in computed_blocks.blocks:
assert block.ref_cnt == 2 assert block.ref_cnt == 2
# At this point, we should have 5 free blocks left. # At this point, we should have 5 free blocks left.
@ -137,11 +137,11 @@ def test_prefill(hash_algo):
req2 = make_request("2", common_token_ids + unique_token_ids) req2 = make_request("2", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(manager.req_to_block_hashes[req2.request_id]) == 3 assert len(manager.req_to_block_hashes[req2.request_id]) == 3
assert [b.block_id for b in computed_blocks] == [1, 2, 3] assert computed_blocks.get_block_ids() == [1, 2, 3]
assert num_computed_tokens == 3 * 16 assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16 num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks) blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [6] assert blocks.get_block_ids() == [6]
# Although we only have 6 free blocks, we have 8 blocks in # Although we only have 6 free blocks, we have 8 blocks in
# the free block queue due to lazy removal. # the free block queue due to lazy removal.
@ -159,11 +159,11 @@ def test_prefill(hash_algo):
# Cache miss and eviction. # Cache miss and eviction.
req3 = make_request("3", [99] * (16 * 10)) req3 = make_request("3", [99] * (16 * 10))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert not computed_blocks assert not computed_blocks.blocks
assert num_computed_tokens == 0 assert num_computed_tokens == 0
blocks = manager.allocate_slots(req3, 16 * 10, computed_blocks) blocks = manager.allocate_slots(req3, 16 * 10, computed_blocks)
# This block ID order also checks the eviction order. # This block ID order also checks the eviction order.
assert [b.block_id for b in blocks] == [7, 8, 9, 10, 4, 5, 6, 3, 2, 1] assert blocks.get_block_ids() == [7, 8, 9, 10, 4, 5, 6, 3, 2, 1]
assert manager.block_pool.free_block_queue.num_free_blocks == 0 assert manager.block_pool.free_block_queue.num_free_blocks == 0
assert manager.block_pool.free_block_queue.free_list_head is None assert manager.block_pool.free_block_queue.free_list_head is None
assert manager.block_pool.free_block_queue.free_list_tail is None assert manager.block_pool.free_block_queue.free_list_tail is None
@ -195,11 +195,11 @@ def test_prefill_plp():
req0 = make_request("0", all_token_ids, prompt_logprobs=5) req0 = make_request("0", all_token_ids, prompt_logprobs=5)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert len(manager.req_to_block_hashes[req0.request_id]) == 3 assert len(manager.req_to_block_hashes[req0.request_id]) == 3
assert not computed_blocks assert not computed_blocks.blocks
assert num_computed_tokens == 0 assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks) blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [1, 2, 3, 4] assert blocks.get_block_ids() == [1, 2, 3, 4]
req0_block_hashes = [b.block_hash for b in blocks] req0_block_hashes = [b.block_hash for b in blocks.blocks]
# Check full block metadata # Check full block metadata
parent_block_hash = None parent_block_hash = None
@ -223,12 +223,12 @@ def test_prefill_plp():
req1 = make_request("1", common_token_ids + unique_token_ids) req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3 assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert [b.block_id for b in computed_blocks] == [1, 2, 3] assert computed_blocks.get_block_ids() == [1, 2, 3]
assert num_computed_tokens == 3 * 16 assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16 num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks) blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [5] assert blocks.get_block_ids() == [5]
for block in computed_blocks: for block in computed_blocks.blocks:
assert block.ref_cnt == 2 assert block.ref_cnt == 2
# At this point, we should have 5 free blocks left. # At this point, we should have 5 free blocks left.
@ -257,12 +257,12 @@ def test_prefill_plp():
prompt_logprobs=5) prompt_logprobs=5)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(manager.req_to_block_hashes[req2.request_id]) == 3 assert len(manager.req_to_block_hashes[req2.request_id]) == 3
assert not computed_blocks assert not computed_blocks.blocks
assert num_computed_tokens == 0 assert num_computed_tokens == 0
blocks = manager.allocate_slots(req2, 55, computed_blocks) blocks = manager.allocate_slots(req2, 55, computed_blocks)
block_ids = [b.block_id for b in blocks] block_ids = blocks.get_block_ids()
# Duplicate cached blocks have different ids but same hashes vs request #0 # Duplicate cached blocks have different ids but same hashes vs request #0
assert [b.block_hash for b in blocks] == req0_block_hashes assert [b.block_hash for b in blocks.blocks] == req0_block_hashes
assert block_ids != [1, 2, 3, 4] assert block_ids != [1, 2, 3, 4]
# Request #2 block hashes are valid since request #0 hashes are. # Request #2 block hashes are valid since request #0 hashes are.
@ -288,17 +288,17 @@ def test_decode():
unique_token_ids = [3] * 7 unique_token_ids = [3] * 7
req0 = make_request("0", common_token_ids + unique_token_ids) req0 = make_request("0", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks assert not computed_blocks.blocks
assert num_computed_tokens == 0 assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks) blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [1, 2, 3, 4] assert blocks.get_block_ids() == [1, 2, 3, 4]
# Append slots without allocating a new block. # Append slots without allocating a new block.
req0.num_computed_tokens = 55 req0.num_computed_tokens = 55
for _ in range(4): for _ in range(4):
req0.append_output_token_ids(8) req0.append_output_token_ids(8)
new_blocks = manager.allocate_slots(req0, 4) new_blocks = manager.allocate_slots(req0, 4)
assert new_blocks is not None and len(new_blocks) == 0 assert new_blocks is not None and len(new_blocks.blocks) == 0
assert manager.req_to_blocks[req0.request_id][-1].block_hash is None assert manager.req_to_blocks[req0.request_id][-1].block_hash is None
# Append slots with allocating a new block. # Append slots with allocating a new block.
@ -308,7 +308,7 @@ def test_decode():
for _ in range(9 + 10): for _ in range(9 + 10):
req0.append_output_token_ids(7) req0.append_output_token_ids(7)
new_blocks = manager.allocate_slots(req0, 19) new_blocks = manager.allocate_slots(req0, 19)
assert new_blocks is not None and len(new_blocks) == 1 assert new_blocks is not None and len(new_blocks.blocks) == 1
assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None
assert manager.req_to_blocks[req0.request_id][-1].block_hash is None assert manager.req_to_blocks[req0.request_id][-1].block_hash is None
@ -323,19 +323,19 @@ def test_evict():
last_token_id = 5 * 16 + 7 last_token_id = 5 * 16 + 7
req0 = make_request("0", list(range(last_token_id))) req0 = make_request("0", list(range(last_token_id)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks assert not computed_blocks.blocks
assert num_computed_tokens == 0 assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks) blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks)
assert len(blocks) == 6 # 5 full + 1 partial assert len(blocks.blocks) == 6 # 5 full + 1 partial
# 3 blocks. # 3 blocks.
req1 = make_request("1", list(range(last_token_id, req1 = make_request("1", list(range(last_token_id,
last_token_id + 3 * 16))) last_token_id + 3 * 16)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks assert not computed_blocks.blocks
assert num_computed_tokens == 0 assert num_computed_tokens == 0
blocks = manager.allocate_slots(req1, 3 * 16, computed_blocks) blocks = manager.allocate_slots(req1, 3 * 16, computed_blocks)
assert len(blocks) == 3 # 3 full blocks assert len(blocks.blocks) == 3 # 3 full blocks
last_token_id += 3 * 16 last_token_id += 3 * 16
# 10 - (6 + 3) == 1 # 10 - (6 + 3) == 1
@ -352,10 +352,10 @@ def test_evict():
# Touch the first 2 blocks. # Touch the first 2 blocks.
req2 = make_request("2", list(range(2 * 16 + 3))) req2 = make_request("2", list(range(2 * 16 + 3)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert [b.block_id for b in computed_blocks] == [1, 2] assert computed_blocks.get_block_ids() == [1, 2]
assert num_computed_tokens == 2 * 16 assert num_computed_tokens == 2 * 16
blocks = manager.allocate_slots(req2, 3, computed_blocks) blocks = manager.allocate_slots(req2, 3, computed_blocks)
assert [b.block_id for b in blocks] == [10] assert blocks.get_block_ids() == [10]
assert manager.block_pool.free_block_queue.num_free_blocks == 7 assert manager.block_pool.free_block_queue.num_free_blocks == 7
@ -375,10 +375,10 @@ def test_hash_block_correct_reuse():
num_tokens = block_size * 1 num_tokens = block_size * 1
req = make_request("0", list(range(num_tokens))) req = make_request("0", list(range(num_tokens)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks assert not computed_blocks.blocks
assert num_computed_tokens == 0 assert num_computed_tokens == 0
blocks = manager.allocate_slots(req, num_tokens, computed_blocks) blocks = manager.allocate_slots(req, num_tokens, computed_blocks)
assert len(blocks) == 1 assert len(blocks.blocks) == 1
# Deallocate the block. # Deallocate the block.
manager.free(req) manager.free(req)
@ -387,12 +387,13 @@ def test_hash_block_correct_reuse():
# block is cleared. # block is cleared.
req = make_request("1", list(range(num_tokens - 1))) req = make_request("1", list(range(num_tokens - 1)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks assert not computed_blocks.blocks
assert num_computed_tokens == 0 assert num_computed_tokens == 0
blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks) blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks)
assert len(blocks) == 1 assert len(blocks.blocks) == 1
assert manager.block_pool.blocks[blocks[0].block_id].block_hash is None assert manager.block_pool.blocks[
blocks.blocks[0].block_id].block_hash is None
def test_computed_blocks_not_evicted(): def test_computed_blocks_not_evicted():
@ -411,20 +412,20 @@ def test_computed_blocks_not_evicted():
num_tokens = block_size * 1 num_tokens = block_size * 1
req0 = make_request("0", list(range(num_tokens))) req0 = make_request("0", list(range(num_tokens)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks assert not computed_blocks.blocks
assert num_computed_tokens == 0 assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, num_tokens, computed_blocks) blocks = manager.allocate_slots(req0, num_tokens, computed_blocks)
assert len(blocks) == 1 assert len(blocks.blocks) == 1
assert blocks[0].block_id == 1 assert blocks.blocks[0].block_id == 1
# Allocate another block. # Allocate another block.
req1 = make_request("1", list(range(num_tokens, num_tokens * 2))) req1 = make_request("1", list(range(num_tokens, num_tokens * 2)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks assert not computed_blocks.blocks
assert num_computed_tokens == 0 assert num_computed_tokens == 0
blocks = manager.allocate_slots(req1, num_tokens, computed_blocks) blocks = manager.allocate_slots(req1, num_tokens, computed_blocks)
assert len(blocks) == 1 assert len(blocks.blocks) == 1
assert blocks[0].block_id == 2 assert blocks.blocks[0].block_id == 2
# Free the blocks. # Free the blocks.
manager.free(req0) manager.free(req0)
@ -434,14 +435,14 @@ def test_computed_blocks_not_evicted():
# cached block rather than the first one. # cached block rather than the first one.
req2 = make_request("2", list(range(num_tokens * 2))) req2 = make_request("2", list(range(num_tokens * 2)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(computed_blocks) == 1 assert len(computed_blocks.blocks) == 1
assert computed_blocks[0].block_id == 1 assert computed_blocks.blocks[0].block_id == 1
assert num_computed_tokens == block_size assert num_computed_tokens == block_size
blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens, blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens,
computed_blocks) computed_blocks)
assert len(blocks) == 1 assert len(blocks.blocks) == 1
assert blocks[0].block_id == 2 assert blocks.blocks[0].block_id == 2
def test_basic_prefix_caching_disabled(): def test_basic_prefix_caching_disabled():
@ -458,10 +459,10 @@ def test_basic_prefix_caching_disabled():
req1 = make_request("1", list(range(10))) # 2 blocks and some more req1 = make_request("1", list(range(10))) # 2 blocks and some more
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks assert not computed_blocks.blocks
assert num_computed_tokens == 0 assert num_computed_tokens == 0
blocks = manager.allocate_slots(req1, 10, computed_blocks) blocks = manager.allocate_slots(req1, 10, computed_blocks)
assert len(blocks) == 3 assert len(blocks.blocks) == 3
# Free the blocks. # Free the blocks.
manager.free(req1) manager.free(req1)
@ -469,15 +470,15 @@ def test_basic_prefix_caching_disabled():
# No caching. # No caching.
req2 = make_request("2", list(range(16))) # shared prefix req2 = make_request("2", list(range(16))) # shared prefix
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert not computed_blocks assert not computed_blocks.blocks
assert num_computed_tokens == 0 assert num_computed_tokens == 0
blocks = manager.allocate_slots(req2, 16, computed_blocks) blocks = manager.allocate_slots(req2, 16, computed_blocks)
assert len(blocks) == 4 assert len(blocks.blocks) == 4
# New requests should not have any blocks. # New requests should not have any blocks.
req3 = make_request("3", list(range(4))) req3 = make_request("3", list(range(4)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert not computed_blocks assert not computed_blocks.blocks
assert num_computed_tokens == 0 assert num_computed_tokens == 0
blocks = manager.allocate_slots(req3, 4, computed_blocks) blocks = manager.allocate_slots(req3, 4, computed_blocks)
assert not blocks assert not blocks
@ -569,7 +570,7 @@ def test_mm_prefix_caching():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
# Completed block should have hashes with extra keys. # Completed block should have hashes with extra keys.
assert not computed_blocks assert not computed_blocks.blocks
assert num_computed_tokens == 0 assert num_computed_tokens == 0
block_hashes = manager.req_to_block_hashes[req0.request_id] block_hashes = manager.req_to_block_hashes[req0.request_id]
assert len(block_hashes) == 3 assert len(block_hashes) == 3
@ -578,14 +579,14 @@ def test_mm_prefix_caching():
assert block_hashes[2].extra_keys == ("bbb", ) assert block_hashes[2].extra_keys == ("bbb", )
blocks = manager.allocate_slots(req0, 59, computed_blocks) blocks = manager.allocate_slots(req0, 59, computed_blocks)
assert [b.block_id for b in blocks] == [1, 2, 3, 4] assert blocks.get_block_ids() == [1, 2, 3, 4]
req0.num_computed_tokens = 59 req0.num_computed_tokens = 59
# Append slots without allocating a new block. # Append slots without allocating a new block.
for _ in range(5): for _ in range(5):
req0.append_output_token_ids(8) req0.append_output_token_ids(8)
new_blocks = manager.allocate_slots(req0, 5) new_blocks = manager.allocate_slots(req0, 5)
assert new_blocks is not None and len(new_blocks) == 0 assert new_blocks is not None and len(new_blocks.blocks) == 0
# The just completed block should have hashes with extra keys. # The just completed block should have hashes with extra keys.
assert len(block_hashes) == 4 assert len(block_hashes) == 4
@ -603,7 +604,7 @@ def test_mm_prefix_caching():
mm_positions=mm_positions, mm_positions=mm_positions,
mm_hashes=mm_hashes) mm_hashes=mm_hashes)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(computed_blocks) == 3 assert len(computed_blocks.blocks) == 3
assert num_computed_tokens == 3 * 16 assert num_computed_tokens == 3 * 16
@ -626,7 +627,7 @@ def test_cache_key_salting():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
# Completed block should have hashes with extra keys. # Completed block should have hashes with extra keys.
assert not computed_blocks assert not computed_blocks.blocks
assert num_computed_tokens == 0 assert num_computed_tokens == 0
block_hashes = manager.req_to_block_hashes[req0.request_id] block_hashes = manager.req_to_block_hashes[req0.request_id]
assert len(block_hashes) == 3 assert len(block_hashes) == 3
@ -635,14 +636,14 @@ def test_cache_key_salting():
assert block_hashes[2].extra_keys is None assert block_hashes[2].extra_keys is None
blocks = manager.allocate_slots(req0, 59, computed_blocks) blocks = manager.allocate_slots(req0, 59, computed_blocks)
assert [b.block_id for b in blocks] == [1, 2, 3, 4] assert blocks.get_block_ids() == [1, 2, 3, 4]
req0.num_computed_tokens = 59 req0.num_computed_tokens = 59
# Append slots without allocating a new block. # Append slots without allocating a new block.
for _ in range(5): for _ in range(5):
req0.append_output_token_ids(8) req0.append_output_token_ids(8)
new_blocks = manager.allocate_slots(req0, 5) new_blocks = manager.allocate_slots(req0, 5)
assert new_blocks is not None and len(new_blocks) == 0 assert new_blocks is not None and len(new_blocks.blocks) == 0
# Now one more block that should not have extra keys. # Now one more block that should not have extra keys.
assert len(block_hashes) == 4 assert len(block_hashes) == 4
@ -653,14 +654,14 @@ def test_cache_key_salting():
req1 = make_request("1", token_ids, cache_salt="salt1") req1 = make_request("1", token_ids, cache_salt="salt1")
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
# Should match only a prefix of 3 blocks. # Should match only a prefix of 3 blocks.
assert len(computed_blocks) == 3 assert len(computed_blocks.blocks) == 3
assert num_computed_tokens == 3 * block_size assert num_computed_tokens == 3 * block_size
# Test cache miss with same content but different salt. # Test cache miss with same content but different salt.
token_ids = common_token_ids + [4] * 11 token_ids = common_token_ids + [4] * 11
req2 = make_request("2", token_ids, cache_salt="salt2") req2 = make_request("2", token_ids, cache_salt="salt2")
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(computed_blocks) == 0 assert len(computed_blocks.blocks) == 0
assert num_computed_tokens == 0 assert num_computed_tokens == 0
block_hashes = manager.req_to_block_hashes[req2.request_id] block_hashes = manager.req_to_block_hashes[req2.request_id]
assert len(block_hashes) == 3 assert len(block_hashes) == 3
@ -685,7 +686,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
common_token_ids = [i for i in range(3) for _ in range(16)] common_token_ids = [i for i in range(3) for _ in range(16)]
req0 = make_request("0", common_token_ids) req0 = make_request("0", common_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks assert not computed_blocks.blocks
assert num_computed_tokens == 0 assert num_computed_tokens == 0
manager.allocate_slots(req0, 48, computed_blocks) manager.allocate_slots(req0, 48, computed_blocks)
block_part0 = manager.req_to_blocks[req0.request_id] block_part0 = manager.req_to_blocks[req0.request_id]
@ -693,7 +694,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... | # | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
req1 = make_request("1", common_token_ids * 2) req1 = make_request("1", common_token_ids * 2)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert computed_blocks == block_part0 assert computed_blocks.blocks == block_part0
assert num_computed_tokens == 3 * 16 assert num_computed_tokens == 3 * 16
manager.allocate_slots(req1, 48, computed_blocks) manager.allocate_slots(req1, 48, computed_blocks)
block_part1 = manager.req_to_blocks[req1.request_id] block_part1 = manager.req_to_blocks[req1.request_id]
@ -707,7 +708,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# | Req1-5(F)| Req2-0 | Req2-1 | ... | # | Req1-5(F)| Req2-0 | Req2-1 | ... |
req2 = make_request("2", [7] * block_size * 2) req2 = make_request("2", [7] * block_size * 2)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert not computed_blocks assert not computed_blocks.blocks
assert num_computed_tokens == 0 assert num_computed_tokens == 0
manager.allocate_slots(req2, block_size * 2, computed_blocks) manager.allocate_slots(req2, block_size * 2, computed_blocks)
@ -717,7 +718,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
assert manager.block_pool.free_block_queue.num_free_blocks == 5 assert manager.block_pool.free_block_queue.num_free_blocks == 5
req3 = make_request("3", common_token_ids * 3) req3 = make_request("3", common_token_ids * 3)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert computed_blocks == block_part1 assert computed_blocks.blocks == block_part1
assert num_computed_tokens == 6 * 16 assert num_computed_tokens == 6 * 16
# Req3 cannot be allocated. # Req3 cannot be allocated.
assert manager.allocate_slots(req3, 48, computed_blocks) is None assert manager.allocate_slots(req3, 48, computed_blocks) is None
@ -739,16 +740,16 @@ def test_reset_prefix_cache():
all_token_ids = full_block_token_ids + unique_token_ids all_token_ids = full_block_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids) req0 = make_request("0", all_token_ids)
blocks = manager.allocate_slots(req0, 55) blocks = manager.allocate_slots(req0, 55)
assert [b.block_id for b in blocks] == [1, 2, 3, 4] assert blocks.get_block_ids() == [1, 2, 3, 4]
unique_token_ids = [4] * 7 unique_token_ids = [4] * 7
all_token_ids = full_block_token_ids + unique_token_ids all_token_ids = full_block_token_ids + unique_token_ids
req1 = make_request("1", all_token_ids) req1 = make_request("1", all_token_ids)
computed_blocks, _ = manager.get_computed_blocks(req1) computed_blocks, _ = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3 assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert len(computed_blocks) == 3 assert len(computed_blocks.blocks) == 3
blocks = manager.allocate_slots(req1, 7, computed_blocks) blocks = manager.allocate_slots(req1, 7, computed_blocks)
assert [b.block_id for b in blocks] == [5] assert blocks.get_block_ids() == [5]
# Failed to reset prefix cache because some blocks are not freed yet. # Failed to reset prefix cache because some blocks are not freed yet.
assert not manager.reset_prefix_cache() assert not manager.reset_prefix_cache()
@ -776,7 +777,7 @@ def test_prefix_cache_stats_disabled():
# Call all functions that check whether log_stats is disabled. # Call all functions that check whether log_stats is disabled.
req = make_request("0", list(range(16))) req = make_request("0", list(range(16)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks assert not computed_blocks.blocks
assert num_computed_tokens == 0 assert num_computed_tokens == 0
manager.allocate_slots(req, 16, computed_blocks) manager.allocate_slots(req, 16, computed_blocks)
manager.reset_prefix_cache() manager.reset_prefix_cache()
@ -866,7 +867,7 @@ def test_eagle_enabled_removes_last_block():
# Should retain 1 block: # Should retain 1 block:
# 1. Original 3 blocks → pop last hash → 2 matched blocks # 1. Original 3 blocks → pop last hash → 2 matched blocks
# 2. drop last matched block → 1 remaining block # 2. drop last matched block → 1 remaining block
assert len(computed_blocks) == 1 assert len(computed_blocks.blocks) == 1
assert num_tokens == 1 * block_size # 16 tokens assert num_tokens == 1 * block_size # 16 tokens
@ -892,7 +893,7 @@ def test_eagle_with_partial_blocks():
req_eagle = make_request("partial_eagle", token_ids) req_eagle = make_request("partial_eagle", token_ids)
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining # Original match: 2 full blocks → Eagle removes 1 → 1 remaining
assert len(computed_blocks) == 1 assert len(computed_blocks.blocks) == 1
assert num_tokens == 1 * block_size assert num_tokens == 1 * block_size
@ -934,7 +935,7 @@ def test_eagle_with_sliding_window():
req_eagle = make_request("partial_eagle", token_ids) req_eagle = make_request("partial_eagle", token_ids)
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining # Original match: 2 full blocks → Eagle removes 1 → 1 remaining
assert len(computed_blocks) == 1 assert len(computed_blocks.blocks) == 1
assert num_tokens == 1 * block_size assert num_tokens == 1 * block_size
# Evict the first block in the request # Evict the first block in the request
@ -948,5 +949,5 @@ def test_eagle_with_sliding_window():
# Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is # Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is
# not considered. But after dropping the last matched block due to eagle, # not considered. But after dropping the last matched block due to eagle,
# there will be no matched prefix. # there will be no matched prefix.
assert len(computed_blocks) == 0 assert len(computed_blocks.blocks) == 0
assert num_tokens == 0 assert num_tokens == 0

View File

@ -2,6 +2,7 @@
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterable from collections.abc import Iterable
from dataclasses import dataclass
from typing import Optional from typing import Optional
from vllm.distributed.kv_events import KVCacheEvent from vllm.distributed.kv_events import KVCacheEvent
@ -18,6 +19,24 @@ from vllm.v1.request import Request, RequestStatus
logger = init_logger(__name__) logger = init_logger(__name__)
@dataclass
class KVCacheBlocks:
blocks: list[KVCacheBlock]
def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks":
"""Adds two KVCacheBlocks instances."""
return KVCacheBlocks(self.blocks + other.blocks)
@classmethod
def create_empty(cls) -> "KVCacheBlocks":
"""Creates a new KVCacheBlocks instance with no blocks."""
return cls([])
def get_block_ids(self) -> list[int]:
"""Converts the KVCacheBlocks instance to a list of block IDs."""
return [block.block_id for block in self.blocks]
class KVCacheManager: class KVCacheManager:
def __init__( def __init__(
@ -94,8 +113,8 @@ class KVCacheManager:
self.prefix_cache_stats = PrefixCacheStats() self.prefix_cache_stats = PrefixCacheStats()
return stats return stats
def get_computed_blocks( def get_computed_blocks(self,
self, request: Request) -> tuple[list[KVCacheBlock], int]: request: Request) -> tuple[KVCacheBlocks, int]:
"""Get the computed (cached) blocks for the request. """Get the computed (cached) blocks for the request.
Note that the computed blocks must be full. Note that the computed blocks must be full.
@ -109,7 +128,7 @@ class KVCacheManager:
""" """
if not self.enable_caching: if not self.enable_caching:
# Prefix caching is disabled. # Prefix caching is disabled.
return [], 0 return KVCacheBlocks.create_empty(), 0
# The block hashes for the request may already be computed # The block hashes for the request may already be computed
# if the scheduler has tried to schedule the request before. # if the scheduler has tried to schedule the request before.
@ -124,7 +143,7 @@ class KVCacheManager:
self.prefix_cache_stats.requests += 1 self.prefix_cache_stats.requests += 1
# When the request requires prompt logprobs, we skip prefix caching. # When the request requires prompt logprobs, we skip prefix caching.
if request.sampling_params.prompt_logprobs is not None: if request.sampling_params.prompt_logprobs is not None:
return [], 0 return KVCacheBlocks.create_empty(), 0
if len(block_hashes) * self.block_size == request.num_tokens: if len(block_hashes) * self.block_size == request.num_tokens:
# When prompt length is divisible by the block size and all # When prompt length is divisible by the block size and all
@ -157,15 +176,15 @@ class KVCacheManager:
# sharing, `num_computed_tokens` is always a multiple of # sharing, `num_computed_tokens` is always a multiple of
# `block_size`. # `block_size`.
num_computed_tokens = len(computed_blocks) * self.block_size num_computed_tokens = len(computed_blocks) * self.block_size
return computed_blocks, num_computed_tokens return KVCacheBlocks(computed_blocks), num_computed_tokens
def allocate_slots( def allocate_slots(
self, self,
request: Request, request: Request,
num_tokens: int, num_tokens: int,
new_computed_blocks: Optional[list[KVCacheBlock]] = None, new_computed_blocks: Optional[KVCacheBlocks] = None,
num_lookahead_tokens: int = 0, num_lookahead_tokens: int = 0,
) -> Optional[list[KVCacheBlock]]: ) -> Optional[KVCacheBlocks]:
"""Add slots for a request with new tokens to append. """Add slots for a request with new tokens to append.
Args: Args:
@ -173,7 +192,7 @@ class KVCacheManager:
num_tokens: The number of tokens to allocate, including external num_tokens: The number of tokens to allocate, including external
tokens. Note that this does not include tokens that have tokens. Note that this does not include tokens that have
already been computed locally (i.e. new_computed_blocks). already been computed locally (i.e. new_computed_blocks).
new_computed_blocks: A list of new computed blocks just hitting the new_computed_blocks: The new computed blocks just hitting the
prefix caching. prefix caching.
num_lookahead_tokens: The number of speculative tokens to allocate. num_lookahead_tokens: The number of speculative tokens to allocate.
This is used by spec decode proposers with kv-cache such This is used by spec decode proposers with kv-cache such
@ -199,7 +218,10 @@ class KVCacheManager:
if num_tokens == 0: if num_tokens == 0:
raise ValueError("num_tokens must be greater than 0") raise ValueError("num_tokens must be greater than 0")
new_computed_blocks = new_computed_blocks or [] if new_computed_blocks is not None:
new_computed_block_list = new_computed_blocks.blocks
else:
new_computed_block_list = []
req_blocks = self.req_to_blocks[request.request_id] req_blocks = self.req_to_blocks[request.request_id]
@ -216,17 +238,18 @@ class KVCacheManager:
# The number of computed tokens is the number of computed tokens plus # The number of computed tokens is the number of computed tokens plus
# the new prefix caching hits # the new prefix caching hits
num_computed_tokens = (request.num_computed_tokens + num_computed_tokens = (request.num_computed_tokens +
len(new_computed_blocks) * self.block_size) len(new_computed_block_list) * self.block_size)
num_required_blocks = cdiv( num_required_blocks = cdiv(
num_computed_tokens + num_tokens + num_lookahead_tokens, num_computed_tokens + num_tokens + num_lookahead_tokens,
self.block_size) self.block_size)
num_new_blocks = (num_required_blocks - len(req_blocks) - num_new_blocks = (num_required_blocks - len(req_blocks) -
len(new_computed_blocks)) len(new_computed_block_list))
# If a computed block of a request is an eviction candidate (in the # If a computed block of a request is an eviction candidate (in the
# free queue and ref_cnt == 0), it cannot be counted as a free block # free queue and ref_cnt == 0), it cannot be counted as a free block
# when allocating this request. # when allocating this request.
num_evictable_computed_blocks = sum(1 for blk in new_computed_blocks num_evictable_computed_blocks = sum(1
for blk in new_computed_block_list
if blk.ref_cnt == 0) if blk.ref_cnt == 0)
if (num_new_blocks > self.block_pool.get_num_free_blocks() - if (num_new_blocks > self.block_pool.get_num_free_blocks() -
num_evictable_computed_blocks): num_evictable_computed_blocks):
@ -235,15 +258,15 @@ class KVCacheManager:
# Touch the computed blocks to make sure they won't be evicted. # Touch the computed blocks to make sure they won't be evicted.
if self.enable_caching: if self.enable_caching:
self.block_pool.touch(new_computed_blocks) self.block_pool.touch(new_computed_block_list)
else: else:
assert not new_computed_blocks, ( assert not new_computed_block_list, (
"Computed blocks should be empty when " "Computed blocks should be empty when "
"prefix caching is disabled") "prefix caching is disabled")
# Append the new computed blocks to the request blocks until now to # Append the new computed blocks to the request blocks until now to
# avoid the case where the new blocks cannot be allocated. # avoid the case where the new blocks cannot be allocated.
req_blocks.extend(new_computed_blocks) req_blocks.extend(new_computed_block_list)
# Start to handle new blocks # Start to handle new blocks
@ -267,12 +290,12 @@ class KVCacheManager:
req_blocks.extend(new_blocks) req_blocks.extend(new_blocks)
if not self.enable_caching: if not self.enable_caching:
return new_blocks return KVCacheBlocks(new_blocks)
# Use `new_computed_blocks` for a new request, and `num_cached_block` # Use `new_computed_block_list` for a new request, and
# for a running request. # `num_cached_block` for a running request.
num_cached_blocks = self.num_cached_block.get(request.request_id, num_cached_blocks = self.num_cached_block.get(
len(new_computed_blocks)) request.request_id, len(new_computed_block_list))
# Speculated tokens might be rejected in the future, so we does # Speculated tokens might be rejected in the future, so we does
# not cache any speculated tokens. We only cache blocks with # not cache any speculated tokens. We only cache blocks with
# generated (accepted) tokens. # generated (accepted) tokens.
@ -291,7 +314,7 @@ class KVCacheManager:
self.num_cached_block[ self.num_cached_block[
request.request_id] = num_full_blocks_after_append request.request_id] = num_full_blocks_after_append
return new_blocks return KVCacheBlocks(new_blocks)
def free(self, request: Request) -> None: def free(self, request: Request) -> None:
"""Free the blocks allocated for the request. """Free the blocks allocated for the request.

View File

@ -261,9 +261,8 @@ class Scheduler(SchedulerInterface):
# Therefore, we might introduce some additional # Therefore, we might introduce some additional
# cycle to fill in the bitmask, which could be a big no-op. # cycle to fill in the bitmask, which could be a big no-op.
structured_output_request_ids[request.request_id] = req_index structured_output_request_ids[request.request_id] = req_index
req_to_new_block_ids[request.request_id] = [ req_to_new_block_ids[request.request_id] = (
b.block_id for b in new_blocks new_blocks.get_block_ids())
]
num_scheduled_tokens[request.request_id] = num_new_tokens num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens token_budget -= num_new_tokens
req_index += 1 req_index += 1
@ -407,9 +406,8 @@ class Scheduler(SchedulerInterface):
if self.lora_config and request.lora_request: if self.lora_config and request.lora_request:
scheduled_loras.add(request.lora_request.lora_int_id) scheduled_loras.add(request.lora_request.lora_int_id)
req_to_new_block_ids[request.request_id] = [ req_to_new_block_ids[request.request_id] = (
b.block_id for b in computed_blocks + new_blocks computed_blocks + new_blocks).get_block_ids()
]
num_scheduled_tokens[request.request_id] = num_new_tokens num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING request.status = RequestStatus.RUNNING