[Core] Use tuple for kv cache group block ids (#19175)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-06-09 22:01:17 -07:00 committed by GitHub
parent 6cd4ae8acd
commit 646d62f636
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 140 additions and 142 deletions

View File

@ -117,7 +117,7 @@ def test_prefill(hash_algo):
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
# Check full block metadata
parent_block_hash = None
@ -141,13 +141,13 @@ def test_prefill(hash_algo):
req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
assert computed_blocks.get_block_ids() == ([1, 2, 3], )
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[5]]
assert blocks.get_block_ids() == ([5], )
for block in computed_blocks.blocks[0]:
assert block.ref_cnt == 2
@ -175,13 +175,13 @@ def test_prefill(hash_algo):
req2 = make_request("2", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(manager.req_to_block_hashes[req2.request_id]) == 3
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
assert computed_blocks.get_block_ids() == ([1, 2, 3], )
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req2, num_new_tokens,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[6]]
assert blocks.get_block_ids() == ([6], )
# Although we only have 6 free blocks, we have 8 blocks in
# the free block queue due to lazy removal.
@ -205,7 +205,7 @@ def test_prefill(hash_algo):
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
# This block ID order also checks the eviction order.
assert blocks.get_block_ids() == [[7, 8, 9, 10, 4, 5, 6, 3, 2, 1]]
assert blocks.get_block_ids() == ([7, 8, 9, 10, 4, 5, 6, 3, 2, 1], )
assert manager.block_pool.free_block_queue.num_free_blocks == 0
assert manager.block_pool.free_block_queue.free_list_head is None
assert manager.block_pool.free_block_queue.free_list_tail is None
@ -236,8 +236,8 @@ def test_prefill_hybrid_model():
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[1, 2, 3, 4], [5, 6, 7, 8],
[9, 10, 11, 12]]
assert blocks.get_block_ids() == ([1, 2, 3, 4], [5, 6, 7,
8], [9, 10, 11, 12])
# Check full block metadata
parent_block_hash = None
@ -263,14 +263,14 @@ def test_prefill_hybrid_model():
req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert computed_blocks.get_block_ids() == [[1, 2, 3], [0, 6, 7],
[0, 10, 11]]
assert computed_blocks.get_block_ids() == ([1, 2, 3], [0, 6,
7], [0, 10, 11])
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[13], [14], [15]]
assert blocks.get_block_ids() == ([13], [14], [15])
for block_per_group in computed_blocks.blocks:
for block in block_per_group:
if block != manager.block_pool.null_block:
@ -374,7 +374,7 @@ def test_prefill_plp():
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
req0_block_hashes = [b.block_hash for b in blocks.blocks[0]]
# Check full block metadata
@ -400,13 +400,13 @@ def test_prefill_plp():
req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
assert computed_blocks.get_block_ids() == ([1, 2, 3], )
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[5]]
assert blocks.get_block_ids() == ([5], )
for block in computed_blocks.blocks[0]:
assert block.ref_cnt == 2
@ -444,7 +444,7 @@ def test_prefill_plp():
block_ids = blocks.get_block_ids()
# Duplicate cached blocks have different ids but same hashes vs request #0
assert [b.block_hash for b in blocks.blocks[0]] == req0_block_hashes
assert block_ids != [[1, 2, 3, 4]]
assert block_ids != ([1, 2, 3, 4], )
# Request #2 block hashes are valid since request #0 hashes are.
# Check block reference counts.
@ -474,7 +474,7 @@ def test_decode():
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
# Append slots without allocating a new block.
req0.num_computed_tokens = 55
@ -546,12 +546,12 @@ def test_evict():
# Touch the first 2 blocks.
req2 = make_request("2", list(range(2 * 16 + 3)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert computed_blocks.get_block_ids() == [[1, 2]]
assert computed_blocks.get_block_ids() == ([1, 2], )
assert num_computed_tokens == 2 * 16
blocks = manager.allocate_slots(req2, 3,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[10]]
assert blocks.get_block_ids() == ([10], )
assert manager.block_pool.free_block_queue.num_free_blocks == 7
@ -865,7 +865,7 @@ def test_mm_prefix_caching():
blocks = manager.allocate_slots(req0, 59,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
req0.num_computed_tokens = 59
# Append slots without allocating a new block.
@ -926,7 +926,7 @@ def test_cache_key_salting():
blocks = manager.allocate_slots(req0, 59,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
req0.num_computed_tokens = 59
# Append slots without allocating a new block.
@ -1042,7 +1042,7 @@ def test_reset_prefix_cache():
all_token_ids = full_block_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids)
blocks = manager.allocate_slots(req0, 55)
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
unique_token_ids = [4] * 7
all_token_ids = full_block_token_ids + unique_token_ids
@ -1053,7 +1053,7 @@ def test_reset_prefix_cache():
blocks = manager.allocate_slots(req1, 7,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[5]]
assert blocks.get_block_ids() == ([5], )
# Failed to reset prefix cache because some blocks are not freed yet.
assert not manager.reset_prefix_cache()

View File

@ -71,7 +71,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
mm_hashes=[],
mm_positions=[],
sampling_params=SamplingParams(),
block_ids=[[0]], # block_ids should be list[list[int]]
block_ids=([0], ), # block_ids should be tuple[list[int]]
num_computed_tokens=0,
lora_request=None,
))
@ -116,10 +116,10 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
# This is safe since we currently only use single KV cache groups
block_table = multi_group_block_table[0]
# req_state.block_ids is now list[list[int]] for MultiGroupBlockTable
# req_state.block_ids is now tuple[list[int], ...] for MultiGroupBlockTable
# Extract the first group's block IDs
if isinstance(req_state.block_ids[0], list):
# New format: list[list[int]] - extract first group
# New format: tuple[list[int], ...] - extract first group
req_block_ids = req_state.block_ids[0]
else:
# Legacy format: list[int] - use directly
@ -210,7 +210,7 @@ def test_update_states_request_resumed(model_runner):
req_id=req_id,
resumed_from_preemption=False,
new_token_ids=[],
new_block_ids=[[]],
new_block_ids=([], ),
num_computed_tokens=0,
)

View File

@ -203,7 +203,7 @@ def _construct_cached_request_state(req_id_suffix: int):
sampling_params=_create_sampling_params(),
mm_inputs=[],
mm_positions=[],
block_ids=[[]],
block_ids=([], ),
generator=None,
num_computed_tokens=len(output_token_ids),
output_token_ids=output_token_ids,

View File

@ -123,7 +123,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
mm_hashes=[],
mm_positions=[],
sampling_params=SamplingParams(),
block_ids=[[0]],
block_ids=([0], ),
num_computed_tokens=0,
lora_request=None,
))
@ -251,7 +251,7 @@ def test_update_states_request_resumed(model_runner):
req_id=req_id,
resumed_from_preemption=False,
new_token_ids=[],
new_block_ids=[[]],
new_block_ids=([], ),
num_computed_tokens=0,
)

View File

@ -89,8 +89,8 @@ class BlockPool:
BlockHashWithGroupId(block_hash, group_id))
if not cached_blocks_one_group:
return None
first_block_id = next(iter(cached_blocks_one_group))
cached_blocks.append(cached_blocks_one_group[first_block_id])
first_block = next(iter(cached_blocks_one_group.values()))
cached_blocks.append(first_block)
return cached_blocks
def cache_full_blocks(
@ -260,7 +260,7 @@ class BlockPool:
return True
return False
def touch(self, blocks: list[list[KVCacheBlock]]) -> None:
def touch(self, blocks: tuple[list[KVCacheBlock], ...]) -> None:
"""Touch a block increases its reference count by 1, and may remove
the block from the free queue. This is used when a block is hit by
another request with the same prefix.
@ -299,7 +299,7 @@ class BlockPool:
bool: True if the prefix cache is successfully reset,
False otherwise.
"""
num_used_blocks = (self.num_gpu_blocks - self.get_num_free_blocks())
num_used_blocks = self.num_gpu_blocks - self.get_num_free_blocks()
if num_used_blocks != 1: # The null block is always marked as used
logger.warning(
"Failed to reset prefix cache because some "

View File

@ -5,8 +5,7 @@ from typing import Callable, Optional
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
from vllm.v1.core.single_type_kv_cache_manager import (
FullAttentionManager, SingleTypeKVCacheManager,
get_manager_for_kv_cache_spec)
FullAttentionManager, get_manager_for_kv_cache_spec)
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
from vllm.v1.request import Request
@ -30,25 +29,21 @@ class KVCacheCoordinator(ABC):
self.block_pool = BlockPool(kv_cache_config.num_blocks, enable_caching,
enable_kv_cache_events)
self.single_type_managers: list[SingleTypeKVCacheManager] = []
# Needs special handling for find_longest_cache_hit if eagle is enabled
self.use_eagle = use_eagle
for i in range(len(self.kv_cache_config.kv_cache_groups)):
kv_cache_spec = self.kv_cache_config.kv_cache_groups[
i].kv_cache_spec
self.single_type_managers.append(
get_manager_for_kv_cache_spec(
kv_cache_spec=kv_cache_spec,
block_pool=self.block_pool,
kv_cache_group_id=i,
caching_hash_fn=caching_hash_fn,
))
self.single_type_managers = tuple(
get_manager_for_kv_cache_spec(
kv_cache_spec=kv_cache_group.kv_cache_spec,
block_pool=self.block_pool,
kv_cache_group_id=i,
caching_hash_fn=caching_hash_fn,
) for i, kv_cache_group in enumerate(
self.kv_cache_config.kv_cache_groups))
def get_num_blocks_to_allocate(
self, request_id: str, num_tokens: int,
new_computed_blocks: list[list[KVCacheBlock]]) -> int:
new_computed_blocks: tuple[list[KVCacheBlock], ...]) -> int:
"""
Get the number of blocks needed to be allocated for the request.
@ -70,7 +65,7 @@ class KVCacheCoordinator(ABC):
def save_new_computed_blocks(
self, request_id: str,
new_computed_blocks: list[list[KVCacheBlock]]) -> None:
new_computed_blocks: tuple[list[KVCacheBlock], ...]) -> None:
"""
Add the new computed blocks to the request.
@ -84,7 +79,7 @@ class KVCacheCoordinator(ABC):
new_computed_blocks[i])
def allocate_new_blocks(self, request_id: str,
num_tokens: int) -> list[list[KVCacheBlock]]:
num_tokens: int) -> tuple[list[KVCacheBlock], ...]:
"""
Allocate new blocks for the request to give it at least `num_tokens`
token slots.
@ -97,11 +92,9 @@ class KVCacheCoordinator(ABC):
Returns:
The new allocated blocks.
"""
new_blocks = []
for manager in self.single_type_managers:
new_blocks.append(
manager.allocate_new_blocks(request_id, num_tokens))
return new_blocks
return tuple(
manager.allocate_new_blocks(request_id, num_tokens)
for manager in self.single_type_managers)
def cache_blocks(self, request: Request, block_hashes: list[BlockHash],
num_computed_tokens: int) -> None:
@ -159,19 +152,20 @@ class KVCacheCoordinator(ABC):
for manager in self.single_type_managers:
manager.remove_skipped_blocks(request_id, num_computed_tokens)
def get_blocks(self, request_id: str) -> list[list[KVCacheBlock]]:
def get_blocks(self, request_id: str) -> tuple[list[KVCacheBlock], ...]:
"""
Get the blocks for the request.
"""
return [
return tuple(
manager.req_to_blocks.get(request_id) or []
for manager in self.single_type_managers
]
for manager in self.single_type_managers)
@abstractmethod
def find_longest_cache_hit(
self, block_hashes: list[BlockHash],
max_cache_hit_length: int) -> tuple[list[list[KVCacheBlock]], int]:
self,
block_hashes: list[BlockHash],
max_cache_hit_length: int,
) -> tuple[tuple[list[KVCacheBlock], ...], int]:
pass
@ -195,8 +189,10 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
"UnitaryKVCacheCoordinator assumes only one kv cache group")
def find_longest_cache_hit(
self, block_hashes: list[BlockHash],
max_cache_hit_length: int) -> tuple[list[list[KVCacheBlock]], int]:
self,
block_hashes: list[BlockHash],
max_cache_hit_length: int,
) -> tuple[tuple[list[KVCacheBlock], ...], int]:
hit_blocks = self.single_type_managers[0].find_longest_cache_hit(
block_hashes=block_hashes,
max_length=max_cache_hit_length,
@ -275,11 +271,24 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
"KVCacheCoordinator assumes the block_size of full attention "
"layers is divisible by other layers now.")
if max(self.full_attention_group_ids) < min(self.other_group_ids):
self.full_attn_first = True
elif max(self.other_group_ids) < min(self.full_attention_group_ids):
self.full_attn_first = False
else:
raise ValueError(
"HybridKVCacheCoordinator assumes the full "
"attention group ids and other attention group ids "
"do not interleave, either full attention group ids "
"are before other attention group ids or vice versa."
"This is for simplifying merging hit_blocks_full_attn and "
"hit_blocks_other_attn to hit_blocks.")
def find_longest_cache_hit(
self,
block_hashes: list[BlockHash],
max_cache_hit_length: int,
) -> tuple[list[list[KVCacheBlock]], int]:
) -> tuple[tuple[list[KVCacheBlock], ...], int]:
"""
Find the longest cache hit for the request.
@ -318,27 +327,25 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
))
hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size
# NOTE: the prefix cache hit length must be a multiply of block_size as
# NOTE: the prefix cache hit length must be a multiple of block_size as
# we don't support partial block cache hit yet. The cache hit length
# of other attention is ensured to be a multiply of the block size of
# of other attention is ensured to be a multiple of the block size of
# full attention layers in current implementation, because hit_length is
# a multiply of other attention's block size, and other attention's
# block size is a multiply of full attention's block size (verified in
# a multiple of other attention's block size, and other attention's
# block size is a multiple of full attention's block size (verified in
# `verify_and_split_kv_cache_groups`).
assert hit_length % self.full_attention_block_size == 0
# Truncate the full attention cache hit to the length of the
# cache hit of the other attention.
for i in range(len(hit_blocks_full_attn)):
del hit_blocks_full_attn[i][hit_length //
self.full_attention_block_size:]
for group_hit_blocks in hit_blocks_full_attn:
del group_hit_blocks[hit_length // self.full_attention_block_size:]
# Merge the hit blocks of full attention and other attention.
hit_blocks = hit_blocks_other_attn
for group_id, blocks in enumerate(hit_blocks_full_attn):
# NOTE: there is only one full attention group in most cases. So
# the time complexity of insert is fine.
hit_blocks.insert(group_id, blocks)
if self.full_attn_first:
hit_blocks = hit_blocks_full_attn + hit_blocks_other_attn
else:
hit_blocks = hit_blocks_other_attn + hit_blocks_full_attn
return hit_blocks, hit_length
@ -351,8 +358,6 @@ def get_kv_cache_coordinator(
use_eagle, enable_caching,
caching_hash_fn,
enable_kv_cache_events)
else:
return HybridKVCacheCoordinator(kv_cache_config, max_model_len,
use_eagle, enable_caching,
caching_hash_fn,
enable_kv_cache_events)
return HybridKVCacheCoordinator(kv_cache_config, max_model_len, use_eagle,
enable_caching, caching_hash_fn,
enable_kv_cache_events)

View File

@ -21,11 +21,11 @@ logger = init_logger(__name__)
@dataclass
class KVCacheBlocks:
"""
The allocation result of KVCacheManager, work as the interface between
Scheduler and KVCacheManager, to hide KVCacheManager's internal data
The allocation result of KVCacheManager, work as the interface between
Scheduler and KVCacheManager, to hide KVCacheManager's internal data
structure from the Scheduler.
"""
blocks: list[list[KVCacheBlock]]
blocks: tuple[list[KVCacheBlock], ...]
"""
blocks[i][j] refers to the i-th kv_cache_group and the j-th block of tokens.
We don't use block of tokens as the outer dimension because it assumes all
@ -37,21 +37,19 @@ class KVCacheBlocks:
def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks":
"""Adds two KVCacheBlocks instances."""
return KVCacheBlocks(
[blk1 + blk2 for blk1, blk2 in zip(self.blocks, other.blocks)])
tuple(blk1 + blk2
for blk1, blk2 in zip(self.blocks, other.blocks)))
def get_block_ids(self) -> list[list[int]]:
def get_block_ids(self) -> tuple[list[int], ...]:
"""
Converts the KVCacheBlocks instance to block_ids.
Returns:
list[list[int]]: A two-level list where
* the outer list corresponds to KV cache groups
tuple[list[int], ...]: A tuple of lists where
* the outer tuple corresponds to KV cache groups
* each inner list contains the block_ids of the blocks in that group
"""
block_ids = []
for group in self.blocks:
block_ids.append([blk.block_id for blk in group])
return block_ids
return tuple([blk.block_id for blk in group] for group in self.blocks)
def get_unhashed_block_ids(self) -> list[int]:
"""Get block_ids of unhashed blocks from KVCacheBlocks instance."""
@ -63,7 +61,7 @@ class KVCacheBlocks:
def new_empty(self) -> "KVCacheBlocks":
"""Creates a new KVCacheBlocks instance with no blocks."""
return KVCacheBlocks([[] for _ in range(len(self.blocks))])
return KVCacheBlocks(tuple([] for _ in range(len(self.blocks))))
class KVCacheManager:
@ -232,9 +230,8 @@ class KVCacheManager:
if new_computed_blocks is not None:
new_computed_block_list = new_computed_blocks.blocks
else:
new_computed_block_list = [
[] for _ in range(len(self.kv_cache_config.kv_cache_groups))
]
new_computed_block_list = tuple(
[] for _ in range(len(self.kv_cache_config.kv_cache_groups)))
# Free the blocks that are skipped during the attention computation
# (e.g., tokens outside the sliding window).
@ -267,7 +264,7 @@ class KVCacheManager:
if self.enable_caching:
self.block_pool.touch(new_computed_block_list)
else:
assert all(not blocks for blocks in new_computed_block_list), (
assert not any(new_computed_block_list), (
"Computed blocks should be empty when "
"prefix caching is disabled")
@ -378,17 +375,18 @@ class KVCacheManager:
"""
return self.block_pool.take_events()
def get_block_ids(self, request_id: str) -> list[list[int]]:
def get_block_ids(self, request_id: str) -> tuple[list[int], ...]:
"""Get the block ids of a request."""
return KVCacheBlocks(
self.coordinator.get_blocks(request_id)).get_block_ids()
def cache_blocks(self, request: Request, block_hashes: list[BlockHash],
num_computed_tokens: int) -> None:
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
"""Cache the blocks for the request."""
block_hashes = self.req_to_block_hashes[request.request_id]
self.coordinator.cache_blocks(request, block_hashes,
num_computed_tokens)
def create_empty_block_list(self) -> KVCacheBlocks:
"""Creates a new KVCacheBlocks instance with no blocks."""
return KVCacheBlocks([[] for _ in range(self.num_kv_cache_groups)])
return KVCacheBlocks(tuple([]
for _ in range(self.num_kv_cache_groups)))

View File

@ -27,7 +27,7 @@ class NewRequestData:
mm_hashes: list[str]
mm_positions: list[PlaceholderRange]
sampling_params: SamplingParams
block_ids: list[list[int]]
block_ids: tuple[list[int], ...]
num_computed_tokens: int
lora_request: Optional[LoRARequest]
@ -35,7 +35,7 @@ class NewRequestData:
def from_request(
cls,
request: Request,
block_ids: list[list[int]],
block_ids: tuple[list[int], ...],
) -> NewRequestData:
return cls(
req_id=request.request_id,
@ -86,7 +86,7 @@ class CachedRequestData:
# request's block IDs instead of appending to the existing block IDs.
resumed_from_preemption: bool
new_token_ids: list[int]
new_block_ids: list[list[int]]
new_block_ids: tuple[list[int], ...]
num_computed_tokens: int
@classmethod
@ -95,7 +95,7 @@ class CachedRequestData:
request: Request,
resumed_from_preemption: bool,
new_token_ids: list[int],
new_block_ids: list[list[int]],
new_block_ids: tuple[list[int], ...],
) -> CachedRequestData:
return cls(
req_id=request.request_id,

View File

@ -180,7 +180,7 @@ class Scheduler(SchedulerInterface):
# uses structured decoding.
structured_output_request_ids: dict[str, int] = {}
req_to_new_block_ids: dict[str, list[list[int]]] = {}
req_to_new_block_ids: dict[str, tuple[list[int], ...]] = {}
num_scheduled_tokens: dict[str, int] = {}
token_budget = self.max_num_scheduled_tokens
# Encoder-related.
@ -471,7 +471,7 @@ class Scheduler(SchedulerInterface):
token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens
# Count the number of prifix cached tokens.
# Count the number of prefix cached tokens.
if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens
# Encoder-related.
@ -588,7 +588,7 @@ class Scheduler(SchedulerInterface):
request: Request,
num_scheduled_tokens: int,
num_scheduled_spec_tokens: int,
new_block_ids: list[list[int]],
new_block_ids: tuple[list[int], ...],
resumed_from_preemption: bool,
) -> CachedRequestData:
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
@ -1015,11 +1015,7 @@ class Scheduler(SchedulerInterface):
num_computed_tokens = min(num_computed_tokens, request.num_tokens)
if num_computed_tokens == request.num_tokens:
num_computed_tokens -= 1
self.kv_cache_manager.cache_blocks(
request,
self.kv_cache_manager.req_to_block_hashes[request.request_id],
num_computed_tokens,
)
self.kv_cache_manager.cache_blocks(request, num_computed_tokens)
# Update the request state for scheduling.
request.num_computed_tokens = num_computed_tokens

View File

@ -197,7 +197,7 @@ class SingleTypeKVCacheManager(ABC):
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
) -> list[list[KVCacheBlock]]:
) -> tuple[list[KVCacheBlock], ...]:
"""
Get the longest cache hit prefix of the blocks that is not longer than
`max_length`. The prefix should be a common prefix hit for all the
@ -222,7 +222,7 @@ class SingleTypeKVCacheManager(ABC):
element is a list of cached blocks for the i-th kv cache group
in `kv_cache_group_ids`.
For example, sliding window manager should return a list like
[[NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)]] for block size 4
([NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)]) for block size 4
and sliding window 8 and len(kv_cache_group_ids) = 1.
"""
@ -254,27 +254,25 @@ class FullAttentionManager(SingleTypeKVCacheManager):
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
) -> list[list[KVCacheBlock]]:
) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(kv_cache_spec, FullAttentionSpec), (
"FullAttentionManager can only be used for full attention groups")
computed_blocks: list[list[KVCacheBlock]] = [
[] for _ in range(len(kv_cache_group_ids))
]
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
[] for _ in range(len(kv_cache_group_ids)))
max_num_blocks = max_length // kv_cache_spec.block_size
for i in range(max_num_blocks):
block_hash = block_hashes[i]
for i, block_hash in zip(range(max_num_blocks), block_hashes):
# block_hashes is a chain of block hashes. If a block hash is not
# in the cached_block_hash_to_id, the following block hashes are
# not computed yet for sure.
if cached_block := block_pool.get_cached_block(
block_hash, kv_cache_group_ids):
for j in range(len(kv_cache_group_ids)):
computed_blocks[j].append(cached_block[j])
for computed, cached in zip(computed_blocks, cached_block):
computed.append(cached)
else:
break
if use_eagle and len(computed_blocks[0]) > 0:
for j in range(len(kv_cache_group_ids)):
computed_blocks[j].pop()
if use_eagle and computed_blocks[0]:
for computed in computed_blocks:
computed.pop()
return computed_blocks
def remove_skipped_blocks(self, request_id: str,
@ -311,7 +309,7 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
) -> list[list[KVCacheBlock]]:
) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(kv_cache_spec, SlidingWindowSpec), (
"SlidingWindowManager can only be used for sliding window groups")
@ -332,23 +330,23 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
# sliding_window_contiguous_blocks),
# which is good for low cache hit rate scenarios.
max_num_blocks = max_length // kv_cache_spec.block_size
computed_blocks = [[block_pool.null_block] * max_num_blocks
for _ in range(len(kv_cache_group_ids))]
computed_blocks = tuple([block_pool.null_block] * max_num_blocks
for _ in range(len(kv_cache_group_ids)))
num_contiguous_blocks = 0
match_found = False
# Search from right to left and early stop when a match is found.
for i in range(max_num_blocks - 1, -1, -1):
if cached_block := block_pool.get_cached_block(
block_hashes[i], kv_cache_group_ids):
for j in range(len(kv_cache_group_ids)):
computed_blocks[j][i] = cached_block[j]
for computed, cached in zip(computed_blocks, cached_block):
computed[i] = cached
num_contiguous_blocks += 1
if (num_contiguous_blocks >= sliding_window_contiguous_blocks):
if num_contiguous_blocks >= sliding_window_contiguous_blocks:
# Trim the trailing blocks.
# E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3]
# when sliding_window_contiguous_blocks=2.
for j in range(len(kv_cache_group_ids)):
del computed_blocks[j][i + num_contiguous_blocks:]
for computed in computed_blocks:
del computed[i + num_contiguous_blocks:]
match_found = True
break
else:
@ -356,11 +354,11 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
if not match_found:
# The first `num_contiguous_blocks` is a cache hit even if
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
for j in range(len(kv_cache_group_ids)):
del computed_blocks[j][num_contiguous_blocks:]
if use_eagle and len(computed_blocks[0]) > 0:
for j in range(len(kv_cache_group_ids)):
computed_blocks[j].pop()
for computed in computed_blocks:
del computed[num_contiguous_blocks:]
if use_eagle and computed_blocks[0]:
for computed in computed_blocks:
computed.pop()
return computed_blocks
def remove_skipped_blocks(self, request_id: str,

View File

@ -112,11 +112,12 @@ class MultiGroupBlockTable:
for block_size in block_sizes
]
def append_row(self, block_ids: list[list[int]], row_idx: int) -> None:
def append_row(self, block_ids: tuple[list[int], ...],
row_idx: int) -> None:
for i, block_table in enumerate(self.block_tables):
block_table.append_row(block_ids[i], row_idx)
def add_row(self, block_ids: list[list[int]], row_idx: int) -> None:
def add_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:
for i, block_table in enumerate(self.block_tables):
block_table.add_row(block_ids[i], row_idx)

View File

@ -30,7 +30,7 @@ class CachedRequestState:
sampling_params: SamplingParams
generator: Optional[torch.Generator]
block_ids: list[list[int]]
block_ids: tuple[list[int], ...]
num_computed_tokens: int
output_token_ids: list[int]