From 646d62f636197d5d809d21090752700e094b8b86 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 9 Jun 2025 22:01:17 -0700 Subject: [PATCH] [Core] Use tuple for kv cache group block ids (#19175) Signed-off-by: Nick Hill --- tests/v1/core/test_prefix_caching.py | 44 ++++---- tests/v1/tpu/worker/test_tpu_model_runner.py | 8 +- tests/v1/worker/test_gpu_input_batch.py | 2 +- tests/v1/worker/test_gpu_model_runner.py | 4 +- vllm/v1/core/block_pool.py | 8 +- vllm/v1/core/kv_cache_coordinator.py | 101 ++++++++++--------- vllm/v1/core/kv_cache_manager.py | 38 ++++--- vllm/v1/core/sched/output.py | 8 +- vllm/v1/core/sched/scheduler.py | 12 +-- vllm/v1/core/single_type_kv_cache_manager.py | 50 +++++---- vllm/v1/worker/block_table.py | 5 +- vllm/v1/worker/gpu_input_batch.py | 2 +- 12 files changed, 140 insertions(+), 142 deletions(-) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index bf4cb539ebef..394336624aca 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -117,7 +117,7 @@ def test_prefill(hash_algo): blocks = manager.allocate_slots(req0, 55, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert blocks.get_block_ids() == [[1, 2, 3, 4]] + assert blocks.get_block_ids() == ([1, 2, 3, 4], ) # Check full block metadata parent_block_hash = None @@ -141,13 +141,13 @@ def test_prefill(hash_algo): req1 = make_request("1", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(manager.req_to_block_hashes[req1.request_id]) == 3 - assert computed_blocks.get_block_ids() == [[1, 2, 3]] + assert computed_blocks.get_block_ids() == ([1, 2, 3], ) assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req1, num_new_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert blocks.get_block_ids() == [[5]] + assert blocks.get_block_ids() == ([5], ) for block in computed_blocks.blocks[0]: assert block.ref_cnt == 2 @@ -175,13 +175,13 @@ def test_prefill(hash_algo): req2 = make_request("2", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(manager.req_to_block_hashes[req2.request_id]) == 3 - assert computed_blocks.get_block_ids() == [[1, 2, 3]] + assert computed_blocks.get_block_ids() == ([1, 2, 3], ) assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req2, num_new_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert blocks.get_block_ids() == [[6]] + assert blocks.get_block_ids() == ([6], ) # Although we only have 6 free blocks, we have 8 blocks in # the free block queue due to lazy removal. @@ -205,7 +205,7 @@ def test_prefill(hash_algo): len(computed_blocks.blocks[0]) * 16, computed_blocks) # This block ID order also checks the eviction order. - assert blocks.get_block_ids() == [[7, 8, 9, 10, 4, 5, 6, 3, 2, 1]] + assert blocks.get_block_ids() == ([7, 8, 9, 10, 4, 5, 6, 3, 2, 1], ) assert manager.block_pool.free_block_queue.num_free_blocks == 0 assert manager.block_pool.free_block_queue.free_list_head is None assert manager.block_pool.free_block_queue.free_list_tail is None @@ -236,8 +236,8 @@ def test_prefill_hybrid_model(): blocks = manager.allocate_slots(req0, 55, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert blocks.get_block_ids() == [[1, 2, 3, 4], [5, 6, 7, 8], - [9, 10, 11, 12]] + assert blocks.get_block_ids() == ([1, 2, 3, 4], [5, 6, 7, + 8], [9, 10, 11, 12]) # Check full block metadata parent_block_hash = None @@ -263,14 +263,14 @@ def test_prefill_hybrid_model(): req1 = make_request("1", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(manager.req_to_block_hashes[req1.request_id]) == 3 - assert computed_blocks.get_block_ids() == [[1, 2, 3], [0, 6, 7], - [0, 10, 11]] + assert computed_blocks.get_block_ids() == ([1, 2, 3], [0, 6, + 7], [0, 10, 11]) assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req1, num_new_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert blocks.get_block_ids() == [[13], [14], [15]] + assert blocks.get_block_ids() == ([13], [14], [15]) for block_per_group in computed_blocks.blocks: for block in block_per_group: if block != manager.block_pool.null_block: @@ -374,7 +374,7 @@ def test_prefill_plp(): blocks = manager.allocate_slots(req0, 55, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert blocks.get_block_ids() == [[1, 2, 3, 4]] + assert blocks.get_block_ids() == ([1, 2, 3, 4], ) req0_block_hashes = [b.block_hash for b in blocks.blocks[0]] # Check full block metadata @@ -400,13 +400,13 @@ def test_prefill_plp(): req1 = make_request("1", common_token_ids + unique_token_ids) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert len(manager.req_to_block_hashes[req1.request_id]) == 3 - assert computed_blocks.get_block_ids() == [[1, 2, 3]] + assert computed_blocks.get_block_ids() == ([1, 2, 3], ) assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 blocks = manager.allocate_slots(req1, num_new_tokens, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert blocks.get_block_ids() == [[5]] + assert blocks.get_block_ids() == ([5], ) for block in computed_blocks.blocks[0]: assert block.ref_cnt == 2 @@ -444,7 +444,7 @@ def test_prefill_plp(): block_ids = blocks.get_block_ids() # Duplicate cached blocks have different ids but same hashes vs request #0 assert [b.block_hash for b in blocks.blocks[0]] == req0_block_hashes - assert block_ids != [[1, 2, 3, 4]] + assert block_ids != ([1, 2, 3, 4], ) # Request #2 block hashes are valid since request #0 hashes are. # Check block reference counts. @@ -474,7 +474,7 @@ def test_decode(): blocks = manager.allocate_slots(req0, 55, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert blocks.get_block_ids() == [[1, 2, 3, 4]] + assert blocks.get_block_ids() == ([1, 2, 3, 4], ) # Append slots without allocating a new block. req0.num_computed_tokens = 55 @@ -546,12 +546,12 @@ def test_evict(): # Touch the first 2 blocks. req2 = make_request("2", list(range(2 * 16 + 3))) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert computed_blocks.get_block_ids() == [[1, 2]] + assert computed_blocks.get_block_ids() == ([1, 2], ) assert num_computed_tokens == 2 * 16 blocks = manager.allocate_slots(req2, 3, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert blocks.get_block_ids() == [[10]] + assert blocks.get_block_ids() == ([10], ) assert manager.block_pool.free_block_queue.num_free_blocks == 7 @@ -865,7 +865,7 @@ def test_mm_prefix_caching(): blocks = manager.allocate_slots(req0, 59, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert blocks.get_block_ids() == [[1, 2, 3, 4]] + assert blocks.get_block_ids() == ([1, 2, 3, 4], ) req0.num_computed_tokens = 59 # Append slots without allocating a new block. @@ -926,7 +926,7 @@ def test_cache_key_salting(): blocks = manager.allocate_slots(req0, 59, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert blocks.get_block_ids() == [[1, 2, 3, 4]] + assert blocks.get_block_ids() == ([1, 2, 3, 4], ) req0.num_computed_tokens = 59 # Append slots without allocating a new block. @@ -1042,7 +1042,7 @@ def test_reset_prefix_cache(): all_token_ids = full_block_token_ids + unique_token_ids req0 = make_request("0", all_token_ids) blocks = manager.allocate_slots(req0, 55) - assert blocks.get_block_ids() == [[1, 2, 3, 4]] + assert blocks.get_block_ids() == ([1, 2, 3, 4], ) unique_token_ids = [4] * 7 all_token_ids = full_block_token_ids + unique_token_ids @@ -1053,7 +1053,7 @@ def test_reset_prefix_cache(): blocks = manager.allocate_slots(req1, 7, len(computed_blocks.blocks[0]) * 16, computed_blocks) - assert blocks.get_block_ids() == [[5]] + assert blocks.get_block_ids() == ([5], ) # Failed to reset prefix cache because some blocks are not freed yet. assert not manager.reset_prefix_cache() diff --git a/tests/v1/tpu/worker/test_tpu_model_runner.py b/tests/v1/tpu/worker/test_tpu_model_runner.py index 320ebef4075e..0e7d305fef9e 100644 --- a/tests/v1/tpu/worker/test_tpu_model_runner.py +++ b/tests/v1/tpu/worker/test_tpu_model_runner.py @@ -71,7 +71,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: mm_hashes=[], mm_positions=[], sampling_params=SamplingParams(), - block_ids=[[0]], # block_ids should be list[list[int]] + block_ids=([0], ), # block_ids should be tuple[list[int]] num_computed_tokens=0, lora_request=None, )) @@ -116,10 +116,10 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool: # This is safe since we currently only use single KV cache groups block_table = multi_group_block_table[0] - # req_state.block_ids is now list[list[int]] for MultiGroupBlockTable + # req_state.block_ids is now tuple[list[int], ...] for MultiGroupBlockTable # Extract the first group's block IDs if isinstance(req_state.block_ids[0], list): - # New format: list[list[int]] - extract first group + # New format: tuple[list[int], ...] - extract first group req_block_ids = req_state.block_ids[0] else: # Legacy format: list[int] - use directly @@ -210,7 +210,7 @@ def test_update_states_request_resumed(model_runner): req_id=req_id, resumed_from_preemption=False, new_token_ids=[], - new_block_ids=[[]], + new_block_ids=([], ), num_computed_tokens=0, ) diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 72547e86b0e9..de6ebe4f6716 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -203,7 +203,7 @@ def _construct_cached_request_state(req_id_suffix: int): sampling_params=_create_sampling_params(), mm_inputs=[], mm_positions=[], - block_ids=[[]], + block_ids=([], ), generator=None, num_computed_tokens=len(output_token_ids), output_token_ids=output_token_ids, diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 3d51b53df2ce..fa0bab71b954 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -123,7 +123,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: mm_hashes=[], mm_positions=[], sampling_params=SamplingParams(), - block_ids=[[0]], + block_ids=([0], ), num_computed_tokens=0, lora_request=None, )) @@ -251,7 +251,7 @@ def test_update_states_request_resumed(model_runner): req_id=req_id, resumed_from_preemption=False, new_token_ids=[], - new_block_ids=[[]], + new_block_ids=([], ), num_computed_tokens=0, ) diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 3b2a4f936000..d21f94727cf6 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -89,8 +89,8 @@ class BlockPool: BlockHashWithGroupId(block_hash, group_id)) if not cached_blocks_one_group: return None - first_block_id = next(iter(cached_blocks_one_group)) - cached_blocks.append(cached_blocks_one_group[first_block_id]) + first_block = next(iter(cached_blocks_one_group.values())) + cached_blocks.append(first_block) return cached_blocks def cache_full_blocks( @@ -260,7 +260,7 @@ class BlockPool: return True return False - def touch(self, blocks: list[list[KVCacheBlock]]) -> None: + def touch(self, blocks: tuple[list[KVCacheBlock], ...]) -> None: """Touch a block increases its reference count by 1, and may remove the block from the free queue. This is used when a block is hit by another request with the same prefix. @@ -299,7 +299,7 @@ class BlockPool: bool: True if the prefix cache is successfully reset, False otherwise. """ - num_used_blocks = (self.num_gpu_blocks - self.get_num_free_blocks()) + num_used_blocks = self.num_gpu_blocks - self.get_num_free_blocks() if num_used_blocks != 1: # The null block is always marked as used logger.warning( "Failed to reset prefix cache because some " diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 231bad1df922..5620d9bee7a3 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -5,8 +5,7 @@ from typing import Callable, Optional from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock from vllm.v1.core.single_type_kv_cache_manager import ( - FullAttentionManager, SingleTypeKVCacheManager, - get_manager_for_kv_cache_spec) + FullAttentionManager, get_manager_for_kv_cache_spec) from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig from vllm.v1.request import Request @@ -30,25 +29,21 @@ class KVCacheCoordinator(ABC): self.block_pool = BlockPool(kv_cache_config.num_blocks, enable_caching, enable_kv_cache_events) - self.single_type_managers: list[SingleTypeKVCacheManager] = [] # Needs special handling for find_longest_cache_hit if eagle is enabled self.use_eagle = use_eagle - - for i in range(len(self.kv_cache_config.kv_cache_groups)): - kv_cache_spec = self.kv_cache_config.kv_cache_groups[ - i].kv_cache_spec - self.single_type_managers.append( - get_manager_for_kv_cache_spec( - kv_cache_spec=kv_cache_spec, - block_pool=self.block_pool, - kv_cache_group_id=i, - caching_hash_fn=caching_hash_fn, - )) + self.single_type_managers = tuple( + get_manager_for_kv_cache_spec( + kv_cache_spec=kv_cache_group.kv_cache_spec, + block_pool=self.block_pool, + kv_cache_group_id=i, + caching_hash_fn=caching_hash_fn, + ) for i, kv_cache_group in enumerate( + self.kv_cache_config.kv_cache_groups)) def get_num_blocks_to_allocate( self, request_id: str, num_tokens: int, - new_computed_blocks: list[list[KVCacheBlock]]) -> int: + new_computed_blocks: tuple[list[KVCacheBlock], ...]) -> int: """ Get the number of blocks needed to be allocated for the request. @@ -70,7 +65,7 @@ class KVCacheCoordinator(ABC): def save_new_computed_blocks( self, request_id: str, - new_computed_blocks: list[list[KVCacheBlock]]) -> None: + new_computed_blocks: tuple[list[KVCacheBlock], ...]) -> None: """ Add the new computed blocks to the request. @@ -84,7 +79,7 @@ class KVCacheCoordinator(ABC): new_computed_blocks[i]) def allocate_new_blocks(self, request_id: str, - num_tokens: int) -> list[list[KVCacheBlock]]: + num_tokens: int) -> tuple[list[KVCacheBlock], ...]: """ Allocate new blocks for the request to give it at least `num_tokens` token slots. @@ -97,11 +92,9 @@ class KVCacheCoordinator(ABC): Returns: The new allocated blocks. """ - new_blocks = [] - for manager in self.single_type_managers: - new_blocks.append( - manager.allocate_new_blocks(request_id, num_tokens)) - return new_blocks + return tuple( + manager.allocate_new_blocks(request_id, num_tokens) + for manager in self.single_type_managers) def cache_blocks(self, request: Request, block_hashes: list[BlockHash], num_computed_tokens: int) -> None: @@ -159,19 +152,20 @@ class KVCacheCoordinator(ABC): for manager in self.single_type_managers: manager.remove_skipped_blocks(request_id, num_computed_tokens) - def get_blocks(self, request_id: str) -> list[list[KVCacheBlock]]: + def get_blocks(self, request_id: str) -> tuple[list[KVCacheBlock], ...]: """ Get the blocks for the request. """ - return [ + return tuple( manager.req_to_blocks.get(request_id) or [] - for manager in self.single_type_managers - ] + for manager in self.single_type_managers) @abstractmethod def find_longest_cache_hit( - self, block_hashes: list[BlockHash], - max_cache_hit_length: int) -> tuple[list[list[KVCacheBlock]], int]: + self, + block_hashes: list[BlockHash], + max_cache_hit_length: int, + ) -> tuple[tuple[list[KVCacheBlock], ...], int]: pass @@ -195,8 +189,10 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator): "UnitaryKVCacheCoordinator assumes only one kv cache group") def find_longest_cache_hit( - self, block_hashes: list[BlockHash], - max_cache_hit_length: int) -> tuple[list[list[KVCacheBlock]], int]: + self, + block_hashes: list[BlockHash], + max_cache_hit_length: int, + ) -> tuple[tuple[list[KVCacheBlock], ...], int]: hit_blocks = self.single_type_managers[0].find_longest_cache_hit( block_hashes=block_hashes, max_length=max_cache_hit_length, @@ -275,11 +271,24 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): "KVCacheCoordinator assumes the block_size of full attention " "layers is divisible by other layers now.") + if max(self.full_attention_group_ids) < min(self.other_group_ids): + self.full_attn_first = True + elif max(self.other_group_ids) < min(self.full_attention_group_ids): + self.full_attn_first = False + else: + raise ValueError( + "HybridKVCacheCoordinator assumes the full " + "attention group ids and other attention group ids " + "do not interleave, either full attention group ids " + "are before other attention group ids or vice versa." + "This is for simplifying merging hit_blocks_full_attn and " + "hit_blocks_other_attn to hit_blocks.") + def find_longest_cache_hit( self, block_hashes: list[BlockHash], max_cache_hit_length: int, - ) -> tuple[list[list[KVCacheBlock]], int]: + ) -> tuple[tuple[list[KVCacheBlock], ...], int]: """ Find the longest cache hit for the request. @@ -318,27 +327,25 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): )) hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size - # NOTE: the prefix cache hit length must be a multiply of block_size as + # NOTE: the prefix cache hit length must be a multiple of block_size as # we don't support partial block cache hit yet. The cache hit length - # of other attention is ensured to be a multiply of the block size of + # of other attention is ensured to be a multiple of the block size of # full attention layers in current implementation, because hit_length is - # a multiply of other attention's block size, and other attention's - # block size is a multiply of full attention's block size (verified in + # a multiple of other attention's block size, and other attention's + # block size is a multiple of full attention's block size (verified in # `verify_and_split_kv_cache_groups`). assert hit_length % self.full_attention_block_size == 0 # Truncate the full attention cache hit to the length of the # cache hit of the other attention. - for i in range(len(hit_blocks_full_attn)): - del hit_blocks_full_attn[i][hit_length // - self.full_attention_block_size:] + for group_hit_blocks in hit_blocks_full_attn: + del group_hit_blocks[hit_length // self.full_attention_block_size:] # Merge the hit blocks of full attention and other attention. - hit_blocks = hit_blocks_other_attn - for group_id, blocks in enumerate(hit_blocks_full_attn): - # NOTE: there is only one full attention group in most cases. So - # the time complexity of insert is fine. - hit_blocks.insert(group_id, blocks) + if self.full_attn_first: + hit_blocks = hit_blocks_full_attn + hit_blocks_other_attn + else: + hit_blocks = hit_blocks_other_attn + hit_blocks_full_attn return hit_blocks, hit_length @@ -351,8 +358,6 @@ def get_kv_cache_coordinator( use_eagle, enable_caching, caching_hash_fn, enable_kv_cache_events) - else: - return HybridKVCacheCoordinator(kv_cache_config, max_model_len, - use_eagle, enable_caching, - caching_hash_fn, - enable_kv_cache_events) + return HybridKVCacheCoordinator(kv_cache_config, max_model_len, use_eagle, + enable_caching, caching_hash_fn, + enable_kv_cache_events) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 35fb189fda34..2e09f4c0aacf 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -21,11 +21,11 @@ logger = init_logger(__name__) @dataclass class KVCacheBlocks: """ - The allocation result of KVCacheManager, work as the interface between - Scheduler and KVCacheManager, to hide KVCacheManager's internal data + The allocation result of KVCacheManager, work as the interface between + Scheduler and KVCacheManager, to hide KVCacheManager's internal data structure from the Scheduler. """ - blocks: list[list[KVCacheBlock]] + blocks: tuple[list[KVCacheBlock], ...] """ blocks[i][j] refers to the i-th kv_cache_group and the j-th block of tokens. We don't use block of tokens as the outer dimension because it assumes all @@ -37,21 +37,19 @@ class KVCacheBlocks: def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks": """Adds two KVCacheBlocks instances.""" return KVCacheBlocks( - [blk1 + blk2 for blk1, blk2 in zip(self.blocks, other.blocks)]) + tuple(blk1 + blk2 + for blk1, blk2 in zip(self.blocks, other.blocks))) - def get_block_ids(self) -> list[list[int]]: + def get_block_ids(self) -> tuple[list[int], ...]: """ Converts the KVCacheBlocks instance to block_ids. Returns: - list[list[int]]: A two-level list where - * the outer list corresponds to KV cache groups + tuple[list[int], ...]: A tuple of lists where + * the outer tuple corresponds to KV cache groups * each inner list contains the block_ids of the blocks in that group """ - block_ids = [] - for group in self.blocks: - block_ids.append([blk.block_id for blk in group]) - return block_ids + return tuple([blk.block_id for blk in group] for group in self.blocks) def get_unhashed_block_ids(self) -> list[int]: """Get block_ids of unhashed blocks from KVCacheBlocks instance.""" @@ -63,7 +61,7 @@ class KVCacheBlocks: def new_empty(self) -> "KVCacheBlocks": """Creates a new KVCacheBlocks instance with no blocks.""" - return KVCacheBlocks([[] for _ in range(len(self.blocks))]) + return KVCacheBlocks(tuple([] for _ in range(len(self.blocks)))) class KVCacheManager: @@ -232,9 +230,8 @@ class KVCacheManager: if new_computed_blocks is not None: new_computed_block_list = new_computed_blocks.blocks else: - new_computed_block_list = [ - [] for _ in range(len(self.kv_cache_config.kv_cache_groups)) - ] + new_computed_block_list = tuple( + [] for _ in range(len(self.kv_cache_config.kv_cache_groups))) # Free the blocks that are skipped during the attention computation # (e.g., tokens outside the sliding window). @@ -267,7 +264,7 @@ class KVCacheManager: if self.enable_caching: self.block_pool.touch(new_computed_block_list) else: - assert all(not blocks for blocks in new_computed_block_list), ( + assert not any(new_computed_block_list), ( "Computed blocks should be empty when " "prefix caching is disabled") @@ -378,17 +375,18 @@ class KVCacheManager: """ return self.block_pool.take_events() - def get_block_ids(self, request_id: str) -> list[list[int]]: + def get_block_ids(self, request_id: str) -> tuple[list[int], ...]: """Get the block ids of a request.""" return KVCacheBlocks( self.coordinator.get_blocks(request_id)).get_block_ids() - def cache_blocks(self, request: Request, block_hashes: list[BlockHash], - num_computed_tokens: int) -> None: + def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: """Cache the blocks for the request.""" + block_hashes = self.req_to_block_hashes[request.request_id] self.coordinator.cache_blocks(request, block_hashes, num_computed_tokens) def create_empty_block_list(self) -> KVCacheBlocks: """Creates a new KVCacheBlocks instance with no blocks.""" - return KVCacheBlocks([[] for _ in range(self.num_kv_cache_groups)]) + return KVCacheBlocks(tuple([] + for _ in range(self.num_kv_cache_groups))) diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index b404c70eb1e4..9b0a439fe7dc 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -27,7 +27,7 @@ class NewRequestData: mm_hashes: list[str] mm_positions: list[PlaceholderRange] sampling_params: SamplingParams - block_ids: list[list[int]] + block_ids: tuple[list[int], ...] num_computed_tokens: int lora_request: Optional[LoRARequest] @@ -35,7 +35,7 @@ class NewRequestData: def from_request( cls, request: Request, - block_ids: list[list[int]], + block_ids: tuple[list[int], ...], ) -> NewRequestData: return cls( req_id=request.request_id, @@ -86,7 +86,7 @@ class CachedRequestData: # request's block IDs instead of appending to the existing block IDs. resumed_from_preemption: bool new_token_ids: list[int] - new_block_ids: list[list[int]] + new_block_ids: tuple[list[int], ...] num_computed_tokens: int @classmethod @@ -95,7 +95,7 @@ class CachedRequestData: request: Request, resumed_from_preemption: bool, new_token_ids: list[int], - new_block_ids: list[list[int]], + new_block_ids: tuple[list[int], ...], ) -> CachedRequestData: return cls( req_id=request.request_id, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index b3293d9a541f..3d7bbe7e0e39 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -180,7 +180,7 @@ class Scheduler(SchedulerInterface): # uses structured decoding. structured_output_request_ids: dict[str, int] = {} - req_to_new_block_ids: dict[str, list[list[int]]] = {} + req_to_new_block_ids: dict[str, tuple[list[int], ...]] = {} num_scheduled_tokens: dict[str, int] = {} token_budget = self.max_num_scheduled_tokens # Encoder-related. @@ -471,7 +471,7 @@ class Scheduler(SchedulerInterface): token_budget -= num_new_tokens request.status = RequestStatus.RUNNING request.num_computed_tokens = num_computed_tokens - # Count the number of prifix cached tokens. + # Count the number of prefix cached tokens. if request.num_cached_tokens < 0: request.num_cached_tokens = num_computed_tokens # Encoder-related. @@ -588,7 +588,7 @@ class Scheduler(SchedulerInterface): request: Request, num_scheduled_tokens: int, num_scheduled_spec_tokens: int, - new_block_ids: list[list[int]], + new_block_ids: tuple[list[int], ...], resumed_from_preemption: bool, ) -> CachedRequestData: # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating @@ -1015,11 +1015,7 @@ class Scheduler(SchedulerInterface): num_computed_tokens = min(num_computed_tokens, request.num_tokens) if num_computed_tokens == request.num_tokens: num_computed_tokens -= 1 - self.kv_cache_manager.cache_blocks( - request, - self.kv_cache_manager.req_to_block_hashes[request.request_id], - num_computed_tokens, - ) + self.kv_cache_manager.cache_blocks(request, num_computed_tokens) # Update the request state for scheduling. request.num_computed_tokens = num_computed_tokens diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 98d758f820ad..95222779c3af 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -197,7 +197,7 @@ class SingleTypeKVCacheManager(ABC): block_pool: BlockPool, kv_cache_spec: KVCacheSpec, use_eagle: bool, - ) -> list[list[KVCacheBlock]]: + ) -> tuple[list[KVCacheBlock], ...]: """ Get the longest cache hit prefix of the blocks that is not longer than `max_length`. The prefix should be a common prefix hit for all the @@ -222,7 +222,7 @@ class SingleTypeKVCacheManager(ABC): element is a list of cached blocks for the i-th kv cache group in `kv_cache_group_ids`. For example, sliding window manager should return a list like - [[NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)]] for block size 4 + ([NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)]) for block size 4 and sliding window 8 and len(kv_cache_group_ids) = 1. """ @@ -254,27 +254,25 @@ class FullAttentionManager(SingleTypeKVCacheManager): block_pool: BlockPool, kv_cache_spec: KVCacheSpec, use_eagle: bool, - ) -> list[list[KVCacheBlock]]: + ) -> tuple[list[KVCacheBlock], ...]: assert isinstance(kv_cache_spec, FullAttentionSpec), ( "FullAttentionManager can only be used for full attention groups") - computed_blocks: list[list[KVCacheBlock]] = [ - [] for _ in range(len(kv_cache_group_ids)) - ] + computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( + [] for _ in range(len(kv_cache_group_ids))) max_num_blocks = max_length // kv_cache_spec.block_size - for i in range(max_num_blocks): - block_hash = block_hashes[i] + for i, block_hash in zip(range(max_num_blocks), block_hashes): # block_hashes is a chain of block hashes. If a block hash is not # in the cached_block_hash_to_id, the following block hashes are # not computed yet for sure. if cached_block := block_pool.get_cached_block( block_hash, kv_cache_group_ids): - for j in range(len(kv_cache_group_ids)): - computed_blocks[j].append(cached_block[j]) + for computed, cached in zip(computed_blocks, cached_block): + computed.append(cached) else: break - if use_eagle and len(computed_blocks[0]) > 0: - for j in range(len(kv_cache_group_ids)): - computed_blocks[j].pop() + if use_eagle and computed_blocks[0]: + for computed in computed_blocks: + computed.pop() return computed_blocks def remove_skipped_blocks(self, request_id: str, @@ -311,7 +309,7 @@ class SlidingWindowManager(SingleTypeKVCacheManager): block_pool: BlockPool, kv_cache_spec: KVCacheSpec, use_eagle: bool, - ) -> list[list[KVCacheBlock]]: + ) -> tuple[list[KVCacheBlock], ...]: assert isinstance(kv_cache_spec, SlidingWindowSpec), ( "SlidingWindowManager can only be used for sliding window groups") @@ -332,23 +330,23 @@ class SlidingWindowManager(SingleTypeKVCacheManager): # sliding_window_contiguous_blocks), # which is good for low cache hit rate scenarios. max_num_blocks = max_length // kv_cache_spec.block_size - computed_blocks = [[block_pool.null_block] * max_num_blocks - for _ in range(len(kv_cache_group_ids))] + computed_blocks = tuple([block_pool.null_block] * max_num_blocks + for _ in range(len(kv_cache_group_ids))) num_contiguous_blocks = 0 match_found = False # Search from right to left and early stop when a match is found. for i in range(max_num_blocks - 1, -1, -1): if cached_block := block_pool.get_cached_block( block_hashes[i], kv_cache_group_ids): - for j in range(len(kv_cache_group_ids)): - computed_blocks[j][i] = cached_block[j] + for computed, cached in zip(computed_blocks, cached_block): + computed[i] = cached num_contiguous_blocks += 1 - if (num_contiguous_blocks >= sliding_window_contiguous_blocks): + if num_contiguous_blocks >= sliding_window_contiguous_blocks: # Trim the trailing blocks. # E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3] # when sliding_window_contiguous_blocks=2. - for j in range(len(kv_cache_group_ids)): - del computed_blocks[j][i + num_contiguous_blocks:] + for computed in computed_blocks: + del computed[i + num_contiguous_blocks:] match_found = True break else: @@ -356,11 +354,11 @@ class SlidingWindowManager(SingleTypeKVCacheManager): if not match_found: # The first `num_contiguous_blocks` is a cache hit even if # `num_contiguous_blocks < sliding_window_contiguous_blocks`. - for j in range(len(kv_cache_group_ids)): - del computed_blocks[j][num_contiguous_blocks:] - if use_eagle and len(computed_blocks[0]) > 0: - for j in range(len(kv_cache_group_ids)): - computed_blocks[j].pop() + for computed in computed_blocks: + del computed[num_contiguous_blocks:] + if use_eagle and computed_blocks[0]: + for computed in computed_blocks: + computed.pop() return computed_blocks def remove_skipped_blocks(self, request_id: str, diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 5cd5674fb522..8f4e8d64c615 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -112,11 +112,12 @@ class MultiGroupBlockTable: for block_size in block_sizes ] - def append_row(self, block_ids: list[list[int]], row_idx: int) -> None: + def append_row(self, block_ids: tuple[list[int], ...], + row_idx: int) -> None: for i, block_table in enumerate(self.block_tables): block_table.append_row(block_ids[i], row_idx) - def add_row(self, block_ids: list[list[int]], row_idx: int) -> None: + def add_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None: for i, block_table in enumerate(self.block_tables): block_table.add_row(block_ids[i], row_idx) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 34737029f6bf..ebb770a7ddb2 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -30,7 +30,7 @@ class CachedRequestState: sampling_params: SamplingParams generator: Optional[torch.Generator] - block_ids: list[list[int]] + block_ids: tuple[list[int], ...] num_computed_tokens: int output_token_ids: list[int]