[Core] Use tuple for kv cache group block ids (#19175)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-06-09 22:01:17 -07:00 committed by GitHub
parent 6cd4ae8acd
commit 646d62f636
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 140 additions and 142 deletions

View File

@ -117,7 +117,7 @@ def test_prefill(hash_algo):
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
# Check full block metadata
parent_block_hash = None
@ -141,13 +141,13 @@ def test_prefill(hash_algo):
req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
assert computed_blocks.get_block_ids() == ([1, 2, 3], )
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[5]]
assert blocks.get_block_ids() == ([5], )
for block in computed_blocks.blocks[0]:
assert block.ref_cnt == 2
@ -175,13 +175,13 @@ def test_prefill(hash_algo):
req2 = make_request("2", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(manager.req_to_block_hashes[req2.request_id]) == 3
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
assert computed_blocks.get_block_ids() == ([1, 2, 3], )
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req2, num_new_tokens,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[6]]
assert blocks.get_block_ids() == ([6], )
# Although we only have 6 free blocks, we have 8 blocks in
# the free block queue due to lazy removal.
@ -205,7 +205,7 @@ def test_prefill(hash_algo):
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
# This block ID order also checks the eviction order.
assert blocks.get_block_ids() == [[7, 8, 9, 10, 4, 5, 6, 3, 2, 1]]
assert blocks.get_block_ids() == ([7, 8, 9, 10, 4, 5, 6, 3, 2, 1], )
assert manager.block_pool.free_block_queue.num_free_blocks == 0
assert manager.block_pool.free_block_queue.free_list_head is None
assert manager.block_pool.free_block_queue.free_list_tail is None
@ -236,8 +236,8 @@ def test_prefill_hybrid_model():
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[1, 2, 3, 4], [5, 6, 7, 8],
[9, 10, 11, 12]]
assert blocks.get_block_ids() == ([1, 2, 3, 4], [5, 6, 7,
8], [9, 10, 11, 12])
# Check full block metadata
parent_block_hash = None
@ -263,14 +263,14 @@ def test_prefill_hybrid_model():
req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert computed_blocks.get_block_ids() == [[1, 2, 3], [0, 6, 7],
[0, 10, 11]]
assert computed_blocks.get_block_ids() == ([1, 2, 3], [0, 6,
7], [0, 10, 11])
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[13], [14], [15]]
assert blocks.get_block_ids() == ([13], [14], [15])
for block_per_group in computed_blocks.blocks:
for block in block_per_group:
if block != manager.block_pool.null_block:
@ -374,7 +374,7 @@ def test_prefill_plp():
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
req0_block_hashes = [b.block_hash for b in blocks.blocks[0]]
# Check full block metadata
@ -400,13 +400,13 @@ def test_prefill_plp():
req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
assert computed_blocks.get_block_ids() == ([1, 2, 3], )
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[5]]
assert blocks.get_block_ids() == ([5], )
for block in computed_blocks.blocks[0]:
assert block.ref_cnt == 2
@ -444,7 +444,7 @@ def test_prefill_plp():
block_ids = blocks.get_block_ids()
# Duplicate cached blocks have different ids but same hashes vs request #0
assert [b.block_hash for b in blocks.blocks[0]] == req0_block_hashes
assert block_ids != [[1, 2, 3, 4]]
assert block_ids != ([1, 2, 3, 4], )
# Request #2 block hashes are valid since request #0 hashes are.
# Check block reference counts.
@ -474,7 +474,7 @@ def test_decode():
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
# Append slots without allocating a new block.
req0.num_computed_tokens = 55
@ -546,12 +546,12 @@ def test_evict():
# Touch the first 2 blocks.
req2 = make_request("2", list(range(2 * 16 + 3)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert computed_blocks.get_block_ids() == [[1, 2]]
assert computed_blocks.get_block_ids() == ([1, 2], )
assert num_computed_tokens == 2 * 16
blocks = manager.allocate_slots(req2, 3,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[10]]
assert blocks.get_block_ids() == ([10], )
assert manager.block_pool.free_block_queue.num_free_blocks == 7
@ -865,7 +865,7 @@ def test_mm_prefix_caching():
blocks = manager.allocate_slots(req0, 59,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
req0.num_computed_tokens = 59
# Append slots without allocating a new block.
@ -926,7 +926,7 @@ def test_cache_key_salting():
blocks = manager.allocate_slots(req0, 59,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
req0.num_computed_tokens = 59
# Append slots without allocating a new block.
@ -1042,7 +1042,7 @@ def test_reset_prefix_cache():
all_token_ids = full_block_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids)
blocks = manager.allocate_slots(req0, 55)
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
unique_token_ids = [4] * 7
all_token_ids = full_block_token_ids + unique_token_ids
@ -1053,7 +1053,7 @@ def test_reset_prefix_cache():
blocks = manager.allocate_slots(req1, 7,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[5]]
assert blocks.get_block_ids() == ([5], )
# Failed to reset prefix cache because some blocks are not freed yet.
assert not manager.reset_prefix_cache()

View File

@ -71,7 +71,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
mm_hashes=[],
mm_positions=[],
sampling_params=SamplingParams(),
block_ids=[[0]], # block_ids should be list[list[int]]
block_ids=([0], ), # block_ids should be tuple[list[int]]
num_computed_tokens=0,
lora_request=None,
))
@ -116,10 +116,10 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
# This is safe since we currently only use single KV cache groups
block_table = multi_group_block_table[0]
# req_state.block_ids is now list[list[int]] for MultiGroupBlockTable
# req_state.block_ids is now tuple[list[int], ...] for MultiGroupBlockTable
# Extract the first group's block IDs
if isinstance(req_state.block_ids[0], list):
# New format: list[list[int]] - extract first group
# New format: tuple[list[int], ...] - extract first group
req_block_ids = req_state.block_ids[0]
else:
# Legacy format: list[int] - use directly
@ -210,7 +210,7 @@ def test_update_states_request_resumed(model_runner):
req_id=req_id,
resumed_from_preemption=False,
new_token_ids=[],
new_block_ids=[[]],
new_block_ids=([], ),
num_computed_tokens=0,
)

View File

@ -203,7 +203,7 @@ def _construct_cached_request_state(req_id_suffix: int):
sampling_params=_create_sampling_params(),
mm_inputs=[],
mm_positions=[],
block_ids=[[]],
block_ids=([], ),
generator=None,
num_computed_tokens=len(output_token_ids),
output_token_ids=output_token_ids,

View File

@ -123,7 +123,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
mm_hashes=[],
mm_positions=[],
sampling_params=SamplingParams(),
block_ids=[[0]],
block_ids=([0], ),
num_computed_tokens=0,
lora_request=None,
))
@ -251,7 +251,7 @@ def test_update_states_request_resumed(model_runner):
req_id=req_id,
resumed_from_preemption=False,
new_token_ids=[],
new_block_ids=[[]],
new_block_ids=([], ),
num_computed_tokens=0,
)

View File

@ -89,8 +89,8 @@ class BlockPool:
BlockHashWithGroupId(block_hash, group_id))
if not cached_blocks_one_group:
return None
first_block_id = next(iter(cached_blocks_one_group))
cached_blocks.append(cached_blocks_one_group[first_block_id])
first_block = next(iter(cached_blocks_one_group.values()))
cached_blocks.append(first_block)
return cached_blocks
def cache_full_blocks(
@ -260,7 +260,7 @@ class BlockPool:
return True
return False
def touch(self, blocks: list[list[KVCacheBlock]]) -> None:
def touch(self, blocks: tuple[list[KVCacheBlock], ...]) -> None:
"""Touch a block increases its reference count by 1, and may remove
the block from the free queue. This is used when a block is hit by
another request with the same prefix.
@ -299,7 +299,7 @@ class BlockPool:
bool: True if the prefix cache is successfully reset,
False otherwise.
"""
num_used_blocks = (self.num_gpu_blocks - self.get_num_free_blocks())
num_used_blocks = self.num_gpu_blocks - self.get_num_free_blocks()
if num_used_blocks != 1: # The null block is always marked as used
logger.warning(
"Failed to reset prefix cache because some "

View File

@ -5,8 +5,7 @@ from typing import Callable, Optional
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
from vllm.v1.core.single_type_kv_cache_manager import (
FullAttentionManager, SingleTypeKVCacheManager,
get_manager_for_kv_cache_spec)
FullAttentionManager, get_manager_for_kv_cache_spec)
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
from vllm.v1.request import Request
@ -30,25 +29,21 @@ class KVCacheCoordinator(ABC):
self.block_pool = BlockPool(kv_cache_config.num_blocks, enable_caching,
enable_kv_cache_events)
self.single_type_managers: list[SingleTypeKVCacheManager] = []
# Needs special handling for find_longest_cache_hit if eagle is enabled
self.use_eagle = use_eagle
for i in range(len(self.kv_cache_config.kv_cache_groups)):
kv_cache_spec = self.kv_cache_config.kv_cache_groups[
i].kv_cache_spec
self.single_type_managers.append(
get_manager_for_kv_cache_spec(
kv_cache_spec=kv_cache_spec,
block_pool=self.block_pool,
kv_cache_group_id=i,
caching_hash_fn=caching_hash_fn,
))
self.single_type_managers = tuple(
get_manager_for_kv_cache_spec(
kv_cache_spec=kv_cache_group.kv_cache_spec,
block_pool=self.block_pool,
kv_cache_group_id=i,
caching_hash_fn=caching_hash_fn,
) for i, kv_cache_group in enumerate(
self.kv_cache_config.kv_cache_groups))
def get_num_blocks_to_allocate(
self, request_id: str, num_tokens: int,
new_computed_blocks: list[list[KVCacheBlock]]) -> int:
new_computed_blocks: tuple[list[KVCacheBlock], ...]) -> int:
"""
Get the number of blocks needed to be allocated for the request.
@ -70,7 +65,7 @@ class KVCacheCoordinator(ABC):
def save_new_computed_blocks(
self, request_id: str,
new_computed_blocks: list[list[KVCacheBlock]]) -> None:
new_computed_blocks: tuple[list[KVCacheBlock], ...]) -> None:
"""
Add the new computed blocks to the request.
@ -84,7 +79,7 @@ class KVCacheCoordinator(ABC):
new_computed_blocks[i])
def allocate_new_blocks(self, request_id: str,
num_tokens: int) -> list[list[KVCacheBlock]]:
num_tokens: int) -> tuple[list[KVCacheBlock], ...]:
"""
Allocate new blocks for the request to give it at least `num_tokens`
token slots.
@ -97,11 +92,9 @@ class KVCacheCoordinator(ABC):
Returns:
The new allocated blocks.
"""
new_blocks = []
for manager in self.single_type_managers:
new_blocks.append(
manager.allocate_new_blocks(request_id, num_tokens))
return new_blocks
return tuple(
manager.allocate_new_blocks(request_id, num_tokens)
for manager in self.single_type_managers)
def cache_blocks(self, request: Request, block_hashes: list[BlockHash],
num_computed_tokens: int) -> None:
@ -159,19 +152,20 @@ class KVCacheCoordinator(ABC):
for manager in self.single_type_managers:
manager.remove_skipped_blocks(request_id, num_computed_tokens)
def get_blocks(self, request_id: str) -> list[list[KVCacheBlock]]:
def get_blocks(self, request_id: str) -> tuple[list[KVCacheBlock], ...]:
"""
Get the blocks for the request.
"""
return [
return tuple(
manager.req_to_blocks.get(request_id) or []
for manager in self.single_type_managers
]
for manager in self.single_type_managers)
@abstractmethod
def find_longest_cache_hit(
self, block_hashes: list[BlockHash],
max_cache_hit_length: int) -> tuple[list[list[KVCacheBlock]], int]:
self,
block_hashes: list[BlockHash],
max_cache_hit_length: int,
) -> tuple[tuple[list[KVCacheBlock], ...], int]:
pass
@ -195,8 +189,10 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
"UnitaryKVCacheCoordinator assumes only one kv cache group")
def find_longest_cache_hit(
self, block_hashes: list[BlockHash],
max_cache_hit_length: int) -> tuple[list[list[KVCacheBlock]], int]:
self,
block_hashes: list[BlockHash],
max_cache_hit_length: int,
) -> tuple[tuple[list[KVCacheBlock], ...], int]:
hit_blocks = self.single_type_managers[0].find_longest_cache_hit(
block_hashes=block_hashes,
max_length=max_cache_hit_length,
@ -275,11 +271,24 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
"KVCacheCoordinator assumes the block_size of full attention "
"layers is divisible by other layers now.")
if max(self.full_attention_group_ids) < min(self.other_group_ids):
self.full_attn_first = True
elif max(self.other_group_ids) < min(self.full_attention_group_ids):
self.full_attn_first = False
else:
raise ValueError(
"HybridKVCacheCoordinator assumes the full "
"attention group ids and other attention group ids "
"do not interleave, either full attention group ids "
"are before other attention group ids or vice versa."
"This is for simplifying merging hit_blocks_full_attn and "
"hit_blocks_other_attn to hit_blocks.")
def find_longest_cache_hit(
self,
block_hashes: list[BlockHash],
max_cache_hit_length: int,
) -> tuple[list[list[KVCacheBlock]], int]:
) -> tuple[tuple[list[KVCacheBlock], ...], int]:
"""
Find the longest cache hit for the request.
@ -318,27 +327,25 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
))
hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size
# NOTE: the prefix cache hit length must be a multiply of block_size as
# NOTE: the prefix cache hit length must be a multiple of block_size as
# we don't support partial block cache hit yet. The cache hit length
# of other attention is ensured to be a multiply of the block size of
# of other attention is ensured to be a multiple of the block size of
# full attention layers in current implementation, because hit_length is
# a multiply of other attention's block size, and other attention's
# block size is a multiply of full attention's block size (verified in
# a multiple of other attention's block size, and other attention's
# block size is a multiple of full attention's block size (verified in
# `verify_and_split_kv_cache_groups`).
assert hit_length % self.full_attention_block_size == 0
# Truncate the full attention cache hit to the length of the
# cache hit of the other attention.
for i in range(len(hit_blocks_full_attn)):
del hit_blocks_full_attn[i][hit_length //
self.full_attention_block_size:]
for group_hit_blocks in hit_blocks_full_attn:
del group_hit_blocks[hit_length // self.full_attention_block_size:]
# Merge the hit blocks of full attention and other attention.
hit_blocks = hit_blocks_other_attn
for group_id, blocks in enumerate(hit_blocks_full_attn):
# NOTE: there is only one full attention group in most cases. So
# the time complexity of insert is fine.
hit_blocks.insert(group_id, blocks)
if self.full_attn_first:
hit_blocks = hit_blocks_full_attn + hit_blocks_other_attn
else:
hit_blocks = hit_blocks_other_attn + hit_blocks_full_attn
return hit_blocks, hit_length
@ -351,8 +358,6 @@ def get_kv_cache_coordinator(
use_eagle, enable_caching,
caching_hash_fn,
enable_kv_cache_events)
else:
return HybridKVCacheCoordinator(kv_cache_config, max_model_len,
use_eagle, enable_caching,
caching_hash_fn,
enable_kv_cache_events)
return HybridKVCacheCoordinator(kv_cache_config, max_model_len, use_eagle,
enable_caching, caching_hash_fn,
enable_kv_cache_events)

View File

@ -25,7 +25,7 @@ class KVCacheBlocks:
Scheduler and KVCacheManager, to hide KVCacheManager's internal data
structure from the Scheduler.
"""
blocks: list[list[KVCacheBlock]]
blocks: tuple[list[KVCacheBlock], ...]
"""
blocks[i][j] refers to the i-th kv_cache_group and the j-th block of tokens.
We don't use block of tokens as the outer dimension because it assumes all
@ -37,21 +37,19 @@ class KVCacheBlocks:
def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks":
"""Adds two KVCacheBlocks instances."""
return KVCacheBlocks(
[blk1 + blk2 for blk1, blk2 in zip(self.blocks, other.blocks)])
tuple(blk1 + blk2
for blk1, blk2 in zip(self.blocks, other.blocks)))
def get_block_ids(self) -> list[list[int]]:
def get_block_ids(self) -> tuple[list[int], ...]:
"""
Converts the KVCacheBlocks instance to block_ids.
Returns:
list[list[int]]: A two-level list where
* the outer list corresponds to KV cache groups
tuple[list[int], ...]: A tuple of lists where
* the outer tuple corresponds to KV cache groups
* each inner list contains the block_ids of the blocks in that group
"""
block_ids = []
for group in self.blocks:
block_ids.append([blk.block_id for blk in group])
return block_ids
return tuple([blk.block_id for blk in group] for group in self.blocks)
def get_unhashed_block_ids(self) -> list[int]:
"""Get block_ids of unhashed blocks from KVCacheBlocks instance."""
@ -63,7 +61,7 @@ class KVCacheBlocks:
def new_empty(self) -> "KVCacheBlocks":
"""Creates a new KVCacheBlocks instance with no blocks."""
return KVCacheBlocks([[] for _ in range(len(self.blocks))])
return KVCacheBlocks(tuple([] for _ in range(len(self.blocks))))
class KVCacheManager:
@ -232,9 +230,8 @@ class KVCacheManager:
if new_computed_blocks is not None:
new_computed_block_list = new_computed_blocks.blocks
else:
new_computed_block_list = [
[] for _ in range(len(self.kv_cache_config.kv_cache_groups))
]
new_computed_block_list = tuple(
[] for _ in range(len(self.kv_cache_config.kv_cache_groups)))
# Free the blocks that are skipped during the attention computation
# (e.g., tokens outside the sliding window).
@ -267,7 +264,7 @@ class KVCacheManager:
if self.enable_caching:
self.block_pool.touch(new_computed_block_list)
else:
assert all(not blocks for blocks in new_computed_block_list), (
assert not any(new_computed_block_list), (
"Computed blocks should be empty when "
"prefix caching is disabled")
@ -378,17 +375,18 @@ class KVCacheManager:
"""
return self.block_pool.take_events()
def get_block_ids(self, request_id: str) -> list[list[int]]:
def get_block_ids(self, request_id: str) -> tuple[list[int], ...]:
"""Get the block ids of a request."""
return KVCacheBlocks(
self.coordinator.get_blocks(request_id)).get_block_ids()
def cache_blocks(self, request: Request, block_hashes: list[BlockHash],
num_computed_tokens: int) -> None:
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
"""Cache the blocks for the request."""
block_hashes = self.req_to_block_hashes[request.request_id]
self.coordinator.cache_blocks(request, block_hashes,
num_computed_tokens)
def create_empty_block_list(self) -> KVCacheBlocks:
"""Creates a new KVCacheBlocks instance with no blocks."""
return KVCacheBlocks([[] for _ in range(self.num_kv_cache_groups)])
return KVCacheBlocks(tuple([]
for _ in range(self.num_kv_cache_groups)))

View File

@ -27,7 +27,7 @@ class NewRequestData:
mm_hashes: list[str]
mm_positions: list[PlaceholderRange]
sampling_params: SamplingParams
block_ids: list[list[int]]
block_ids: tuple[list[int], ...]
num_computed_tokens: int
lora_request: Optional[LoRARequest]
@ -35,7 +35,7 @@ class NewRequestData:
def from_request(
cls,
request: Request,
block_ids: list[list[int]],
block_ids: tuple[list[int], ...],
) -> NewRequestData:
return cls(
req_id=request.request_id,
@ -86,7 +86,7 @@ class CachedRequestData:
# request's block IDs instead of appending to the existing block IDs.
resumed_from_preemption: bool
new_token_ids: list[int]
new_block_ids: list[list[int]]
new_block_ids: tuple[list[int], ...]
num_computed_tokens: int
@classmethod
@ -95,7 +95,7 @@ class CachedRequestData:
request: Request,
resumed_from_preemption: bool,
new_token_ids: list[int],
new_block_ids: list[list[int]],
new_block_ids: tuple[list[int], ...],
) -> CachedRequestData:
return cls(
req_id=request.request_id,

View File

@ -180,7 +180,7 @@ class Scheduler(SchedulerInterface):
# uses structured decoding.
structured_output_request_ids: dict[str, int] = {}
req_to_new_block_ids: dict[str, list[list[int]]] = {}
req_to_new_block_ids: dict[str, tuple[list[int], ...]] = {}
num_scheduled_tokens: dict[str, int] = {}
token_budget = self.max_num_scheduled_tokens
# Encoder-related.
@ -471,7 +471,7 @@ class Scheduler(SchedulerInterface):
token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens
# Count the number of prifix cached tokens.
# Count the number of prefix cached tokens.
if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens
# Encoder-related.
@ -588,7 +588,7 @@ class Scheduler(SchedulerInterface):
request: Request,
num_scheduled_tokens: int,
num_scheduled_spec_tokens: int,
new_block_ids: list[list[int]],
new_block_ids: tuple[list[int], ...],
resumed_from_preemption: bool,
) -> CachedRequestData:
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
@ -1015,11 +1015,7 @@ class Scheduler(SchedulerInterface):
num_computed_tokens = min(num_computed_tokens, request.num_tokens)
if num_computed_tokens == request.num_tokens:
num_computed_tokens -= 1
self.kv_cache_manager.cache_blocks(
request,
self.kv_cache_manager.req_to_block_hashes[request.request_id],
num_computed_tokens,
)
self.kv_cache_manager.cache_blocks(request, num_computed_tokens)
# Update the request state for scheduling.
request.num_computed_tokens = num_computed_tokens

View File

@ -197,7 +197,7 @@ class SingleTypeKVCacheManager(ABC):
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
) -> list[list[KVCacheBlock]]:
) -> tuple[list[KVCacheBlock], ...]:
"""
Get the longest cache hit prefix of the blocks that is not longer than
`max_length`. The prefix should be a common prefix hit for all the
@ -222,7 +222,7 @@ class SingleTypeKVCacheManager(ABC):
element is a list of cached blocks for the i-th kv cache group
in `kv_cache_group_ids`.
For example, sliding window manager should return a list like
[[NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)]] for block size 4
([NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)]) for block size 4
and sliding window 8 and len(kv_cache_group_ids) = 1.
"""
@ -254,27 +254,25 @@ class FullAttentionManager(SingleTypeKVCacheManager):
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
) -> list[list[KVCacheBlock]]:
) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(kv_cache_spec, FullAttentionSpec), (
"FullAttentionManager can only be used for full attention groups")
computed_blocks: list[list[KVCacheBlock]] = [
[] for _ in range(len(kv_cache_group_ids))
]
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
[] for _ in range(len(kv_cache_group_ids)))
max_num_blocks = max_length // kv_cache_spec.block_size
for i in range(max_num_blocks):
block_hash = block_hashes[i]
for i, block_hash in zip(range(max_num_blocks), block_hashes):
# block_hashes is a chain of block hashes. If a block hash is not
# in the cached_block_hash_to_id, the following block hashes are
# not computed yet for sure.
if cached_block := block_pool.get_cached_block(
block_hash, kv_cache_group_ids):
for j in range(len(kv_cache_group_ids)):
computed_blocks[j].append(cached_block[j])
for computed, cached in zip(computed_blocks, cached_block):
computed.append(cached)
else:
break
if use_eagle and len(computed_blocks[0]) > 0:
for j in range(len(kv_cache_group_ids)):
computed_blocks[j].pop()
if use_eagle and computed_blocks[0]:
for computed in computed_blocks:
computed.pop()
return computed_blocks
def remove_skipped_blocks(self, request_id: str,
@ -311,7 +309,7 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
) -> list[list[KVCacheBlock]]:
) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(kv_cache_spec, SlidingWindowSpec), (
"SlidingWindowManager can only be used for sliding window groups")
@ -332,23 +330,23 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
# sliding_window_contiguous_blocks),
# which is good for low cache hit rate scenarios.
max_num_blocks = max_length // kv_cache_spec.block_size
computed_blocks = [[block_pool.null_block] * max_num_blocks
for _ in range(len(kv_cache_group_ids))]
computed_blocks = tuple([block_pool.null_block] * max_num_blocks
for _ in range(len(kv_cache_group_ids)))
num_contiguous_blocks = 0
match_found = False
# Search from right to left and early stop when a match is found.
for i in range(max_num_blocks - 1, -1, -1):
if cached_block := block_pool.get_cached_block(
block_hashes[i], kv_cache_group_ids):
for j in range(len(kv_cache_group_ids)):
computed_blocks[j][i] = cached_block[j]
for computed, cached in zip(computed_blocks, cached_block):
computed[i] = cached
num_contiguous_blocks += 1
if (num_contiguous_blocks >= sliding_window_contiguous_blocks):
if num_contiguous_blocks >= sliding_window_contiguous_blocks:
# Trim the trailing blocks.
# E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3]
# when sliding_window_contiguous_blocks=2.
for j in range(len(kv_cache_group_ids)):
del computed_blocks[j][i + num_contiguous_blocks:]
for computed in computed_blocks:
del computed[i + num_contiguous_blocks:]
match_found = True
break
else:
@ -356,11 +354,11 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
if not match_found:
# The first `num_contiguous_blocks` is a cache hit even if
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
for j in range(len(kv_cache_group_ids)):
del computed_blocks[j][num_contiguous_blocks:]
if use_eagle and len(computed_blocks[0]) > 0:
for j in range(len(kv_cache_group_ids)):
computed_blocks[j].pop()
for computed in computed_blocks:
del computed[num_contiguous_blocks:]
if use_eagle and computed_blocks[0]:
for computed in computed_blocks:
computed.pop()
return computed_blocks
def remove_skipped_blocks(self, request_id: str,

View File

@ -112,11 +112,12 @@ class MultiGroupBlockTable:
for block_size in block_sizes
]
def append_row(self, block_ids: list[list[int]], row_idx: int) -> None:
def append_row(self, block_ids: tuple[list[int], ...],
row_idx: int) -> None:
for i, block_table in enumerate(self.block_tables):
block_table.append_row(block_ids[i], row_idx)
def add_row(self, block_ids: list[list[int]], row_idx: int) -> None:
def add_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:
for i, block_table in enumerate(self.block_tables):
block_table.add_row(block_ids[i], row_idx)

View File

@ -30,7 +30,7 @@ class CachedRequestState:
sampling_params: SamplingParams
generator: Optional[torch.Generator]
block_ids: list[list[int]]
block_ids: tuple[list[int], ...]
num_computed_tokens: int
output_token_ids: list[int]