mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:05:01 +08:00
[Core] Use KVCacheBlock as much as possible instead of dict[block_id, KVCacheBlock] (#24830)
Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com>
This commit is contained in:
parent
ae002924e9
commit
4f8c4b890a
@ -14,10 +14,11 @@ from vllm.multimodal.inputs import (MultiModalFeatureSpec,
|
|||||||
MultiModalKwargsItem, PlaceholderRange)
|
MultiModalKwargsItem, PlaceholderRange)
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.utils import sha256, sha256_cbor
|
from vllm.utils import sha256, sha256_cbor
|
||||||
from vllm.v1.core.block_pool import BlockPool
|
from vllm.v1.core.block_pool import BlockHashToBlockMap, BlockPool
|
||||||
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
|
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
|
||||||
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
|
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
|
||||||
get_block_hash, get_group_id,
|
KVCacheBlock, get_block_hash,
|
||||||
|
get_group_id,
|
||||||
get_request_block_hasher,
|
get_request_block_hasher,
|
||||||
hash_block_tokens, init_none_hash,
|
hash_block_tokens, init_none_hash,
|
||||||
make_block_hash_with_group_id)
|
make_block_hash_with_group_id)
|
||||||
@ -138,7 +139,7 @@ def test_prefill(hash_fn):
|
|||||||
blocks = manager.allocate_slots(req0, 55,
|
blocks = manager.allocate_slots(req0, 55,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||||
|
|
||||||
# Check full block metadata
|
# Check full block metadata
|
||||||
parent_block_hash = None
|
parent_block_hash = None
|
||||||
@ -171,7 +172,7 @@ def test_prefill(hash_fn):
|
|||||||
blocks = manager.allocate_slots(req1, num_new_tokens,
|
blocks = manager.allocate_slots(req1, num_new_tokens,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert blocks.get_block_ids() == ([5], )
|
assert blocks is not None and blocks.get_block_ids() == ([5], )
|
||||||
for block in computed_blocks.blocks[0]:
|
for block in computed_blocks.blocks[0]:
|
||||||
assert block.ref_cnt == 2
|
assert block.ref_cnt == 2
|
||||||
|
|
||||||
@ -207,7 +208,7 @@ def test_prefill(hash_fn):
|
|||||||
blocks = manager.allocate_slots(req2, num_new_tokens,
|
blocks = manager.allocate_slots(req2, num_new_tokens,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert blocks.get_block_ids() == ([6], )
|
assert blocks is not None and blocks.get_block_ids() == ([6], )
|
||||||
|
|
||||||
# Although we only have 6 free blocks, we have 8 blocks in
|
# Although we only have 6 free blocks, we have 8 blocks in
|
||||||
# the free block queue due to lazy removal.
|
# the free block queue due to lazy removal.
|
||||||
@ -227,7 +228,9 @@ def test_prefill(hash_fn):
|
|||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
# This block ID order also checks the eviction order.
|
# This block ID order also checks the eviction order.
|
||||||
assert blocks.get_block_ids() == ([7, 8, 9, 10, 4, 5, 6, 3, 2, 1], )
|
assert blocks is not None and blocks.get_block_ids() == ([
|
||||||
|
7, 8, 9, 10, 4, 5, 6, 3, 2, 1
|
||||||
|
], )
|
||||||
|
|
||||||
assert free_block_queue.num_free_blocks == 0
|
assert free_block_queue.num_free_blocks == 0
|
||||||
assert (free_block_queue.fake_free_list_head.next_free_block
|
assert (free_block_queue.fake_free_list_head.next_free_block
|
||||||
@ -261,8 +264,9 @@ def test_prefill_hybrid_model():
|
|||||||
blocks = manager.allocate_slots(req0, 55,
|
blocks = manager.allocate_slots(req0, 55,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], [5, 6, 7,
|
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], [
|
||||||
8], [9, 10, 11, 12])
|
5, 6, 7, 8
|
||||||
|
], [9, 10, 11, 12])
|
||||||
|
|
||||||
# Check full block metadata
|
# Check full block metadata
|
||||||
parent_block_hash = None
|
parent_block_hash = None
|
||||||
@ -298,7 +302,7 @@ def test_prefill_hybrid_model():
|
|||||||
blocks = manager.allocate_slots(req1, num_new_tokens,
|
blocks = manager.allocate_slots(req1, num_new_tokens,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert blocks.get_block_ids() == ([13], [14], [15])
|
assert blocks is not None and blocks.get_block_ids() == ([13], [14], [15])
|
||||||
for block_per_group in computed_blocks.blocks:
|
for block_per_group in computed_blocks.blocks:
|
||||||
for block in block_per_group:
|
for block in block_per_group:
|
||||||
if block != manager.block_pool.null_block:
|
if block != manager.block_pool.null_block:
|
||||||
@ -309,14 +313,15 @@ def test_prefill_hybrid_model():
|
|||||||
manager.free(req1)
|
manager.free(req1)
|
||||||
|
|
||||||
cached_block_hash_to_block_bak = copy.copy(
|
cached_block_hash_to_block_bak = copy.copy(
|
||||||
manager.block_pool.cached_block_hash_to_block)
|
manager.block_pool.cached_block_hash_to_block._cache)
|
||||||
|
|
||||||
def test_partial_request_hit(request_id: str, hash_to_evict: list[bytes],
|
def test_partial_request_hit(request_id: str,
|
||||||
|
hash_to_evict: list[BlockHashWithGroupId],
|
||||||
expect_hit_length: int):
|
expect_hit_length: int):
|
||||||
req = make_request(request_id, common_token_ids + unique_token_ids,
|
req = make_request(request_id, common_token_ids + unique_token_ids,
|
||||||
block_size, sha256)
|
block_size, sha256)
|
||||||
for hash_with_group_id in hash_to_evict:
|
for hash_with_group_id in hash_to_evict:
|
||||||
manager.block_pool.cached_block_hash_to_block.pop(
|
manager.block_pool.cached_block_hash_to_block._cache.pop(
|
||||||
hash_with_group_id)
|
hash_with_group_id)
|
||||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
||||||
assert len(req.block_hashes) == 3
|
assert len(req.block_hashes) == 3
|
||||||
@ -324,7 +329,7 @@ def test_prefill_hybrid_model():
|
|||||||
for block_per_group in computed_blocks.blocks:
|
for block_per_group in computed_blocks.blocks:
|
||||||
assert len(block_per_group) == num_computed_tokens // block_size
|
assert len(block_per_group) == num_computed_tokens // block_size
|
||||||
for hash_with_group_id in hash_to_evict:
|
for hash_with_group_id in hash_to_evict:
|
||||||
manager.block_pool.cached_block_hash_to_block[
|
manager.block_pool.cached_block_hash_to_block._cache[
|
||||||
hash_with_group_id] = cached_block_hash_to_block_bak[
|
hash_with_group_id] = cached_block_hash_to_block_bak[
|
||||||
hash_with_group_id]
|
hash_with_group_id]
|
||||||
manager.free(req)
|
manager.free(req)
|
||||||
@ -362,7 +367,8 @@ def test_prefill_hybrid_model():
|
|||||||
# total cache miss.
|
# total cache miss.
|
||||||
# The cache hit length of full attention is 1 * block_size.
|
# The cache hit length of full attention is 1 * block_size.
|
||||||
# The cache hit length of sliding window is 2 * block_size.
|
# The cache hit length of sliding window is 2 * block_size.
|
||||||
# Then it is cache miss as the two type of layers have different hit length.
|
# Then it is cache miss as the two type of layers
|
||||||
|
# have different hit length.
|
||||||
test_partial_request_hit("8", [
|
test_partial_request_hit("8", [
|
||||||
make_block_hash_with_group_id(block_hashes[2], 0),
|
make_block_hash_with_group_id(block_hashes[2], 0),
|
||||||
make_block_hash_with_group_id(block_hashes[0], 1),
|
make_block_hash_with_group_id(block_hashes[0], 1),
|
||||||
@ -406,7 +412,7 @@ def test_prefill_plp():
|
|||||||
blocks = manager.allocate_slots(req0, 55,
|
blocks = manager.allocate_slots(req0, 55,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||||
req0_block_hashes = [b.block_hash for b in blocks.blocks[0]]
|
req0_block_hashes = [b.block_hash for b in blocks.blocks[0]]
|
||||||
|
|
||||||
# Check full block metadata
|
# Check full block metadata
|
||||||
@ -441,7 +447,7 @@ def test_prefill_plp():
|
|||||||
blocks = manager.allocate_slots(req1, num_new_tokens,
|
blocks = manager.allocate_slots(req1, num_new_tokens,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert blocks.get_block_ids() == ([5], )
|
assert blocks is not None and blocks.get_block_ids() == ([5], )
|
||||||
for block in computed_blocks.blocks[0]:
|
for block in computed_blocks.blocks[0]:
|
||||||
assert block.ref_cnt == 2
|
assert block.ref_cnt == 2
|
||||||
|
|
||||||
@ -478,6 +484,7 @@ def test_prefill_plp():
|
|||||||
blocks = manager.allocate_slots(req2, 55,
|
blocks = manager.allocate_slots(req2, 55,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
|
assert blocks is not None
|
||||||
block_ids = blocks.get_block_ids()
|
block_ids = blocks.get_block_ids()
|
||||||
# Duplicate cached blocks have different ids but same hashes vs request #0
|
# Duplicate cached blocks have different ids but same hashes vs request #0
|
||||||
assert [b.block_hash for b in blocks.blocks[0]] == req0_block_hashes
|
assert [b.block_hash for b in blocks.blocks[0]] == req0_block_hashes
|
||||||
@ -513,7 +520,7 @@ def test_decode():
|
|||||||
blocks = manager.allocate_slots(req0, 55,
|
blocks = manager.allocate_slots(req0, 55,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||||
|
|
||||||
# Append slots without allocating a new block.
|
# Append slots without allocating a new block.
|
||||||
req0.num_computed_tokens = 55
|
req0.num_computed_tokens = 55
|
||||||
@ -558,7 +565,8 @@ def test_evict():
|
|||||||
blocks = manager.allocate_slots(req0, 5 * 16 + 7,
|
blocks = manager.allocate_slots(req0, 5 * 16 + 7,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert len(blocks.blocks[0]) == 6 # 5 full + 1 partial
|
# 5 full + 1 partial
|
||||||
|
assert blocks is not None and len(blocks.blocks[0]) == 6
|
||||||
|
|
||||||
# 3 blocks.
|
# 3 blocks.
|
||||||
req1 = make_request("1", list(range(last_token_id,
|
req1 = make_request("1", list(range(last_token_id,
|
||||||
@ -570,7 +578,7 @@ def test_evict():
|
|||||||
blocks = manager.allocate_slots(req1, 3 * 16,
|
blocks = manager.allocate_slots(req1, 3 * 16,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert len(blocks.blocks[0]) == 3 # 3 full blocks
|
assert blocks is not None and len(blocks.blocks[0]) == 3 # 3 full blocks
|
||||||
last_token_id += 3 * 16
|
last_token_id += 3 * 16
|
||||||
|
|
||||||
# 10 - (6 + 3) == 1
|
# 10 - (6 + 3) == 1
|
||||||
@ -592,7 +600,7 @@ def test_evict():
|
|||||||
blocks = manager.allocate_slots(req2, 3,
|
blocks = manager.allocate_slots(req2, 3,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert blocks.get_block_ids() == ([10], )
|
assert blocks is not None and blocks.get_block_ids() == ([10], )
|
||||||
assert manager.block_pool.free_block_queue.num_free_blocks == 7
|
assert manager.block_pool.free_block_queue.num_free_blocks == 7
|
||||||
|
|
||||||
|
|
||||||
@ -617,7 +625,7 @@ def test_hash_block_correct_reuse():
|
|||||||
blocks = manager.allocate_slots(req, num_tokens,
|
blocks = manager.allocate_slots(req, num_tokens,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert len(blocks.blocks[0]) == 1
|
assert blocks is not None and len(blocks.blocks[0]) == 1
|
||||||
|
|
||||||
# Deallocate the block.
|
# Deallocate the block.
|
||||||
manager.free(req)
|
manager.free(req)
|
||||||
@ -631,7 +639,7 @@ def test_hash_block_correct_reuse():
|
|||||||
blocks = manager.allocate_slots(req, num_tokens - 1,
|
blocks = manager.allocate_slots(req, num_tokens - 1,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert len(blocks.blocks[0]) == 1
|
assert blocks is not None and len(blocks.blocks[0]) == 1
|
||||||
|
|
||||||
assert manager.block_pool.blocks[blocks.blocks[0]
|
assert manager.block_pool.blocks[blocks.blocks[0]
|
||||||
[0].block_id].block_hash is None
|
[0].block_id].block_hash is None
|
||||||
@ -658,7 +666,7 @@ def test_computed_blocks_not_evicted():
|
|||||||
blocks = manager.allocate_slots(req0, num_tokens,
|
blocks = manager.allocate_slots(req0, num_tokens,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert len(blocks.blocks[0]) == 1
|
assert blocks is not None and len(blocks.blocks[0]) == 1
|
||||||
assert blocks.blocks[0][0].block_id == 1
|
assert blocks.blocks[0][0].block_id == 1
|
||||||
|
|
||||||
# Allocate another block.
|
# Allocate another block.
|
||||||
@ -670,7 +678,7 @@ def test_computed_blocks_not_evicted():
|
|||||||
blocks = manager.allocate_slots(req1, num_tokens,
|
blocks = manager.allocate_slots(req1, num_tokens,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert len(blocks.blocks[0]) == 1
|
assert blocks is not None and len(blocks.blocks[0]) == 1
|
||||||
assert blocks.blocks[0][0].block_id == 2
|
assert blocks.blocks[0][0].block_id == 2
|
||||||
|
|
||||||
# Free the blocks.
|
# Free the blocks.
|
||||||
@ -688,7 +696,7 @@ def test_computed_blocks_not_evicted():
|
|||||||
blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens,
|
blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert len(blocks.blocks[0]) == 1
|
assert blocks is not None and len(blocks.blocks[0]) == 1
|
||||||
assert blocks.blocks[0][0].block_id == 2
|
assert blocks.blocks[0][0].block_id == 2
|
||||||
|
|
||||||
|
|
||||||
@ -712,7 +720,7 @@ def test_basic_prefix_caching_disabled():
|
|||||||
blocks = manager.allocate_slots(req1, 10,
|
blocks = manager.allocate_slots(req1, 10,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert len(blocks.blocks[0]) == 3
|
assert blocks is not None and len(blocks.blocks[0]) == 3
|
||||||
|
|
||||||
# Free the blocks.
|
# Free the blocks.
|
||||||
manager.free(req1)
|
manager.free(req1)
|
||||||
@ -726,7 +734,7 @@ def test_basic_prefix_caching_disabled():
|
|||||||
blocks = manager.allocate_slots(req2, 16,
|
blocks = manager.allocate_slots(req2, 16,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert len(blocks.blocks[0]) == 4
|
assert blocks is not None and len(blocks.blocks[0]) == 4
|
||||||
|
|
||||||
# New requests should not have any blocks.
|
# New requests should not have any blocks.
|
||||||
req3 = make_request("3", list(range(4)), block_size, sha256)
|
req3 = make_request("3", list(range(4)), block_size, sha256)
|
||||||
@ -773,7 +781,8 @@ def test_cache_blocks(hash_fn):
|
|||||||
assert len(block_pool.cached_block_hash_to_block) == 2
|
assert len(block_pool.cached_block_hash_to_block) == 2
|
||||||
assert all([block.block_hash is not None for block in blocks])
|
assert all([block.block_hash is not None for block in blocks])
|
||||||
|
|
||||||
# Test that blocks that don't start from the beginning are cached correctly.
|
# Test that blocks that don't start from the beginning are cached
|
||||||
|
# correctly.
|
||||||
blocks += [KVCacheBlock(block_id=2)]
|
blocks += [KVCacheBlock(block_id=2)]
|
||||||
block_pool.cache_full_blocks(
|
block_pool.cache_full_blocks(
|
||||||
request=req,
|
request=req,
|
||||||
@ -1101,7 +1110,7 @@ def test_reset_prefix_cache():
|
|||||||
all_token_ids = full_block_token_ids + unique_token_ids
|
all_token_ids = full_block_token_ids + unique_token_ids
|
||||||
req0 = make_request("0", all_token_ids, block_size, sha256)
|
req0 = make_request("0", all_token_ids, block_size, sha256)
|
||||||
blocks = manager.allocate_slots(req0, 55)
|
blocks = manager.allocate_slots(req0, 55)
|
||||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||||
|
|
||||||
unique_token_ids = [4] * 7
|
unique_token_ids = [4] * 7
|
||||||
all_token_ids = full_block_token_ids + unique_token_ids
|
all_token_ids = full_block_token_ids + unique_token_ids
|
||||||
@ -1112,7 +1121,7 @@ def test_reset_prefix_cache():
|
|||||||
blocks = manager.allocate_slots(req1, 7,
|
blocks = manager.allocate_slots(req1, 7,
|
||||||
len(computed_blocks.blocks[0]) * 16,
|
len(computed_blocks.blocks[0]) * 16,
|
||||||
computed_blocks)
|
computed_blocks)
|
||||||
assert blocks.get_block_ids() == ([5], )
|
assert blocks is not None and blocks.get_block_ids() == ([5], )
|
||||||
|
|
||||||
# Failed to reset prefix cache because some blocks are not freed yet.
|
# Failed to reset prefix cache because some blocks are not freed yet.
|
||||||
assert not manager.reset_prefix_cache()
|
assert not manager.reset_prefix_cache()
|
||||||
@ -1168,49 +1177,41 @@ def test_maybe_evict_cached_block():
|
|||||||
# Manually add all blocks to cached_blocks
|
# Manually add all blocks to cached_blocks
|
||||||
for block, block_hash in zip(pool.blocks, block_hashes):
|
for block, block_hash in zip(pool.blocks, block_hashes):
|
||||||
block.block_hash = block_hash
|
block.block_hash = block_hash
|
||||||
pool.cached_block_hash_to_block[block_hash][block.block_id] = block
|
pool.cached_block_hash_to_block.insert(block_hash, block)
|
||||||
|
|
||||||
block0, block1, block2, block3 = pool.blocks
|
block0, block1, block2, block3 = pool.blocks
|
||||||
assert pool.cached_block_hash_to_block == {
|
assert pool.cached_block_hash_to_block._cache == {
|
||||||
block_hash0: {
|
block_hash0: {
|
||||||
block0.block_id: block0,
|
block0.block_id: block0,
|
||||||
block3.block_id: block3
|
block3.block_id: block3,
|
||||||
},
|
},
|
||||||
block_hash1: {
|
block_hash1: block1,
|
||||||
block1.block_id: block1
|
block_hash2: block2,
|
||||||
},
|
|
||||||
block_hash2: {
|
|
||||||
block2.block_id: block2
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
# Evict block1
|
# Evict block1
|
||||||
pool._maybe_evict_cached_block(block1)
|
pool._maybe_evict_cached_block(block1)
|
||||||
assert pool.cached_block_hash_to_block == {
|
assert pool.cached_block_hash_to_block._cache == {
|
||||||
block_hash0: {
|
block_hash0: {
|
||||||
block0.block_id: block0,
|
block0.block_id: block0,
|
||||||
block3.block_id: block3
|
block3.block_id: block3
|
||||||
},
|
},
|
||||||
block_hash2: {
|
block_hash2: block2,
|
||||||
block2.block_id: block2
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
# Evict block0: block_hash0 entry should NOT be removed, as block3
|
# Evict block0: block_hash0 entry should NOT be removed, as block3
|
||||||
# also use the same hash
|
# also use the same hash
|
||||||
pool._maybe_evict_cached_block(block0)
|
pool._maybe_evict_cached_block(block0)
|
||||||
assert pool.cached_block_hash_to_block == {
|
assert pool.cached_block_hash_to_block._cache == {
|
||||||
block_hash0: {
|
block_hash0: {
|
||||||
block3.block_id: block3
|
block3.block_id: block3
|
||||||
},
|
},
|
||||||
block_hash2: {
|
block_hash2: block2,
|
||||||
block2.block_id: block2
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
# Evict block2
|
# Evict block2
|
||||||
pool._maybe_evict_cached_block(block2)
|
pool._maybe_evict_cached_block(block2)
|
||||||
assert pool.cached_block_hash_to_block == {block_hash0: {3: block3}}
|
assert pool.cached_block_hash_to_block._cache == {block_hash0: {3: block3}}
|
||||||
# Evict block3
|
# Evict block3
|
||||||
pool._maybe_evict_cached_block(block3)
|
pool._maybe_evict_cached_block(block3)
|
||||||
assert pool.cached_block_hash_to_block == {}
|
assert pool.cached_block_hash_to_block._cache == {}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("blocks_to_cache", [2, 3, 10])
|
@pytest.mark.parametrize("blocks_to_cache", [2, 3, 10])
|
||||||
@ -1374,7 +1375,7 @@ def test_eagle_with_sliding_window():
|
|||||||
# Evict the first block in the request
|
# Evict the first block in the request
|
||||||
assert manager.block_pool.get_cached_block(
|
assert manager.block_pool.get_cached_block(
|
||||||
block_hash_first_block, kv_cache_group_ids=[0]) is not None
|
block_hash_first_block, kv_cache_group_ids=[0]) is not None
|
||||||
manager.block_pool.cached_block_hash_to_block.pop(
|
manager.block_pool.cached_block_hash_to_block._cache.pop(
|
||||||
make_block_hash_with_group_id(block_hash_first_block, 0))
|
make_block_hash_with_group_id(block_hash_first_block, 0))
|
||||||
|
|
||||||
# New request
|
# New request
|
||||||
@ -1386,3 +1387,78 @@ def test_eagle_with_sliding_window():
|
|||||||
# there will be no matched prefix.
|
# there will be no matched prefix.
|
||||||
assert len(computed_blocks.blocks[0]) == 0
|
assert len(computed_blocks.blocks[0]) == 0
|
||||||
assert num_tokens == 0
|
assert num_tokens == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_block_lookup_cache_single_block_per_key():
|
||||||
|
cache = BlockHashToBlockMap()
|
||||||
|
key0 = BlockHashWithGroupId(b"hash0")
|
||||||
|
key1 = BlockHashWithGroupId(b"hash1")
|
||||||
|
key2 = BlockHashWithGroupId(b"hash2")
|
||||||
|
block0 = KVCacheBlock(0)
|
||||||
|
block1 = KVCacheBlock(1)
|
||||||
|
|
||||||
|
assert cache.get_one_block(key0) is None
|
||||||
|
assert cache.get_one_block(key1) is None
|
||||||
|
assert cache.get_one_block(key2) is None
|
||||||
|
# key0 inserted
|
||||||
|
cache.insert(key0, block0)
|
||||||
|
assert cache.get_one_block(key0) is block0
|
||||||
|
assert cache.get_one_block(key1) is None
|
||||||
|
assert cache.get_one_block(key2) is None
|
||||||
|
# key1 inserted
|
||||||
|
cache.insert(key1, block1)
|
||||||
|
assert cache.get_one_block(key0) is block0
|
||||||
|
assert cache.get_one_block(key1) is block1
|
||||||
|
assert cache.get_one_block(key2) is None
|
||||||
|
# No block poped due to block_id mismatch
|
||||||
|
assert cache.pop(key0, 100) is None
|
||||||
|
assert cache.get_one_block(key0) is block0
|
||||||
|
assert cache.get_one_block(key1) is block1
|
||||||
|
assert cache.get_one_block(key2) is None
|
||||||
|
# block poped with (key0, block ID 0)
|
||||||
|
assert cache.pop(key0, 0) is block0
|
||||||
|
assert cache.get_one_block(key0) is None
|
||||||
|
assert cache.get_one_block(key1) is block1
|
||||||
|
assert cache.get_one_block(key2) is None
|
||||||
|
# No block poped due to block_id mismatch
|
||||||
|
assert cache.pop(key0, 1) is None
|
||||||
|
assert cache.get_one_block(key0) is None
|
||||||
|
assert cache.get_one_block(key1) is block1
|
||||||
|
assert cache.get_one_block(key2) is None
|
||||||
|
# block poped with (key1, block ID 1)
|
||||||
|
assert cache.pop(key1, 1) is block1
|
||||||
|
assert cache.get_one_block(key0) is None
|
||||||
|
assert cache.get_one_block(key1) is None
|
||||||
|
assert cache.get_one_block(key2) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_block_lookup_cache_multi_blocks_per_key():
|
||||||
|
cache = BlockHashToBlockMap()
|
||||||
|
key0 = BlockHashWithGroupId(b"hash0")
|
||||||
|
key1 = BlockHashWithGroupId(b"hash1")
|
||||||
|
block00 = KVCacheBlock(0)
|
||||||
|
block01 = KVCacheBlock(1)
|
||||||
|
block10 = KVCacheBlock(10)
|
||||||
|
block11 = KVCacheBlock(11)
|
||||||
|
|
||||||
|
assert cache.get_one_block(key0) is None
|
||||||
|
assert cache.get_one_block(key1) is None
|
||||||
|
|
||||||
|
cache.insert(key0, block00)
|
||||||
|
cache.insert(key0, block01)
|
||||||
|
cache.insert(key1, block10)
|
||||||
|
cache.insert(key1, block11)
|
||||||
|
|
||||||
|
assert cache.get_one_block(key0) is block00
|
||||||
|
assert cache.pop(key0, 0) is block00
|
||||||
|
assert cache.get_one_block(key0) is block01
|
||||||
|
assert cache.pop(key0, 1) is block01
|
||||||
|
assert cache.get_one_block(key0) is None
|
||||||
|
assert cache.pop(key0, 2) is None
|
||||||
|
|
||||||
|
assert cache.get_one_block(key1) is block10
|
||||||
|
assert cache.pop(key1, 10) is block10
|
||||||
|
assert cache.get_one_block(key1) is block11
|
||||||
|
assert cache.pop(key1, 11) is block11
|
||||||
|
assert cache.get_one_block(key1) is None
|
||||||
|
assert cache.pop(key1, 12) is None
|
||||||
|
|||||||
@ -47,16 +47,15 @@ def test_chunked_local_attention_possible_cached_prefix():
|
|||||||
BlockHash(str(i).encode()) for i in range(len(block_is_cached))
|
BlockHash(str(i).encode()) for i in range(len(block_is_cached))
|
||||||
]
|
]
|
||||||
|
|
||||||
block_pool.cached_block_hash_to_block.clear()
|
block_pool.cached_block_hash_to_block._cache.clear()
|
||||||
|
|
||||||
# Mock the block pool with the cached blocks
|
# Mock the block pool with the cached blocks
|
||||||
for i, (block_hash,
|
for i, (block_hash,
|
||||||
is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
|
is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
|
||||||
if is_cached:
|
if is_cached:
|
||||||
block_pool.cached_block_hash_to_block[
|
block_pool.cached_block_hash_to_block.insert(
|
||||||
make_block_hash_with_group_id(block_hash, 0)] = {
|
make_block_hash_with_group_id(block_hash, 0),
|
||||||
i: block_pool.blocks[i + 10],
|
block_pool.blocks[i + 10])
|
||||||
}
|
|
||||||
|
|
||||||
computed_blocks = manager.find_longest_cache_hit(
|
computed_blocks = manager.find_longest_cache_hit(
|
||||||
block_hashes=block_hash_list,
|
block_hashes=block_hash_list,
|
||||||
@ -112,16 +111,15 @@ def test_sliding_window_possible_cached_prefix():
|
|||||||
BlockHash(str(i).encode()) for i in range(len(block_is_cached))
|
BlockHash(str(i).encode()) for i in range(len(block_is_cached))
|
||||||
]
|
]
|
||||||
|
|
||||||
block_pool.cached_block_hash_to_block.clear()
|
block_pool.cached_block_hash_to_block._cache.clear()
|
||||||
|
|
||||||
# Mock the block pool with the cached blocks
|
# Mock the block pool with the cached blocks
|
||||||
for i, (block_hash,
|
for i, (block_hash,
|
||||||
is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
|
is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
|
||||||
if is_cached:
|
if is_cached:
|
||||||
block_pool.cached_block_hash_to_block[
|
block_pool.cached_block_hash_to_block.insert(
|
||||||
make_block_hash_with_group_id(block_hash, 0)] = {
|
make_block_hash_with_group_id(block_hash, 0),
|
||||||
i: block_pool.blocks[i + 10],
|
block_pool.blocks[i + 10])
|
||||||
}
|
|
||||||
|
|
||||||
computed_blocks = manager.find_longest_cache_hit(
|
computed_blocks = manager.find_longest_cache_hit(
|
||||||
block_hashes=block_hash_list,
|
block_hashes=block_hash_list,
|
||||||
|
|||||||
@ -1,8 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from collections import defaultdict
|
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import Optional
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from vllm.distributed.kv_events import (MEDIUM_GPU, AllBlocksCleared,
|
from vllm.distributed.kv_events import (MEDIUM_GPU, AllBlocksCleared,
|
||||||
BlockRemoved, BlockStored,
|
BlockRemoved, BlockStored,
|
||||||
@ -19,6 +18,103 @@ from vllm.v1.request import Request
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BlockHashToBlockMap:
|
||||||
|
"""
|
||||||
|
Cache of blocks that are used for prefix caching. It caches blocks
|
||||||
|
from hash directly to a block or multiple blocks
|
||||||
|
(i.e. {block_hash: KVCacheBlocks})
|
||||||
|
- Mostly block_hash maps to a single KVCacheBlock, and KVCacheBlocks
|
||||||
|
would simply be a KVCacheBlock.
|
||||||
|
- Otherwise, KVCacheBlocks is a dict from {block_id: KVCacheBlock}
|
||||||
|
|
||||||
|
A cached block is a full block with a block hash that can be used
|
||||||
|
for prefix caching.
|
||||||
|
The cached block may be used by running requests or in the
|
||||||
|
free_block_queue that could potentially be evicted.
|
||||||
|
|
||||||
|
NOTE #1: We currently don't de-duplicate the blocks in the cache,
|
||||||
|
meaning that if a block becomes full and is cached, we don't check
|
||||||
|
if there is already an identical block in the cache. This is because
|
||||||
|
we want to make sure the allocated block IDs won't change so that
|
||||||
|
block tables are append-only.
|
||||||
|
NOTE #2: The union type is introduced in order to reduce GC costs
|
||||||
|
from the inner dict.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._cache: dict[BlockHashWithGroupId,
|
||||||
|
Union[KVCacheBlock, dict[int, KVCacheBlock]]] = {}
|
||||||
|
|
||||||
|
def get_one_block(self,
|
||||||
|
key: BlockHashWithGroupId) -> Optional[KVCacheBlock]:
|
||||||
|
"""
|
||||||
|
Gets any block with the given block hash key.
|
||||||
|
"""
|
||||||
|
blocks = self._cache.get(key)
|
||||||
|
if blocks is not None:
|
||||||
|
if isinstance(blocks, KVCacheBlock):
|
||||||
|
return blocks
|
||||||
|
if isinstance(blocks, dict):
|
||||||
|
return next(iter(blocks.values()))
|
||||||
|
self._unexpected_blocks_type(blocks)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def insert(self, key: BlockHashWithGroupId, block: KVCacheBlock) -> None:
|
||||||
|
"""
|
||||||
|
Inserts the KVCacheBlock to the cache
|
||||||
|
"""
|
||||||
|
blocks = self._cache.get(key)
|
||||||
|
if blocks is None:
|
||||||
|
# When key is not found, attach a single block to the key
|
||||||
|
self._cache[key] = block
|
||||||
|
elif isinstance(blocks, KVCacheBlock):
|
||||||
|
# If there's a block with the same key, merge the original block
|
||||||
|
# and the new block into a dict
|
||||||
|
self._cache[key] = {blocks.block_id: blocks, block.block_id: block}
|
||||||
|
elif isinstance(blocks, dict):
|
||||||
|
# If it's already a dict, simply insert the block
|
||||||
|
blocks[block.block_id] = block
|
||||||
|
else:
|
||||||
|
self._unexpected_blocks_type(blocks)
|
||||||
|
|
||||||
|
def pop(self, key: BlockHashWithGroupId,
|
||||||
|
block_id: int) -> Optional[KVCacheBlock]:
|
||||||
|
"""
|
||||||
|
Checks if block_hash exists and pop block_id from the cache
|
||||||
|
"""
|
||||||
|
blocks = self._cache.pop(key, None)
|
||||||
|
if blocks is None:
|
||||||
|
# block_hash not found in the cache
|
||||||
|
return None
|
||||||
|
# TODO(Jialin): If key is found, block_id should always present
|
||||||
|
# in blocks. We currently keep the original behaviour for safety.
|
||||||
|
#
|
||||||
|
# Will add block_id == blocks.block_id assertion and
|
||||||
|
# use del blocks[block_id] instead as followup.
|
||||||
|
if isinstance(blocks, KVCacheBlock):
|
||||||
|
if blocks.block_id == block_id:
|
||||||
|
return blocks
|
||||||
|
# If the single block ID doesn't match, we should put the
|
||||||
|
# block back (it should happen rarely)
|
||||||
|
self._cache[key] = blocks
|
||||||
|
return None
|
||||||
|
if isinstance(blocks, dict):
|
||||||
|
# Try to pop block_id from the block dict, and if dict still
|
||||||
|
# contain blocks, put back to the cache.
|
||||||
|
block = blocks.pop(block_id, None)
|
||||||
|
if len(blocks) > 0:
|
||||||
|
self._cache[key] = blocks
|
||||||
|
return block
|
||||||
|
self._unexpected_blocks_type(blocks)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self._cache)
|
||||||
|
|
||||||
|
def _unexpected_blocks_type(self, blocks: Any) -> None:
|
||||||
|
raise AssertionError(f"Invalid KV cache block type {type(blocks)}")
|
||||||
|
|
||||||
|
|
||||||
class BlockPool:
|
class BlockPool:
|
||||||
"""BlockPool that manages KVCacheBlocks.
|
"""BlockPool that manages KVCacheBlocks.
|
||||||
It provides methods to allocate, free and cache the kv cache blocks. The
|
It provides methods to allocate, free and cache the kv cache blocks. The
|
||||||
@ -51,17 +147,9 @@ class BlockPool:
|
|||||||
# enabled).
|
# enabled).
|
||||||
self.free_block_queue = FreeKVCacheBlockQueue(self.blocks)
|
self.free_block_queue = FreeKVCacheBlockQueue(self.blocks)
|
||||||
|
|
||||||
# {block_hash: {block ID: block}}. A cached block is
|
# Cache for block lookup
|
||||||
# a full block with a block hash that can be used for prefix caching.
|
self.cached_block_hash_to_block: BlockHashToBlockMap = \
|
||||||
# The cached block may be used by running requests or in the
|
BlockHashToBlockMap()
|
||||||
# free_block_queue that could potentially be evicted.
|
|
||||||
# NOTE: We currently don't de-duplicate the blocks in the cache,
|
|
||||||
# meaning that if a block becomes full and is cached, we don't check
|
|
||||||
# if there is already an identical block in the cache. This is because
|
|
||||||
# we want to make sure the allocated block IDs won't change so that
|
|
||||||
# block tables are append-only.
|
|
||||||
self.cached_block_hash_to_block: dict[BlockHashWithGroupId, dict[
|
|
||||||
int, KVCacheBlock]] = defaultdict(dict)
|
|
||||||
|
|
||||||
# To represent a placeholder block with block_id=0.
|
# To represent a placeholder block with block_id=0.
|
||||||
# The ref_cnt of null_block is not maintained, needs special care to
|
# The ref_cnt of null_block is not maintained, needs special care to
|
||||||
@ -90,12 +178,11 @@ class BlockPool:
|
|||||||
for group_id in kv_cache_group_ids:
|
for group_id in kv_cache_group_ids:
|
||||||
block_hash_with_group_id = make_block_hash_with_group_id(
|
block_hash_with_group_id = make_block_hash_with_group_id(
|
||||||
block_hash, group_id)
|
block_hash, group_id)
|
||||||
cached_blocks_one_group = self.cached_block_hash_to_block.get(
|
block = self.cached_block_hash_to_block.get_one_block(
|
||||||
block_hash_with_group_id)
|
block_hash_with_group_id)
|
||||||
if not cached_blocks_one_group:
|
if not block:
|
||||||
return None
|
return None
|
||||||
first_block = next(iter(cached_blocks_one_group.values()))
|
cached_blocks.append(block)
|
||||||
cached_blocks.append(first_block)
|
|
||||||
return cached_blocks
|
return cached_blocks
|
||||||
|
|
||||||
def cache_full_blocks(
|
def cache_full_blocks(
|
||||||
@ -140,8 +227,8 @@ class BlockPool:
|
|||||||
block_hash_with_group_id = make_block_hash_with_group_id(
|
block_hash_with_group_id = make_block_hash_with_group_id(
|
||||||
block_hash, kv_cache_group_id)
|
block_hash, kv_cache_group_id)
|
||||||
blk.block_hash = block_hash_with_group_id
|
blk.block_hash = block_hash_with_group_id
|
||||||
self.cached_block_hash_to_block[block_hash_with_group_id][
|
self.cached_block_hash_to_block.insert(block_hash_with_group_id,
|
||||||
blk.block_id] = blk
|
blk)
|
||||||
if new_hashes is not None:
|
if new_hashes is not None:
|
||||||
new_hashes.append(maybe_convert_block_hash(block_hash))
|
new_hashes.append(maybe_convert_block_hash(block_hash))
|
||||||
|
|
||||||
@ -211,15 +298,14 @@ class BlockPool:
|
|||||||
if block_hash is None:
|
if block_hash is None:
|
||||||
# The block doesn't have hash, eviction is not needed
|
# The block doesn't have hash, eviction is not needed
|
||||||
return False
|
return False
|
||||||
blocks_by_id = self.cached_block_hash_to_block.get(block_hash)
|
|
||||||
if blocks_by_id is None:
|
if self.cached_block_hash_to_block.pop(block_hash,
|
||||||
# block_hash not found in cached_block_hash_to_block,
|
block.block_id) is None:
|
||||||
|
# block not found in cached_block_hash_to_block,
|
||||||
# eviction is not needed
|
# eviction is not needed
|
||||||
return False
|
return False
|
||||||
|
|
||||||
block.reset_hash()
|
block.reset_hash()
|
||||||
blocks_by_id.pop(block.block_id, None)
|
|
||||||
if len(blocks_by_id) == 0:
|
|
||||||
del self.cached_block_hash_to_block[block_hash]
|
|
||||||
|
|
||||||
if self.enable_kv_cache_events:
|
if self.enable_kv_cache_events:
|
||||||
# FIXME (Chen): Not sure whether we should return `hash_value`
|
# FIXME (Chen): Not sure whether we should return `hash_value`
|
||||||
@ -283,7 +369,7 @@ class BlockPool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# Remove all hashes so that no new blocks will hit.
|
# Remove all hashes so that no new blocks will hit.
|
||||||
self.cached_block_hash_to_block = defaultdict(dict)
|
self.cached_block_hash_to_block = BlockHashToBlockMap()
|
||||||
|
|
||||||
# Remove all hashes from all blocks.
|
# Remove all hashes from all blocks.
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user