[v1] Introduce KVCacheBlocks as interface between Scheduler and KVCacheManager (#17479)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang 2025-05-06 23:50:34 +08:00 committed by GitHub
parent 0d115460a7
commit aabcd2cae3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 121 additions and 99 deletions

View File

@ -542,7 +542,7 @@ def test_allocate_with_lookahead():
num_tokens=3,
num_lookahead_tokens=2, # Total required: 3+2=5 tokens
)
assert len(blocks) == 2 # ceil(5/4)=2 blocks
assert len(blocks.blocks) == 2 # ceil(5/4)=2 blocks
# Test case 2: With precomputed blocks
kv_cache_manager = KVCacheManager(kv_cache_config=config,
@ -553,7 +553,7 @@ def test_allocate_with_lookahead():
num_tokens=3,
num_lookahead_tokens=2,
)
assert len(blocks) == 2
assert len(blocks.blocks) == 2
# Test case 3: With precomputed blocks
# required_blocks = ceil((3 + 4) / 4) = 2
@ -564,4 +564,4 @@ def test_allocate_with_lookahead():
num_tokens=3,
num_lookahead_tokens=4,
)
assert len(blocks) == 2
assert len(blocks.blocks) == 2

View File

@ -79,10 +79,10 @@ def test_prefill(hash_algo):
req0 = make_request("0", all_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert len(manager.req_to_block_hashes[req0.request_id]) == 3
assert not computed_blocks
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [1, 2, 3, 4]
assert blocks.get_block_ids() == [1, 2, 3, 4]
# Check full block metadata
parent_block_hash = None
@ -105,12 +105,12 @@ def test_prefill(hash_algo):
req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert [b.block_id for b in computed_blocks] == [1, 2, 3]
assert computed_blocks.get_block_ids() == [1, 2, 3]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [5]
for block in computed_blocks:
assert blocks.get_block_ids() == [5]
for block in computed_blocks.blocks:
assert block.ref_cnt == 2
# At this point, we should have 5 free blocks left.
@ -137,11 +137,11 @@ def test_prefill(hash_algo):
req2 = make_request("2", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(manager.req_to_block_hashes[req2.request_id]) == 3
assert [b.block_id for b in computed_blocks] == [1, 2, 3]
assert computed_blocks.get_block_ids() == [1, 2, 3]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [6]
assert blocks.get_block_ids() == [6]
# Although we only have 6 free blocks, we have 8 blocks in
# the free block queue due to lazy removal.
@ -159,11 +159,11 @@ def test_prefill(hash_algo):
# Cache miss and eviction.
req3 = make_request("3", [99] * (16 * 10))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert not computed_blocks
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req3, 16 * 10, computed_blocks)
# This block ID order also checks the eviction order.
assert [b.block_id for b in blocks] == [7, 8, 9, 10, 4, 5, 6, 3, 2, 1]
assert blocks.get_block_ids() == [7, 8, 9, 10, 4, 5, 6, 3, 2, 1]
assert manager.block_pool.free_block_queue.num_free_blocks == 0
assert manager.block_pool.free_block_queue.free_list_head is None
assert manager.block_pool.free_block_queue.free_list_tail is None
@ -195,11 +195,11 @@ def test_prefill_plp():
req0 = make_request("0", all_token_ids, prompt_logprobs=5)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert len(manager.req_to_block_hashes[req0.request_id]) == 3
assert not computed_blocks
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [1, 2, 3, 4]
req0_block_hashes = [b.block_hash for b in blocks]
assert blocks.get_block_ids() == [1, 2, 3, 4]
req0_block_hashes = [b.block_hash for b in blocks.blocks]
# Check full block metadata
parent_block_hash = None
@ -223,12 +223,12 @@ def test_prefill_plp():
req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert [b.block_id for b in computed_blocks] == [1, 2, 3]
assert computed_blocks.get_block_ids() == [1, 2, 3]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
assert [b.block_id for b in blocks] == [5]
for block in computed_blocks:
assert blocks.get_block_ids() == [5]
for block in computed_blocks.blocks:
assert block.ref_cnt == 2
# At this point, we should have 5 free blocks left.
@ -257,12 +257,12 @@ def test_prefill_plp():
prompt_logprobs=5)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(manager.req_to_block_hashes[req2.request_id]) == 3
assert not computed_blocks
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req2, 55, computed_blocks)
block_ids = [b.block_id for b in blocks]
block_ids = blocks.get_block_ids()
# Duplicate cached blocks have different ids but same hashes vs request #0
assert [b.block_hash for b in blocks] == req0_block_hashes
assert [b.block_hash for b in blocks.blocks] == req0_block_hashes
assert block_ids != [1, 2, 3, 4]
# Request #2 block hashes are valid since request #0 hashes are.
@ -288,17 +288,17 @@ def test_decode():
unique_token_ids = [3] * 7
req0 = make_request("0", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks)
assert [b.block_id for b in blocks] == [1, 2, 3, 4]
assert blocks.get_block_ids() == [1, 2, 3, 4]
# Append slots without allocating a new block.
req0.num_computed_tokens = 55
for _ in range(4):
req0.append_output_token_ids(8)
new_blocks = manager.allocate_slots(req0, 4)
assert new_blocks is not None and len(new_blocks) == 0
assert new_blocks is not None and len(new_blocks.blocks) == 0
assert manager.req_to_blocks[req0.request_id][-1].block_hash is None
# Append slots with allocating a new block.
@ -308,7 +308,7 @@ def test_decode():
for _ in range(9 + 10):
req0.append_output_token_ids(7)
new_blocks = manager.allocate_slots(req0, 19)
assert new_blocks is not None and len(new_blocks) == 1
assert new_blocks is not None and len(new_blocks.blocks) == 1
assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None
assert manager.req_to_blocks[req0.request_id][-1].block_hash is None
@ -323,19 +323,19 @@ def test_evict():
last_token_id = 5 * 16 + 7
req0 = make_request("0", list(range(last_token_id)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks)
assert len(blocks) == 6 # 5 full + 1 partial
assert len(blocks.blocks) == 6 # 5 full + 1 partial
# 3 blocks.
req1 = make_request("1", list(range(last_token_id,
last_token_id + 3 * 16)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req1, 3 * 16, computed_blocks)
assert len(blocks) == 3 # 3 full blocks
assert len(blocks.blocks) == 3 # 3 full blocks
last_token_id += 3 * 16
# 10 - (6 + 3) == 1
@ -352,10 +352,10 @@ def test_evict():
# Touch the first 2 blocks.
req2 = make_request("2", list(range(2 * 16 + 3)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert [b.block_id for b in computed_blocks] == [1, 2]
assert computed_blocks.get_block_ids() == [1, 2]
assert num_computed_tokens == 2 * 16
blocks = manager.allocate_slots(req2, 3, computed_blocks)
assert [b.block_id for b in blocks] == [10]
assert blocks.get_block_ids() == [10]
assert manager.block_pool.free_block_queue.num_free_blocks == 7
@ -375,10 +375,10 @@ def test_hash_block_correct_reuse():
num_tokens = block_size * 1
req = make_request("0", list(range(num_tokens)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req, num_tokens, computed_blocks)
assert len(blocks) == 1
assert len(blocks.blocks) == 1
# Deallocate the block.
manager.free(req)
@ -387,12 +387,13 @@ def test_hash_block_correct_reuse():
# block is cleared.
req = make_request("1", list(range(num_tokens - 1)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks)
assert len(blocks) == 1
assert len(blocks.blocks) == 1
assert manager.block_pool.blocks[blocks[0].block_id].block_hash is None
assert manager.block_pool.blocks[
blocks.blocks[0].block_id].block_hash is None
def test_computed_blocks_not_evicted():
@ -411,20 +412,20 @@ def test_computed_blocks_not_evicted():
num_tokens = block_size * 1
req0 = make_request("0", list(range(num_tokens)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, num_tokens, computed_blocks)
assert len(blocks) == 1
assert blocks[0].block_id == 1
assert len(blocks.blocks) == 1
assert blocks.blocks[0].block_id == 1
# Allocate another block.
req1 = make_request("1", list(range(num_tokens, num_tokens * 2)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req1, num_tokens, computed_blocks)
assert len(blocks) == 1
assert blocks[0].block_id == 2
assert len(blocks.blocks) == 1
assert blocks.blocks[0].block_id == 2
# Free the blocks.
manager.free(req0)
@ -434,14 +435,14 @@ def test_computed_blocks_not_evicted():
# cached block rather than the first one.
req2 = make_request("2", list(range(num_tokens * 2)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(computed_blocks) == 1
assert computed_blocks[0].block_id == 1
assert len(computed_blocks.blocks) == 1
assert computed_blocks.blocks[0].block_id == 1
assert num_computed_tokens == block_size
blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens,
computed_blocks)
assert len(blocks) == 1
assert blocks[0].block_id == 2
assert len(blocks.blocks) == 1
assert blocks.blocks[0].block_id == 2
def test_basic_prefix_caching_disabled():
@ -458,10 +459,10 @@ def test_basic_prefix_caching_disabled():
req1 = make_request("1", list(range(10))) # 2 blocks and some more
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req1, 10, computed_blocks)
assert len(blocks) == 3
assert len(blocks.blocks) == 3
# Free the blocks.
manager.free(req1)
@ -469,15 +470,15 @@ def test_basic_prefix_caching_disabled():
# No caching.
req2 = make_request("2", list(range(16))) # shared prefix
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert not computed_blocks
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req2, 16, computed_blocks)
assert len(blocks) == 4
assert len(blocks.blocks) == 4
# New requests should not have any blocks.
req3 = make_request("3", list(range(4)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert not computed_blocks
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req3, 4, computed_blocks)
assert not blocks
@ -569,7 +570,7 @@ def test_mm_prefix_caching():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
# Completed block should have hashes with extra keys.
assert not computed_blocks
assert not computed_blocks.blocks
assert num_computed_tokens == 0
block_hashes = manager.req_to_block_hashes[req0.request_id]
assert len(block_hashes) == 3
@ -578,14 +579,14 @@ def test_mm_prefix_caching():
assert block_hashes[2].extra_keys == ("bbb", )
blocks = manager.allocate_slots(req0, 59, computed_blocks)
assert [b.block_id for b in blocks] == [1, 2, 3, 4]
assert blocks.get_block_ids() == [1, 2, 3, 4]
req0.num_computed_tokens = 59
# Append slots without allocating a new block.
for _ in range(5):
req0.append_output_token_ids(8)
new_blocks = manager.allocate_slots(req0, 5)
assert new_blocks is not None and len(new_blocks) == 0
assert new_blocks is not None and len(new_blocks.blocks) == 0
# The just completed block should have hashes with extra keys.
assert len(block_hashes) == 4
@ -603,7 +604,7 @@ def test_mm_prefix_caching():
mm_positions=mm_positions,
mm_hashes=mm_hashes)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(computed_blocks) == 3
assert len(computed_blocks.blocks) == 3
assert num_computed_tokens == 3 * 16
@ -626,7 +627,7 @@ def test_cache_key_salting():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
# Completed block should have hashes with extra keys.
assert not computed_blocks
assert not computed_blocks.blocks
assert num_computed_tokens == 0
block_hashes = manager.req_to_block_hashes[req0.request_id]
assert len(block_hashes) == 3
@ -635,14 +636,14 @@ def test_cache_key_salting():
assert block_hashes[2].extra_keys is None
blocks = manager.allocate_slots(req0, 59, computed_blocks)
assert [b.block_id for b in blocks] == [1, 2, 3, 4]
assert blocks.get_block_ids() == [1, 2, 3, 4]
req0.num_computed_tokens = 59
# Append slots without allocating a new block.
for _ in range(5):
req0.append_output_token_ids(8)
new_blocks = manager.allocate_slots(req0, 5)
assert new_blocks is not None and len(new_blocks) == 0
assert new_blocks is not None and len(new_blocks.blocks) == 0
# Now one more block that should not have extra keys.
assert len(block_hashes) == 4
@ -653,14 +654,14 @@ def test_cache_key_salting():
req1 = make_request("1", token_ids, cache_salt="salt1")
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
# Should match only a prefix of 3 blocks.
assert len(computed_blocks) == 3
assert len(computed_blocks.blocks) == 3
assert num_computed_tokens == 3 * block_size
# Test cache miss with same content but different salt.
token_ids = common_token_ids + [4] * 11
req2 = make_request("2", token_ids, cache_salt="salt2")
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(computed_blocks) == 0
assert len(computed_blocks.blocks) == 0
assert num_computed_tokens == 0
block_hashes = manager.req_to_block_hashes[req2.request_id]
assert len(block_hashes) == 3
@ -685,7 +686,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
common_token_ids = [i for i in range(3) for _ in range(16)]
req0 = make_request("0", common_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks
assert not computed_blocks.blocks
assert num_computed_tokens == 0
manager.allocate_slots(req0, 48, computed_blocks)
block_part0 = manager.req_to_blocks[req0.request_id]
@ -693,7 +694,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
req1 = make_request("1", common_token_ids * 2)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert computed_blocks == block_part0
assert computed_blocks.blocks == block_part0
assert num_computed_tokens == 3 * 16
manager.allocate_slots(req1, 48, computed_blocks)
block_part1 = manager.req_to_blocks[req1.request_id]
@ -707,7 +708,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# | Req1-5(F)| Req2-0 | Req2-1 | ... |
req2 = make_request("2", [7] * block_size * 2)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert not computed_blocks
assert not computed_blocks.blocks
assert num_computed_tokens == 0
manager.allocate_slots(req2, block_size * 2, computed_blocks)
@ -717,7 +718,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
assert manager.block_pool.free_block_queue.num_free_blocks == 5
req3 = make_request("3", common_token_ids * 3)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert computed_blocks == block_part1
assert computed_blocks.blocks == block_part1
assert num_computed_tokens == 6 * 16
# Req3 cannot be allocated.
assert manager.allocate_slots(req3, 48, computed_blocks) is None
@ -739,16 +740,16 @@ def test_reset_prefix_cache():
all_token_ids = full_block_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids)
blocks = manager.allocate_slots(req0, 55)
assert [b.block_id for b in blocks] == [1, 2, 3, 4]
assert blocks.get_block_ids() == [1, 2, 3, 4]
unique_token_ids = [4] * 7
all_token_ids = full_block_token_ids + unique_token_ids
req1 = make_request("1", all_token_ids)
computed_blocks, _ = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert len(computed_blocks) == 3
assert len(computed_blocks.blocks) == 3
blocks = manager.allocate_slots(req1, 7, computed_blocks)
assert [b.block_id for b in blocks] == [5]
assert blocks.get_block_ids() == [5]
# Failed to reset prefix cache because some blocks are not freed yet.
assert not manager.reset_prefix_cache()
@ -776,7 +777,7 @@ def test_prefix_cache_stats_disabled():
# Call all functions that check whether log_stats is disabled.
req = make_request("0", list(range(16)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks
assert not computed_blocks.blocks
assert num_computed_tokens == 0
manager.allocate_slots(req, 16, computed_blocks)
manager.reset_prefix_cache()
@ -866,7 +867,7 @@ def test_eagle_enabled_removes_last_block():
# Should retain 1 block:
# 1. Original 3 blocks → pop last hash → 2 matched blocks
# 2. drop last matched block → 1 remaining block
assert len(computed_blocks) == 1
assert len(computed_blocks.blocks) == 1
assert num_tokens == 1 * block_size # 16 tokens
@ -892,7 +893,7 @@ def test_eagle_with_partial_blocks():
req_eagle = make_request("partial_eagle", token_ids)
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
assert len(computed_blocks) == 1
assert len(computed_blocks.blocks) == 1
assert num_tokens == 1 * block_size
@ -934,7 +935,7 @@ def test_eagle_with_sliding_window():
req_eagle = make_request("partial_eagle", token_ids)
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
assert len(computed_blocks) == 1
assert len(computed_blocks.blocks) == 1
assert num_tokens == 1 * block_size
# Evict the first block in the request
@ -948,5 +949,5 @@ def test_eagle_with_sliding_window():
# Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is
# not considered. But after dropping the last matched block due to eagle,
# there will be no matched prefix.
assert len(computed_blocks) == 0
assert len(computed_blocks.blocks) == 0
assert num_tokens == 0

View File

@ -2,6 +2,7 @@
from collections import defaultdict
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Optional
from vllm.distributed.kv_events import KVCacheEvent
@ -18,6 +19,24 @@ from vllm.v1.request import Request, RequestStatus
logger = init_logger(__name__)
@dataclass
class KVCacheBlocks:
blocks: list[KVCacheBlock]
def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks":
"""Adds two KVCacheBlocks instances."""
return KVCacheBlocks(self.blocks + other.blocks)
@classmethod
def create_empty(cls) -> "KVCacheBlocks":
"""Creates a new KVCacheBlocks instance with no blocks."""
return cls([])
def get_block_ids(self) -> list[int]:
"""Converts the KVCacheBlocks instance to a list of block IDs."""
return [block.block_id for block in self.blocks]
class KVCacheManager:
def __init__(
@ -94,8 +113,8 @@ class KVCacheManager:
self.prefix_cache_stats = PrefixCacheStats()
return stats
def get_computed_blocks(
self, request: Request) -> tuple[list[KVCacheBlock], int]:
def get_computed_blocks(self,
request: Request) -> tuple[KVCacheBlocks, int]:
"""Get the computed (cached) blocks for the request.
Note that the computed blocks must be full.
@ -109,7 +128,7 @@ class KVCacheManager:
"""
if not self.enable_caching:
# Prefix caching is disabled.
return [], 0
return KVCacheBlocks.create_empty(), 0
# The block hashes for the request may already be computed
# if the scheduler has tried to schedule the request before.
@ -124,7 +143,7 @@ class KVCacheManager:
self.prefix_cache_stats.requests += 1
# When the request requires prompt logprobs, we skip prefix caching.
if request.sampling_params.prompt_logprobs is not None:
return [], 0
return KVCacheBlocks.create_empty(), 0
if len(block_hashes) * self.block_size == request.num_tokens:
# When prompt length is divisible by the block size and all
@ -157,15 +176,15 @@ class KVCacheManager:
# sharing, `num_computed_tokens` is always a multiple of
# `block_size`.
num_computed_tokens = len(computed_blocks) * self.block_size
return computed_blocks, num_computed_tokens
return KVCacheBlocks(computed_blocks), num_computed_tokens
def allocate_slots(
self,
request: Request,
num_tokens: int,
new_computed_blocks: Optional[list[KVCacheBlock]] = None,
new_computed_blocks: Optional[KVCacheBlocks] = None,
num_lookahead_tokens: int = 0,
) -> Optional[list[KVCacheBlock]]:
) -> Optional[KVCacheBlocks]:
"""Add slots for a request with new tokens to append.
Args:
@ -173,7 +192,7 @@ class KVCacheManager:
num_tokens: The number of tokens to allocate, including external
tokens. Note that this does not include tokens that have
already been computed locally (i.e. new_computed_blocks).
new_computed_blocks: A list of new computed blocks just hitting the
new_computed_blocks: The new computed blocks just hitting the
prefix caching.
num_lookahead_tokens: The number of speculative tokens to allocate.
This is used by spec decode proposers with kv-cache such
@ -199,7 +218,10 @@ class KVCacheManager:
if num_tokens == 0:
raise ValueError("num_tokens must be greater than 0")
new_computed_blocks = new_computed_blocks or []
if new_computed_blocks is not None:
new_computed_block_list = new_computed_blocks.blocks
else:
new_computed_block_list = []
req_blocks = self.req_to_blocks[request.request_id]
@ -216,17 +238,18 @@ class KVCacheManager:
# The number of computed tokens is the number of computed tokens plus
# the new prefix caching hits
num_computed_tokens = (request.num_computed_tokens +
len(new_computed_blocks) * self.block_size)
len(new_computed_block_list) * self.block_size)
num_required_blocks = cdiv(
num_computed_tokens + num_tokens + num_lookahead_tokens,
self.block_size)
num_new_blocks = (num_required_blocks - len(req_blocks) -
len(new_computed_blocks))
len(new_computed_block_list))
# If a computed block of a request is an eviction candidate (in the
# free queue and ref_cnt == 0), it cannot be counted as a free block
# when allocating this request.
num_evictable_computed_blocks = sum(1 for blk in new_computed_blocks
num_evictable_computed_blocks = sum(1
for blk in new_computed_block_list
if blk.ref_cnt == 0)
if (num_new_blocks > self.block_pool.get_num_free_blocks() -
num_evictable_computed_blocks):
@ -235,15 +258,15 @@ class KVCacheManager:
# Touch the computed blocks to make sure they won't be evicted.
if self.enable_caching:
self.block_pool.touch(new_computed_blocks)
self.block_pool.touch(new_computed_block_list)
else:
assert not new_computed_blocks, (
assert not new_computed_block_list, (
"Computed blocks should be empty when "
"prefix caching is disabled")
# Append the new computed blocks to the request blocks until now to
# avoid the case where the new blocks cannot be allocated.
req_blocks.extend(new_computed_blocks)
req_blocks.extend(new_computed_block_list)
# Start to handle new blocks
@ -267,12 +290,12 @@ class KVCacheManager:
req_blocks.extend(new_blocks)
if not self.enable_caching:
return new_blocks
return KVCacheBlocks(new_blocks)
# Use `new_computed_blocks` for a new request, and `num_cached_block`
# for a running request.
num_cached_blocks = self.num_cached_block.get(request.request_id,
len(new_computed_blocks))
# Use `new_computed_block_list` for a new request, and
# `num_cached_block` for a running request.
num_cached_blocks = self.num_cached_block.get(
request.request_id, len(new_computed_block_list))
# Speculated tokens might be rejected in the future, so we does
# not cache any speculated tokens. We only cache blocks with
# generated (accepted) tokens.
@ -291,7 +314,7 @@ class KVCacheManager:
self.num_cached_block[
request.request_id] = num_full_blocks_after_append
return new_blocks
return KVCacheBlocks(new_blocks)
def free(self, request: Request) -> None:
"""Free the blocks allocated for the request.

View File

@ -261,9 +261,8 @@ class Scheduler(SchedulerInterface):
# Therefore, we might introduce some additional
# cycle to fill in the bitmask, which could be a big no-op.
structured_output_request_ids[request.request_id] = req_index
req_to_new_block_ids[request.request_id] = [
b.block_id for b in new_blocks
]
req_to_new_block_ids[request.request_id] = (
new_blocks.get_block_ids())
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
req_index += 1
@ -407,9 +406,8 @@ class Scheduler(SchedulerInterface):
if self.lora_config and request.lora_request:
scheduled_loras.add(request.lora_request.lora_int_id)
req_to_new_block_ids[request.request_id] = [
b.block_id for b in computed_blocks + new_blocks
]
req_to_new_block_ids[request.request_id] = (
computed_blocks + new_blocks).get_block_ids()
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING