[Misc] Improve code readability of KVCacheManager (#21673)

Signed-off-by: tanruixiang <tanruixiang0104@gmail.com>
Signed-off-by: Ruixiang Tan <819464715@qq.com>
Signed-off-by: GitHub <noreply@github.com>
This commit is contained in:
Ruixiang Tan 2025-07-30 22:20:43 +08:00 committed by GitHub
parent 36ede45989
commit 8f4a1c9a04
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 18 additions and 22 deletions

View File

@ -112,9 +112,9 @@ def test_kv_cache_block():
assert block.block_hash is None
# Test reference count manipulation
block.incr_ref()
block.ref_cnt += 1
assert block.ref_cnt == 1
block.decr_ref()
block.ref_cnt -= 1
assert block.ref_cnt == 0
# Test block hash setting and resetting

View File

@ -276,7 +276,7 @@ class BlockPool:
# candidate), so remove it.
if block.ref_cnt == 0 and not block.is_null:
self.free_block_queue.remove(block)
block.incr_ref()
block.ref_cnt += 1
def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None:
"""Free a list of blocks. The blocks should be ordered by their

View File

@ -126,14 +126,17 @@ class KVCacheCoordinator(ABC):
def get_num_common_prefix_blocks(self, request_id: str,
num_running_requests: int) -> list[int]:
"""
Get the number of common prefix blocks for a request.
Get the number of common prefix blocks for all requests in the RUNNING
state for each kv cache group.
Args:
request_id: The request ID.
num_running_requests: The number of requests in the RUNNING state.
num_running_requests: The total number of requests in the RUNNING
state.
Returns:
list[int]: The number of common prefix blocks.
list[int]: The number of common prefix blocks for all requests in
the RUNNING state for each kv cache group.
"""
num_blocks_per_group = [
manager.get_num_common_prefix_blocks(request_id,

View File

@ -170,10 +170,6 @@ class KVCacheManager:
self.block_size, request)
self.req_to_block_hashes[request.request_id] = block_hashes
if self.log_stats:
assert self.prefix_cache_stats is not None
self.prefix_cache_stats.requests += 1
# NOTE: When all tokens hit the cache, we must recompute the last token
# to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1.
# This can trigger recomputation of an entire block, rather than just
@ -187,6 +183,7 @@ class KVCacheManager:
if self.log_stats:
assert self.prefix_cache_stats is not None
self.prefix_cache_stats.requests += 1
self.prefix_cache_stats.queries += request.num_tokens
self.prefix_cache_stats.hits += num_new_computed_tokens

View File

@ -154,14 +154,6 @@ class KVCacheBlock:
# Whether the block is a null block that should never be cached.
is_null: bool = False
# TODO(Jialin): For performance, let callers handle ref_cnt bumps to
# avoid function calls.
def incr_ref(self):
self.ref_cnt += 1
def decr_ref(self):
self.ref_cnt -= 1
@property
def block_hash(self) -> Optional[BlockHashWithGroupId]:
return self._block_hash

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Callable
@ -177,14 +178,17 @@ class SingleTypeKVCacheManager(ABC):
def get_num_common_prefix_blocks(self, request_id: str,
num_running_requests: int) -> int:
"""
Get the number of common prefix blocks for a request.
Get the number of common prefix blocks for all requests in the RUNNING
state.
Args:
request_id: The request ID.
num_running_requests: The number of requests in the RUNNING state.
num_running_requests: The total number of requests in the RUNNING
state.
Returns:
The number of common prefix blocks.
The number of common prefix blocks for all requests in the RUNNING
state.
"""
raise NotImplementedError
@ -264,7 +268,7 @@ class FullAttentionManager(SingleTypeKVCacheManager):
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
[] for _ in range(len(kv_cache_group_ids)))
max_num_blocks = max_length // kv_cache_spec.block_size
for i, block_hash in zip(range(max_num_blocks), block_hashes):
for block_hash in itertools.islice(block_hashes, max_num_blocks):
# block_hashes is a chain of block hashes. If a block hash is not
# in the cached_block_hash_to_id, the following block hashes are
# not computed yet for sure.