mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 05:55:01 +08:00
[v1][KVCacheManager] Rename BlockHashType to BlockHash (#19015)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
parent
d32aa2e670
commit
f32fcd9444
@ -104,7 +104,7 @@ class KVCacheBlock:
|
|||||||
block_id: int
|
block_id: int
|
||||||
# The block hash (will be assigned when the block is full,
|
# The block hash (will be assigned when the block is full,
|
||||||
# and will be reset when the block is evicted).
|
# and will be reset when the block is evicted).
|
||||||
block_hash: BlockHashType
|
block_hash: BlockHash
|
||||||
# The number of requests using this block now.
|
# The number of requests using this block now.
|
||||||
ref_cnt: int
|
ref_cnt: int
|
||||||
|
|
||||||
|
|||||||
@ -100,8 +100,8 @@ def test_kv_cache_block():
|
|||||||
assert block.ref_cnt == 0
|
assert block.ref_cnt == 0
|
||||||
|
|
||||||
# Test block hash setting and resetting
|
# Test block hash setting and resetting
|
||||||
block_hash = vllm.v1.core.kv_cache_utils.BlockHashType(hash_value=123,
|
block_hash = vllm.v1.core.kv_cache_utils.BlockHash(hash_value=123,
|
||||||
token_ids=(1, 2, 3))
|
token_ids=(1, 2, 3))
|
||||||
block.block_hash = block_hash
|
block.block_hash = block_hash
|
||||||
assert block.block_hash == block_hash
|
assert block.block_hash == block_hash
|
||||||
|
|
||||||
@ -282,7 +282,7 @@ def test_hash_block_tokens(hash_fn):
|
|||||||
|
|
||||||
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
|
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
|
||||||
curr_block_token_ids, extra_keys)
|
curr_block_token_ids, extra_keys)
|
||||||
assert isinstance(block_hash, vllm.v1.core.kv_cache_utils.BlockHashType)
|
assert isinstance(block_hash, vllm.v1.core.kv_cache_utils.BlockHash)
|
||||||
assert block_hash.hash_value == hash_fn(
|
assert block_hash.hash_value == hash_fn(
|
||||||
(parent_block_hash, curr_block_token_ids, extra_keys))
|
(parent_block_hash, curr_block_token_ids, extra_keys))
|
||||||
assert block_hash.token_ids == curr_block_token_ids
|
assert block_hash.token_ids == curr_block_token_ids
|
||||||
@ -306,10 +306,8 @@ def test_hash_request_tokens(hash_fn):
|
|||||||
block_hashes = hash_request_tokens(hash_fn, block_size, request)
|
block_hashes = hash_request_tokens(hash_fn, block_size, request)
|
||||||
|
|
||||||
assert len(block_hashes) == 2
|
assert len(block_hashes) == 2
|
||||||
assert isinstance(block_hashes[0],
|
assert isinstance(block_hashes[0], vllm.v1.core.kv_cache_utils.BlockHash)
|
||||||
vllm.v1.core.kv_cache_utils.BlockHashType)
|
assert isinstance(block_hashes[1], vllm.v1.core.kv_cache_utils.BlockHash)
|
||||||
assert isinstance(block_hashes[1],
|
|
||||||
vllm.v1.core.kv_cache_utils.BlockHashType)
|
|
||||||
|
|
||||||
# Check the first block
|
# Check the first block
|
||||||
assert block_hashes[0].token_ids == (0, 1, 2)
|
assert block_hashes[0].token_ids == (0, 1, 2)
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from vllm.sampling_params import SamplingParams
|
|||||||
from vllm.utils import sha256
|
from vllm.utils import sha256
|
||||||
from vllm.v1.core.block_pool import BlockPool
|
from vllm.v1.core.block_pool import BlockPool
|
||||||
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
|
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
|
||||||
from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
|
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
|
||||||
hash_block_tokens)
|
hash_block_tokens)
|
||||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||||
KVCacheGroupSpec, SlidingWindowSpec)
|
KVCacheGroupSpec, SlidingWindowSpec)
|
||||||
@ -547,7 +547,7 @@ def test_cache_blocks(hash_fn):
|
|||||||
|
|
||||||
# Test that blocks are cached correctly for 2 full blocks from the start.
|
# Test that blocks are cached correctly for 2 full blocks from the start.
|
||||||
blocks = [KVCacheBlock(block_id=i) for i in range(2)]
|
blocks = [KVCacheBlock(block_id=i) for i in range(2)]
|
||||||
block_hashes: list[BlockHashType] = []
|
block_hashes: list[BlockHash] = []
|
||||||
|
|
||||||
block_pool.cache_full_blocks(
|
block_pool.cache_full_blocks(
|
||||||
request=req,
|
request=req,
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.v1.core.block_pool import BlockPool
|
from vllm.v1.core.block_pool import BlockPool
|
||||||
from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock
|
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
|
||||||
from vllm.v1.core.single_type_kv_cache_manager import SlidingWindowManager
|
from vllm.v1.core.single_type_kv_cache_manager import SlidingWindowManager
|
||||||
from vllm.v1.kv_cache_interface import SlidingWindowSpec
|
from vllm.v1.kv_cache_interface import SlidingWindowSpec
|
||||||
|
|
||||||
@ -32,7 +32,7 @@ def test_sliding_window_possible_cached_prefix():
|
|||||||
|
|
||||||
def run_one_case(block_is_cached, expect_length):
|
def run_one_case(block_is_cached, expect_length):
|
||||||
block_hash_list = [
|
block_hash_list = [
|
||||||
BlockHashType(i, ()) for i in range(len(block_is_cached))
|
BlockHash(i, ()) for i in range(len(block_is_cached))
|
||||||
]
|
]
|
||||||
|
|
||||||
block_pool.cached_block_hash_to_block.clear()
|
block_pool.cached_block_hash_to_block.clear()
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from typing import Callable, Optional
|
|||||||
from vllm.distributed.kv_events import (AllBlocksCleared, BlockRemoved,
|
from vllm.distributed.kv_events import (AllBlocksCleared, BlockRemoved,
|
||||||
BlockStored, KVCacheEvent)
|
BlockStored, KVCacheEvent)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
|
from vllm.v1.core.kv_cache_utils import (BlockHash, FreeKVCacheBlockQueue,
|
||||||
KVCacheBlock,
|
KVCacheBlock,
|
||||||
generate_block_hash_extra_keys,
|
generate_block_hash_extra_keys,
|
||||||
hash_block_tokens)
|
hash_block_tokens)
|
||||||
@ -55,7 +55,7 @@ class BlockPool:
|
|||||||
# if there is already an identical block in the cache. This is because
|
# if there is already an identical block in the cache. This is because
|
||||||
# we want to make sure the allocated block IDs won't change so that
|
# we want to make sure the allocated block IDs won't change so that
|
||||||
# block tables are append-only.
|
# block tables are append-only.
|
||||||
self.cached_block_hash_to_block: dict[BlockHashType, dict[
|
self.cached_block_hash_to_block: dict[BlockHash, dict[
|
||||||
int, KVCacheBlock]] = defaultdict(dict)
|
int, KVCacheBlock]] = defaultdict(dict)
|
||||||
|
|
||||||
# To represent a placeholder block with block_id=0.
|
# To represent a placeholder block with block_id=0.
|
||||||
@ -67,7 +67,7 @@ class BlockPool:
|
|||||||
self.kv_event_queue: list[KVCacheEvent] = []
|
self.kv_event_queue: list[KVCacheEvent] = []
|
||||||
|
|
||||||
def get_cached_block(self,
|
def get_cached_block(self,
|
||||||
block_hash: BlockHashType) -> Optional[KVCacheBlock]:
|
block_hash: BlockHash) -> Optional[KVCacheBlock]:
|
||||||
"""Get a cached block by the block hash, or None if cache miss.
|
"""Get a cached block by the block hash, or None if cache miss.
|
||||||
If there are duplicated blocks, we return the first block in the cache.
|
If there are duplicated blocks, we return the first block in the cache.
|
||||||
|
|
||||||
@ -87,7 +87,7 @@ class BlockPool:
|
|||||||
self,
|
self,
|
||||||
request: Request,
|
request: Request,
|
||||||
blocks: list[KVCacheBlock],
|
blocks: list[KVCacheBlock],
|
||||||
block_hashes: list[BlockHashType],
|
block_hashes: list[BlockHash],
|
||||||
num_cached_blocks: int,
|
num_cached_blocks: int,
|
||||||
num_full_blocks: int,
|
num_full_blocks: int,
|
||||||
block_size: int,
|
block_size: int,
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from vllm.distributed.kv_events import KVCacheEvent
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import sha256
|
from vllm.utils import sha256
|
||||||
from vllm.v1.core.block_pool import BlockPool
|
from vllm.v1.core.block_pool import BlockPool
|
||||||
from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
|
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
|
||||||
hash_request_tokens)
|
hash_request_tokens)
|
||||||
from vllm.v1.core.single_type_kv_cache_manager import (
|
from vllm.v1.core.single_type_kv_cache_manager import (
|
||||||
get_manager_for_kv_cache_spec)
|
get_manager_for_kv_cache_spec)
|
||||||
@ -92,7 +92,7 @@ class KVCacheManager:
|
|||||||
# This is to avoid recomputing the block hashes for each call of
|
# This is to avoid recomputing the block hashes for each call of
|
||||||
# `get_computed_blocks` or `allocate_slots`.
|
# `get_computed_blocks` or `allocate_slots`.
|
||||||
self.req_to_block_hashes: defaultdict[
|
self.req_to_block_hashes: defaultdict[
|
||||||
str, list[BlockHashType]] = defaultdict(list)
|
str, list[BlockHash]] = defaultdict(list)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def usage(self) -> float:
|
def usage(self) -> float:
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from vllm.v1.request import Request
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BlockHashType(NamedTuple):
|
class BlockHash(NamedTuple):
|
||||||
"""Hash value of a block (int), the token IDs in the block, and extra keys.
|
"""Hash value of a block (int), the token IDs in the block, and extra keys.
|
||||||
We keep a tuple of token IDs and extra keys to reduce the likelihood of
|
We keep a tuple of token IDs and extra keys to reduce the likelihood of
|
||||||
hash collisions when the hash value is the same. By using SHA256 however,
|
hash collisions when the hash value is the same. By using SHA256 however,
|
||||||
@ -117,7 +117,7 @@ class KVCacheBlock:
|
|||||||
ref_cnt: int = 0
|
ref_cnt: int = 0
|
||||||
# The hash of the block composed of (block hash, tuple of token IDs).
|
# The hash of the block composed of (block hash, tuple of token IDs).
|
||||||
# It is only available when the block is full.
|
# It is only available when the block is full.
|
||||||
_block_hash: Optional[BlockHashType] = None
|
_block_hash: Optional[BlockHash] = None
|
||||||
|
|
||||||
# Used to construct a doubly linked list for free blocks.
|
# Used to construct a doubly linked list for free blocks.
|
||||||
# These two attributes should only be manipulated by FreeKVCacheBlockQueue.
|
# These two attributes should only be manipulated by FreeKVCacheBlockQueue.
|
||||||
@ -131,11 +131,11 @@ class KVCacheBlock:
|
|||||||
self.ref_cnt -= 1
|
self.ref_cnt -= 1
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def block_hash(self) -> Optional[BlockHashType]:
|
def block_hash(self) -> Optional[BlockHash]:
|
||||||
return self._block_hash
|
return self._block_hash
|
||||||
|
|
||||||
@block_hash.setter
|
@block_hash.setter
|
||||||
def block_hash(self, block_hash: BlockHashType):
|
def block_hash(self, block_hash: BlockHash):
|
||||||
assert self.block_hash is None, (
|
assert self.block_hash is None, (
|
||||||
"The block already has a hash. This should not happen.")
|
"The block already has a hash. This should not happen.")
|
||||||
self._block_hash = block_hash
|
self._block_hash = block_hash
|
||||||
@ -398,7 +398,7 @@ def hash_block_tokens(
|
|||||||
hash_function: Callable,
|
hash_function: Callable,
|
||||||
parent_block_hash: Optional[int],
|
parent_block_hash: Optional[int],
|
||||||
curr_block_token_ids: Sequence[int],
|
curr_block_token_ids: Sequence[int],
|
||||||
extra_keys: Optional[tuple[Any, ...]] = None) -> BlockHashType:
|
extra_keys: Optional[tuple[Any, ...]] = None) -> BlockHash:
|
||||||
"""Computes a hash value corresponding to the contents of a block and
|
"""Computes a hash value corresponding to the contents of a block and
|
||||||
the contents of the preceding block(s). The hash value is used for
|
the contents of the preceding block(s). The hash value is used for
|
||||||
prefix caching. We use LRU cache for this function to avoid recomputing
|
prefix caching. We use LRU cache for this function to avoid recomputing
|
||||||
@ -419,14 +419,14 @@ def hash_block_tokens(
|
|||||||
parent_block_hash = NONE_HASH
|
parent_block_hash = NONE_HASH
|
||||||
|
|
||||||
curr_block_token_ids_tuple = tuple(curr_block_token_ids)
|
curr_block_token_ids_tuple = tuple(curr_block_token_ids)
|
||||||
return BlockHashType(
|
return BlockHash(
|
||||||
hash_function(
|
hash_function(
|
||||||
(parent_block_hash, curr_block_token_ids_tuple, extra_keys)),
|
(parent_block_hash, curr_block_token_ids_tuple, extra_keys)),
|
||||||
curr_block_token_ids_tuple, extra_keys)
|
curr_block_token_ids_tuple, extra_keys)
|
||||||
|
|
||||||
|
|
||||||
def hash_request_tokens(hash_function: Any, block_size: int,
|
def hash_request_tokens(hash_function: Any, block_size: int,
|
||||||
request: Request) -> list[BlockHashType]:
|
request: Request) -> list[BlockHash]:
|
||||||
"""Computes hash values of a chain of blocks given a sequence of
|
"""Computes hash values of a chain of blocks given a sequence of
|
||||||
token IDs. The hash value is used for prefix caching.
|
token IDs. The hash value is used for prefix caching.
|
||||||
|
|
||||||
|
|||||||
@ -5,7 +5,7 @@ from typing import Callable
|
|||||||
|
|
||||||
from vllm.utils import cdiv
|
from vllm.utils import cdiv
|
||||||
from vllm.v1.core.block_pool import BlockPool
|
from vllm.v1.core.block_pool import BlockPool
|
||||||
from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock
|
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
|
||||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec,
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec,
|
||||||
SlidingWindowSpec)
|
SlidingWindowSpec)
|
||||||
from vllm.v1.request import Request
|
from vllm.v1.request import Request
|
||||||
@ -133,7 +133,7 @@ class SingleTypeKVCacheManager(ABC):
|
|||||||
req_blocks.extend(new_blocks)
|
req_blocks.extend(new_blocks)
|
||||||
return new_blocks
|
return new_blocks
|
||||||
|
|
||||||
def cache_blocks(self, request: Request, block_hashes: list[BlockHashType],
|
def cache_blocks(self, request: Request, block_hashes: list[BlockHash],
|
||||||
num_tokens: int) -> None:
|
num_tokens: int) -> None:
|
||||||
"""
|
"""
|
||||||
Cache the blocks for the request.
|
Cache the blocks for the request.
|
||||||
@ -187,7 +187,7 @@ class SingleTypeKVCacheManager(ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def find_longest_cache_hit(self, block_hashes: list[BlockHashType],
|
def find_longest_cache_hit(self, block_hashes: list[BlockHash],
|
||||||
max_length: int) -> list[KVCacheBlock]:
|
max_length: int) -> list[KVCacheBlock]:
|
||||||
"""
|
"""
|
||||||
Get the longest cache hit prefix of the blocks that is not longer than
|
Get the longest cache hit prefix of the blocks that is not longer than
|
||||||
@ -228,7 +228,7 @@ class SingleTypeKVCacheManager(ABC):
|
|||||||
|
|
||||||
class FullAttentionManager(SingleTypeKVCacheManager):
|
class FullAttentionManager(SingleTypeKVCacheManager):
|
||||||
|
|
||||||
def find_longest_cache_hit(self, block_hashes: list[BlockHashType],
|
def find_longest_cache_hit(self, block_hashes: list[BlockHash],
|
||||||
max_length: int) -> list[KVCacheBlock]:
|
max_length: int) -> list[KVCacheBlock]:
|
||||||
computed_blocks: list[KVCacheBlock] = []
|
computed_blocks: list[KVCacheBlock] = []
|
||||||
max_num_blocks = max_length // self.block_size
|
max_num_blocks = max_length // self.block_size
|
||||||
@ -280,7 +280,7 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
|
|||||||
self.sliding_window_contiguous_blocks += 1
|
self.sliding_window_contiguous_blocks += 1
|
||||||
self._null_block = block_pool.null_block
|
self._null_block = block_pool.null_block
|
||||||
|
|
||||||
def find_longest_cache_hit(self, block_hashes: list[BlockHashType],
|
def find_longest_cache_hit(self, block_hashes: list[BlockHash],
|
||||||
max_length: int) -> list[KVCacheBlock]:
|
max_length: int) -> list[KVCacheBlock]:
|
||||||
# TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to
|
# TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to
|
||||||
# optimize the time complexity from O(max_num_blocks) to
|
# optimize the time complexity from O(max_num_blocks) to
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user