mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-26 08:54:28 +08:00
[v1] Introduce KVCacheBlocks as interface between Scheduler and KVCacheManager (#17479)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
parent
0d115460a7
commit
aabcd2cae3
@ -542,7 +542,7 @@ def test_allocate_with_lookahead():
|
||||
num_tokens=3,
|
||||
num_lookahead_tokens=2, # Total required: 3+2=5 tokens
|
||||
)
|
||||
assert len(blocks) == 2 # ceil(5/4)=2 blocks
|
||||
assert len(blocks.blocks) == 2 # ceil(5/4)=2 blocks
|
||||
|
||||
# Test case 2: With precomputed blocks
|
||||
kv_cache_manager = KVCacheManager(kv_cache_config=config,
|
||||
@ -553,7 +553,7 @@ def test_allocate_with_lookahead():
|
||||
num_tokens=3,
|
||||
num_lookahead_tokens=2,
|
||||
)
|
||||
assert len(blocks) == 2
|
||||
assert len(blocks.blocks) == 2
|
||||
|
||||
# Test case 3: With precomputed blocks
|
||||
# required_blocks = ceil((3 + 4) / 4) = 2
|
||||
@ -564,4 +564,4 @@ def test_allocate_with_lookahead():
|
||||
num_tokens=3,
|
||||
num_lookahead_tokens=4,
|
||||
)
|
||||
assert len(blocks) == 2
|
||||
assert len(blocks.blocks) == 2
|
||||
|
||||
@ -79,10 +79,10 @@ def test_prefill(hash_algo):
|
||||
req0 = make_request("0", all_token_ids)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert len(manager.req_to_block_hashes[req0.request_id]) == 3
|
||||
assert not computed_blocks
|
||||
assert not computed_blocks.blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req0, 55, computed_blocks)
|
||||
assert [b.block_id for b in blocks] == [1, 2, 3, 4]
|
||||
assert blocks.get_block_ids() == [1, 2, 3, 4]
|
||||
|
||||
# Check full block metadata
|
||||
parent_block_hash = None
|
||||
@ -105,12 +105,12 @@ def test_prefill(hash_algo):
|
||||
req1 = make_request("1", common_token_ids + unique_token_ids)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
|
||||
assert [b.block_id for b in computed_blocks] == [1, 2, 3]
|
||||
assert computed_blocks.get_block_ids() == [1, 2, 3]
|
||||
assert num_computed_tokens == 3 * 16
|
||||
num_new_tokens = 53 - 3 * 16
|
||||
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
|
||||
assert [b.block_id for b in blocks] == [5]
|
||||
for block in computed_blocks:
|
||||
assert blocks.get_block_ids() == [5]
|
||||
for block in computed_blocks.blocks:
|
||||
assert block.ref_cnt == 2
|
||||
|
||||
# At this point, we should have 5 free blocks left.
|
||||
@ -137,11 +137,11 @@ def test_prefill(hash_algo):
|
||||
req2 = make_request("2", common_token_ids + unique_token_ids)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert len(manager.req_to_block_hashes[req2.request_id]) == 3
|
||||
assert [b.block_id for b in computed_blocks] == [1, 2, 3]
|
||||
assert computed_blocks.get_block_ids() == [1, 2, 3]
|
||||
assert num_computed_tokens == 3 * 16
|
||||
num_new_tokens = 53 - 3 * 16
|
||||
blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks)
|
||||
assert [b.block_id for b in blocks] == [6]
|
||||
assert blocks.get_block_ids() == [6]
|
||||
|
||||
# Although we only have 6 free blocks, we have 8 blocks in
|
||||
# the free block queue due to lazy removal.
|
||||
@ -159,11 +159,11 @@ def test_prefill(hash_algo):
|
||||
# Cache miss and eviction.
|
||||
req3 = make_request("3", [99] * (16 * 10))
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
|
||||
assert not computed_blocks
|
||||
assert not computed_blocks.blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req3, 16 * 10, computed_blocks)
|
||||
# This block ID order also checks the eviction order.
|
||||
assert [b.block_id for b in blocks] == [7, 8, 9, 10, 4, 5, 6, 3, 2, 1]
|
||||
assert blocks.get_block_ids() == [7, 8, 9, 10, 4, 5, 6, 3, 2, 1]
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 0
|
||||
assert manager.block_pool.free_block_queue.free_list_head is None
|
||||
assert manager.block_pool.free_block_queue.free_list_tail is None
|
||||
@ -195,11 +195,11 @@ def test_prefill_plp():
|
||||
req0 = make_request("0", all_token_ids, prompt_logprobs=5)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert len(manager.req_to_block_hashes[req0.request_id]) == 3
|
||||
assert not computed_blocks
|
||||
assert not computed_blocks.blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req0, 55, computed_blocks)
|
||||
assert [b.block_id for b in blocks] == [1, 2, 3, 4]
|
||||
req0_block_hashes = [b.block_hash for b in blocks]
|
||||
assert blocks.get_block_ids() == [1, 2, 3, 4]
|
||||
req0_block_hashes = [b.block_hash for b in blocks.blocks]
|
||||
|
||||
# Check full block metadata
|
||||
parent_block_hash = None
|
||||
@ -223,12 +223,12 @@ def test_prefill_plp():
|
||||
req1 = make_request("1", common_token_ids + unique_token_ids)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
|
||||
assert [b.block_id for b in computed_blocks] == [1, 2, 3]
|
||||
assert computed_blocks.get_block_ids() == [1, 2, 3]
|
||||
assert num_computed_tokens == 3 * 16
|
||||
num_new_tokens = 53 - 3 * 16
|
||||
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
|
||||
assert [b.block_id for b in blocks] == [5]
|
||||
for block in computed_blocks:
|
||||
assert blocks.get_block_ids() == [5]
|
||||
for block in computed_blocks.blocks:
|
||||
assert block.ref_cnt == 2
|
||||
|
||||
# At this point, we should have 5 free blocks left.
|
||||
@ -257,12 +257,12 @@ def test_prefill_plp():
|
||||
prompt_logprobs=5)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert len(manager.req_to_block_hashes[req2.request_id]) == 3
|
||||
assert not computed_blocks
|
||||
assert not computed_blocks.blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req2, 55, computed_blocks)
|
||||
block_ids = [b.block_id for b in blocks]
|
||||
block_ids = blocks.get_block_ids()
|
||||
# Duplicate cached blocks have different ids but same hashes vs request #0
|
||||
assert [b.block_hash for b in blocks] == req0_block_hashes
|
||||
assert [b.block_hash for b in blocks.blocks] == req0_block_hashes
|
||||
assert block_ids != [1, 2, 3, 4]
|
||||
|
||||
# Request #2 block hashes are valid since request #0 hashes are.
|
||||
@ -288,17 +288,17 @@ def test_decode():
|
||||
unique_token_ids = [3] * 7
|
||||
req0 = make_request("0", common_token_ids + unique_token_ids)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert not computed_blocks
|
||||
assert not computed_blocks.blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req0, 55, computed_blocks)
|
||||
assert [b.block_id for b in blocks] == [1, 2, 3, 4]
|
||||
assert blocks.get_block_ids() == [1, 2, 3, 4]
|
||||
|
||||
# Append slots without allocating a new block.
|
||||
req0.num_computed_tokens = 55
|
||||
for _ in range(4):
|
||||
req0.append_output_token_ids(8)
|
||||
new_blocks = manager.allocate_slots(req0, 4)
|
||||
assert new_blocks is not None and len(new_blocks) == 0
|
||||
assert new_blocks is not None and len(new_blocks.blocks) == 0
|
||||
assert manager.req_to_blocks[req0.request_id][-1].block_hash is None
|
||||
|
||||
# Append slots with allocating a new block.
|
||||
@ -308,7 +308,7 @@ def test_decode():
|
||||
for _ in range(9 + 10):
|
||||
req0.append_output_token_ids(7)
|
||||
new_blocks = manager.allocate_slots(req0, 19)
|
||||
assert new_blocks is not None and len(new_blocks) == 1
|
||||
assert new_blocks is not None and len(new_blocks.blocks) == 1
|
||||
assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None
|
||||
assert manager.req_to_blocks[req0.request_id][-1].block_hash is None
|
||||
|
||||
@ -323,19 +323,19 @@ def test_evict():
|
||||
last_token_id = 5 * 16 + 7
|
||||
req0 = make_request("0", list(range(last_token_id)))
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert not computed_blocks
|
||||
assert not computed_blocks.blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks)
|
||||
assert len(blocks) == 6 # 5 full + 1 partial
|
||||
assert len(blocks.blocks) == 6 # 5 full + 1 partial
|
||||
|
||||
# 3 blocks.
|
||||
req1 = make_request("1", list(range(last_token_id,
|
||||
last_token_id + 3 * 16)))
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert not computed_blocks
|
||||
assert not computed_blocks.blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req1, 3 * 16, computed_blocks)
|
||||
assert len(blocks) == 3 # 3 full blocks
|
||||
assert len(blocks.blocks) == 3 # 3 full blocks
|
||||
last_token_id += 3 * 16
|
||||
|
||||
# 10 - (6 + 3) == 1
|
||||
@ -352,10 +352,10 @@ def test_evict():
|
||||
# Touch the first 2 blocks.
|
||||
req2 = make_request("2", list(range(2 * 16 + 3)))
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert [b.block_id for b in computed_blocks] == [1, 2]
|
||||
assert computed_blocks.get_block_ids() == [1, 2]
|
||||
assert num_computed_tokens == 2 * 16
|
||||
blocks = manager.allocate_slots(req2, 3, computed_blocks)
|
||||
assert [b.block_id for b in blocks] == [10]
|
||||
assert blocks.get_block_ids() == [10]
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 7
|
||||
|
||||
|
||||
@ -375,10 +375,10 @@ def test_hash_block_correct_reuse():
|
||||
num_tokens = block_size * 1
|
||||
req = make_request("0", list(range(num_tokens)))
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
||||
assert not computed_blocks
|
||||
assert not computed_blocks.blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req, num_tokens, computed_blocks)
|
||||
assert len(blocks) == 1
|
||||
assert len(blocks.blocks) == 1
|
||||
|
||||
# Deallocate the block.
|
||||
manager.free(req)
|
||||
@ -387,12 +387,13 @@ def test_hash_block_correct_reuse():
|
||||
# block is cleared.
|
||||
req = make_request("1", list(range(num_tokens - 1)))
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
||||
assert not computed_blocks
|
||||
assert not computed_blocks.blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks)
|
||||
assert len(blocks) == 1
|
||||
assert len(blocks.blocks) == 1
|
||||
|
||||
assert manager.block_pool.blocks[blocks[0].block_id].block_hash is None
|
||||
assert manager.block_pool.blocks[
|
||||
blocks.blocks[0].block_id].block_hash is None
|
||||
|
||||
|
||||
def test_computed_blocks_not_evicted():
|
||||
@ -411,20 +412,20 @@ def test_computed_blocks_not_evicted():
|
||||
num_tokens = block_size * 1
|
||||
req0 = make_request("0", list(range(num_tokens)))
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert not computed_blocks
|
||||
assert not computed_blocks.blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req0, num_tokens, computed_blocks)
|
||||
assert len(blocks) == 1
|
||||
assert blocks[0].block_id == 1
|
||||
assert len(blocks.blocks) == 1
|
||||
assert blocks.blocks[0].block_id == 1
|
||||
|
||||
# Allocate another block.
|
||||
req1 = make_request("1", list(range(num_tokens, num_tokens * 2)))
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert not computed_blocks
|
||||
assert not computed_blocks.blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req1, num_tokens, computed_blocks)
|
||||
assert len(blocks) == 1
|
||||
assert blocks[0].block_id == 2
|
||||
assert len(blocks.blocks) == 1
|
||||
assert blocks.blocks[0].block_id == 2
|
||||
|
||||
# Free the blocks.
|
||||
manager.free(req0)
|
||||
@ -434,14 +435,14 @@ def test_computed_blocks_not_evicted():
|
||||
# cached block rather than the first one.
|
||||
req2 = make_request("2", list(range(num_tokens * 2)))
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert len(computed_blocks) == 1
|
||||
assert computed_blocks[0].block_id == 1
|
||||
assert len(computed_blocks.blocks) == 1
|
||||
assert computed_blocks.blocks[0].block_id == 1
|
||||
assert num_computed_tokens == block_size
|
||||
|
||||
blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens,
|
||||
computed_blocks)
|
||||
assert len(blocks) == 1
|
||||
assert blocks[0].block_id == 2
|
||||
assert len(blocks.blocks) == 1
|
||||
assert blocks.blocks[0].block_id == 2
|
||||
|
||||
|
||||
def test_basic_prefix_caching_disabled():
|
||||
@ -458,10 +459,10 @@ def test_basic_prefix_caching_disabled():
|
||||
req1 = make_request("1", list(range(10))) # 2 blocks and some more
|
||||
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert not computed_blocks
|
||||
assert not computed_blocks.blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req1, 10, computed_blocks)
|
||||
assert len(blocks) == 3
|
||||
assert len(blocks.blocks) == 3
|
||||
|
||||
# Free the blocks.
|
||||
manager.free(req1)
|
||||
@ -469,15 +470,15 @@ def test_basic_prefix_caching_disabled():
|
||||
# No caching.
|
||||
req2 = make_request("2", list(range(16))) # shared prefix
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert not computed_blocks
|
||||
assert not computed_blocks.blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req2, 16, computed_blocks)
|
||||
assert len(blocks) == 4
|
||||
assert len(blocks.blocks) == 4
|
||||
|
||||
# New requests should not have any blocks.
|
||||
req3 = make_request("3", list(range(4)))
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
|
||||
assert not computed_blocks
|
||||
assert not computed_blocks.blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req3, 4, computed_blocks)
|
||||
assert not blocks
|
||||
@ -569,7 +570,7 @@ def test_mm_prefix_caching():
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
|
||||
# Completed block should have hashes with extra keys.
|
||||
assert not computed_blocks
|
||||
assert not computed_blocks.blocks
|
||||
assert num_computed_tokens == 0
|
||||
block_hashes = manager.req_to_block_hashes[req0.request_id]
|
||||
assert len(block_hashes) == 3
|
||||
@ -578,14 +579,14 @@ def test_mm_prefix_caching():
|
||||
assert block_hashes[2].extra_keys == ("bbb", )
|
||||
|
||||
blocks = manager.allocate_slots(req0, 59, computed_blocks)
|
||||
assert [b.block_id for b in blocks] == [1, 2, 3, 4]
|
||||
assert blocks.get_block_ids() == [1, 2, 3, 4]
|
||||
req0.num_computed_tokens = 59
|
||||
|
||||
# Append slots without allocating a new block.
|
||||
for _ in range(5):
|
||||
req0.append_output_token_ids(8)
|
||||
new_blocks = manager.allocate_slots(req0, 5)
|
||||
assert new_blocks is not None and len(new_blocks) == 0
|
||||
assert new_blocks is not None and len(new_blocks.blocks) == 0
|
||||
|
||||
# The just completed block should have hashes with extra keys.
|
||||
assert len(block_hashes) == 4
|
||||
@ -603,7 +604,7 @@ def test_mm_prefix_caching():
|
||||
mm_positions=mm_positions,
|
||||
mm_hashes=mm_hashes)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert len(computed_blocks) == 3
|
||||
assert len(computed_blocks.blocks) == 3
|
||||
assert num_computed_tokens == 3 * 16
|
||||
|
||||
|
||||
@ -626,7 +627,7 @@ def test_cache_key_salting():
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
|
||||
# Completed block should have hashes with extra keys.
|
||||
assert not computed_blocks
|
||||
assert not computed_blocks.blocks
|
||||
assert num_computed_tokens == 0
|
||||
block_hashes = manager.req_to_block_hashes[req0.request_id]
|
||||
assert len(block_hashes) == 3
|
||||
@ -635,14 +636,14 @@ def test_cache_key_salting():
|
||||
assert block_hashes[2].extra_keys is None
|
||||
|
||||
blocks = manager.allocate_slots(req0, 59, computed_blocks)
|
||||
assert [b.block_id for b in blocks] == [1, 2, 3, 4]
|
||||
assert blocks.get_block_ids() == [1, 2, 3, 4]
|
||||
req0.num_computed_tokens = 59
|
||||
|
||||
# Append slots without allocating a new block.
|
||||
for _ in range(5):
|
||||
req0.append_output_token_ids(8)
|
||||
new_blocks = manager.allocate_slots(req0, 5)
|
||||
assert new_blocks is not None and len(new_blocks) == 0
|
||||
assert new_blocks is not None and len(new_blocks.blocks) == 0
|
||||
|
||||
# Now one more block that should not have extra keys.
|
||||
assert len(block_hashes) == 4
|
||||
@ -653,14 +654,14 @@ def test_cache_key_salting():
|
||||
req1 = make_request("1", token_ids, cache_salt="salt1")
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
# Should match only a prefix of 3 blocks.
|
||||
assert len(computed_blocks) == 3
|
||||
assert len(computed_blocks.blocks) == 3
|
||||
assert num_computed_tokens == 3 * block_size
|
||||
|
||||
# Test cache miss with same content but different salt.
|
||||
token_ids = common_token_ids + [4] * 11
|
||||
req2 = make_request("2", token_ids, cache_salt="salt2")
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert len(computed_blocks) == 0
|
||||
assert len(computed_blocks.blocks) == 0
|
||||
assert num_computed_tokens == 0
|
||||
block_hashes = manager.req_to_block_hashes[req2.request_id]
|
||||
assert len(block_hashes) == 3
|
||||
@ -685,7 +686,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||
common_token_ids = [i for i in range(3) for _ in range(16)]
|
||||
req0 = make_request("0", common_token_ids)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert not computed_blocks
|
||||
assert not computed_blocks.blocks
|
||||
assert num_computed_tokens == 0
|
||||
manager.allocate_slots(req0, 48, computed_blocks)
|
||||
block_part0 = manager.req_to_blocks[req0.request_id]
|
||||
@ -693,7 +694,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||
# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
|
||||
req1 = make_request("1", common_token_ids * 2)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert computed_blocks == block_part0
|
||||
assert computed_blocks.blocks == block_part0
|
||||
assert num_computed_tokens == 3 * 16
|
||||
manager.allocate_slots(req1, 48, computed_blocks)
|
||||
block_part1 = manager.req_to_blocks[req1.request_id]
|
||||
@ -707,7 +708,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||
# | Req1-5(F)| Req2-0 | Req2-1 | ... |
|
||||
req2 = make_request("2", [7] * block_size * 2)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert not computed_blocks
|
||||
assert not computed_blocks.blocks
|
||||
assert num_computed_tokens == 0
|
||||
manager.allocate_slots(req2, block_size * 2, computed_blocks)
|
||||
|
||||
@ -717,7 +718,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 5
|
||||
req3 = make_request("3", common_token_ids * 3)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
|
||||
assert computed_blocks == block_part1
|
||||
assert computed_blocks.blocks == block_part1
|
||||
assert num_computed_tokens == 6 * 16
|
||||
# Req3 cannot be allocated.
|
||||
assert manager.allocate_slots(req3, 48, computed_blocks) is None
|
||||
@ -739,16 +740,16 @@ def test_reset_prefix_cache():
|
||||
all_token_ids = full_block_token_ids + unique_token_ids
|
||||
req0 = make_request("0", all_token_ids)
|
||||
blocks = manager.allocate_slots(req0, 55)
|
||||
assert [b.block_id for b in blocks] == [1, 2, 3, 4]
|
||||
assert blocks.get_block_ids() == [1, 2, 3, 4]
|
||||
|
||||
unique_token_ids = [4] * 7
|
||||
all_token_ids = full_block_token_ids + unique_token_ids
|
||||
req1 = make_request("1", all_token_ids)
|
||||
computed_blocks, _ = manager.get_computed_blocks(req1)
|
||||
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
|
||||
assert len(computed_blocks) == 3
|
||||
assert len(computed_blocks.blocks) == 3
|
||||
blocks = manager.allocate_slots(req1, 7, computed_blocks)
|
||||
assert [b.block_id for b in blocks] == [5]
|
||||
assert blocks.get_block_ids() == [5]
|
||||
|
||||
# Failed to reset prefix cache because some blocks are not freed yet.
|
||||
assert not manager.reset_prefix_cache()
|
||||
@ -776,7 +777,7 @@ def test_prefix_cache_stats_disabled():
|
||||
# Call all functions that check whether log_stats is disabled.
|
||||
req = make_request("0", list(range(16)))
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
||||
assert not computed_blocks
|
||||
assert not computed_blocks.blocks
|
||||
assert num_computed_tokens == 0
|
||||
manager.allocate_slots(req, 16, computed_blocks)
|
||||
manager.reset_prefix_cache()
|
||||
@ -866,7 +867,7 @@ def test_eagle_enabled_removes_last_block():
|
||||
# Should retain 1 block:
|
||||
# 1. Original 3 blocks → pop last hash → 2 matched blocks
|
||||
# 2. drop last matched block → 1 remaining block
|
||||
assert len(computed_blocks) == 1
|
||||
assert len(computed_blocks.blocks) == 1
|
||||
assert num_tokens == 1 * block_size # 16 tokens
|
||||
|
||||
|
||||
@ -892,7 +893,7 @@ def test_eagle_with_partial_blocks():
|
||||
req_eagle = make_request("partial_eagle", token_ids)
|
||||
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
|
||||
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
|
||||
assert len(computed_blocks) == 1
|
||||
assert len(computed_blocks.blocks) == 1
|
||||
assert num_tokens == 1 * block_size
|
||||
|
||||
|
||||
@ -934,7 +935,7 @@ def test_eagle_with_sliding_window():
|
||||
req_eagle = make_request("partial_eagle", token_ids)
|
||||
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
|
||||
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
|
||||
assert len(computed_blocks) == 1
|
||||
assert len(computed_blocks.blocks) == 1
|
||||
assert num_tokens == 1 * block_size
|
||||
|
||||
# Evict the first block in the request
|
||||
@ -948,5 +949,5 @@ def test_eagle_with_sliding_window():
|
||||
# Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is
|
||||
# not considered. But after dropping the last matched block due to eagle,
|
||||
# there will be no matched prefix.
|
||||
assert len(computed_blocks) == 0
|
||||
assert len(computed_blocks.blocks) == 0
|
||||
assert num_tokens == 0
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
@ -18,6 +19,24 @@ from vllm.v1.request import Request, RequestStatus
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVCacheBlocks:
|
||||
blocks: list[KVCacheBlock]
|
||||
|
||||
def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks":
|
||||
"""Adds two KVCacheBlocks instances."""
|
||||
return KVCacheBlocks(self.blocks + other.blocks)
|
||||
|
||||
@classmethod
|
||||
def create_empty(cls) -> "KVCacheBlocks":
|
||||
"""Creates a new KVCacheBlocks instance with no blocks."""
|
||||
return cls([])
|
||||
|
||||
def get_block_ids(self) -> list[int]:
|
||||
"""Converts the KVCacheBlocks instance to a list of block IDs."""
|
||||
return [block.block_id for block in self.blocks]
|
||||
|
||||
|
||||
class KVCacheManager:
|
||||
|
||||
def __init__(
|
||||
@ -94,8 +113,8 @@ class KVCacheManager:
|
||||
self.prefix_cache_stats = PrefixCacheStats()
|
||||
return stats
|
||||
|
||||
def get_computed_blocks(
|
||||
self, request: Request) -> tuple[list[KVCacheBlock], int]:
|
||||
def get_computed_blocks(self,
|
||||
request: Request) -> tuple[KVCacheBlocks, int]:
|
||||
"""Get the computed (cached) blocks for the request.
|
||||
Note that the computed blocks must be full.
|
||||
|
||||
@ -109,7 +128,7 @@ class KVCacheManager:
|
||||
"""
|
||||
if not self.enable_caching:
|
||||
# Prefix caching is disabled.
|
||||
return [], 0
|
||||
return KVCacheBlocks.create_empty(), 0
|
||||
|
||||
# The block hashes for the request may already be computed
|
||||
# if the scheduler has tried to schedule the request before.
|
||||
@ -124,7 +143,7 @@ class KVCacheManager:
|
||||
self.prefix_cache_stats.requests += 1
|
||||
# When the request requires prompt logprobs, we skip prefix caching.
|
||||
if request.sampling_params.prompt_logprobs is not None:
|
||||
return [], 0
|
||||
return KVCacheBlocks.create_empty(), 0
|
||||
|
||||
if len(block_hashes) * self.block_size == request.num_tokens:
|
||||
# When prompt length is divisible by the block size and all
|
||||
@ -157,15 +176,15 @@ class KVCacheManager:
|
||||
# sharing, `num_computed_tokens` is always a multiple of
|
||||
# `block_size`.
|
||||
num_computed_tokens = len(computed_blocks) * self.block_size
|
||||
return computed_blocks, num_computed_tokens
|
||||
return KVCacheBlocks(computed_blocks), num_computed_tokens
|
||||
|
||||
def allocate_slots(
|
||||
self,
|
||||
request: Request,
|
||||
num_tokens: int,
|
||||
new_computed_blocks: Optional[list[KVCacheBlock]] = None,
|
||||
new_computed_blocks: Optional[KVCacheBlocks] = None,
|
||||
num_lookahead_tokens: int = 0,
|
||||
) -> Optional[list[KVCacheBlock]]:
|
||||
) -> Optional[KVCacheBlocks]:
|
||||
"""Add slots for a request with new tokens to append.
|
||||
|
||||
Args:
|
||||
@ -173,7 +192,7 @@ class KVCacheManager:
|
||||
num_tokens: The number of tokens to allocate, including external
|
||||
tokens. Note that this does not include tokens that have
|
||||
already been computed locally (i.e. new_computed_blocks).
|
||||
new_computed_blocks: A list of new computed blocks just hitting the
|
||||
new_computed_blocks: The new computed blocks just hitting the
|
||||
prefix caching.
|
||||
num_lookahead_tokens: The number of speculative tokens to allocate.
|
||||
This is used by spec decode proposers with kv-cache such
|
||||
@ -199,7 +218,10 @@ class KVCacheManager:
|
||||
if num_tokens == 0:
|
||||
raise ValueError("num_tokens must be greater than 0")
|
||||
|
||||
new_computed_blocks = new_computed_blocks or []
|
||||
if new_computed_blocks is not None:
|
||||
new_computed_block_list = new_computed_blocks.blocks
|
||||
else:
|
||||
new_computed_block_list = []
|
||||
|
||||
req_blocks = self.req_to_blocks[request.request_id]
|
||||
|
||||
@ -216,17 +238,18 @@ class KVCacheManager:
|
||||
# The number of computed tokens is the number of computed tokens plus
|
||||
# the new prefix caching hits
|
||||
num_computed_tokens = (request.num_computed_tokens +
|
||||
len(new_computed_blocks) * self.block_size)
|
||||
len(new_computed_block_list) * self.block_size)
|
||||
num_required_blocks = cdiv(
|
||||
num_computed_tokens + num_tokens + num_lookahead_tokens,
|
||||
self.block_size)
|
||||
num_new_blocks = (num_required_blocks - len(req_blocks) -
|
||||
len(new_computed_blocks))
|
||||
len(new_computed_block_list))
|
||||
|
||||
# If a computed block of a request is an eviction candidate (in the
|
||||
# free queue and ref_cnt == 0), it cannot be counted as a free block
|
||||
# when allocating this request.
|
||||
num_evictable_computed_blocks = sum(1 for blk in new_computed_blocks
|
||||
num_evictable_computed_blocks = sum(1
|
||||
for blk in new_computed_block_list
|
||||
if blk.ref_cnt == 0)
|
||||
if (num_new_blocks > self.block_pool.get_num_free_blocks() -
|
||||
num_evictable_computed_blocks):
|
||||
@ -235,15 +258,15 @@ class KVCacheManager:
|
||||
|
||||
# Touch the computed blocks to make sure they won't be evicted.
|
||||
if self.enable_caching:
|
||||
self.block_pool.touch(new_computed_blocks)
|
||||
self.block_pool.touch(new_computed_block_list)
|
||||
else:
|
||||
assert not new_computed_blocks, (
|
||||
assert not new_computed_block_list, (
|
||||
"Computed blocks should be empty when "
|
||||
"prefix caching is disabled")
|
||||
|
||||
# Append the new computed blocks to the request blocks until now to
|
||||
# avoid the case where the new blocks cannot be allocated.
|
||||
req_blocks.extend(new_computed_blocks)
|
||||
req_blocks.extend(new_computed_block_list)
|
||||
|
||||
# Start to handle new blocks
|
||||
|
||||
@ -267,12 +290,12 @@ class KVCacheManager:
|
||||
req_blocks.extend(new_blocks)
|
||||
|
||||
if not self.enable_caching:
|
||||
return new_blocks
|
||||
return KVCacheBlocks(new_blocks)
|
||||
|
||||
# Use `new_computed_blocks` for a new request, and `num_cached_block`
|
||||
# for a running request.
|
||||
num_cached_blocks = self.num_cached_block.get(request.request_id,
|
||||
len(new_computed_blocks))
|
||||
# Use `new_computed_block_list` for a new request, and
|
||||
# `num_cached_block` for a running request.
|
||||
num_cached_blocks = self.num_cached_block.get(
|
||||
request.request_id, len(new_computed_block_list))
|
||||
# Speculated tokens might be rejected in the future, so we does
|
||||
# not cache any speculated tokens. We only cache blocks with
|
||||
# generated (accepted) tokens.
|
||||
@ -291,7 +314,7 @@ class KVCacheManager:
|
||||
|
||||
self.num_cached_block[
|
||||
request.request_id] = num_full_blocks_after_append
|
||||
return new_blocks
|
||||
return KVCacheBlocks(new_blocks)
|
||||
|
||||
def free(self, request: Request) -> None:
|
||||
"""Free the blocks allocated for the request.
|
||||
|
||||
@ -261,9 +261,8 @@ class Scheduler(SchedulerInterface):
|
||||
# Therefore, we might introduce some additional
|
||||
# cycle to fill in the bitmask, which could be a big no-op.
|
||||
structured_output_request_ids[request.request_id] = req_index
|
||||
req_to_new_block_ids[request.request_id] = [
|
||||
b.block_id for b in new_blocks
|
||||
]
|
||||
req_to_new_block_ids[request.request_id] = (
|
||||
new_blocks.get_block_ids())
|
||||
num_scheduled_tokens[request.request_id] = num_new_tokens
|
||||
token_budget -= num_new_tokens
|
||||
req_index += 1
|
||||
@ -407,9 +406,8 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
if self.lora_config and request.lora_request:
|
||||
scheduled_loras.add(request.lora_request.lora_int_id)
|
||||
req_to_new_block_ids[request.request_id] = [
|
||||
b.block_id for b in computed_blocks + new_blocks
|
||||
]
|
||||
req_to_new_block_ids[request.request_id] = (
|
||||
computed_blocks + new_blocks).get_block_ids()
|
||||
num_scheduled_tokens[request.request_id] = num_new_tokens
|
||||
token_budget -= num_new_tokens
|
||||
request.status = RequestStatus.RUNNING
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user