[Hybrid Allocator] Support KV cache groups with different block_size (#29143)

Signed-off-by: Yifan Qiao <yifanqiao@berkeley.edu>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Yifan Qiao 2025-11-25 07:30:57 -08:00 committed by GitHub
parent e502098643
commit 48ddb02b79
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 472 additions and 87 deletions

View File

@ -1248,7 +1248,9 @@ def test_allocate_with_lookahead():
)
# Test case 1: Requires additional lookahead tokens
kv_cache_manager = KVCacheManager(kv_cache_config=config, max_model_len=100)
kv_cache_manager = KVCacheManager(
kv_cache_config=config, max_model_len=100, hash_block_size=block_size
)
blocks = kv_cache_manager.allocate_slots(
request,
num_new_tokens=3,
@ -1257,7 +1259,9 @@ def test_allocate_with_lookahead():
assert len(blocks.get_block_ids()[0]) == 2 # ceil(5/4)=2 blocks
# Test case 2: With precomputed blocks
kv_cache_manager = KVCacheManager(kv_cache_config=config, max_model_len=100)
kv_cache_manager = KVCacheManager(
kv_cache_config=config, max_model_len=100, hash_block_size=block_size
)
# required_blocks = ceil((3 + 2) /4) = 2
blocks = kv_cache_manager.allocate_slots(
request,
@ -1268,7 +1272,9 @@ def test_allocate_with_lookahead():
# Test case 3: With precomputed blocks
# required_blocks = ceil((3 + 4) / 4) = 2
kv_cache_manager = KVCacheManager(kv_cache_config=config, max_model_len=100)
kv_cache_manager = KVCacheManager(
kv_cache_config=config, max_model_len=100, hash_block_size=block_size
)
blocks = kv_cache_manager.allocate_slots(
request,
num_new_tokens=3,
@ -1495,7 +1501,8 @@ def test_get_kv_cache_config_one_worker():
),
],
)
# different hidden size
# different hidden size but same type, use UniformTypeKVCacheSpecs
kv_cache_specs_hybrid = {
"layer_1": new_kv_cache_spec(head_size=128),
"layer_2": new_kv_cache_spec(head_size=64),
@ -1519,6 +1526,40 @@ def test_get_kv_cache_config_one_worker():
],
)
# Different hidden size and different type, align by different block size
kv_cache_specs_hybrid = {
"layer_1": new_kv_cache_spec(head_size=64),
"layer_2": new_sliding_window_spec(head_size=32),
}
kv_cache_config_hybrid = get_kv_cache_configs(
vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 32]
)[0]
assert kv_cache_config_hybrid == KVCacheConfig(
num_blocks=32,
kv_cache_tensors=[
KVCacheTensor(
size=mem_per_block_per_layer * 32, shared_by=["layer_1", "layer_2"]
),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer_1"], new_kv_cache_spec(head_size=64)),
KVCacheGroupSpec(
["layer_2"], new_sliding_window_spec(head_size=32, block_size=32)
),
],
)
# different hidden size that cannot be aligned by using different block size
kv_cache_specs_hybrid = {
"layer_1": new_kv_cache_spec(head_size=64),
"layer_2": new_sliding_window_spec(head_size=96),
}
with pytest.raises(NotImplementedError):
get_kv_cache_configs(
vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 2 * 32]
)[0]
# Test num_gpu_blocks_override
vllm_config.cache_config.num_gpu_blocks_override = 16
kv_cache_config_override_blocks = get_kv_cache_configs(

View File

@ -134,6 +134,7 @@ def test_prefill(hash_fn):
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
# Complete 3 blocks (48 tokens)
@ -256,6 +257,7 @@ def test_prefill_hybrid_model():
make_kv_cache_config_hybrid_model(block_size, 21),
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
hash_fn = sha256
@ -416,6 +418,7 @@ def test_prefill_plp():
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
# the default hash function is sha256
hash_fn = sha256
@ -523,6 +526,7 @@ def test_decode():
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
# Complete 3 blocks (48 tokens)
@ -585,6 +589,7 @@ def test_evict():
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
last_token_id = 5 * 16 + 7
@ -643,6 +648,7 @@ def test_hash_block_correct_reuse():
make_kv_cache_config(16, 2),
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
# Allocate 1 block and cache it.
@ -683,6 +689,7 @@ def test_computed_blocks_not_evicted():
make_kv_cache_config(block_size, 3),
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
# Allocate a block and cache it.
@ -741,6 +748,7 @@ def test_basic_prefix_caching_disabled():
make_kv_cache_config(block_size, 5),
max_model_len=8192,
enable_caching=False,
hash_block_size=block_size,
)
req1 = make_request(
@ -790,6 +798,7 @@ def test_cache_blocks(hash_fn):
block_pool = BlockPool(
num_gpu_blocks=5,
enable_caching=True,
hash_block_size=block_size,
)
# Req:
# Block 0: [0, 1, 2, 3]
@ -833,7 +842,9 @@ def test_cache_blocks_multi_group():
This tests that blocks are cached correctly for different kv cache groups.
"""
block_size = 4
block_pool = BlockPool(num_gpu_blocks=10, enable_caching=True)
block_pool = BlockPool(
num_gpu_blocks=10, enable_caching=True, hash_block_size=block_size
)
# Req:
# Block 0/4: [0, 1, 2, 3]
@ -921,6 +932,7 @@ def test_mm_prefix_caching():
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
# Common prompt tokens (T is text tokens and P is image placeholder tokens)
@ -1020,6 +1032,7 @@ def test_cache_key_salting():
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
# 3 complete blocks and an incomplete block with 11 tokens.
@ -1101,6 +1114,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
# Complete 3 blocks (48 tokens)
# | Common-0 | Common-1 | Common-2 | ... |
@ -1173,6 +1187,7 @@ def test_reset_prefix_cache():
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
full_block_token_ids = [i for i in range(3) for _ in range(16)]
@ -1213,6 +1228,7 @@ def test_prefix_cache_stats_disabled():
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
log_stats=False, # Disable logging stats
)
assert manager.prefix_cache_stats is None
@ -1232,7 +1248,7 @@ def test_prefix_cache_stats_disabled():
def test_maybe_evict_cached_block():
pool = BlockPool(num_gpu_blocks=4, enable_caching=True)
pool = BlockPool(num_gpu_blocks=4, enable_caching=True, hash_block_size=16)
block_hash0 = make_block_hash_with_group_id(BlockHash(b"10"), 1000)
block_hash1 = make_block_hash_with_group_id(BlockHash(b"20"), 2000)
block_hash2 = make_block_hash_with_group_id(BlockHash(b"30"), 3000)
@ -1293,6 +1309,7 @@ def test_kv_cache_events(blocks_to_cache: int):
max_model_len=8192,
enable_caching=True,
enable_kv_cache_events=True,
hash_block_size=block_size,
)
num_tokens = block_size * blocks_to_cache
@ -1351,6 +1368,7 @@ def test_kv_cache_events_with_lora(blocks_to_cache: int):
max_model_len=8192,
enable_caching=True,
enable_kv_cache_events=True,
hash_block_size=block_size,
)
# Test with LoRA request
@ -1405,6 +1423,7 @@ def test_eagle_enabled_removes_last_block():
max_model_len=8192,
enable_caching=True,
use_eagle=True,
hash_block_size=block_size,
)
# Request with 3 full blocks (48 tokens)
@ -1437,6 +1456,7 @@ def test_eagle_with_partial_blocks():
max_model_len=8192,
enable_caching=True,
use_eagle=True,
hash_block_size=block_size,
)
# 2 full blocks + 5 tokens (non-divisible length)
token_ids = [0] * (2 * block_size + 5)
@ -1476,6 +1496,7 @@ def test_eagle_with_sliding_window():
max_model_len=8192,
enable_caching=True,
use_eagle=True,
hash_block_size=block_size,
)
# 2 full blocks + 5 tokens (non-divisible length)
@ -1522,6 +1543,76 @@ def test_eagle_with_sliding_window():
assert num_tokens == 0
def test_different_block_size():
block_size = 16
# full attention and sliding window attention layers have the same page size:
# (32 tokens/block * float16 token, vs. 16 tokens/block * float32 token)
kv_cache_config = KVCacheConfig(
num_blocks=100,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer1"],
FullAttentionSpec(block_size * 2, 1, 1, torch.float16),
),
KVCacheGroupSpec(
["layer2"],
SlidingWindowSpec(
block_size,
1,
1,
torch.float32,
sliding_window=2 * block_size,
),
),
],
)
manager = KVCacheManager(
kv_cache_config=kv_cache_config,
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
# 10 blocks of 16 tokens each. Token ids are not strictly aligned for each block.
common_token_ids = [i for i in range(10) for _ in range(block_size)]
req0 = make_request("0", common_token_ids, block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks[0]
assert not computed_blocks.blocks[1]
assert num_computed_tokens == 0
blocks = manager.allocate_slots(
req0, 7 * block_size, len(computed_blocks.blocks[0]) * 16, computed_blocks
)
assert blocks.get_block_ids() == ([1, 2, 3, 4], [5, 6, 7, 8, 9, 10, 11])
req1 = make_request("1", common_token_ids[: 7 * block_size + 1], block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(computed_blocks.blocks[0]) == 3
assert len(computed_blocks.blocks[1]) == 6
assert num_computed_tokens == 6 * 16
req2 = make_request("2", common_token_ids[: 6 * block_size + 1], block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(computed_blocks.blocks[0]) == 3
assert len(computed_blocks.blocks[1]) == 6
assert num_computed_tokens == 6 * 16
# Evict some blocks to make sliding window cache hit length 5*16
# But should return 4 * 16 because full attention cache hit length must be
# a multiple of 32
manager.block_pool.cached_block_hash_to_block.pop(
make_block_hash_with_group_id(req1.block_hashes[6], 1), 11
)
manager.block_pool.cached_block_hash_to_block.pop(
make_block_hash_with_group_id(req1.block_hashes[5], 1), 10
)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(computed_blocks.blocks[0]) == 2
assert len(computed_blocks.blocks[1]) == 4
assert num_computed_tokens == 4 * 16
def test_block_lookup_cache_single_block_per_key():
cache = BlockHashToBlockMap()
key0 = BlockHashWithGroupId(b"hash0")

View File

@ -41,7 +41,9 @@ def test_chunked_local_attention_possible_cached_prefix():
attention_chunk_size=4,
)
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
block_pool = BlockPool(
num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
)
manager = get_chunked_local_attention_manager(
chunked_local_attention_spec, block_pool
)
@ -70,6 +72,7 @@ def test_chunked_local_attention_possible_cached_prefix():
block_pool=block_pool,
kv_cache_spec=chunked_local_attention_spec,
use_eagle=False,
alignment_tokens=block_size,
)[0]
assert len(computed_blocks) == expect_length
@ -111,7 +114,9 @@ def test_sliding_window_possible_cached_prefix():
sliding_window=4,
)
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
block_pool = BlockPool(
num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
)
manager = get_sliding_window_manager(sliding_window_spec, block_pool)
def run_one_case(block_is_cached, expect_length):
@ -138,6 +143,7 @@ def test_sliding_window_possible_cached_prefix():
block_pool=block_pool,
kv_cache_spec=sliding_window_spec,
use_eagle=False,
alignment_tokens=block_size,
)[0]
assert len(computed_blocks) == expect_length
@ -178,7 +184,7 @@ def test_chunked_local_attention_remove_skipped_blocks():
attention_chunk_size=4,
)
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True, hash_block_size=2)
manager = get_chunked_local_attention_manager(attention_spec, block_pool)
@ -239,7 +245,7 @@ def test_sliding_window_remove_skipped_blocks():
sliding_window=4,
)
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True, hash_block_size=2)
manager = get_sliding_window_manager(sliding_window_spec, block_pool)
@ -316,7 +322,9 @@ def test_get_num_blocks_to_allocate():
sliding_window=4, # Placeholder value, not related to test result
)
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
block_pool = BlockPool(
num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
)
manager = get_sliding_window_manager(sliding_window_spec, block_pool)
cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)]
cached_blocks_2 = [block_pool.null_block for _ in range(5)] + [
@ -341,7 +349,9 @@ def test_chunked_local_attention_get_num_blocks_to_allocate():
attention_chunk_size=4, # Placeholder value, not related to test result
)
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
block_pool = BlockPool(
num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
)
manager = get_chunked_local_attention_manager(attention_spec, block_pool)
cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)]
cached_blocks_2 = [block_pool.null_block for _ in range(5)] + [

View File

@ -1816,9 +1816,11 @@ class EngineArgs:
if model_config.runner_type != "pooling":
default_chunked_prefill = True
# Disable prefix caching default for hybrid models
# since the feature is still experimental.
default_prefix_caching = not model_config.is_hybrid
# Disable prefix caching default for hybrid models and mamba-only
# models since the feature is still experimental.
default_prefix_caching = not (
model_config.is_hybrid or model_config.is_attention_free
)
else:
assert model_config.pooler_config is not None

View File

@ -289,9 +289,6 @@ class MambaModelConfig(VerifyAndUpdateConfig):
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
if cache_config.mamba_block_size is None:
cache_config.mamba_block_size = model_config.max_model_len
if cache_config.enable_prefix_caching:
if model_config.supports_mamba_prefix_caching:
logger.info(
@ -299,6 +296,11 @@ class MambaModelConfig(VerifyAndUpdateConfig):
"Its support for Mamba layers is experimental. "
"Please report any issues you may observe."
)
# By default, mamba block size will be set to max_model_len (see
# below). When enabling prefix caching, we align mamba block size
# to the block size as the basic granularity for prefix caching.
if cache_config.mamba_block_size is None:
cache_config.mamba_block_size = cache_config.block_size
else:
logger.info(
"Hybrid or mamba-based model detected without "
@ -306,6 +308,9 @@ class MambaModelConfig(VerifyAndUpdateConfig):
)
cache_config.enable_prefix_caching = False
if cache_config.mamba_block_size is None:
cache_config.mamba_block_size = model_config.max_model_len
# TODO(tdoublep): remove once cascade attention is supported
logger.info(
"Disabling cascade attention since it is not supported for hybrid models."

View File

@ -13,6 +13,8 @@ from vllm.distributed.kv_events import (
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import (
BlockHash,
BlockHashList,
BlockHashListWithBlockSize,
BlockHashWithGroupId,
ExternalBlockHash,
FreeKVCacheBlockQueue,
@ -133,6 +135,10 @@ class BlockPool:
Args:
num_gpu_blocks: The number of blocks in the pool.
enable_caching: Whether to enable prefix caching.
hash_block_size: The block size of which the block hashes are computed.
The actual block size usually equals hash_block_size, but in cases
where different KV cache groups have different block sizes, the
actual block size can be a multiple of hash_block_size.
enable_kv_cache_events: Whether to enable kv cache events.
"""
@ -140,11 +146,13 @@ class BlockPool:
self,
num_gpu_blocks: int,
enable_caching: bool,
hash_block_size: int,
enable_kv_cache_events: bool = False,
):
assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
self.num_gpu_blocks = num_gpu_blocks
self.enable_caching = enable_caching
self.hash_block_size = hash_block_size
# All kv-cache blocks.
self.blocks: list[KVCacheBlock] = [
KVCacheBlock(idx) for idx in range(num_gpu_blocks)
@ -223,8 +231,20 @@ class BlockPool:
return
new_full_blocks = blocks[num_cached_blocks:num_full_blocks]
assert len(request.block_hashes) >= num_full_blocks
new_block_hashes = request.block_hashes[num_cached_blocks:]
if block_size == self.hash_block_size:
# Common case.
block_hashes: BlockHashList = request.block_hashes
else:
# block_size is a multiple of hash_block_size. This happens when
# different KV cache groups have different block sizes.
assert block_size % self.hash_block_size == 0
# Recalculate block_hashes at the granularity of block_size, using
# the original block_hashes (at the granularity of hash_block_size).
block_hashes = BlockHashListWithBlockSize(
request.block_hashes, self.hash_block_size, block_size
)
new_block_hashes = block_hashes[num_cached_blocks:]
new_hashes: list[ExternalBlockHash] | None = (
[] if self.enable_kv_cache_events else None
)

View File

@ -2,15 +2,25 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Sequence
from math import lcm
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
from vllm.v1.core.kv_cache_utils import (
BlockHash,
BlockHashList,
BlockHashListWithBlockSize,
KVCacheBlock,
)
from vllm.v1.core.single_type_kv_cache_manager import (
CrossAttentionManager,
FullAttentionManager,
get_manager_for_kv_cache_spec,
)
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig, KVCacheSpec
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheConfig,
KVCacheSpec,
)
from vllm.v1.request import Request
@ -28,13 +38,17 @@ class KVCacheCoordinator(ABC):
enable_kv_cache_events: bool,
dcp_world_size: int,
pcp_world_size: int,
hash_block_size: int,
):
self.kv_cache_config = kv_cache_config
self.max_model_len = max_model_len
self.enable_caching = enable_caching
self.block_pool = BlockPool(
kv_cache_config.num_blocks, enable_caching, enable_kv_cache_events
kv_cache_config.num_blocks,
enable_caching,
hash_block_size,
enable_kv_cache_events,
)
# Needs special handling for find_longest_cache_hit if eagle is enabled
@ -213,6 +227,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
enable_kv_cache_events: bool,
dcp_world_size: int,
pcp_world_size: int,
hash_block_size: int,
):
super().__init__(
kv_cache_config,
@ -222,6 +237,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
hash_block_size=hash_block_size,
)
self.num_single_type_manager = len(self.single_type_managers)
@ -255,6 +271,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
enable_kv_cache_events: bool,
dcp_world_size: int,
pcp_world_size: int,
hash_block_size: int,
):
super().__init__(
kv_cache_config,
@ -264,6 +281,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
hash_block_size=hash_block_size,
)
self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec
self.block_size = self.kv_cache_spec.block_size
@ -273,6 +291,11 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
self.block_size *= dcp_world_size
if pcp_world_size > 1:
self.block_size *= pcp_world_size
# For models using only Mamba, block_size is set to max_model_len when
# prefix caching is disabled, and hash_block_size validation is skipped.
assert not enable_caching or (hash_block_size == self.block_size), (
"UnitaryKVCacheCoordinator assumes hash_block_size == block_size"
)
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
"UnitaryKVCacheCoordinator assumes only one kv cache group"
)
@ -289,6 +312,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
block_pool=self.block_pool,
kv_cache_spec=self.kv_cache_spec,
use_eagle=self.use_eagle,
alignment_tokens=self.block_size,
dcp_world_size=self.dcp_world_size,
pcp_world_size=self.pcp_world_size,
)
@ -313,6 +337,7 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
enable_kv_cache_events: bool,
dcp_world_size: int,
pcp_world_size: int,
hash_block_size: int,
):
super().__init__(
kv_cache_config,
@ -322,7 +347,17 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
hash_block_size=hash_block_size,
)
# hash_block_size: the block size used to compute block hashes.
# The actual block size usually equals hash_block_size, but in cases where
# different KV cache groups have different block sizes, the actual block size
# can be a multiple of hash_block_size.
self.hash_block_size = hash_block_size
assert all(
g.kv_cache_spec.block_size % hash_block_size == 0
for g in kv_cache_config.kv_cache_groups
), "block_size must be divisible by hash_block_size"
assert dcp_world_size == 1, "DCP not support hybrid attn now."
assert pcp_world_size == 1, "PCP not support hybrid attn now."
self.verify_and_split_kv_cache_groups()
@ -373,14 +408,12 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
self.other_spec = other_spec
self.full_attention_block_size = self.full_attention_spec.block_size
self.other_block_size = self.other_spec.block_size
if self.enable_caching:
# this requirement is only needed for the prefix caching logic
divisible = self.other_block_size % self.full_attention_block_size
assert divisible == 0, (
"KVCacheCoordinator assumes the block_size of full "
"attention layers is divisible by other layers now."
)
# The LCM of the block sizes of full attention and other attention.
# The cache hit length must be a multiple of the LCM of the block sizes
# to make sure the cache hit length is a multiple of the block size of
# each attention type. Requiring this because we don't support partial
# block cache hit yet.
self.lcm_block_size = lcm(self.full_attention_block_size, self.other_block_size)
if max(self.full_attention_group_ids) < min(self.other_group_ids):
self.full_attn_first = True
@ -414,25 +447,48 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
- The number of tokens of the longest cache hit.
"""
# First, find the longest cache hit for full attention.
if self.full_attention_spec.block_size == self.hash_block_size:
# Common case.
full_attention_block_hashes: BlockHashList = block_hashes
else:
# block_size is a multiple of hash_block_size. This happens when different
# KV cache groups have different block sizes. In this case, we need to
# recalculate block_hashes at the granularity of block_size, using the
# original block_hashes (at the granularity of hash_block_size).
full_attention_block_hashes = BlockHashListWithBlockSize(
block_hashes, self.hash_block_size, self.full_attention_spec.block_size
)
hit_blocks_full_attn = self.full_attention_manager_cls.find_longest_cache_hit(
block_hashes=block_hashes,
block_hashes=full_attention_block_hashes,
max_length=max_cache_hit_length,
kv_cache_group_ids=self.full_attention_group_ids,
block_pool=self.block_pool,
kv_cache_spec=self.full_attention_spec,
use_eagle=self.use_eagle,
alignment_tokens=self.lcm_block_size,
)
hit_length = len(hit_blocks_full_attn[0]) * self.full_attention_block_size
# Next, find the cache hit for the other attention WITHIN
# the cache hit of full attention.
if self.other_spec.block_size == self.hash_block_size:
# Common case.
other_block_hashes: BlockHashList = block_hashes
else:
# Similar to the full attention case, here we need to recalculate
# block_hashes at the granularity of block_size, using the original
# block_hashes (at the granularity of hash_block_size).
other_block_hashes = BlockHashListWithBlockSize(
block_hashes, self.hash_block_size, self.other_spec.block_size
)
hit_blocks_other_attn = self.other_attention_cls.find_longest_cache_hit(
block_hashes=block_hashes,
block_hashes=other_block_hashes,
max_length=hit_length,
kv_cache_group_ids=self.other_group_ids,
block_pool=self.block_pool,
kv_cache_spec=self.other_spec,
use_eagle=self.use_eagle,
alignment_tokens=self.lcm_block_size,
)
hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size
@ -466,6 +522,7 @@ def get_kv_cache_coordinator(
enable_kv_cache_events: bool,
dcp_world_size: int,
pcp_world_size: int,
hash_block_size: int,
) -> KVCacheCoordinator:
if not enable_caching:
return KVCacheCoordinatorNoPrefixCache(
@ -473,8 +530,9 @@ def get_kv_cache_coordinator(
max_model_len,
use_eagle,
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
dcp_world_size,
pcp_world_size,
hash_block_size,
)
if len(kv_cache_config.kv_cache_groups) == 1:
return UnitaryKVCacheCoordinator(
@ -483,8 +541,9 @@ def get_kv_cache_coordinator(
use_eagle,
enable_caching,
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
dcp_world_size,
pcp_world_size,
hash_block_size,
)
return HybridKVCacheCoordinator(
kv_cache_config,
@ -492,6 +551,7 @@ def get_kv_cache_coordinator(
use_eagle,
enable_caching,
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
dcp_world_size,
pcp_world_size,
hash_block_size,
)

View File

@ -95,6 +95,7 @@ class KVCacheManager:
self,
kv_cache_config: KVCacheConfig,
max_model_len: int,
hash_block_size: int,
enable_caching: bool = True,
use_eagle: bool = False,
log_stats: bool = False,
@ -107,28 +108,11 @@ class KVCacheManager:
self.enable_caching = enable_caching
self.use_eagle = use_eagle
self.log_stats = log_stats
# FIXME: make prefix cache stats conditional on log_stats
# FIXME: make prefix cache stats conditional on log_stats. We still need
# this comment because when the log stats is enabled there are still
# potential configs we could expose in the future.
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
self.block_size: int | None = None
if self.enable_caching:
assert (
len(
set(
g.kv_cache_spec.block_size
for g in kv_cache_config.kv_cache_groups
)
)
== 1
), "Only one block size is supported for now"
self.block_size = kv_cache_config.kv_cache_groups[
0
].kv_cache_spec.block_size
if dcp_world_size * pcp_world_size > 1:
assert len(kv_cache_config.kv_cache_groups) == 1
self.block_size *= dcp_world_size * pcp_world_size
self.coordinator = get_kv_cache_coordinator(
kv_cache_config=kv_cache_config,
max_model_len=self.max_model_len,
@ -137,6 +121,7 @@ class KVCacheManager:
enable_kv_cache_events=enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
hash_block_size=hash_block_size,
)
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
self.block_pool = self.coordinator.block_pool

View File

@ -5,9 +5,9 @@
import copy
import os
from collections import defaultdict
from collections.abc import Callable, Iterable, Sequence
from dataclasses import dataclass
from typing import Any, NewType, TypeAlias
from collections.abc import Callable, Iterable, Iterator, Sequence
from dataclasses import dataclass, replace
from typing import Any, NewType, TypeAlias, overload
from vllm import envs
from vllm.config import VllmConfig
@ -825,11 +825,11 @@ def get_num_blocks(
return num_blocks
def get_uniform_page_size(kv_cache_spec: dict[str, KVCacheSpec]) -> int:
def get_uniform_page_size(kv_cache_specs: Iterable[KVCacheSpec]) -> int:
"""
Get the page size of the KV cache.
"""
page_sizes = set(layer.page_size_bytes for layer in kv_cache_spec.values())
page_sizes = {layer.page_size_bytes for layer in kv_cache_specs}
assert len(page_sizes) == 1
return page_sizes.pop()
@ -882,6 +882,46 @@ def is_kv_cache_page_size_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool
return len(page_sizes) == 1
def unify_kv_cache_spec_page_size(
kv_cache_spec: dict[str, KVCacheSpec],
) -> dict[str, KVCacheSpec]:
"""
Unify the page size of the given KVCacheSpec. If the page size of all layers
are the same, return the original KVCacheSpec. If not same, unify the page
size by increasing the block size of layers with smaller page size. Raise
NotImplementedError if failed to unify the page size.
Args:
kv_cache_spec: The KVCacheSpec of each attention layer in the model
Returns:
The updated KVCacheSpec with the same page_size_bytes.
"""
page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()}
if len(page_sizes) <= 1:
# All layers have the same page size, no need to unify.
return kv_cache_spec
max_page_size = max(page_sizes)
new_kv_cache_spec = {}
for layer_name, layer_spec in kv_cache_spec.items():
if layer_spec.page_size_bytes == max_page_size:
new_kv_cache_spec[layer_name] = layer_spec
else:
layer_page_size = layer_spec.page_size_bytes
if max_page_size % layer_page_size != 0:
raise NotImplementedError(
"The page size of the layer is not divisible by the "
"maximum page size. Cannot unify by adjusting block_size."
)
ratio = max_page_size // layer_page_size
new_block_size = layer_spec.block_size * ratio
new_spec = replace(layer_spec, block_size=new_block_size)
assert new_spec.page_size_bytes == max_page_size
new_kv_cache_spec[layer_name] = new_spec
return new_kv_cache_spec
def is_kv_cache_type_attention_free(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
# kv_cache_spec is an empty dict for attention free models
return not kv_cache_spec
@ -1010,7 +1050,6 @@ def _get_kv_cache_groups_uniform_page_size(
def get_kv_cache_config_from_groups(
vllm_config: VllmConfig,
kv_cache_groups: list[KVCacheGroupSpec],
kv_cache_specs: dict[str, KVCacheSpec],
available_memory: int,
) -> KVCacheConfig:
"""
@ -1020,7 +1059,6 @@ def get_kv_cache_config_from_groups(
Args:
vllm_config: The global VllmConfig
kv_cache_groups: The KV cache groups
kv_cache_specs: The KV cache spec of each attention layer in the model
available_memory: Memory available for KV cache in bytes
Returns:
The generated KVCacheConfig
@ -1064,7 +1102,9 @@ def get_kv_cache_config_from_groups(
# full.1, sw.2: share another Tensor with size=available_memory//2
group_size = max(len(group.layer_names) for group in kv_cache_groups)
page_size = get_uniform_page_size(kv_cache_specs)
page_size = get_uniform_page_size(
[group.kv_cache_spec for group in kv_cache_groups]
)
assert group_size > 0, "group_size must be greater than 0"
num_blocks = get_num_blocks(
vllm_config, group_size, available_memory, page_size
@ -1166,7 +1206,8 @@ def get_kv_cache_groups(
# This returns an empty list to allow for the KVCacheManager to handle
# attention free models.
return []
elif is_kv_cache_spec_uniform(kv_cache_spec):
if is_kv_cache_spec_uniform(kv_cache_spec):
# KV cache of all layers are the same, which is true for
# most models. Allocate the same amount of memory for
# each layer.
@ -1176,14 +1217,16 @@ def get_kv_cache_groups(
# full attention, or all layers are sliding window attention with the
# same window size). Put all layers into one group.
return _get_kv_cache_groups_uniform_type(uniform_spec)
elif is_kv_cache_page_size_uniform(kv_cache_spec):
# Model contains multiple attention types, but KV cache of all layers
# have the same physical memory per block per layer. Split the layers
# into groups with the same number of layers, and thus same total page
# size.
return _get_kv_cache_groups_uniform_page_size(kv_cache_spec)
raise NotImplementedError
# As KVCacheManager can only allocate memory of one size, we need to unify
# the page size of the layers. For cases cannot be unified, this function
# will raise an error.
kv_cache_spec = unify_kv_cache_spec_page_size(kv_cache_spec)
# Model contains multiple attention types, but KV cache of all layers
# have the same physical memory per block per layer. Split the layers
# into groups with the same number of layers, and thus same total page
# size.
return _get_kv_cache_groups_uniform_page_size(kv_cache_spec)
def generate_scheduler_kv_cache_config(
@ -1327,10 +1370,7 @@ def get_kv_cache_configs(
) == len(kv_cache_spec_one_worker), "Some layers are not assigned to any group."
kv_cache_configs.append(
get_kv_cache_config_from_groups(
vllm_config,
kv_cache_groups_one_worker,
kv_cache_spec_one_worker,
available_memory_one_worker,
vllm_config, kv_cache_groups_one_worker, available_memory_one_worker
)
)
@ -1353,3 +1393,79 @@ def get_kv_cache_configs(
_report_kv_cache_config(vllm_config, kv_cache_config)
return kv_cache_configs
class BlockHashListWithBlockSize:
"""
Convert block-hash granularity from `hash_block_size` to `target_block_size`.
Used when KV cache groups have different block sizes: `hash_block_size`
is the size used to compute the original `block_hashes`; `target_block_size`
is the group's actual block size.
Currently, only scaling up by an integer factor is supported (i.e.,
`target_block_size` is a multiple of `hash_block_size`). Conversion is
performed lazily on access for efficiency, by concatenating consecutive
hashes at `hash_block_size` to form each hash at `target_block_size`.
Example (`hash_block_size` = 16, `target_block_size` = 32):
concatenating two 16-size hashes yields one 32-size hash:
Block hashes with block_size 16:
| Token Range | 0-15 | 16-31 | 32-47 | 48-63 |
|-------------|------|-------|-------|-------|
| Hash | A | B | C | D |
Block hashes with block_size 32:
| Token Range | 0-31 | 32-63 |
|-------------|------|-------|
| Hash | AB | CD |
Args:
block_hashes: Block hashes to convert, computed at `hash_block_size`.
hash_block_size: Block size at which `block_hashes` were computed.
target_block_size: Desired block size; must be a multiple of `hash_block_size`.
"""
def __init__(
self,
block_hashes: list[BlockHash],
hash_block_size: int,
target_block_size: int,
):
self.block_hashes = block_hashes
assert target_block_size % hash_block_size == 0
self.scale_factor = target_block_size // hash_block_size
def __len__(self) -> int:
return len(self.block_hashes) // self.scale_factor
@overload
def __getitem__(self, idx: int) -> BlockHash: ...
@overload
def __getitem__(self, idx: slice) -> list[BlockHash]: ...
def __getitem__(self, idx):
if isinstance(idx, int):
return self._get_value_at(idx)
if isinstance(idx, slice):
start, stop, step = idx.indices(len(self))
return [self._get_value_at(i) for i in range(start, stop, step)]
raise TypeError(f"Invalid index type: {type(idx)!r}")
def __iter__(self) -> Iterator[BlockHash]:
for i in range(len(self)):
yield self._get_value_at(i)
def _get_value_at(self, idx: int) -> BlockHash:
base = idx * self.scale_factor
end = base + self.scale_factor
merged_hash: bytes = self.block_hashes[base]
for i in range(base + 1, end):
merged_hash += self.block_hashes[i]
return BlockHash(merged_hash)
BlockHashList = list[BlockHash] | BlockHashListWithBlockSize

View File

@ -186,6 +186,7 @@ class Scheduler(SchedulerInterface):
enable_kv_cache_events=self.enable_kv_cache_events,
dcp_world_size=self.dcp_world_size,
pcp_world_size=self.pcp_world_size,
hash_block_size=self.block_size,
)
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER

View File

@ -7,7 +7,7 @@ from collections.abc import Sequence
from vllm.utils.math_utils import cdiv
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
from vllm.v1.core.kv_cache_utils import BlockHashList, KVCacheBlock
from vllm.v1.kv_cache_interface import (
ChunkedLocalAttentionSpec,
CrossAttentionSpec,
@ -207,12 +207,13 @@ class SingleTypeKVCacheManager(ABC):
@abstractmethod
def find_longest_cache_hit(
cls,
block_hashes: list[BlockHash],
block_hashes: BlockHashList,
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
alignment_tokens: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
@ -232,6 +233,11 @@ class SingleTypeKVCacheManager(ABC):
block_pool: The block pool.
kv_cache_spec: The kv cache spec.
use_eagle: Whether to use eagle.
alignment_tokens: The returned cache hit length (in tokens) should
be a multiple of this value (in tokens). By default, it should
be set to the block_size.
dcp_world_size: The world size of decode context parallelism.
pcp_world_size: The world size of prefill context parallelism.
Returns:
A list of cached blocks with skipped blocks replaced by null block
@ -299,17 +305,18 @@ class FullAttentionManager(SingleTypeKVCacheManager):
@classmethod
def find_longest_cache_hit(
cls,
block_hashes: list[BlockHash],
block_hashes: BlockHashList,
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
alignment_tokens: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(
kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec)
kv_cache_spec, FullAttentionSpec | ChunkedLocalAttentionSpec
), (
"FullAttentionManager can only be used for full attention "
"and chunked local attention groups"
@ -333,6 +340,13 @@ class FullAttentionManager(SingleTypeKVCacheManager):
else:
break
if use_eagle and computed_blocks[0]:
# Need to drop the last matched block if eagle is enabled.
for computed in computed_blocks:
computed.pop()
while (
block_size != alignment_tokens # Faster for common case.
and len(computed_blocks[0]) * block_size % alignment_tokens != 0
):
for computed in computed_blocks:
computed.pop()
return computed_blocks
@ -359,12 +373,13 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
@classmethod
def find_longest_cache_hit(
cls,
block_hashes: list[BlockHash],
block_hashes: BlockHashList,
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
alignment_tokens: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
@ -396,6 +411,7 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
[block_pool.null_block] * max_num_blocks
for _ in range(len(kv_cache_group_ids))
)
block_size = kv_cache_spec.block_size
num_contiguous_blocks = 0
match_found = False
# Search from right to left and early stop when a match is found.
@ -403,6 +419,15 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
if cached_block := block_pool.get_cached_block(
block_hashes[i], kv_cache_group_ids
):
# Skip prefix matching check if the block is not aligned with
# `alignment_tokens`.
if (
num_contiguous_blocks == 0
and block_size != alignment_tokens # Faster for common case.
and (i + 1) * block_size % alignment_tokens != 0
):
continue
# Add the cached block to the computed blocks.
for computed, cached in zip(computed_blocks, cached_block):
computed[i] = cached
num_contiguous_blocks += 1
@ -421,7 +446,16 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
for computed in computed_blocks:
del computed[num_contiguous_blocks:]
while (
block_size != alignment_tokens # Faster for common case.
and len(computed_blocks[0]) * block_size % alignment_tokens != 0
):
for computed in computed_blocks:
computed.pop()
if use_eagle and computed_blocks[0]:
assert kv_cache_spec.block_size == alignment_tokens, (
"aligned_length is not compatible with eagle now"
)
for computed in computed_blocks:
computed.pop()
return computed_blocks
@ -475,12 +509,13 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
@classmethod
def find_longest_cache_hit(
cls,
block_hashes: list[BlockHash],
block_hashes: BlockHashList,
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
alignment_tokens: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
@ -511,6 +546,10 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
block_pool: The block pool.
kv_cache_spec: The kv cache spec.
use_eagle: Whether to use eagle.
dcp_world_size: The world size of decode context parallelism.
pcp_world_size: The world size of prefill context parallelism.
alignment_tokens: The returned cache hit length (in tokens) should
be a multiple of this value (in tokens).
Returns:
A list of cached blocks
@ -524,6 +563,10 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
)
assert dcp_world_size == 1, "DCP not support chunked local attn now."
assert pcp_world_size == 1, "PCP not support chunked local attn now."
assert kv_cache_spec.block_size == alignment_tokens, (
"KV cache groups with different block sizes are not compatible with "
"chunked local attention now"
)
max_num_blocks = max_length // kv_cache_spec.block_size
if max_length > 0:
local_attention_start_idx = (
@ -612,12 +655,13 @@ class MambaManager(SingleTypeKVCacheManager):
@classmethod
def find_longest_cache_hit(
cls,
block_hashes: list[BlockHash],
block_hashes: BlockHashList,
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
alignment_tokens: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
@ -630,12 +674,21 @@ class MambaManager(SingleTypeKVCacheManager):
[] for _ in range(len(kv_cache_group_ids))
)
max_num_blocks = max_length // kv_cache_spec.block_size
block_size = kv_cache_spec.block_size
max_num_blocks = max_length // block_size
# Search from right to left and early stop when a match is found.
for i in range(max_num_blocks - 1, -1, -1):
if cached_block := block_pool.get_cached_block(
block_hashes[i], kv_cache_group_ids
):
# When enable Mamba prefix caching, `block_size` will be aligned
# across full attention layers and Mamba layers to ensure the
# prefix hit length aligned at block
if (
block_size != alignment_tokens # Faster for common case.
and (i + 1) * block_size % alignment_tokens != 0
):
continue
for computed, cached in zip(computed_blocks, cached_block):
# the hit length logic later assumes:
# hit_length = len(hit_blocks_other_attn[0])
@ -708,12 +761,13 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
@classmethod
def find_longest_cache_hit(
cls,
block_hashes: list[BlockHash],
block_hashes: BlockHashList,
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
alignment_tokens: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]: