From c280066f9dad0288a768a6234bea08171c4b88b9 Mon Sep 17 00:00:00 2001 From: Or Ozeri Date: Sat, 16 Aug 2025 02:52:52 +0300 Subject: [PATCH] [v1] Move block_hashes from KVCacheManager to Request.block_hashes (#19728) Signed-off-by: Or Ozeri --- tests/v1/core/test_async_scheduler.py | 22 +- tests/v1/core/test_kv_cache_utils.py | 50 ++-- tests/v1/core/test_prefix_caching.py | 225 ++++++++++-------- tests/v1/core/test_scheduler.py | 29 ++- .../core/test_single_type_kv_cache_manager.py | 2 - tests/v1/core/utils.py | 17 +- .../kv_connector/unit/test_nixl_connector.py | 2 + .../unit/test_remote_decode_lifecycle.py | 10 +- .../unit/test_remote_prefill_lifecycle.py | 17 +- tests/v1/kv_connector/unit/utils.py | 31 ++- vllm/utils/__init__.py | 18 ++ vllm/v1/core/block_pool.py | 75 ++---- vllm/v1/core/kv_cache_coordinator.py | 33 +-- vllm/v1/core/kv_cache_manager.py | 51 +--- vllm/v1/core/kv_cache_utils.py | 78 +++--- vllm/v1/core/sched/scheduler.py | 2 - vllm/v1/core/single_type_kv_cache_manager.py | 10 +- vllm/v1/engine/core.py | 22 +- vllm/v1/request.py | 22 +- 19 files changed, 381 insertions(+), 335 deletions(-) diff --git a/tests/v1/core/test_async_scheduler.py b/tests/v1/core/test_async_scheduler.py index 3ccefbd81cab5..3a9492269f9c9 100644 --- a/tests/v1/core/test_async_scheduler.py +++ b/tests/v1/core/test_async_scheduler.py @@ -7,6 +7,7 @@ import pytest from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import RequestStatus +from vllm.v1.utils import ConstantList from .utils import create_requests, create_scheduler @@ -140,7 +141,8 @@ def test_prefix_caching_for_prefill_dedup(): requests = create_requests(num_requests=5, num_tokens=num_prompt_tokens, max_tokens=3, - same_prompt=True) + same_prompt=True, + block_size=BLOCK_SIZE) requests_copy = requests.copy() # Two requests with the same prompt. @@ -188,7 +190,8 @@ def test_prefix_caching_for_multi_turn(): block_size=BLOCK_SIZE) requests = create_requests(num_requests=5, num_tokens=num_prompt_tokens, - max_tokens=num_output_tokens) + max_tokens=num_output_tokens, + block_size=BLOCK_SIZE) for req in requests: scheduler.add_request(req) @@ -208,14 +211,19 @@ def test_prefix_caching_for_multi_turn(): # Create next-turn requests whose prompts are the full output of the # previous turn. - next_turn_requests = create_requests( - num_requests=5, - num_tokens=num_prompt_tokens + num_output_tokens, - max_tokens=num_output_tokens, - ) + next_turn_requests = create_requests(num_requests=5, + num_tokens=num_prompt_tokens + + num_output_tokens, + max_tokens=num_output_tokens, + block_size=BLOCK_SIZE) for i, req in enumerate(next_turn_requests): req.prompt_token_ids = (requests[i].prompt_token_ids + list(requests[i].output_token_ids)) + req._all_token_ids = req.prompt_token_ids.copy() + req.all_token_ids = ConstantList(req._all_token_ids) + req.block_hashes = [] + req.block_hashes = req.get_hash_new_full_blocks() + # Schedule the next-turn requests. for req in next_turn_requests: scheduler.add_request(req) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 182ea2b2345c4..e0b91e6dd7ee4 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import importlib -from typing import Optional +from typing import Callable, Optional import pytest import torch @@ -19,7 +19,7 @@ from vllm.v1.core.kv_cache_utils import ( FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics, estimate_max_model_len, generate_block_hash_extra_keys, get_kv_cache_config, get_max_concurrency_for_kv_cache_config, - hash_block_tokens, hash_request_tokens, init_none_hash, + get_request_block_hasher, hash_block_tokens, init_none_hash, is_kv_cache_type_uniform, unify_kv_cache_configs) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheTensor, @@ -33,6 +33,8 @@ from vllm.v1.request import Request def make_request( request_id: str, prompt_token_ids: list[int], + block_size: int = 3, + hash_fn: Callable = hash, mm_positions: Optional[list[PlaceholderRange]] = None, mm_hashes: Optional[list[str]] = None, cache_salt: Optional[str] = None, @@ -49,18 +51,17 @@ def make_request( mm_item = MultiModalKwargsItem.from_elems([mm_elem]) mm_kwargs = [mm_item] * len(mm_positions) - return Request( - request_id=request_id, - prompt_token_ids=prompt_token_ids, - multi_modal_kwargs=mm_kwargs, - multi_modal_hashes=mm_hashes, - multi_modal_placeholders=mm_positions, - sampling_params=SamplingParams(max_tokens=17), - pooling_params=None, - eos_token_id=100, - lora_request=None, - cache_salt=cache_salt, - ) + return Request(request_id=request_id, + prompt_token_ids=prompt_token_ids, + multi_modal_kwargs=mm_kwargs, + multi_modal_hashes=mm_hashes, + multi_modal_placeholders=mm_positions, + sampling_params=SamplingParams(max_tokens=17), + pooling_params=None, + eos_token_id=100, + lora_request=None, + cache_salt=cache_salt, + block_hasher=get_request_block_hasher(block_size, hash_fn)) def new_kv_cache_spec(block_size=16, @@ -428,12 +429,14 @@ def test_hash_block_tokens(hash_fn): @pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) -def test_hash_request_tokens(hash_fn): +def test_request_block_hasher(hash_fn): import vllm.v1.core.kv_cache_utils init_none_hash(hash_fn) request = make_request( request_id="0", prompt_token_ids=[_ for _ in range(6)], + block_size=3, + hash_fn=hash_fn, mm_positions=[ PlaceholderRange(offset=0, length=3), PlaceholderRange(offset=3, length=3), @@ -441,9 +444,7 @@ def test_hash_request_tokens(hash_fn): mm_hashes=["hash1", "hash2"], ) - block_size = 3 - block_hashes = hash_request_tokens(hash_fn, block_size, request) - + block_hashes = request.block_hashes assert len(block_hashes) == 2 assert isinstance(block_hashes[0], vllm.v1.core.kv_cache_utils.BlockHash) assert isinstance(block_hashes[1], vllm.v1.core.kv_cache_utils.BlockHash) @@ -464,6 +465,8 @@ def test_hash_tokens_different_mm_input(hash_fn): request1 = make_request( request_id="0", prompt_token_ids=[_ for _ in range(6)], + block_size=3, + hash_fn=hash_fn, mm_positions=[ PlaceholderRange(offset=0, length=3), PlaceholderRange(offset=3, length=3), @@ -479,9 +482,8 @@ def test_hash_tokens_different_mm_input(hash_fn): ], mm_hashes=["hash3", "hash2"], ) - block_size = 3 - block_hashes1 = hash_request_tokens(hash_fn, block_size, request1) - block_hashes2 = hash_request_tokens(hash_fn, block_size, request2) + block_hashes1 = request1.block_hashes + block_hashes2 = request2.block_hashes assert block_hashes1[0] != block_hashes2[0] assert block_hashes1[1] != block_hashes2[1] @@ -493,12 +495,13 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn): request = make_request( request_id="0", prompt_token_ids=[_ for _ in range(6)], + block_size=3, + hash_fn=hash_fn, mm_positions=None, mm_hashes=None, ) - block_size = 3 - block_hashes = hash_request_tokens(hash_fn, block_size, request) + block_hashes = request.block_hashes assert len(block_hashes) == 2 assert block_hashes[0].token_ids == (0, 1, 2) @@ -858,6 +861,7 @@ def test_allocate_with_lookahead(): request = make_request( request_id="0", prompt_token_ids=[], + block_size=block_size, mm_positions=None, mm_hashes=None, ) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 87acdef220133..28cfca6767b1e 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -3,7 +3,7 @@ """Compare the with and without prefix caching.""" import copy -from typing import Optional +from typing import Callable, Optional import pytest import torch @@ -17,8 +17,9 @@ from vllm.utils import sha256, sha256_cbor_64bit from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_manager import KVCacheManager, Request from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, - KVCacheBlock, hash_block_tokens, - init_none_hash) + KVCacheBlock, + get_request_block_hasher, + hash_block_tokens, init_none_hash) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, SlidingWindowSpec) @@ -26,6 +27,8 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, def make_request( request_id: str, prompt_token_ids: list[int], + block_size: int, + hash_fn: Callable, mm_positions: Optional[list[PlaceholderRange]] = None, mm_hashes: Optional[list[str]] = None, prompt_logprobs: Optional[int] = None, @@ -43,19 +46,18 @@ def make_request( mm_item = MultiModalKwargsItem.from_elems([mm_elem]) mm_kwargs = [mm_item] * len(mm_positions) - return Request( - request_id=request_id, - prompt_token_ids=prompt_token_ids, - multi_modal_kwargs=mm_kwargs, - multi_modal_hashes=mm_hashes, - multi_modal_placeholders=mm_positions, - sampling_params=SamplingParams(max_tokens=17, - prompt_logprobs=prompt_logprobs), - pooling_params=None, - eos_token_id=100, - lora_request=None, - cache_salt=cache_salt, - ) + return Request(request_id=request_id, + prompt_token_ids=prompt_token_ids, + multi_modal_kwargs=mm_kwargs, + multi_modal_hashes=mm_hashes, + multi_modal_placeholders=mm_positions, + sampling_params=SamplingParams( + max_tokens=17, prompt_logprobs=prompt_logprobs), + pooling_params=None, + eos_token_id=100, + lora_request=None, + cache_salt=cache_salt, + block_hasher=get_request_block_hasher(block_size, hash_fn)) def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig: @@ -105,11 +107,11 @@ def make_kv_cache_config_hybrid_model(block_size: int, @pytest.mark.parametrize("hash_algo", ["sha256", "sha256_cbor_64bit", "hash"]) def test_prefill(hash_algo): + block_size = 16 manager = KVCacheManager( - make_kv_cache_config(16, 11), + make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, - caching_hash_algo=hash_algo, ) # choose the hash function according to the parameter @@ -123,9 +125,9 @@ def test_prefill(hash_algo): # Incomplete 1 block (7 tokens) unique_token_ids = [3] * 7 all_token_ids = common_token_ids + unique_token_ids - req0 = make_request("0", all_token_ids) + req0 = make_request("0", all_token_ids, block_size, hash_fn) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert len(manager.req_to_block_hashes[req0.request_id]) == 3 + assert len(req0.block_hashes) == 3 assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, @@ -152,9 +154,10 @@ def test_prefill(hash_algo): # Cache hit in the common prefix when the original block is still in use. # Incomplete 1 block (5 tokens) unique_token_ids = [3] * 5 - req1 = make_request("1", common_token_ids + unique_token_ids) + req1 = make_request("1", common_token_ids + unique_token_ids, block_size, + hash_fn) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert len(manager.req_to_block_hashes[req1.request_id]) == 3 + assert len(req1.block_hashes) == 3 assert computed_blocks.get_block_ids() == ([1, 2, 3], ) assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 @@ -187,9 +190,10 @@ def test_prefill(hash_algo): # Cache hit in the common prefix when the original block is already free. # Incomplete 1 block (6 tokens) unique_token_ids = [3] * 6 - req2 = make_request("2", common_token_ids + unique_token_ids) + req2 = make_request("2", common_token_ids + unique_token_ids, block_size, + hash_fn) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert len(manager.req_to_block_hashes[req2.request_id]) == 3 + assert len(req2.block_hashes) == 3 assert computed_blocks.get_block_ids() == ([1, 2, 3], ) assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 @@ -208,7 +212,7 @@ def test_prefill(hash_algo): manager.free(req2) # Cache miss and eviction. - req3 = make_request("3", [99] * (16 * 10)) + req3 = make_request("3", [99] * (16 * 10), block_size, hash_fn) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -242,9 +246,9 @@ def test_prefill_hybrid_model(): # Incomplete 1 block (7 tokens) unique_token_ids = [3] * 7 all_token_ids = common_token_ids + unique_token_ids - req0 = make_request("0", all_token_ids) + req0 = make_request("0", all_token_ids, block_size, hash_fn) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert len(manager.req_to_block_hashes[req0.request_id]) == 3 + assert len(req0.block_hashes) == 3 assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, @@ -274,9 +278,10 @@ def test_prefill_hybrid_model(): # Cache hit in the common prefix # Incomplete 1 block (5 tokens) unique_token_ids = [3] * 5 - req1 = make_request("1", common_token_ids + unique_token_ids) + req1 = make_request("1", common_token_ids + unique_token_ids, block_size, + hash_fn) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert len(manager.req_to_block_hashes[req1.request_id]) == 3 + assert len(req1.block_hashes) == 3 assert computed_blocks.get_block_ids() == ([1, 2, 3], [0, 6, 7], [0, 10, 11]) assert num_computed_tokens == 3 * 16 @@ -290,7 +295,7 @@ def test_prefill_hybrid_model(): if block != manager.block_pool.null_block: assert block.ref_cnt == 2 - block_hashes = manager.req_to_block_hashes[req1.request_id] + block_hashes = req1.block_hashes manager.free(req0) manager.free(req1) @@ -300,12 +305,13 @@ def test_prefill_hybrid_model(): def test_partial_request_hit(request_id: str, hash_to_evict: list[BlockHashWithGroupId], expect_hit_length: int): - req = make_request(request_id, common_token_ids + unique_token_ids) + req = make_request(request_id, common_token_ids + unique_token_ids, + block_size, hash) for hash_with_group_id in hash_to_evict: manager.block_pool.cached_block_hash_to_block.pop( hash_with_group_id) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) - assert len(manager.req_to_block_hashes[req.request_id]) == 3 + assert len(req.block_hashes) == 3 assert num_computed_tokens == expect_hit_length * block_size for block_per_group in computed_blocks.blocks: assert len(block_per_group) == num_computed_tokens // block_size @@ -364,8 +370,9 @@ def test_prefill_plp(): 2. Schedule non-plp request and validate blocks 3. Schedule plp request; no hit should occur; validate blocks ''' + block_size = 16 manager = KVCacheManager( - make_kv_cache_config(16, 11), + make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, ) @@ -380,9 +387,13 @@ def test_prefill_plp(): # Incomplete 1 block (7 tokens) unique_token_ids = [3] * 7 all_token_ids = common_token_ids + unique_token_ids - req0 = make_request("0", all_token_ids, prompt_logprobs=5) + req0 = make_request("0", + all_token_ids, + block_size, + hash_fn, + prompt_logprobs=5) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert len(manager.req_to_block_hashes[req0.request_id]) == 0 + assert len(req0.block_hashes) == 3 assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, @@ -411,9 +422,10 @@ def test_prefill_plp(): # Cache hit in the common prefix when the original block is still in use. # Incomplete 1 block (5 tokens) unique_token_ids = [3] * 5 - req1 = make_request("1", common_token_ids + unique_token_ids) + req1 = make_request("1", common_token_ids + unique_token_ids, block_size, + hash_fn) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) - assert len(manager.req_to_block_hashes[req1.request_id]) == 3 + assert len(req1.block_hashes) == 3 assert computed_blocks.get_block_ids() == ([1, 2, 3], ) assert num_computed_tokens == 3 * 16 num_new_tokens = 53 - 3 * 16 @@ -447,9 +459,11 @@ def test_prefill_plp(): unique_token_ids = [3] * 6 req2 = make_request("2", common_token_ids + unique_token_ids, + block_size, + hash_fn, prompt_logprobs=5) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert len(manager.req_to_block_hashes[req2.request_id]) == 0 + assert len(req2.block_hashes) == 3 assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 blocks = manager.allocate_slots(req2, 55, @@ -469,8 +483,9 @@ def test_prefill_plp(): def test_decode(): + block_size = 16 manager = KVCacheManager( - make_kv_cache_config(16, 11), + make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, ) @@ -481,7 +496,8 @@ def test_decode(): # Fully cache miss # Incomplete 1 block (7 tokens) unique_token_ids = [3] * 7 - req0 = make_request("0", common_token_ids + unique_token_ids) + req0 = make_request("0", common_token_ids + unique_token_ids, block_size, + hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -518,14 +534,15 @@ def test_decode(): def test_evict(): + block_size = 16 manager = KVCacheManager( - make_kv_cache_config(16, 11), + make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, ) last_token_id = 5 * 16 + 7 - req0 = make_request("0", list(range(last_token_id))) + req0 = make_request("0", list(range(last_token_id)), block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -536,7 +553,8 @@ def test_evict(): # 3 blocks. req1 = make_request("1", list(range(last_token_id, - last_token_id + 3 * 16))) + last_token_id + 3 * 16)), block_size, + hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -558,7 +576,7 @@ def test_evict(): ] == [10, 6, 5, 4, 3, 2, 1, 9, 8, 7] # Touch the first 2 blocks. - req2 = make_request("2", list(range(2 * 16 + 3))) + req2 = make_request("2", list(range(2 * 16 + 3)), block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert computed_blocks.get_block_ids() == ([1, 2], ) assert num_computed_tokens == 2 * 16 @@ -583,7 +601,7 @@ def test_hash_block_correct_reuse(): # Allocate 1 block and cache it. num_tokens = block_size * 1 - req = make_request("0", list(range(num_tokens))) + req = make_request("0", list(range(num_tokens)), block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -597,7 +615,7 @@ def test_hash_block_correct_reuse(): # Allocate a new block that's not full, make sure hash info on the # block is cleared. - req = make_request("1", list(range(num_tokens - 1))) + req = make_request("1", list(range(num_tokens - 1)), block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -624,7 +642,7 @@ def test_computed_blocks_not_evicted(): # Allocate a block and cache it. num_tokens = block_size * 1 - req0 = make_request("0", list(range(num_tokens))) + req0 = make_request("0", list(range(num_tokens)), block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -635,7 +653,8 @@ def test_computed_blocks_not_evicted(): assert blocks.blocks[0][0].block_id == 1 # Allocate another block. - req1 = make_request("1", list(range(num_tokens, num_tokens * 2))) + req1 = make_request("1", list(range(num_tokens, num_tokens * 2)), + block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -651,7 +670,7 @@ def test_computed_blocks_not_evicted(): # Now if we have a cache hit on the first block, we should evict the second # cached block rather than the first one. - req2 = make_request("2", list(range(num_tokens * 2))) + req2 = make_request("2", list(range(num_tokens * 2)), block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(computed_blocks.blocks[0]) == 1 assert computed_blocks.blocks[0][0].block_id == 1 @@ -675,7 +694,8 @@ def test_basic_prefix_caching_disabled(): enable_caching=False, ) - req1 = make_request("1", list(range(10))) # 2 blocks and some more + req1 = make_request("1", list(range(10)), block_size, + hash) # 2 blocks and some more computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks.blocks[0] @@ -689,7 +709,8 @@ def test_basic_prefix_caching_disabled(): manager.free(req1) # No caching. - req2 = make_request("2", list(range(16))) # shared prefix + req2 = make_request("2", list(range(16)), block_size, + hash) # shared prefix computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -699,7 +720,7 @@ def test_basic_prefix_caching_disabled(): assert len(blocks.blocks[0]) == 4 # New requests should not have any blocks. - req3 = make_request("3", list(range(4))) + req3 = make_request("3", list(range(4)), block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -727,20 +748,17 @@ def test_cache_blocks(hash_fn): # Block 1: [4, 5, 6, 7] # Block 2: [8, 9, 10, 11] # Block 3: [12, 13] - req = make_request("0", list(range(14))) + req = make_request("0", list(range(14)), block_size, hash_fn) # Test that blocks are cached correctly for 2 full blocks from the start. blocks = [KVCacheBlock(block_id=i) for i in range(2)] - block_hashes: list[BlockHash] = [] block_pool.cache_full_blocks( request=req, blocks=blocks, - block_hashes=block_hashes, num_cached_blocks=0, num_full_blocks=2, block_size=block_size, - hash_fn=hash_fn, kv_cache_group_id=0, ) @@ -752,11 +770,9 @@ def test_cache_blocks(hash_fn): block_pool.cache_full_blocks( request=req, blocks=blocks, - block_hashes=block_hashes, num_cached_blocks=2, num_full_blocks=3, block_size=block_size, - hash_fn=hash_fn, kv_cache_group_id=0, ) assert len(block_pool.cached_block_hash_to_block) == 3 @@ -775,23 +791,20 @@ def test_cache_blocks_multi_group(): # Block 1/5: [4, 5, 6, 7] # Block 2/6: [8, 9, 10, 11] # Block 3/7: [12, 13] - req = make_request("0", list(range(14))) + req = make_request("0", list(range(14)), block_size, hash) # Cache the blocks for group 0. blocks = [KVCacheBlock(block_id=i) for i in range(2)] - block_hashes: list[BlockHash] = [] block_pool.cache_full_blocks( request=req, blocks=blocks, - block_hashes=block_hashes, num_cached_blocks=0, num_full_blocks=2, block_size=block_size, - hash_fn=hash, kv_cache_group_id=0, ) assert len(block_pool.cached_block_hash_to_block) == 2 - assert len(block_hashes) == 2 + assert len(req.block_hashes) == 3 assert all([block.block_hash is not None for block in blocks]) # Cache the blocks for group 1. @@ -799,38 +812,36 @@ def test_cache_blocks_multi_group(): block_pool.cache_full_blocks( request=req, blocks=blocks, - block_hashes=block_hashes, num_cached_blocks=0, num_full_blocks=3, block_size=block_size, - hash_fn=hash, kv_cache_group_id=1, ) assert len(block_pool.cached_block_hash_to_block) == 5 - assert len(block_hashes) == 3 + assert len(req.block_hashes) == 3 assert all([block.block_hash is not None for block in blocks]) # Block hash 0: hit for group 0 and 1 # Block hash 1: hit for group 0 and 1 # Block hash 2: hit for group 1 - assert block_pool.get_cached_block(block_hashes[0], + assert block_pool.get_cached_block(req.block_hashes[0], kv_cache_group_ids=[0]) is not None - assert block_pool.get_cached_block(block_hashes[1], + assert block_pool.get_cached_block(req.block_hashes[1], kv_cache_group_ids=[0]) is not None - assert block_pool.get_cached_block(block_hashes[2], + assert block_pool.get_cached_block(req.block_hashes[2], kv_cache_group_ids=[0]) is None - assert block_pool.get_cached_block(block_hashes[0], + assert block_pool.get_cached_block(req.block_hashes[0], kv_cache_group_ids=[1]) is not None - assert block_pool.get_cached_block(block_hashes[1], + assert block_pool.get_cached_block(req.block_hashes[1], kv_cache_group_ids=[1]) is not None - assert block_pool.get_cached_block(block_hashes[2], + assert block_pool.get_cached_block(req.block_hashes[2], kv_cache_group_ids=[1]) is not None - assert block_pool.get_cached_block(block_hashes[0], + assert block_pool.get_cached_block(req.block_hashes[0], kv_cache_group_ids=[0, 1]) is not None - assert block_pool.get_cached_block(block_hashes[1], + assert block_pool.get_cached_block(req.block_hashes[1], kv_cache_group_ids=[0, 1]) is not None - assert block_pool.get_cached_block(block_hashes[2], + assert block_pool.get_cached_block(req.block_hashes[2], kv_cache_group_ids=[0, 1]) is None @@ -838,8 +849,9 @@ def test_mm_prefix_caching(): """ This tests that the multi-modal prefix caching is correct. """ + block_size = 16 manager = KVCacheManager( - make_kv_cache_config(16, 11), + make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, ) @@ -865,6 +877,8 @@ def test_mm_prefix_caching(): mm_hashes = common_mm_hashes + ["ccc"] req0 = make_request("0", all_token_ids, + block_size, + hash, mm_positions=mm_positions, mm_hashes=mm_hashes) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) @@ -872,7 +886,7 @@ def test_mm_prefix_caching(): # Completed block should have hashes with extra keys. assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - block_hashes = manager.req_to_block_hashes[req0.request_id] + block_hashes = req0.block_hashes assert len(block_hashes) == 3 assert block_hashes[0].extra_keys == ("aaa", ) assert block_hashes[1].extra_keys == ("aaa", "bbb") @@ -905,6 +919,8 @@ def test_mm_prefix_caching(): mm_hashes = common_mm_hashes + ["ccc"] req1 = make_request("1", all_token_ids, + block_size, + hash, mm_positions=mm_positions, mm_hashes=mm_hashes) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) @@ -927,13 +943,13 @@ def test_cache_key_salting(): # 3 complete blocks and an incomplete block with 11 tokens. common_token_ids = [i for i in range(3) for _ in range(block_size)] token_ids = common_token_ids + [3] * 11 - req0 = make_request("0", token_ids, cache_salt="salt1") + req0 = make_request("0", token_ids, block_size, hash, cache_salt="salt1") computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) # Completed block should have hashes with extra keys. assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 - block_hashes = manager.req_to_block_hashes[req0.request_id] + block_hashes = req0.block_hashes assert len(block_hashes) == 3 assert block_hashes[0].extra_keys == ("salt1", ) assert block_hashes[1].extra_keys is None @@ -959,7 +975,7 @@ def test_cache_key_salting(): # Test cache hit with a new request that has the same salt. token_ids = common_token_ids + [4] * 11 - req1 = make_request("1", token_ids, cache_salt="salt1") + req1 = make_request("1", token_ids, block_size, hash, cache_salt="salt1") computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) # Should match only a prefix of 3 blocks. assert len(computed_blocks.blocks[0]) == 3 @@ -967,11 +983,11 @@ def test_cache_key_salting(): # Test cache miss with same content but different salt. token_ids = common_token_ids + [4] * 11 - req2 = make_request("2", token_ids, cache_salt="salt2") + req2 = make_request("2", token_ids, block_size, hash, cache_salt="salt2") computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(computed_blocks.blocks[0]) == 0 assert num_computed_tokens == 0 - block_hashes = manager.req_to_block_hashes[req2.request_id] + block_hashes = req2.block_hashes assert len(block_hashes) == 3 assert block_hashes[0].extra_keys == ("salt2", ) @@ -992,7 +1008,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # Complete 3 blocks (48 tokens) # | Common-0 | Common-1 | Common-2 | ... | common_token_ids = [i for i in range(3) for _ in range(16)] - req0 = make_request("0", common_token_ids) + req0 = make_request("0", common_token_ids, block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -1003,7 +1019,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): req0.request_id] # | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... | - req1 = make_request("1", common_token_ids * 2) + req1 = make_request("1", common_token_ids * 2, block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert computed_blocks.blocks[0] == block_part0 assert num_computed_tokens == 3 * 16 @@ -1020,19 +1036,19 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) | # | Req1-5(F)| Req2-0 | Req2-1 | ... | - req2 = make_request("2", [7] * block_size * 2) + req2 = make_request("2", [7] * block_size * 2, block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 manager.allocate_slots(req2, block_size * 2, - len(computed_blocks.blocks[0]) * 16, + len(computed_blocks.blocks[0]) * block_size, computed_blocks) # Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed, # but it cannot be allocated due to insufficient free blocks (2). # In this case, the ref_cnt of the computed blocks should not be changed. assert manager.block_pool.free_block_queue.num_free_blocks == 5 - req3 = make_request("3", common_token_ids * 3) + req3 = make_request("3", common_token_ids * 3, block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert computed_blocks.blocks[0] == block_part1 assert num_computed_tokens == 6 * 16 @@ -1047,8 +1063,9 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): def test_reset_prefix_cache(): + block_size = 16 manager = KVCacheManager( - make_kv_cache_config(16, 11), + make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, ) @@ -1056,15 +1073,15 @@ def test_reset_prefix_cache(): full_block_token_ids = [i for i in range(3) for _ in range(16)] unique_token_ids = [3] * 7 all_token_ids = full_block_token_ids + unique_token_ids - req0 = make_request("0", all_token_ids) + req0 = make_request("0", all_token_ids, block_size, hash) blocks = manager.allocate_slots(req0, 55) assert blocks.get_block_ids() == ([1, 2, 3, 4], ) unique_token_ids = [4] * 7 all_token_ids = full_block_token_ids + unique_token_ids - req1 = make_request("1", all_token_ids) + req1 = make_request("1", all_token_ids, block_size, hash) computed_blocks, _ = manager.get_computed_blocks(req1) - assert len(manager.req_to_block_hashes[req1.request_id]) == 3 + assert len(req1.block_hashes) == 3 assert len(computed_blocks.blocks[0]) == 3 blocks = manager.allocate_slots(req1, 7, len(computed_blocks.blocks[0]) * 16, @@ -1086,8 +1103,9 @@ def test_reset_prefix_cache(): def test_prefix_cache_stats_disabled(): """Test that prefix_cache_stats is None when log_stats is False.""" + block_size = 16 manager = KVCacheManager( - make_kv_cache_config(16, 11), + make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, log_stats=False, # Disable logging stats @@ -1095,7 +1113,7 @@ def test_prefix_cache_stats_disabled(): assert manager.prefix_cache_stats is None # Call all functions that check whether log_stats is disabled. - req = make_request("0", list(range(16))) + req = make_request("0", list(range(16)), block_size, hash) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -1192,7 +1210,7 @@ def test_kv_cache_events(blocks_to_cache: int): ) num_tokens = block_size * blocks_to_cache - req0 = make_request("0", list(range(num_tokens))) + req0 = make_request("0", list(range(num_tokens)), block_size, hash) _ = manager.allocate_slots(req0, num_tokens) events = manager.take_events() @@ -1208,7 +1226,7 @@ def test_kv_cache_events(blocks_to_cache: int): # Should see block_to_cache number of removed block events and a new block # stored event manager.free(req0) - req1 = make_request("1", list(range(num_tokens))) + req1 = make_request("1", list(range(num_tokens)), block_size, hash) _ = manager.allocate_slots(req1, num_tokens) events = manager.take_events() @@ -1242,7 +1260,7 @@ def test_eagle_enabled_removes_last_block(): # Request with 3 full blocks (48 tokens) token_ids = [0] * (3 * block_size) - req = make_request("divisible_request", token_ids) + req = make_request("divisible_request", token_ids, block_size, hash) # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) @@ -1252,7 +1270,7 @@ def test_eagle_enabled_removes_last_block(): manager.free(req) # New request with same tokens + Eagle enabled - req_eagle = make_request("eagle_divisible", token_ids) + req_eagle = make_request("eagle_divisible", token_ids, block_size, hash) computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) # Should retain 1 block: @@ -1273,7 +1291,7 @@ def test_eagle_with_partial_blocks(): ) # 2 full blocks + 5 tokens (non-divisible length) token_ids = [0] * (2 * block_size + 5) - req = make_request("partial_block_test", token_ids) + req = make_request("partial_block_test", token_ids, block_size, hash) # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) @@ -1283,7 +1301,7 @@ def test_eagle_with_partial_blocks(): manager.free(req) # New request with Eagle enabled - req_eagle = make_request("partial_eagle", token_ids) + req_eagle = make_request("partial_eagle", token_ids, block_size, hash) computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) # Original match: 2 full blocks → Eagle removes 1 → 1 remaining assert len(computed_blocks.blocks[0]) == 1 @@ -1314,7 +1332,7 @@ def test_eagle_with_sliding_window(): # 2 full blocks + 5 tokens (non-divisible length) token_ids = [0] * (2 * block_size + 5) - req = make_request("partial_block_test", token_ids) + req = make_request("partial_block_test", token_ids, block_size, hash) # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) @@ -1322,12 +1340,12 @@ def test_eagle_with_sliding_window(): len(computed_blocks.blocks[0]) * 16, computed_blocks) # record the block hash of the first block in the request for later use - block_hash_first_block = manager.req_to_block_hashes[req.request_id][0] + block_hash_first_block = req.block_hashes[0] assert block_hash_first_block is not None manager.free(req) # New request with Eagle enabled - req_eagle = make_request("partial_eagle", token_ids) + req_eagle = make_request("partial_eagle", token_ids, block_size, hash) computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) # Original match: 2 full blocks → Eagle removes 1 → 1 remaining assert len(computed_blocks.blocks[0]) == 1 @@ -1340,7 +1358,8 @@ def test_eagle_with_sliding_window(): BlockHashWithGroupId(block_hash_first_block, 0)) # New request - req_after_evict = make_request("partial_eagle_after_evict", token_ids) + req_after_evict = make_request("partial_eagle_after_evict", token_ids, + block_size, hash) computed_blocks, num_tokens = manager.get_computed_blocks(req_after_evict) # Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is # not considered. But after dropping the last matched block due to eagle, diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 1c7dd0ca90b7e..ac70c90d92add 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -589,7 +589,7 @@ def test_preempt_during_execution(): block_size=16, num_blocks=11, enable_prefix_caching=False) - requests = create_requests(num_requests=2, num_tokens=80) + requests = create_requests(num_requests=2, num_tokens=80, block_size=16) # Schedule the first request. scheduler.add_request(requests[0]) @@ -762,7 +762,7 @@ def _assert_right_scheduler_output( def _assert_right_kv_cache_manager( scheduler: Scheduler, - req_ids: list[str], + requests: list[Request], num_tokens: int, block_size: int, num_requests: int, @@ -772,12 +772,12 @@ def _assert_right_kv_cache_manager( # Make sure the request stats are right. EXPECTED_TOTAL_BLOCKS = num_tokens // block_size - for req_id in req_ids: + for req in requests: blocks = (scheduler.kv_cache_manager.coordinator. - single_type_managers[0].req_to_blocks[req_id]) - hashes = scheduler.kv_cache_manager.req_to_block_hashes[req_id] + single_type_managers[0].req_to_blocks[req.request_id]) + hashes = req.block_hashes assert (scheduler.kv_cache_manager.coordinator.single_type_managers[0]. - num_cached_block[req_id] == EXPECTED_TOTAL_BLOCKS) + num_cached_block[req.request_id] == EXPECTED_TOTAL_BLOCKS) assert len(blocks) == EXPECTED_TOTAL_BLOCKS assert len(hashes) == EXPECTED_TOTAL_BLOCKS @@ -840,7 +840,8 @@ def test_kv_connector_basic(): MAX_TOKENS = 3 requests = create_requests(num_requests=NUM_REQUESTS, num_tokens=NUM_TOKENS, - max_tokens=MAX_TOKENS) + max_tokens=MAX_TOKENS, + block_size=BLOCK_SIZE) req_ids = [] req_to_index = {} for i, request in enumerate(requests): @@ -868,7 +869,7 @@ def test_kv_connector_basic(): ) # Ensure KVCacheManager is correct. - _assert_right_kv_cache_manager(scheduler, req_ids, NUM_TOKENS, BLOCK_SIZE, + _assert_right_kv_cache_manager(scheduler, requests, NUM_TOKENS, BLOCK_SIZE, NUM_REQUESTS, NUM_TOTAL_BLOCKS) # Continue Generation until done. @@ -886,7 +887,8 @@ def test_kv_connector_basic(): NUM_TOKENS = NUM_TOKENS_PREFIX * 2 requests = create_requests(num_requests=NUM_REQUESTS, num_tokens=NUM_TOKENS, - max_tokens=MAX_TOKENS) + max_tokens=MAX_TOKENS, + block_size=BLOCK_SIZE) req_ids = [] req_to_index = {} for i, request in enumerate(requests): @@ -915,7 +917,7 @@ def test_kv_connector_basic(): NUM_MATCHED_NEW_TOKENS)) # Ensure KVCacheManager is correct. - _assert_right_kv_cache_manager(scheduler, req_ids, NUM_TOKENS, BLOCK_SIZE, + _assert_right_kv_cache_manager(scheduler, requests, NUM_TOKENS, BLOCK_SIZE, NUM_REQUESTS, NUM_TOTAL_BLOCKS) # Continue Generation until done. @@ -953,7 +955,8 @@ def test_kv_connector_unable_to_allocate(): MAX_TOKENS = 2 requests = create_requests(num_requests=NUM_REQUESTS, num_tokens=NUM_TOKENS, - max_tokens=MAX_TOKENS) + max_tokens=MAX_TOKENS, + block_size=BLOCK_SIZE) req_ids = [] req_to_index = {} for i, request in enumerate(requests): @@ -1034,7 +1037,8 @@ def test_kv_connector_handles_preemption(): MAX_TOKENS = BLOCK_SIZE * 2 requests = create_requests(num_requests=NUM_REQUESTS, num_tokens=NUM_TOKENS, - max_tokens=MAX_TOKENS) + max_tokens=MAX_TOKENS, + block_size=BLOCK_SIZE) req_ids = [] req_to_index = {} for i, request in enumerate(requests): @@ -1162,7 +1166,6 @@ def assert_scheduler_empty(scheduler: Scheduler): # KVCache Manager. assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. req_to_blocks) == 0 - assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0 assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. num_cached_block) == 0 num_free_blocks = ( diff --git a/tests/v1/core/test_single_type_kv_cache_manager.py b/tests/v1/core/test_single_type_kv_cache_manager.py index b67c05bd7ac10..7dcebba491fab 100644 --- a/tests/v1/core/test_single_type_kv_cache_manager.py +++ b/tests/v1/core/test_single_type_kv_cache_manager.py @@ -17,7 +17,6 @@ from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, def get_sliding_window_manager(sliding_window_spec, block_pool): return SlidingWindowManager(sliding_window_spec, block_pool, - caching_hash_fn=lambda x: x, kv_cache_group_id=0) @@ -25,7 +24,6 @@ def get_chunked_local_attention_manager(chunked_local_attention_spec, block_pool): return ChunkedLocalAttentionManager(chunked_local_attention_spec, block_pool, - caching_hash_fn=lambda x: x, kv_cache_group_id=0) diff --git a/tests/v1/core/utils.py b/tests/v1/core/utils.py index 484afe61fc3fb..52093d3d381ae 100644 --- a/tests/v1/core/utils.py +++ b/tests/v1/core/utils.py @@ -10,6 +10,8 @@ from vllm.multimodal.inputs import (MultiModalBatchedField, MultiModalFieldElem, MultiModalKwargsItem, PlaceholderRange) from vllm.sampling_params import SamplingParams +from vllm.v1.core.kv_cache_utils import (get_request_block_hasher, + init_none_hash) from vllm.v1.core.sched.async_scheduler import AsyncScheduler from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, @@ -114,6 +116,9 @@ def create_scheduler( ) +_none_hash_initialized = False + + def create_requests( num_requests: int, num_tokens: int = 10, @@ -122,7 +127,14 @@ def create_requests( stop_token_ids: Optional[list[int]] = None, prompt_logprobs: Optional[int] = None, same_prompt: bool = False, + block_size: int = 16, ) -> list[Request]: + global _none_hash_initialized + if not _none_hash_initialized: + init_none_hash(hash) + _none_hash_initialized = True + + block_hasher = get_request_block_hasher(block_size, hash) sampling_params = SamplingParams(ignore_eos=False, max_tokens=max_tokens, stop_token_ids=stop_token_ids, @@ -139,9 +151,11 @@ def create_requests( ) mm_item = MultiModalKwargsItem.from_elems([mm_elem]) mm_kwargs = [mm_item] * len(mm_position) + mm_hashes = ["hash"] * len(mm_position) else: mm_position = None mm_kwargs = None + mm_hashes = None prompt_token_ids = ([0] * num_tokens if same_prompt else [i] * num_tokens) request = Request( @@ -151,8 +165,9 @@ def create_requests( pooling_params=None, multi_modal_kwargs=mm_kwargs, multi_modal_placeholders=mm_position, - multi_modal_hashes=None, + multi_modal_hashes=mm_hashes, eos_token_id=EOS_TOKEN_ID, + block_hasher=block_hasher, ) requests.append(request) return requests diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index b185936ab025f..e6859ea738277 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -147,6 +147,7 @@ def test_basic_interface(): NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) request = create_request(request_id=1, + block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS, do_remote_prefill=True) request_id = request.request_id @@ -186,6 +187,7 @@ def test_prompt_less_than_block_size(): # Request will have 1 partial remote block. request = create_request(request_id=1, + block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS, do_remote_prefill=True, num_remote_blocks=1) diff --git a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py index 2f8228864e7b4..d8c56ac42f718 100644 --- a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py @@ -21,6 +21,7 @@ def test_basic_lifecycle(): NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) request = create_request(request_id=1, + block_size=BLOCK_SIZE, max_tokens=1, num_tokens=NUM_TOKENS, do_remote_decode=True) @@ -103,8 +104,10 @@ def test_short_prompt_lifecycle(): scheduler = create_scheduler(vllm_config) # Not enough tokens for full block. - NUM_TOKENS = vllm_config.cache_config.block_size // 2 + BLOCK_SIZE = vllm_config.cache_config.block_size + NUM_TOKENS = BLOCK_SIZE // 2 request = create_request(request_id=1, + block_size=BLOCK_SIZE, max_tokens=1, num_tokens=NUM_TOKENS, do_remote_decode=True) @@ -148,7 +151,9 @@ def test_prefix_cache_lifecycle(): NUM_EXTERNAL_FULL_BLOCKS = 3 NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) - request_normal = create_request(request_id=1, num_tokens=NUM_TOKENS) + request_normal = create_request(request_id=1, + block_size=BLOCK_SIZE, + num_tokens=NUM_TOKENS) scheduler.add_request(request_normal) scheduler_output = scheduler.schedule() @@ -166,6 +171,7 @@ def test_prefix_cache_lifecycle(): NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) request_remote = create_request(request_id=1, + block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS, do_remote_decode=True) diff --git a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py index 87f7490698a31..21fec5344255c 100644 --- a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py @@ -23,6 +23,7 @@ def test_basic_lifecycle(): scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) request = create_request(request_id=1, + block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS, do_remote_prefill=True) @@ -133,14 +134,17 @@ def test_interleaved_lifecycle(): NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) request_remote = create_request(request_id=1, + block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS, do_remote_prefill=True) request_local_a = create_request( request_id=2, + block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS, ) request_local_b = create_request( request_id=3, + block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS, ) @@ -236,6 +240,7 @@ def test_no_spurious_prefix_caching(): # Both of these requests have prompts like [1,1,1,1,1, ...] request_remote = create_request( request_id=1, + block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS, do_remote_prefill=True, use_all_1s_for_prompt_tokens=True, @@ -243,6 +248,7 @@ def test_no_spurious_prefix_caching(): request_local = create_request( request_id=2, + block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS, do_remote_prefill=False, use_all_1s_for_prompt_tokens=True, @@ -292,6 +298,7 @@ def test_full_block_prompt(): NUM_TOKENS = int(BLOCK_SIZE * NUM_EXTERNAL_FULL_BLOCKS) request = create_request(request_id=1, + block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS, do_remote_prefill=True) @@ -364,8 +371,11 @@ def test_cannot_schedule_after_recv(): NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS) NUM_TOKENS_REMOTE = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS) - request_normal = create_request(request_id=1, num_tokens=NUM_TOKENS_LOCAL) + request_normal = create_request(request_id=1, + block_size=BLOCK_SIZE, + num_tokens=NUM_TOKENS_LOCAL) request_remote = create_request(request_id=2, + block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS_REMOTE, do_remote_prefill=True) @@ -456,8 +466,11 @@ def test_cannot_recv(): NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS) NUM_TOKENS_REMOTE = int(BLOCK_SIZE * (NUM_PROMPT_BLOCKS + 0.5)) - request_normal = create_request(request_id=1, num_tokens=NUM_TOKENS_LOCAL) + request_normal = create_request(request_id=1, + block_size=BLOCK_SIZE, + num_tokens=NUM_TOKENS_LOCAL) request_remote = create_request(request_id=2, + block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS_REMOTE, do_remote_prefill=True) diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index 60847c48585c6..8c5d132c00ae4 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import tempfile from collections import defaultdict -from typing import Any, Optional +from typing import Any, Callable, Optional import torch @@ -14,6 +14,8 @@ from vllm.distributed.kv_transfer.kv_connector.factory import ( from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa SharedStorageConnector) from vllm.v1.core.kv_cache_manager import KVCacheBlocks +from vllm.v1.core.kv_cache_utils import (get_request_block_hasher, + init_none_hash) from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec) @@ -40,7 +42,6 @@ def assert_scheduler_empty(scheduler: Scheduler): # KVCache Manager. assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. req_to_blocks) == 0 - assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0 assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. num_cached_block) == 0 num_free_blocks = ( @@ -115,16 +116,23 @@ def create_scheduler( ) -def create_request( - request_id: int, - num_tokens: int = 10, - max_tokens: int = 16, - do_remote_decode: bool = False, - do_remote_prefill: bool = False, - use_all_1s_for_prompt_tokens: bool = False, - num_remote_blocks: int = 3, -) -> Request: +_none_hash_initialized = False + + +def create_request(request_id: int, + num_tokens: int = 10, + max_tokens: int = 16, + do_remote_decode: bool = False, + do_remote_prefill: bool = False, + use_all_1s_for_prompt_tokens: bool = False, + num_remote_blocks: int = 3, + block_size: int = 16, + hash_fn: Callable = hash) -> Request: """Make dummy request for testing.""" + global _none_hash_initialized + if not _none_hash_initialized: + init_none_hash(hash) + _none_hash_initialized = True kv_transfer_params: Optional[dict[str, Any]] = None @@ -158,6 +166,7 @@ def create_request( multi_modal_placeholders=None, multi_modal_hashes=None, eos_token_id=EOS_TOKEN_ID, + block_hasher=get_request_block_hasher(block_size, hash_fn), ) req.kv_transfer_params = kv_transfer_params return req diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index a1f8ad164762d..72857ee2abc77 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -3243,6 +3243,24 @@ def sha256_cbor_64bit(input) -> int: return full_hash & ((1 << 64) - 1) +def get_hash_fn_by_name(hash_fn_name: str) -> Callable: + """Get a hash function by name, or raise an error if + the function is not found. + Args: + hash_fn_name: Name of the hash function. + Returns: + A hash function. + """ + if hash_fn_name == "sha256": + return sha256 + if hash_fn_name == "sha256_cbor_64bit": + return sha256_cbor_64bit + if hash_fn_name == "builtin": + return hash + + raise ValueError(f"Unsupported hash function: {hash_fn_name}") + + def is_torch_equal_or_newer(target: str) -> bool: """Check if the installed torch version is >= the target version. diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index ad9854dd29c38..839297135fe0a 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -2,15 +2,13 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections import defaultdict from collections.abc import Iterable -from typing import Callable, Optional +from typing import Optional from vllm.distributed.kv_events import (AllBlocksCleared, BlockRemoved, BlockStored, KVCacheEvent) from vllm.logger import init_logger from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, - FreeKVCacheBlockQueue, KVCacheBlock, - generate_block_hash_extra_keys, - hash_block_tokens) + FreeKVCacheBlockQueue, KVCacheBlock) from vllm.v1.request import Request logger = init_logger(__name__) @@ -97,84 +95,39 @@ class BlockPool: self, request: Request, blocks: list[KVCacheBlock], - block_hashes: list[BlockHash], num_cached_blocks: int, num_full_blocks: int, block_size: int, kv_cache_group_id: int, - hash_fn: Callable, ) -> None: """Cache a list of full blocks for prefix caching. This function takes a list of blocks that will have their block hash - metadata to be updated and cached. Given a request, it computes the - block hashes for the blocks starting from `num_cached_blocks` to - `num_full_blocks`, updating the metadata for each block - and caching them in the `cached_block_hash_to_block`. + metadata to be updated and cached. Given a request, it updates the + metadata for each block and caching it in the + `cached_block_hash_to_block`. + The block hashes values are computed by the Request object immediately + when it is created and when new tokens are appended. Args: request: The request to cache the blocks. blocks: All blocks in the request. - block_hashes: Block hashes of the blocks in the request. Note that - this list may be shorter than the blocks list. In this case the - missed block hash will be computed in this function. num_cached_blocks: The number of blocks that are already cached. num_full_blocks: The number of blocks that are full and should be cached after this function. block_size: Number of tokens in each block. kv_cache_group_id: The id of the KV cache group. - hash_fn: The hash function to use for block hashes. """ if num_cached_blocks == num_full_blocks: return new_full_blocks = blocks[num_cached_blocks:num_full_blocks] - assert len(block_hashes) >= num_cached_blocks - new_block_hashes = block_hashes[num_cached_blocks:] + assert len(request.block_hashes) >= num_full_blocks + new_block_hashes = request.block_hashes[num_cached_blocks:] - # Update the new blocks with the block hashes through the chain. - if num_cached_blocks == 0: - prev_block_hash_value = None - else: - prev_block = blocks[num_cached_blocks - 1] - assert prev_block.block_hash is not None - prev_block_hash_value = prev_block.block_hash.get_hash_value() - - parent_block_hash = prev_block_hash_value new_hashes: Optional[list[int]] = ([] if self.enable_kv_cache_events else None) for i, blk in enumerate(new_full_blocks): assert blk.block_hash is None - - if i < len(new_block_hashes): - # The block hash may already be computed in - # "get_computed_blocks" if the tokens are not generated by - # this request (either the prompt tokens or the previously - # generated tokens with preemption), or by other - # single_type_managers with the same block_size. - # In this case we simply reuse the block hash. - block_hash = new_block_hashes[i] - else: - # Otherwise compute the block hash and cache it in the request - # in case it will be preempted in the future. - blk_idx = num_cached_blocks + i - start_token_idx = blk_idx * block_size - end_token_idx = (blk_idx + 1) * block_size - block_tokens = request.all_token_ids[ - start_token_idx:end_token_idx] - assert len(block_tokens) == block_size, ( - f"Expected {block_size} tokens, got " - f"{len(block_tokens)} at {blk_idx}th block for request " - f"{request.request_id}({request})") - - # Generate extra keys for multi-modal inputs. Note that since - # we reach to this branch only when the block is completed with - # generated tokens, we only need to consider the last mm input. - extra_keys, _ = generate_block_hash_extra_keys( - request, start_token_idx, end_token_idx, -1) - - # Compute the hash of the current block. - block_hash = hash_block_tokens(hash_fn, prev_block_hash_value, - block_tokens, extra_keys) - block_hashes.append(block_hash) + block_hash = new_block_hashes[i] # Update and added the full block to the cache. block_hash_with_group_id = BlockHashWithGroupId( @@ -184,9 +137,15 @@ class BlockPool: blk.block_id] = blk if new_hashes is not None: new_hashes.append(block_hash.hash_value) - prev_block_hash_value = block_hash.hash_value if self.enable_kv_cache_events: + if num_cached_blocks == 0: + parent_block_hash = None + else: + parent_block = blocks[num_cached_blocks - 1] + assert parent_block.block_hash is not None + parent_block_hash = parent_block.block_hash.get_hash_value() + self.kv_event_queue.append( BlockStored( block_hashes=new_hashes, diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index f3a16d64e19fd..a0ea4d96015a2 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod -from typing import Callable, Optional +from typing import Optional from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock @@ -23,7 +23,6 @@ class KVCacheCoordinator(ABC): max_model_len: int, use_eagle: bool, enable_caching: bool, - caching_hash_fn: Callable, enable_kv_cache_events: bool, ): self.kv_cache_config = kv_cache_config @@ -40,7 +39,6 @@ class KVCacheCoordinator(ABC): kv_cache_spec=kv_cache_group.kv_cache_spec, block_pool=self.block_pool, kv_cache_group_id=i, - caching_hash_fn=caching_hash_fn, ) for i, kv_cache_group in enumerate( self.kv_cache_config.kv_cache_groups)) @@ -99,19 +97,17 @@ class KVCacheCoordinator(ABC): manager.allocate_new_blocks(request_id, num_tokens) for manager in self.single_type_managers) - def cache_blocks(self, request: Request, block_hashes: list[BlockHash], - num_computed_tokens: int) -> None: + def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: """ Cache the blocks for the request. Args: request: The request. - block_hashes: The block hashes of the request. num_tokens: The total number of tokens that need to be cached (including tokens that are already cached). """ for manager in self.single_type_managers: - manager.cache_blocks(request, block_hashes, num_computed_tokens) + manager.cache_blocks(request, num_computed_tokens) def free(self, request_id: str) -> None: """ @@ -184,10 +180,9 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator): """ def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, - use_eagle: bool, caching_hash_fn: Callable, - enable_kv_cache_events: bool): + use_eagle: bool, enable_kv_cache_events: bool): super().__init__(kv_cache_config, max_model_len, use_eagle, False, - caching_hash_fn, enable_kv_cache_events) + enable_kv_cache_events) self.num_single_type_manager = len(self.single_type_managers) def get_num_common_prefix_blocks(self, request_id: str, @@ -213,10 +208,9 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator): def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool, enable_caching: bool, - caching_hash_fn: Callable, enable_kv_cache_events: bool): + enable_kv_cache_events: bool): super().__init__(kv_cache_config, max_model_len, use_eagle, - enable_caching, caching_hash_fn, - enable_kv_cache_events) + enable_caching, enable_kv_cache_events) self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[ 0].kv_cache_spec self.block_size = self.kv_cache_spec.block_size @@ -250,10 +244,9 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool, enable_caching: bool, - caching_hash_fn: Callable, enable_kv_cache_events: bool): + enable_kv_cache_events: bool): super().__init__(kv_cache_config, max_model_len, use_eagle, - enable_caching, caching_hash_fn, - enable_kv_cache_events) + enable_caching, enable_kv_cache_events) self.verify_and_split_kv_cache_groups() def verify_and_split_kv_cache_groups(self) -> None: @@ -386,17 +379,15 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): def get_kv_cache_coordinator( kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool, - enable_caching: bool, caching_hash_fn: Callable, + enable_caching: bool, enable_kv_cache_events: bool) -> KVCacheCoordinator: if not enable_caching: return KVCacheCoordinatorNoPrefixCache(kv_cache_config, max_model_len, - use_eagle, caching_hash_fn, + use_eagle, enable_kv_cache_events) if len(kv_cache_config.kv_cache_groups) == 1: return UnitaryKVCacheCoordinator(kv_cache_config, max_model_len, use_eagle, enable_caching, - caching_hash_fn, enable_kv_cache_events) return HybridKVCacheCoordinator(kv_cache_config, max_model_len, use_eagle, - enable_caching, caching_hash_fn, - enable_kv_cache_events) + enable_caching, enable_kv_cache_events) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index ce333dbe61a19..bfaa7ab08f5cf 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -1,16 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections import defaultdict from dataclasses import dataclass from typing import Optional from vllm.distributed.kv_events import KVCacheEvent from vllm.logger import init_logger -from vllm.utils import sha256, sha256_cbor_64bit from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator -from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock, - hash_request_tokens, init_none_hash) +from vllm.v1.core.kv_cache_utils import KVCacheBlock from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request, RequestStatus @@ -71,23 +68,13 @@ class KVCacheManager: kv_cache_config: KVCacheConfig, max_model_len: int, enable_caching: bool = True, - caching_hash_algo: str = "builtin", use_eagle: bool = False, log_stats: bool = False, enable_kv_cache_events: bool = False, ) -> None: self.max_model_len = max_model_len - if len(kv_cache_config.kv_cache_groups) == 0: - # Attention free models don't have kv cache, - # thus don't need prefix caching. - enable_caching = False self.enable_caching = enable_caching - - self.caching_hash_fn = ( - sha256_cbor_64bit if caching_hash_algo == "sha256_cbor_64bit" else - sha256 if caching_hash_algo == "sha256" else hash) - init_none_hash(self.caching_hash_fn) self.use_eagle = use_eagle self.log_stats = log_stats # FIXME: make prefix cache stats conditional on log_stats @@ -107,19 +94,12 @@ class KVCacheManager: max_model_len=self.max_model_len, use_eagle=self.use_eagle, enable_caching=self.enable_caching, - caching_hash_fn=self.caching_hash_fn, enable_kv_cache_events=enable_kv_cache_events, ) self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups) self.block_pool = self.coordinator.block_pool self.kv_cache_config = kv_cache_config - # Mapping from request ID to kv block hashes. - # This is to avoid recomputing the block hashes for each call of - # `get_computed_blocks` or `allocate_slots`. - self.req_to_block_hashes: defaultdict[ - str, list[BlockHash]] = defaultdict(list) - @property def usage(self) -> float: """Get the KV cache usage. @@ -161,15 +141,6 @@ class KVCacheManager: and request.sampling_params.prompt_logprobs is not None)): return self.create_empty_block_list(), 0 - # The block hashes for the request may already be computed - # if the scheduler has tried to schedule the request before. - block_hashes = self.req_to_block_hashes[request.request_id] - if not block_hashes: - assert self.block_size is not None - block_hashes = hash_request_tokens(self.caching_hash_fn, - self.block_size, request) - self.req_to_block_hashes[request.request_id] = block_hashes - # NOTE: When all tokens hit the cache, we must recompute the last token # to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1. # This can trigger recomputation of an entire block, rather than just @@ -178,7 +149,7 @@ class KVCacheManager: # could slightly improve performance in the future. max_cache_hit_length = request.num_tokens - 1 computed_blocks, num_new_computed_tokens = ( - self.coordinator.find_longest_cache_hit(block_hashes, + self.coordinator.find_longest_cache_hit(request.block_hashes, max_cache_hit_length)) if self.log_stats: @@ -296,11 +267,7 @@ class KVCacheManager: # at `request.num_tokens`, ensuring only "finalized" tokens are cached. num_tokens_to_cache = min(num_computed_tokens + num_new_tokens, request.num_tokens) - self.coordinator.cache_blocks( - request, - self.req_to_block_hashes[request.request_id], - num_tokens_to_cache, - ) + self.coordinator.cache_blocks(request, num_tokens_to_cache) return KVCacheBlocks(new_blocks) @@ -373,14 +340,6 @@ class KVCacheManager: return self.coordinator.get_num_common_prefix_blocks( request.request_id, num_running_requests) - def free_block_hashes(self, request: Request) -> None: - """Discard the block hashes for the request. - - NOTE: Unlike `free`, this method should be called only when the request - is finished, not when it is preempted. - """ - self.req_to_block_hashes.pop(request.request_id, None) - def take_events(self) -> list[KVCacheEvent]: """Take the KV cache events from the block pool. @@ -397,9 +356,7 @@ class KVCacheManager: def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: """Cache the blocks for the request, if enabled.""" if self.enable_caching: - block_hashes = self.req_to_block_hashes[request.request_id] - self.coordinator.cache_blocks(request, block_hashes, - num_computed_tokens) + self.coordinator.cache_blocks(request, num_computed_tokens) def create_empty_block_list(self) -> KVCacheBlocks: """Creates a new KVCacheBlocks instance with no blocks.""" diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 626aa35a770c9..6a62c55fb2d5f 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -547,41 +547,61 @@ def hash_block_tokens( curr_block_token_ids_tuple, extra_keys) -def hash_request_tokens(hash_function: Any, block_size: int, - request: Request) -> list[BlockHash]: - """Computes hash values of a chain of blocks given a sequence of - token IDs. The hash value is used for prefix caching. - - Args: - block_size: The size of each block. - request: The request object. - - Returns: - The list of computed hash values. +def get_request_block_hasher( + block_size: int, + caching_hash_fn: Callable[[Any], + int]) -> Callable[[Request], list[BlockHash]]: """ - token_ids = request.all_token_ids + Returns a function which computes the list of un-computed block hashes + of a request. - req_need_extra_keys = need_extra_keys(request) - req_extra_keys = None - curr_mm_idx = 0 + Each request holds a list of its block hashes (request.block_hashes). + When a request is created, it calls the below function to compute + the hashes of all full blocks of the request's initial tokens. + The hashes are then stored in request.block_hashes. + Later, whenever new tokens are appended to the request, it calls + the below function again to compute any new full blocks of tokens. + The returned new hashes are appended to request.block_hashes. + """ - ret = [] - parent_block_hash_value = None - # Only full blocks will be hashed - for start in range(0, len(token_ids) - block_size + 1, block_size): - end = start + block_size - block_token_ids = token_ids[start:end] + def request_block_hasher(request: Request) -> list[BlockHash]: + start_token_idx = len(request.block_hashes) * block_size + num_tokens = request.num_tokens + + curr_mm_idx = 0 + if start_token_idx > 0: + # Set curr_mm_idx = -1 to indicate the last mm input. + # Note that since we reach to this branch only when the block is + # completed with generated tokens, we only need to consider the + # last mm input. + curr_mm_idx = -1 + + prev_block_hash_value = request.block_hashes[-1].hash_value \ + if request.block_hashes else None + new_block_hashes: list[BlockHash] = [] + while True: + end_token_idx = start_token_idx + block_size + if end_token_idx > num_tokens: + # We only hash full blocks + break - if req_need_extra_keys: # MM and LoRA requests need extra keys for block-hash computation. - req_extra_keys, curr_mm_idx = generate_block_hash_extra_keys( - request, start, end, curr_mm_idx) + extra_keys, curr_mm_idx = generate_block_hash_extra_keys( + request, start_token_idx, end_token_idx, curr_mm_idx) - block_hash = hash_block_tokens(hash_function, parent_block_hash_value, - block_token_ids, req_extra_keys) - ret.append(block_hash) - parent_block_hash_value = block_hash.hash_value - return ret + # Compute the hash of the current block + block_tokens = request.all_token_ids[start_token_idx:end_token_idx] + block_hash = hash_block_tokens(caching_hash_fn, + prev_block_hash_value, block_tokens, + extra_keys) + + new_block_hashes.append(block_hash) + start_token_idx += block_size + prev_block_hash_value = block_hash.hash_value + + return new_block_hashes + + return request_block_hasher def max_memory_usage_bytes(vllm_config: VllmConfig, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index dcb9f4dd36f52..9810234090453 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -155,7 +155,6 @@ class Scheduler(SchedulerInterface): kv_cache_config=kv_cache_config, max_model_len=self.max_model_len, enable_caching=self.cache_config.enable_prefix_caching, - caching_hash_algo=self.cache_config.prefix_caching_hash_algo, use_eagle=self.use_eagle, log_stats=self.log_stats, enable_kv_cache_events=self.enable_kv_cache_events, @@ -1036,7 +1035,6 @@ class Scheduler(SchedulerInterface): def _free_blocks(self, request: Request): assert request.is_finished() self.kv_cache_manager.free(request) - self.kv_cache_manager.free_block_hashes(request) del self.requests[request.request_id] def get_num_unfinished_requests(self) -> int: diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 8f310023a8cd3..82e0292522b9a 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -3,7 +3,6 @@ import itertools from abc import ABC, abstractmethod from collections import defaultdict -from typing import Callable from vllm.utils import cdiv from vllm.v1.core.block_pool import BlockPool @@ -25,7 +24,6 @@ class SingleTypeKVCacheManager(ABC): kv_cache_spec: KVCacheSpec, block_pool: BlockPool, kv_cache_group_id: int, - caching_hash_fn: Callable, ) -> None: """ Initializes the SingleTypeKVCacheManager. @@ -33,7 +31,6 @@ class SingleTypeKVCacheManager(ABC): kv_cache_spec: The kv_cache_spec for this manager. block_pool: The block pool. kv_cache_group_id: The id of the kv cache group of this manager. - caching_hash_fn: The caching hash function. """ self.block_size = kv_cache_spec.block_size @@ -52,7 +49,6 @@ class SingleTypeKVCacheManager(ABC): # data for reempted ones. self.num_cached_block: dict[str, int] = {} - self.caching_hash_fn = caching_hash_fn self.kv_cache_group_id = kv_cache_group_id self._null_block = block_pool.null_block @@ -130,14 +126,12 @@ class SingleTypeKVCacheManager(ABC): req_blocks.extend(new_blocks) return new_blocks - def cache_blocks(self, request: Request, block_hashes: list[BlockHash], - num_tokens: int) -> None: + def cache_blocks(self, request: Request, num_tokens: int) -> None: """ Cache the blocks for the request. Args: request: The request. - block_hashes: The block hashes of the request. num_tokens: The total number of tokens that need to be cached (including tokens that are already cached). """ @@ -147,12 +141,10 @@ class SingleTypeKVCacheManager(ABC): self.block_pool.cache_full_blocks( request=request, blocks=self.req_to_blocks[request.request_id], - block_hashes=block_hashes, num_cached_blocks=num_cached_blocks, num_full_blocks=num_full_blocks, block_size=self.block_size, kv_cache_group_id=self.kv_cache_group_id, - hash_fn=self.caching_hash_fn, ) self.num_cached_block[request.request_id] = num_full_blocks diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index ed426f8ff452b..1e52f93a581b3 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -25,9 +25,11 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.tasks import POOLING_TASKS, SupportedTask from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) -from vllm.utils import (decorate_logs, make_zmq_socket, +from vllm.utils import (decorate_logs, get_hash_fn_by_name, make_zmq_socket, resolve_obj_by_qualname, set_process_title) -from vllm.v1.core.kv_cache_utils import (get_kv_cache_config, +from vllm.v1.core.kv_cache_utils import (BlockHash, get_kv_cache_config, + get_request_block_hasher, + init_none_hash, unify_kv_cache_configs) from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.output import SchedulerOutput @@ -140,6 +142,19 @@ class EngineCore: self.batch_queue_size) self.batch_queue = queue.Queue(self.batch_queue_size) + self.request_block_hasher: Optional[Callable[[Request], + list[BlockHash]]] = None + if (self.vllm_config.cache_config.enable_prefix_caching + or self.scheduler.get_kv_connector() is not None): + + block_size = vllm_config.cache_config.block_size + caching_hash_fn = get_hash_fn_by_name( + vllm_config.cache_config.prefix_caching_hash_algo) + init_none_hash(caching_hash_fn) + + self.request_block_hasher = get_request_block_hasher( + block_size, caching_hash_fn) + def _initialize_kv_caches( self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]: start = time.time() @@ -417,7 +432,8 @@ class EngineCore: request.mm_kwargs = self.mm_input_cache_server.get_and_update( request.mm_kwargs, request.mm_hashes) - req = Request.from_engine_core_request(request) + req = Request.from_engine_core_request(request, + self.request_block_hasher) if req.use_structured_output: # Note on thread safety: no race condition. # `grammar_init` is only invoked in input processing thread. For diff --git a/vllm/v1/request.py b/vllm/v1/request.py index d1f1c7f98755f..562925bde669e 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -3,7 +3,8 @@ import enum import time -from typing import TYPE_CHECKING, Any, Optional, Union +from functools import partial +from typing import TYPE_CHECKING, Any, Callable, Optional, Union from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange from vllm.pooling_params import PoolingParams @@ -16,6 +17,7 @@ from vllm.v1.utils import ConstantList if TYPE_CHECKING: from vllm.lora.request import LoRARequest + from vllm.v1.core.kv_cache_utils import BlockHash class Request: @@ -36,6 +38,8 @@ class Request: structured_output_request: Optional["StructuredOutputRequest"] = None, cache_salt: Optional[str] = None, priority: int = 0, + block_hasher: Optional[Callable[["Request"], + list["BlockHash"]]] = None, ) -> None: self.request_id = request_id self.client_index = client_index @@ -108,8 +112,18 @@ class Request: # indicates that the output is corrupted self.num_nans_in_logits = 0 + self.block_hashes: list[BlockHash] = [] + self.get_hash_new_full_blocks: Optional[Callable[ + [], list[BlockHash]]] = None + if block_hasher is not None: + self.get_hash_new_full_blocks = partial(block_hasher, self) + self.block_hashes = self.get_hash_new_full_blocks() + @classmethod - def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": + def from_engine_core_request( + cls, request: EngineCoreRequest, + block_hasher: Optional[Callable[["Request"], list["BlockHash"]]] + ) -> "Request": if request.mm_kwargs is not None: assert is_list_of(request.mm_kwargs, MultiModalKwargsItem), ( "mm_kwargs was not updated in EngineCore.add_request") @@ -131,6 +145,7 @@ class Request: if request.sampling_params else None, cache_salt=request.cache_salt, priority=request.priority, + block_hasher=block_hasher, ) def append_output_token_ids( @@ -144,6 +159,9 @@ class Request: self._output_token_ids.extend(token_ids) self._all_token_ids.extend(token_ids) + if self.get_hash_new_full_blocks is not None: + self.block_hashes.extend(self.get_hash_new_full_blocks()) + @property def is_output_corrupted(self) -> bool: return self.num_nans_in_logits > 0