From 82dfb12e52ded608da2a458b2da5f48ba1acecdf Mon Sep 17 00:00:00 2001 From: Zebing Lin Date: Tue, 9 Sep 2025 00:34:37 -0400 Subject: [PATCH] [Core] Use sha256 bytes instead of BlockHash to reduce GC overhead (#23673) Signed-off-by: linzebing --- .../online_serving/kv_events_subscriber.py | 8 +- tests/utils_/test_utils.py | 20 +- tests/v1/core/test_kv_cache_utils.py | 65 +++-- tests/v1/core/test_prefix_caching.py | 225 ++++++++++-------- .../core/test_single_type_kv_cache_manager.py | 16 +- tests/v1/core/utils.py | 5 +- tests/v1/engine/test_engine_args.py | 13 +- tests/v1/kv_connector/unit/utils.py | 5 +- vllm/config/cache.py | 17 +- vllm/distributed/kv_events.py | 7 +- vllm/engine/arg_utils.py | 20 +- vllm/envs.py | 6 + vllm/utils/__init__.py | 27 +-- vllm/v1/core/block_pool.py | 27 ++- vllm/v1/core/kv_cache_utils.py | 120 +++++----- 15 files changed, 298 insertions(+), 283 deletions(-) diff --git a/examples/online_serving/kv_events_subscriber.py b/examples/online_serving/kv_events_subscriber.py index f238c66234dcc..9fd55fc9ddc94 100644 --- a/examples/online_serving/kv_events_subscriber.py +++ b/examples/online_serving/kv_events_subscriber.py @@ -6,6 +6,8 @@ import msgspec import zmq from msgspec.msgpack import Decoder +from vllm.v1.core.kv_cache_utils import BlockHash + # # Types copied from vllm.distributed.kv_events @@ -22,8 +24,8 @@ class KVCacheEvent( class BlockStored(KVCacheEvent): - block_hashes: list[int] - parent_block_hash: Optional[int] + block_hashes: list[BlockHash] + parent_block_hash: Optional[BlockHash] token_ids: list[int] block_size: int lora_id: Optional[int] @@ -31,7 +33,7 @@ class BlockStored(KVCacheEvent): class BlockRemoved(KVCacheEvent): - block_hashes: list[int] + block_hashes: list[BlockHash] medium: Optional[str] diff --git a/tests/utils_/test_utils.py b/tests/utils_/test_utils.py index 66124dd854ee0..6dbba18b4dcfa 100644 --- a/tests/utils_/test_utils.py +++ b/tests/utils_/test_utils.py @@ -835,22 +835,20 @@ def test_model_specification(parser_with_config, cli_config_file, @pytest.mark.parametrize("input", [(), ("abc", ), (None, ), (None, bool, [1, 2, 3])]) -@pytest.mark.parametrize("output", [0, 1, 2]) -def test_sha256(input: tuple, output: int): - hash = sha256(input) - assert hash is not None - assert isinstance(hash, int) - assert hash != 0 +def test_sha256(input: tuple): + digest = sha256(input) + assert digest is not None + assert isinstance(digest, bytes) + assert digest != b"" - bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) - assert hash == int.from_bytes(hashlib.sha256(bytes).digest(), - byteorder="big") + input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) + assert digest == hashlib.sha256(input_bytes).digest() # hashing again, returns the same value - assert hash == sha256(input) + assert digest == sha256(input) # hashing different input, returns different value - assert hash != sha256(input + (1, )) + assert digest != sha256(input + (1, )) @pytest.mark.parametrize( diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 4d0a26f76e98e..44e479098ad5d 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -6,20 +6,22 @@ from typing import Callable, Optional import pytest import torch +import vllm.v1.core.kv_cache_utils as kv_cache_utils from vllm.config import ModelConfig, SchedulerConfig, VllmConfig from vllm.multimodal.inputs import (MultiModalFeatureSpec, MultiModalKwargsItem, PlaceholderRange) from vllm.sampling_params import SamplingParams -from vllm.utils import GiB_bytes, sha256, sha256_cbor_64bit +from vllm.utils import GiB_bytes, sha256, sha256_cbor from vllm.v1.core.kv_cache_manager import KVCacheManager # disable yapf here as it formats differently than isort such that both fail # yapf: disable from vllm.v1.core.kv_cache_utils import ( - FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics, + BlockHash, FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics, estimate_max_model_len, generate_block_hash_extra_keys, get_kv_cache_config, get_max_concurrency_for_kv_cache_config, get_request_block_hasher, hash_block_tokens, init_none_hash, - is_kv_cache_type_uniform, unify_kv_cache_configs) + is_kv_cache_type_uniform, make_block_hash_with_group_id, + unify_kv_cache_configs) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheTensor, SlidingWindowSpec) @@ -88,7 +90,7 @@ def new_sliding_window_spec(block_size=16, sliding_window=sliding_window) -@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) +@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) def test_none_hash(monkeypatch, hash_fn): import vllm.v1.core.kv_cache_utils @@ -98,8 +100,8 @@ def test_none_hash(monkeypatch, hash_fn): reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils) reloaded_kv_cache_utils.init_none_hash(hash_fn) assert reloaded_kv_cache_utils.NONE_HASH is not None - assert isinstance(reloaded_kv_cache_utils.NONE_HASH, int) - assert reloaded_kv_cache_utils.NONE_HASH != 0 + assert isinstance(reloaded_kv_cache_utils.NONE_HASH, bytes) + assert reloaded_kv_cache_utils.NONE_HASH != b"" # case 2: PYTHONHASHSEED is set, use the seed and hash_fn with monkeypatch.context() as m: @@ -107,12 +109,11 @@ def test_none_hash(monkeypatch, hash_fn): reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils) reloaded_kv_cache_utils.init_none_hash(hash_fn) assert reloaded_kv_cache_utils.NONE_HASH is not None - assert isinstance(reloaded_kv_cache_utils.NONE_HASH, int) + assert isinstance(reloaded_kv_cache_utils.NONE_HASH, bytes) assert hash_fn('python hash seed') == reloaded_kv_cache_utils.NONE_HASH def test_kv_cache_block(): - import vllm.v1.core.kv_cache_utils # Test KVCacheBlock initialization block = KVCacheBlock(block_id=0) @@ -127,8 +128,7 @@ def test_kv_cache_block(): assert block.ref_cnt == 0 # Test block hash setting and resetting - block_hash = vllm.v1.core.kv_cache_utils.BlockHash(hash_value=123, - token_ids=(1, 2, 3)) + block_hash = make_block_hash_with_group_id(BlockHash(b"abc"), 0) block.block_hash = block_hash assert block.block_hash == block_hash @@ -407,27 +407,23 @@ def test_generate_block_hash_extra_keys_cache_salt(): assert next_mm_idx == 1 -@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) +@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) def test_hash_block_tokens(hash_fn): - import vllm.v1.core.kv_cache_utils init_none_hash(hash_fn) - parent_block_hash = 123 + parent_block_hash = BlockHash(b"123") curr_block_token_ids = (1, 2, 3) extra_keys = ("key1", "key2") block_hash = hash_block_tokens(hash_fn, parent_block_hash, curr_block_token_ids, extra_keys) - assert isinstance(block_hash, vllm.v1.core.kv_cache_utils.BlockHash) - assert block_hash.hash_value == hash_fn( - (parent_block_hash, curr_block_token_ids, extra_keys)) - assert block_hash.token_ids == curr_block_token_ids - assert block_hash.extra_keys == extra_keys + expected = hash_fn((parent_block_hash, curr_block_token_ids, extra_keys)) + assert block_hash == expected -@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) +@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) def test_request_block_hasher(hash_fn): - import vllm.v1.core.kv_cache_utils - init_none_hash(hash_fn) + kv_cache_utils.init_none_hash(hash_fn) + request = make_request( request_id="0", prompt_token_ids=[_ for _ in range(6)], @@ -442,19 +438,13 @@ def test_request_block_hasher(hash_fn): block_hashes = request.block_hashes assert len(block_hashes) == 2 - assert isinstance(block_hashes[0], vllm.v1.core.kv_cache_utils.BlockHash) - assert isinstance(block_hashes[1], vllm.v1.core.kv_cache_utils.BlockHash) - - # Check the first block - assert block_hashes[0].token_ids == (0, 1, 2) - assert block_hashes[0].extra_keys == ("hash1", ) - - # Check the second block - assert block_hashes[1].token_ids == (3, 4, 5) - assert block_hashes[1].extra_keys == ("hash2", ) + assert block_hashes[0] == hash_fn( + (kv_cache_utils.NONE_HASH, (0, 1, 2), ("hash1", ))) + assert block_hashes[1] == hash_fn( + (block_hashes[0], (3, 4, 5), ("hash2", ))) -@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) +@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) def test_hash_tokens_different_mm_input(hash_fn): init_none_hash(hash_fn) @@ -484,9 +474,9 @@ def test_hash_tokens_different_mm_input(hash_fn): assert block_hashes1[1] != block_hashes2[1] -@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) +@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) def test_hash_request_tokens_no_mm_inputs(hash_fn): - init_none_hash(hash_fn) + kv_cache_utils.init_none_hash(hash_fn) request = make_request( request_id="0", @@ -500,10 +490,9 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn): block_hashes = request.block_hashes assert len(block_hashes) == 2 - assert block_hashes[0].token_ids == (0, 1, 2) - assert block_hashes[0].extra_keys is None - assert block_hashes[1].token_ids == (3, 4, 5) - assert block_hashes[1].extra_keys is None + assert block_hashes[0] == hash_fn( + (kv_cache_utils.NONE_HASH, (0, 1, 2), None)) + assert block_hashes[1] == hash_fn((block_hashes[0], (3, 4, 5), None)) def test_metrics(): diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index e7a8f63702b30..659d768bcf2e9 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -8,17 +8,19 @@ from typing import Callable, Optional import pytest import torch +import vllm.v1.core.kv_cache_utils as kv_cache_utils from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved from vllm.multimodal.inputs import (MultiModalFeatureSpec, MultiModalKwargsItem, PlaceholderRange) from vllm.sampling_params import SamplingParams -from vllm.utils import sha256, sha256_cbor_64bit +from vllm.utils import sha256, sha256_cbor from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_manager import KVCacheManager, Request -from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, - KVCacheBlock, +from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock, + get_block_hash, get_group_id, get_request_block_hasher, - hash_block_tokens, init_none_hash) + hash_block_tokens, init_none_hash, + make_block_hash_with_group_id) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, SlidingWindowSpec) @@ -101,8 +103,10 @@ def make_kv_cache_config_hybrid_model(block_size: int, ) -@pytest.mark.parametrize("hash_algo", ["sha256", "sha256_cbor_64bit", "hash"]) -def test_prefill(hash_algo): +@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) +def test_prefill(hash_fn): + init_none_hash(hash_fn) + block_size = 16 manager = KVCacheManager( make_kv_cache_config(block_size, 11), @@ -110,10 +114,6 @@ def test_prefill(hash_algo): enable_caching=True, ) - # choose the hash function according to the parameter - hash_fn = (sha256_cbor_64bit if hash_algo == "sha256_cbor_64bit" else - sha256 if hash_algo == "sha256" else hash) - # Complete 3 blocks (48 tokens) common_token_ids = [i for i in range(3) for _ in range(16)] @@ -137,10 +137,12 @@ def test_prefill(hash_algo): block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16]) block_hash = hash_block_tokens(hash_fn, parent_block_hash, block_tokens) - assert manager.block_pool.blocks[ - block_id].block_hash.block_hash == block_hash + blk_hash = manager.block_pool.blocks[block_id].block_hash + assert blk_hash is not None + assert get_block_hash(blk_hash) == block_hash + assert get_group_id(blk_hash) == 0 assert manager.block_pool.blocks[block_id].ref_cnt == 1 - parent_block_hash = block_hash.hash_value + parent_block_hash = block_hash # Check partial block metadata for block_id in (4, ): @@ -233,7 +235,7 @@ def test_prefill_hybrid_model(): enable_caching=True, ) - hash_fn = hash + hash_fn = sha256 # Complete 3 blocks (48 tokens) common_token_ids = [i for i in range(3) for _ in range(block_size)] @@ -260,11 +262,13 @@ def test_prefill_hybrid_model(): block_tokens = tuple(all_token_ids[(length - 1) * 16:length * 16]) block_hash = hash_block_tokens(hash_fn, parent_block_hash, block_tokens) - for block_id in block_ids: - assert manager.block_pool.blocks[ - block_id].block_hash.block_hash == block_hash + for group_id, block_id in enumerate(block_ids): + blk_hash = manager.block_pool.blocks[block_id].block_hash + assert blk_hash is not None + assert get_block_hash(blk_hash) == block_hash + assert get_group_id(blk_hash) == group_id assert manager.block_pool.blocks[block_id].ref_cnt == 1 - parent_block_hash = block_hash.hash_value + parent_block_hash = block_hash # Check partial block metadata for block_id in (4, 8, 12): @@ -298,11 +302,10 @@ def test_prefill_hybrid_model(): cached_block_hash_to_block_bak = copy.copy( manager.block_pool.cached_block_hash_to_block) - def test_partial_request_hit(request_id: str, - hash_to_evict: list[BlockHashWithGroupId], + def test_partial_request_hit(request_id: str, hash_to_evict: list[bytes], expect_hit_length: int): req = make_request(request_id, common_token_ids + unique_token_ids, - block_size, hash) + block_size, sha256) for hash_with_group_id in hash_to_evict: manager.block_pool.cached_block_hash_to_block.pop( hash_with_group_id) @@ -319,33 +322,32 @@ def test_prefill_hybrid_model(): # Evict the blocks outside sliding window, does not affect the hit length. test_partial_request_hit("2", [ - BlockHashWithGroupId(block_hashes[0], 1), - BlockHashWithGroupId(block_hashes[0], 2) + make_block_hash_with_group_id(block_hashes[0], 1), + make_block_hash_with_group_id(block_hashes[0], 2) ], 3) # Evict the first block of full attention, makes total cache miss. - test_partial_request_hit("3", [ - BlockHashWithGroupId(block_hashes[0], 0), - ], 0) + test_partial_request_hit( + "3", [make_block_hash_with_group_id(block_hashes[0], 0)], 0) # Evict the last block of all layers, reduces the hit length to 2. test_partial_request_hit("4", [ - BlockHashWithGroupId(block_hashes[2], 0), - BlockHashWithGroupId(block_hashes[2], 1), - BlockHashWithGroupId(block_hashes[2], 2), + make_block_hash_with_group_id(block_hashes[2], 0), + make_block_hash_with_group_id(block_hashes[2], 1), + make_block_hash_with_group_id(block_hashes[2], 2), ], 2) # Evict the last block of full attention, reduces the hit length to 2. - test_partial_request_hit("5", [BlockHashWithGroupId(block_hashes[2], 0)], - 2) + test_partial_request_hit( + "5", [make_block_hash_with_group_id(block_hashes[2], 0)], 2) # Evict the last block of sliding window, reduces the hit length to 2. - test_partial_request_hit("6", [BlockHashWithGroupId(block_hashes[2], 1)], - 2) + test_partial_request_hit( + "6", [make_block_hash_with_group_id(block_hashes[2], 1)], 2) # Evict the last block of sliding window, reduces the hit length to 2. - test_partial_request_hit("7", [BlockHashWithGroupId(block_hashes[2], 2)], - 2) + test_partial_request_hit( + "7", [make_block_hash_with_group_id(block_hashes[2], 2)], 2) # Evict different set of blocks for full attention and sliding window makes # total cache miss. @@ -353,9 +355,9 @@ def test_prefill_hybrid_model(): # The cache hit length of sliding window is 2 * block_size. # Then it is cache miss as the two type of layers have different hit length. test_partial_request_hit("8", [ - BlockHashWithGroupId(block_hashes[2], 0), - BlockHashWithGroupId(block_hashes[0], 1), - BlockHashWithGroupId(block_hashes[0], 2), + make_block_hash_with_group_id(block_hashes[2], 0), + make_block_hash_with_group_id(block_hashes[0], 1), + make_block_hash_with_group_id(block_hashes[0], 2), ], 0) @@ -372,8 +374,8 @@ def test_prefill_plp(): max_model_len=8192, enable_caching=True, ) - # the default hash function is hash - hash_fn = hash + # the default hash function is sha256 + hash_fn = sha256 # Complete 3 blocks (48 tokens) common_token_ids = [i for i in range(3) for _ in range(16)] @@ -404,10 +406,12 @@ def test_prefill_plp(): block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16]) block_hash = hash_block_tokens(hash_fn, parent_block_hash, block_tokens) - assert manager.block_pool.blocks[ - block_id].block_hash.block_hash == block_hash + blk_hash = (manager.block_pool.blocks[block_id].block_hash) + assert blk_hash is not None + assert get_block_hash(blk_hash) == block_hash + assert get_group_id(blk_hash) == 0 assert manager.block_pool.blocks[block_id].ref_cnt == 1 - parent_block_hash = block_hash.hash_value + parent_block_hash = block_hash # Check partial block metadata for block_id in (4, ): @@ -493,7 +497,7 @@ def test_decode(): # Incomplete 1 block (7 tokens) unique_token_ids = [3] * 7 req0 = make_request("0", common_token_ids + unique_token_ids, block_size, - hash) + sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -538,7 +542,7 @@ def test_evict(): ) last_token_id = 5 * 16 + 7 - req0 = make_request("0", list(range(last_token_id)), block_size, hash) + req0 = make_request("0", list(range(last_token_id)), block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -550,7 +554,7 @@ def test_evict(): # 3 blocks. req1 = make_request("1", list(range(last_token_id, last_token_id + 3 * 16)), block_size, - hash) + sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -572,7 +576,7 @@ def test_evict(): ] == [10, 6, 5, 4, 3, 2, 1, 9, 8, 7] # Touch the first 2 blocks. - req2 = make_request("2", list(range(2 * 16 + 3)), block_size, hash) + req2 = make_request("2", list(range(2 * 16 + 3)), block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert computed_blocks.get_block_ids() == ([1, 2], ) assert num_computed_tokens == 2 * 16 @@ -597,7 +601,7 @@ def test_hash_block_correct_reuse(): # Allocate 1 block and cache it. num_tokens = block_size * 1 - req = make_request("0", list(range(num_tokens)), block_size, hash) + req = make_request("0", list(range(num_tokens)), block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -611,7 +615,7 @@ def test_hash_block_correct_reuse(): # Allocate a new block that's not full, make sure hash info on the # block is cleared. - req = make_request("1", list(range(num_tokens - 1)), block_size, hash) + req = make_request("1", list(range(num_tokens - 1)), block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -638,7 +642,7 @@ def test_computed_blocks_not_evicted(): # Allocate a block and cache it. num_tokens = block_size * 1 - req0 = make_request("0", list(range(num_tokens)), block_size, hash) + req0 = make_request("0", list(range(num_tokens)), block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -650,7 +654,7 @@ def test_computed_blocks_not_evicted(): # Allocate another block. req1 = make_request("1", list(range(num_tokens, num_tokens * 2)), - block_size, hash) + block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -666,7 +670,7 @@ def test_computed_blocks_not_evicted(): # Now if we have a cache hit on the first block, we should evict the second # cached block rather than the first one. - req2 = make_request("2", list(range(num_tokens * 2)), block_size, hash) + req2 = make_request("2", list(range(num_tokens * 2)), block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(computed_blocks.blocks[0]) == 1 assert computed_blocks.blocks[0][0].block_id == 1 @@ -691,7 +695,7 @@ def test_basic_prefix_caching_disabled(): ) req1 = make_request("1", list(range(10)), block_size, - hash) # 2 blocks and some more + sha256) # 2 blocks and some more computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert not computed_blocks.blocks[0] @@ -706,7 +710,7 @@ def test_basic_prefix_caching_disabled(): # No caching. req2 = make_request("2", list(range(16)), block_size, - hash) # shared prefix + sha256) # shared prefix computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -716,7 +720,7 @@ def test_basic_prefix_caching_disabled(): assert len(blocks.blocks[0]) == 4 # New requests should not have any blocks. - req3 = make_request("3", list(range(4)), block_size, hash) + req3 = make_request("3", list(range(4)), block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -726,7 +730,7 @@ def test_basic_prefix_caching_disabled(): assert not blocks -@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) +@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) def test_cache_blocks(hash_fn): """ This is a unit test that tests the correctness of the _cache_full_blocks @@ -787,7 +791,7 @@ def test_cache_blocks_multi_group(): # Block 1/5: [4, 5, 6, 7] # Block 2/6: [8, 9, 10, 11] # Block 3/7: [12, 13] - req = make_request("0", list(range(14)), block_size, hash) + req = make_request("0", list(range(14)), block_size, sha256) # Cache the blocks for group 0. blocks = [KVCacheBlock(block_id=i) for i in range(2)] @@ -845,6 +849,8 @@ def test_mm_prefix_caching(): """ This tests that the multi-modal prefix caching is correct. """ + kv_cache_utils.init_none_hash(sha256) + block_size = 16 manager = KVCacheManager( make_kv_cache_config(block_size, 11), @@ -874,23 +880,30 @@ def test_mm_prefix_caching(): req0 = make_request("0", all_token_ids, block_size, - hash, + sha256, mm_positions=mm_positions, mm_hashes=mm_hashes) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - # Completed block should have hashes with extra keys. + # Completed block should have hashes assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 block_hashes = req0.block_hashes assert len(block_hashes) == 3 - assert block_hashes[0].extra_keys == ("aaa", ) - assert block_hashes[1].extra_keys == ("aaa", "bbb") - assert block_hashes[2].extra_keys == ("bbb", ) + assert block_hashes[0] == sha256( + (kv_cache_utils.NONE_HASH, tuple(all_token_ids[:block_size]), + ("aaa", ))) + assert block_hashes[1] == sha256( + (block_hashes[0], tuple(all_token_ids[block_size:block_size * 2]), + ("aaa", "bbb"))) + assert block_hashes[2] == sha256( + (block_hashes[1], tuple(all_token_ids[block_size * 2:block_size * 3]), + ("bbb", ))) blocks = manager.allocate_slots(req0, 59, len(computed_blocks.blocks[0]) * 16, computed_blocks) + assert blocks is not None assert blocks.get_block_ids() == ([1, 2, 3, 4], ) req0.num_computed_tokens = 59 @@ -901,10 +914,10 @@ def test_mm_prefix_caching(): len(computed_blocks.blocks[0]) * 16, computed_blocks) assert new_blocks is not None and len(new_blocks.blocks[0]) == 0 - - # The just completed block should have hashes with extra keys. assert len(block_hashes) == 4 - assert block_hashes[3].extra_keys == ("ccc", ) + assert block_hashes[3] == sha256( + (block_hashes[2], tuple(all_token_ids[3 * block_size:] + [8] * 5), + ("ccc", ))) # Cache hit. unique_token_ids = [-1] * 7 + [200] * 5 @@ -916,7 +929,7 @@ def test_mm_prefix_caching(): req1 = make_request("1", all_token_ids, block_size, - hash, + sha256, mm_positions=mm_positions, mm_hashes=mm_hashes) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) @@ -929,6 +942,8 @@ def test_cache_key_salting(): This tests that cache salts are applied during hashing and the cache is separated cache as expected. """ + kv_cache_utils.init_none_hash(sha256) + block_size = 16 manager = KVCacheManager( make_kv_cache_config(block_size, 11), @@ -939,21 +954,26 @@ def test_cache_key_salting(): # 3 complete blocks and an incomplete block with 11 tokens. common_token_ids = [i for i in range(3) for _ in range(block_size)] token_ids = common_token_ids + [3] * 11 - req0 = make_request("0", token_ids, block_size, hash, cache_salt="salt1") + req0 = make_request("0", token_ids, block_size, sha256, cache_salt="salt1") computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - # Completed block should have hashes with extra keys. + # Completed block should have hashes assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 block_hashes = req0.block_hashes assert len(block_hashes) == 3 - assert block_hashes[0].extra_keys == ("salt1", ) - assert block_hashes[1].extra_keys is None - assert block_hashes[2].extra_keys is None + assert block_hashes[0] == sha256( + (kv_cache_utils.NONE_HASH, tuple(token_ids[:block_size]), ("salt1", ))) + assert block_hashes[1] == sha256( + (block_hashes[0], tuple(token_ids[block_size:block_size * 2]), None)) + assert block_hashes[2] == sha256( + (block_hashes[1], tuple(token_ids[block_size * 2:block_size * 3]), + None)) blocks = manager.allocate_slots(req0, 59, len(computed_blocks.blocks[0]) * 16, computed_blocks) + assert blocks is not None assert blocks.get_block_ids() == ([1, 2, 3, 4], ) req0.num_computed_tokens = 59 @@ -964,14 +984,13 @@ def test_cache_key_salting(): len(computed_blocks.blocks[0]) * 16, computed_blocks) assert new_blocks is not None and len(new_blocks.blocks[0]) == 0 - - # Now one more block that should not have extra keys. assert len(block_hashes) == 4 - assert block_hashes[3].extra_keys is None + assert block_hashes[3] == sha256( + (block_hashes[2], tuple(token_ids[3 * block_size:] + [8] * 5), None)) # Test cache hit with a new request that has the same salt. token_ids = common_token_ids + [4] * 11 - req1 = make_request("1", token_ids, block_size, hash, cache_salt="salt1") + req1 = make_request("1", token_ids, block_size, sha256, cache_salt="salt1") computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) # Should match only a prefix of 3 blocks. assert len(computed_blocks.blocks[0]) == 3 @@ -979,13 +998,19 @@ def test_cache_key_salting(): # Test cache miss with same content but different salt. token_ids = common_token_ids + [4] * 11 - req2 = make_request("2", token_ids, block_size, hash, cache_salt="salt2") + req2 = make_request("2", token_ids, block_size, sha256, cache_salt="salt2") computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert len(computed_blocks.blocks[0]) == 0 assert num_computed_tokens == 0 block_hashes = req2.block_hashes assert len(block_hashes) == 3 - assert block_hashes[0].extra_keys == ("salt2", ) + assert block_hashes[0] == sha256( + (kv_cache_utils.NONE_HASH, tuple(token_ids[:block_size]), ("salt2", ))) + assert block_hashes[1] == sha256( + (block_hashes[0], tuple(token_ids[block_size:block_size * 2]), None)) + assert block_hashes[2] == sha256( + (block_hashes[1], tuple(token_ids[block_size * 2:block_size * 3]), + None)) def test_prefill_not_enough_free_blocks_with_computed_blocks(): @@ -1004,7 +1029,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # Complete 3 blocks (48 tokens) # | Common-0 | Common-1 | Common-2 | ... | common_token_ids = [i for i in range(3) for _ in range(16)] - req0 = make_request("0", common_token_ids, block_size, hash) + req0 = make_request("0", common_token_ids, block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -1015,7 +1040,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): req0.request_id] # | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... | - req1 = make_request("1", common_token_ids * 2, block_size, hash) + req1 = make_request("1", common_token_ids * 2, block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) assert computed_blocks.blocks[0] == block_part0 assert num_computed_tokens == 3 * 16 @@ -1032,7 +1057,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) | # | Req1-5(F)| Req2-0 | Req2-1 | ... | - req2 = make_request("2", [7] * block_size * 2, block_size, hash) + req2 = make_request("2", [7] * block_size * 2, block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -1044,7 +1069,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): # but it cannot be allocated due to insufficient free blocks (2). # In this case, the ref_cnt of the computed blocks should not be changed. assert manager.block_pool.free_block_queue.num_free_blocks == 5 - req3 = make_request("3", common_token_ids * 3, block_size, hash) + req3 = make_request("3", common_token_ids * 3, block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) assert computed_blocks.blocks[0] == block_part1 assert num_computed_tokens == 6 * 16 @@ -1069,13 +1094,13 @@ def test_reset_prefix_cache(): full_block_token_ids = [i for i in range(3) for _ in range(16)] unique_token_ids = [3] * 7 all_token_ids = full_block_token_ids + unique_token_ids - req0 = make_request("0", all_token_ids, block_size, hash) + req0 = make_request("0", all_token_ids, block_size, sha256) blocks = manager.allocate_slots(req0, 55) assert blocks.get_block_ids() == ([1, 2, 3, 4], ) unique_token_ids = [4] * 7 all_token_ids = full_block_token_ids + unique_token_ids - req1 = make_request("1", all_token_ids, block_size, hash) + req1 = make_request("1", all_token_ids, block_size, sha256) computed_blocks, _ = manager.get_computed_blocks(req1) assert len(req1.block_hashes) == 3 assert len(computed_blocks.blocks[0]) == 3 @@ -1109,7 +1134,7 @@ def test_prefix_cache_stats_disabled(): assert manager.prefix_cache_stats is None # Call all functions that check whether log_stats is disabled. - req = make_request("0", list(range(16)), block_size, hash) + req = make_request("0", list(range(16)), block_size, sha256) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) assert not computed_blocks.blocks[0] assert num_computed_tokens == 0 @@ -1124,15 +1149,9 @@ def test_prefix_cache_stats_disabled(): def test_maybe_evict_cached_block(): pool = BlockPool(num_gpu_blocks=4, enable_caching=True) - block_hash0 = BlockHashWithGroupId(block_hash=BlockHash(hash_value=10, - token_ids=(100, )), - group_id=1000) - block_hash1 = BlockHashWithGroupId(block_hash=BlockHash(hash_value=20, - token_ids=(200, )), - group_id=2000) - block_hash2 = BlockHashWithGroupId(block_hash=BlockHash(hash_value=30, - token_ids=(300, )), - group_id=3000) + block_hash0 = make_block_hash_with_group_id(BlockHash(b"10"), 1000) + block_hash1 = make_block_hash_with_group_id(BlockHash(b"20"), 2000) + block_hash2 = make_block_hash_with_group_id(BlockHash(b"30"), 3000) block_hashes = [ block_hash0, block_hash1, @@ -1206,7 +1225,7 @@ def test_kv_cache_events(blocks_to_cache: int): ) num_tokens = block_size * blocks_to_cache - req0 = make_request("0", list(range(num_tokens)), block_size, hash) + req0 = make_request("0", list(range(num_tokens)), block_size, sha256) _ = manager.allocate_slots(req0, num_tokens) events = manager.take_events() @@ -1222,7 +1241,7 @@ def test_kv_cache_events(blocks_to_cache: int): # Should see block_to_cache number of removed block events and a new block # stored event manager.free(req0) - req1 = make_request("1", list(range(num_tokens)), block_size, hash) + req1 = make_request("1", list(range(num_tokens)), block_size, sha256) _ = manager.allocate_slots(req1, num_tokens) events = manager.take_events() @@ -1256,7 +1275,7 @@ def test_eagle_enabled_removes_last_block(): # Request with 3 full blocks (48 tokens) token_ids = [0] * (3 * block_size) - req = make_request("divisible_request", token_ids, block_size, hash) + req = make_request("divisible_request", token_ids, block_size, sha256) # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) @@ -1266,7 +1285,7 @@ def test_eagle_enabled_removes_last_block(): manager.free(req) # New request with same tokens + Eagle enabled - req_eagle = make_request("eagle_divisible", token_ids, block_size, hash) + req_eagle = make_request("eagle_divisible", token_ids, block_size, sha256) computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) # Should retain 1 block: @@ -1287,7 +1306,7 @@ def test_eagle_with_partial_blocks(): ) # 2 full blocks + 5 tokens (non-divisible length) token_ids = [0] * (2 * block_size + 5) - req = make_request("partial_block_test", token_ids, block_size, hash) + req = make_request("partial_block_test", token_ids, block_size, sha256) # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) @@ -1297,7 +1316,7 @@ def test_eagle_with_partial_blocks(): manager.free(req) # New request with Eagle enabled - req_eagle = make_request("partial_eagle", token_ids, block_size, hash) + req_eagle = make_request("partial_eagle", token_ids, block_size, sha256) computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) # Original match: 2 full blocks → Eagle removes 1 → 1 remaining assert len(computed_blocks.blocks[0]) == 1 @@ -1328,7 +1347,7 @@ def test_eagle_with_sliding_window(): # 2 full blocks + 5 tokens (non-divisible length) token_ids = [0] * (2 * block_size + 5) - req = make_request("partial_block_test", token_ids, block_size, hash) + req = make_request("partial_block_test", token_ids, block_size, sha256) # Prime the cache computed_blocks, _ = manager.get_computed_blocks(req) @@ -1341,7 +1360,7 @@ def test_eagle_with_sliding_window(): manager.free(req) # New request with Eagle enabled - req_eagle = make_request("partial_eagle", token_ids, block_size, hash) + req_eagle = make_request("partial_eagle", token_ids, block_size, sha256) computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) # Original match: 2 full blocks → Eagle removes 1 → 1 remaining assert len(computed_blocks.blocks[0]) == 1 @@ -1351,11 +1370,11 @@ def test_eagle_with_sliding_window(): assert manager.block_pool.get_cached_block( block_hash_first_block, kv_cache_group_ids=[0]) is not None manager.block_pool.cached_block_hash_to_block.pop( - BlockHashWithGroupId(block_hash_first_block, 0)) + make_block_hash_with_group_id(block_hash_first_block, 0)) # New request req_after_evict = make_request("partial_eagle_after_evict", token_ids, - block_size, hash) + block_size, sha256) computed_blocks, num_tokens = manager.get_computed_blocks(req_after_evict) # Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is # not considered. But after dropping the last matched block due to eagle, diff --git a/tests/v1/core/test_single_type_kv_cache_manager.py b/tests/v1/core/test_single_type_kv_cache_manager.py index 7dcebba491fab..b70850a9bcff9 100644 --- a/tests/v1/core/test_single_type_kv_cache_manager.py +++ b/tests/v1/core/test_single_type_kv_cache_manager.py @@ -6,8 +6,8 @@ import random import torch from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, - KVCacheBlock) +from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock, + make_block_hash_with_group_id) from vllm.v1.core.single_type_kv_cache_manager import ( ChunkedLocalAttentionManager, SlidingWindowManager) from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, @@ -44,7 +44,7 @@ def test_chunked_local_attention_possible_cached_prefix(): def run_one_case(block_is_cached, tail_token, expect_length): block_hash_list = [ - BlockHash(i, ()) for i in range(len(block_is_cached)) + BlockHash(str(i).encode()) for i in range(len(block_is_cached)) ] block_pool.cached_block_hash_to_block.clear() @@ -53,8 +53,8 @@ def test_chunked_local_attention_possible_cached_prefix(): for i, (block_hash, is_cached) in enumerate(zip(block_hash_list, block_is_cached)): if is_cached: - block_pool.cached_block_hash_to_block[BlockHashWithGroupId( - block_hash, 0)] = { + block_pool.cached_block_hash_to_block[ + make_block_hash_with_group_id(block_hash, 0)] = { i: block_pool.blocks[i + 10], } @@ -109,7 +109,7 @@ def test_sliding_window_possible_cached_prefix(): def run_one_case(block_is_cached, expect_length): block_hash_list = [ - BlockHash(i, ()) for i in range(len(block_is_cached)) + BlockHash(str(i).encode()) for i in range(len(block_is_cached)) ] block_pool.cached_block_hash_to_block.clear() @@ -118,8 +118,8 @@ def test_sliding_window_possible_cached_prefix(): for i, (block_hash, is_cached) in enumerate(zip(block_hash_list, block_is_cached)): if is_cached: - block_pool.cached_block_hash_to_block[BlockHashWithGroupId( - block_hash, 0)] = { + block_pool.cached_block_hash_to_block[ + make_block_hash_with_group_id(block_hash, 0)] = { i: block_pool.blocks[i + 10], } diff --git a/tests/v1/core/utils.py b/tests/v1/core/utils.py index e392c2c336e9b..d343141cdf4cb 100644 --- a/tests/v1/core/utils.py +++ b/tests/v1/core/utils.py @@ -9,6 +9,7 @@ from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, from vllm.multimodal.inputs import (MultiModalFeatureSpec, MultiModalKwargsItem, PlaceholderRange) from vllm.sampling_params import SamplingParams +from vllm.utils import sha256 from vllm.v1.core.kv_cache_utils import (get_request_block_hasher, init_none_hash) from vllm.v1.core.sched.async_scheduler import AsyncScheduler @@ -130,10 +131,10 @@ def create_requests( ) -> list[Request]: global _none_hash_initialized if not _none_hash_initialized: - init_none_hash(hash) + init_none_hash(sha256) _none_hash_initialized = True - block_hasher = get_request_block_hasher(block_size, hash) + block_hasher = get_request_block_hasher(block_size, sha256) sampling_params = SamplingParams(ignore_eos=False, max_tokens=max_tokens, stop_token_ids=stop_token_ids, diff --git a/tests/v1/engine/test_engine_args.py b/tests/v1/engine/test_engine_args.py index f70a3ce147ff2..23ec3673b10b4 100644 --- a/tests/v1/engine/test_engine_args.py +++ b/tests/v1/engine/test_engine_args.py @@ -36,18 +36,19 @@ def test_prefix_caching_from_cli(): assert vllm_config.cache_config.enable_prefix_caching # default hash algorithm is "builtin" - assert vllm_config.cache_config.prefix_caching_hash_algo == "builtin" + assert vllm_config.cache_config.prefix_caching_hash_algo == "sha256" + + # set hash algorithm to sha256_cbor + args = parser.parse_args(["--prefix-caching-hash-algo", "sha256_cbor"]) + vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config() + assert vllm_config.cache_config.prefix_caching_hash_algo == \ + "sha256_cbor" # set hash algorithm to sha256 args = parser.parse_args(["--prefix-caching-hash-algo", "sha256"]) vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config() assert vllm_config.cache_config.prefix_caching_hash_algo == "sha256" - # set hash algorithm to builtin - args = parser.parse_args(["--prefix-caching-hash-algo", "builtin"]) - vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config() - assert vllm_config.cache_config.prefix_caching_hash_algo == "builtin" - # an invalid hash algorithm raises an error parser.exit_on_error = False with pytest.raises(ArgumentError): diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index 3f068d5e8c7eb..0cae1c7bc0518 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -13,6 +13,7 @@ from vllm.distributed.kv_transfer.kv_connector.factory import ( KVConnectorFactory) from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa SharedStorageConnector) +from vllm.utils import sha256 from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_utils import (get_request_block_hasher, init_none_hash) @@ -127,11 +128,11 @@ def create_request(request_id: int, use_all_1s_for_prompt_tokens: bool = False, num_remote_blocks: int = 3, block_size: int = 16, - hash_fn: Callable = hash) -> Request: + hash_fn: Callable = sha256) -> Request: """Make dummy request for testing.""" global _none_hash_initialized if not _none_hash_initialized: - init_none_hash(hash) + init_none_hash(hash_fn) _none_hash_initialized = True kv_transfer_params: Optional[dict[str, Any]] = None diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 5cc630b72846d..bf85aad452d0f 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -24,7 +24,7 @@ logger = init_logger(__name__) BlockSize = Literal[1, 8, 16, 32, 64, 128] CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"] MambaDType = Literal["auto", "float32"] -PrefixCachingHashAlgo = Literal["builtin", "sha256", "sha256_cbor_64bit"] +PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"] @config @@ -63,17 +63,12 @@ class CacheConfig: """Sliding window size for the KV cache. This is primarily set in `ModelConfig` and that value should be manually duplicated here.""" enable_prefix_caching: Optional[bool] = None - """Whether to enable prefix caching. Disabled by default for V0. Enabled by - default for V1.""" - prefix_caching_hash_algo: PrefixCachingHashAlgo = "builtin" + """Whether to enable prefix caching. Enabled by default for V1.""" + prefix_caching_hash_algo: PrefixCachingHashAlgo = "sha256" """Set the hash algorithm for prefix caching:\n - - "builtin" is Python's built-in hash.\n - - "sha256" is collision resistant but with certain overheads. - This option uses Pickle for object serialization before hashing.\n - - "sha256_cbor_64bit" provides a reproducible, cross-language compatible - hash. It serializes objects using canonical CBOR and hashes them with - SHA-256. The resulting hash consists of the lower 64 bits of the SHA-256 - digest.""" + - "sha256" uses Pickle for object serialization before hashing.\n + - "sha256_cbor" provides a reproducible, cross-language compatible hash. It + serializes objects using canonical CBOR and hashes them with SHA-256.""" cpu_offload_gb: float = 0 """The space in GiB to offload to CPU, per GPU. Default is 0, which means no offloading. Intuitively, this argument can be seen as a virtual way to diff --git a/vllm/distributed/kv_events.py b/vllm/distributed/kv_events.py index 09f42b550fe2b..46f0cd9289b23 100644 --- a/vllm/distributed/kv_events.py +++ b/vllm/distributed/kv_events.py @@ -16,6 +16,7 @@ import zmq from vllm.config.kv_events import KVEventsConfig from vllm.logger import init_logger +from vllm.v1.core.kv_cache_utils import ExternalBlockHash logger = init_logger(__name__) @@ -44,8 +45,8 @@ MEDIUM_GPU = "GPU" class BlockStored(KVCacheEvent): - block_hashes: list[int] - parent_block_hash: Optional[int] + block_hashes: list[ExternalBlockHash] + parent_block_hash: Optional[ExternalBlockHash] token_ids: list[int] block_size: int lora_id: Optional[int] @@ -53,7 +54,7 @@ class BlockStored(KVCacheEvent): class BlockRemoved(KVCacheEvent): - block_hashes: list[int] + block_hashes: list[ExternalBlockHash] medium: Optional[str] diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index bee97f4cd04d8..94c984116131d 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1592,20 +1592,12 @@ class EngineArgs: "in low performance due to small KV cache size. Consider " "setting --max-model-len to a smaller value.", max_model_len) - # if using prefix caching, we must set a hash algo - if self.enable_prefix_caching: - # Disable prefix caching for multimodal models for VLLM_V0. - if model_config.is_multimodal_model: - logger.warning( - "--enable-prefix-caching is not supported for multimodal " - "models in V0 and has been disabled.") - self.enable_prefix_caching = False - - # VLLM_V0 only supports builtin hash algo for prefix caching. - if self.prefix_caching_hash_algo == "sha256": - raise ValueError( - "sha256 is not supported for prefix caching in V0 engine. " - "Please use 'builtin'.") + # Disable prefix caching for multimodal models for VLLM_V0. + if self.enable_prefix_caching and model_config.is_multimodal_model: + logger.warning( + "--enable-prefix-caching is not supported for multimodal " + "models in V0 and has been disabled.") + self.enable_prefix_caching = False # Set max_num_seqs to 256 for VLLM_V0. if self.max_num_seqs is None: diff --git a/vllm/envs.py b/vllm/envs.py index 927bea3bf9538..8d199da45b082 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -171,6 +171,7 @@ if TYPE_CHECKING: VLLM_GPT_OSS_USE_CONTAINER_TOOL: bool = False VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False + VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True def get_default_cache_root(): @@ -1215,6 +1216,11 @@ environment_variables: dict[str, Callable[[], Any]] = { # Add optional custom scopes for profiling, disable to avoid overheads "VLLM_CUSTOM_SCOPES_FOR_PROFILING": lambda: bool(int(os.getenv("VLLM_CUSTOM_SCOPES_FOR_PROFILING", "0"))), + + # Represent block hashes in KV cache events as 64-bit integers instead of + # raw bytes. Defaults to True for backward compatibility. + "VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES": + lambda: bool(int(os.getenv("VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES", "1"))), } # --8<-- [end:env-vars-definition] diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 9c78e56d580e0..6d0cb3710bb93 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -3249,7 +3249,7 @@ def check_use_alibi(model_config: ModelConfig) -> bool: and getattr(cfg.attn_config, "alibi", False))))) -def sha256(input) -> int: +def sha256(input) -> bytes: """Hash any picklable Python object using SHA-256. The input is serialized using pickle before hashing, which allows @@ -3260,16 +3260,15 @@ def sha256(input) -> int: input: Any picklable Python object. Returns: - An integer representing the SHA-256 hash of the serialized input. + Bytes representing the SHA-256 hash of the serialized input. """ input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) - return int.from_bytes(hashlib.sha256(input_bytes).digest(), - byteorder="big") + return hashlib.sha256(input_bytes).digest() -def sha256_cbor_64bit(input) -> int: +def sha256_cbor(input) -> bytes: """ - Hash objects using CBOR serialization and SHA-256, then truncate to 64bits. + Hash objects using CBOR serialization and SHA-256. This option is useful for non-Python-dependent serialization and hashing. @@ -3280,17 +3279,13 @@ def sha256_cbor_64bit(input) -> int: Custom classes must implement CBOR serialization methods. Returns: - An integer in the range [0, 2^64-1] representing the lower 64 bits - of the SHA-256 hash of the CBOR serialized input. + Bytes representing the SHA-256 hash of the CBOR serialized input. """ input_bytes = cbor2.dumps(input, canonical=True) - full_hash = int.from_bytes(hashlib.sha256(input_bytes).digest(), - byteorder="big") - - return full_hash & ((1 << 64) - 1) + return hashlib.sha256(input_bytes).digest() -def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], int]: +def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]: """Get a hash function by name, or raise an error if the function is not found. Args: @@ -3300,10 +3295,8 @@ def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], int]: """ if hash_fn_name == "sha256": return sha256 - if hash_fn_name == "sha256_cbor_64bit": - return sha256_cbor_64bit - if hash_fn_name == "builtin": - return hash + if hash_fn_name == "sha256_cbor": + return sha256_cbor raise ValueError(f"Unsupported hash function: {hash_fn_name}") diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index b537cac8e1d72..d1e1c1c8d0382 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -9,7 +9,11 @@ from vllm.distributed.kv_events import (MEDIUM_GPU, AllBlocksCleared, KVCacheEvent) from vllm.logger import init_logger from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, - FreeKVCacheBlockQueue, KVCacheBlock) + ExternalBlockHash, + FreeKVCacheBlockQueue, KVCacheBlock, + get_block_hash, + make_block_hash_with_group_id, + maybe_convert_block_hash) from vllm.v1.request import Request logger = init_logger(__name__) @@ -84,8 +88,10 @@ class BlockPool: """ cached_blocks = [] for group_id in kv_cache_group_ids: + block_hash_with_group_id = make_block_hash_with_group_id( + block_hash, group_id) cached_blocks_one_group = self.cached_block_hash_to_block.get( - BlockHashWithGroupId(block_hash, group_id)) + block_hash_with_group_id) if not cached_blocks_one_group: return None first_block = next(iter(cached_blocks_one_group.values())) @@ -124,28 +130,29 @@ class BlockPool: assert len(request.block_hashes) >= num_full_blocks new_block_hashes = request.block_hashes[num_cached_blocks:] - new_hashes: Optional[list[int]] = ([] if self.enable_kv_cache_events - else None) + new_hashes: Optional[list[ExternalBlockHash]] = ( + [] if self.enable_kv_cache_events else None) for i, blk in enumerate(new_full_blocks): assert blk.block_hash is None block_hash = new_block_hashes[i] # Update and added the full block to the cache. - block_hash_with_group_id = BlockHashWithGroupId( + block_hash_with_group_id = make_block_hash_with_group_id( block_hash, kv_cache_group_id) blk.block_hash = block_hash_with_group_id self.cached_block_hash_to_block[block_hash_with_group_id][ blk.block_id] = blk if new_hashes is not None: - new_hashes.append(block_hash.hash_value) + new_hashes.append(maybe_convert_block_hash(block_hash)) if self.enable_kv_cache_events: if num_cached_blocks == 0: - parent_block_hash = None + parent_block_hash: Optional[ExternalBlockHash] = None else: parent_block = blocks[num_cached_blocks - 1] assert parent_block.block_hash is not None - parent_block_hash = parent_block.block_hash.get_hash_value() + parent_block_hash = maybe_convert_block_hash( + get_block_hash(parent_block.block_hash)) self.kv_event_queue.append( BlockStored( @@ -220,7 +227,9 @@ class BlockPool: # we disable hybrid kv cache manager when kv cache event is # enabled, so there is only one group. self.kv_event_queue.append( - BlockRemoved(block_hashes=[block_hash.get_hash_value()], + BlockRemoved(block_hashes=[ + maybe_convert_block_hash(get_block_hash(block_hash)) + ], medium=MEDIUM_GPU)) return True diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index aff1183e499a4..2c0eac3ddd79d 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -6,11 +6,12 @@ import os from collections import defaultdict, deque from collections.abc import Iterable, Sequence from dataclasses import astuple, dataclass -from typing import Any, Callable, NamedTuple, Optional +from typing import Any, Callable, NewType, Optional, Union +from vllm import envs from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.utils import GiB_bytes, cdiv, sha256_cbor_64bit +from vllm.utils import GiB_bytes, cdiv, sha256_cbor from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheSpec, @@ -18,59 +19,78 @@ from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request -logger = init_logger(__name__) +# BlockHash represents the hash of a single KV-cache block used for +# prefix caching. Treating it as a distinct type from ``bytes`` helps +# catch accidental misuse when passing around raw byte strings. +BlockHash = NewType("BlockHash", bytes) + +# ``BlockHashWithGroupId`` combines a ``BlockHash`` with its KV cache group ID. +# It is represented as raw bytes for compactness and efficiency. The helper +# functions below pack/unpack the ``BlockHash`` and group id into/from the key. +BlockHashWithGroupId = NewType("BlockHashWithGroupId", bytes) + +# ExternalBlockHash is used for reproducible prefix-cache block hashing. +# It's a union of ``bytes`` and ``int`` to keep backward compatibility +# after we default block hashing to use sha256 bytes. +ExternalBlockHash = Union[bytes, int] -class BlockHash(NamedTuple): - """Hash value of a block (int), the token IDs in the block, and extra keys. - We keep a tuple of token IDs and extra keys to reduce the likelihood of - hash collisions when the hash value is the same. By using SHA256 however, - hash collisions are practically impossible. +def make_block_hash_with_group_id(block_hash: BlockHash, + group_id: int) -> BlockHashWithGroupId: + """Pack a ``BlockHash`` and group id into a ``BlockHashWithGroupId``. + + The group id is encoded using 4 bytes in big-endian order and appended to + the block hash bytes. This representation avoids creating tuples while + still allowing us to recover both components when needed. """ - # Hash value of the block in an integer. - hash_value: int - # Token IDs in the block. - token_ids: tuple[int, ...] - # Extra keys for the block. - extra_keys: Optional[Any] = None + return BlockHashWithGroupId(block_hash + + group_id.to_bytes(4, "big", signed=False)) -class BlockHashWithGroupId(NamedTuple): - # The hash value for the contents (e.g., token_ids) of a block without group - # ID. The value is the same for blocks representing the same tokens but for - # different groups. - block_hash: BlockHash - # The KV cache group ID. - group_id: int +def get_block_hash(key: BlockHashWithGroupId) -> BlockHash: + """Extract the ``BlockHash`` from a ``BlockHashWithGroupId``.""" + return BlockHash(key[:-4]) - def get_hash_value(self) -> int: - return self.block_hash.hash_value +def get_group_id(key: BlockHashWithGroupId) -> int: + """Extract the group id from a ``BlockHashWithGroupId``.""" + return int.from_bytes(key[-4:], "big", signed=False) + + +def maybe_convert_block_hash(hash_bytes: BlockHash) -> ExternalBlockHash: + if not envs.VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: + return hash_bytes + return int.from_bytes(hash_bytes, byteorder="big") & ((1 << 64) - 1) + + +logger = init_logger(__name__) # The hash seed for the first block of any prefix block sequence. # # We use a random value to avoid hash collisions or PYTHONHASHSEED environment -# variable if set such that processes can share the seed if needed. -# This aligns with the behavior of Python's hash() function, which also uses -# a random seed if PYTHONHASHSEED is not set. +# variable if set such that processes can share the seed if needed. This aligns +# with the behavior of Python's hash() function, which also uses a random seed +# if PYTHONHASHSEED is not set. # # The function `init_none_hash` initializes this variable globally. -NONE_HASH: int +NONE_HASH: BlockHash -def init_none_hash(hash_fn: Callable): +def init_none_hash(hash_fn: Callable[[Any], bytes]): global NONE_HASH hash_seed = os.getenv("PYTHONHASHSEED") - if hash_seed is None and hash_fn is sha256_cbor_64bit: + if hash_seed is None and hash_fn is sha256_cbor: logger.warning( "PYTHONHASHSEED is not set. This will lead to non-reproducible " - "block-hashes when using sha256_cbor_64bit as the hash function." + "block-hashes when using sha256_cbor as the hash function." "Consider setting PYTHONHASHSEED to a fixed value for " "reproducibility.") - NONE_HASH = (int.from_bytes(os.urandom(32), byteorder="big") - if hash_seed is None else hash_fn(hash_seed)) + if hash_seed is None: + NONE_HASH = BlockHash(os.urandom(32)) + else: + NONE_HASH = BlockHash(hash_fn(hash_seed)) class PrefixCachingMetrics: @@ -142,8 +162,8 @@ class KVCacheBlock: block_id: int # Reference count. ref_cnt: int = 0 - # The hash of the block composed of (block hash, tuple of token IDs). - # It is only available when the block is full. + # The hash key (block hash + group id) of the block, only available + # when the block is full and cached. _block_hash: Optional[BlockHashWithGroupId] = None # Used to construct a doubly linked list for free blocks. @@ -177,7 +197,7 @@ class KVCacheBlock: if self.next_free_block else None) return (f"KVCacheBlock(block_id={self.block_id}, " f"ref_cnt={self.ref_cnt}, " - f"_block_hash={self._block_hash}, " + f"_block_hash={self._block_hash!r}, " f"prev_free_block={prev_block_id}, " f"next_free_block={next_block_id})") @@ -517,15 +537,14 @@ def generate_block_hash_extra_keys( def hash_block_tokens( - hash_function: Callable, - parent_block_hash: Optional[int], + hash_function: Callable[[Any], bytes], + parent_block_hash: Optional[BlockHash], curr_block_token_ids: Sequence[int], extra_keys: Optional[tuple[Any, ...]] = None) -> BlockHash: """Computes a hash value corresponding to the contents of a block and the contents of the preceding block(s). The hash value is used for prefix caching. We use LRU cache for this function to avoid recomputing hash values for the same block contents. - Args: hash_function: The hash function used to compute block hash. parent_block_hash: The hash of the parent block. None @@ -533,7 +552,6 @@ def hash_block_tokens( curr_block_token_ids: A list of token ids in the current block. The current block is assumed to be full. extra_keys: Extra keys for the block. - Returns: The hash value of the block and the token ids in the block. The entire tuple is used as the hash key of the block. @@ -544,26 +562,16 @@ def hash_block_tokens( curr_block_token_ids_tuple = tuple(curr_block_token_ids) return BlockHash( hash_function( - (parent_block_hash, curr_block_token_ids_tuple, extra_keys)), - curr_block_token_ids_tuple, extra_keys) + (parent_block_hash, curr_block_token_ids_tuple, extra_keys))) def get_request_block_hasher( block_size: int, - caching_hash_fn: Callable[[Any], - int]) -> Callable[[Request], list[BlockHash]]: + caching_hash_fn: Callable[[Any], bytes], +) -> Callable[[Request], list[BlockHash]]: """ Returns a function which computes the list of un-computed block hashes - of a request. - - Each request holds a list of its block hashes (request.block_hashes). - When a request is created, it calls the below function to compute - the hashes of all full blocks of the request's initial tokens. - The hashes are then stored in request.block_hashes. - Later, whenever new tokens are appended to the request, it calls - the below function again to compute any new full blocks of tokens. - The returned new hashes are appended to request.block_hashes. - """ + of a request.""" def request_block_hasher(request: Request) -> list[BlockHash]: start_token_idx = len(request.block_hashes) * block_size @@ -577,8 +585,8 @@ def get_request_block_hasher( # last mm input. curr_mm_idx = -1 - prev_block_hash_value = request.block_hashes[-1].hash_value \ - if request.block_hashes else None + prev_block_hash_value = (request.block_hashes[-1] + if request.block_hashes else None) new_block_hashes: list[BlockHash] = [] while True: end_token_idx = start_token_idx + block_size @@ -598,7 +606,7 @@ def get_request_block_hasher( new_block_hashes.append(block_hash) start_token_idx += block_size - prev_block_hash_value = block_hash.hash_value + prev_block_hash_value = block_hash return new_block_hashes