mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-07 10:43:11 +08:00
[Misc] Improve code readability of KVCacheManager (#21673)
Signed-off-by: tanruixiang <tanruixiang0104@gmail.com> Signed-off-by: Ruixiang Tan <819464715@qq.com> Signed-off-by: GitHub <noreply@github.com>
This commit is contained in:
parent
36ede45989
commit
8f4a1c9a04
@ -112,9 +112,9 @@ def test_kv_cache_block():
|
||||
assert block.block_hash is None
|
||||
|
||||
# Test reference count manipulation
|
||||
block.incr_ref()
|
||||
block.ref_cnt += 1
|
||||
assert block.ref_cnt == 1
|
||||
block.decr_ref()
|
||||
block.ref_cnt -= 1
|
||||
assert block.ref_cnt == 0
|
||||
|
||||
# Test block hash setting and resetting
|
||||
|
||||
@ -276,7 +276,7 @@ class BlockPool:
|
||||
# candidate), so remove it.
|
||||
if block.ref_cnt == 0 and not block.is_null:
|
||||
self.free_block_queue.remove(block)
|
||||
block.incr_ref()
|
||||
block.ref_cnt += 1
|
||||
|
||||
def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None:
|
||||
"""Free a list of blocks. The blocks should be ordered by their
|
||||
|
||||
@ -126,14 +126,17 @@ class KVCacheCoordinator(ABC):
|
||||
def get_num_common_prefix_blocks(self, request_id: str,
|
||||
num_running_requests: int) -> list[int]:
|
||||
"""
|
||||
Get the number of common prefix blocks for a request.
|
||||
Get the number of common prefix blocks for all requests in the RUNNING
|
||||
state for each kv cache group.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
num_running_requests: The number of requests in the RUNNING state.
|
||||
num_running_requests: The total number of requests in the RUNNING
|
||||
state.
|
||||
|
||||
Returns:
|
||||
list[int]: The number of common prefix blocks.
|
||||
list[int]: The number of common prefix blocks for all requests in
|
||||
the RUNNING state for each kv cache group.
|
||||
"""
|
||||
num_blocks_per_group = [
|
||||
manager.get_num_common_prefix_blocks(request_id,
|
||||
|
||||
@ -170,10 +170,6 @@ class KVCacheManager:
|
||||
self.block_size, request)
|
||||
self.req_to_block_hashes[request.request_id] = block_hashes
|
||||
|
||||
if self.log_stats:
|
||||
assert self.prefix_cache_stats is not None
|
||||
self.prefix_cache_stats.requests += 1
|
||||
|
||||
# NOTE: When all tokens hit the cache, we must recompute the last token
|
||||
# to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1.
|
||||
# This can trigger recomputation of an entire block, rather than just
|
||||
@ -187,6 +183,7 @@ class KVCacheManager:
|
||||
|
||||
if self.log_stats:
|
||||
assert self.prefix_cache_stats is not None
|
||||
self.prefix_cache_stats.requests += 1
|
||||
self.prefix_cache_stats.queries += request.num_tokens
|
||||
self.prefix_cache_stats.hits += num_new_computed_tokens
|
||||
|
||||
|
||||
@ -154,14 +154,6 @@ class KVCacheBlock:
|
||||
# Whether the block is a null block that should never be cached.
|
||||
is_null: bool = False
|
||||
|
||||
# TODO(Jialin): For performance, let callers handle ref_cnt bumps to
|
||||
# avoid function calls.
|
||||
def incr_ref(self):
|
||||
self.ref_cnt += 1
|
||||
|
||||
def decr_ref(self):
|
||||
self.ref_cnt -= 1
|
||||
|
||||
@property
|
||||
def block_hash(self) -> Optional[BlockHashWithGroupId]:
|
||||
return self._block_hash
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import itertools
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from typing import Callable
|
||||
@ -177,14 +178,17 @@ class SingleTypeKVCacheManager(ABC):
|
||||
def get_num_common_prefix_blocks(self, request_id: str,
|
||||
num_running_requests: int) -> int:
|
||||
"""
|
||||
Get the number of common prefix blocks for a request.
|
||||
Get the number of common prefix blocks for all requests in the RUNNING
|
||||
state.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
num_running_requests: The number of requests in the RUNNING state.
|
||||
num_running_requests: The total number of requests in the RUNNING
|
||||
state.
|
||||
|
||||
Returns:
|
||||
The number of common prefix blocks.
|
||||
The number of common prefix blocks for all requests in the RUNNING
|
||||
state.
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
@ -264,7 +268,7 @@ class FullAttentionManager(SingleTypeKVCacheManager):
|
||||
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
|
||||
[] for _ in range(len(kv_cache_group_ids)))
|
||||
max_num_blocks = max_length // kv_cache_spec.block_size
|
||||
for i, block_hash in zip(range(max_num_blocks), block_hashes):
|
||||
for block_hash in itertools.islice(block_hashes, max_num_blocks):
|
||||
# block_hashes is a chain of block hashes. If a block hash is not
|
||||
# in the cached_block_hash_to_id, the following block hashes are
|
||||
# not computed yet for sure.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user