From 48ddb02b79d7e22e2eefbf5294bf70de50afd1b2 Mon Sep 17 00:00:00 2001 From: Yifan Qiao Date: Tue, 25 Nov 2025 07:30:57 -0800 Subject: [PATCH] [Hybrid Allocator] Support KV cache groups with different block_size (#29143) Signed-off-by: Yifan Qiao Co-authored-by: Chen Zhang --- tests/v1/core/test_kv_cache_utils.py | 49 +++++- tests/v1/core/test_prefix_caching.py | 95 ++++++++++- .../core/test_single_type_kv_cache_manager.py | 22 ++- vllm/engine/arg_utils.py | 8 +- vllm/model_executor/models/config.py | 11 +- vllm/v1/core/block_pool.py | 22 ++- vllm/v1/core/kv_cache_coordinator.py | 98 ++++++++--- vllm/v1/core/kv_cache_manager.py | 25 +-- vllm/v1/core/kv_cache_utils.py | 156 +++++++++++++++--- vllm/v1/core/sched/scheduler.py | 1 + vllm/v1/core/single_type_kv_cache_manager.py | 72 +++++++- 11 files changed, 472 insertions(+), 87 deletions(-) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 12ed59b6e863b..58a7a2692bfc8 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -1248,7 +1248,9 @@ def test_allocate_with_lookahead(): ) # Test case 1: Requires additional lookahead tokens - kv_cache_manager = KVCacheManager(kv_cache_config=config, max_model_len=100) + kv_cache_manager = KVCacheManager( + kv_cache_config=config, max_model_len=100, hash_block_size=block_size + ) blocks = kv_cache_manager.allocate_slots( request, num_new_tokens=3, @@ -1257,7 +1259,9 @@ def test_allocate_with_lookahead(): assert len(blocks.get_block_ids()[0]) == 2 # ceil(5/4)=2 blocks # Test case 2: With precomputed blocks - kv_cache_manager = KVCacheManager(kv_cache_config=config, max_model_len=100) + kv_cache_manager = KVCacheManager( + kv_cache_config=config, max_model_len=100, hash_block_size=block_size + ) # required_blocks = ceil((3 + 2) /4) = 2 blocks = kv_cache_manager.allocate_slots( request, @@ -1268,7 +1272,9 @@ def test_allocate_with_lookahead(): # Test case 3: With precomputed blocks # required_blocks = ceil((3 + 4) / 4) = 2 - kv_cache_manager = KVCacheManager(kv_cache_config=config, max_model_len=100) + kv_cache_manager = KVCacheManager( + kv_cache_config=config, max_model_len=100, hash_block_size=block_size + ) blocks = kv_cache_manager.allocate_slots( request, num_new_tokens=3, @@ -1495,7 +1501,8 @@ def test_get_kv_cache_config_one_worker(): ), ], ) - # different hidden size + + # different hidden size but same type, use UniformTypeKVCacheSpecs kv_cache_specs_hybrid = { "layer_1": new_kv_cache_spec(head_size=128), "layer_2": new_kv_cache_spec(head_size=64), @@ -1519,6 +1526,40 @@ def test_get_kv_cache_config_one_worker(): ], ) + # Different hidden size and different type, align by different block size + kv_cache_specs_hybrid = { + "layer_1": new_kv_cache_spec(head_size=64), + "layer_2": new_sliding_window_spec(head_size=32), + } + kv_cache_config_hybrid = get_kv_cache_configs( + vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 32] + )[0] + assert kv_cache_config_hybrid == KVCacheConfig( + num_blocks=32, + kv_cache_tensors=[ + KVCacheTensor( + size=mem_per_block_per_layer * 32, shared_by=["layer_1", "layer_2"] + ), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer_1"], new_kv_cache_spec(head_size=64)), + KVCacheGroupSpec( + ["layer_2"], new_sliding_window_spec(head_size=32, block_size=32) + ), + ], + ) + + # different hidden size that cannot be aligned by using different block size + kv_cache_specs_hybrid = { + "layer_1": new_kv_cache_spec(head_size=64), + "layer_2": new_sliding_window_spec(head_size=96), + } + + with pytest.raises(NotImplementedError): + get_kv_cache_configs( + vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 2 * 32] + )[0] + # Test num_gpu_blocks_override vllm_config.cache_config.num_gpu_blocks_override = 16 kv_cache_config_override_blocks = get_kv_cache_configs( diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 2291f363731f2..64fd5ab1dd9aa 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -134,6 +134,7 @@ def test_prefill(hash_fn): make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, + hash_block_size=block_size, ) # Complete 3 blocks (48 tokens) @@ -256,6 +257,7 @@ def test_prefill_hybrid_model(): make_kv_cache_config_hybrid_model(block_size, 21), max_model_len=8192, enable_caching=True, + hash_block_size=block_size, ) hash_fn = sha256 @@ -416,6 +418,7 @@ def test_prefill_plp(): make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, + hash_block_size=block_size, ) # the default hash function is sha256 hash_fn = sha256 @@ -523,6 +526,7 @@ def test_decode(): make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, + hash_block_size=block_size, ) # Complete 3 blocks (48 tokens) @@ -585,6 +589,7 @@ def test_evict(): make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, + hash_block_size=block_size, ) last_token_id = 5 * 16 + 7 @@ -643,6 +648,7 @@ def test_hash_block_correct_reuse(): make_kv_cache_config(16, 2), max_model_len=8192, enable_caching=True, + hash_block_size=block_size, ) # Allocate 1 block and cache it. @@ -683,6 +689,7 @@ def test_computed_blocks_not_evicted(): make_kv_cache_config(block_size, 3), max_model_len=8192, enable_caching=True, + hash_block_size=block_size, ) # Allocate a block and cache it. @@ -741,6 +748,7 @@ def test_basic_prefix_caching_disabled(): make_kv_cache_config(block_size, 5), max_model_len=8192, enable_caching=False, + hash_block_size=block_size, ) req1 = make_request( @@ -790,6 +798,7 @@ def test_cache_blocks(hash_fn): block_pool = BlockPool( num_gpu_blocks=5, enable_caching=True, + hash_block_size=block_size, ) # Req: # Block 0: [0, 1, 2, 3] @@ -833,7 +842,9 @@ def test_cache_blocks_multi_group(): This tests that blocks are cached correctly for different kv cache groups. """ block_size = 4 - block_pool = BlockPool(num_gpu_blocks=10, enable_caching=True) + block_pool = BlockPool( + num_gpu_blocks=10, enable_caching=True, hash_block_size=block_size + ) # Req: # Block 0/4: [0, 1, 2, 3] @@ -921,6 +932,7 @@ def test_mm_prefix_caching(): make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, + hash_block_size=block_size, ) # Common prompt tokens (T is text tokens and P is image placeholder tokens) @@ -1020,6 +1032,7 @@ def test_cache_key_salting(): make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, + hash_block_size=block_size, ) # 3 complete blocks and an incomplete block with 11 tokens. @@ -1101,6 +1114,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, + hash_block_size=block_size, ) # Complete 3 blocks (48 tokens) # | Common-0 | Common-1 | Common-2 | ... | @@ -1173,6 +1187,7 @@ def test_reset_prefix_cache(): make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, + hash_block_size=block_size, ) full_block_token_ids = [i for i in range(3) for _ in range(16)] @@ -1213,6 +1228,7 @@ def test_prefix_cache_stats_disabled(): make_kv_cache_config(block_size, 11), max_model_len=8192, enable_caching=True, + hash_block_size=block_size, log_stats=False, # Disable logging stats ) assert manager.prefix_cache_stats is None @@ -1232,7 +1248,7 @@ def test_prefix_cache_stats_disabled(): def test_maybe_evict_cached_block(): - pool = BlockPool(num_gpu_blocks=4, enable_caching=True) + pool = BlockPool(num_gpu_blocks=4, enable_caching=True, hash_block_size=16) block_hash0 = make_block_hash_with_group_id(BlockHash(b"10"), 1000) block_hash1 = make_block_hash_with_group_id(BlockHash(b"20"), 2000) block_hash2 = make_block_hash_with_group_id(BlockHash(b"30"), 3000) @@ -1293,6 +1309,7 @@ def test_kv_cache_events(blocks_to_cache: int): max_model_len=8192, enable_caching=True, enable_kv_cache_events=True, + hash_block_size=block_size, ) num_tokens = block_size * blocks_to_cache @@ -1351,6 +1368,7 @@ def test_kv_cache_events_with_lora(blocks_to_cache: int): max_model_len=8192, enable_caching=True, enable_kv_cache_events=True, + hash_block_size=block_size, ) # Test with LoRA request @@ -1405,6 +1423,7 @@ def test_eagle_enabled_removes_last_block(): max_model_len=8192, enable_caching=True, use_eagle=True, + hash_block_size=block_size, ) # Request with 3 full blocks (48 tokens) @@ -1437,6 +1456,7 @@ def test_eagle_with_partial_blocks(): max_model_len=8192, enable_caching=True, use_eagle=True, + hash_block_size=block_size, ) # 2 full blocks + 5 tokens (non-divisible length) token_ids = [0] * (2 * block_size + 5) @@ -1476,6 +1496,7 @@ def test_eagle_with_sliding_window(): max_model_len=8192, enable_caching=True, use_eagle=True, + hash_block_size=block_size, ) # 2 full blocks + 5 tokens (non-divisible length) @@ -1522,6 +1543,76 @@ def test_eagle_with_sliding_window(): assert num_tokens == 0 +def test_different_block_size(): + block_size = 16 + # full attention and sliding window attention layers have the same page size: + # (32 tokens/block * float16 token, vs. 16 tokens/block * float32 token) + kv_cache_config = KVCacheConfig( + num_blocks=100, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer1"], + FullAttentionSpec(block_size * 2, 1, 1, torch.float16), + ), + KVCacheGroupSpec( + ["layer2"], + SlidingWindowSpec( + block_size, + 1, + 1, + torch.float32, + sliding_window=2 * block_size, + ), + ), + ], + ) + manager = KVCacheManager( + kv_cache_config=kv_cache_config, + max_model_len=8192, + enable_caching=True, + hash_block_size=block_size, + ) + + # 10 blocks of 16 tokens each. Token ids are not strictly aligned for each block. + common_token_ids = [i for i in range(10) for _ in range(block_size)] + + req0 = make_request("0", common_token_ids, block_size, sha256) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) + assert not computed_blocks.blocks[0] + assert not computed_blocks.blocks[1] + assert num_computed_tokens == 0 + blocks = manager.allocate_slots( + req0, 7 * block_size, len(computed_blocks.blocks[0]) * 16, computed_blocks + ) + assert blocks.get_block_ids() == ([1, 2, 3, 4], [5, 6, 7, 8, 9, 10, 11]) + req1 = make_request("1", common_token_ids[: 7 * block_size + 1], block_size, sha256) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + assert len(computed_blocks.blocks[0]) == 3 + assert len(computed_blocks.blocks[1]) == 6 + assert num_computed_tokens == 6 * 16 + + req2 = make_request("2", common_token_ids[: 6 * block_size + 1], block_size, sha256) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) + assert len(computed_blocks.blocks[0]) == 3 + assert len(computed_blocks.blocks[1]) == 6 + assert num_computed_tokens == 6 * 16 + + # Evict some blocks to make sliding window cache hit length 5*16 + # But should return 4 * 16 because full attention cache hit length must be + # a multiple of 32 + manager.block_pool.cached_block_hash_to_block.pop( + make_block_hash_with_group_id(req1.block_hashes[6], 1), 11 + ) + manager.block_pool.cached_block_hash_to_block.pop( + make_block_hash_with_group_id(req1.block_hashes[5], 1), 10 + ) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) + assert len(computed_blocks.blocks[0]) == 2 + assert len(computed_blocks.blocks[1]) == 4 + assert num_computed_tokens == 4 * 16 + + def test_block_lookup_cache_single_block_per_key(): cache = BlockHashToBlockMap() key0 = BlockHashWithGroupId(b"hash0") diff --git a/tests/v1/core/test_single_type_kv_cache_manager.py b/tests/v1/core/test_single_type_kv_cache_manager.py index a27f32938c08b..e6a69dc8a949a 100644 --- a/tests/v1/core/test_single_type_kv_cache_manager.py +++ b/tests/v1/core/test_single_type_kv_cache_manager.py @@ -41,7 +41,9 @@ def test_chunked_local_attention_possible_cached_prefix(): attention_chunk_size=4, ) - block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) + block_pool = BlockPool( + num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size + ) manager = get_chunked_local_attention_manager( chunked_local_attention_spec, block_pool ) @@ -70,6 +72,7 @@ def test_chunked_local_attention_possible_cached_prefix(): block_pool=block_pool, kv_cache_spec=chunked_local_attention_spec, use_eagle=False, + alignment_tokens=block_size, )[0] assert len(computed_blocks) == expect_length @@ -111,7 +114,9 @@ def test_sliding_window_possible_cached_prefix(): sliding_window=4, ) - block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) + block_pool = BlockPool( + num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size + ) manager = get_sliding_window_manager(sliding_window_spec, block_pool) def run_one_case(block_is_cached, expect_length): @@ -138,6 +143,7 @@ def test_sliding_window_possible_cached_prefix(): block_pool=block_pool, kv_cache_spec=sliding_window_spec, use_eagle=False, + alignment_tokens=block_size, )[0] assert len(computed_blocks) == expect_length @@ -178,7 +184,7 @@ def test_chunked_local_attention_remove_skipped_blocks(): attention_chunk_size=4, ) - block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True) + block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True, hash_block_size=2) manager = get_chunked_local_attention_manager(attention_spec, block_pool) @@ -239,7 +245,7 @@ def test_sliding_window_remove_skipped_blocks(): sliding_window=4, ) - block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True) + block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True, hash_block_size=2) manager = get_sliding_window_manager(sliding_window_spec, block_pool) @@ -316,7 +322,9 @@ def test_get_num_blocks_to_allocate(): sliding_window=4, # Placeholder value, not related to test result ) - block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) + block_pool = BlockPool( + num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size + ) manager = get_sliding_window_manager(sliding_window_spec, block_pool) cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)] cached_blocks_2 = [block_pool.null_block for _ in range(5)] + [ @@ -341,7 +349,9 @@ def test_chunked_local_attention_get_num_blocks_to_allocate(): attention_chunk_size=4, # Placeholder value, not related to test result ) - block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) + block_pool = BlockPool( + num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size + ) manager = get_chunked_local_attention_manager(attention_spec, block_pool) cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)] cached_blocks_2 = [block_pool.null_block for _ in range(5)] + [ diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 6d5b3392baa2b..bdccb15e3f655 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1816,9 +1816,11 @@ class EngineArgs: if model_config.runner_type != "pooling": default_chunked_prefill = True - # Disable prefix caching default for hybrid models - # since the feature is still experimental. - default_prefix_caching = not model_config.is_hybrid + # Disable prefix caching default for hybrid models and mamba-only + # models since the feature is still experimental. + default_prefix_caching = not ( + model_config.is_hybrid or model_config.is_attention_free + ) else: assert model_config.pooler_config is not None diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 3cf4bf991e667..d7e802ba1aca0 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -289,9 +289,6 @@ class MambaModelConfig(VerifyAndUpdateConfig): model_config = vllm_config.model_config cache_config = vllm_config.cache_config - if cache_config.mamba_block_size is None: - cache_config.mamba_block_size = model_config.max_model_len - if cache_config.enable_prefix_caching: if model_config.supports_mamba_prefix_caching: logger.info( @@ -299,6 +296,11 @@ class MambaModelConfig(VerifyAndUpdateConfig): "Its support for Mamba layers is experimental. " "Please report any issues you may observe." ) + # By default, mamba block size will be set to max_model_len (see + # below). When enabling prefix caching, we align mamba block size + # to the block size as the basic granularity for prefix caching. + if cache_config.mamba_block_size is None: + cache_config.mamba_block_size = cache_config.block_size else: logger.info( "Hybrid or mamba-based model detected without " @@ -306,6 +308,9 @@ class MambaModelConfig(VerifyAndUpdateConfig): ) cache_config.enable_prefix_caching = False + if cache_config.mamba_block_size is None: + cache_config.mamba_block_size = model_config.max_model_len + # TODO(tdoublep): remove once cascade attention is supported logger.info( "Disabling cascade attention since it is not supported for hybrid models." diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 55710ad5cc693..8b0e8fd3a2410 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -13,6 +13,8 @@ from vllm.distributed.kv_events import ( from vllm.logger import init_logger from vllm.v1.core.kv_cache_utils import ( BlockHash, + BlockHashList, + BlockHashListWithBlockSize, BlockHashWithGroupId, ExternalBlockHash, FreeKVCacheBlockQueue, @@ -133,6 +135,10 @@ class BlockPool: Args: num_gpu_blocks: The number of blocks in the pool. enable_caching: Whether to enable prefix caching. + hash_block_size: The block size of which the block hashes are computed. + The actual block size usually equals hash_block_size, but in cases + where different KV cache groups have different block sizes, the + actual block size can be a multiple of hash_block_size. enable_kv_cache_events: Whether to enable kv cache events. """ @@ -140,11 +146,13 @@ class BlockPool: self, num_gpu_blocks: int, enable_caching: bool, + hash_block_size: int, enable_kv_cache_events: bool = False, ): assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0 self.num_gpu_blocks = num_gpu_blocks self.enable_caching = enable_caching + self.hash_block_size = hash_block_size # All kv-cache blocks. self.blocks: list[KVCacheBlock] = [ KVCacheBlock(idx) for idx in range(num_gpu_blocks) @@ -223,8 +231,20 @@ class BlockPool: return new_full_blocks = blocks[num_cached_blocks:num_full_blocks] assert len(request.block_hashes) >= num_full_blocks - new_block_hashes = request.block_hashes[num_cached_blocks:] + if block_size == self.hash_block_size: + # Common case. + block_hashes: BlockHashList = request.block_hashes + else: + # block_size is a multiple of hash_block_size. This happens when + # different KV cache groups have different block sizes. + assert block_size % self.hash_block_size == 0 + # Recalculate block_hashes at the granularity of block_size, using + # the original block_hashes (at the granularity of hash_block_size). + block_hashes = BlockHashListWithBlockSize( + request.block_hashes, self.hash_block_size, block_size + ) + new_block_hashes = block_hashes[num_cached_blocks:] new_hashes: list[ExternalBlockHash] | None = ( [] if self.enable_kv_cache_events else None ) diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 1531b61f88fe2..fd1ec8e27fba2 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -2,15 +2,25 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod from collections.abc import Sequence +from math import lcm from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock +from vllm.v1.core.kv_cache_utils import ( + BlockHash, + BlockHashList, + BlockHashListWithBlockSize, + KVCacheBlock, +) from vllm.v1.core.single_type_kv_cache_manager import ( CrossAttentionManager, FullAttentionManager, get_manager_for_kv_cache_spec, ) -from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig, KVCacheSpec +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheConfig, + KVCacheSpec, +) from vllm.v1.request import Request @@ -28,13 +38,17 @@ class KVCacheCoordinator(ABC): enable_kv_cache_events: bool, dcp_world_size: int, pcp_world_size: int, + hash_block_size: int, ): self.kv_cache_config = kv_cache_config self.max_model_len = max_model_len self.enable_caching = enable_caching self.block_pool = BlockPool( - kv_cache_config.num_blocks, enable_caching, enable_kv_cache_events + kv_cache_config.num_blocks, + enable_caching, + hash_block_size, + enable_kv_cache_events, ) # Needs special handling for find_longest_cache_hit if eagle is enabled @@ -213,6 +227,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator): enable_kv_cache_events: bool, dcp_world_size: int, pcp_world_size: int, + hash_block_size: int, ): super().__init__( kv_cache_config, @@ -222,6 +237,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator): enable_kv_cache_events, dcp_world_size=dcp_world_size, pcp_world_size=pcp_world_size, + hash_block_size=hash_block_size, ) self.num_single_type_manager = len(self.single_type_managers) @@ -255,6 +271,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator): enable_kv_cache_events: bool, dcp_world_size: int, pcp_world_size: int, + hash_block_size: int, ): super().__init__( kv_cache_config, @@ -264,6 +281,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator): enable_kv_cache_events, dcp_world_size=dcp_world_size, pcp_world_size=pcp_world_size, + hash_block_size=hash_block_size, ) self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec self.block_size = self.kv_cache_spec.block_size @@ -273,6 +291,11 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator): self.block_size *= dcp_world_size if pcp_world_size > 1: self.block_size *= pcp_world_size + # For models using only Mamba, block_size is set to max_model_len when + # prefix caching is disabled, and hash_block_size validation is skipped. + assert not enable_caching or (hash_block_size == self.block_size), ( + "UnitaryKVCacheCoordinator assumes hash_block_size == block_size" + ) assert len(self.kv_cache_config.kv_cache_groups) == 1, ( "UnitaryKVCacheCoordinator assumes only one kv cache group" ) @@ -289,6 +312,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator): block_pool=self.block_pool, kv_cache_spec=self.kv_cache_spec, use_eagle=self.use_eagle, + alignment_tokens=self.block_size, dcp_world_size=self.dcp_world_size, pcp_world_size=self.pcp_world_size, ) @@ -313,6 +337,7 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): enable_kv_cache_events: bool, dcp_world_size: int, pcp_world_size: int, + hash_block_size: int, ): super().__init__( kv_cache_config, @@ -322,7 +347,17 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): enable_kv_cache_events, dcp_world_size=dcp_world_size, pcp_world_size=pcp_world_size, + hash_block_size=hash_block_size, ) + # hash_block_size: the block size used to compute block hashes. + # The actual block size usually equals hash_block_size, but in cases where + # different KV cache groups have different block sizes, the actual block size + # can be a multiple of hash_block_size. + self.hash_block_size = hash_block_size + assert all( + g.kv_cache_spec.block_size % hash_block_size == 0 + for g in kv_cache_config.kv_cache_groups + ), "block_size must be divisible by hash_block_size" assert dcp_world_size == 1, "DCP not support hybrid attn now." assert pcp_world_size == 1, "PCP not support hybrid attn now." self.verify_and_split_kv_cache_groups() @@ -373,14 +408,12 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): self.other_spec = other_spec self.full_attention_block_size = self.full_attention_spec.block_size self.other_block_size = self.other_spec.block_size - - if self.enable_caching: - # this requirement is only needed for the prefix caching logic - divisible = self.other_block_size % self.full_attention_block_size - assert divisible == 0, ( - "KVCacheCoordinator assumes the block_size of full " - "attention layers is divisible by other layers now." - ) + # The LCM of the block sizes of full attention and other attention. + # The cache hit length must be a multiple of the LCM of the block sizes + # to make sure the cache hit length is a multiple of the block size of + # each attention type. Requiring this because we don't support partial + # block cache hit yet. + self.lcm_block_size = lcm(self.full_attention_block_size, self.other_block_size) if max(self.full_attention_group_ids) < min(self.other_group_ids): self.full_attn_first = True @@ -414,25 +447,48 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): - The number of tokens of the longest cache hit. """ # First, find the longest cache hit for full attention. + if self.full_attention_spec.block_size == self.hash_block_size: + # Common case. + full_attention_block_hashes: BlockHashList = block_hashes + else: + # block_size is a multiple of hash_block_size. This happens when different + # KV cache groups have different block sizes. In this case, we need to + # recalculate block_hashes at the granularity of block_size, using the + # original block_hashes (at the granularity of hash_block_size). + full_attention_block_hashes = BlockHashListWithBlockSize( + block_hashes, self.hash_block_size, self.full_attention_spec.block_size + ) hit_blocks_full_attn = self.full_attention_manager_cls.find_longest_cache_hit( - block_hashes=block_hashes, + block_hashes=full_attention_block_hashes, max_length=max_cache_hit_length, kv_cache_group_ids=self.full_attention_group_ids, block_pool=self.block_pool, kv_cache_spec=self.full_attention_spec, use_eagle=self.use_eagle, + alignment_tokens=self.lcm_block_size, ) hit_length = len(hit_blocks_full_attn[0]) * self.full_attention_block_size # Next, find the cache hit for the other attention WITHIN # the cache hit of full attention. + if self.other_spec.block_size == self.hash_block_size: + # Common case. + other_block_hashes: BlockHashList = block_hashes + else: + # Similar to the full attention case, here we need to recalculate + # block_hashes at the granularity of block_size, using the original + # block_hashes (at the granularity of hash_block_size). + other_block_hashes = BlockHashListWithBlockSize( + block_hashes, self.hash_block_size, self.other_spec.block_size + ) hit_blocks_other_attn = self.other_attention_cls.find_longest_cache_hit( - block_hashes=block_hashes, + block_hashes=other_block_hashes, max_length=hit_length, kv_cache_group_ids=self.other_group_ids, block_pool=self.block_pool, kv_cache_spec=self.other_spec, use_eagle=self.use_eagle, + alignment_tokens=self.lcm_block_size, ) hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size @@ -466,6 +522,7 @@ def get_kv_cache_coordinator( enable_kv_cache_events: bool, dcp_world_size: int, pcp_world_size: int, + hash_block_size: int, ) -> KVCacheCoordinator: if not enable_caching: return KVCacheCoordinatorNoPrefixCache( @@ -473,8 +530,9 @@ def get_kv_cache_coordinator( max_model_len, use_eagle, enable_kv_cache_events, - dcp_world_size=dcp_world_size, - pcp_world_size=pcp_world_size, + dcp_world_size, + pcp_world_size, + hash_block_size, ) if len(kv_cache_config.kv_cache_groups) == 1: return UnitaryKVCacheCoordinator( @@ -483,8 +541,9 @@ def get_kv_cache_coordinator( use_eagle, enable_caching, enable_kv_cache_events, - dcp_world_size=dcp_world_size, - pcp_world_size=pcp_world_size, + dcp_world_size, + pcp_world_size, + hash_block_size, ) return HybridKVCacheCoordinator( kv_cache_config, @@ -492,6 +551,7 @@ def get_kv_cache_coordinator( use_eagle, enable_caching, enable_kv_cache_events, - dcp_world_size=dcp_world_size, - pcp_world_size=pcp_world_size, + dcp_world_size, + pcp_world_size, + hash_block_size, ) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 2012c3fef88bc..b061e5cc831dd 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -95,6 +95,7 @@ class KVCacheManager: self, kv_cache_config: KVCacheConfig, max_model_len: int, + hash_block_size: int, enable_caching: bool = True, use_eagle: bool = False, log_stats: bool = False, @@ -107,28 +108,11 @@ class KVCacheManager: self.enable_caching = enable_caching self.use_eagle = use_eagle self.log_stats = log_stats - # FIXME: make prefix cache stats conditional on log_stats + # FIXME: make prefix cache stats conditional on log_stats. We still need + # this comment because when the log stats is enabled there are still + # potential configs we could expose in the future. self.prefix_cache_stats = PrefixCacheStats() if log_stats else None - self.block_size: int | None = None - if self.enable_caching: - assert ( - len( - set( - g.kv_cache_spec.block_size - for g in kv_cache_config.kv_cache_groups - ) - ) - == 1 - ), "Only one block size is supported for now" - self.block_size = kv_cache_config.kv_cache_groups[ - 0 - ].kv_cache_spec.block_size - - if dcp_world_size * pcp_world_size > 1: - assert len(kv_cache_config.kv_cache_groups) == 1 - self.block_size *= dcp_world_size * pcp_world_size - self.coordinator = get_kv_cache_coordinator( kv_cache_config=kv_cache_config, max_model_len=self.max_model_len, @@ -137,6 +121,7 @@ class KVCacheManager: enable_kv_cache_events=enable_kv_cache_events, dcp_world_size=dcp_world_size, pcp_world_size=pcp_world_size, + hash_block_size=hash_block_size, ) self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups) self.block_pool = self.coordinator.block_pool diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index a0033fa650baa..602eb81beb010 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -5,9 +5,9 @@ import copy import os from collections import defaultdict -from collections.abc import Callable, Iterable, Sequence -from dataclasses import dataclass -from typing import Any, NewType, TypeAlias +from collections.abc import Callable, Iterable, Iterator, Sequence +from dataclasses import dataclass, replace +from typing import Any, NewType, TypeAlias, overload from vllm import envs from vllm.config import VllmConfig @@ -825,11 +825,11 @@ def get_num_blocks( return num_blocks -def get_uniform_page_size(kv_cache_spec: dict[str, KVCacheSpec]) -> int: +def get_uniform_page_size(kv_cache_specs: Iterable[KVCacheSpec]) -> int: """ Get the page size of the KV cache. """ - page_sizes = set(layer.page_size_bytes for layer in kv_cache_spec.values()) + page_sizes = {layer.page_size_bytes for layer in kv_cache_specs} assert len(page_sizes) == 1 return page_sizes.pop() @@ -882,6 +882,46 @@ def is_kv_cache_page_size_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool return len(page_sizes) == 1 +def unify_kv_cache_spec_page_size( + kv_cache_spec: dict[str, KVCacheSpec], +) -> dict[str, KVCacheSpec]: + """ + Unify the page size of the given KVCacheSpec. If the page size of all layers + are the same, return the original KVCacheSpec. If not same, unify the page + size by increasing the block size of layers with smaller page size. Raise + NotImplementedError if failed to unify the page size. + + Args: + kv_cache_spec: The KVCacheSpec of each attention layer in the model + + Returns: + The updated KVCacheSpec with the same page_size_bytes. + """ + page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()} + if len(page_sizes) <= 1: + # All layers have the same page size, no need to unify. + return kv_cache_spec + + max_page_size = max(page_sizes) + new_kv_cache_spec = {} + for layer_name, layer_spec in kv_cache_spec.items(): + if layer_spec.page_size_bytes == max_page_size: + new_kv_cache_spec[layer_name] = layer_spec + else: + layer_page_size = layer_spec.page_size_bytes + if max_page_size % layer_page_size != 0: + raise NotImplementedError( + "The page size of the layer is not divisible by the " + "maximum page size. Cannot unify by adjusting block_size." + ) + ratio = max_page_size // layer_page_size + new_block_size = layer_spec.block_size * ratio + new_spec = replace(layer_spec, block_size=new_block_size) + assert new_spec.page_size_bytes == max_page_size + new_kv_cache_spec[layer_name] = new_spec + return new_kv_cache_spec + + def is_kv_cache_type_attention_free(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: # kv_cache_spec is an empty dict for attention free models return not kv_cache_spec @@ -1010,7 +1050,6 @@ def _get_kv_cache_groups_uniform_page_size( def get_kv_cache_config_from_groups( vllm_config: VllmConfig, kv_cache_groups: list[KVCacheGroupSpec], - kv_cache_specs: dict[str, KVCacheSpec], available_memory: int, ) -> KVCacheConfig: """ @@ -1020,7 +1059,6 @@ def get_kv_cache_config_from_groups( Args: vllm_config: The global VllmConfig kv_cache_groups: The KV cache groups - kv_cache_specs: The KV cache spec of each attention layer in the model available_memory: Memory available for KV cache in bytes Returns: The generated KVCacheConfig @@ -1064,7 +1102,9 @@ def get_kv_cache_config_from_groups( # full.1, sw.2: share another Tensor with size=available_memory//2 group_size = max(len(group.layer_names) for group in kv_cache_groups) - page_size = get_uniform_page_size(kv_cache_specs) + page_size = get_uniform_page_size( + [group.kv_cache_spec for group in kv_cache_groups] + ) assert group_size > 0, "group_size must be greater than 0" num_blocks = get_num_blocks( vllm_config, group_size, available_memory, page_size @@ -1166,7 +1206,8 @@ def get_kv_cache_groups( # This returns an empty list to allow for the KVCacheManager to handle # attention free models. return [] - elif is_kv_cache_spec_uniform(kv_cache_spec): + + if is_kv_cache_spec_uniform(kv_cache_spec): # KV cache of all layers are the same, which is true for # most models. Allocate the same amount of memory for # each layer. @@ -1176,14 +1217,16 @@ def get_kv_cache_groups( # full attention, or all layers are sliding window attention with the # same window size). Put all layers into one group. return _get_kv_cache_groups_uniform_type(uniform_spec) - elif is_kv_cache_page_size_uniform(kv_cache_spec): - # Model contains multiple attention types, but KV cache of all layers - # have the same physical memory per block per layer. Split the layers - # into groups with the same number of layers, and thus same total page - # size. - return _get_kv_cache_groups_uniform_page_size(kv_cache_spec) - raise NotImplementedError + # As KVCacheManager can only allocate memory of one size, we need to unify + # the page size of the layers. For cases cannot be unified, this function + # will raise an error. + kv_cache_spec = unify_kv_cache_spec_page_size(kv_cache_spec) + # Model contains multiple attention types, but KV cache of all layers + # have the same physical memory per block per layer. Split the layers + # into groups with the same number of layers, and thus same total page + # size. + return _get_kv_cache_groups_uniform_page_size(kv_cache_spec) def generate_scheduler_kv_cache_config( @@ -1327,10 +1370,7 @@ def get_kv_cache_configs( ) == len(kv_cache_spec_one_worker), "Some layers are not assigned to any group." kv_cache_configs.append( get_kv_cache_config_from_groups( - vllm_config, - kv_cache_groups_one_worker, - kv_cache_spec_one_worker, - available_memory_one_worker, + vllm_config, kv_cache_groups_one_worker, available_memory_one_worker ) ) @@ -1353,3 +1393,79 @@ def get_kv_cache_configs( _report_kv_cache_config(vllm_config, kv_cache_config) return kv_cache_configs + + +class BlockHashListWithBlockSize: + """ + Convert block-hash granularity from `hash_block_size` to `target_block_size`. + Used when KV cache groups have different block sizes: `hash_block_size` + is the size used to compute the original `block_hashes`; `target_block_size` + is the group's actual block size. + + Currently, only scaling up by an integer factor is supported (i.e., + `target_block_size` is a multiple of `hash_block_size`). Conversion is + performed lazily on access for efficiency, by concatenating consecutive + hashes at `hash_block_size` to form each hash at `target_block_size`. + + Example (`hash_block_size` = 16, `target_block_size` = 32): + concatenating two 16-size hashes yields one 32-size hash: + + Block hashes with block_size 16: + | Token Range | 0-15 | 16-31 | 32-47 | 48-63 | + |-------------|------|-------|-------|-------| + | Hash | A | B | C | D | + + Block hashes with block_size 32: + | Token Range | 0-31 | 32-63 | + |-------------|------|-------| + | Hash | AB | CD | + + Args: + block_hashes: Block hashes to convert, computed at `hash_block_size`. + hash_block_size: Block size at which `block_hashes` were computed. + target_block_size: Desired block size; must be a multiple of `hash_block_size`. + """ + + def __init__( + self, + block_hashes: list[BlockHash], + hash_block_size: int, + target_block_size: int, + ): + self.block_hashes = block_hashes + assert target_block_size % hash_block_size == 0 + self.scale_factor = target_block_size // hash_block_size + + def __len__(self) -> int: + return len(self.block_hashes) // self.scale_factor + + @overload + def __getitem__(self, idx: int) -> BlockHash: ... + + @overload + def __getitem__(self, idx: slice) -> list[BlockHash]: ... + + def __getitem__(self, idx): + if isinstance(idx, int): + return self._get_value_at(idx) + + if isinstance(idx, slice): + start, stop, step = idx.indices(len(self)) + return [self._get_value_at(i) for i in range(start, stop, step)] + + raise TypeError(f"Invalid index type: {type(idx)!r}") + + def __iter__(self) -> Iterator[BlockHash]: + for i in range(len(self)): + yield self._get_value_at(i) + + def _get_value_at(self, idx: int) -> BlockHash: + base = idx * self.scale_factor + end = base + self.scale_factor + merged_hash: bytes = self.block_hashes[base] + for i in range(base + 1, end): + merged_hash += self.block_hashes[i] + return BlockHash(merged_hash) + + +BlockHashList = list[BlockHash] | BlockHashListWithBlockSize diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 23af014c10364..bea2f865bad46 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -186,6 +186,7 @@ class Scheduler(SchedulerInterface): enable_kv_cache_events=self.enable_kv_cache_events, dcp_world_size=self.dcp_world_size, pcp_world_size=self.pcp_world_size, + hash_block_size=self.block_size, ) self.use_pp = self.parallel_config.pipeline_parallel_size > 1 self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index d90ec550f7666..4aeb17a156bb3 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -7,7 +7,7 @@ from collections.abc import Sequence from vllm.utils.math_utils import cdiv from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock +from vllm.v1.core.kv_cache_utils import BlockHashList, KVCacheBlock from vllm.v1.kv_cache_interface import ( ChunkedLocalAttentionSpec, CrossAttentionSpec, @@ -207,12 +207,13 @@ class SingleTypeKVCacheManager(ABC): @abstractmethod def find_longest_cache_hit( cls, - block_hashes: list[BlockHash], + block_hashes: BlockHashList, max_length: int, kv_cache_group_ids: list[int], block_pool: BlockPool, kv_cache_spec: KVCacheSpec, use_eagle: bool, + alignment_tokens: int, dcp_world_size: int = 1, pcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: @@ -232,6 +233,11 @@ class SingleTypeKVCacheManager(ABC): block_pool: The block pool. kv_cache_spec: The kv cache spec. use_eagle: Whether to use eagle. + alignment_tokens: The returned cache hit length (in tokens) should + be a multiple of this value (in tokens). By default, it should + be set to the block_size. + dcp_world_size: The world size of decode context parallelism. + pcp_world_size: The world size of prefill context parallelism. Returns: A list of cached blocks with skipped blocks replaced by null block @@ -299,17 +305,18 @@ class FullAttentionManager(SingleTypeKVCacheManager): @classmethod def find_longest_cache_hit( cls, - block_hashes: list[BlockHash], + block_hashes: BlockHashList, max_length: int, kv_cache_group_ids: list[int], block_pool: BlockPool, kv_cache_spec: KVCacheSpec, use_eagle: bool, + alignment_tokens: int, dcp_world_size: int = 1, pcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: assert isinstance( - kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec) + kv_cache_spec, FullAttentionSpec | ChunkedLocalAttentionSpec ), ( "FullAttentionManager can only be used for full attention " "and chunked local attention groups" @@ -333,6 +340,13 @@ class FullAttentionManager(SingleTypeKVCacheManager): else: break if use_eagle and computed_blocks[0]: + # Need to drop the last matched block if eagle is enabled. + for computed in computed_blocks: + computed.pop() + while ( + block_size != alignment_tokens # Faster for common case. + and len(computed_blocks[0]) * block_size % alignment_tokens != 0 + ): for computed in computed_blocks: computed.pop() return computed_blocks @@ -359,12 +373,13 @@ class SlidingWindowManager(SingleTypeKVCacheManager): @classmethod def find_longest_cache_hit( cls, - block_hashes: list[BlockHash], + block_hashes: BlockHashList, max_length: int, kv_cache_group_ids: list[int], block_pool: BlockPool, kv_cache_spec: KVCacheSpec, use_eagle: bool, + alignment_tokens: int, dcp_world_size: int = 1, pcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: @@ -396,6 +411,7 @@ class SlidingWindowManager(SingleTypeKVCacheManager): [block_pool.null_block] * max_num_blocks for _ in range(len(kv_cache_group_ids)) ) + block_size = kv_cache_spec.block_size num_contiguous_blocks = 0 match_found = False # Search from right to left and early stop when a match is found. @@ -403,6 +419,15 @@ class SlidingWindowManager(SingleTypeKVCacheManager): if cached_block := block_pool.get_cached_block( block_hashes[i], kv_cache_group_ids ): + # Skip prefix matching check if the block is not aligned with + # `alignment_tokens`. + if ( + num_contiguous_blocks == 0 + and block_size != alignment_tokens # Faster for common case. + and (i + 1) * block_size % alignment_tokens != 0 + ): + continue + # Add the cached block to the computed blocks. for computed, cached in zip(computed_blocks, cached_block): computed[i] = cached num_contiguous_blocks += 1 @@ -421,7 +446,16 @@ class SlidingWindowManager(SingleTypeKVCacheManager): # `num_contiguous_blocks < sliding_window_contiguous_blocks`. for computed in computed_blocks: del computed[num_contiguous_blocks:] + while ( + block_size != alignment_tokens # Faster for common case. + and len(computed_blocks[0]) * block_size % alignment_tokens != 0 + ): + for computed in computed_blocks: + computed.pop() if use_eagle and computed_blocks[0]: + assert kv_cache_spec.block_size == alignment_tokens, ( + "aligned_length is not compatible with eagle now" + ) for computed in computed_blocks: computed.pop() return computed_blocks @@ -475,12 +509,13 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager): @classmethod def find_longest_cache_hit( cls, - block_hashes: list[BlockHash], + block_hashes: BlockHashList, max_length: int, kv_cache_group_ids: list[int], block_pool: BlockPool, kv_cache_spec: KVCacheSpec, use_eagle: bool, + alignment_tokens: int, dcp_world_size: int = 1, pcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: @@ -511,6 +546,10 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager): block_pool: The block pool. kv_cache_spec: The kv cache spec. use_eagle: Whether to use eagle. + dcp_world_size: The world size of decode context parallelism. + pcp_world_size: The world size of prefill context parallelism. + alignment_tokens: The returned cache hit length (in tokens) should + be a multiple of this value (in tokens). Returns: A list of cached blocks @@ -524,6 +563,10 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager): ) assert dcp_world_size == 1, "DCP not support chunked local attn now." assert pcp_world_size == 1, "PCP not support chunked local attn now." + assert kv_cache_spec.block_size == alignment_tokens, ( + "KV cache groups with different block sizes are not compatible with " + "chunked local attention now" + ) max_num_blocks = max_length // kv_cache_spec.block_size if max_length > 0: local_attention_start_idx = ( @@ -612,12 +655,13 @@ class MambaManager(SingleTypeKVCacheManager): @classmethod def find_longest_cache_hit( cls, - block_hashes: list[BlockHash], + block_hashes: BlockHashList, max_length: int, kv_cache_group_ids: list[int], block_pool: BlockPool, kv_cache_spec: KVCacheSpec, use_eagle: bool, + alignment_tokens: int, dcp_world_size: int = 1, pcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: @@ -630,12 +674,21 @@ class MambaManager(SingleTypeKVCacheManager): [] for _ in range(len(kv_cache_group_ids)) ) - max_num_blocks = max_length // kv_cache_spec.block_size + block_size = kv_cache_spec.block_size + max_num_blocks = max_length // block_size # Search from right to left and early stop when a match is found. for i in range(max_num_blocks - 1, -1, -1): if cached_block := block_pool.get_cached_block( block_hashes[i], kv_cache_group_ids ): + # When enable Mamba prefix caching, `block_size` will be aligned + # across full attention layers and Mamba layers to ensure the + # prefix hit length aligned at block + if ( + block_size != alignment_tokens # Faster for common case. + and (i + 1) * block_size % alignment_tokens != 0 + ): + continue for computed, cached in zip(computed_blocks, cached_block): # the hit length logic later assumes: # hit_length = len(hit_blocks_other_attn[0]) @@ -708,12 +761,13 @@ class CrossAttentionManager(SingleTypeKVCacheManager): @classmethod def find_longest_cache_hit( cls, - block_hashes: list[BlockHash], + block_hashes: BlockHashList, max_length: int, kv_cache_group_ids: list[int], block_pool: BlockPool, kv_cache_spec: KVCacheSpec, use_eagle: bool, + alignment_tokens: int, dcp_world_size: int = 1, pcp_world_size: int = 1, ) -> tuple[list[KVCacheBlock], ...]: