[Core] Use tuple for kv cache group block ids (#19175)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-06-09 22:01:17 -07:00 committed by GitHub
parent 6cd4ae8acd
commit 646d62f636
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 140 additions and 142 deletions

View File

@ -117,7 +117,7 @@ def test_prefill(hash_algo):
blocks = manager.allocate_slots(req0, 55, blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks[0]) * 16, len(computed_blocks.blocks[0]) * 16,
computed_blocks) computed_blocks)
assert blocks.get_block_ids() == [[1, 2, 3, 4]] assert blocks.get_block_ids() == ([1, 2, 3, 4], )
# Check full block metadata # Check full block metadata
parent_block_hash = None parent_block_hash = None
@ -141,13 +141,13 @@ def test_prefill(hash_algo):
req1 = make_request("1", common_token_ids + unique_token_ids) req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3 assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert computed_blocks.get_block_ids() == [[1, 2, 3]] assert computed_blocks.get_block_ids() == ([1, 2, 3], )
assert num_computed_tokens == 3 * 16 assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16 num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens, blocks = manager.allocate_slots(req1, num_new_tokens,
len(computed_blocks.blocks[0]) * 16, len(computed_blocks.blocks[0]) * 16,
computed_blocks) computed_blocks)
assert blocks.get_block_ids() == [[5]] assert blocks.get_block_ids() == ([5], )
for block in computed_blocks.blocks[0]: for block in computed_blocks.blocks[0]:
assert block.ref_cnt == 2 assert block.ref_cnt == 2
@ -175,13 +175,13 @@ def test_prefill(hash_algo):
req2 = make_request("2", common_token_ids + unique_token_ids) req2 = make_request("2", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(manager.req_to_block_hashes[req2.request_id]) == 3 assert len(manager.req_to_block_hashes[req2.request_id]) == 3
assert computed_blocks.get_block_ids() == [[1, 2, 3]] assert computed_blocks.get_block_ids() == ([1, 2, 3], )
assert num_computed_tokens == 3 * 16 assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16 num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req2, num_new_tokens, blocks = manager.allocate_slots(req2, num_new_tokens,
len(computed_blocks.blocks[0]) * 16, len(computed_blocks.blocks[0]) * 16,
computed_blocks) computed_blocks)
assert blocks.get_block_ids() == [[6]] assert blocks.get_block_ids() == ([6], )
# Although we only have 6 free blocks, we have 8 blocks in # Although we only have 6 free blocks, we have 8 blocks in
# the free block queue due to lazy removal. # the free block queue due to lazy removal.
@ -205,7 +205,7 @@ def test_prefill(hash_algo):
len(computed_blocks.blocks[0]) * 16, len(computed_blocks.blocks[0]) * 16,
computed_blocks) computed_blocks)
# This block ID order also checks the eviction order. # This block ID order also checks the eviction order.
assert blocks.get_block_ids() == [[7, 8, 9, 10, 4, 5, 6, 3, 2, 1]] assert blocks.get_block_ids() == ([7, 8, 9, 10, 4, 5, 6, 3, 2, 1], )
assert manager.block_pool.free_block_queue.num_free_blocks == 0 assert manager.block_pool.free_block_queue.num_free_blocks == 0
assert manager.block_pool.free_block_queue.free_list_head is None assert manager.block_pool.free_block_queue.free_list_head is None
assert manager.block_pool.free_block_queue.free_list_tail is None assert manager.block_pool.free_block_queue.free_list_tail is None
@ -236,8 +236,8 @@ def test_prefill_hybrid_model():
blocks = manager.allocate_slots(req0, 55, blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks[0]) * 16, len(computed_blocks.blocks[0]) * 16,
computed_blocks) computed_blocks)
assert blocks.get_block_ids() == [[1, 2, 3, 4], [5, 6, 7, 8], assert blocks.get_block_ids() == ([1, 2, 3, 4], [5, 6, 7,
[9, 10, 11, 12]] 8], [9, 10, 11, 12])
# Check full block metadata # Check full block metadata
parent_block_hash = None parent_block_hash = None
@ -263,14 +263,14 @@ def test_prefill_hybrid_model():
req1 = make_request("1", common_token_ids + unique_token_ids) req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3 assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert computed_blocks.get_block_ids() == [[1, 2, 3], [0, 6, 7], assert computed_blocks.get_block_ids() == ([1, 2, 3], [0, 6,
[0, 10, 11]] 7], [0, 10, 11])
assert num_computed_tokens == 3 * 16 assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16 num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens, blocks = manager.allocate_slots(req1, num_new_tokens,
len(computed_blocks.blocks[0]) * 16, len(computed_blocks.blocks[0]) * 16,
computed_blocks) computed_blocks)
assert blocks.get_block_ids() == [[13], [14], [15]] assert blocks.get_block_ids() == ([13], [14], [15])
for block_per_group in computed_blocks.blocks: for block_per_group in computed_blocks.blocks:
for block in block_per_group: for block in block_per_group:
if block != manager.block_pool.null_block: if block != manager.block_pool.null_block:
@ -374,7 +374,7 @@ def test_prefill_plp():
blocks = manager.allocate_slots(req0, 55, blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks[0]) * 16, len(computed_blocks.blocks[0]) * 16,
computed_blocks) computed_blocks)
assert blocks.get_block_ids() == [[1, 2, 3, 4]] assert blocks.get_block_ids() == ([1, 2, 3, 4], )
req0_block_hashes = [b.block_hash for b in blocks.blocks[0]] req0_block_hashes = [b.block_hash for b in blocks.blocks[0]]
# Check full block metadata # Check full block metadata
@ -400,13 +400,13 @@ def test_prefill_plp():
req1 = make_request("1", common_token_ids + unique_token_ids) req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3 assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert computed_blocks.get_block_ids() == [[1, 2, 3]] assert computed_blocks.get_block_ids() == ([1, 2, 3], )
assert num_computed_tokens == 3 * 16 assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16 num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens, blocks = manager.allocate_slots(req1, num_new_tokens,
len(computed_blocks.blocks[0]) * 16, len(computed_blocks.blocks[0]) * 16,
computed_blocks) computed_blocks)
assert blocks.get_block_ids() == [[5]] assert blocks.get_block_ids() == ([5], )
for block in computed_blocks.blocks[0]: for block in computed_blocks.blocks[0]:
assert block.ref_cnt == 2 assert block.ref_cnt == 2
@ -444,7 +444,7 @@ def test_prefill_plp():
block_ids = blocks.get_block_ids() block_ids = blocks.get_block_ids()
# Duplicate cached blocks have different ids but same hashes vs request #0 # Duplicate cached blocks have different ids but same hashes vs request #0
assert [b.block_hash for b in blocks.blocks[0]] == req0_block_hashes assert [b.block_hash for b in blocks.blocks[0]] == req0_block_hashes
assert block_ids != [[1, 2, 3, 4]] assert block_ids != ([1, 2, 3, 4], )
# Request #2 block hashes are valid since request #0 hashes are. # Request #2 block hashes are valid since request #0 hashes are.
# Check block reference counts. # Check block reference counts.
@ -474,7 +474,7 @@ def test_decode():
blocks = manager.allocate_slots(req0, 55, blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks[0]) * 16, len(computed_blocks.blocks[0]) * 16,
computed_blocks) computed_blocks)
assert blocks.get_block_ids() == [[1, 2, 3, 4]] assert blocks.get_block_ids() == ([1, 2, 3, 4], )
# Append slots without allocating a new block. # Append slots without allocating a new block.
req0.num_computed_tokens = 55 req0.num_computed_tokens = 55
@ -546,12 +546,12 @@ def test_evict():
# Touch the first 2 blocks. # Touch the first 2 blocks.
req2 = make_request("2", list(range(2 * 16 + 3))) req2 = make_request("2", list(range(2 * 16 + 3)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert computed_blocks.get_block_ids() == [[1, 2]] assert computed_blocks.get_block_ids() == ([1, 2], )
assert num_computed_tokens == 2 * 16 assert num_computed_tokens == 2 * 16
blocks = manager.allocate_slots(req2, 3, blocks = manager.allocate_slots(req2, 3,
len(computed_blocks.blocks[0]) * 16, len(computed_blocks.blocks[0]) * 16,
computed_blocks) computed_blocks)
assert blocks.get_block_ids() == [[10]] assert blocks.get_block_ids() == ([10], )
assert manager.block_pool.free_block_queue.num_free_blocks == 7 assert manager.block_pool.free_block_queue.num_free_blocks == 7
@ -865,7 +865,7 @@ def test_mm_prefix_caching():
blocks = manager.allocate_slots(req0, 59, blocks = manager.allocate_slots(req0, 59,
len(computed_blocks.blocks[0]) * 16, len(computed_blocks.blocks[0]) * 16,
computed_blocks) computed_blocks)
assert blocks.get_block_ids() == [[1, 2, 3, 4]] assert blocks.get_block_ids() == ([1, 2, 3, 4], )
req0.num_computed_tokens = 59 req0.num_computed_tokens = 59
# Append slots without allocating a new block. # Append slots without allocating a new block.
@ -926,7 +926,7 @@ def test_cache_key_salting():
blocks = manager.allocate_slots(req0, 59, blocks = manager.allocate_slots(req0, 59,
len(computed_blocks.blocks[0]) * 16, len(computed_blocks.blocks[0]) * 16,
computed_blocks) computed_blocks)
assert blocks.get_block_ids() == [[1, 2, 3, 4]] assert blocks.get_block_ids() == ([1, 2, 3, 4], )
req0.num_computed_tokens = 59 req0.num_computed_tokens = 59
# Append slots without allocating a new block. # Append slots without allocating a new block.
@ -1042,7 +1042,7 @@ def test_reset_prefix_cache():
all_token_ids = full_block_token_ids + unique_token_ids all_token_ids = full_block_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids) req0 = make_request("0", all_token_ids)
blocks = manager.allocate_slots(req0, 55) blocks = manager.allocate_slots(req0, 55)
assert blocks.get_block_ids() == [[1, 2, 3, 4]] assert blocks.get_block_ids() == ([1, 2, 3, 4], )
unique_token_ids = [4] * 7 unique_token_ids = [4] * 7
all_token_ids = full_block_token_ids + unique_token_ids all_token_ids = full_block_token_ids + unique_token_ids
@ -1053,7 +1053,7 @@ def test_reset_prefix_cache():
blocks = manager.allocate_slots(req1, 7, blocks = manager.allocate_slots(req1, 7,
len(computed_blocks.blocks[0]) * 16, len(computed_blocks.blocks[0]) * 16,
computed_blocks) computed_blocks)
assert blocks.get_block_ids() == [[5]] assert blocks.get_block_ids() == ([5], )
# Failed to reset prefix cache because some blocks are not freed yet. # Failed to reset prefix cache because some blocks are not freed yet.
assert not manager.reset_prefix_cache() assert not manager.reset_prefix_cache()

View File

@ -71,7 +71,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
mm_hashes=[], mm_hashes=[],
mm_positions=[], mm_positions=[],
sampling_params=SamplingParams(), sampling_params=SamplingParams(),
block_ids=[[0]], # block_ids should be list[list[int]] block_ids=([0], ), # block_ids should be tuple[list[int]]
num_computed_tokens=0, num_computed_tokens=0,
lora_request=None, lora_request=None,
)) ))
@ -116,10 +116,10 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
# This is safe since we currently only use single KV cache groups # This is safe since we currently only use single KV cache groups
block_table = multi_group_block_table[0] block_table = multi_group_block_table[0]
# req_state.block_ids is now list[list[int]] for MultiGroupBlockTable # req_state.block_ids is now tuple[list[int], ...] for MultiGroupBlockTable
# Extract the first group's block IDs # Extract the first group's block IDs
if isinstance(req_state.block_ids[0], list): if isinstance(req_state.block_ids[0], list):
# New format: list[list[int]] - extract first group # New format: tuple[list[int], ...] - extract first group
req_block_ids = req_state.block_ids[0] req_block_ids = req_state.block_ids[0]
else: else:
# Legacy format: list[int] - use directly # Legacy format: list[int] - use directly
@ -210,7 +210,7 @@ def test_update_states_request_resumed(model_runner):
req_id=req_id, req_id=req_id,
resumed_from_preemption=False, resumed_from_preemption=False,
new_token_ids=[], new_token_ids=[],
new_block_ids=[[]], new_block_ids=([], ),
num_computed_tokens=0, num_computed_tokens=0,
) )

View File

@ -203,7 +203,7 @@ def _construct_cached_request_state(req_id_suffix: int):
sampling_params=_create_sampling_params(), sampling_params=_create_sampling_params(),
mm_inputs=[], mm_inputs=[],
mm_positions=[], mm_positions=[],
block_ids=[[]], block_ids=([], ),
generator=None, generator=None,
num_computed_tokens=len(output_token_ids), num_computed_tokens=len(output_token_ids),
output_token_ids=output_token_ids, output_token_ids=output_token_ids,

View File

@ -123,7 +123,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
mm_hashes=[], mm_hashes=[],
mm_positions=[], mm_positions=[],
sampling_params=SamplingParams(), sampling_params=SamplingParams(),
block_ids=[[0]], block_ids=([0], ),
num_computed_tokens=0, num_computed_tokens=0,
lora_request=None, lora_request=None,
)) ))
@ -251,7 +251,7 @@ def test_update_states_request_resumed(model_runner):
req_id=req_id, req_id=req_id,
resumed_from_preemption=False, resumed_from_preemption=False,
new_token_ids=[], new_token_ids=[],
new_block_ids=[[]], new_block_ids=([], ),
num_computed_tokens=0, num_computed_tokens=0,
) )

View File

@ -89,8 +89,8 @@ class BlockPool:
BlockHashWithGroupId(block_hash, group_id)) BlockHashWithGroupId(block_hash, group_id))
if not cached_blocks_one_group: if not cached_blocks_one_group:
return None return None
first_block_id = next(iter(cached_blocks_one_group)) first_block = next(iter(cached_blocks_one_group.values()))
cached_blocks.append(cached_blocks_one_group[first_block_id]) cached_blocks.append(first_block)
return cached_blocks return cached_blocks
def cache_full_blocks( def cache_full_blocks(
@ -260,7 +260,7 @@ class BlockPool:
return True return True
return False return False
def touch(self, blocks: list[list[KVCacheBlock]]) -> None: def touch(self, blocks: tuple[list[KVCacheBlock], ...]) -> None:
"""Touch a block increases its reference count by 1, and may remove """Touch a block increases its reference count by 1, and may remove
the block from the free queue. This is used when a block is hit by the block from the free queue. This is used when a block is hit by
another request with the same prefix. another request with the same prefix.
@ -299,7 +299,7 @@ class BlockPool:
bool: True if the prefix cache is successfully reset, bool: True if the prefix cache is successfully reset,
False otherwise. False otherwise.
""" """
num_used_blocks = (self.num_gpu_blocks - self.get_num_free_blocks()) num_used_blocks = self.num_gpu_blocks - self.get_num_free_blocks()
if num_used_blocks != 1: # The null block is always marked as used if num_used_blocks != 1: # The null block is always marked as used
logger.warning( logger.warning(
"Failed to reset prefix cache because some " "Failed to reset prefix cache because some "

View File

@ -5,8 +5,7 @@ from typing import Callable, Optional
from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
from vllm.v1.core.single_type_kv_cache_manager import ( from vllm.v1.core.single_type_kv_cache_manager import (
FullAttentionManager, SingleTypeKVCacheManager, FullAttentionManager, get_manager_for_kv_cache_spec)
get_manager_for_kv_cache_spec)
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
from vllm.v1.request import Request from vllm.v1.request import Request
@ -30,25 +29,21 @@ class KVCacheCoordinator(ABC):
self.block_pool = BlockPool(kv_cache_config.num_blocks, enable_caching, self.block_pool = BlockPool(kv_cache_config.num_blocks, enable_caching,
enable_kv_cache_events) enable_kv_cache_events)
self.single_type_managers: list[SingleTypeKVCacheManager] = []
# Needs special handling for find_longest_cache_hit if eagle is enabled # Needs special handling for find_longest_cache_hit if eagle is enabled
self.use_eagle = use_eagle self.use_eagle = use_eagle
self.single_type_managers = tuple(
for i in range(len(self.kv_cache_config.kv_cache_groups)):
kv_cache_spec = self.kv_cache_config.kv_cache_groups[
i].kv_cache_spec
self.single_type_managers.append(
get_manager_for_kv_cache_spec( get_manager_for_kv_cache_spec(
kv_cache_spec=kv_cache_spec, kv_cache_spec=kv_cache_group.kv_cache_spec,
block_pool=self.block_pool, block_pool=self.block_pool,
kv_cache_group_id=i, kv_cache_group_id=i,
caching_hash_fn=caching_hash_fn, caching_hash_fn=caching_hash_fn,
)) ) for i, kv_cache_group in enumerate(
self.kv_cache_config.kv_cache_groups))
def get_num_blocks_to_allocate( def get_num_blocks_to_allocate(
self, request_id: str, num_tokens: int, self, request_id: str, num_tokens: int,
new_computed_blocks: list[list[KVCacheBlock]]) -> int: new_computed_blocks: tuple[list[KVCacheBlock], ...]) -> int:
""" """
Get the number of blocks needed to be allocated for the request. Get the number of blocks needed to be allocated for the request.
@ -70,7 +65,7 @@ class KVCacheCoordinator(ABC):
def save_new_computed_blocks( def save_new_computed_blocks(
self, request_id: str, self, request_id: str,
new_computed_blocks: list[list[KVCacheBlock]]) -> None: new_computed_blocks: tuple[list[KVCacheBlock], ...]) -> None:
""" """
Add the new computed blocks to the request. Add the new computed blocks to the request.
@ -84,7 +79,7 @@ class KVCacheCoordinator(ABC):
new_computed_blocks[i]) new_computed_blocks[i])
def allocate_new_blocks(self, request_id: str, def allocate_new_blocks(self, request_id: str,
num_tokens: int) -> list[list[KVCacheBlock]]: num_tokens: int) -> tuple[list[KVCacheBlock], ...]:
""" """
Allocate new blocks for the request to give it at least `num_tokens` Allocate new blocks for the request to give it at least `num_tokens`
token slots. token slots.
@ -97,11 +92,9 @@ class KVCacheCoordinator(ABC):
Returns: Returns:
The new allocated blocks. The new allocated blocks.
""" """
new_blocks = [] return tuple(
for manager in self.single_type_managers: manager.allocate_new_blocks(request_id, num_tokens)
new_blocks.append( for manager in self.single_type_managers)
manager.allocate_new_blocks(request_id, num_tokens))
return new_blocks
def cache_blocks(self, request: Request, block_hashes: list[BlockHash], def cache_blocks(self, request: Request, block_hashes: list[BlockHash],
num_computed_tokens: int) -> None: num_computed_tokens: int) -> None:
@ -159,19 +152,20 @@ class KVCacheCoordinator(ABC):
for manager in self.single_type_managers: for manager in self.single_type_managers:
manager.remove_skipped_blocks(request_id, num_computed_tokens) manager.remove_skipped_blocks(request_id, num_computed_tokens)
def get_blocks(self, request_id: str) -> list[list[KVCacheBlock]]: def get_blocks(self, request_id: str) -> tuple[list[KVCacheBlock], ...]:
""" """
Get the blocks for the request. Get the blocks for the request.
""" """
return [ return tuple(
manager.req_to_blocks.get(request_id) or [] manager.req_to_blocks.get(request_id) or []
for manager in self.single_type_managers for manager in self.single_type_managers)
]
@abstractmethod @abstractmethod
def find_longest_cache_hit( def find_longest_cache_hit(
self, block_hashes: list[BlockHash], self,
max_cache_hit_length: int) -> tuple[list[list[KVCacheBlock]], int]: block_hashes: list[BlockHash],
max_cache_hit_length: int,
) -> tuple[tuple[list[KVCacheBlock], ...], int]:
pass pass
@ -195,8 +189,10 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
"UnitaryKVCacheCoordinator assumes only one kv cache group") "UnitaryKVCacheCoordinator assumes only one kv cache group")
def find_longest_cache_hit( def find_longest_cache_hit(
self, block_hashes: list[BlockHash], self,
max_cache_hit_length: int) -> tuple[list[list[KVCacheBlock]], int]: block_hashes: list[BlockHash],
max_cache_hit_length: int,
) -> tuple[tuple[list[KVCacheBlock], ...], int]:
hit_blocks = self.single_type_managers[0].find_longest_cache_hit( hit_blocks = self.single_type_managers[0].find_longest_cache_hit(
block_hashes=block_hashes, block_hashes=block_hashes,
max_length=max_cache_hit_length, max_length=max_cache_hit_length,
@ -275,11 +271,24 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
"KVCacheCoordinator assumes the block_size of full attention " "KVCacheCoordinator assumes the block_size of full attention "
"layers is divisible by other layers now.") "layers is divisible by other layers now.")
if max(self.full_attention_group_ids) < min(self.other_group_ids):
self.full_attn_first = True
elif max(self.other_group_ids) < min(self.full_attention_group_ids):
self.full_attn_first = False
else:
raise ValueError(
"HybridKVCacheCoordinator assumes the full "
"attention group ids and other attention group ids "
"do not interleave, either full attention group ids "
"are before other attention group ids or vice versa."
"This is for simplifying merging hit_blocks_full_attn and "
"hit_blocks_other_attn to hit_blocks.")
def find_longest_cache_hit( def find_longest_cache_hit(
self, self,
block_hashes: list[BlockHash], block_hashes: list[BlockHash],
max_cache_hit_length: int, max_cache_hit_length: int,
) -> tuple[list[list[KVCacheBlock]], int]: ) -> tuple[tuple[list[KVCacheBlock], ...], int]:
""" """
Find the longest cache hit for the request. Find the longest cache hit for the request.
@ -318,27 +327,25 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
)) ))
hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size
# NOTE: the prefix cache hit length must be a multiply of block_size as # NOTE: the prefix cache hit length must be a multiple of block_size as
# we don't support partial block cache hit yet. The cache hit length # we don't support partial block cache hit yet. The cache hit length
# of other attention is ensured to be a multiply of the block size of # of other attention is ensured to be a multiple of the block size of
# full attention layers in current implementation, because hit_length is # full attention layers in current implementation, because hit_length is
# a multiply of other attention's block size, and other attention's # a multiple of other attention's block size, and other attention's
# block size is a multiply of full attention's block size (verified in # block size is a multiple of full attention's block size (verified in
# `verify_and_split_kv_cache_groups`). # `verify_and_split_kv_cache_groups`).
assert hit_length % self.full_attention_block_size == 0 assert hit_length % self.full_attention_block_size == 0
# Truncate the full attention cache hit to the length of the # Truncate the full attention cache hit to the length of the
# cache hit of the other attention. # cache hit of the other attention.
for i in range(len(hit_blocks_full_attn)): for group_hit_blocks in hit_blocks_full_attn:
del hit_blocks_full_attn[i][hit_length // del group_hit_blocks[hit_length // self.full_attention_block_size:]
self.full_attention_block_size:]
# Merge the hit blocks of full attention and other attention. # Merge the hit blocks of full attention and other attention.
hit_blocks = hit_blocks_other_attn if self.full_attn_first:
for group_id, blocks in enumerate(hit_blocks_full_attn): hit_blocks = hit_blocks_full_attn + hit_blocks_other_attn
# NOTE: there is only one full attention group in most cases. So else:
# the time complexity of insert is fine. hit_blocks = hit_blocks_other_attn + hit_blocks_full_attn
hit_blocks.insert(group_id, blocks)
return hit_blocks, hit_length return hit_blocks, hit_length
@ -351,8 +358,6 @@ def get_kv_cache_coordinator(
use_eagle, enable_caching, use_eagle, enable_caching,
caching_hash_fn, caching_hash_fn,
enable_kv_cache_events) enable_kv_cache_events)
else: return HybridKVCacheCoordinator(kv_cache_config, max_model_len, use_eagle,
return HybridKVCacheCoordinator(kv_cache_config, max_model_len, enable_caching, caching_hash_fn,
use_eagle, enable_caching,
caching_hash_fn,
enable_kv_cache_events) enable_kv_cache_events)

View File

@ -25,7 +25,7 @@ class KVCacheBlocks:
Scheduler and KVCacheManager, to hide KVCacheManager's internal data Scheduler and KVCacheManager, to hide KVCacheManager's internal data
structure from the Scheduler. structure from the Scheduler.
""" """
blocks: list[list[KVCacheBlock]] blocks: tuple[list[KVCacheBlock], ...]
""" """
blocks[i][j] refers to the i-th kv_cache_group and the j-th block of tokens. blocks[i][j] refers to the i-th kv_cache_group and the j-th block of tokens.
We don't use block of tokens as the outer dimension because it assumes all We don't use block of tokens as the outer dimension because it assumes all
@ -37,21 +37,19 @@ class KVCacheBlocks:
def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks": def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks":
"""Adds two KVCacheBlocks instances.""" """Adds two KVCacheBlocks instances."""
return KVCacheBlocks( return KVCacheBlocks(
[blk1 + blk2 for blk1, blk2 in zip(self.blocks, other.blocks)]) tuple(blk1 + blk2
for blk1, blk2 in zip(self.blocks, other.blocks)))
def get_block_ids(self) -> list[list[int]]: def get_block_ids(self) -> tuple[list[int], ...]:
""" """
Converts the KVCacheBlocks instance to block_ids. Converts the KVCacheBlocks instance to block_ids.
Returns: Returns:
list[list[int]]: A two-level list where tuple[list[int], ...]: A tuple of lists where
* the outer list corresponds to KV cache groups * the outer tuple corresponds to KV cache groups
* each inner list contains the block_ids of the blocks in that group * each inner list contains the block_ids of the blocks in that group
""" """
block_ids = [] return tuple([blk.block_id for blk in group] for group in self.blocks)
for group in self.blocks:
block_ids.append([blk.block_id for blk in group])
return block_ids
def get_unhashed_block_ids(self) -> list[int]: def get_unhashed_block_ids(self) -> list[int]:
"""Get block_ids of unhashed blocks from KVCacheBlocks instance.""" """Get block_ids of unhashed blocks from KVCacheBlocks instance."""
@ -63,7 +61,7 @@ class KVCacheBlocks:
def new_empty(self) -> "KVCacheBlocks": def new_empty(self) -> "KVCacheBlocks":
"""Creates a new KVCacheBlocks instance with no blocks.""" """Creates a new KVCacheBlocks instance with no blocks."""
return KVCacheBlocks([[] for _ in range(len(self.blocks))]) return KVCacheBlocks(tuple([] for _ in range(len(self.blocks))))
class KVCacheManager: class KVCacheManager:
@ -232,9 +230,8 @@ class KVCacheManager:
if new_computed_blocks is not None: if new_computed_blocks is not None:
new_computed_block_list = new_computed_blocks.blocks new_computed_block_list = new_computed_blocks.blocks
else: else:
new_computed_block_list = [ new_computed_block_list = tuple(
[] for _ in range(len(self.kv_cache_config.kv_cache_groups)) [] for _ in range(len(self.kv_cache_config.kv_cache_groups)))
]
# Free the blocks that are skipped during the attention computation # Free the blocks that are skipped during the attention computation
# (e.g., tokens outside the sliding window). # (e.g., tokens outside the sliding window).
@ -267,7 +264,7 @@ class KVCacheManager:
if self.enable_caching: if self.enable_caching:
self.block_pool.touch(new_computed_block_list) self.block_pool.touch(new_computed_block_list)
else: else:
assert all(not blocks for blocks in new_computed_block_list), ( assert not any(new_computed_block_list), (
"Computed blocks should be empty when " "Computed blocks should be empty when "
"prefix caching is disabled") "prefix caching is disabled")
@ -378,17 +375,18 @@ class KVCacheManager:
""" """
return self.block_pool.take_events() return self.block_pool.take_events()
def get_block_ids(self, request_id: str) -> list[list[int]]: def get_block_ids(self, request_id: str) -> tuple[list[int], ...]:
"""Get the block ids of a request.""" """Get the block ids of a request."""
return KVCacheBlocks( return KVCacheBlocks(
self.coordinator.get_blocks(request_id)).get_block_ids() self.coordinator.get_blocks(request_id)).get_block_ids()
def cache_blocks(self, request: Request, block_hashes: list[BlockHash], def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
num_computed_tokens: int) -> None:
"""Cache the blocks for the request.""" """Cache the blocks for the request."""
block_hashes = self.req_to_block_hashes[request.request_id]
self.coordinator.cache_blocks(request, block_hashes, self.coordinator.cache_blocks(request, block_hashes,
num_computed_tokens) num_computed_tokens)
def create_empty_block_list(self) -> KVCacheBlocks: def create_empty_block_list(self) -> KVCacheBlocks:
"""Creates a new KVCacheBlocks instance with no blocks.""" """Creates a new KVCacheBlocks instance with no blocks."""
return KVCacheBlocks([[] for _ in range(self.num_kv_cache_groups)]) return KVCacheBlocks(tuple([]
for _ in range(self.num_kv_cache_groups)))

View File

@ -27,7 +27,7 @@ class NewRequestData:
mm_hashes: list[str] mm_hashes: list[str]
mm_positions: list[PlaceholderRange] mm_positions: list[PlaceholderRange]
sampling_params: SamplingParams sampling_params: SamplingParams
block_ids: list[list[int]] block_ids: tuple[list[int], ...]
num_computed_tokens: int num_computed_tokens: int
lora_request: Optional[LoRARequest] lora_request: Optional[LoRARequest]
@ -35,7 +35,7 @@ class NewRequestData:
def from_request( def from_request(
cls, cls,
request: Request, request: Request,
block_ids: list[list[int]], block_ids: tuple[list[int], ...],
) -> NewRequestData: ) -> NewRequestData:
return cls( return cls(
req_id=request.request_id, req_id=request.request_id,
@ -86,7 +86,7 @@ class CachedRequestData:
# request's block IDs instead of appending to the existing block IDs. # request's block IDs instead of appending to the existing block IDs.
resumed_from_preemption: bool resumed_from_preemption: bool
new_token_ids: list[int] new_token_ids: list[int]
new_block_ids: list[list[int]] new_block_ids: tuple[list[int], ...]
num_computed_tokens: int num_computed_tokens: int
@classmethod @classmethod
@ -95,7 +95,7 @@ class CachedRequestData:
request: Request, request: Request,
resumed_from_preemption: bool, resumed_from_preemption: bool,
new_token_ids: list[int], new_token_ids: list[int],
new_block_ids: list[list[int]], new_block_ids: tuple[list[int], ...],
) -> CachedRequestData: ) -> CachedRequestData:
return cls( return cls(
req_id=request.request_id, req_id=request.request_id,

View File

@ -180,7 +180,7 @@ class Scheduler(SchedulerInterface):
# uses structured decoding. # uses structured decoding.
structured_output_request_ids: dict[str, int] = {} structured_output_request_ids: dict[str, int] = {}
req_to_new_block_ids: dict[str, list[list[int]]] = {} req_to_new_block_ids: dict[str, tuple[list[int], ...]] = {}
num_scheduled_tokens: dict[str, int] = {} num_scheduled_tokens: dict[str, int] = {}
token_budget = self.max_num_scheduled_tokens token_budget = self.max_num_scheduled_tokens
# Encoder-related. # Encoder-related.
@ -471,7 +471,7 @@ class Scheduler(SchedulerInterface):
token_budget -= num_new_tokens token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens request.num_computed_tokens = num_computed_tokens
# Count the number of prifix cached tokens. # Count the number of prefix cached tokens.
if request.num_cached_tokens < 0: if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens request.num_cached_tokens = num_computed_tokens
# Encoder-related. # Encoder-related.
@ -588,7 +588,7 @@ class Scheduler(SchedulerInterface):
request: Request, request: Request,
num_scheduled_tokens: int, num_scheduled_tokens: int,
num_scheduled_spec_tokens: int, num_scheduled_spec_tokens: int,
new_block_ids: list[list[int]], new_block_ids: tuple[list[int], ...],
resumed_from_preemption: bool, resumed_from_preemption: bool,
) -> CachedRequestData: ) -> CachedRequestData:
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
@ -1015,11 +1015,7 @@ class Scheduler(SchedulerInterface):
num_computed_tokens = min(num_computed_tokens, request.num_tokens) num_computed_tokens = min(num_computed_tokens, request.num_tokens)
if num_computed_tokens == request.num_tokens: if num_computed_tokens == request.num_tokens:
num_computed_tokens -= 1 num_computed_tokens -= 1
self.kv_cache_manager.cache_blocks( self.kv_cache_manager.cache_blocks(request, num_computed_tokens)
request,
self.kv_cache_manager.req_to_block_hashes[request.request_id],
num_computed_tokens,
)
# Update the request state for scheduling. # Update the request state for scheduling.
request.num_computed_tokens = num_computed_tokens request.num_computed_tokens = num_computed_tokens

View File

@ -197,7 +197,7 @@ class SingleTypeKVCacheManager(ABC):
block_pool: BlockPool, block_pool: BlockPool,
kv_cache_spec: KVCacheSpec, kv_cache_spec: KVCacheSpec,
use_eagle: bool, use_eagle: bool,
) -> list[list[KVCacheBlock]]: ) -> tuple[list[KVCacheBlock], ...]:
""" """
Get the longest cache hit prefix of the blocks that is not longer than Get the longest cache hit prefix of the blocks that is not longer than
`max_length`. The prefix should be a common prefix hit for all the `max_length`. The prefix should be a common prefix hit for all the
@ -222,7 +222,7 @@ class SingleTypeKVCacheManager(ABC):
element is a list of cached blocks for the i-th kv cache group element is a list of cached blocks for the i-th kv cache group
in `kv_cache_group_ids`. in `kv_cache_group_ids`.
For example, sliding window manager should return a list like For example, sliding window manager should return a list like
[[NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)]] for block size 4 ([NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)]) for block size 4
and sliding window 8 and len(kv_cache_group_ids) = 1. and sliding window 8 and len(kv_cache_group_ids) = 1.
""" """
@ -254,27 +254,25 @@ class FullAttentionManager(SingleTypeKVCacheManager):
block_pool: BlockPool, block_pool: BlockPool,
kv_cache_spec: KVCacheSpec, kv_cache_spec: KVCacheSpec,
use_eagle: bool, use_eagle: bool,
) -> list[list[KVCacheBlock]]: ) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(kv_cache_spec, FullAttentionSpec), ( assert isinstance(kv_cache_spec, FullAttentionSpec), (
"FullAttentionManager can only be used for full attention groups") "FullAttentionManager can only be used for full attention groups")
computed_blocks: list[list[KVCacheBlock]] = [ computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
[] for _ in range(len(kv_cache_group_ids)) [] for _ in range(len(kv_cache_group_ids)))
]
max_num_blocks = max_length // kv_cache_spec.block_size max_num_blocks = max_length // kv_cache_spec.block_size
for i in range(max_num_blocks): for i, block_hash in zip(range(max_num_blocks), block_hashes):
block_hash = block_hashes[i]
# block_hashes is a chain of block hashes. If a block hash is not # block_hashes is a chain of block hashes. If a block hash is not
# in the cached_block_hash_to_id, the following block hashes are # in the cached_block_hash_to_id, the following block hashes are
# not computed yet for sure. # not computed yet for sure.
if cached_block := block_pool.get_cached_block( if cached_block := block_pool.get_cached_block(
block_hash, kv_cache_group_ids): block_hash, kv_cache_group_ids):
for j in range(len(kv_cache_group_ids)): for computed, cached in zip(computed_blocks, cached_block):
computed_blocks[j].append(cached_block[j]) computed.append(cached)
else: else:
break break
if use_eagle and len(computed_blocks[0]) > 0: if use_eagle and computed_blocks[0]:
for j in range(len(kv_cache_group_ids)): for computed in computed_blocks:
computed_blocks[j].pop() computed.pop()
return computed_blocks return computed_blocks
def remove_skipped_blocks(self, request_id: str, def remove_skipped_blocks(self, request_id: str,
@ -311,7 +309,7 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
block_pool: BlockPool, block_pool: BlockPool,
kv_cache_spec: KVCacheSpec, kv_cache_spec: KVCacheSpec,
use_eagle: bool, use_eagle: bool,
) -> list[list[KVCacheBlock]]: ) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(kv_cache_spec, SlidingWindowSpec), ( assert isinstance(kv_cache_spec, SlidingWindowSpec), (
"SlidingWindowManager can only be used for sliding window groups") "SlidingWindowManager can only be used for sliding window groups")
@ -332,23 +330,23 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
# sliding_window_contiguous_blocks), # sliding_window_contiguous_blocks),
# which is good for low cache hit rate scenarios. # which is good for low cache hit rate scenarios.
max_num_blocks = max_length // kv_cache_spec.block_size max_num_blocks = max_length // kv_cache_spec.block_size
computed_blocks = [[block_pool.null_block] * max_num_blocks computed_blocks = tuple([block_pool.null_block] * max_num_blocks
for _ in range(len(kv_cache_group_ids))] for _ in range(len(kv_cache_group_ids)))
num_contiguous_blocks = 0 num_contiguous_blocks = 0
match_found = False match_found = False
# Search from right to left and early stop when a match is found. # Search from right to left and early stop when a match is found.
for i in range(max_num_blocks - 1, -1, -1): for i in range(max_num_blocks - 1, -1, -1):
if cached_block := block_pool.get_cached_block( if cached_block := block_pool.get_cached_block(
block_hashes[i], kv_cache_group_ids): block_hashes[i], kv_cache_group_ids):
for j in range(len(kv_cache_group_ids)): for computed, cached in zip(computed_blocks, cached_block):
computed_blocks[j][i] = cached_block[j] computed[i] = cached
num_contiguous_blocks += 1 num_contiguous_blocks += 1
if (num_contiguous_blocks >= sliding_window_contiguous_blocks): if num_contiguous_blocks >= sliding_window_contiguous_blocks:
# Trim the trailing blocks. # Trim the trailing blocks.
# E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3] # E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3]
# when sliding_window_contiguous_blocks=2. # when sliding_window_contiguous_blocks=2.
for j in range(len(kv_cache_group_ids)): for computed in computed_blocks:
del computed_blocks[j][i + num_contiguous_blocks:] del computed[i + num_contiguous_blocks:]
match_found = True match_found = True
break break
else: else:
@ -356,11 +354,11 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
if not match_found: if not match_found:
# The first `num_contiguous_blocks` is a cache hit even if # The first `num_contiguous_blocks` is a cache hit even if
# `num_contiguous_blocks < sliding_window_contiguous_blocks`. # `num_contiguous_blocks < sliding_window_contiguous_blocks`.
for j in range(len(kv_cache_group_ids)): for computed in computed_blocks:
del computed_blocks[j][num_contiguous_blocks:] del computed[num_contiguous_blocks:]
if use_eagle and len(computed_blocks[0]) > 0: if use_eagle and computed_blocks[0]:
for j in range(len(kv_cache_group_ids)): for computed in computed_blocks:
computed_blocks[j].pop() computed.pop()
return computed_blocks return computed_blocks
def remove_skipped_blocks(self, request_id: str, def remove_skipped_blocks(self, request_id: str,

View File

@ -112,11 +112,12 @@ class MultiGroupBlockTable:
for block_size in block_sizes for block_size in block_sizes
] ]
def append_row(self, block_ids: list[list[int]], row_idx: int) -> None: def append_row(self, block_ids: tuple[list[int], ...],
row_idx: int) -> None:
for i, block_table in enumerate(self.block_tables): for i, block_table in enumerate(self.block_tables):
block_table.append_row(block_ids[i], row_idx) block_table.append_row(block_ids[i], row_idx)
def add_row(self, block_ids: list[list[int]], row_idx: int) -> None: def add_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:
for i, block_table in enumerate(self.block_tables): for i, block_table in enumerate(self.block_tables):
block_table.add_row(block_ids[i], row_idx) block_table.add_row(block_ids[i], row_idx)

View File

@ -30,7 +30,7 @@ class CachedRequestState:
sampling_params: SamplingParams sampling_params: SamplingParams
generator: Optional[torch.Generator] generator: Optional[torch.Generator]
block_ids: list[list[int]] block_ids: tuple[list[int], ...]
num_computed_tokens: int num_computed_tokens: int
output_token_ids: list[int] output_token_ids: list[int]