mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-09 09:42:14 +08:00
[Misc] Improve code readability of KVCacheManager (#21673)
Signed-off-by: tanruixiang <tanruixiang0104@gmail.com> Signed-off-by: Ruixiang Tan <819464715@qq.com> Signed-off-by: GitHub <noreply@github.com>
This commit is contained in:
parent
36ede45989
commit
8f4a1c9a04
@ -112,9 +112,9 @@ def test_kv_cache_block():
|
|||||||
assert block.block_hash is None
|
assert block.block_hash is None
|
||||||
|
|
||||||
# Test reference count manipulation
|
# Test reference count manipulation
|
||||||
block.incr_ref()
|
block.ref_cnt += 1
|
||||||
assert block.ref_cnt == 1
|
assert block.ref_cnt == 1
|
||||||
block.decr_ref()
|
block.ref_cnt -= 1
|
||||||
assert block.ref_cnt == 0
|
assert block.ref_cnt == 0
|
||||||
|
|
||||||
# Test block hash setting and resetting
|
# Test block hash setting and resetting
|
||||||
|
|||||||
@ -276,7 +276,7 @@ class BlockPool:
|
|||||||
# candidate), so remove it.
|
# candidate), so remove it.
|
||||||
if block.ref_cnt == 0 and not block.is_null:
|
if block.ref_cnt == 0 and not block.is_null:
|
||||||
self.free_block_queue.remove(block)
|
self.free_block_queue.remove(block)
|
||||||
block.incr_ref()
|
block.ref_cnt += 1
|
||||||
|
|
||||||
def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None:
|
def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None:
|
||||||
"""Free a list of blocks. The blocks should be ordered by their
|
"""Free a list of blocks. The blocks should be ordered by their
|
||||||
|
|||||||
@ -126,14 +126,17 @@ class KVCacheCoordinator(ABC):
|
|||||||
def get_num_common_prefix_blocks(self, request_id: str,
|
def get_num_common_prefix_blocks(self, request_id: str,
|
||||||
num_running_requests: int) -> list[int]:
|
num_running_requests: int) -> list[int]:
|
||||||
"""
|
"""
|
||||||
Get the number of common prefix blocks for a request.
|
Get the number of common prefix blocks for all requests in the RUNNING
|
||||||
|
state for each kv cache group.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request_id: The request ID.
|
request_id: The request ID.
|
||||||
num_running_requests: The number of requests in the RUNNING state.
|
num_running_requests: The total number of requests in the RUNNING
|
||||||
|
state.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[int]: The number of common prefix blocks.
|
list[int]: The number of common prefix blocks for all requests in
|
||||||
|
the RUNNING state for each kv cache group.
|
||||||
"""
|
"""
|
||||||
num_blocks_per_group = [
|
num_blocks_per_group = [
|
||||||
manager.get_num_common_prefix_blocks(request_id,
|
manager.get_num_common_prefix_blocks(request_id,
|
||||||
|
|||||||
@ -170,10 +170,6 @@ class KVCacheManager:
|
|||||||
self.block_size, request)
|
self.block_size, request)
|
||||||
self.req_to_block_hashes[request.request_id] = block_hashes
|
self.req_to_block_hashes[request.request_id] = block_hashes
|
||||||
|
|
||||||
if self.log_stats:
|
|
||||||
assert self.prefix_cache_stats is not None
|
|
||||||
self.prefix_cache_stats.requests += 1
|
|
||||||
|
|
||||||
# NOTE: When all tokens hit the cache, we must recompute the last token
|
# NOTE: When all tokens hit the cache, we must recompute the last token
|
||||||
# to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1.
|
# to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1.
|
||||||
# This can trigger recomputation of an entire block, rather than just
|
# This can trigger recomputation of an entire block, rather than just
|
||||||
@ -187,6 +183,7 @@ class KVCacheManager:
|
|||||||
|
|
||||||
if self.log_stats:
|
if self.log_stats:
|
||||||
assert self.prefix_cache_stats is not None
|
assert self.prefix_cache_stats is not None
|
||||||
|
self.prefix_cache_stats.requests += 1
|
||||||
self.prefix_cache_stats.queries += request.num_tokens
|
self.prefix_cache_stats.queries += request.num_tokens
|
||||||
self.prefix_cache_stats.hits += num_new_computed_tokens
|
self.prefix_cache_stats.hits += num_new_computed_tokens
|
||||||
|
|
||||||
|
|||||||
@ -154,14 +154,6 @@ class KVCacheBlock:
|
|||||||
# Whether the block is a null block that should never be cached.
|
# Whether the block is a null block that should never be cached.
|
||||||
is_null: bool = False
|
is_null: bool = False
|
||||||
|
|
||||||
# TODO(Jialin): For performance, let callers handle ref_cnt bumps to
|
|
||||||
# avoid function calls.
|
|
||||||
def incr_ref(self):
|
|
||||||
self.ref_cnt += 1
|
|
||||||
|
|
||||||
def decr_ref(self):
|
|
||||||
self.ref_cnt -= 1
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def block_hash(self) -> Optional[BlockHashWithGroupId]:
|
def block_hash(self) -> Optional[BlockHashWithGroupId]:
|
||||||
return self._block_hash
|
return self._block_hash
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import itertools
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
@ -177,14 +178,17 @@ class SingleTypeKVCacheManager(ABC):
|
|||||||
def get_num_common_prefix_blocks(self, request_id: str,
|
def get_num_common_prefix_blocks(self, request_id: str,
|
||||||
num_running_requests: int) -> int:
|
num_running_requests: int) -> int:
|
||||||
"""
|
"""
|
||||||
Get the number of common prefix blocks for a request.
|
Get the number of common prefix blocks for all requests in the RUNNING
|
||||||
|
state.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request_id: The request ID.
|
request_id: The request ID.
|
||||||
num_running_requests: The number of requests in the RUNNING state.
|
num_running_requests: The total number of requests in the RUNNING
|
||||||
|
state.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The number of common prefix blocks.
|
The number of common prefix blocks for all requests in the RUNNING
|
||||||
|
state.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -264,7 +268,7 @@ class FullAttentionManager(SingleTypeKVCacheManager):
|
|||||||
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
|
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
|
||||||
[] for _ in range(len(kv_cache_group_ids)))
|
[] for _ in range(len(kv_cache_group_ids)))
|
||||||
max_num_blocks = max_length // kv_cache_spec.block_size
|
max_num_blocks = max_length // kv_cache_spec.block_size
|
||||||
for i, block_hash in zip(range(max_num_blocks), block_hashes):
|
for block_hash in itertools.islice(block_hashes, max_num_blocks):
|
||||||
# block_hashes is a chain of block hashes. If a block hash is not
|
# block_hashes is a chain of block hashes. If a block hash is not
|
||||||
# in the cached_block_hash_to_id, the following block hashes are
|
# in the cached_block_hash_to_id, the following block hashes are
|
||||||
# not computed yet for sure.
|
# not computed yet for sure.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user