mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 22:44:54 +08:00
[Core] Use KVCacheBlock as much as possible instead of dict[block_id, KVCacheBlock] (#24830)
Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com>
This commit is contained in:
parent
ae002924e9
commit
4f8c4b890a
@ -14,10 +14,11 @@ from vllm.multimodal.inputs import (MultiModalFeatureSpec,
|
||||
MultiModalKwargsItem, PlaceholderRange)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import sha256, sha256_cbor
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.block_pool import BlockHashToBlockMap, BlockPool
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
|
||||
get_block_hash, get_group_id,
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
|
||||
KVCacheBlock, get_block_hash,
|
||||
get_group_id,
|
||||
get_request_block_hasher,
|
||||
hash_block_tokens, init_none_hash,
|
||||
make_block_hash_with_group_id)
|
||||
@ -138,7 +139,7 @@ def test_prefill(hash_fn):
|
||||
blocks = manager.allocate_slots(req0, 55,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||
|
||||
# Check full block metadata
|
||||
parent_block_hash = None
|
||||
@ -171,7 +172,7 @@ def test_prefill(hash_fn):
|
||||
blocks = manager.allocate_slots(req1, num_new_tokens,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == ([5], )
|
||||
assert blocks is not None and blocks.get_block_ids() == ([5], )
|
||||
for block in computed_blocks.blocks[0]:
|
||||
assert block.ref_cnt == 2
|
||||
|
||||
@ -207,7 +208,7 @@ def test_prefill(hash_fn):
|
||||
blocks = manager.allocate_slots(req2, num_new_tokens,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == ([6], )
|
||||
assert blocks is not None and blocks.get_block_ids() == ([6], )
|
||||
|
||||
# Although we only have 6 free blocks, we have 8 blocks in
|
||||
# the free block queue due to lazy removal.
|
||||
@ -227,7 +228,9 @@ def test_prefill(hash_fn):
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
# This block ID order also checks the eviction order.
|
||||
assert blocks.get_block_ids() == ([7, 8, 9, 10, 4, 5, 6, 3, 2, 1], )
|
||||
assert blocks is not None and blocks.get_block_ids() == ([
|
||||
7, 8, 9, 10, 4, 5, 6, 3, 2, 1
|
||||
], )
|
||||
|
||||
assert free_block_queue.num_free_blocks == 0
|
||||
assert (free_block_queue.fake_free_list_head.next_free_block
|
||||
@ -261,8 +264,9 @@ def test_prefill_hybrid_model():
|
||||
blocks = manager.allocate_slots(req0, 55,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], [5, 6, 7,
|
||||
8], [9, 10, 11, 12])
|
||||
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], [
|
||||
5, 6, 7, 8
|
||||
], [9, 10, 11, 12])
|
||||
|
||||
# Check full block metadata
|
||||
parent_block_hash = None
|
||||
@ -298,7 +302,7 @@ def test_prefill_hybrid_model():
|
||||
blocks = manager.allocate_slots(req1, num_new_tokens,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == ([13], [14], [15])
|
||||
assert blocks is not None and blocks.get_block_ids() == ([13], [14], [15])
|
||||
for block_per_group in computed_blocks.blocks:
|
||||
for block in block_per_group:
|
||||
if block != manager.block_pool.null_block:
|
||||
@ -309,14 +313,15 @@ def test_prefill_hybrid_model():
|
||||
manager.free(req1)
|
||||
|
||||
cached_block_hash_to_block_bak = copy.copy(
|
||||
manager.block_pool.cached_block_hash_to_block)
|
||||
manager.block_pool.cached_block_hash_to_block._cache)
|
||||
|
||||
def test_partial_request_hit(request_id: str, hash_to_evict: list[bytes],
|
||||
def test_partial_request_hit(request_id: str,
|
||||
hash_to_evict: list[BlockHashWithGroupId],
|
||||
expect_hit_length: int):
|
||||
req = make_request(request_id, common_token_ids + unique_token_ids,
|
||||
block_size, sha256)
|
||||
for hash_with_group_id in hash_to_evict:
|
||||
manager.block_pool.cached_block_hash_to_block.pop(
|
||||
manager.block_pool.cached_block_hash_to_block._cache.pop(
|
||||
hash_with_group_id)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
||||
assert len(req.block_hashes) == 3
|
||||
@ -324,7 +329,7 @@ def test_prefill_hybrid_model():
|
||||
for block_per_group in computed_blocks.blocks:
|
||||
assert len(block_per_group) == num_computed_tokens // block_size
|
||||
for hash_with_group_id in hash_to_evict:
|
||||
manager.block_pool.cached_block_hash_to_block[
|
||||
manager.block_pool.cached_block_hash_to_block._cache[
|
||||
hash_with_group_id] = cached_block_hash_to_block_bak[
|
||||
hash_with_group_id]
|
||||
manager.free(req)
|
||||
@ -362,7 +367,8 @@ def test_prefill_hybrid_model():
|
||||
# total cache miss.
|
||||
# The cache hit length of full attention is 1 * block_size.
|
||||
# The cache hit length of sliding window is 2 * block_size.
|
||||
# Then it is cache miss as the two type of layers have different hit length.
|
||||
# Then it is cache miss as the two type of layers
|
||||
# have different hit length.
|
||||
test_partial_request_hit("8", [
|
||||
make_block_hash_with_group_id(block_hashes[2], 0),
|
||||
make_block_hash_with_group_id(block_hashes[0], 1),
|
||||
@ -406,7 +412,7 @@ def test_prefill_plp():
|
||||
blocks = manager.allocate_slots(req0, 55,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||
req0_block_hashes = [b.block_hash for b in blocks.blocks[0]]
|
||||
|
||||
# Check full block metadata
|
||||
@ -441,7 +447,7 @@ def test_prefill_plp():
|
||||
blocks = manager.allocate_slots(req1, num_new_tokens,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == ([5], )
|
||||
assert blocks is not None and blocks.get_block_ids() == ([5], )
|
||||
for block in computed_blocks.blocks[0]:
|
||||
assert block.ref_cnt == 2
|
||||
|
||||
@ -478,6 +484,7 @@ def test_prefill_plp():
|
||||
blocks = manager.allocate_slots(req2, 55,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks is not None
|
||||
block_ids = blocks.get_block_ids()
|
||||
# Duplicate cached blocks have different ids but same hashes vs request #0
|
||||
assert [b.block_hash for b in blocks.blocks[0]] == req0_block_hashes
|
||||
@ -513,7 +520,7 @@ def test_decode():
|
||||
blocks = manager.allocate_slots(req0, 55,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||
|
||||
# Append slots without allocating a new block.
|
||||
req0.num_computed_tokens = 55
|
||||
@ -558,7 +565,8 @@ def test_evict():
|
||||
blocks = manager.allocate_slots(req0, 5 * 16 + 7,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks[0]) == 6 # 5 full + 1 partial
|
||||
# 5 full + 1 partial
|
||||
assert blocks is not None and len(blocks.blocks[0]) == 6
|
||||
|
||||
# 3 blocks.
|
||||
req1 = make_request("1", list(range(last_token_id,
|
||||
@ -570,7 +578,7 @@ def test_evict():
|
||||
blocks = manager.allocate_slots(req1, 3 * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks[0]) == 3 # 3 full blocks
|
||||
assert blocks is not None and len(blocks.blocks[0]) == 3 # 3 full blocks
|
||||
last_token_id += 3 * 16
|
||||
|
||||
# 10 - (6 + 3) == 1
|
||||
@ -592,7 +600,7 @@ def test_evict():
|
||||
blocks = manager.allocate_slots(req2, 3,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == ([10], )
|
||||
assert blocks is not None and blocks.get_block_ids() == ([10], )
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 7
|
||||
|
||||
|
||||
@ -617,7 +625,7 @@ def test_hash_block_correct_reuse():
|
||||
blocks = manager.allocate_slots(req, num_tokens,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks[0]) == 1
|
||||
assert blocks is not None and len(blocks.blocks[0]) == 1
|
||||
|
||||
# Deallocate the block.
|
||||
manager.free(req)
|
||||
@ -631,7 +639,7 @@ def test_hash_block_correct_reuse():
|
||||
blocks = manager.allocate_slots(req, num_tokens - 1,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks[0]) == 1
|
||||
assert blocks is not None and len(blocks.blocks[0]) == 1
|
||||
|
||||
assert manager.block_pool.blocks[blocks.blocks[0]
|
||||
[0].block_id].block_hash is None
|
||||
@ -658,7 +666,7 @@ def test_computed_blocks_not_evicted():
|
||||
blocks = manager.allocate_slots(req0, num_tokens,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks[0]) == 1
|
||||
assert blocks is not None and len(blocks.blocks[0]) == 1
|
||||
assert blocks.blocks[0][0].block_id == 1
|
||||
|
||||
# Allocate another block.
|
||||
@ -670,7 +678,7 @@ def test_computed_blocks_not_evicted():
|
||||
blocks = manager.allocate_slots(req1, num_tokens,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks[0]) == 1
|
||||
assert blocks is not None and len(blocks.blocks[0]) == 1
|
||||
assert blocks.blocks[0][0].block_id == 2
|
||||
|
||||
# Free the blocks.
|
||||
@ -688,7 +696,7 @@ def test_computed_blocks_not_evicted():
|
||||
blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks[0]) == 1
|
||||
assert blocks is not None and len(blocks.blocks[0]) == 1
|
||||
assert blocks.blocks[0][0].block_id == 2
|
||||
|
||||
|
||||
@ -712,7 +720,7 @@ def test_basic_prefix_caching_disabled():
|
||||
blocks = manager.allocate_slots(req1, 10,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks[0]) == 3
|
||||
assert blocks is not None and len(blocks.blocks[0]) == 3
|
||||
|
||||
# Free the blocks.
|
||||
manager.free(req1)
|
||||
@ -726,7 +734,7 @@ def test_basic_prefix_caching_disabled():
|
||||
blocks = manager.allocate_slots(req2, 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks[0]) == 4
|
||||
assert blocks is not None and len(blocks.blocks[0]) == 4
|
||||
|
||||
# New requests should not have any blocks.
|
||||
req3 = make_request("3", list(range(4)), block_size, sha256)
|
||||
@ -773,7 +781,8 @@ def test_cache_blocks(hash_fn):
|
||||
assert len(block_pool.cached_block_hash_to_block) == 2
|
||||
assert all([block.block_hash is not None for block in blocks])
|
||||
|
||||
# Test that blocks that don't start from the beginning are cached correctly.
|
||||
# Test that blocks that don't start from the beginning are cached
|
||||
# correctly.
|
||||
blocks += [KVCacheBlock(block_id=2)]
|
||||
block_pool.cache_full_blocks(
|
||||
request=req,
|
||||
@ -1101,7 +1110,7 @@ def test_reset_prefix_cache():
|
||||
all_token_ids = full_block_token_ids + unique_token_ids
|
||||
req0 = make_request("0", all_token_ids, block_size, sha256)
|
||||
blocks = manager.allocate_slots(req0, 55)
|
||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||
|
||||
unique_token_ids = [4] * 7
|
||||
all_token_ids = full_block_token_ids + unique_token_ids
|
||||
@ -1112,7 +1121,7 @@ def test_reset_prefix_cache():
|
||||
blocks = manager.allocate_slots(req1, 7,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == ([5], )
|
||||
assert blocks is not None and blocks.get_block_ids() == ([5], )
|
||||
|
||||
# Failed to reset prefix cache because some blocks are not freed yet.
|
||||
assert not manager.reset_prefix_cache()
|
||||
@ -1168,49 +1177,41 @@ def test_maybe_evict_cached_block():
|
||||
# Manually add all blocks to cached_blocks
|
||||
for block, block_hash in zip(pool.blocks, block_hashes):
|
||||
block.block_hash = block_hash
|
||||
pool.cached_block_hash_to_block[block_hash][block.block_id] = block
|
||||
pool.cached_block_hash_to_block.insert(block_hash, block)
|
||||
|
||||
block0, block1, block2, block3 = pool.blocks
|
||||
assert pool.cached_block_hash_to_block == {
|
||||
assert pool.cached_block_hash_to_block._cache == {
|
||||
block_hash0: {
|
||||
block0.block_id: block0,
|
||||
block3.block_id: block3
|
||||
block3.block_id: block3,
|
||||
},
|
||||
block_hash1: {
|
||||
block1.block_id: block1
|
||||
},
|
||||
block_hash2: {
|
||||
block2.block_id: block2
|
||||
}
|
||||
block_hash1: block1,
|
||||
block_hash2: block2,
|
||||
}
|
||||
# Evict block1
|
||||
pool._maybe_evict_cached_block(block1)
|
||||
assert pool.cached_block_hash_to_block == {
|
||||
assert pool.cached_block_hash_to_block._cache == {
|
||||
block_hash0: {
|
||||
block0.block_id: block0,
|
||||
block3.block_id: block3
|
||||
},
|
||||
block_hash2: {
|
||||
block2.block_id: block2
|
||||
}
|
||||
block_hash2: block2,
|
||||
}
|
||||
# Evict block0: block_hash0 entry should NOT be removed, as block3
|
||||
# also use the same hash
|
||||
pool._maybe_evict_cached_block(block0)
|
||||
assert pool.cached_block_hash_to_block == {
|
||||
assert pool.cached_block_hash_to_block._cache == {
|
||||
block_hash0: {
|
||||
block3.block_id: block3
|
||||
},
|
||||
block_hash2: {
|
||||
block2.block_id: block2
|
||||
}
|
||||
block_hash2: block2,
|
||||
}
|
||||
# Evict block2
|
||||
pool._maybe_evict_cached_block(block2)
|
||||
assert pool.cached_block_hash_to_block == {block_hash0: {3: block3}}
|
||||
assert pool.cached_block_hash_to_block._cache == {block_hash0: {3: block3}}
|
||||
# Evict block3
|
||||
pool._maybe_evict_cached_block(block3)
|
||||
assert pool.cached_block_hash_to_block == {}
|
||||
assert pool.cached_block_hash_to_block._cache == {}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("blocks_to_cache", [2, 3, 10])
|
||||
@ -1374,7 +1375,7 @@ def test_eagle_with_sliding_window():
|
||||
# Evict the first block in the request
|
||||
assert manager.block_pool.get_cached_block(
|
||||
block_hash_first_block, kv_cache_group_ids=[0]) is not None
|
||||
manager.block_pool.cached_block_hash_to_block.pop(
|
||||
manager.block_pool.cached_block_hash_to_block._cache.pop(
|
||||
make_block_hash_with_group_id(block_hash_first_block, 0))
|
||||
|
||||
# New request
|
||||
@ -1386,3 +1387,78 @@ def test_eagle_with_sliding_window():
|
||||
# there will be no matched prefix.
|
||||
assert len(computed_blocks.blocks[0]) == 0
|
||||
assert num_tokens == 0
|
||||
|
||||
|
||||
def test_block_lookup_cache_single_block_per_key():
|
||||
cache = BlockHashToBlockMap()
|
||||
key0 = BlockHashWithGroupId(b"hash0")
|
||||
key1 = BlockHashWithGroupId(b"hash1")
|
||||
key2 = BlockHashWithGroupId(b"hash2")
|
||||
block0 = KVCacheBlock(0)
|
||||
block1 = KVCacheBlock(1)
|
||||
|
||||
assert cache.get_one_block(key0) is None
|
||||
assert cache.get_one_block(key1) is None
|
||||
assert cache.get_one_block(key2) is None
|
||||
# key0 inserted
|
||||
cache.insert(key0, block0)
|
||||
assert cache.get_one_block(key0) is block0
|
||||
assert cache.get_one_block(key1) is None
|
||||
assert cache.get_one_block(key2) is None
|
||||
# key1 inserted
|
||||
cache.insert(key1, block1)
|
||||
assert cache.get_one_block(key0) is block0
|
||||
assert cache.get_one_block(key1) is block1
|
||||
assert cache.get_one_block(key2) is None
|
||||
# No block poped due to block_id mismatch
|
||||
assert cache.pop(key0, 100) is None
|
||||
assert cache.get_one_block(key0) is block0
|
||||
assert cache.get_one_block(key1) is block1
|
||||
assert cache.get_one_block(key2) is None
|
||||
# block poped with (key0, block ID 0)
|
||||
assert cache.pop(key0, 0) is block0
|
||||
assert cache.get_one_block(key0) is None
|
||||
assert cache.get_one_block(key1) is block1
|
||||
assert cache.get_one_block(key2) is None
|
||||
# No block poped due to block_id mismatch
|
||||
assert cache.pop(key0, 1) is None
|
||||
assert cache.get_one_block(key0) is None
|
||||
assert cache.get_one_block(key1) is block1
|
||||
assert cache.get_one_block(key2) is None
|
||||
# block poped with (key1, block ID 1)
|
||||
assert cache.pop(key1, 1) is block1
|
||||
assert cache.get_one_block(key0) is None
|
||||
assert cache.get_one_block(key1) is None
|
||||
assert cache.get_one_block(key2) is None
|
||||
|
||||
|
||||
def test_block_lookup_cache_multi_blocks_per_key():
|
||||
cache = BlockHashToBlockMap()
|
||||
key0 = BlockHashWithGroupId(b"hash0")
|
||||
key1 = BlockHashWithGroupId(b"hash1")
|
||||
block00 = KVCacheBlock(0)
|
||||
block01 = KVCacheBlock(1)
|
||||
block10 = KVCacheBlock(10)
|
||||
block11 = KVCacheBlock(11)
|
||||
|
||||
assert cache.get_one_block(key0) is None
|
||||
assert cache.get_one_block(key1) is None
|
||||
|
||||
cache.insert(key0, block00)
|
||||
cache.insert(key0, block01)
|
||||
cache.insert(key1, block10)
|
||||
cache.insert(key1, block11)
|
||||
|
||||
assert cache.get_one_block(key0) is block00
|
||||
assert cache.pop(key0, 0) is block00
|
||||
assert cache.get_one_block(key0) is block01
|
||||
assert cache.pop(key0, 1) is block01
|
||||
assert cache.get_one_block(key0) is None
|
||||
assert cache.pop(key0, 2) is None
|
||||
|
||||
assert cache.get_one_block(key1) is block10
|
||||
assert cache.pop(key1, 10) is block10
|
||||
assert cache.get_one_block(key1) is block11
|
||||
assert cache.pop(key1, 11) is block11
|
||||
assert cache.get_one_block(key1) is None
|
||||
assert cache.pop(key1, 12) is None
|
||||
|
||||
@ -47,16 +47,15 @@ def test_chunked_local_attention_possible_cached_prefix():
|
||||
BlockHash(str(i).encode()) for i in range(len(block_is_cached))
|
||||
]
|
||||
|
||||
block_pool.cached_block_hash_to_block.clear()
|
||||
block_pool.cached_block_hash_to_block._cache.clear()
|
||||
|
||||
# Mock the block pool with the cached blocks
|
||||
for i, (block_hash,
|
||||
is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
|
||||
if is_cached:
|
||||
block_pool.cached_block_hash_to_block[
|
||||
make_block_hash_with_group_id(block_hash, 0)] = {
|
||||
i: block_pool.blocks[i + 10],
|
||||
}
|
||||
block_pool.cached_block_hash_to_block.insert(
|
||||
make_block_hash_with_group_id(block_hash, 0),
|
||||
block_pool.blocks[i + 10])
|
||||
|
||||
computed_blocks = manager.find_longest_cache_hit(
|
||||
block_hashes=block_hash_list,
|
||||
@ -112,16 +111,15 @@ def test_sliding_window_possible_cached_prefix():
|
||||
BlockHash(str(i).encode()) for i in range(len(block_is_cached))
|
||||
]
|
||||
|
||||
block_pool.cached_block_hash_to_block.clear()
|
||||
block_pool.cached_block_hash_to_block._cache.clear()
|
||||
|
||||
# Mock the block pool with the cached blocks
|
||||
for i, (block_hash,
|
||||
is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
|
||||
if is_cached:
|
||||
block_pool.cached_block_hash_to_block[
|
||||
make_block_hash_with_group_id(block_hash, 0)] = {
|
||||
i: block_pool.blocks[i + 10],
|
||||
}
|
||||
block_pool.cached_block_hash_to_block.insert(
|
||||
make_block_hash_with_group_id(block_hash, 0),
|
||||
block_pool.blocks[i + 10])
|
||||
|
||||
computed_blocks = manager.find_longest_cache_hit(
|
||||
block_hashes=block_hash_list,
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from vllm.distributed.kv_events import (MEDIUM_GPU, AllBlocksCleared,
|
||||
BlockRemoved, BlockStored,
|
||||
@ -19,6 +18,103 @@ from vllm.v1.request import Request
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BlockHashToBlockMap:
|
||||
"""
|
||||
Cache of blocks that are used for prefix caching. It caches blocks
|
||||
from hash directly to a block or multiple blocks
|
||||
(i.e. {block_hash: KVCacheBlocks})
|
||||
- Mostly block_hash maps to a single KVCacheBlock, and KVCacheBlocks
|
||||
would simply be a KVCacheBlock.
|
||||
- Otherwise, KVCacheBlocks is a dict from {block_id: KVCacheBlock}
|
||||
|
||||
A cached block is a full block with a block hash that can be used
|
||||
for prefix caching.
|
||||
The cached block may be used by running requests or in the
|
||||
free_block_queue that could potentially be evicted.
|
||||
|
||||
NOTE #1: We currently don't de-duplicate the blocks in the cache,
|
||||
meaning that if a block becomes full and is cached, we don't check
|
||||
if there is already an identical block in the cache. This is because
|
||||
we want to make sure the allocated block IDs won't change so that
|
||||
block tables are append-only.
|
||||
NOTE #2: The union type is introduced in order to reduce GC costs
|
||||
from the inner dict.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._cache: dict[BlockHashWithGroupId,
|
||||
Union[KVCacheBlock, dict[int, KVCacheBlock]]] = {}
|
||||
|
||||
def get_one_block(self,
|
||||
key: BlockHashWithGroupId) -> Optional[KVCacheBlock]:
|
||||
"""
|
||||
Gets any block with the given block hash key.
|
||||
"""
|
||||
blocks = self._cache.get(key)
|
||||
if blocks is not None:
|
||||
if isinstance(blocks, KVCacheBlock):
|
||||
return blocks
|
||||
if isinstance(blocks, dict):
|
||||
return next(iter(blocks.values()))
|
||||
self._unexpected_blocks_type(blocks)
|
||||
return None
|
||||
|
||||
def insert(self, key: BlockHashWithGroupId, block: KVCacheBlock) -> None:
|
||||
"""
|
||||
Inserts the KVCacheBlock to the cache
|
||||
"""
|
||||
blocks = self._cache.get(key)
|
||||
if blocks is None:
|
||||
# When key is not found, attach a single block to the key
|
||||
self._cache[key] = block
|
||||
elif isinstance(blocks, KVCacheBlock):
|
||||
# If there's a block with the same key, merge the original block
|
||||
# and the new block into a dict
|
||||
self._cache[key] = {blocks.block_id: blocks, block.block_id: block}
|
||||
elif isinstance(blocks, dict):
|
||||
# If it's already a dict, simply insert the block
|
||||
blocks[block.block_id] = block
|
||||
else:
|
||||
self._unexpected_blocks_type(blocks)
|
||||
|
||||
def pop(self, key: BlockHashWithGroupId,
|
||||
block_id: int) -> Optional[KVCacheBlock]:
|
||||
"""
|
||||
Checks if block_hash exists and pop block_id from the cache
|
||||
"""
|
||||
blocks = self._cache.pop(key, None)
|
||||
if blocks is None:
|
||||
# block_hash not found in the cache
|
||||
return None
|
||||
# TODO(Jialin): If key is found, block_id should always present
|
||||
# in blocks. We currently keep the original behaviour for safety.
|
||||
#
|
||||
# Will add block_id == blocks.block_id assertion and
|
||||
# use del blocks[block_id] instead as followup.
|
||||
if isinstance(blocks, KVCacheBlock):
|
||||
if blocks.block_id == block_id:
|
||||
return blocks
|
||||
# If the single block ID doesn't match, we should put the
|
||||
# block back (it should happen rarely)
|
||||
self._cache[key] = blocks
|
||||
return None
|
||||
if isinstance(blocks, dict):
|
||||
# Try to pop block_id from the block dict, and if dict still
|
||||
# contain blocks, put back to the cache.
|
||||
block = blocks.pop(block_id, None)
|
||||
if len(blocks) > 0:
|
||||
self._cache[key] = blocks
|
||||
return block
|
||||
self._unexpected_blocks_type(blocks)
|
||||
return None
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._cache)
|
||||
|
||||
def _unexpected_blocks_type(self, blocks: Any) -> None:
|
||||
raise AssertionError(f"Invalid KV cache block type {type(blocks)}")
|
||||
|
||||
|
||||
class BlockPool:
|
||||
"""BlockPool that manages KVCacheBlocks.
|
||||
It provides methods to allocate, free and cache the kv cache blocks. The
|
||||
@ -51,17 +147,9 @@ class BlockPool:
|
||||
# enabled).
|
||||
self.free_block_queue = FreeKVCacheBlockQueue(self.blocks)
|
||||
|
||||
# {block_hash: {block ID: block}}. A cached block is
|
||||
# a full block with a block hash that can be used for prefix caching.
|
||||
# The cached block may be used by running requests or in the
|
||||
# free_block_queue that could potentially be evicted.
|
||||
# NOTE: We currently don't de-duplicate the blocks in the cache,
|
||||
# meaning that if a block becomes full and is cached, we don't check
|
||||
# if there is already an identical block in the cache. This is because
|
||||
# we want to make sure the allocated block IDs won't change so that
|
||||
# block tables are append-only.
|
||||
self.cached_block_hash_to_block: dict[BlockHashWithGroupId, dict[
|
||||
int, KVCacheBlock]] = defaultdict(dict)
|
||||
# Cache for block lookup
|
||||
self.cached_block_hash_to_block: BlockHashToBlockMap = \
|
||||
BlockHashToBlockMap()
|
||||
|
||||
# To represent a placeholder block with block_id=0.
|
||||
# The ref_cnt of null_block is not maintained, needs special care to
|
||||
@ -90,12 +178,11 @@ class BlockPool:
|
||||
for group_id in kv_cache_group_ids:
|
||||
block_hash_with_group_id = make_block_hash_with_group_id(
|
||||
block_hash, group_id)
|
||||
cached_blocks_one_group = self.cached_block_hash_to_block.get(
|
||||
block = self.cached_block_hash_to_block.get_one_block(
|
||||
block_hash_with_group_id)
|
||||
if not cached_blocks_one_group:
|
||||
if not block:
|
||||
return None
|
||||
first_block = next(iter(cached_blocks_one_group.values()))
|
||||
cached_blocks.append(first_block)
|
||||
cached_blocks.append(block)
|
||||
return cached_blocks
|
||||
|
||||
def cache_full_blocks(
|
||||
@ -140,8 +227,8 @@ class BlockPool:
|
||||
block_hash_with_group_id = make_block_hash_with_group_id(
|
||||
block_hash, kv_cache_group_id)
|
||||
blk.block_hash = block_hash_with_group_id
|
||||
self.cached_block_hash_to_block[block_hash_with_group_id][
|
||||
blk.block_id] = blk
|
||||
self.cached_block_hash_to_block.insert(block_hash_with_group_id,
|
||||
blk)
|
||||
if new_hashes is not None:
|
||||
new_hashes.append(maybe_convert_block_hash(block_hash))
|
||||
|
||||
@ -211,15 +298,14 @@ class BlockPool:
|
||||
if block_hash is None:
|
||||
# The block doesn't have hash, eviction is not needed
|
||||
return False
|
||||
blocks_by_id = self.cached_block_hash_to_block.get(block_hash)
|
||||
if blocks_by_id is None:
|
||||
# block_hash not found in cached_block_hash_to_block,
|
||||
|
||||
if self.cached_block_hash_to_block.pop(block_hash,
|
||||
block.block_id) is None:
|
||||
# block not found in cached_block_hash_to_block,
|
||||
# eviction is not needed
|
||||
return False
|
||||
|
||||
block.reset_hash()
|
||||
blocks_by_id.pop(block.block_id, None)
|
||||
if len(blocks_by_id) == 0:
|
||||
del self.cached_block_hash_to_block[block_hash]
|
||||
|
||||
if self.enable_kv_cache_events:
|
||||
# FIXME (Chen): Not sure whether we should return `hash_value`
|
||||
@ -283,7 +369,7 @@ class BlockPool:
|
||||
return False
|
||||
|
||||
# Remove all hashes so that no new blocks will hit.
|
||||
self.cached_block_hash_to_block = defaultdict(dict)
|
||||
self.cached_block_hash_to_block = BlockHashToBlockMap()
|
||||
|
||||
# Remove all hashes from all blocks.
|
||||
for block in self.blocks:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user