mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 14:27:19 +08:00
[KVEvent] User request.block_hash for parent block_hash (#30544)
Signed-off-by: Chen Zhang <zhangch99@outlook.com> Signed-off-by: Yifan Qiao <yifanqiao@berkeley.edu> Co-authored-by: Yifan Qiao <yifanqiao@berkeley.edu>
This commit is contained in:
parent
4ed11105d7
commit
538e830caa
@ -1356,6 +1356,69 @@ def test_kv_cache_events(blocks_to_cache: int):
|
||||
assert len(manager.block_pool.cached_block_hash_to_block) == 0
|
||||
|
||||
|
||||
def test_null_parent_block_hash():
|
||||
block_size = 1
|
||||
num_cached_blocks = 2
|
||||
num_full_blocks = 4
|
||||
|
||||
pool = BlockPool(
|
||||
num_gpu_blocks=8,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
enable_kv_cache_events=True,
|
||||
)
|
||||
|
||||
req = make_request(
|
||||
"req_null_parent",
|
||||
prompt_token_ids=[10, 11, 12, 13],
|
||||
block_size=block_size,
|
||||
hash_fn=sha256,
|
||||
)
|
||||
assert len(req.block_hashes) == num_full_blocks
|
||||
|
||||
# Physical parent is `null_block` (no hash), while the logical parent hash
|
||||
# still exists in `request.block_hashes[num_cached_blocks - 1]`.
|
||||
assert pool.null_block.block_hash is None
|
||||
new_blocks = pool.get_new_blocks(num_full_blocks - 1)
|
||||
blocks = [
|
||||
new_blocks[: num_cached_blocks - 1],
|
||||
pool.null_block, # physical parent
|
||||
*new_blocks[num_cached_blocks - 1 :],
|
||||
]
|
||||
|
||||
pool.cache_full_blocks(
|
||||
request=req,
|
||||
blocks=blocks,
|
||||
num_cached_blocks=num_cached_blocks,
|
||||
num_full_blocks=num_full_blocks,
|
||||
block_size=block_size,
|
||||
kv_cache_group_id=0,
|
||||
)
|
||||
|
||||
events = pool.take_events()
|
||||
assert len(events) == 1
|
||||
event = events[0]
|
||||
assert isinstance(event, BlockStored)
|
||||
|
||||
expected_parent = kv_cache_utils.maybe_convert_block_hash(
|
||||
req.block_hashes[num_cached_blocks - 1]
|
||||
)
|
||||
assert event.parent_block_hash == expected_parent
|
||||
assert event.parent_block_hash is not None
|
||||
|
||||
expected_new_hashes = [
|
||||
kv_cache_utils.maybe_convert_block_hash(h)
|
||||
for h in req.block_hashes[num_cached_blocks:num_full_blocks]
|
||||
]
|
||||
assert event.block_hashes == expected_new_hashes
|
||||
|
||||
# Ensure we didn't accidentally assign a hash to the null block.
|
||||
assert pool.null_block.block_hash is None
|
||||
# Sanity check: newly cached physical blocks should have hashes assigned.
|
||||
assert blocks[num_cached_blocks].block_hash is not None
|
||||
assert blocks[num_full_blocks - 1].block_hash is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("blocks_to_cache", [2, 3, 10])
|
||||
def test_kv_cache_events_with_lora(blocks_to_cache: int):
|
||||
"""Test BlockStored events contain correct lora_id when using LoRA requests."""
|
||||
|
||||
@ -270,10 +270,8 @@ class BlockPool:
|
||||
if num_cached_blocks == 0:
|
||||
parent_block_hash: ExternalBlockHash | None = None
|
||||
else:
|
||||
parent_block = blocks[num_cached_blocks - 1]
|
||||
assert parent_block.block_hash is not None
|
||||
parent_block_hash = maybe_convert_block_hash(
|
||||
get_block_hash(parent_block.block_hash)
|
||||
block_hashes[num_cached_blocks - 1]
|
||||
)
|
||||
|
||||
self.kv_event_queue.append(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user