mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:06:06 +08:00
[v1] Refactor KVCacheManager for more hash input than token ids (#10507)
Signed-off-by: rickyx <rickyx@anyscale.com> Signed-off-by: Cody Yu <hao.yu.cody@gmail.com> Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
parent
eebad39f26
commit
97814fbf0f
@ -1,8 +1,11 @@
|
||||
"""Compare the with and without prefix caching."""
|
||||
import pytest
|
||||
|
||||
from vllm.inputs import token_inputs
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
|
||||
from vllm.v1.core.kv_cache_utils import hash_block_tokens
|
||||
from vllm.v1.core.kv_cache_utils import KVCacheBlock, hash_block_tokens
|
||||
|
||||
|
||||
def make_request(request_id, prompt_token_ids):
|
||||
@ -31,7 +34,8 @@ def test_prefill():
|
||||
# Fully cache miss
|
||||
# Incomplete 1 block (7 tokens)
|
||||
unique_token_ids = [3] * 7
|
||||
req0 = make_request("0", common_token_ids + unique_token_ids)
|
||||
all_token_ids = common_token_ids + unique_token_ids
|
||||
req0 = make_request("0", all_token_ids)
|
||||
computed_blocks = manager.get_computed_blocks(req0)
|
||||
assert not computed_blocks
|
||||
blocks = manager.allocate_slots(req0, 55, computed_blocks)
|
||||
@ -40,24 +44,16 @@ def test_prefill():
|
||||
# Check full block metadata
|
||||
parent_block_hash = None
|
||||
for block_id in (0, 1, 2):
|
||||
block_hash = hash_block_tokens(parent_block_hash,
|
||||
manager.block_pool[block_id].token_ids)
|
||||
block_tokens = tuple(all_token_ids[block_id * 16:(block_id + 1) * 16])
|
||||
block_hash = hash_block_tokens(parent_block_hash, block_tokens)
|
||||
assert manager.block_pool[block_id].block_hash == block_hash
|
||||
assert manager.block_pool[block_id].ref_cnt == 1
|
||||
assert manager.block_pool[block_id].num_hashed_tokens == 16 * (
|
||||
block_id + 1)
|
||||
assert manager.block_pool[block_id].token_ids == tuple([block_id] * 16)
|
||||
parent_block_hash = block_hash
|
||||
|
||||
# Check partial/preallocated block metadata
|
||||
for block_id in (3, 4):
|
||||
assert manager.block_pool[block_id].block_hash is None
|
||||
assert manager.block_pool[block_id].ref_cnt == 1
|
||||
assert manager.block_pool[block_id].num_hashed_tokens == 0
|
||||
if block_id == 3:
|
||||
assert manager.block_pool[block_id].token_ids == [3] * 7
|
||||
else:
|
||||
assert not manager.block_pool[block_id].token_ids
|
||||
|
||||
# Cache hit in the common prefix when the original block is still in use.
|
||||
# Incomplete 1 block (5 tokens)
|
||||
@ -113,7 +109,7 @@ def test_prefill():
|
||||
req3 = make_request("3", [99] * (16 * 9))
|
||||
computed_blocks = manager.get_computed_blocks(req3)
|
||||
assert not computed_blocks
|
||||
blocks = manager.allocate_slots(req2, 16 * 9, computed_blocks)
|
||||
blocks = manager.allocate_slots(req3, 16 * 9, computed_blocks)
|
||||
# This block ID order also checks the eviction order.
|
||||
assert [b.block_id for b in blocks] == [9, 4, 3, 6, 5, 8, 7, 2, 1, 0]
|
||||
assert manager.free_block_queue.num_free_blocks == 0
|
||||
@ -148,7 +144,7 @@ def test_decode():
|
||||
req0.append_output_token_ids(8)
|
||||
new_blocks = manager.append_slots(req0, 4)
|
||||
assert new_blocks is not None and len(new_blocks) == 0
|
||||
assert len(manager.block_pool[3].token_ids) == 11
|
||||
assert manager.req_to_blocks[req0.request_id][-2].block_hash is None
|
||||
|
||||
# Append slots without allocating a new block, but start using the
|
||||
# preallocated block.
|
||||
@ -159,8 +155,7 @@ def test_decode():
|
||||
req0.append_output_token_ids(7)
|
||||
new_blocks = manager.append_slots(req0, 15)
|
||||
assert new_blocks is not None and len(new_blocks) == 0
|
||||
assert len(manager.block_pool[3].token_ids) == 16
|
||||
assert len(manager.block_pool[4].token_ids) == 10
|
||||
assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None
|
||||
|
||||
# Append slots with allocating a new block.
|
||||
req0.num_computed_tokens = 74
|
||||
@ -171,9 +166,6 @@ def test_decode():
|
||||
new_blocks = manager.append_slots(req0, 17)
|
||||
# Plus one preallocated block.
|
||||
assert new_blocks is not None and len(new_blocks) == 2
|
||||
assert len(manager.block_pool[4].token_ids) == 16
|
||||
assert len(manager.block_pool[5].token_ids) == 11
|
||||
assert len(manager.block_pool[6].token_ids) == 0
|
||||
|
||||
|
||||
def test_evict():
|
||||
@ -217,3 +209,198 @@ def test_evict():
|
||||
blocks = manager.allocate_slots(req2, 3, computed_blocks)
|
||||
assert [b.block_id for b in blocks] == [6, 5]
|
||||
assert manager.free_block_queue.num_free_blocks == 6
|
||||
|
||||
|
||||
def test_hash_block_correct_reuse():
|
||||
"""
|
||||
This tests when a previously cached block is reused as a new block,
|
||||
its hash metadata should be correctly reset.
|
||||
"""
|
||||
block_size = 16
|
||||
manager = KVCacheManager(
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=1,
|
||||
sliding_window=False,
|
||||
enable_caching=True,
|
||||
num_preallocate_tokens=0,
|
||||
)
|
||||
|
||||
# Allocate 1 block and cache it.
|
||||
num_tokens = block_size * 1
|
||||
req = make_request("0", list(range(num_tokens)))
|
||||
computed_blocks = manager.get_computed_blocks(req)
|
||||
assert not computed_blocks
|
||||
blocks = manager.allocate_slots(req, num_tokens, computed_blocks)
|
||||
assert len(blocks) == 1
|
||||
|
||||
# Deallocate the block.
|
||||
manager.free(req)
|
||||
|
||||
# Allocate a new block that's not full, make sure hash info on the
|
||||
# block is cleared.
|
||||
req = make_request("1", list(range(num_tokens - 1)))
|
||||
computed_blocks = manager.get_computed_blocks(req)
|
||||
assert not computed_blocks
|
||||
blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks)
|
||||
assert len(blocks) == 1
|
||||
|
||||
assert manager.block_pool[blocks[0].block_id].block_hash is None
|
||||
|
||||
|
||||
def test_computed_blocks_not_evicted():
|
||||
"""
|
||||
Test that the computed blocks are not evicted when getting new blocks
|
||||
for a request if there are any other free blocks.
|
||||
"""
|
||||
block_size = 16
|
||||
manager = KVCacheManager(
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=2,
|
||||
sliding_window=False,
|
||||
enable_caching=True,
|
||||
num_preallocate_tokens=0,
|
||||
)
|
||||
|
||||
# Allocate a block and cache it.
|
||||
num_tokens = block_size * 1
|
||||
req0 = make_request("0", list(range(num_tokens)))
|
||||
computed_blocks = manager.get_computed_blocks(req0)
|
||||
assert not computed_blocks
|
||||
blocks = manager.allocate_slots(req0, num_tokens, computed_blocks)
|
||||
assert len(blocks) == 1
|
||||
assert blocks[0].block_id == 0
|
||||
|
||||
# Allocate another block.
|
||||
req1 = make_request("1", list(range(num_tokens, num_tokens * 2)))
|
||||
computed_blocks = manager.get_computed_blocks(req1)
|
||||
assert not computed_blocks
|
||||
blocks = manager.allocate_slots(req1, num_tokens, computed_blocks)
|
||||
assert len(blocks) == 1
|
||||
assert blocks[0].block_id == 1
|
||||
|
||||
# Free the blocks.
|
||||
manager.free(req0)
|
||||
manager.free(req1)
|
||||
|
||||
# Now if we have a cache hit on the first block, we should evict the second
|
||||
# cached block rather than the first one.
|
||||
req2 = make_request("2", list(range(num_tokens * 2)))
|
||||
computed_blocks = manager.get_computed_blocks(req2)
|
||||
assert len(computed_blocks) == 1
|
||||
assert computed_blocks[0].block_id == 0
|
||||
|
||||
blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens,
|
||||
computed_blocks)
|
||||
assert len(blocks) == 1
|
||||
assert blocks[0].block_id == 1
|
||||
|
||||
|
||||
def test_basic_prefix_caching_disabled():
|
||||
"""
|
||||
This tests that the prefix caching is disabled.
|
||||
"""
|
||||
block_size = 4
|
||||
manager = KVCacheManager(
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=4,
|
||||
sliding_window=False,
|
||||
enable_caching=False,
|
||||
num_preallocate_tokens=0,
|
||||
)
|
||||
|
||||
req1 = make_request("1", list(range(10))) # 2 blocks and some more
|
||||
|
||||
computed_blocks = manager.get_computed_blocks(req1)
|
||||
assert not computed_blocks
|
||||
blocks = manager.allocate_slots(req1, 10, computed_blocks)
|
||||
assert len(blocks) == 3
|
||||
|
||||
# Free the blocks.
|
||||
manager.free(req1)
|
||||
|
||||
# No caching.
|
||||
req2 = make_request("2", list(range(16))) # shared prefix
|
||||
computed_blocks = manager.get_computed_blocks(req2)
|
||||
assert not computed_blocks
|
||||
blocks = manager.allocate_slots(req2, 16, computed_blocks)
|
||||
assert len(blocks) == 4
|
||||
|
||||
# New requests should not have any blocks.
|
||||
req3 = make_request("3", list(range(4)))
|
||||
computed_blocks = manager.get_computed_blocks(req3)
|
||||
assert not computed_blocks
|
||||
blocks = manager.allocate_slots(req3, 4, computed_blocks)
|
||||
assert not blocks
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_preallocate_tokens", list(range(0, 8)))
|
||||
@pytest.mark.parametrize("block_size", [4])
|
||||
def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
|
||||
"""
|
||||
This tests that the preallocated blocks are correctly added.
|
||||
"""
|
||||
manager = KVCacheManager(
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=10,
|
||||
sliding_window=False,
|
||||
enable_caching=True,
|
||||
num_preallocate_tokens=num_preallocate_tokens,
|
||||
)
|
||||
num_preallocated_blocks = cdiv(num_preallocate_tokens, block_size)
|
||||
|
||||
req = make_request("0", list(range(block_size * 30)))
|
||||
computed_blocks = manager.get_computed_blocks(req)
|
||||
assert not computed_blocks
|
||||
# Just ask for 1 block.
|
||||
blocks = manager.allocate_slots(req, block_size, computed_blocks)
|
||||
assert len(blocks) == 1 + num_preallocated_blocks
|
||||
|
||||
# Append slots to the block.
|
||||
req.num_computed_tokens = block_size * len(blocks) # Assume all used.
|
||||
blocks = manager.append_slots(req, block_size) # Append 1 block.
|
||||
assert len(blocks) == 1 + num_preallocated_blocks
|
||||
|
||||
|
||||
def test_cache_blocks():
|
||||
"""
|
||||
This is a unit test that tests the correctness of the _cache_full_blocks
|
||||
function of KVCacheManager.
|
||||
"""
|
||||
block_size = 4
|
||||
manager = KVCacheManager(
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=5,
|
||||
sliding_window=False,
|
||||
enable_caching=True,
|
||||
num_preallocate_tokens=0,
|
||||
)
|
||||
# Req:
|
||||
# Block 0: [0, 1, 2, 3]
|
||||
# Block 1: [4, 5, 6, 7]
|
||||
# Block 2: [8, 9, 10, 11]
|
||||
# Block 3: [12, 13]
|
||||
req = make_request("0", list(range(14)))
|
||||
|
||||
# Test that blocks are cached correctly for 2 full blocks from the start.
|
||||
blocks = [KVCacheBlock(block_id=i) for i in range(2)]
|
||||
|
||||
manager._cache_full_blocks(
|
||||
request=req,
|
||||
blk_start_idx=0,
|
||||
full_blocks=blocks,
|
||||
prev_block=None,
|
||||
)
|
||||
|
||||
assert len(manager.cached_block_hash_to_block) == 2
|
||||
assert all([block.block_hash is not None for block in blocks])
|
||||
|
||||
# Test that blocks that don't start from the beginning are cached correctly.
|
||||
blocks = [KVCacheBlock(block_id=2)]
|
||||
manager._cache_full_blocks(
|
||||
request=req,
|
||||
blk_start_idx=2,
|
||||
full_blocks=blocks,
|
||||
prev_block=None,
|
||||
)
|
||||
assert len(manager.cached_block_hash_to_block) == 3
|
||||
assert blocks[0].block_hash is not None
|
||||
|
||||
@ -79,6 +79,9 @@ class KVCacheManager:
|
||||
return []
|
||||
|
||||
computed_blocks = []
|
||||
|
||||
# TODO(rickyx): potentially we could cache this so we don't have to
|
||||
# recompute it every time.
|
||||
block_hashes = hash_request_tokens(self.block_size,
|
||||
request.all_token_ids)
|
||||
|
||||
@ -120,47 +123,45 @@ class KVCacheManager:
|
||||
# slots, but we cannot allocate new blocks due to the limit.
|
||||
return None
|
||||
|
||||
# When caching is enabled, assign token IDs to already allocated blocks.
|
||||
new_token_ids = None
|
||||
parent_block = None
|
||||
if self.enable_caching:
|
||||
# Figure out the token IDs to add to the blocks.
|
||||
new_token_ids = request.all_token_ids[
|
||||
request.num_computed_tokens:request.num_computed_tokens +
|
||||
num_tokens]
|
||||
if num_new_blocks <= 0:
|
||||
# No new block is needed.
|
||||
new_blocks = []
|
||||
else:
|
||||
# Get new blocks from the free block pool considering
|
||||
# preallocated blocks.
|
||||
num_new_blocks = min(
|
||||
num_new_blocks + self.num_preallocate_blocks,
|
||||
self.free_block_queue.num_free_blocks,
|
||||
)
|
||||
|
||||
# Find the last full block index.
|
||||
# TODO: This may be optimized by calculating the computed tokens.
|
||||
last_full_block_idx = len(req_blocks) - 1
|
||||
while (last_full_block_idx >= 0
|
||||
and req_blocks[last_full_block_idx].block_hash is None):
|
||||
last_full_block_idx -= 1
|
||||
new_blocks = self._get_new_blocks(num_new_blocks)
|
||||
req_blocks.extend(new_blocks)
|
||||
|
||||
parent_block = (req_blocks[last_full_block_idx]
|
||||
if last_full_block_idx >= 0 else None)
|
||||
token_id_idx = self._add_token_ids_to_blocks(
|
||||
blocks=req_blocks[last_full_block_idx + 1:],
|
||||
token_ids=new_token_ids,
|
||||
parent_block=parent_block)
|
||||
if not self.enable_caching:
|
||||
return new_blocks
|
||||
|
||||
new_token_ids = new_token_ids[token_id_idx:]
|
||||
parent_block = req_blocks[-1]
|
||||
num_computed_full_blocks = (request.num_computed_tokens //
|
||||
self.block_size)
|
||||
|
||||
# No new block is needed. When caching is enabled, we make sure
|
||||
# token_id_idx is equal to len(new_token_ids), meaning that all tokens
|
||||
# are added to allocated blocks.
|
||||
if num_required_blocks <= len(req_blocks):
|
||||
assert not self.enable_caching or token_id_idx == num_tokens, \
|
||||
f"{token_id_idx=} != {num_tokens=}"
|
||||
return []
|
||||
# NOTE(rickyx): We are assuming the `num_tokens` are actual
|
||||
# tokens rather than lookahead slots (e.g. for speculative decoding).
|
||||
# TODO(rickyx): When supporting speculative decoding, we will need to
|
||||
# differentiate between them so that we can know how many blocks are
|
||||
# full after appending the actual tokens.
|
||||
num_full_blocks_after_append = (request.num_computed_tokens +
|
||||
num_tokens) // self.block_size
|
||||
assert num_full_blocks_after_append <= len(req_blocks)
|
||||
|
||||
new_full_blocks = req_blocks[
|
||||
num_computed_full_blocks:num_full_blocks_after_append]
|
||||
self._cache_full_blocks(
|
||||
request=request,
|
||||
blk_start_idx=num_computed_full_blocks,
|
||||
full_blocks=new_full_blocks,
|
||||
prev_block=req_blocks[num_computed_full_blocks - 1]
|
||||
if num_computed_full_blocks >= 1 else None,
|
||||
)
|
||||
|
||||
# Allocate new blocks considering preallocated blocks, and
|
||||
# add token IDs to them if caching is enabled.
|
||||
num_new_blocks = min(num_new_blocks + self.num_preallocate_blocks,
|
||||
self.free_block_queue.num_free_blocks)
|
||||
new_blocks = self._get_new_blocks(num_new_blocks, new_token_ids,
|
||||
parent_block)
|
||||
req_blocks.extend(new_blocks)
|
||||
return new_blocks
|
||||
|
||||
def allocate_slots(
|
||||
@ -184,11 +185,20 @@ class KVCacheManager:
|
||||
raise ValueError(
|
||||
f"num_tokens must be greater than 0, got {num_tokens}")
|
||||
|
||||
# If a computed block of a request is an eviction candidate (in the
|
||||
# free queue and ref_cnt == 0), it cannot be counted as a free block
|
||||
# when allocating this request.
|
||||
num_evictable_computed_blocks = len(
|
||||
[blk for blk in computed_blocks if blk.ref_cnt == 0])
|
||||
# Touch the computed blocks to make sure they won't be evicted.
|
||||
num_evictable_computed_blocks = 0
|
||||
if self.enable_caching:
|
||||
self._touch(computed_blocks)
|
||||
|
||||
# If a computed block of a request is an eviction candidate (in the
|
||||
# free queue and ref_cnt == 0), it cannot be counted as a free block
|
||||
# when allocating this request.
|
||||
num_evictable_computed_blocks = len(
|
||||
[blk for blk in computed_blocks if blk.ref_cnt == 0])
|
||||
else:
|
||||
assert not computed_blocks, (
|
||||
"Computed blocks should be empty when "
|
||||
"prefix caching is disabled")
|
||||
|
||||
num_required_blocks = cdiv(num_tokens, self.block_size)
|
||||
if (num_required_blocks > self.free_block_queue.num_free_blocks -
|
||||
@ -201,35 +211,28 @@ class KVCacheManager:
|
||||
num_new_blocks = min(
|
||||
num_required_blocks + self.num_preallocate_blocks,
|
||||
self.free_block_queue.num_free_blocks -
|
||||
num_evictable_computed_blocks)
|
||||
|
||||
num_computed_tokens = len(computed_blocks) * self.block_size
|
||||
|
||||
# When caching is enabled, get the new token IDs and the parent block
|
||||
# ID to generate cache keys.
|
||||
new_token_ids = None
|
||||
parent_block = None
|
||||
if self.enable_caching:
|
||||
# Touch the computed blocks to make sure they won't be evicted.
|
||||
self._touch(computed_blocks)
|
||||
|
||||
# Get the token IDs for the blocks being allocated for hashing.
|
||||
new_token_ids = request.all_token_ids[
|
||||
num_computed_tokens:num_computed_tokens + num_tokens]
|
||||
if not new_token_ids:
|
||||
raise RuntimeError(
|
||||
"Failed to infer the token IDs for allocation. "
|
||||
f"#all_tokens={len(request.all_token_ids)} < "
|
||||
f"#computed_tokens={num_computed_tokens}")
|
||||
|
||||
# Get the parent block ID to construct the block chain.
|
||||
parent_block = computed_blocks[-1] if computed_blocks else None
|
||||
|
||||
new_blocks = self._get_new_blocks(num_new_blocks, new_token_ids,
|
||||
parent_block)
|
||||
num_evictable_computed_blocks,
|
||||
)
|
||||
|
||||
# Concatenate the computed block IDs and the new block IDs.
|
||||
new_blocks = self._get_new_blocks(num_new_blocks)
|
||||
self.req_to_blocks[request.request_id] = computed_blocks + new_blocks
|
||||
|
||||
if not self.enable_caching:
|
||||
return new_blocks
|
||||
|
||||
num_computed_tokens = len(computed_blocks) * self.block_size
|
||||
num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size
|
||||
|
||||
self._cache_full_blocks(
|
||||
request=request,
|
||||
blk_start_idx=len(computed_blocks),
|
||||
# The new full blocks are the full blocks that are not computed.
|
||||
full_blocks=self.req_to_blocks[request.request_id]
|
||||
[len(computed_blocks):num_full_blocks],
|
||||
prev_block=computed_blocks[-1] if computed_blocks else None,
|
||||
)
|
||||
|
||||
return new_blocks
|
||||
|
||||
def free(self, request: Request) -> None:
|
||||
@ -248,24 +251,17 @@ class KVCacheManager:
|
||||
blocks = reversed(blocks)
|
||||
|
||||
for block in blocks:
|
||||
block.ref_cnt -= 1
|
||||
block.decr_ref()
|
||||
if block.ref_cnt == 0:
|
||||
self.free_block_queue.append(block)
|
||||
|
||||
def _get_new_blocks(
|
||||
self,
|
||||
num_blocks: int,
|
||||
token_ids: Optional[List[int]] = None,
|
||||
parent_block: Optional[int] = None) -> List[KVCacheBlock]:
|
||||
"""Get new blocks from the free block pool, and add token IDs to
|
||||
allocated blocks if caching is enabled.
|
||||
def _get_new_blocks(self, num_blocks: int) -> List[KVCacheBlock]:
|
||||
"""Get new blocks from the free block pool.
|
||||
|
||||
Note that we do not check block cache in this function.
|
||||
|
||||
Args:
|
||||
num_blocks: The number of blocks to allocate.
|
||||
token_ids: The token IDs in the blocks. None if caching is disabled.
|
||||
parent_block: The parent block. Used to include block chain
|
||||
in the block hash.
|
||||
|
||||
Returns:
|
||||
A list of new block.
|
||||
@ -274,56 +270,38 @@ class KVCacheManager:
|
||||
raise ValueError(
|
||||
f"Cannot get {num_blocks} free blocks from the pool")
|
||||
|
||||
# First allocate blocks.
|
||||
ret: List[KVCacheBlock] = []
|
||||
idx = 0
|
||||
while idx < num_blocks:
|
||||
# First allocate blocks.
|
||||
curr_block = self.free_block_queue.popleft()
|
||||
assert curr_block.ref_cnt == 0
|
||||
|
||||
# Evict blocks from the cache.
|
||||
# If the block is cached, evict it.
|
||||
if self.enable_caching:
|
||||
block_hash = curr_block.block_hash
|
||||
if (block_hash is not None
|
||||
and block_hash in self.cached_block_hash_to_block):
|
||||
if len(self.cached_block_hash_to_block[block_hash]) == 1:
|
||||
del self.cached_block_hash_to_block[block_hash]
|
||||
else:
|
||||
del self.cached_block_hash_to_block[block_hash][
|
||||
curr_block.block_id]
|
||||
curr_block.reset()
|
||||
self._evict_cached_block(curr_block)
|
||||
|
||||
curr_block.ref_cnt = 1
|
||||
curr_block.incr_ref()
|
||||
ret.append(curr_block)
|
||||
idx += 1
|
||||
|
||||
# Then assign token IDs to the allocated blocks.
|
||||
if self.enable_caching:
|
||||
assert token_ids is not None
|
||||
token_id_idx = self._add_token_ids_to_blocks(
|
||||
blocks=ret, token_ids=token_ids, parent_block=parent_block)
|
||||
assert token_id_idx == len(token_ids)
|
||||
|
||||
return ret
|
||||
|
||||
def _cache_full_block(self,
|
||||
block: KVCacheBlock,
|
||||
parent_block: Optional[KVCacheBlock] = None) -> None:
|
||||
"""Cache a full block for prefix caching.
|
||||
def _evict_cached_block(self, block: KVCacheBlock) -> None:
|
||||
"""
|
||||
If a block is cached in `cached_block_hash_to_block`, we reset its hash
|
||||
metadata and evict it from the cache.
|
||||
|
||||
Args:
|
||||
block: The block to cache.
|
||||
parent_block: The parent block. None if this is the first block.
|
||||
block: The block to evict.
|
||||
"""
|
||||
parent_block_hash = (parent_block.block_hash
|
||||
if parent_block is not None else None)
|
||||
assert len(block.token_ids) == self.block_size
|
||||
block.token_ids = tuple(block.token_ids)
|
||||
block_hash = hash_block_tokens(parent_block_hash, block.token_ids)
|
||||
block.block_hash = block_hash
|
||||
block.num_hashed_tokens = self.block_size + (
|
||||
parent_block.num_hashed_tokens if parent_block is not None else 0)
|
||||
self.cached_block_hash_to_block[block_hash][block.block_id] = block
|
||||
block_hash = block.block_hash
|
||||
if block_hash and block_hash in self.cached_block_hash_to_block:
|
||||
block.reset_hash()
|
||||
del self.cached_block_hash_to_block[block_hash][block.block_id]
|
||||
|
||||
if len(self.cached_block_hash_to_block[block_hash]) == 0:
|
||||
del self.cached_block_hash_to_block[block_hash]
|
||||
|
||||
def _get_cached_block(self,
|
||||
block_hash: BlockHashType) -> Optional[KVCacheBlock]:
|
||||
@ -355,43 +333,50 @@ class KVCacheManager:
|
||||
# candidate), so remove it.
|
||||
if block.ref_cnt == 0:
|
||||
self.free_block_queue.remove(block)
|
||||
block.ref_cnt += 1
|
||||
block.incr_ref()
|
||||
|
||||
def _add_token_ids_to_blocks(
|
||||
self,
|
||||
blocks: List[KVCacheBlock],
|
||||
token_ids: List[int],
|
||||
parent_block: Optional[KVCacheBlock] = None) -> int:
|
||||
"""Add token IDs to a list of allocated blocks.
|
||||
If a block becomes full after adding token IDs, cache it.
|
||||
Return the token ID index that has not been added to the blocks
|
||||
if the blocks are not enough to hold all the token IDs.
|
||||
def _cache_full_blocks(
|
||||
self,
|
||||
request: Request,
|
||||
blk_start_idx: int,
|
||||
full_blocks: List[KVCacheBlock],
|
||||
prev_block: Optional[KVCacheBlock],
|
||||
) -> None:
|
||||
"""Cache a list of full blocks for prefix caching.
|
||||
|
||||
This function takes a list of blocks that will have their block hash
|
||||
metadata to be updated and cached. Given a request, it computes the
|
||||
block hashes for the blocks starting from `blk_start_idx` to the end
|
||||
of the request's full blocks, updating the metadata for each block
|
||||
and caching them in the `cached_block_hash_to_block`.
|
||||
|
||||
Args:
|
||||
blocks: A list of blocks to add token IDs.
|
||||
token_ids: A list of token IDs to add.
|
||||
parent_block: The parent block. None if this is the
|
||||
first block.
|
||||
|
||||
Returns:
|
||||
The starting token ID index that has not been added to the blocks
|
||||
due to insufficient given blocks.
|
||||
request: The request to cache the blocks.
|
||||
blk_start_idx: The index of the first block in the request's blocks
|
||||
to cache.
|
||||
full_blocks: The list of blocks to update hash metadata.
|
||||
prev_block: The previous block in the chain.
|
||||
"""
|
||||
token_id_start = 0
|
||||
for curr_block in blocks:
|
||||
# If all token IDs are added, then the rest of the blocks are
|
||||
# preallocated blocks, so we only need to update the
|
||||
# parent_block_id. FIXME
|
||||
if token_id_start == len(token_ids):
|
||||
continue
|
||||
# Update the new blocks with the block hashes through the chain.
|
||||
prev_block_hash = (prev_block.block_hash
|
||||
if prev_block is not None else None)
|
||||
for i, blk in enumerate(full_blocks):
|
||||
blk_idx = blk_start_idx + i
|
||||
|
||||
# Add token IDs to the empty slots in the block.
|
||||
empty_slots = self.block_size - len(curr_block.token_ids)
|
||||
token_id_end = min(token_id_start + empty_slots, len(token_ids))
|
||||
curr_block.token_ids.extend(token_ids[token_id_start:token_id_end])
|
||||
# Cache the block if it becomes full.
|
||||
if len(curr_block.token_ids) == self.block_size:
|
||||
self._cache_full_block(curr_block, parent_block)
|
||||
parent_block = curr_block
|
||||
token_id_start = token_id_end
|
||||
return token_id_start
|
||||
block_tokens = request.all_token_ids[blk_idx *
|
||||
self.block_size:(blk_idx +
|
||||
1) *
|
||||
self.block_size]
|
||||
assert len(block_tokens) == self.block_size, (
|
||||
f"Expected {self.block_size} tokens, got {len(block_tokens)} "
|
||||
f"at {blk_idx}th block for request "
|
||||
f"{request.request_id}({request})")
|
||||
|
||||
# Compute the hash of the current block.
|
||||
block_hash = hash_block_tokens(prev_block_hash,
|
||||
tuple(block_tokens))
|
||||
|
||||
# Update and added the full block to the cache.
|
||||
blk.block_hash = block_hash
|
||||
self.cached_block_hash_to_block[block_hash][blk.block_id] = blk
|
||||
prev_block_hash = block_hash
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
"""KV-Cache Utilities."""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
@ -16,27 +16,34 @@ class KVCacheBlock:
|
||||
block_id: int
|
||||
# Reference count.
|
||||
ref_cnt: int = 0
|
||||
# Token IDs in the block. When the block is full, the type of token_ids
|
||||
# should be Tuple[int] for fast matching.
|
||||
token_ids: Union[List[int], Tuple[int]] = field(default_factory=list)
|
||||
# The hash of the block composed of (block hash, tuple of token IDs).
|
||||
# It is only available when the block is full.
|
||||
block_hash: Optional[BlockHashType] = None
|
||||
# The number of hashed tokens. More hashed tokens means the block
|
||||
# is closer to the end of a prompt and more likely to be evicted.
|
||||
num_hashed_tokens: int = 0
|
||||
_block_hash: Optional[BlockHashType] = None
|
||||
|
||||
# Used to construct a doubly linked list for free blocks.
|
||||
# These two attributes should only be manipulated by FreeKVCacheBlockQueue.
|
||||
prev_free_block: Optional["KVCacheBlock"] = None
|
||||
next_free_block: Optional["KVCacheBlock"] = None
|
||||
|
||||
def reset(self):
|
||||
"""Reset the block metadata."""
|
||||
self.ref_cnt = 0
|
||||
self.token_ids = []
|
||||
self.block_hash = None
|
||||
self.num_hashed_tokens = 0
|
||||
def incr_ref(self):
|
||||
self.ref_cnt += 1
|
||||
|
||||
def decr_ref(self):
|
||||
self.ref_cnt -= 1
|
||||
|
||||
@property
|
||||
def block_hash(self) -> Optional[BlockHashType]:
|
||||
return self._block_hash
|
||||
|
||||
@block_hash.setter
|
||||
def block_hash(self, block_hash: BlockHashType):
|
||||
assert self.block_hash is None, (
|
||||
"The block already has a hash. This should not happen.")
|
||||
self._block_hash = block_hash
|
||||
|
||||
def reset_hash(self):
|
||||
"""Reset the block hash when the block is evicted."""
|
||||
self._block_hash = None
|
||||
|
||||
|
||||
class FreeKVCacheBlockQueue:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user