mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:55:01 +08:00
[Core] Use tuple for kv cache group block ids (#19175)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
6cd4ae8acd
commit
646d62f636
@ -117,7 +117,7 @@ def test_prefill(hash_algo):
|
|||||||
blocks = manager.allocate_slots(req0, 55,
|
blocks = manager.allocate_slots(req0, 55,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
|
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||||
|
|
||||||
# Check full block metadata
|
# Check full block metadata
|
||||||
parent_block_hash = None
|
parent_block_hash = None
|
||||||
@ -141,13 +141,13 @@ def test_prefill(hash_algo):
|
|||||||
req1 = make_request("1", common_token_ids + unique_token_ids)
|
req1 = make_request("1", common_token_ids + unique_token_ids)
|
||||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||||
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
|
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
|
||||||
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
|
assert computed_blocks.get_block_ids() == ([1, 2, 3], )
|
||||||
assert num_computed_tokens == 3 * 16
|
assert num_computed_tokens == 3 * 16
|
||||||
num_new_tokens = 53 - 3 * 16
|
num_new_tokens = 53 - 3 * 16
|
||||||
blocks = manager.allocate_slots(req1, num_new_tokens,
|
blocks = manager.allocate_slots(req1, num_new_tokens,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert blocks.get_block_ids() == [[5]]
|
assert blocks.get_block_ids() == ([5], )
|
||||||
for block in computed_blocks.blocks[0]:
|
for block in computed_blocks.blocks[0]:
|
||||||
assert block.ref_cnt == 2
|
assert block.ref_cnt == 2
|
||||||
|
|
||||||
@ -175,13 +175,13 @@ def test_prefill(hash_algo):
|
|||||||
req2 = make_request("2", common_token_ids + unique_token_ids)
|
req2 = make_request("2", common_token_ids + unique_token_ids)
|
||||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||||
assert len(manager.req_to_block_hashes[req2.request_id]) == 3
|
assert len(manager.req_to_block_hashes[req2.request_id]) == 3
|
||||||
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
|
assert computed_blocks.get_block_ids() == ([1, 2, 3], )
|
||||||
assert num_computed_tokens == 3 * 16
|
assert num_computed_tokens == 3 * 16
|
||||||
num_new_tokens = 53 - 3 * 16
|
num_new_tokens = 53 - 3 * 16
|
||||||
blocks = manager.allocate_slots(req2, num_new_tokens,
|
blocks = manager.allocate_slots(req2, num_new_tokens,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert blocks.get_block_ids() == [[6]]
|
assert blocks.get_block_ids() == ([6], )
|
||||||
|
|
||||||
# Although we only have 6 free blocks, we have 8 blocks in
|
# Although we only have 6 free blocks, we have 8 blocks in
|
||||||
# the free block queue due to lazy removal.
|
# the free block queue due to lazy removal.
|
||||||
@ -205,7 +205,7 @@ def test_prefill(hash_algo):
|
|||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
# This block ID order also checks the eviction order.
|
# This block ID order also checks the eviction order.
|
||||||
assert blocks.get_block_ids() == [[7, 8, 9, 10, 4, 5, 6, 3, 2, 1]]
|
assert blocks.get_block_ids() == ([7, 8, 9, 10, 4, 5, 6, 3, 2, 1], )
|
||||||
assert manager.block_pool.free_block_queue.num_free_blocks == 0
|
assert manager.block_pool.free_block_queue.num_free_blocks == 0
|
||||||
assert manager.block_pool.free_block_queue.free_list_head is None
|
assert manager.block_pool.free_block_queue.free_list_head is None
|
||||||
assert manager.block_pool.free_block_queue.free_list_tail is None
|
assert manager.block_pool.free_block_queue.free_list_tail is None
|
||||||
@ -236,8 +236,8 @@ def test_prefill_hybrid_model():
|
|||||||
blocks = manager.allocate_slots(req0, 55,
|
blocks = manager.allocate_slots(req0, 55,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert blocks.get_block_ids() == [[1, 2, 3, 4], [5, 6, 7, 8],
|
assert blocks.get_block_ids() == ([1, 2, 3, 4], [5, 6, 7,
|
||||||
[9, 10, 11, 12]]
|
8], [9, 10, 11, 12])
|
||||||
|
|
||||||
# Check full block metadata
|
# Check full block metadata
|
||||||
parent_block_hash = None
|
parent_block_hash = None
|
||||||
@ -263,14 +263,14 @@ def test_prefill_hybrid_model():
|
|||||||
req1 = make_request("1", common_token_ids + unique_token_ids)
|
req1 = make_request("1", common_token_ids + unique_token_ids)
|
||||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||||
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
|
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
|
||||||
assert computed_blocks.get_block_ids() == [[1, 2, 3], [0, 6, 7],
|
assert computed_blocks.get_block_ids() == ([1, 2, 3], [0, 6,
|
||||||
[0, 10, 11]]
|
7], [0, 10, 11])
|
||||||
assert num_computed_tokens == 3 * 16
|
assert num_computed_tokens == 3 * 16
|
||||||
num_new_tokens = 53 - 3 * 16
|
num_new_tokens = 53 - 3 * 16
|
||||||
blocks = manager.allocate_slots(req1, num_new_tokens,
|
blocks = manager.allocate_slots(req1, num_new_tokens,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert blocks.get_block_ids() == [[13], [14], [15]]
|
assert blocks.get_block_ids() == ([13], [14], [15])
|
||||||
for block_per_group in computed_blocks.blocks:
|
for block_per_group in computed_blocks.blocks:
|
||||||
for block in block_per_group:
|
for block in block_per_group:
|
||||||
if block != manager.block_pool.null_block:
|
if block != manager.block_pool.null_block:
|
||||||
@ -374,7 +374,7 @@ def test_prefill_plp():
|
|||||||
blocks = manager.allocate_slots(req0, 55,
|
blocks = manager.allocate_slots(req0, 55,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
|
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||||
req0_block_hashes = [b.block_hash for b in blocks.blocks[0]]
|
req0_block_hashes = [b.block_hash for b in blocks.blocks[0]]
|
||||||
|
|
||||||
# Check full block metadata
|
# Check full block metadata
|
||||||
@ -400,13 +400,13 @@ def test_prefill_plp():
|
|||||||
req1 = make_request("1", common_token_ids + unique_token_ids)
|
req1 = make_request("1", common_token_ids + unique_token_ids)
|
||||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||||
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
|
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
|
||||||
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
|
assert computed_blocks.get_block_ids() == ([1, 2, 3], )
|
||||||
assert num_computed_tokens == 3 * 16
|
assert num_computed_tokens == 3 * 16
|
||||||
num_new_tokens = 53 - 3 * 16
|
num_new_tokens = 53 - 3 * 16
|
||||||
blocks = manager.allocate_slots(req1, num_new_tokens,
|
blocks = manager.allocate_slots(req1, num_new_tokens,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert blocks.get_block_ids() == [[5]]
|
assert blocks.get_block_ids() == ([5], )
|
||||||
for block in computed_blocks.blocks[0]:
|
for block in computed_blocks.blocks[0]:
|
||||||
assert block.ref_cnt == 2
|
assert block.ref_cnt == 2
|
||||||
|
|
||||||
@ -444,7 +444,7 @@ def test_prefill_plp():
|
|||||||
block_ids = blocks.get_block_ids()
|
block_ids = blocks.get_block_ids()
|
||||||
# Duplicate cached blocks have different ids but same hashes vs request #0
|
# Duplicate cached blocks have different ids but same hashes vs request #0
|
||||||
assert [b.block_hash for b in blocks.blocks[0]] == req0_block_hashes
|
assert [b.block_hash for b in blocks.blocks[0]] == req0_block_hashes
|
||||||
assert block_ids != [[1, 2, 3, 4]]
|
assert block_ids != ([1, 2, 3, 4], )
|
||||||
|
|
||||||
# Request #2 block hashes are valid since request #0 hashes are.
|
# Request #2 block hashes are valid since request #0 hashes are.
|
||||||
# Check block reference counts.
|
# Check block reference counts.
|
||||||
@ -474,7 +474,7 @@ def test_decode():
|
|||||||
blocks = manager.allocate_slots(req0, 55,
|
blocks = manager.allocate_slots(req0, 55,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
|
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||||
|
|
||||||
# Append slots without allocating a new block.
|
# Append slots without allocating a new block.
|
||||||
req0.num_computed_tokens = 55
|
req0.num_computed_tokens = 55
|
||||||
@ -546,12 +546,12 @@ def test_evict():
|
|||||||
# Touch the first 2 blocks.
|
# Touch the first 2 blocks.
|
||||||
req2 = make_request("2", list(range(2 * 16 + 3)))
|
req2 = make_request("2", list(range(2 * 16 + 3)))
|
||||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||||
assert computed_blocks.get_block_ids() == [[1, 2]]
|
assert computed_blocks.get_block_ids() == ([1, 2], )
|
||||||
assert num_computed_tokens == 2 * 16
|
assert num_computed_tokens == 2 * 16
|
||||||
blocks = manager.allocate_slots(req2, 3,
|
blocks = manager.allocate_slots(req2, 3,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert blocks.get_block_ids() == [[10]]
|
assert blocks.get_block_ids() == ([10], )
|
||||||
assert manager.block_pool.free_block_queue.num_free_blocks == 7
|
assert manager.block_pool.free_block_queue.num_free_blocks == 7
|
||||||
|
|
||||||
|
|
||||||
@ -865,7 +865,7 @@ def test_mm_prefix_caching():
|
|||||||
blocks = manager.allocate_slots(req0, 59,
|
blocks = manager.allocate_slots(req0, 59,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
|
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||||
req0.num_computed_tokens = 59
|
req0.num_computed_tokens = 59
|
||||||
|
|
||||||
# Append slots without allocating a new block.
|
# Append slots without allocating a new block.
|
||||||
@ -926,7 +926,7 @@ def test_cache_key_salting():
|
|||||||
blocks = manager.allocate_slots(req0, 59,
|
blocks = manager.allocate_slots(req0, 59,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
|
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||||
req0.num_computed_tokens = 59
|
req0.num_computed_tokens = 59
|
||||||
|
|
||||||
# Append slots without allocating a new block.
|
# Append slots without allocating a new block.
|
||||||
@ -1042,7 +1042,7 @@ def test_reset_prefix_cache():
|
|||||||
all_token_ids = full_block_token_ids + unique_token_ids
|
all_token_ids = full_block_token_ids + unique_token_ids
|
||||||
req0 = make_request("0", all_token_ids)
|
req0 = make_request("0", all_token_ids)
|
||||||
blocks = manager.allocate_slots(req0, 55)
|
blocks = manager.allocate_slots(req0, 55)
|
||||||
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
|
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||||
|
|
||||||
unique_token_ids = [4] * 7
|
unique_token_ids = [4] * 7
|
||||||
all_token_ids = full_block_token_ids + unique_token_ids
|
all_token_ids = full_block_token_ids + unique_token_ids
|
||||||
@ -1053,7 +1053,7 @@ def test_reset_prefix_cache():
|
|||||||
blocks = manager.allocate_slots(req1, 7,
|
blocks = manager.allocate_slots(req1, 7,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert blocks.get_block_ids() == [[5]]
|
assert blocks.get_block_ids() == ([5], )
|
||||||
|
|
||||||
# Failed to reset prefix cache because some blocks are not freed yet.
|
# Failed to reset prefix cache because some blocks are not freed yet.
|
||||||
assert not manager.reset_prefix_cache()
|
assert not manager.reset_prefix_cache()
|
||||||
|
|||||||
@ -71,7 +71,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
|
|||||||
mm_hashes=[],
|
mm_hashes=[],
|
||||||
mm_positions=[],
|
mm_positions=[],
|
||||||
sampling_params=SamplingParams(),
|
sampling_params=SamplingParams(),
|
||||||
block_ids=[[0]], # block_ids should be list[list[int]]
|
block_ids=([0], ), # block_ids should be tuple[list[int]]
|
||||||
num_computed_tokens=0,
|
num_computed_tokens=0,
|
||||||
lora_request=None,
|
lora_request=None,
|
||||||
))
|
))
|
||||||
@ -116,10 +116,10 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
|
|||||||
# This is safe since we currently only use single KV cache groups
|
# This is safe since we currently only use single KV cache groups
|
||||||
block_table = multi_group_block_table[0]
|
block_table = multi_group_block_table[0]
|
||||||
|
|
||||||
# req_state.block_ids is now list[list[int]] for MultiGroupBlockTable
|
# req_state.block_ids is now tuple[list[int], ...] for MultiGroupBlockTable
|
||||||
# Extract the first group's block IDs
|
# Extract the first group's block IDs
|
||||||
if isinstance(req_state.block_ids[0], list):
|
if isinstance(req_state.block_ids[0], list):
|
||||||
# New format: list[list[int]] - extract first group
|
# New format: tuple[list[int], ...] - extract first group
|
||||||
req_block_ids = req_state.block_ids[0]
|
req_block_ids = req_state.block_ids[0]
|
||||||
else:
|
else:
|
||||||
# Legacy format: list[int] - use directly
|
# Legacy format: list[int] - use directly
|
||||||
@ -210,7 +210,7 @@ def test_update_states_request_resumed(model_runner):
|
|||||||
req_id=req_id,
|
req_id=req_id,
|
||||||
resumed_from_preemption=False,
|
resumed_from_preemption=False,
|
||||||
new_token_ids=[],
|
new_token_ids=[],
|
||||||
new_block_ids=[[]],
|
new_block_ids=([], ),
|
||||||
num_computed_tokens=0,
|
num_computed_tokens=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -203,7 +203,7 @@ def _construct_cached_request_state(req_id_suffix: int):
|
|||||||
sampling_params=_create_sampling_params(),
|
sampling_params=_create_sampling_params(),
|
||||||
mm_inputs=[],
|
mm_inputs=[],
|
||||||
mm_positions=[],
|
mm_positions=[],
|
||||||
block_ids=[[]],
|
block_ids=([], ),
|
||||||
generator=None,
|
generator=None,
|
||||||
num_computed_tokens=len(output_token_ids),
|
num_computed_tokens=len(output_token_ids),
|
||||||
output_token_ids=output_token_ids,
|
output_token_ids=output_token_ids,
|
||||||
|
|||||||
@ -123,7 +123,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
|
|||||||
mm_hashes=[],
|
mm_hashes=[],
|
||||||
mm_positions=[],
|
mm_positions=[],
|
||||||
sampling_params=SamplingParams(),
|
sampling_params=SamplingParams(),
|
||||||
block_ids=[[0]],
|
block_ids=([0], ),
|
||||||
num_computed_tokens=0,
|
num_computed_tokens=0,
|
||||||
lora_request=None,
|
lora_request=None,
|
||||||
))
|
))
|
||||||
@ -251,7 +251,7 @@ def test_update_states_request_resumed(model_runner):
|
|||||||
req_id=req_id,
|
req_id=req_id,
|
||||||
resumed_from_preemption=False,
|
resumed_from_preemption=False,
|
||||||
new_token_ids=[],
|
new_token_ids=[],
|
||||||
new_block_ids=[[]],
|
new_block_ids=([], ),
|
||||||
num_computed_tokens=0,
|
num_computed_tokens=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -89,8 +89,8 @@ class BlockPool:
|
|||||||
BlockHashWithGroupId(block_hash, group_id))
|
BlockHashWithGroupId(block_hash, group_id))
|
||||||
if not cached_blocks_one_group:
|
if not cached_blocks_one_group:
|
||||||
return None
|
return None
|
||||||
first_block_id = next(iter(cached_blocks_one_group))
|
first_block = next(iter(cached_blocks_one_group.values()))
|
||||||
cached_blocks.append(cached_blocks_one_group[first_block_id])
|
cached_blocks.append(first_block)
|
||||||
return cached_blocks
|
return cached_blocks
|
||||||
|
|
||||||
def cache_full_blocks(
|
def cache_full_blocks(
|
||||||
@ -260,7 +260,7 @@ class BlockPool:
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def touch(self, blocks: list[list[KVCacheBlock]]) -> None:
|
def touch(self, blocks: tuple[list[KVCacheBlock], ...]) -> None:
|
||||||
"""Touch a block increases its reference count by 1, and may remove
|
"""Touch a block increases its reference count by 1, and may remove
|
||||||
the block from the free queue. This is used when a block is hit by
|
the block from the free queue. This is used when a block is hit by
|
||||||
another request with the same prefix.
|
another request with the same prefix.
|
||||||
@ -299,7 +299,7 @@ class BlockPool:
|
|||||||
bool: True if the prefix cache is successfully reset,
|
bool: True if the prefix cache is successfully reset,
|
||||||
False otherwise.
|
False otherwise.
|
||||||
"""
|
"""
|
||||||
num_used_blocks = (self.num_gpu_blocks - self.get_num_free_blocks())
|
num_used_blocks = self.num_gpu_blocks - self.get_num_free_blocks()
|
||||||
if num_used_blocks != 1: # The null block is always marked as used
|
if num_used_blocks != 1: # The null block is always marked as used
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Failed to reset prefix cache because some "
|
"Failed to reset prefix cache because some "
|
||||||
|
|||||||
@ -5,8 +5,7 @@ from typing import Callable, Optional
|
|||||||
from vllm.v1.core.block_pool import BlockPool
|
from vllm.v1.core.block_pool import BlockPool
|
||||||
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
|
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
|
||||||
from vllm.v1.core.single_type_kv_cache_manager import (
|
from vllm.v1.core.single_type_kv_cache_manager import (
|
||||||
FullAttentionManager, SingleTypeKVCacheManager,
|
FullAttentionManager, get_manager_for_kv_cache_spec)
|
||||||
get_manager_for_kv_cache_spec)
|
|
||||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
|
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
|
||||||
from vllm.v1.request import Request
|
from vllm.v1.request import Request
|
||||||
|
|
||||||
@ -30,25 +29,21 @@ class KVCacheCoordinator(ABC):
|
|||||||
|
|
||||||
self.block_pool = BlockPool(kv_cache_config.num_blocks, enable_caching,
|
self.block_pool = BlockPool(kv_cache_config.num_blocks, enable_caching,
|
||||||
enable_kv_cache_events)
|
enable_kv_cache_events)
|
||||||
self.single_type_managers: list[SingleTypeKVCacheManager] = []
|
|
||||||
|
|
||||||
# Needs special handling for find_longest_cache_hit if eagle is enabled
|
# Needs special handling for find_longest_cache_hit if eagle is enabled
|
||||||
self.use_eagle = use_eagle
|
self.use_eagle = use_eagle
|
||||||
|
self.single_type_managers = tuple(
|
||||||
for i in range(len(self.kv_cache_config.kv_cache_groups)):
|
get_manager_for_kv_cache_spec(
|
||||||
kv_cache_spec = self.kv_cache_config.kv_cache_groups[
|
kv_cache_spec=kv_cache_group.kv_cache_spec,
|
||||||
i].kv_cache_spec
|
block_pool=self.block_pool,
|
||||||
self.single_type_managers.append(
|
kv_cache_group_id=i,
|
||||||
get_manager_for_kv_cache_spec(
|
caching_hash_fn=caching_hash_fn,
|
||||||
kv_cache_spec=kv_cache_spec,
|
) for i, kv_cache_group in enumerate(
|
||||||
block_pool=self.block_pool,
|
self.kv_cache_config.kv_cache_groups))
|
||||||
kv_cache_group_id=i,
|
|
||||||
caching_hash_fn=caching_hash_fn,
|
|
||||||
))
|
|
||||||
|
|
||||||
def get_num_blocks_to_allocate(
|
def get_num_blocks_to_allocate(
|
||||||
self, request_id: str, num_tokens: int,
|
self, request_id: str, num_tokens: int,
|
||||||
new_computed_blocks: list[list[KVCacheBlock]]) -> int:
|
new_computed_blocks: tuple[list[KVCacheBlock], ...]) -> int:
|
||||||
"""
|
"""
|
||||||
Get the number of blocks needed to be allocated for the request.
|
Get the number of blocks needed to be allocated for the request.
|
||||||
|
|
||||||
@ -70,7 +65,7 @@ class KVCacheCoordinator(ABC):
|
|||||||
|
|
||||||
def save_new_computed_blocks(
|
def save_new_computed_blocks(
|
||||||
self, request_id: str,
|
self, request_id: str,
|
||||||
new_computed_blocks: list[list[KVCacheBlock]]) -> None:
|
new_computed_blocks: tuple[list[KVCacheBlock], ...]) -> None:
|
||||||
"""
|
"""
|
||||||
Add the new computed blocks to the request.
|
Add the new computed blocks to the request.
|
||||||
|
|
||||||
@ -84,7 +79,7 @@ class KVCacheCoordinator(ABC):
|
|||||||
new_computed_blocks[i])
|
new_computed_blocks[i])
|
||||||
|
|
||||||
def allocate_new_blocks(self, request_id: str,
|
def allocate_new_blocks(self, request_id: str,
|
||||||
num_tokens: int) -> list[list[KVCacheBlock]]:
|
num_tokens: int) -> tuple[list[KVCacheBlock], ...]:
|
||||||
"""
|
"""
|
||||||
Allocate new blocks for the request to give it at least `num_tokens`
|
Allocate new blocks for the request to give it at least `num_tokens`
|
||||||
token slots.
|
token slots.
|
||||||
@ -97,11 +92,9 @@ class KVCacheCoordinator(ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
The new allocated blocks.
|
The new allocated blocks.
|
||||||
"""
|
"""
|
||||||
new_blocks = []
|
return tuple(
|
||||||
for manager in self.single_type_managers:
|
manager.allocate_new_blocks(request_id, num_tokens)
|
||||||
new_blocks.append(
|
for manager in self.single_type_managers)
|
||||||
manager.allocate_new_blocks(request_id, num_tokens))
|
|
||||||
return new_blocks
|
|
||||||
|
|
||||||
def cache_blocks(self, request: Request, block_hashes: list[BlockHash],
|
def cache_blocks(self, request: Request, block_hashes: list[BlockHash],
|
||||||
num_computed_tokens: int) -> None:
|
num_computed_tokens: int) -> None:
|
||||||
@ -159,19 +152,20 @@ class KVCacheCoordinator(ABC):
|
|||||||
for manager in self.single_type_managers:
|
for manager in self.single_type_managers:
|
||||||
manager.remove_skipped_blocks(request_id, num_computed_tokens)
|
manager.remove_skipped_blocks(request_id, num_computed_tokens)
|
||||||
|
|
||||||
def get_blocks(self, request_id: str) -> list[list[KVCacheBlock]]:
|
def get_blocks(self, request_id: str) -> tuple[list[KVCacheBlock], ...]:
|
||||||
"""
|
"""
|
||||||
Get the blocks for the request.
|
Get the blocks for the request.
|
||||||
"""
|
"""
|
||||||
return [
|
return tuple(
|
||||||
manager.req_to_blocks.get(request_id) or []
|
manager.req_to_blocks.get(request_id) or []
|
||||||
for manager in self.single_type_managers
|
for manager in self.single_type_managers)
|
||||||
]
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def find_longest_cache_hit(
|
def find_longest_cache_hit(
|
||||||
self, block_hashes: list[BlockHash],
|
self,
|
||||||
max_cache_hit_length: int) -> tuple[list[list[KVCacheBlock]], int]:
|
block_hashes: list[BlockHash],
|
||||||
|
max_cache_hit_length: int,
|
||||||
|
) -> tuple[tuple[list[KVCacheBlock], ...], int]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -195,8 +189,10 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
|
|||||||
"UnitaryKVCacheCoordinator assumes only one kv cache group")
|
"UnitaryKVCacheCoordinator assumes only one kv cache group")
|
||||||
|
|
||||||
def find_longest_cache_hit(
|
def find_longest_cache_hit(
|
||||||
self, block_hashes: list[BlockHash],
|
self,
|
||||||
max_cache_hit_length: int) -> tuple[list[list[KVCacheBlock]], int]:
|
block_hashes: list[BlockHash],
|
||||||
|
max_cache_hit_length: int,
|
||||||
|
) -> tuple[tuple[list[KVCacheBlock], ...], int]:
|
||||||
hit_blocks = self.single_type_managers[0].find_longest_cache_hit(
|
hit_blocks = self.single_type_managers[0].find_longest_cache_hit(
|
||||||
block_hashes=block_hashes,
|
block_hashes=block_hashes,
|
||||||
max_length=max_cache_hit_length,
|
max_length=max_cache_hit_length,
|
||||||
@ -275,11 +271,24 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
|||||||
"KVCacheCoordinator assumes the block_size of full attention "
|
"KVCacheCoordinator assumes the block_size of full attention "
|
||||||
"layers is divisible by other layers now.")
|
"layers is divisible by other layers now.")
|
||||||
|
|
||||||
|
if max(self.full_attention_group_ids) < min(self.other_group_ids):
|
||||||
|
self.full_attn_first = True
|
||||||
|
elif max(self.other_group_ids) < min(self.full_attention_group_ids):
|
||||||
|
self.full_attn_first = False
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"HybridKVCacheCoordinator assumes the full "
|
||||||
|
"attention group ids and other attention group ids "
|
||||||
|
"do not interleave, either full attention group ids "
|
||||||
|
"are before other attention group ids or vice versa."
|
||||||
|
"This is for simplifying merging hit_blocks_full_attn and "
|
||||||
|
"hit_blocks_other_attn to hit_blocks.")
|
||||||
|
|
||||||
def find_longest_cache_hit(
|
def find_longest_cache_hit(
|
||||||
self,
|
self,
|
||||||
block_hashes: list[BlockHash],
|
block_hashes: list[BlockHash],
|
||||||
max_cache_hit_length: int,
|
max_cache_hit_length: int,
|
||||||
) -> tuple[list[list[KVCacheBlock]], int]:
|
) -> tuple[tuple[list[KVCacheBlock], ...], int]:
|
||||||
"""
|
"""
|
||||||
Find the longest cache hit for the request.
|
Find the longest cache hit for the request.
|
||||||
|
|
||||||
@ -318,27 +327,25 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
|||||||
))
|
))
|
||||||
hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size
|
hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size
|
||||||
|
|
||||||
# NOTE: the prefix cache hit length must be a multiply of block_size as
|
# NOTE: the prefix cache hit length must be a multiple of block_size as
|
||||||
# we don't support partial block cache hit yet. The cache hit length
|
# we don't support partial block cache hit yet. The cache hit length
|
||||||
# of other attention is ensured to be a multiply of the block size of
|
# of other attention is ensured to be a multiple of the block size of
|
||||||
# full attention layers in current implementation, because hit_length is
|
# full attention layers in current implementation, because hit_length is
|
||||||
# a multiply of other attention's block size, and other attention's
|
# a multiple of other attention's block size, and other attention's
|
||||||
# block size is a multiply of full attention's block size (verified in
|
# block size is a multiple of full attention's block size (verified in
|
||||||
# `verify_and_split_kv_cache_groups`).
|
# `verify_and_split_kv_cache_groups`).
|
||||||
assert hit_length % self.full_attention_block_size == 0
|
assert hit_length % self.full_attention_block_size == 0
|
||||||
|
|
||||||
# Truncate the full attention cache hit to the length of the
|
# Truncate the full attention cache hit to the length of the
|
||||||
# cache hit of the other attention.
|
# cache hit of the other attention.
|
||||||
for i in range(len(hit_blocks_full_attn)):
|
for group_hit_blocks in hit_blocks_full_attn:
|
||||||
del hit_blocks_full_attn[i][hit_length //
|
del group_hit_blocks[hit_length // self.full_attention_block_size:]
|
||||||
self.full_attention_block_size:]
|
|
||||||
|
|
||||||
# Merge the hit blocks of full attention and other attention.
|
# Merge the hit blocks of full attention and other attention.
|
||||||
hit_blocks = hit_blocks_other_attn
|
if self.full_attn_first:
|
||||||
for group_id, blocks in enumerate(hit_blocks_full_attn):
|
hit_blocks = hit_blocks_full_attn + hit_blocks_other_attn
|
||||||
# NOTE: there is only one full attention group in most cases. So
|
else:
|
||||||
# the time complexity of insert is fine.
|
hit_blocks = hit_blocks_other_attn + hit_blocks_full_attn
|
||||||
hit_blocks.insert(group_id, blocks)
|
|
||||||
return hit_blocks, hit_length
|
return hit_blocks, hit_length
|
||||||
|
|
||||||
|
|
||||||
@ -351,8 +358,6 @@ def get_kv_cache_coordinator(
|
|||||||
use_eagle, enable_caching,
|
use_eagle, enable_caching,
|
||||||
caching_hash_fn,
|
caching_hash_fn,
|
||||||
enable_kv_cache_events)
|
enable_kv_cache_events)
|
||||||
else:
|
return HybridKVCacheCoordinator(kv_cache_config, max_model_len, use_eagle,
|
||||||
return HybridKVCacheCoordinator(kv_cache_config, max_model_len,
|
enable_caching, caching_hash_fn,
|
||||||
use_eagle, enable_caching,
|
enable_kv_cache_events)
|
||||||
caching_hash_fn,
|
|
||||||
enable_kv_cache_events)
|
|
||||||
|
|||||||
@ -25,7 +25,7 @@ class KVCacheBlocks:
|
|||||||
Scheduler and KVCacheManager, to hide KVCacheManager's internal data
|
Scheduler and KVCacheManager, to hide KVCacheManager's internal data
|
||||||
structure from the Scheduler.
|
structure from the Scheduler.
|
||||||
"""
|
"""
|
||||||
blocks: list[list[KVCacheBlock]]
|
blocks: tuple[list[KVCacheBlock], ...]
|
||||||
"""
|
"""
|
||||||
blocks[i][j] refers to the i-th kv_cache_group and the j-th block of tokens.
|
blocks[i][j] refers to the i-th kv_cache_group and the j-th block of tokens.
|
||||||
We don't use block of tokens as the outer dimension because it assumes all
|
We don't use block of tokens as the outer dimension because it assumes all
|
||||||
@ -37,21 +37,19 @@ class KVCacheBlocks:
|
|||||||
def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks":
|
def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks":
|
||||||
"""Adds two KVCacheBlocks instances."""
|
"""Adds two KVCacheBlocks instances."""
|
||||||
return KVCacheBlocks(
|
return KVCacheBlocks(
|
||||||
[blk1 + blk2 for blk1, blk2 in zip(self.blocks, other.blocks)])
|
tuple(blk1 + blk2
|
||||||
|
for blk1, blk2 in zip(self.blocks, other.blocks)))
|
||||||
|
|
||||||
def get_block_ids(self) -> list[list[int]]:
|
def get_block_ids(self) -> tuple[list[int], ...]:
|
||||||
"""
|
"""
|
||||||
Converts the KVCacheBlocks instance to block_ids.
|
Converts the KVCacheBlocks instance to block_ids.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[list[int]]: A two-level list where
|
tuple[list[int], ...]: A tuple of lists where
|
||||||
* the outer list corresponds to KV cache groups
|
* the outer tuple corresponds to KV cache groups
|
||||||
* each inner list contains the block_ids of the blocks in that group
|
* each inner list contains the block_ids of the blocks in that group
|
||||||
"""
|
"""
|
||||||
block_ids = []
|
return tuple([blk.block_id for blk in group] for group in self.blocks)
|
||||||
for group in self.blocks:
|
|
||||||
block_ids.append([blk.block_id for blk in group])
|
|
||||||
return block_ids
|
|
||||||
|
|
||||||
def get_unhashed_block_ids(self) -> list[int]:
|
def get_unhashed_block_ids(self) -> list[int]:
|
||||||
"""Get block_ids of unhashed blocks from KVCacheBlocks instance."""
|
"""Get block_ids of unhashed blocks from KVCacheBlocks instance."""
|
||||||
@ -63,7 +61,7 @@ class KVCacheBlocks:
|
|||||||
|
|
||||||
def new_empty(self) -> "KVCacheBlocks":
|
def new_empty(self) -> "KVCacheBlocks":
|
||||||
"""Creates a new KVCacheBlocks instance with no blocks."""
|
"""Creates a new KVCacheBlocks instance with no blocks."""
|
||||||
return KVCacheBlocks([[] for _ in range(len(self.blocks))])
|
return KVCacheBlocks(tuple([] for _ in range(len(self.blocks))))
|
||||||
|
|
||||||
|
|
||||||
class KVCacheManager:
|
class KVCacheManager:
|
||||||
@ -232,9 +230,8 @@ class KVCacheManager:
|
|||||||
if new_computed_blocks is not None:
|
if new_computed_blocks is not None:
|
||||||
new_computed_block_list = new_computed_blocks.blocks
|
new_computed_block_list = new_computed_blocks.blocks
|
||||||
else:
|
else:
|
||||||
new_computed_block_list = [
|
new_computed_block_list = tuple(
|
||||||
[] for _ in range(len(self.kv_cache_config.kv_cache_groups))
|
[] for _ in range(len(self.kv_cache_config.kv_cache_groups)))
|
||||||
]
|
|
||||||
|
|
||||||
# Free the blocks that are skipped during the attention computation
|
# Free the blocks that are skipped during the attention computation
|
||||||
# (e.g., tokens outside the sliding window).
|
# (e.g., tokens outside the sliding window).
|
||||||
@ -267,7 +264,7 @@ class KVCacheManager:
|
|||||||
if self.enable_caching:
|
if self.enable_caching:
|
||||||
self.block_pool.touch(new_computed_block_list)
|
self.block_pool.touch(new_computed_block_list)
|
||||||
else:
|
else:
|
||||||
assert all(not blocks for blocks in new_computed_block_list), (
|
assert not any(new_computed_block_list), (
|
||||||
"Computed blocks should be empty when "
|
"Computed blocks should be empty when "
|
||||||
"prefix caching is disabled")
|
"prefix caching is disabled")
|
||||||
|
|
||||||
@ -378,17 +375,18 @@ class KVCacheManager:
|
|||||||
"""
|
"""
|
||||||
return self.block_pool.take_events()
|
return self.block_pool.take_events()
|
||||||
|
|
||||||
def get_block_ids(self, request_id: str) -> list[list[int]]:
|
def get_block_ids(self, request_id: str) -> tuple[list[int], ...]:
|
||||||
"""Get the block ids of a request."""
|
"""Get the block ids of a request."""
|
||||||
return KVCacheBlocks(
|
return KVCacheBlocks(
|
||||||
self.coordinator.get_blocks(request_id)).get_block_ids()
|
self.coordinator.get_blocks(request_id)).get_block_ids()
|
||||||
|
|
||||||
def cache_blocks(self, request: Request, block_hashes: list[BlockHash],
|
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
|
||||||
num_computed_tokens: int) -> None:
|
|
||||||
"""Cache the blocks for the request."""
|
"""Cache the blocks for the request."""
|
||||||
|
block_hashes = self.req_to_block_hashes[request.request_id]
|
||||||
self.coordinator.cache_blocks(request, block_hashes,
|
self.coordinator.cache_blocks(request, block_hashes,
|
||||||
num_computed_tokens)
|
num_computed_tokens)
|
||||||
|
|
||||||
def create_empty_block_list(self) -> KVCacheBlocks:
|
def create_empty_block_list(self) -> KVCacheBlocks:
|
||||||
"""Creates a new KVCacheBlocks instance with no blocks."""
|
"""Creates a new KVCacheBlocks instance with no blocks."""
|
||||||
return KVCacheBlocks([[] for _ in range(self.num_kv_cache_groups)])
|
return KVCacheBlocks(tuple([]
|
||||||
|
for _ in range(self.num_kv_cache_groups)))
|
||||||
|
|||||||
@ -27,7 +27,7 @@ class NewRequestData:
|
|||||||
mm_hashes: list[str]
|
mm_hashes: list[str]
|
||||||
mm_positions: list[PlaceholderRange]
|
mm_positions: list[PlaceholderRange]
|
||||||
sampling_params: SamplingParams
|
sampling_params: SamplingParams
|
||||||
block_ids: list[list[int]]
|
block_ids: tuple[list[int], ...]
|
||||||
num_computed_tokens: int
|
num_computed_tokens: int
|
||||||
lora_request: Optional[LoRARequest]
|
lora_request: Optional[LoRARequest]
|
||||||
|
|
||||||
@ -35,7 +35,7 @@ class NewRequestData:
|
|||||||
def from_request(
|
def from_request(
|
||||||
cls,
|
cls,
|
||||||
request: Request,
|
request: Request,
|
||||||
block_ids: list[list[int]],
|
block_ids: tuple[list[int], ...],
|
||||||
) -> NewRequestData:
|
) -> NewRequestData:
|
||||||
return cls(
|
return cls(
|
||||||
req_id=request.request_id,
|
req_id=request.request_id,
|
||||||
@ -86,7 +86,7 @@ class CachedRequestData:
|
|||||||
# request's block IDs instead of appending to the existing block IDs.
|
# request's block IDs instead of appending to the existing block IDs.
|
||||||
resumed_from_preemption: bool
|
resumed_from_preemption: bool
|
||||||
new_token_ids: list[int]
|
new_token_ids: list[int]
|
||||||
new_block_ids: list[list[int]]
|
new_block_ids: tuple[list[int], ...]
|
||||||
num_computed_tokens: int
|
num_computed_tokens: int
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -95,7 +95,7 @@ class CachedRequestData:
|
|||||||
request: Request,
|
request: Request,
|
||||||
resumed_from_preemption: bool,
|
resumed_from_preemption: bool,
|
||||||
new_token_ids: list[int],
|
new_token_ids: list[int],
|
||||||
new_block_ids: list[list[int]],
|
new_block_ids: tuple[list[int], ...],
|
||||||
) -> CachedRequestData:
|
) -> CachedRequestData:
|
||||||
return cls(
|
return cls(
|
||||||
req_id=request.request_id,
|
req_id=request.request_id,
|
||||||
|
|||||||
@ -180,7 +180,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
# uses structured decoding.
|
# uses structured decoding.
|
||||||
structured_output_request_ids: dict[str, int] = {}
|
structured_output_request_ids: dict[str, int] = {}
|
||||||
|
|
||||||
req_to_new_block_ids: dict[str, list[list[int]]] = {}
|
req_to_new_block_ids: dict[str, tuple[list[int], ...]] = {}
|
||||||
num_scheduled_tokens: dict[str, int] = {}
|
num_scheduled_tokens: dict[str, int] = {}
|
||||||
token_budget = self.max_num_scheduled_tokens
|
token_budget = self.max_num_scheduled_tokens
|
||||||
# Encoder-related.
|
# Encoder-related.
|
||||||
@ -471,7 +471,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
token_budget -= num_new_tokens
|
token_budget -= num_new_tokens
|
||||||
request.status = RequestStatus.RUNNING
|
request.status = RequestStatus.RUNNING
|
||||||
request.num_computed_tokens = num_computed_tokens
|
request.num_computed_tokens = num_computed_tokens
|
||||||
# Count the number of prifix cached tokens.
|
# Count the number of prefix cached tokens.
|
||||||
if request.num_cached_tokens < 0:
|
if request.num_cached_tokens < 0:
|
||||||
request.num_cached_tokens = num_computed_tokens
|
request.num_cached_tokens = num_computed_tokens
|
||||||
# Encoder-related.
|
# Encoder-related.
|
||||||
@ -588,7 +588,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
request: Request,
|
request: Request,
|
||||||
num_scheduled_tokens: int,
|
num_scheduled_tokens: int,
|
||||||
num_scheduled_spec_tokens: int,
|
num_scheduled_spec_tokens: int,
|
||||||
new_block_ids: list[list[int]],
|
new_block_ids: tuple[list[int], ...],
|
||||||
resumed_from_preemption: bool,
|
resumed_from_preemption: bool,
|
||||||
) -> CachedRequestData:
|
) -> CachedRequestData:
|
||||||
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
|
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
|
||||||
@ -1015,11 +1015,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
num_computed_tokens = min(num_computed_tokens, request.num_tokens)
|
num_computed_tokens = min(num_computed_tokens, request.num_tokens)
|
||||||
if num_computed_tokens == request.num_tokens:
|
if num_computed_tokens == request.num_tokens:
|
||||||
num_computed_tokens -= 1
|
num_computed_tokens -= 1
|
||||||
self.kv_cache_manager.cache_blocks(
|
self.kv_cache_manager.cache_blocks(request, num_computed_tokens)
|
||||||
request,
|
|
||||||
self.kv_cache_manager.req_to_block_hashes[request.request_id],
|
|
||||||
num_computed_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update the request state for scheduling.
|
# Update the request state for scheduling.
|
||||||
request.num_computed_tokens = num_computed_tokens
|
request.num_computed_tokens = num_computed_tokens
|
||||||
|
|||||||
@ -197,7 +197,7 @@ class SingleTypeKVCacheManager(ABC):
|
|||||||
block_pool: BlockPool,
|
block_pool: BlockPool,
|
||||||
kv_cache_spec: KVCacheSpec,
|
kv_cache_spec: KVCacheSpec,
|
||||||
use_eagle: bool,
|
use_eagle: bool,
|
||||||
) -> list[list[KVCacheBlock]]:
|
) -> tuple[list[KVCacheBlock], ...]:
|
||||||
"""
|
"""
|
||||||
Get the longest cache hit prefix of the blocks that is not longer than
|
Get the longest cache hit prefix of the blocks that is not longer than
|
||||||
`max_length`. The prefix should be a common prefix hit for all the
|
`max_length`. The prefix should be a common prefix hit for all the
|
||||||
@ -222,7 +222,7 @@ class SingleTypeKVCacheManager(ABC):
|
|||||||
element is a list of cached blocks for the i-th kv cache group
|
element is a list of cached blocks for the i-th kv cache group
|
||||||
in `kv_cache_group_ids`.
|
in `kv_cache_group_ids`.
|
||||||
For example, sliding window manager should return a list like
|
For example, sliding window manager should return a list like
|
||||||
[[NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)]] for block size 4
|
([NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)]) for block size 4
|
||||||
and sliding window 8 and len(kv_cache_group_ids) = 1.
|
and sliding window 8 and len(kv_cache_group_ids) = 1.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -254,27 +254,25 @@ class FullAttentionManager(SingleTypeKVCacheManager):
|
|||||||
block_pool: BlockPool,
|
block_pool: BlockPool,
|
||||||
kv_cache_spec: KVCacheSpec,
|
kv_cache_spec: KVCacheSpec,
|
||||||
use_eagle: bool,
|
use_eagle: bool,
|
||||||
) -> list[list[KVCacheBlock]]:
|
) -> tuple[list[KVCacheBlock], ...]:
|
||||||
assert isinstance(kv_cache_spec, FullAttentionSpec), (
|
assert isinstance(kv_cache_spec, FullAttentionSpec), (
|
||||||
"FullAttentionManager can only be used for full attention groups")
|
"FullAttentionManager can only be used for full attention groups")
|
||||||
computed_blocks: list[list[KVCacheBlock]] = [
|
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
|
||||||
[] for _ in range(len(kv_cache_group_ids))
|
[] for _ in range(len(kv_cache_group_ids)))
|
||||||
]
|
|
||||||
max_num_blocks = max_length // kv_cache_spec.block_size
|
max_num_blocks = max_length // kv_cache_spec.block_size
|
||||||
for i in range(max_num_blocks):
|
for i, block_hash in zip(range(max_num_blocks), block_hashes):
|
||||||
block_hash = block_hashes[i]
|
|
||||||
# block_hashes is a chain of block hashes. If a block hash is not
|
# block_hashes is a chain of block hashes. If a block hash is not
|
||||||
# in the cached_block_hash_to_id, the following block hashes are
|
# in the cached_block_hash_to_id, the following block hashes are
|
||||||
# not computed yet for sure.
|
# not computed yet for sure.
|
||||||
if cached_block := block_pool.get_cached_block(
|
if cached_block := block_pool.get_cached_block(
|
||||||
block_hash, kv_cache_group_ids):
|
block_hash, kv_cache_group_ids):
|
||||||
for j in range(len(kv_cache_group_ids)):
|
for computed, cached in zip(computed_blocks, cached_block):
|
||||||
computed_blocks[j].append(cached_block[j])
|
computed.append(cached)
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
if use_eagle and len(computed_blocks[0]) > 0:
|
if use_eagle and computed_blocks[0]:
|
||||||
for j in range(len(kv_cache_group_ids)):
|
for computed in computed_blocks:
|
||||||
computed_blocks[j].pop()
|
computed.pop()
|
||||||
return computed_blocks
|
return computed_blocks
|
||||||
|
|
||||||
def remove_skipped_blocks(self, request_id: str,
|
def remove_skipped_blocks(self, request_id: str,
|
||||||
@ -311,7 +309,7 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
|
|||||||
block_pool: BlockPool,
|
block_pool: BlockPool,
|
||||||
kv_cache_spec: KVCacheSpec,
|
kv_cache_spec: KVCacheSpec,
|
||||||
use_eagle: bool,
|
use_eagle: bool,
|
||||||
) -> list[list[KVCacheBlock]]:
|
) -> tuple[list[KVCacheBlock], ...]:
|
||||||
assert isinstance(kv_cache_spec, SlidingWindowSpec), (
|
assert isinstance(kv_cache_spec, SlidingWindowSpec), (
|
||||||
"SlidingWindowManager can only be used for sliding window groups")
|
"SlidingWindowManager can only be used for sliding window groups")
|
||||||
|
|
||||||
@ -332,23 +330,23 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
|
|||||||
# sliding_window_contiguous_blocks),
|
# sliding_window_contiguous_blocks),
|
||||||
# which is good for low cache hit rate scenarios.
|
# which is good for low cache hit rate scenarios.
|
||||||
max_num_blocks = max_length // kv_cache_spec.block_size
|
max_num_blocks = max_length // kv_cache_spec.block_size
|
||||||
computed_blocks = [[block_pool.null_block] * max_num_blocks
|
computed_blocks = tuple([block_pool.null_block] * max_num_blocks
|
||||||
for _ in range(len(kv_cache_group_ids))]
|
for _ in range(len(kv_cache_group_ids)))
|
||||||
num_contiguous_blocks = 0
|
num_contiguous_blocks = 0
|
||||||
match_found = False
|
match_found = False
|
||||||
# Search from right to left and early stop when a match is found.
|
# Search from right to left and early stop when a match is found.
|
||||||
for i in range(max_num_blocks - 1, -1, -1):
|
for i in range(max_num_blocks - 1, -1, -1):
|
||||||
if cached_block := block_pool.get_cached_block(
|
if cached_block := block_pool.get_cached_block(
|
||||||
block_hashes[i], kv_cache_group_ids):
|
block_hashes[i], kv_cache_group_ids):
|
||||||
for j in range(len(kv_cache_group_ids)):
|
for computed, cached in zip(computed_blocks, cached_block):
|
||||||
computed_blocks[j][i] = cached_block[j]
|
computed[i] = cached
|
||||||
num_contiguous_blocks += 1
|
num_contiguous_blocks += 1
|
||||||
if (num_contiguous_blocks >= sliding_window_contiguous_blocks):
|
if num_contiguous_blocks >= sliding_window_contiguous_blocks:
|
||||||
# Trim the trailing blocks.
|
# Trim the trailing blocks.
|
||||||
# E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3]
|
# E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3]
|
||||||
# when sliding_window_contiguous_blocks=2.
|
# when sliding_window_contiguous_blocks=2.
|
||||||
for j in range(len(kv_cache_group_ids)):
|
for computed in computed_blocks:
|
||||||
del computed_blocks[j][i + num_contiguous_blocks:]
|
del computed[i + num_contiguous_blocks:]
|
||||||
match_found = True
|
match_found = True
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
@ -356,11 +354,11 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
|
|||||||
if not match_found:
|
if not match_found:
|
||||||
# The first `num_contiguous_blocks` is a cache hit even if
|
# The first `num_contiguous_blocks` is a cache hit even if
|
||||||
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
|
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
|
||||||
for j in range(len(kv_cache_group_ids)):
|
for computed in computed_blocks:
|
||||||
del computed_blocks[j][num_contiguous_blocks:]
|
del computed[num_contiguous_blocks:]
|
||||||
if use_eagle and len(computed_blocks[0]) > 0:
|
if use_eagle and computed_blocks[0]:
|
||||||
for j in range(len(kv_cache_group_ids)):
|
for computed in computed_blocks:
|
||||||
computed_blocks[j].pop()
|
computed.pop()
|
||||||
return computed_blocks
|
return computed_blocks
|
||||||
|
|
||||||
def remove_skipped_blocks(self, request_id: str,
|
def remove_skipped_blocks(self, request_id: str,
|
||||||
|
|||||||
@ -112,11 +112,12 @@ class MultiGroupBlockTable:
|
|||||||
for block_size in block_sizes
|
for block_size in block_sizes
|
||||||
]
|
]
|
||||||
|
|
||||||
def append_row(self, block_ids: list[list[int]], row_idx: int) -> None:
|
def append_row(self, block_ids: tuple[list[int], ...],
|
||||||
|
row_idx: int) -> None:
|
||||||
for i, block_table in enumerate(self.block_tables):
|
for i, block_table in enumerate(self.block_tables):
|
||||||
block_table.append_row(block_ids[i], row_idx)
|
block_table.append_row(block_ids[i], row_idx)
|
||||||
|
|
||||||
def add_row(self, block_ids: list[list[int]], row_idx: int) -> None:
|
def add_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:
|
||||||
for i, block_table in enumerate(self.block_tables):
|
for i, block_table in enumerate(self.block_tables):
|
||||||
block_table.add_row(block_ids[i], row_idx)
|
block_table.add_row(block_ids[i], row_idx)
|
||||||
|
|
||||||
|
|||||||
@ -30,7 +30,7 @@ class CachedRequestState:
|
|||||||
sampling_params: SamplingParams
|
sampling_params: SamplingParams
|
||||||
generator: Optional[torch.Generator]
|
generator: Optional[torch.Generator]
|
||||||
|
|
||||||
block_ids: list[list[int]]
|
block_ids: tuple[list[int], ...]
|
||||||
num_computed_tokens: int
|
num_computed_tokens: int
|
||||||
output_token_ids: list[int]
|
output_token_ids: list[int]
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user