[Misc] Improve code readability of KVCacheManager (#21673)

Signed-off-by: tanruixiang <tanruixiang0104@gmail.com>
Signed-off-by: Ruixiang Tan <819464715@qq.com>
Signed-off-by: GitHub <noreply@github.com>
This commit is contained in:
Ruixiang Tan 2025-07-30 22:20:43 +08:00 committed by GitHub
parent 36ede45989
commit 8f4a1c9a04
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 18 additions and 22 deletions

View File

@ -112,9 +112,9 @@ def test_kv_cache_block():
assert block.block_hash is None assert block.block_hash is None
# Test reference count manipulation # Test reference count manipulation
block.incr_ref() block.ref_cnt += 1
assert block.ref_cnt == 1 assert block.ref_cnt == 1
block.decr_ref() block.ref_cnt -= 1
assert block.ref_cnt == 0 assert block.ref_cnt == 0
# Test block hash setting and resetting # Test block hash setting and resetting

View File

@ -276,7 +276,7 @@ class BlockPool:
# candidate), so remove it. # candidate), so remove it.
if block.ref_cnt == 0 and not block.is_null: if block.ref_cnt == 0 and not block.is_null:
self.free_block_queue.remove(block) self.free_block_queue.remove(block)
block.incr_ref() block.ref_cnt += 1
def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None: def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None:
"""Free a list of blocks. The blocks should be ordered by their """Free a list of blocks. The blocks should be ordered by their

View File

@ -126,14 +126,17 @@ class KVCacheCoordinator(ABC):
def get_num_common_prefix_blocks(self, request_id: str, def get_num_common_prefix_blocks(self, request_id: str,
num_running_requests: int) -> list[int]: num_running_requests: int) -> list[int]:
""" """
Get the number of common prefix blocks for a request. Get the number of common prefix blocks for all requests in the RUNNING
state for each kv cache group.
Args: Args:
request_id: The request ID. request_id: The request ID.
num_running_requests: The number of requests in the RUNNING state. num_running_requests: The total number of requests in the RUNNING
state.
Returns: Returns:
list[int]: The number of common prefix blocks. list[int]: The number of common prefix blocks for all requests in
the RUNNING state for each kv cache group.
""" """
num_blocks_per_group = [ num_blocks_per_group = [
manager.get_num_common_prefix_blocks(request_id, manager.get_num_common_prefix_blocks(request_id,

View File

@ -170,10 +170,6 @@ class KVCacheManager:
self.block_size, request) self.block_size, request)
self.req_to_block_hashes[request.request_id] = block_hashes self.req_to_block_hashes[request.request_id] = block_hashes
if self.log_stats:
assert self.prefix_cache_stats is not None
self.prefix_cache_stats.requests += 1
# NOTE: When all tokens hit the cache, we must recompute the last token # NOTE: When all tokens hit the cache, we must recompute the last token
# to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1. # to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1.
# This can trigger recomputation of an entire block, rather than just # This can trigger recomputation of an entire block, rather than just
@ -187,6 +183,7 @@ class KVCacheManager:
if self.log_stats: if self.log_stats:
assert self.prefix_cache_stats is not None assert self.prefix_cache_stats is not None
self.prefix_cache_stats.requests += 1
self.prefix_cache_stats.queries += request.num_tokens self.prefix_cache_stats.queries += request.num_tokens
self.prefix_cache_stats.hits += num_new_computed_tokens self.prefix_cache_stats.hits += num_new_computed_tokens

View File

@ -154,14 +154,6 @@ class KVCacheBlock:
# Whether the block is a null block that should never be cached. # Whether the block is a null block that should never be cached.
is_null: bool = False is_null: bool = False
# TODO(Jialin): For performance, let callers handle ref_cnt bumps to
# avoid function calls.
def incr_ref(self):
self.ref_cnt += 1
def decr_ref(self):
self.ref_cnt -= 1
@property @property
def block_hash(self) -> Optional[BlockHashWithGroupId]: def block_hash(self) -> Optional[BlockHashWithGroupId]:
return self._block_hash return self._block_hash

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict from collections import defaultdict
from typing import Callable from typing import Callable
@ -177,14 +178,17 @@ class SingleTypeKVCacheManager(ABC):
def get_num_common_prefix_blocks(self, request_id: str, def get_num_common_prefix_blocks(self, request_id: str,
num_running_requests: int) -> int: num_running_requests: int) -> int:
""" """
Get the number of common prefix blocks for a request. Get the number of common prefix blocks for all requests in the RUNNING
state.
Args: Args:
request_id: The request ID. request_id: The request ID.
num_running_requests: The number of requests in the RUNNING state. num_running_requests: The total number of requests in the RUNNING
state.
Returns: Returns:
The number of common prefix blocks. The number of common prefix blocks for all requests in the RUNNING
state.
""" """
raise NotImplementedError raise NotImplementedError
@ -264,7 +268,7 @@ class FullAttentionManager(SingleTypeKVCacheManager):
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
[] for _ in range(len(kv_cache_group_ids))) [] for _ in range(len(kv_cache_group_ids)))
max_num_blocks = max_length // kv_cache_spec.block_size max_num_blocks = max_length // kv_cache_spec.block_size
for i, block_hash in zip(range(max_num_blocks), block_hashes): for block_hash in itertools.islice(block_hashes, max_num_blocks):
# block_hashes is a chain of block hashes. If a block hash is not # block_hashes is a chain of block hashes. If a block hash is not
# in the cached_block_hash_to_id, the following block hashes are # in the cached_block_hash_to_id, the following block hashes are
# not computed yet for sure. # not computed yet for sure.