[Core] Use KVCacheBlock as much as possible instead of dict[block_id, KVCacheBlock] (#24830)

Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com>
This commit is contained in:
Jialin Ouyang 2025-09-23 15:11:14 -07:00 committed by GitHub
parent ae002924e9
commit 4f8c4b890a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 247 additions and 87 deletions

View File

@ -14,10 +14,11 @@ from vllm.multimodal.inputs import (MultiModalFeatureSpec,
MultiModalKwargsItem, PlaceholderRange)
from vllm.sampling_params import SamplingParams
from vllm.utils import sha256, sha256_cbor
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.block_pool import BlockHashToBlockMap, BlockPool
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
get_block_hash, get_group_id,
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
KVCacheBlock, get_block_hash,
get_group_id,
get_request_block_hasher,
hash_block_tokens, init_none_hash,
make_block_hash_with_group_id)
@ -138,7 +139,7 @@ def test_prefill(hash_fn):
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], )
# Check full block metadata
parent_block_hash = None
@ -171,7 +172,7 @@ def test_prefill(hash_fn):
blocks = manager.allocate_slots(req1, num_new_tokens,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == ([5], )
assert blocks is not None and blocks.get_block_ids() == ([5], )
for block in computed_blocks.blocks[0]:
assert block.ref_cnt == 2
@ -207,7 +208,7 @@ def test_prefill(hash_fn):
blocks = manager.allocate_slots(req2, num_new_tokens,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == ([6], )
assert blocks is not None and blocks.get_block_ids() == ([6], )
# Although we only have 6 free blocks, we have 8 blocks in
# the free block queue due to lazy removal.
@ -227,7 +228,9 @@ def test_prefill(hash_fn):
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
# This block ID order also checks the eviction order.
assert blocks.get_block_ids() == ([7, 8, 9, 10, 4, 5, 6, 3, 2, 1], )
assert blocks is not None and blocks.get_block_ids() == ([
7, 8, 9, 10, 4, 5, 6, 3, 2, 1
], )
assert free_block_queue.num_free_blocks == 0
assert (free_block_queue.fake_free_list_head.next_free_block
@ -261,8 +264,9 @@ def test_prefill_hybrid_model():
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == ([1, 2, 3, 4], [5, 6, 7,
8], [9, 10, 11, 12])
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], [
5, 6, 7, 8
], [9, 10, 11, 12])
# Check full block metadata
parent_block_hash = None
@ -298,7 +302,7 @@ def test_prefill_hybrid_model():
blocks = manager.allocate_slots(req1, num_new_tokens,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == ([13], [14], [15])
assert blocks is not None and blocks.get_block_ids() == ([13], [14], [15])
for block_per_group in computed_blocks.blocks:
for block in block_per_group:
if block != manager.block_pool.null_block:
@ -309,14 +313,15 @@ def test_prefill_hybrid_model():
manager.free(req1)
cached_block_hash_to_block_bak = copy.copy(
manager.block_pool.cached_block_hash_to_block)
manager.block_pool.cached_block_hash_to_block._cache)
def test_partial_request_hit(request_id: str, hash_to_evict: list[bytes],
def test_partial_request_hit(request_id: str,
hash_to_evict: list[BlockHashWithGroupId],
expect_hit_length: int):
req = make_request(request_id, common_token_ids + unique_token_ids,
block_size, sha256)
for hash_with_group_id in hash_to_evict:
manager.block_pool.cached_block_hash_to_block.pop(
manager.block_pool.cached_block_hash_to_block._cache.pop(
hash_with_group_id)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert len(req.block_hashes) == 3
@ -324,7 +329,7 @@ def test_prefill_hybrid_model():
for block_per_group in computed_blocks.blocks:
assert len(block_per_group) == num_computed_tokens // block_size
for hash_with_group_id in hash_to_evict:
manager.block_pool.cached_block_hash_to_block[
manager.block_pool.cached_block_hash_to_block._cache[
hash_with_group_id] = cached_block_hash_to_block_bak[
hash_with_group_id]
manager.free(req)
@ -362,7 +367,8 @@ def test_prefill_hybrid_model():
# total cache miss.
# The cache hit length of full attention is 1 * block_size.
# The cache hit length of sliding window is 2 * block_size.
# Then it is cache miss as the two type of layers have different hit length.
# Then it is cache miss as the two type of layers
# have different hit length.
test_partial_request_hit("8", [
make_block_hash_with_group_id(block_hashes[2], 0),
make_block_hash_with_group_id(block_hashes[0], 1),
@ -406,7 +412,7 @@ def test_prefill_plp():
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], )
req0_block_hashes = [b.block_hash for b in blocks.blocks[0]]
# Check full block metadata
@ -441,7 +447,7 @@ def test_prefill_plp():
blocks = manager.allocate_slots(req1, num_new_tokens,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == ([5], )
assert blocks is not None and blocks.get_block_ids() == ([5], )
for block in computed_blocks.blocks[0]:
assert block.ref_cnt == 2
@ -478,6 +484,7 @@ def test_prefill_plp():
blocks = manager.allocate_slots(req2, 55,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks is not None
block_ids = blocks.get_block_ids()
# Duplicate cached blocks have different ids but same hashes vs request #0
assert [b.block_hash for b in blocks.blocks[0]] == req0_block_hashes
@ -513,7 +520,7 @@ def test_decode():
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], )
# Append slots without allocating a new block.
req0.num_computed_tokens = 55
@ -558,7 +565,8 @@ def test_evict():
blocks = manager.allocate_slots(req0, 5 * 16 + 7,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert len(blocks.blocks[0]) == 6 # 5 full + 1 partial
# 5 full + 1 partial
assert blocks is not None and len(blocks.blocks[0]) == 6
# 3 blocks.
req1 = make_request("1", list(range(last_token_id,
@ -570,7 +578,7 @@ def test_evict():
blocks = manager.allocate_slots(req1, 3 * 16,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert len(blocks.blocks[0]) == 3 # 3 full blocks
assert blocks is not None and len(blocks.blocks[0]) == 3 # 3 full blocks
last_token_id += 3 * 16
# 10 - (6 + 3) == 1
@ -592,7 +600,7 @@ def test_evict():
blocks = manager.allocate_slots(req2, 3,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == ([10], )
assert blocks is not None and blocks.get_block_ids() == ([10], )
assert manager.block_pool.free_block_queue.num_free_blocks == 7
@ -617,7 +625,7 @@ def test_hash_block_correct_reuse():
blocks = manager.allocate_slots(req, num_tokens,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert len(blocks.blocks[0]) == 1
assert blocks is not None and len(blocks.blocks[0]) == 1
# Deallocate the block.
manager.free(req)
@ -631,7 +639,7 @@ def test_hash_block_correct_reuse():
blocks = manager.allocate_slots(req, num_tokens - 1,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert len(blocks.blocks[0]) == 1
assert blocks is not None and len(blocks.blocks[0]) == 1
assert manager.block_pool.blocks[blocks.blocks[0]
[0].block_id].block_hash is None
@ -658,7 +666,7 @@ def test_computed_blocks_not_evicted():
blocks = manager.allocate_slots(req0, num_tokens,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert len(blocks.blocks[0]) == 1
assert blocks is not None and len(blocks.blocks[0]) == 1
assert blocks.blocks[0][0].block_id == 1
# Allocate another block.
@ -670,7 +678,7 @@ def test_computed_blocks_not_evicted():
blocks = manager.allocate_slots(req1, num_tokens,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert len(blocks.blocks[0]) == 1
assert blocks is not None and len(blocks.blocks[0]) == 1
assert blocks.blocks[0][0].block_id == 2
# Free the blocks.
@ -688,7 +696,7 @@ def test_computed_blocks_not_evicted():
blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert len(blocks.blocks[0]) == 1
assert blocks is not None and len(blocks.blocks[0]) == 1
assert blocks.blocks[0][0].block_id == 2
@ -712,7 +720,7 @@ def test_basic_prefix_caching_disabled():
blocks = manager.allocate_slots(req1, 10,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert len(blocks.blocks[0]) == 3
assert blocks is not None and len(blocks.blocks[0]) == 3
# Free the blocks.
manager.free(req1)
@ -726,7 +734,7 @@ def test_basic_prefix_caching_disabled():
blocks = manager.allocate_slots(req2, 16,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert len(blocks.blocks[0]) == 4
assert blocks is not None and len(blocks.blocks[0]) == 4
# New requests should not have any blocks.
req3 = make_request("3", list(range(4)), block_size, sha256)
@ -773,7 +781,8 @@ def test_cache_blocks(hash_fn):
assert len(block_pool.cached_block_hash_to_block) == 2
assert all([block.block_hash is not None for block in blocks])
# Test that blocks that don't start from the beginning are cached correctly.
# Test that blocks that don't start from the beginning are cached
# correctly.
blocks += [KVCacheBlock(block_id=2)]
block_pool.cache_full_blocks(
request=req,
@ -1101,7 +1110,7 @@ def test_reset_prefix_cache():
all_token_ids = full_block_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids, block_size, sha256)
blocks = manager.allocate_slots(req0, 55)
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
assert blocks is not None and blocks.get_block_ids() == ([1, 2, 3, 4], )
unique_token_ids = [4] * 7
all_token_ids = full_block_token_ids + unique_token_ids
@ -1112,7 +1121,7 @@ def test_reset_prefix_cache():
blocks = manager.allocate_slots(req1, 7,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == ([5], )
assert blocks is not None and blocks.get_block_ids() == ([5], )
# Failed to reset prefix cache because some blocks are not freed yet.
assert not manager.reset_prefix_cache()
@ -1168,49 +1177,41 @@ def test_maybe_evict_cached_block():
# Manually add all blocks to cached_blocks
for block, block_hash in zip(pool.blocks, block_hashes):
block.block_hash = block_hash
pool.cached_block_hash_to_block[block_hash][block.block_id] = block
pool.cached_block_hash_to_block.insert(block_hash, block)
block0, block1, block2, block3 = pool.blocks
assert pool.cached_block_hash_to_block == {
assert pool.cached_block_hash_to_block._cache == {
block_hash0: {
block0.block_id: block0,
block3.block_id: block3
block3.block_id: block3,
},
block_hash1: {
block1.block_id: block1
},
block_hash2: {
block2.block_id: block2
}
block_hash1: block1,
block_hash2: block2,
}
# Evict block1
pool._maybe_evict_cached_block(block1)
assert pool.cached_block_hash_to_block == {
assert pool.cached_block_hash_to_block._cache == {
block_hash0: {
block0.block_id: block0,
block3.block_id: block3
},
block_hash2: {
block2.block_id: block2
}
block_hash2: block2,
}
# Evict block0: block_hash0 entry should NOT be removed, as block3
# also use the same hash
pool._maybe_evict_cached_block(block0)
assert pool.cached_block_hash_to_block == {
assert pool.cached_block_hash_to_block._cache == {
block_hash0: {
block3.block_id: block3
},
block_hash2: {
block2.block_id: block2
}
block_hash2: block2,
}
# Evict block2
pool._maybe_evict_cached_block(block2)
assert pool.cached_block_hash_to_block == {block_hash0: {3: block3}}
assert pool.cached_block_hash_to_block._cache == {block_hash0: {3: block3}}
# Evict block3
pool._maybe_evict_cached_block(block3)
assert pool.cached_block_hash_to_block == {}
assert pool.cached_block_hash_to_block._cache == {}
@pytest.mark.parametrize("blocks_to_cache", [2, 3, 10])
@ -1374,7 +1375,7 @@ def test_eagle_with_sliding_window():
# Evict the first block in the request
assert manager.block_pool.get_cached_block(
block_hash_first_block, kv_cache_group_ids=[0]) is not None
manager.block_pool.cached_block_hash_to_block.pop(
manager.block_pool.cached_block_hash_to_block._cache.pop(
make_block_hash_with_group_id(block_hash_first_block, 0))
# New request
@ -1386,3 +1387,78 @@ def test_eagle_with_sliding_window():
# there will be no matched prefix.
assert len(computed_blocks.blocks[0]) == 0
assert num_tokens == 0
def test_block_lookup_cache_single_block_per_key():
cache = BlockHashToBlockMap()
key0 = BlockHashWithGroupId(b"hash0")
key1 = BlockHashWithGroupId(b"hash1")
key2 = BlockHashWithGroupId(b"hash2")
block0 = KVCacheBlock(0)
block1 = KVCacheBlock(1)
assert cache.get_one_block(key0) is None
assert cache.get_one_block(key1) is None
assert cache.get_one_block(key2) is None
# key0 inserted
cache.insert(key0, block0)
assert cache.get_one_block(key0) is block0
assert cache.get_one_block(key1) is None
assert cache.get_one_block(key2) is None
# key1 inserted
cache.insert(key1, block1)
assert cache.get_one_block(key0) is block0
assert cache.get_one_block(key1) is block1
assert cache.get_one_block(key2) is None
# No block poped due to block_id mismatch
assert cache.pop(key0, 100) is None
assert cache.get_one_block(key0) is block0
assert cache.get_one_block(key1) is block1
assert cache.get_one_block(key2) is None
# block poped with (key0, block ID 0)
assert cache.pop(key0, 0) is block0
assert cache.get_one_block(key0) is None
assert cache.get_one_block(key1) is block1
assert cache.get_one_block(key2) is None
# No block poped due to block_id mismatch
assert cache.pop(key0, 1) is None
assert cache.get_one_block(key0) is None
assert cache.get_one_block(key1) is block1
assert cache.get_one_block(key2) is None
# block poped with (key1, block ID 1)
assert cache.pop(key1, 1) is block1
assert cache.get_one_block(key0) is None
assert cache.get_one_block(key1) is None
assert cache.get_one_block(key2) is None
def test_block_lookup_cache_multi_blocks_per_key():
cache = BlockHashToBlockMap()
key0 = BlockHashWithGroupId(b"hash0")
key1 = BlockHashWithGroupId(b"hash1")
block00 = KVCacheBlock(0)
block01 = KVCacheBlock(1)
block10 = KVCacheBlock(10)
block11 = KVCacheBlock(11)
assert cache.get_one_block(key0) is None
assert cache.get_one_block(key1) is None
cache.insert(key0, block00)
cache.insert(key0, block01)
cache.insert(key1, block10)
cache.insert(key1, block11)
assert cache.get_one_block(key0) is block00
assert cache.pop(key0, 0) is block00
assert cache.get_one_block(key0) is block01
assert cache.pop(key0, 1) is block01
assert cache.get_one_block(key0) is None
assert cache.pop(key0, 2) is None
assert cache.get_one_block(key1) is block10
assert cache.pop(key1, 10) is block10
assert cache.get_one_block(key1) is block11
assert cache.pop(key1, 11) is block11
assert cache.get_one_block(key1) is None
assert cache.pop(key1, 12) is None

View File

@ -47,16 +47,15 @@ def test_chunked_local_attention_possible_cached_prefix():
BlockHash(str(i).encode()) for i in range(len(block_is_cached))
]
block_pool.cached_block_hash_to_block.clear()
block_pool.cached_block_hash_to_block._cache.clear()
# Mock the block pool with the cached blocks
for i, (block_hash,
is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
if is_cached:
block_pool.cached_block_hash_to_block[
make_block_hash_with_group_id(block_hash, 0)] = {
i: block_pool.blocks[i + 10],
}
block_pool.cached_block_hash_to_block.insert(
make_block_hash_with_group_id(block_hash, 0),
block_pool.blocks[i + 10])
computed_blocks = manager.find_longest_cache_hit(
block_hashes=block_hash_list,
@ -112,16 +111,15 @@ def test_sliding_window_possible_cached_prefix():
BlockHash(str(i).encode()) for i in range(len(block_is_cached))
]
block_pool.cached_block_hash_to_block.clear()
block_pool.cached_block_hash_to_block._cache.clear()
# Mock the block pool with the cached blocks
for i, (block_hash,
is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
if is_cached:
block_pool.cached_block_hash_to_block[
make_block_hash_with_group_id(block_hash, 0)] = {
i: block_pool.blocks[i + 10],
}
block_pool.cached_block_hash_to_block.insert(
make_block_hash_with_group_id(block_hash, 0),
block_pool.blocks[i + 10])
computed_blocks = manager.find_longest_cache_hit(
block_hashes=block_hash_list,

View File

@ -1,8 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from collections.abc import Iterable
from typing import Optional
from typing import Any, Optional, Union
from vllm.distributed.kv_events import (MEDIUM_GPU, AllBlocksCleared,
BlockRemoved, BlockStored,
@ -19,6 +18,103 @@ from vllm.v1.request import Request
logger = init_logger(__name__)
class BlockHashToBlockMap:
"""
Cache of blocks that are used for prefix caching. It caches blocks
from hash directly to a block or multiple blocks
(i.e. {block_hash: KVCacheBlocks})
- Mostly block_hash maps to a single KVCacheBlock, and KVCacheBlocks
would simply be a KVCacheBlock.
- Otherwise, KVCacheBlocks is a dict from {block_id: KVCacheBlock}
A cached block is a full block with a block hash that can be used
for prefix caching.
The cached block may be used by running requests or in the
free_block_queue that could potentially be evicted.
NOTE #1: We currently don't de-duplicate the blocks in the cache,
meaning that if a block becomes full and is cached, we don't check
if there is already an identical block in the cache. This is because
we want to make sure the allocated block IDs won't change so that
block tables are append-only.
NOTE #2: The union type is introduced in order to reduce GC costs
from the inner dict.
"""
def __init__(self):
self._cache: dict[BlockHashWithGroupId,
Union[KVCacheBlock, dict[int, KVCacheBlock]]] = {}
def get_one_block(self,
key: BlockHashWithGroupId) -> Optional[KVCacheBlock]:
"""
Gets any block with the given block hash key.
"""
blocks = self._cache.get(key)
if blocks is not None:
if isinstance(blocks, KVCacheBlock):
return blocks
if isinstance(blocks, dict):
return next(iter(blocks.values()))
self._unexpected_blocks_type(blocks)
return None
def insert(self, key: BlockHashWithGroupId, block: KVCacheBlock) -> None:
"""
Inserts the KVCacheBlock to the cache
"""
blocks = self._cache.get(key)
if blocks is None:
# When key is not found, attach a single block to the key
self._cache[key] = block
elif isinstance(blocks, KVCacheBlock):
# If there's a block with the same key, merge the original block
# and the new block into a dict
self._cache[key] = {blocks.block_id: blocks, block.block_id: block}
elif isinstance(blocks, dict):
# If it's already a dict, simply insert the block
blocks[block.block_id] = block
else:
self._unexpected_blocks_type(blocks)
def pop(self, key: BlockHashWithGroupId,
block_id: int) -> Optional[KVCacheBlock]:
"""
Checks if block_hash exists and pop block_id from the cache
"""
blocks = self._cache.pop(key, None)
if blocks is None:
# block_hash not found in the cache
return None
# TODO(Jialin): If key is found, block_id should always present
# in blocks. We currently keep the original behaviour for safety.
#
# Will add block_id == blocks.block_id assertion and
# use del blocks[block_id] instead as followup.
if isinstance(blocks, KVCacheBlock):
if blocks.block_id == block_id:
return blocks
# If the single block ID doesn't match, we should put the
# block back (it should happen rarely)
self._cache[key] = blocks
return None
if isinstance(blocks, dict):
# Try to pop block_id from the block dict, and if dict still
# contain blocks, put back to the cache.
block = blocks.pop(block_id, None)
if len(blocks) > 0:
self._cache[key] = blocks
return block
self._unexpected_blocks_type(blocks)
return None
def __len__(self) -> int:
return len(self._cache)
def _unexpected_blocks_type(self, blocks: Any) -> None:
raise AssertionError(f"Invalid KV cache block type {type(blocks)}")
class BlockPool:
"""BlockPool that manages KVCacheBlocks.
It provides methods to allocate, free and cache the kv cache blocks. The
@ -51,17 +147,9 @@ class BlockPool:
# enabled).
self.free_block_queue = FreeKVCacheBlockQueue(self.blocks)
# {block_hash: {block ID: block}}. A cached block is
# a full block with a block hash that can be used for prefix caching.
# The cached block may be used by running requests or in the
# free_block_queue that could potentially be evicted.
# NOTE: We currently don't de-duplicate the blocks in the cache,
# meaning that if a block becomes full and is cached, we don't check
# if there is already an identical block in the cache. This is because
# we want to make sure the allocated block IDs won't change so that
# block tables are append-only.
self.cached_block_hash_to_block: dict[BlockHashWithGroupId, dict[
int, KVCacheBlock]] = defaultdict(dict)
# Cache for block lookup
self.cached_block_hash_to_block: BlockHashToBlockMap = \
BlockHashToBlockMap()
# To represent a placeholder block with block_id=0.
# The ref_cnt of null_block is not maintained, needs special care to
@ -90,12 +178,11 @@ class BlockPool:
for group_id in kv_cache_group_ids:
block_hash_with_group_id = make_block_hash_with_group_id(
block_hash, group_id)
cached_blocks_one_group = self.cached_block_hash_to_block.get(
block = self.cached_block_hash_to_block.get_one_block(
block_hash_with_group_id)
if not cached_blocks_one_group:
if not block:
return None
first_block = next(iter(cached_blocks_one_group.values()))
cached_blocks.append(first_block)
cached_blocks.append(block)
return cached_blocks
def cache_full_blocks(
@ -140,8 +227,8 @@ class BlockPool:
block_hash_with_group_id = make_block_hash_with_group_id(
block_hash, kv_cache_group_id)
blk.block_hash = block_hash_with_group_id
self.cached_block_hash_to_block[block_hash_with_group_id][
blk.block_id] = blk
self.cached_block_hash_to_block.insert(block_hash_with_group_id,
blk)
if new_hashes is not None:
new_hashes.append(maybe_convert_block_hash(block_hash))
@ -211,15 +298,14 @@ class BlockPool:
if block_hash is None:
# The block doesn't have hash, eviction is not needed
return False
blocks_by_id = self.cached_block_hash_to_block.get(block_hash)
if blocks_by_id is None:
# block_hash not found in cached_block_hash_to_block,
if self.cached_block_hash_to_block.pop(block_hash,
block.block_id) is None:
# block not found in cached_block_hash_to_block,
# eviction is not needed
return False
block.reset_hash()
blocks_by_id.pop(block.block_id, None)
if len(blocks_by_id) == 0:
del self.cached_block_hash_to_block[block_hash]
if self.enable_kv_cache_events:
# FIXME (Chen): Not sure whether we should return `hash_value`
@ -283,7 +369,7 @@ class BlockPool:
return False
# Remove all hashes so that no new blocks will hit.
self.cached_block_hash_to_block = defaultdict(dict)
self.cached_block_hash_to_block = BlockHashToBlockMap()
# Remove all hashes from all blocks.
for block in self.blocks: