mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 03:34:57 +08:00
[V1] Move KV block hashes from Request to KVCacheManager (#12922)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
b21f0f9d17
commit
3243158336
@ -51,7 +51,7 @@ def test_prefill():
|
|||||||
all_token_ids = common_token_ids + unique_token_ids
|
all_token_ids = common_token_ids + unique_token_ids
|
||||||
req0 = make_request("0", all_token_ids)
|
req0 = make_request("0", all_token_ids)
|
||||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||||
assert len(req0.kv_block_hashes) == 3
|
assert len(manager.req_to_block_hashes[req0.request_id]) == 3
|
||||||
assert not computed_blocks
|
assert not computed_blocks
|
||||||
assert num_computed_tokens == 0
|
assert num_computed_tokens == 0
|
||||||
blocks = manager.allocate_slots(req0, 55, computed_blocks)
|
blocks = manager.allocate_slots(req0, 55, computed_blocks)
|
||||||
@ -76,7 +76,7 @@ def test_prefill():
|
|||||||
unique_token_ids = [3] * 5
|
unique_token_ids = [3] * 5
|
||||||
req1 = make_request("1", common_token_ids + unique_token_ids)
|
req1 = make_request("1", common_token_ids + unique_token_ids)
|
||||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||||
assert len(req1.kv_block_hashes) == 3
|
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
|
||||||
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
|
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
|
||||||
assert num_computed_tokens == 3 * 16
|
assert num_computed_tokens == 3 * 16
|
||||||
num_new_tokens = 53 - 3 * 16
|
num_new_tokens = 53 - 3 * 16
|
||||||
@ -107,7 +107,7 @@ def test_prefill():
|
|||||||
unique_token_ids = [3] * 6
|
unique_token_ids = [3] * 6
|
||||||
req2 = make_request("2", common_token_ids + unique_token_ids)
|
req2 = make_request("2", common_token_ids + unique_token_ids)
|
||||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||||
assert len(req2.kv_block_hashes) == 3
|
assert len(manager.req_to_block_hashes[req2.request_id]) == 3
|
||||||
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
|
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
|
||||||
assert num_computed_tokens == 3 * 16
|
assert num_computed_tokens == 3 * 16
|
||||||
num_new_tokens = 53 - 3 * 16
|
num_new_tokens = 53 - 3 * 16
|
||||||
@ -494,10 +494,11 @@ def test_mm_prefix_caching():
|
|||||||
# Completed block should have hashes with extra keys.
|
# Completed block should have hashes with extra keys.
|
||||||
assert not computed_blocks
|
assert not computed_blocks
|
||||||
assert num_computed_tokens == 0
|
assert num_computed_tokens == 0
|
||||||
assert len(req0.kv_block_hashes) == 3
|
block_hashes = manager.req_to_block_hashes[req0.request_id]
|
||||||
assert req0.kv_block_hashes[0].extra_keys == ("aaa", )
|
assert len(block_hashes) == 3
|
||||||
assert req0.kv_block_hashes[1].extra_keys == ("aaa", "bbb")
|
assert block_hashes[0].extra_keys == ("aaa", )
|
||||||
assert req0.kv_block_hashes[2].extra_keys == ("bbb", )
|
assert block_hashes[1].extra_keys == ("aaa", "bbb")
|
||||||
|
assert block_hashes[2].extra_keys == ("bbb", )
|
||||||
|
|
||||||
blocks = manager.allocate_slots(req0, 59, computed_blocks)
|
blocks = manager.allocate_slots(req0, 59, computed_blocks)
|
||||||
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
|
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
|
||||||
@ -510,8 +511,8 @@ def test_mm_prefix_caching():
|
|||||||
assert new_blocks is not None and len(new_blocks) == 0
|
assert new_blocks is not None and len(new_blocks) == 0
|
||||||
|
|
||||||
# The just completed block should have hashes with extra keys.
|
# The just completed block should have hashes with extra keys.
|
||||||
assert len(req0.kv_block_hashes) == 4
|
assert len(block_hashes) == 4
|
||||||
assert req0.kv_block_hashes[3].extra_keys == ("ccc", )
|
assert block_hashes[3].extra_keys == ("ccc", )
|
||||||
|
|
||||||
# Cache hit.
|
# Cache hit.
|
||||||
unique_token_ids = [-1] * 7 + [200] * 5
|
unique_token_ids = [-1] * 7 + [200] * 5
|
||||||
@ -613,7 +614,7 @@ def test_reset_prefix_cache():
|
|||||||
all_token_ids = full_block_token_ids + unique_token_ids
|
all_token_ids = full_block_token_ids + unique_token_ids
|
||||||
req1 = make_request("1", all_token_ids)
|
req1 = make_request("1", all_token_ids)
|
||||||
computed_blocks, _ = manager.get_computed_blocks(req1)
|
computed_blocks, _ = manager.get_computed_blocks(req1)
|
||||||
assert len(req1.kv_block_hashes) == 3
|
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
|
||||||
assert len(computed_blocks) == 3
|
assert len(computed_blocks) == 3
|
||||||
blocks = manager.allocate_slots(req1, 7, computed_blocks)
|
blocks = manager.allocate_slots(req1, 7, computed_blocks)
|
||||||
assert [b.block_id for b in blocks] == [4]
|
assert [b.block_id for b in blocks] == [4]
|
||||||
|
|||||||
@ -72,6 +72,12 @@ class KVCacheManager:
|
|||||||
self.req_to_blocks: DefaultDict[str,
|
self.req_to_blocks: DefaultDict[str,
|
||||||
List[KVCacheBlock]] = defaultdict(list)
|
List[KVCacheBlock]] = defaultdict(list)
|
||||||
|
|
||||||
|
# Mapping from request ID to kv block hashes.
|
||||||
|
# This is to avoid recomputing the block hashes for each call of
|
||||||
|
# `get_computed_blocks` or `allocate_slots`.
|
||||||
|
self.req_to_block_hashes: DefaultDict[
|
||||||
|
str, List[BlockHashType]] = defaultdict(list)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def usage(self) -> float:
|
def usage(self) -> float:
|
||||||
return 1.0 - (self.free_block_queue.num_free_blocks /
|
return 1.0 - (self.free_block_queue.num_free_blocks /
|
||||||
@ -97,11 +103,11 @@ class KVCacheManager:
|
|||||||
computed_blocks = []
|
computed_blocks = []
|
||||||
|
|
||||||
# The block hashes for the request may already be computed
|
# The block hashes for the request may already be computed
|
||||||
# if the request was preempted and resumed.
|
# if the scheduler has tried to schedule the request before.
|
||||||
if not request.kv_block_hashes:
|
block_hashes = self.req_to_block_hashes[request.request_id]
|
||||||
request.set_kv_block_hashes(
|
if not block_hashes:
|
||||||
hash_request_tokens(self.block_size, request))
|
block_hashes = hash_request_tokens(self.block_size, request)
|
||||||
block_hashes = request.kv_block_hashes
|
self.req_to_block_hashes[request.request_id] = block_hashes
|
||||||
|
|
||||||
for block_hash in block_hashes:
|
for block_hash in block_hashes:
|
||||||
# block_hashes is a chain of block hashes. If a block hash is not
|
# block_hashes is a chain of block hashes. If a block hash is not
|
||||||
@ -435,7 +441,8 @@ class KVCacheManager:
|
|||||||
full_blocks: The list of blocks to update hash metadata.
|
full_blocks: The list of blocks to update hash metadata.
|
||||||
prev_block: The previous block in the chain.
|
prev_block: The previous block in the chain.
|
||||||
"""
|
"""
|
||||||
num_cached_block_hashes = len(request.kv_block_hashes)
|
block_hashes = self.req_to_block_hashes[request.request_id]
|
||||||
|
num_cached_block_hashes = len(block_hashes)
|
||||||
|
|
||||||
# Update the new blocks with the block hashes through the chain.
|
# Update the new blocks with the block hashes through the chain.
|
||||||
prev_block_hash_value = None
|
prev_block_hash_value = None
|
||||||
@ -468,7 +475,7 @@ class KVCacheManager:
|
|||||||
# this request (either the prompt tokens or the previously
|
# this request (either the prompt tokens or the previously
|
||||||
# generated tokens with preemption). In this case we simply
|
# generated tokens with preemption). In this case we simply
|
||||||
# reuse the block hash.
|
# reuse the block hash.
|
||||||
block_hash = request.kv_block_hashes[blk_idx]
|
block_hash = block_hashes[blk_idx]
|
||||||
else:
|
else:
|
||||||
# Otherwise compute the block hash and cache it in the request
|
# Otherwise compute the block hash and cache it in the request
|
||||||
# in case it will be preempted in the future.
|
# in case it will be preempted in the future.
|
||||||
@ -490,9 +497,17 @@ class KVCacheManager:
|
|||||||
# Compute the hash of the current block.
|
# Compute the hash of the current block.
|
||||||
block_hash = hash_block_tokens(prev_block_hash_value,
|
block_hash = hash_block_tokens(prev_block_hash_value,
|
||||||
block_tokens, extra_keys)
|
block_tokens, extra_keys)
|
||||||
request.append_kv_block_hashes(block_hash)
|
block_hashes.append(block_hash)
|
||||||
|
|
||||||
# Update and added the full block to the cache.
|
# Update and added the full block to the cache.
|
||||||
blk.block_hash = block_hash
|
blk.block_hash = block_hash
|
||||||
self.cached_block_hash_to_block[block_hash][blk.block_id] = blk
|
self.cached_block_hash_to_block[block_hash][blk.block_id] = blk
|
||||||
prev_block_hash_value = block_hash.hash_value
|
prev_block_hash_value = block_hash.hash_value
|
||||||
|
|
||||||
|
def free_block_hashes(self, request: Request) -> None:
|
||||||
|
"""Discard the block hashes for the request.
|
||||||
|
|
||||||
|
NOTE: Unlike `free`, this method should be called only when the request
|
||||||
|
is finished, not when it is preempted.
|
||||||
|
"""
|
||||||
|
self.req_to_block_hashes.pop(request.request_id, None)
|
||||||
|
|||||||
@ -579,6 +579,7 @@ class Scheduler:
|
|||||||
def _free_request(self, request: Request) -> None:
|
def _free_request(self, request: Request) -> None:
|
||||||
assert request.is_finished()
|
assert request.is_finished()
|
||||||
self.kv_cache_manager.free(request)
|
self.kv_cache_manager.free(request)
|
||||||
|
self.kv_cache_manager.free_block_hashes(request)
|
||||||
self.encoder_cache_manager.free(request)
|
self.encoder_cache_manager.free(request)
|
||||||
self._cached_reqs_data.pop(request.request_id, None)
|
self._cached_reqs_data.pop(request.request_id, None)
|
||||||
del self.requests[request.request_id]
|
del self.requests[request.request_id]
|
||||||
|
|||||||
@ -12,7 +12,6 @@ from vllm.v1.utils import ConstantList
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.multimodal import MultiModalKwargs
|
from vllm.multimodal import MultiModalKwargs
|
||||||
from vllm.multimodal.inputs import PlaceholderRange
|
from vllm.multimodal.inputs import PlaceholderRange
|
||||||
from vllm.v1.core.kv_cache_utils import BlockHashType
|
|
||||||
|
|
||||||
|
|
||||||
class Request:
|
class Request:
|
||||||
@ -63,11 +62,6 @@ class Request:
|
|||||||
if self.mm_hashes:
|
if self.mm_hashes:
|
||||||
assert len(self.mm_inputs) == len(self.mm_hashes)
|
assert len(self.mm_inputs) == len(self.mm_hashes)
|
||||||
|
|
||||||
# Cache the computed kv block hashes of the request to avoid
|
|
||||||
# recomputing.
|
|
||||||
self._kv_block_hashes: List[BlockHashType] = []
|
|
||||||
self.kv_block_hashes = ConstantList(self._kv_block_hashes)
|
|
||||||
|
|
||||||
# Read-only views
|
# Read-only views
|
||||||
# Prevent directly appending to the these lists since
|
# Prevent directly appending to the these lists since
|
||||||
# they should also be updated simultaneously.
|
# they should also be updated simultaneously.
|
||||||
@ -124,13 +118,6 @@ class Request:
|
|||||||
num_tokens = self.mm_positions[input_id]["length"]
|
num_tokens = self.mm_positions[input_id]["length"]
|
||||||
return num_tokens
|
return num_tokens
|
||||||
|
|
||||||
def set_kv_block_hashes(self, value: List["BlockHashType"]) -> None:
|
|
||||||
self._kv_block_hashes = value
|
|
||||||
self.kv_block_hashes = ConstantList(self._kv_block_hashes)
|
|
||||||
|
|
||||||
def append_kv_block_hashes(self, block_hash: "BlockHashType") -> None:
|
|
||||||
self._kv_block_hashes.append(block_hash)
|
|
||||||
|
|
||||||
|
|
||||||
class RequestStatus(enum.IntEnum):
|
class RequestStatus(enum.IntEnum):
|
||||||
"""Status of a request."""
|
"""Status of a request."""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user