mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:15:31 +08:00
[Core] Use tuple for kv cache group block ids (#19175)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
6cd4ae8acd
commit
646d62f636
@ -117,7 +117,7 @@ def test_prefill(hash_algo):
|
||||
blocks = manager.allocate_slots(req0, 55,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
|
||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||
|
||||
# Check full block metadata
|
||||
parent_block_hash = None
|
||||
@ -141,13 +141,13 @@ def test_prefill(hash_algo):
|
||||
req1 = make_request("1", common_token_ids + unique_token_ids)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
|
||||
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
|
||||
assert computed_blocks.get_block_ids() == ([1, 2, 3], )
|
||||
assert num_computed_tokens == 3 * 16
|
||||
num_new_tokens = 53 - 3 * 16
|
||||
blocks = manager.allocate_slots(req1, num_new_tokens,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[5]]
|
||||
assert blocks.get_block_ids() == ([5], )
|
||||
for block in computed_blocks.blocks[0]:
|
||||
assert block.ref_cnt == 2
|
||||
|
||||
@ -175,13 +175,13 @@ def test_prefill(hash_algo):
|
||||
req2 = make_request("2", common_token_ids + unique_token_ids)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert len(manager.req_to_block_hashes[req2.request_id]) == 3
|
||||
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
|
||||
assert computed_blocks.get_block_ids() == ([1, 2, 3], )
|
||||
assert num_computed_tokens == 3 * 16
|
||||
num_new_tokens = 53 - 3 * 16
|
||||
blocks = manager.allocate_slots(req2, num_new_tokens,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[6]]
|
||||
assert blocks.get_block_ids() == ([6], )
|
||||
|
||||
# Although we only have 6 free blocks, we have 8 blocks in
|
||||
# the free block queue due to lazy removal.
|
||||
@ -205,7 +205,7 @@ def test_prefill(hash_algo):
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
# This block ID order also checks the eviction order.
|
||||
assert blocks.get_block_ids() == [[7, 8, 9, 10, 4, 5, 6, 3, 2, 1]]
|
||||
assert blocks.get_block_ids() == ([7, 8, 9, 10, 4, 5, 6, 3, 2, 1], )
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 0
|
||||
assert manager.block_pool.free_block_queue.free_list_head is None
|
||||
assert manager.block_pool.free_block_queue.free_list_tail is None
|
||||
@ -236,8 +236,8 @@ def test_prefill_hybrid_model():
|
||||
blocks = manager.allocate_slots(req0, 55,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[1, 2, 3, 4], [5, 6, 7, 8],
|
||||
[9, 10, 11, 12]]
|
||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], [5, 6, 7,
|
||||
8], [9, 10, 11, 12])
|
||||
|
||||
# Check full block metadata
|
||||
parent_block_hash = None
|
||||
@ -263,14 +263,14 @@ def test_prefill_hybrid_model():
|
||||
req1 = make_request("1", common_token_ids + unique_token_ids)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
|
||||
assert computed_blocks.get_block_ids() == [[1, 2, 3], [0, 6, 7],
|
||||
[0, 10, 11]]
|
||||
assert computed_blocks.get_block_ids() == ([1, 2, 3], [0, 6,
|
||||
7], [0, 10, 11])
|
||||
assert num_computed_tokens == 3 * 16
|
||||
num_new_tokens = 53 - 3 * 16
|
||||
blocks = manager.allocate_slots(req1, num_new_tokens,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[13], [14], [15]]
|
||||
assert blocks.get_block_ids() == ([13], [14], [15])
|
||||
for block_per_group in computed_blocks.blocks:
|
||||
for block in block_per_group:
|
||||
if block != manager.block_pool.null_block:
|
||||
@ -374,7 +374,7 @@ def test_prefill_plp():
|
||||
blocks = manager.allocate_slots(req0, 55,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
|
||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||
req0_block_hashes = [b.block_hash for b in blocks.blocks[0]]
|
||||
|
||||
# Check full block metadata
|
||||
@ -400,13 +400,13 @@ def test_prefill_plp():
|
||||
req1 = make_request("1", common_token_ids + unique_token_ids)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
|
||||
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
|
||||
assert computed_blocks.get_block_ids() == ([1, 2, 3], )
|
||||
assert num_computed_tokens == 3 * 16
|
||||
num_new_tokens = 53 - 3 * 16
|
||||
blocks = manager.allocate_slots(req1, num_new_tokens,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[5]]
|
||||
assert blocks.get_block_ids() == ([5], )
|
||||
for block in computed_blocks.blocks[0]:
|
||||
assert block.ref_cnt == 2
|
||||
|
||||
@ -444,7 +444,7 @@ def test_prefill_plp():
|
||||
block_ids = blocks.get_block_ids()
|
||||
# Duplicate cached blocks have different ids but same hashes vs request #0
|
||||
assert [b.block_hash for b in blocks.blocks[0]] == req0_block_hashes
|
||||
assert block_ids != [[1, 2, 3, 4]]
|
||||
assert block_ids != ([1, 2, 3, 4], )
|
||||
|
||||
# Request #2 block hashes are valid since request #0 hashes are.
|
||||
# Check block reference counts.
|
||||
@ -474,7 +474,7 @@ def test_decode():
|
||||
blocks = manager.allocate_slots(req0, 55,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
|
||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||
|
||||
# Append slots without allocating a new block.
|
||||
req0.num_computed_tokens = 55
|
||||
@ -546,12 +546,12 @@ def test_evict():
|
||||
# Touch the first 2 blocks.
|
||||
req2 = make_request("2", list(range(2 * 16 + 3)))
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert computed_blocks.get_block_ids() == [[1, 2]]
|
||||
assert computed_blocks.get_block_ids() == ([1, 2], )
|
||||
assert num_computed_tokens == 2 * 16
|
||||
blocks = manager.allocate_slots(req2, 3,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[10]]
|
||||
assert blocks.get_block_ids() == ([10], )
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 7
|
||||
|
||||
|
||||
@ -865,7 +865,7 @@ def test_mm_prefix_caching():
|
||||
blocks = manager.allocate_slots(req0, 59,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
|
||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||
req0.num_computed_tokens = 59
|
||||
|
||||
# Append slots without allocating a new block.
|
||||
@ -926,7 +926,7 @@ def test_cache_key_salting():
|
||||
blocks = manager.allocate_slots(req0, 59,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
|
||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||
req0.num_computed_tokens = 59
|
||||
|
||||
# Append slots without allocating a new block.
|
||||
@ -1042,7 +1042,7 @@ def test_reset_prefix_cache():
|
||||
all_token_ids = full_block_token_ids + unique_token_ids
|
||||
req0 = make_request("0", all_token_ids)
|
||||
blocks = manager.allocate_slots(req0, 55)
|
||||
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
|
||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||
|
||||
unique_token_ids = [4] * 7
|
||||
all_token_ids = full_block_token_ids + unique_token_ids
|
||||
@ -1053,7 +1053,7 @@ def test_reset_prefix_cache():
|
||||
blocks = manager.allocate_slots(req1, 7,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[5]]
|
||||
assert blocks.get_block_ids() == ([5], )
|
||||
|
||||
# Failed to reset prefix cache because some blocks are not freed yet.
|
||||
assert not manager.reset_prefix_cache()
|
||||
|
||||
@ -71,7 +71,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
|
||||
mm_hashes=[],
|
||||
mm_positions=[],
|
||||
sampling_params=SamplingParams(),
|
||||
block_ids=[[0]], # block_ids should be list[list[int]]
|
||||
block_ids=([0], ), # block_ids should be tuple[list[int]]
|
||||
num_computed_tokens=0,
|
||||
lora_request=None,
|
||||
))
|
||||
@ -116,10 +116,10 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
|
||||
# This is safe since we currently only use single KV cache groups
|
||||
block_table = multi_group_block_table[0]
|
||||
|
||||
# req_state.block_ids is now list[list[int]] for MultiGroupBlockTable
|
||||
# req_state.block_ids is now tuple[list[int], ...] for MultiGroupBlockTable
|
||||
# Extract the first group's block IDs
|
||||
if isinstance(req_state.block_ids[0], list):
|
||||
# New format: list[list[int]] - extract first group
|
||||
# New format: tuple[list[int], ...] - extract first group
|
||||
req_block_ids = req_state.block_ids[0]
|
||||
else:
|
||||
# Legacy format: list[int] - use directly
|
||||
@ -210,7 +210,7 @@ def test_update_states_request_resumed(model_runner):
|
||||
req_id=req_id,
|
||||
resumed_from_preemption=False,
|
||||
new_token_ids=[],
|
||||
new_block_ids=[[]],
|
||||
new_block_ids=([], ),
|
||||
num_computed_tokens=0,
|
||||
)
|
||||
|
||||
|
||||
@ -203,7 +203,7 @@ def _construct_cached_request_state(req_id_suffix: int):
|
||||
sampling_params=_create_sampling_params(),
|
||||
mm_inputs=[],
|
||||
mm_positions=[],
|
||||
block_ids=[[]],
|
||||
block_ids=([], ),
|
||||
generator=None,
|
||||
num_computed_tokens=len(output_token_ids),
|
||||
output_token_ids=output_token_ids,
|
||||
|
||||
@ -123,7 +123,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
|
||||
mm_hashes=[],
|
||||
mm_positions=[],
|
||||
sampling_params=SamplingParams(),
|
||||
block_ids=[[0]],
|
||||
block_ids=([0], ),
|
||||
num_computed_tokens=0,
|
||||
lora_request=None,
|
||||
))
|
||||
@ -251,7 +251,7 @@ def test_update_states_request_resumed(model_runner):
|
||||
req_id=req_id,
|
||||
resumed_from_preemption=False,
|
||||
new_token_ids=[],
|
||||
new_block_ids=[[]],
|
||||
new_block_ids=([], ),
|
||||
num_computed_tokens=0,
|
||||
)
|
||||
|
||||
|
||||
@ -89,8 +89,8 @@ class BlockPool:
|
||||
BlockHashWithGroupId(block_hash, group_id))
|
||||
if not cached_blocks_one_group:
|
||||
return None
|
||||
first_block_id = next(iter(cached_blocks_one_group))
|
||||
cached_blocks.append(cached_blocks_one_group[first_block_id])
|
||||
first_block = next(iter(cached_blocks_one_group.values()))
|
||||
cached_blocks.append(first_block)
|
||||
return cached_blocks
|
||||
|
||||
def cache_full_blocks(
|
||||
@ -260,7 +260,7 @@ class BlockPool:
|
||||
return True
|
||||
return False
|
||||
|
||||
def touch(self, blocks: list[list[KVCacheBlock]]) -> None:
|
||||
def touch(self, blocks: tuple[list[KVCacheBlock], ...]) -> None:
|
||||
"""Touch a block increases its reference count by 1, and may remove
|
||||
the block from the free queue. This is used when a block is hit by
|
||||
another request with the same prefix.
|
||||
@ -299,7 +299,7 @@ class BlockPool:
|
||||
bool: True if the prefix cache is successfully reset,
|
||||
False otherwise.
|
||||
"""
|
||||
num_used_blocks = (self.num_gpu_blocks - self.get_num_free_blocks())
|
||||
num_used_blocks = self.num_gpu_blocks - self.get_num_free_blocks()
|
||||
if num_used_blocks != 1: # The null block is always marked as used
|
||||
logger.warning(
|
||||
"Failed to reset prefix cache because some "
|
||||
|
||||
@ -5,8 +5,7 @@ from typing import Callable, Optional
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
|
||||
from vllm.v1.core.single_type_kv_cache_manager import (
|
||||
FullAttentionManager, SingleTypeKVCacheManager,
|
||||
get_manager_for_kv_cache_spec)
|
||||
FullAttentionManager, get_manager_for_kv_cache_spec)
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
@ -30,25 +29,21 @@ class KVCacheCoordinator(ABC):
|
||||
|
||||
self.block_pool = BlockPool(kv_cache_config.num_blocks, enable_caching,
|
||||
enable_kv_cache_events)
|
||||
self.single_type_managers: list[SingleTypeKVCacheManager] = []
|
||||
|
||||
# Needs special handling for find_longest_cache_hit if eagle is enabled
|
||||
self.use_eagle = use_eagle
|
||||
|
||||
for i in range(len(self.kv_cache_config.kv_cache_groups)):
|
||||
kv_cache_spec = self.kv_cache_config.kv_cache_groups[
|
||||
i].kv_cache_spec
|
||||
self.single_type_managers.append(
|
||||
get_manager_for_kv_cache_spec(
|
||||
kv_cache_spec=kv_cache_spec,
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_group_id=i,
|
||||
caching_hash_fn=caching_hash_fn,
|
||||
))
|
||||
self.single_type_managers = tuple(
|
||||
get_manager_for_kv_cache_spec(
|
||||
kv_cache_spec=kv_cache_group.kv_cache_spec,
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_group_id=i,
|
||||
caching_hash_fn=caching_hash_fn,
|
||||
) for i, kv_cache_group in enumerate(
|
||||
self.kv_cache_config.kv_cache_groups))
|
||||
|
||||
def get_num_blocks_to_allocate(
|
||||
self, request_id: str, num_tokens: int,
|
||||
new_computed_blocks: list[list[KVCacheBlock]]) -> int:
|
||||
new_computed_blocks: tuple[list[KVCacheBlock], ...]) -> int:
|
||||
"""
|
||||
Get the number of blocks needed to be allocated for the request.
|
||||
|
||||
@ -70,7 +65,7 @@ class KVCacheCoordinator(ABC):
|
||||
|
||||
def save_new_computed_blocks(
|
||||
self, request_id: str,
|
||||
new_computed_blocks: list[list[KVCacheBlock]]) -> None:
|
||||
new_computed_blocks: tuple[list[KVCacheBlock], ...]) -> None:
|
||||
"""
|
||||
Add the new computed blocks to the request.
|
||||
|
||||
@ -84,7 +79,7 @@ class KVCacheCoordinator(ABC):
|
||||
new_computed_blocks[i])
|
||||
|
||||
def allocate_new_blocks(self, request_id: str,
|
||||
num_tokens: int) -> list[list[KVCacheBlock]]:
|
||||
num_tokens: int) -> tuple[list[KVCacheBlock], ...]:
|
||||
"""
|
||||
Allocate new blocks for the request to give it at least `num_tokens`
|
||||
token slots.
|
||||
@ -97,11 +92,9 @@ class KVCacheCoordinator(ABC):
|
||||
Returns:
|
||||
The new allocated blocks.
|
||||
"""
|
||||
new_blocks = []
|
||||
for manager in self.single_type_managers:
|
||||
new_blocks.append(
|
||||
manager.allocate_new_blocks(request_id, num_tokens))
|
||||
return new_blocks
|
||||
return tuple(
|
||||
manager.allocate_new_blocks(request_id, num_tokens)
|
||||
for manager in self.single_type_managers)
|
||||
|
||||
def cache_blocks(self, request: Request, block_hashes: list[BlockHash],
|
||||
num_computed_tokens: int) -> None:
|
||||
@ -159,19 +152,20 @@ class KVCacheCoordinator(ABC):
|
||||
for manager in self.single_type_managers:
|
||||
manager.remove_skipped_blocks(request_id, num_computed_tokens)
|
||||
|
||||
def get_blocks(self, request_id: str) -> list[list[KVCacheBlock]]:
|
||||
def get_blocks(self, request_id: str) -> tuple[list[KVCacheBlock], ...]:
|
||||
"""
|
||||
Get the blocks for the request.
|
||||
"""
|
||||
return [
|
||||
return tuple(
|
||||
manager.req_to_blocks.get(request_id) or []
|
||||
for manager in self.single_type_managers
|
||||
]
|
||||
for manager in self.single_type_managers)
|
||||
|
||||
@abstractmethod
|
||||
def find_longest_cache_hit(
|
||||
self, block_hashes: list[BlockHash],
|
||||
max_cache_hit_length: int) -> tuple[list[list[KVCacheBlock]], int]:
|
||||
self,
|
||||
block_hashes: list[BlockHash],
|
||||
max_cache_hit_length: int,
|
||||
) -> tuple[tuple[list[KVCacheBlock], ...], int]:
|
||||
pass
|
||||
|
||||
|
||||
@ -195,8 +189,10 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
|
||||
"UnitaryKVCacheCoordinator assumes only one kv cache group")
|
||||
|
||||
def find_longest_cache_hit(
|
||||
self, block_hashes: list[BlockHash],
|
||||
max_cache_hit_length: int) -> tuple[list[list[KVCacheBlock]], int]:
|
||||
self,
|
||||
block_hashes: list[BlockHash],
|
||||
max_cache_hit_length: int,
|
||||
) -> tuple[tuple[list[KVCacheBlock], ...], int]:
|
||||
hit_blocks = self.single_type_managers[0].find_longest_cache_hit(
|
||||
block_hashes=block_hashes,
|
||||
max_length=max_cache_hit_length,
|
||||
@ -275,11 +271,24 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
||||
"KVCacheCoordinator assumes the block_size of full attention "
|
||||
"layers is divisible by other layers now.")
|
||||
|
||||
if max(self.full_attention_group_ids) < min(self.other_group_ids):
|
||||
self.full_attn_first = True
|
||||
elif max(self.other_group_ids) < min(self.full_attention_group_ids):
|
||||
self.full_attn_first = False
|
||||
else:
|
||||
raise ValueError(
|
||||
"HybridKVCacheCoordinator assumes the full "
|
||||
"attention group ids and other attention group ids "
|
||||
"do not interleave, either full attention group ids "
|
||||
"are before other attention group ids or vice versa."
|
||||
"This is for simplifying merging hit_blocks_full_attn and "
|
||||
"hit_blocks_other_attn to hit_blocks.")
|
||||
|
||||
def find_longest_cache_hit(
|
||||
self,
|
||||
block_hashes: list[BlockHash],
|
||||
max_cache_hit_length: int,
|
||||
) -> tuple[list[list[KVCacheBlock]], int]:
|
||||
) -> tuple[tuple[list[KVCacheBlock], ...], int]:
|
||||
"""
|
||||
Find the longest cache hit for the request.
|
||||
|
||||
@ -318,27 +327,25 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
||||
))
|
||||
hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size
|
||||
|
||||
# NOTE: the prefix cache hit length must be a multiply of block_size as
|
||||
# NOTE: the prefix cache hit length must be a multiple of block_size as
|
||||
# we don't support partial block cache hit yet. The cache hit length
|
||||
# of other attention is ensured to be a multiply of the block size of
|
||||
# of other attention is ensured to be a multiple of the block size of
|
||||
# full attention layers in current implementation, because hit_length is
|
||||
# a multiply of other attention's block size, and other attention's
|
||||
# block size is a multiply of full attention's block size (verified in
|
||||
# a multiple of other attention's block size, and other attention's
|
||||
# block size is a multiple of full attention's block size (verified in
|
||||
# `verify_and_split_kv_cache_groups`).
|
||||
assert hit_length % self.full_attention_block_size == 0
|
||||
|
||||
# Truncate the full attention cache hit to the length of the
|
||||
# cache hit of the other attention.
|
||||
for i in range(len(hit_blocks_full_attn)):
|
||||
del hit_blocks_full_attn[i][hit_length //
|
||||
self.full_attention_block_size:]
|
||||
for group_hit_blocks in hit_blocks_full_attn:
|
||||
del group_hit_blocks[hit_length // self.full_attention_block_size:]
|
||||
|
||||
# Merge the hit blocks of full attention and other attention.
|
||||
hit_blocks = hit_blocks_other_attn
|
||||
for group_id, blocks in enumerate(hit_blocks_full_attn):
|
||||
# NOTE: there is only one full attention group in most cases. So
|
||||
# the time complexity of insert is fine.
|
||||
hit_blocks.insert(group_id, blocks)
|
||||
if self.full_attn_first:
|
||||
hit_blocks = hit_blocks_full_attn + hit_blocks_other_attn
|
||||
else:
|
||||
hit_blocks = hit_blocks_other_attn + hit_blocks_full_attn
|
||||
return hit_blocks, hit_length
|
||||
|
||||
|
||||
@ -351,8 +358,6 @@ def get_kv_cache_coordinator(
|
||||
use_eagle, enable_caching,
|
||||
caching_hash_fn,
|
||||
enable_kv_cache_events)
|
||||
else:
|
||||
return HybridKVCacheCoordinator(kv_cache_config, max_model_len,
|
||||
use_eagle, enable_caching,
|
||||
caching_hash_fn,
|
||||
enable_kv_cache_events)
|
||||
return HybridKVCacheCoordinator(kv_cache_config, max_model_len, use_eagle,
|
||||
enable_caching, caching_hash_fn,
|
||||
enable_kv_cache_events)
|
||||
|
||||
@ -21,11 +21,11 @@ logger = init_logger(__name__)
|
||||
@dataclass
|
||||
class KVCacheBlocks:
|
||||
"""
|
||||
The allocation result of KVCacheManager, work as the interface between
|
||||
Scheduler and KVCacheManager, to hide KVCacheManager's internal data
|
||||
The allocation result of KVCacheManager, work as the interface between
|
||||
Scheduler and KVCacheManager, to hide KVCacheManager's internal data
|
||||
structure from the Scheduler.
|
||||
"""
|
||||
blocks: list[list[KVCacheBlock]]
|
||||
blocks: tuple[list[KVCacheBlock], ...]
|
||||
"""
|
||||
blocks[i][j] refers to the i-th kv_cache_group and the j-th block of tokens.
|
||||
We don't use block of tokens as the outer dimension because it assumes all
|
||||
@ -37,21 +37,19 @@ class KVCacheBlocks:
|
||||
def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks":
|
||||
"""Adds two KVCacheBlocks instances."""
|
||||
return KVCacheBlocks(
|
||||
[blk1 + blk2 for blk1, blk2 in zip(self.blocks, other.blocks)])
|
||||
tuple(blk1 + blk2
|
||||
for blk1, blk2 in zip(self.blocks, other.blocks)))
|
||||
|
||||
def get_block_ids(self) -> list[list[int]]:
|
||||
def get_block_ids(self) -> tuple[list[int], ...]:
|
||||
"""
|
||||
Converts the KVCacheBlocks instance to block_ids.
|
||||
|
||||
Returns:
|
||||
list[list[int]]: A two-level list where
|
||||
* the outer list corresponds to KV cache groups
|
||||
tuple[list[int], ...]: A tuple of lists where
|
||||
* the outer tuple corresponds to KV cache groups
|
||||
* each inner list contains the block_ids of the blocks in that group
|
||||
"""
|
||||
block_ids = []
|
||||
for group in self.blocks:
|
||||
block_ids.append([blk.block_id for blk in group])
|
||||
return block_ids
|
||||
return tuple([blk.block_id for blk in group] for group in self.blocks)
|
||||
|
||||
def get_unhashed_block_ids(self) -> list[int]:
|
||||
"""Get block_ids of unhashed blocks from KVCacheBlocks instance."""
|
||||
@ -63,7 +61,7 @@ class KVCacheBlocks:
|
||||
|
||||
def new_empty(self) -> "KVCacheBlocks":
|
||||
"""Creates a new KVCacheBlocks instance with no blocks."""
|
||||
return KVCacheBlocks([[] for _ in range(len(self.blocks))])
|
||||
return KVCacheBlocks(tuple([] for _ in range(len(self.blocks))))
|
||||
|
||||
|
||||
class KVCacheManager:
|
||||
@ -232,9 +230,8 @@ class KVCacheManager:
|
||||
if new_computed_blocks is not None:
|
||||
new_computed_block_list = new_computed_blocks.blocks
|
||||
else:
|
||||
new_computed_block_list = [
|
||||
[] for _ in range(len(self.kv_cache_config.kv_cache_groups))
|
||||
]
|
||||
new_computed_block_list = tuple(
|
||||
[] for _ in range(len(self.kv_cache_config.kv_cache_groups)))
|
||||
|
||||
# Free the blocks that are skipped during the attention computation
|
||||
# (e.g., tokens outside the sliding window).
|
||||
@ -267,7 +264,7 @@ class KVCacheManager:
|
||||
if self.enable_caching:
|
||||
self.block_pool.touch(new_computed_block_list)
|
||||
else:
|
||||
assert all(not blocks for blocks in new_computed_block_list), (
|
||||
assert not any(new_computed_block_list), (
|
||||
"Computed blocks should be empty when "
|
||||
"prefix caching is disabled")
|
||||
|
||||
@ -378,17 +375,18 @@ class KVCacheManager:
|
||||
"""
|
||||
return self.block_pool.take_events()
|
||||
|
||||
def get_block_ids(self, request_id: str) -> list[list[int]]:
|
||||
def get_block_ids(self, request_id: str) -> tuple[list[int], ...]:
|
||||
"""Get the block ids of a request."""
|
||||
return KVCacheBlocks(
|
||||
self.coordinator.get_blocks(request_id)).get_block_ids()
|
||||
|
||||
def cache_blocks(self, request: Request, block_hashes: list[BlockHash],
|
||||
num_computed_tokens: int) -> None:
|
||||
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
|
||||
"""Cache the blocks for the request."""
|
||||
block_hashes = self.req_to_block_hashes[request.request_id]
|
||||
self.coordinator.cache_blocks(request, block_hashes,
|
||||
num_computed_tokens)
|
||||
|
||||
def create_empty_block_list(self) -> KVCacheBlocks:
|
||||
"""Creates a new KVCacheBlocks instance with no blocks."""
|
||||
return KVCacheBlocks([[] for _ in range(self.num_kv_cache_groups)])
|
||||
return KVCacheBlocks(tuple([]
|
||||
for _ in range(self.num_kv_cache_groups)))
|
||||
|
||||
@ -27,7 +27,7 @@ class NewRequestData:
|
||||
mm_hashes: list[str]
|
||||
mm_positions: list[PlaceholderRange]
|
||||
sampling_params: SamplingParams
|
||||
block_ids: list[list[int]]
|
||||
block_ids: tuple[list[int], ...]
|
||||
num_computed_tokens: int
|
||||
lora_request: Optional[LoRARequest]
|
||||
|
||||
@ -35,7 +35,7 @@ class NewRequestData:
|
||||
def from_request(
|
||||
cls,
|
||||
request: Request,
|
||||
block_ids: list[list[int]],
|
||||
block_ids: tuple[list[int], ...],
|
||||
) -> NewRequestData:
|
||||
return cls(
|
||||
req_id=request.request_id,
|
||||
@ -86,7 +86,7 @@ class CachedRequestData:
|
||||
# request's block IDs instead of appending to the existing block IDs.
|
||||
resumed_from_preemption: bool
|
||||
new_token_ids: list[int]
|
||||
new_block_ids: list[list[int]]
|
||||
new_block_ids: tuple[list[int], ...]
|
||||
num_computed_tokens: int
|
||||
|
||||
@classmethod
|
||||
@ -95,7 +95,7 @@ class CachedRequestData:
|
||||
request: Request,
|
||||
resumed_from_preemption: bool,
|
||||
new_token_ids: list[int],
|
||||
new_block_ids: list[list[int]],
|
||||
new_block_ids: tuple[list[int], ...],
|
||||
) -> CachedRequestData:
|
||||
return cls(
|
||||
req_id=request.request_id,
|
||||
|
||||
@ -180,7 +180,7 @@ class Scheduler(SchedulerInterface):
|
||||
# uses structured decoding.
|
||||
structured_output_request_ids: dict[str, int] = {}
|
||||
|
||||
req_to_new_block_ids: dict[str, list[list[int]]] = {}
|
||||
req_to_new_block_ids: dict[str, tuple[list[int], ...]] = {}
|
||||
num_scheduled_tokens: dict[str, int] = {}
|
||||
token_budget = self.max_num_scheduled_tokens
|
||||
# Encoder-related.
|
||||
@ -471,7 +471,7 @@ class Scheduler(SchedulerInterface):
|
||||
token_budget -= num_new_tokens
|
||||
request.status = RequestStatus.RUNNING
|
||||
request.num_computed_tokens = num_computed_tokens
|
||||
# Count the number of prifix cached tokens.
|
||||
# Count the number of prefix cached tokens.
|
||||
if request.num_cached_tokens < 0:
|
||||
request.num_cached_tokens = num_computed_tokens
|
||||
# Encoder-related.
|
||||
@ -588,7 +588,7 @@ class Scheduler(SchedulerInterface):
|
||||
request: Request,
|
||||
num_scheduled_tokens: int,
|
||||
num_scheduled_spec_tokens: int,
|
||||
new_block_ids: list[list[int]],
|
||||
new_block_ids: tuple[list[int], ...],
|
||||
resumed_from_preemption: bool,
|
||||
) -> CachedRequestData:
|
||||
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
|
||||
@ -1015,11 +1015,7 @@ class Scheduler(SchedulerInterface):
|
||||
num_computed_tokens = min(num_computed_tokens, request.num_tokens)
|
||||
if num_computed_tokens == request.num_tokens:
|
||||
num_computed_tokens -= 1
|
||||
self.kv_cache_manager.cache_blocks(
|
||||
request,
|
||||
self.kv_cache_manager.req_to_block_hashes[request.request_id],
|
||||
num_computed_tokens,
|
||||
)
|
||||
self.kv_cache_manager.cache_blocks(request, num_computed_tokens)
|
||||
|
||||
# Update the request state for scheduling.
|
||||
request.num_computed_tokens = num_computed_tokens
|
||||
|
||||
@ -197,7 +197,7 @@ class SingleTypeKVCacheManager(ABC):
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
) -> list[list[KVCacheBlock]]:
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
"""
|
||||
Get the longest cache hit prefix of the blocks that is not longer than
|
||||
`max_length`. The prefix should be a common prefix hit for all the
|
||||
@ -222,7 +222,7 @@ class SingleTypeKVCacheManager(ABC):
|
||||
element is a list of cached blocks for the i-th kv cache group
|
||||
in `kv_cache_group_ids`.
|
||||
For example, sliding window manager should return a list like
|
||||
[[NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)]] for block size 4
|
||||
([NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)]) for block size 4
|
||||
and sliding window 8 and len(kv_cache_group_ids) = 1.
|
||||
"""
|
||||
|
||||
@ -254,27 +254,25 @@ class FullAttentionManager(SingleTypeKVCacheManager):
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
) -> list[list[KVCacheBlock]]:
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
assert isinstance(kv_cache_spec, FullAttentionSpec), (
|
||||
"FullAttentionManager can only be used for full attention groups")
|
||||
computed_blocks: list[list[KVCacheBlock]] = [
|
||||
[] for _ in range(len(kv_cache_group_ids))
|
||||
]
|
||||
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
|
||||
[] for _ in range(len(kv_cache_group_ids)))
|
||||
max_num_blocks = max_length // kv_cache_spec.block_size
|
||||
for i in range(max_num_blocks):
|
||||
block_hash = block_hashes[i]
|
||||
for i, block_hash in zip(range(max_num_blocks), block_hashes):
|
||||
# block_hashes is a chain of block hashes. If a block hash is not
|
||||
# in the cached_block_hash_to_id, the following block hashes are
|
||||
# not computed yet for sure.
|
||||
if cached_block := block_pool.get_cached_block(
|
||||
block_hash, kv_cache_group_ids):
|
||||
for j in range(len(kv_cache_group_ids)):
|
||||
computed_blocks[j].append(cached_block[j])
|
||||
for computed, cached in zip(computed_blocks, cached_block):
|
||||
computed.append(cached)
|
||||
else:
|
||||
break
|
||||
if use_eagle and len(computed_blocks[0]) > 0:
|
||||
for j in range(len(kv_cache_group_ids)):
|
||||
computed_blocks[j].pop()
|
||||
if use_eagle and computed_blocks[0]:
|
||||
for computed in computed_blocks:
|
||||
computed.pop()
|
||||
return computed_blocks
|
||||
|
||||
def remove_skipped_blocks(self, request_id: str,
|
||||
@ -311,7 +309,7 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
) -> list[list[KVCacheBlock]]:
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
assert isinstance(kv_cache_spec, SlidingWindowSpec), (
|
||||
"SlidingWindowManager can only be used for sliding window groups")
|
||||
|
||||
@ -332,23 +330,23 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
|
||||
# sliding_window_contiguous_blocks),
|
||||
# which is good for low cache hit rate scenarios.
|
||||
max_num_blocks = max_length // kv_cache_spec.block_size
|
||||
computed_blocks = [[block_pool.null_block] * max_num_blocks
|
||||
for _ in range(len(kv_cache_group_ids))]
|
||||
computed_blocks = tuple([block_pool.null_block] * max_num_blocks
|
||||
for _ in range(len(kv_cache_group_ids)))
|
||||
num_contiguous_blocks = 0
|
||||
match_found = False
|
||||
# Search from right to left and early stop when a match is found.
|
||||
for i in range(max_num_blocks - 1, -1, -1):
|
||||
if cached_block := block_pool.get_cached_block(
|
||||
block_hashes[i], kv_cache_group_ids):
|
||||
for j in range(len(kv_cache_group_ids)):
|
||||
computed_blocks[j][i] = cached_block[j]
|
||||
for computed, cached in zip(computed_blocks, cached_block):
|
||||
computed[i] = cached
|
||||
num_contiguous_blocks += 1
|
||||
if (num_contiguous_blocks >= sliding_window_contiguous_blocks):
|
||||
if num_contiguous_blocks >= sliding_window_contiguous_blocks:
|
||||
# Trim the trailing blocks.
|
||||
# E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3]
|
||||
# when sliding_window_contiguous_blocks=2.
|
||||
for j in range(len(kv_cache_group_ids)):
|
||||
del computed_blocks[j][i + num_contiguous_blocks:]
|
||||
for computed in computed_blocks:
|
||||
del computed[i + num_contiguous_blocks:]
|
||||
match_found = True
|
||||
break
|
||||
else:
|
||||
@ -356,11 +354,11 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
|
||||
if not match_found:
|
||||
# The first `num_contiguous_blocks` is a cache hit even if
|
||||
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
|
||||
for j in range(len(kv_cache_group_ids)):
|
||||
del computed_blocks[j][num_contiguous_blocks:]
|
||||
if use_eagle and len(computed_blocks[0]) > 0:
|
||||
for j in range(len(kv_cache_group_ids)):
|
||||
computed_blocks[j].pop()
|
||||
for computed in computed_blocks:
|
||||
del computed[num_contiguous_blocks:]
|
||||
if use_eagle and computed_blocks[0]:
|
||||
for computed in computed_blocks:
|
||||
computed.pop()
|
||||
return computed_blocks
|
||||
|
||||
def remove_skipped_blocks(self, request_id: str,
|
||||
|
||||
@ -112,11 +112,12 @@ class MultiGroupBlockTable:
|
||||
for block_size in block_sizes
|
||||
]
|
||||
|
||||
def append_row(self, block_ids: list[list[int]], row_idx: int) -> None:
|
||||
def append_row(self, block_ids: tuple[list[int], ...],
|
||||
row_idx: int) -> None:
|
||||
for i, block_table in enumerate(self.block_tables):
|
||||
block_table.append_row(block_ids[i], row_idx)
|
||||
|
||||
def add_row(self, block_ids: list[list[int]], row_idx: int) -> None:
|
||||
def add_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:
|
||||
for i, block_table in enumerate(self.block_tables):
|
||||
block_table.add_row(block_ids[i], row_idx)
|
||||
|
||||
|
||||
@ -30,7 +30,7 @@ class CachedRequestState:
|
||||
sampling_params: SamplingParams
|
||||
generator: Optional[torch.Generator]
|
||||
|
||||
block_ids: list[list[int]]
|
||||
block_ids: tuple[list[int], ...]
|
||||
num_computed_tokens: int
|
||||
output_token_ids: list[int]
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user