[Hybrid Allocator] Support KV cache groups with different block_size (#29143)

Signed-off-by: Yifan Qiao <yifanqiao@berkeley.edu>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Yifan Qiao 2025-11-25 07:30:57 -08:00 committed by GitHub
parent e502098643
commit 48ddb02b79
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 472 additions and 87 deletions

View File

@ -1248,7 +1248,9 @@ def test_allocate_with_lookahead():
) )
# Test case 1: Requires additional lookahead tokens # Test case 1: Requires additional lookahead tokens
kv_cache_manager = KVCacheManager(kv_cache_config=config, max_model_len=100) kv_cache_manager = KVCacheManager(
kv_cache_config=config, max_model_len=100, hash_block_size=block_size
)
blocks = kv_cache_manager.allocate_slots( blocks = kv_cache_manager.allocate_slots(
request, request,
num_new_tokens=3, num_new_tokens=3,
@ -1257,7 +1259,9 @@ def test_allocate_with_lookahead():
assert len(blocks.get_block_ids()[0]) == 2 # ceil(5/4)=2 blocks assert len(blocks.get_block_ids()[0]) == 2 # ceil(5/4)=2 blocks
# Test case 2: With precomputed blocks # Test case 2: With precomputed blocks
kv_cache_manager = KVCacheManager(kv_cache_config=config, max_model_len=100) kv_cache_manager = KVCacheManager(
kv_cache_config=config, max_model_len=100, hash_block_size=block_size
)
# required_blocks = ceil((3 + 2) /4) = 2 # required_blocks = ceil((3 + 2) /4) = 2
blocks = kv_cache_manager.allocate_slots( blocks = kv_cache_manager.allocate_slots(
request, request,
@ -1268,7 +1272,9 @@ def test_allocate_with_lookahead():
# Test case 3: With precomputed blocks # Test case 3: With precomputed blocks
# required_blocks = ceil((3 + 4) / 4) = 2 # required_blocks = ceil((3 + 4) / 4) = 2
kv_cache_manager = KVCacheManager(kv_cache_config=config, max_model_len=100) kv_cache_manager = KVCacheManager(
kv_cache_config=config, max_model_len=100, hash_block_size=block_size
)
blocks = kv_cache_manager.allocate_slots( blocks = kv_cache_manager.allocate_slots(
request, request,
num_new_tokens=3, num_new_tokens=3,
@ -1495,7 +1501,8 @@ def test_get_kv_cache_config_one_worker():
), ),
], ],
) )
# different hidden size
# different hidden size but same type, use UniformTypeKVCacheSpecs
kv_cache_specs_hybrid = { kv_cache_specs_hybrid = {
"layer_1": new_kv_cache_spec(head_size=128), "layer_1": new_kv_cache_spec(head_size=128),
"layer_2": new_kv_cache_spec(head_size=64), "layer_2": new_kv_cache_spec(head_size=64),
@ -1519,6 +1526,40 @@ def test_get_kv_cache_config_one_worker():
], ],
) )
# Different hidden size and different type, align by different block size
kv_cache_specs_hybrid = {
"layer_1": new_kv_cache_spec(head_size=64),
"layer_2": new_sliding_window_spec(head_size=32),
}
kv_cache_config_hybrid = get_kv_cache_configs(
vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 32]
)[0]
assert kv_cache_config_hybrid == KVCacheConfig(
num_blocks=32,
kv_cache_tensors=[
KVCacheTensor(
size=mem_per_block_per_layer * 32, shared_by=["layer_1", "layer_2"]
),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer_1"], new_kv_cache_spec(head_size=64)),
KVCacheGroupSpec(
["layer_2"], new_sliding_window_spec(head_size=32, block_size=32)
),
],
)
# different hidden size that cannot be aligned by using different block size
kv_cache_specs_hybrid = {
"layer_1": new_kv_cache_spec(head_size=64),
"layer_2": new_sliding_window_spec(head_size=96),
}
with pytest.raises(NotImplementedError):
get_kv_cache_configs(
vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 2 * 32]
)[0]
# Test num_gpu_blocks_override # Test num_gpu_blocks_override
vllm_config.cache_config.num_gpu_blocks_override = 16 vllm_config.cache_config.num_gpu_blocks_override = 16
kv_cache_config_override_blocks = get_kv_cache_configs( kv_cache_config_override_blocks = get_kv_cache_configs(

View File

@ -134,6 +134,7 @@ def test_prefill(hash_fn):
make_kv_cache_config(block_size, 11), make_kv_cache_config(block_size, 11),
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
hash_block_size=block_size,
) )
# Complete 3 blocks (48 tokens) # Complete 3 blocks (48 tokens)
@ -256,6 +257,7 @@ def test_prefill_hybrid_model():
make_kv_cache_config_hybrid_model(block_size, 21), make_kv_cache_config_hybrid_model(block_size, 21),
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
hash_block_size=block_size,
) )
hash_fn = sha256 hash_fn = sha256
@ -416,6 +418,7 @@ def test_prefill_plp():
make_kv_cache_config(block_size, 11), make_kv_cache_config(block_size, 11),
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
hash_block_size=block_size,
) )
# the default hash function is sha256 # the default hash function is sha256
hash_fn = sha256 hash_fn = sha256
@ -523,6 +526,7 @@ def test_decode():
make_kv_cache_config(block_size, 11), make_kv_cache_config(block_size, 11),
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
hash_block_size=block_size,
) )
# Complete 3 blocks (48 tokens) # Complete 3 blocks (48 tokens)
@ -585,6 +589,7 @@ def test_evict():
make_kv_cache_config(block_size, 11), make_kv_cache_config(block_size, 11),
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
hash_block_size=block_size,
) )
last_token_id = 5 * 16 + 7 last_token_id = 5 * 16 + 7
@ -643,6 +648,7 @@ def test_hash_block_correct_reuse():
make_kv_cache_config(16, 2), make_kv_cache_config(16, 2),
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
hash_block_size=block_size,
) )
# Allocate 1 block and cache it. # Allocate 1 block and cache it.
@ -683,6 +689,7 @@ def test_computed_blocks_not_evicted():
make_kv_cache_config(block_size, 3), make_kv_cache_config(block_size, 3),
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
hash_block_size=block_size,
) )
# Allocate a block and cache it. # Allocate a block and cache it.
@ -741,6 +748,7 @@ def test_basic_prefix_caching_disabled():
make_kv_cache_config(block_size, 5), make_kv_cache_config(block_size, 5),
max_model_len=8192, max_model_len=8192,
enable_caching=False, enable_caching=False,
hash_block_size=block_size,
) )
req1 = make_request( req1 = make_request(
@ -790,6 +798,7 @@ def test_cache_blocks(hash_fn):
block_pool = BlockPool( block_pool = BlockPool(
num_gpu_blocks=5, num_gpu_blocks=5,
enable_caching=True, enable_caching=True,
hash_block_size=block_size,
) )
# Req: # Req:
# Block 0: [0, 1, 2, 3] # Block 0: [0, 1, 2, 3]
@ -833,7 +842,9 @@ def test_cache_blocks_multi_group():
This tests that blocks are cached correctly for different kv cache groups. This tests that blocks are cached correctly for different kv cache groups.
""" """
block_size = 4 block_size = 4
block_pool = BlockPool(num_gpu_blocks=10, enable_caching=True) block_pool = BlockPool(
num_gpu_blocks=10, enable_caching=True, hash_block_size=block_size
)
# Req: # Req:
# Block 0/4: [0, 1, 2, 3] # Block 0/4: [0, 1, 2, 3]
@ -921,6 +932,7 @@ def test_mm_prefix_caching():
make_kv_cache_config(block_size, 11), make_kv_cache_config(block_size, 11),
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
hash_block_size=block_size,
) )
# Common prompt tokens (T is text tokens and P is image placeholder tokens) # Common prompt tokens (T is text tokens and P is image placeholder tokens)
@ -1020,6 +1032,7 @@ def test_cache_key_salting():
make_kv_cache_config(block_size, 11), make_kv_cache_config(block_size, 11),
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
hash_block_size=block_size,
) )
# 3 complete blocks and an incomplete block with 11 tokens. # 3 complete blocks and an incomplete block with 11 tokens.
@ -1101,6 +1114,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
make_kv_cache_config(block_size, 11), make_kv_cache_config(block_size, 11),
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
hash_block_size=block_size,
) )
# Complete 3 blocks (48 tokens) # Complete 3 blocks (48 tokens)
# | Common-0 | Common-1 | Common-2 | ... | # | Common-0 | Common-1 | Common-2 | ... |
@ -1173,6 +1187,7 @@ def test_reset_prefix_cache():
make_kv_cache_config(block_size, 11), make_kv_cache_config(block_size, 11),
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
hash_block_size=block_size,
) )
full_block_token_ids = [i for i in range(3) for _ in range(16)] full_block_token_ids = [i for i in range(3) for _ in range(16)]
@ -1213,6 +1228,7 @@ def test_prefix_cache_stats_disabled():
make_kv_cache_config(block_size, 11), make_kv_cache_config(block_size, 11),
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
hash_block_size=block_size,
log_stats=False, # Disable logging stats log_stats=False, # Disable logging stats
) )
assert manager.prefix_cache_stats is None assert manager.prefix_cache_stats is None
@ -1232,7 +1248,7 @@ def test_prefix_cache_stats_disabled():
def test_maybe_evict_cached_block(): def test_maybe_evict_cached_block():
pool = BlockPool(num_gpu_blocks=4, enable_caching=True) pool = BlockPool(num_gpu_blocks=4, enable_caching=True, hash_block_size=16)
block_hash0 = make_block_hash_with_group_id(BlockHash(b"10"), 1000) block_hash0 = make_block_hash_with_group_id(BlockHash(b"10"), 1000)
block_hash1 = make_block_hash_with_group_id(BlockHash(b"20"), 2000) block_hash1 = make_block_hash_with_group_id(BlockHash(b"20"), 2000)
block_hash2 = make_block_hash_with_group_id(BlockHash(b"30"), 3000) block_hash2 = make_block_hash_with_group_id(BlockHash(b"30"), 3000)
@ -1293,6 +1309,7 @@ def test_kv_cache_events(blocks_to_cache: int):
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
enable_kv_cache_events=True, enable_kv_cache_events=True,
hash_block_size=block_size,
) )
num_tokens = block_size * blocks_to_cache num_tokens = block_size * blocks_to_cache
@ -1351,6 +1368,7 @@ def test_kv_cache_events_with_lora(blocks_to_cache: int):
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
enable_kv_cache_events=True, enable_kv_cache_events=True,
hash_block_size=block_size,
) )
# Test with LoRA request # Test with LoRA request
@ -1405,6 +1423,7 @@ def test_eagle_enabled_removes_last_block():
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
use_eagle=True, use_eagle=True,
hash_block_size=block_size,
) )
# Request with 3 full blocks (48 tokens) # Request with 3 full blocks (48 tokens)
@ -1437,6 +1456,7 @@ def test_eagle_with_partial_blocks():
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
use_eagle=True, use_eagle=True,
hash_block_size=block_size,
) )
# 2 full blocks + 5 tokens (non-divisible length) # 2 full blocks + 5 tokens (non-divisible length)
token_ids = [0] * (2 * block_size + 5) token_ids = [0] * (2 * block_size + 5)
@ -1476,6 +1496,7 @@ def test_eagle_with_sliding_window():
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
use_eagle=True, use_eagle=True,
hash_block_size=block_size,
) )
# 2 full blocks + 5 tokens (non-divisible length) # 2 full blocks + 5 tokens (non-divisible length)
@ -1522,6 +1543,76 @@ def test_eagle_with_sliding_window():
assert num_tokens == 0 assert num_tokens == 0
def test_different_block_size():
block_size = 16
# full attention and sliding window attention layers have the same page size:
# (32 tokens/block * float16 token, vs. 16 tokens/block * float32 token)
kv_cache_config = KVCacheConfig(
num_blocks=100,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer1"],
FullAttentionSpec(block_size * 2, 1, 1, torch.float16),
),
KVCacheGroupSpec(
["layer2"],
SlidingWindowSpec(
block_size,
1,
1,
torch.float32,
sliding_window=2 * block_size,
),
),
],
)
manager = KVCacheManager(
kv_cache_config=kv_cache_config,
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
# 10 blocks of 16 tokens each. Token ids are not strictly aligned for each block.
common_token_ids = [i for i in range(10) for _ in range(block_size)]
req0 = make_request("0", common_token_ids, block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks[0]
assert not computed_blocks.blocks[1]
assert num_computed_tokens == 0
blocks = manager.allocate_slots(
req0, 7 * block_size, len(computed_blocks.blocks[0]) * 16, computed_blocks
)
assert blocks.get_block_ids() == ([1, 2, 3, 4], [5, 6, 7, 8, 9, 10, 11])
req1 = make_request("1", common_token_ids[: 7 * block_size + 1], block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(computed_blocks.blocks[0]) == 3
assert len(computed_blocks.blocks[1]) == 6
assert num_computed_tokens == 6 * 16
req2 = make_request("2", common_token_ids[: 6 * block_size + 1], block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(computed_blocks.blocks[0]) == 3
assert len(computed_blocks.blocks[1]) == 6
assert num_computed_tokens == 6 * 16
# Evict some blocks to make sliding window cache hit length 5*16
# But should return 4 * 16 because full attention cache hit length must be
# a multiple of 32
manager.block_pool.cached_block_hash_to_block.pop(
make_block_hash_with_group_id(req1.block_hashes[6], 1), 11
)
manager.block_pool.cached_block_hash_to_block.pop(
make_block_hash_with_group_id(req1.block_hashes[5], 1), 10
)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(computed_blocks.blocks[0]) == 2
assert len(computed_blocks.blocks[1]) == 4
assert num_computed_tokens == 4 * 16
def test_block_lookup_cache_single_block_per_key(): def test_block_lookup_cache_single_block_per_key():
cache = BlockHashToBlockMap() cache = BlockHashToBlockMap()
key0 = BlockHashWithGroupId(b"hash0") key0 = BlockHashWithGroupId(b"hash0")

View File

@ -41,7 +41,9 @@ def test_chunked_local_attention_possible_cached_prefix():
attention_chunk_size=4, attention_chunk_size=4,
) )
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) block_pool = BlockPool(
num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
)
manager = get_chunked_local_attention_manager( manager = get_chunked_local_attention_manager(
chunked_local_attention_spec, block_pool chunked_local_attention_spec, block_pool
) )
@ -70,6 +72,7 @@ def test_chunked_local_attention_possible_cached_prefix():
block_pool=block_pool, block_pool=block_pool,
kv_cache_spec=chunked_local_attention_spec, kv_cache_spec=chunked_local_attention_spec,
use_eagle=False, use_eagle=False,
alignment_tokens=block_size,
)[0] )[0]
assert len(computed_blocks) == expect_length assert len(computed_blocks) == expect_length
@ -111,7 +114,9 @@ def test_sliding_window_possible_cached_prefix():
sliding_window=4, sliding_window=4,
) )
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) block_pool = BlockPool(
num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
)
manager = get_sliding_window_manager(sliding_window_spec, block_pool) manager = get_sliding_window_manager(sliding_window_spec, block_pool)
def run_one_case(block_is_cached, expect_length): def run_one_case(block_is_cached, expect_length):
@ -138,6 +143,7 @@ def test_sliding_window_possible_cached_prefix():
block_pool=block_pool, block_pool=block_pool,
kv_cache_spec=sliding_window_spec, kv_cache_spec=sliding_window_spec,
use_eagle=False, use_eagle=False,
alignment_tokens=block_size,
)[0] )[0]
assert len(computed_blocks) == expect_length assert len(computed_blocks) == expect_length
@ -178,7 +184,7 @@ def test_chunked_local_attention_remove_skipped_blocks():
attention_chunk_size=4, attention_chunk_size=4,
) )
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True) block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True, hash_block_size=2)
manager = get_chunked_local_attention_manager(attention_spec, block_pool) manager = get_chunked_local_attention_manager(attention_spec, block_pool)
@ -239,7 +245,7 @@ def test_sliding_window_remove_skipped_blocks():
sliding_window=4, sliding_window=4,
) )
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True) block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True, hash_block_size=2)
manager = get_sliding_window_manager(sliding_window_spec, block_pool) manager = get_sliding_window_manager(sliding_window_spec, block_pool)
@ -316,7 +322,9 @@ def test_get_num_blocks_to_allocate():
sliding_window=4, # Placeholder value, not related to test result sliding_window=4, # Placeholder value, not related to test result
) )
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) block_pool = BlockPool(
num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
)
manager = get_sliding_window_manager(sliding_window_spec, block_pool) manager = get_sliding_window_manager(sliding_window_spec, block_pool)
cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)] cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)]
cached_blocks_2 = [block_pool.null_block for _ in range(5)] + [ cached_blocks_2 = [block_pool.null_block for _ in range(5)] + [
@ -341,7 +349,9 @@ def test_chunked_local_attention_get_num_blocks_to_allocate():
attention_chunk_size=4, # Placeholder value, not related to test result attention_chunk_size=4, # Placeholder value, not related to test result
) )
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) block_pool = BlockPool(
num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
)
manager = get_chunked_local_attention_manager(attention_spec, block_pool) manager = get_chunked_local_attention_manager(attention_spec, block_pool)
cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)] cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)]
cached_blocks_2 = [block_pool.null_block for _ in range(5)] + [ cached_blocks_2 = [block_pool.null_block for _ in range(5)] + [

View File

@ -1816,9 +1816,11 @@ class EngineArgs:
if model_config.runner_type != "pooling": if model_config.runner_type != "pooling":
default_chunked_prefill = True default_chunked_prefill = True
# Disable prefix caching default for hybrid models # Disable prefix caching default for hybrid models and mamba-only
# since the feature is still experimental. # models since the feature is still experimental.
default_prefix_caching = not model_config.is_hybrid default_prefix_caching = not (
model_config.is_hybrid or model_config.is_attention_free
)
else: else:
assert model_config.pooler_config is not None assert model_config.pooler_config is not None

View File

@ -289,9 +289,6 @@ class MambaModelConfig(VerifyAndUpdateConfig):
model_config = vllm_config.model_config model_config = vllm_config.model_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
if cache_config.mamba_block_size is None:
cache_config.mamba_block_size = model_config.max_model_len
if cache_config.enable_prefix_caching: if cache_config.enable_prefix_caching:
if model_config.supports_mamba_prefix_caching: if model_config.supports_mamba_prefix_caching:
logger.info( logger.info(
@ -299,6 +296,11 @@ class MambaModelConfig(VerifyAndUpdateConfig):
"Its support for Mamba layers is experimental. " "Its support for Mamba layers is experimental. "
"Please report any issues you may observe." "Please report any issues you may observe."
) )
# By default, mamba block size will be set to max_model_len (see
# below). When enabling prefix caching, we align mamba block size
# to the block size as the basic granularity for prefix caching.
if cache_config.mamba_block_size is None:
cache_config.mamba_block_size = cache_config.block_size
else: else:
logger.info( logger.info(
"Hybrid or mamba-based model detected without " "Hybrid or mamba-based model detected without "
@ -306,6 +308,9 @@ class MambaModelConfig(VerifyAndUpdateConfig):
) )
cache_config.enable_prefix_caching = False cache_config.enable_prefix_caching = False
if cache_config.mamba_block_size is None:
cache_config.mamba_block_size = model_config.max_model_len
# TODO(tdoublep): remove once cascade attention is supported # TODO(tdoublep): remove once cascade attention is supported
logger.info( logger.info(
"Disabling cascade attention since it is not supported for hybrid models." "Disabling cascade attention since it is not supported for hybrid models."

View File

@ -13,6 +13,8 @@ from vllm.distributed.kv_events import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import ( from vllm.v1.core.kv_cache_utils import (
BlockHash, BlockHash,
BlockHashList,
BlockHashListWithBlockSize,
BlockHashWithGroupId, BlockHashWithGroupId,
ExternalBlockHash, ExternalBlockHash,
FreeKVCacheBlockQueue, FreeKVCacheBlockQueue,
@ -133,6 +135,10 @@ class BlockPool:
Args: Args:
num_gpu_blocks: The number of blocks in the pool. num_gpu_blocks: The number of blocks in the pool.
enable_caching: Whether to enable prefix caching. enable_caching: Whether to enable prefix caching.
hash_block_size: The block size of which the block hashes are computed.
The actual block size usually equals hash_block_size, but in cases
where different KV cache groups have different block sizes, the
actual block size can be a multiple of hash_block_size.
enable_kv_cache_events: Whether to enable kv cache events. enable_kv_cache_events: Whether to enable kv cache events.
""" """
@ -140,11 +146,13 @@ class BlockPool:
self, self,
num_gpu_blocks: int, num_gpu_blocks: int,
enable_caching: bool, enable_caching: bool,
hash_block_size: int,
enable_kv_cache_events: bool = False, enable_kv_cache_events: bool = False,
): ):
assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0 assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
self.num_gpu_blocks = num_gpu_blocks self.num_gpu_blocks = num_gpu_blocks
self.enable_caching = enable_caching self.enable_caching = enable_caching
self.hash_block_size = hash_block_size
# All kv-cache blocks. # All kv-cache blocks.
self.blocks: list[KVCacheBlock] = [ self.blocks: list[KVCacheBlock] = [
KVCacheBlock(idx) for idx in range(num_gpu_blocks) KVCacheBlock(idx) for idx in range(num_gpu_blocks)
@ -223,8 +231,20 @@ class BlockPool:
return return
new_full_blocks = blocks[num_cached_blocks:num_full_blocks] new_full_blocks = blocks[num_cached_blocks:num_full_blocks]
assert len(request.block_hashes) >= num_full_blocks assert len(request.block_hashes) >= num_full_blocks
new_block_hashes = request.block_hashes[num_cached_blocks:] if block_size == self.hash_block_size:
# Common case.
block_hashes: BlockHashList = request.block_hashes
else:
# block_size is a multiple of hash_block_size. This happens when
# different KV cache groups have different block sizes.
assert block_size % self.hash_block_size == 0
# Recalculate block_hashes at the granularity of block_size, using
# the original block_hashes (at the granularity of hash_block_size).
block_hashes = BlockHashListWithBlockSize(
request.block_hashes, self.hash_block_size, block_size
)
new_block_hashes = block_hashes[num_cached_blocks:]
new_hashes: list[ExternalBlockHash] | None = ( new_hashes: list[ExternalBlockHash] | None = (
[] if self.enable_kv_cache_events else None [] if self.enable_kv_cache_events else None
) )

View File

@ -2,15 +2,25 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence from collections.abc import Sequence
from math import lcm
from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock from vllm.v1.core.kv_cache_utils import (
BlockHash,
BlockHashList,
BlockHashListWithBlockSize,
KVCacheBlock,
)
from vllm.v1.core.single_type_kv_cache_manager import ( from vllm.v1.core.single_type_kv_cache_manager import (
CrossAttentionManager, CrossAttentionManager,
FullAttentionManager, FullAttentionManager,
get_manager_for_kv_cache_spec, get_manager_for_kv_cache_spec,
) )
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig, KVCacheSpec from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheConfig,
KVCacheSpec,
)
from vllm.v1.request import Request from vllm.v1.request import Request
@ -28,13 +38,17 @@ class KVCacheCoordinator(ABC):
enable_kv_cache_events: bool, enable_kv_cache_events: bool,
dcp_world_size: int, dcp_world_size: int,
pcp_world_size: int, pcp_world_size: int,
hash_block_size: int,
): ):
self.kv_cache_config = kv_cache_config self.kv_cache_config = kv_cache_config
self.max_model_len = max_model_len self.max_model_len = max_model_len
self.enable_caching = enable_caching self.enable_caching = enable_caching
self.block_pool = BlockPool( self.block_pool = BlockPool(
kv_cache_config.num_blocks, enable_caching, enable_kv_cache_events kv_cache_config.num_blocks,
enable_caching,
hash_block_size,
enable_kv_cache_events,
) )
# Needs special handling for find_longest_cache_hit if eagle is enabled # Needs special handling for find_longest_cache_hit if eagle is enabled
@ -213,6 +227,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
enable_kv_cache_events: bool, enable_kv_cache_events: bool,
dcp_world_size: int, dcp_world_size: int,
pcp_world_size: int, pcp_world_size: int,
hash_block_size: int,
): ):
super().__init__( super().__init__(
kv_cache_config, kv_cache_config,
@ -222,6 +237,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
enable_kv_cache_events, enable_kv_cache_events,
dcp_world_size=dcp_world_size, dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size, pcp_world_size=pcp_world_size,
hash_block_size=hash_block_size,
) )
self.num_single_type_manager = len(self.single_type_managers) self.num_single_type_manager = len(self.single_type_managers)
@ -255,6 +271,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
enable_kv_cache_events: bool, enable_kv_cache_events: bool,
dcp_world_size: int, dcp_world_size: int,
pcp_world_size: int, pcp_world_size: int,
hash_block_size: int,
): ):
super().__init__( super().__init__(
kv_cache_config, kv_cache_config,
@ -264,6 +281,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
enable_kv_cache_events, enable_kv_cache_events,
dcp_world_size=dcp_world_size, dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size, pcp_world_size=pcp_world_size,
hash_block_size=hash_block_size,
) )
self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec
self.block_size = self.kv_cache_spec.block_size self.block_size = self.kv_cache_spec.block_size
@ -273,6 +291,11 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
self.block_size *= dcp_world_size self.block_size *= dcp_world_size
if pcp_world_size > 1: if pcp_world_size > 1:
self.block_size *= pcp_world_size self.block_size *= pcp_world_size
# For models using only Mamba, block_size is set to max_model_len when
# prefix caching is disabled, and hash_block_size validation is skipped.
assert not enable_caching or (hash_block_size == self.block_size), (
"UnitaryKVCacheCoordinator assumes hash_block_size == block_size"
)
assert len(self.kv_cache_config.kv_cache_groups) == 1, ( assert len(self.kv_cache_config.kv_cache_groups) == 1, (
"UnitaryKVCacheCoordinator assumes only one kv cache group" "UnitaryKVCacheCoordinator assumes only one kv cache group"
) )
@ -289,6 +312,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
block_pool=self.block_pool, block_pool=self.block_pool,
kv_cache_spec=self.kv_cache_spec, kv_cache_spec=self.kv_cache_spec,
use_eagle=self.use_eagle, use_eagle=self.use_eagle,
alignment_tokens=self.block_size,
dcp_world_size=self.dcp_world_size, dcp_world_size=self.dcp_world_size,
pcp_world_size=self.pcp_world_size, pcp_world_size=self.pcp_world_size,
) )
@ -313,6 +337,7 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
enable_kv_cache_events: bool, enable_kv_cache_events: bool,
dcp_world_size: int, dcp_world_size: int,
pcp_world_size: int, pcp_world_size: int,
hash_block_size: int,
): ):
super().__init__( super().__init__(
kv_cache_config, kv_cache_config,
@ -322,7 +347,17 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
enable_kv_cache_events, enable_kv_cache_events,
dcp_world_size=dcp_world_size, dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size, pcp_world_size=pcp_world_size,
hash_block_size=hash_block_size,
) )
# hash_block_size: the block size used to compute block hashes.
# The actual block size usually equals hash_block_size, but in cases where
# different KV cache groups have different block sizes, the actual block size
# can be a multiple of hash_block_size.
self.hash_block_size = hash_block_size
assert all(
g.kv_cache_spec.block_size % hash_block_size == 0
for g in kv_cache_config.kv_cache_groups
), "block_size must be divisible by hash_block_size"
assert dcp_world_size == 1, "DCP not support hybrid attn now." assert dcp_world_size == 1, "DCP not support hybrid attn now."
assert pcp_world_size == 1, "PCP not support hybrid attn now." assert pcp_world_size == 1, "PCP not support hybrid attn now."
self.verify_and_split_kv_cache_groups() self.verify_and_split_kv_cache_groups()
@ -373,14 +408,12 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
self.other_spec = other_spec self.other_spec = other_spec
self.full_attention_block_size = self.full_attention_spec.block_size self.full_attention_block_size = self.full_attention_spec.block_size
self.other_block_size = self.other_spec.block_size self.other_block_size = self.other_spec.block_size
# The LCM of the block sizes of full attention and other attention.
if self.enable_caching: # The cache hit length must be a multiple of the LCM of the block sizes
# this requirement is only needed for the prefix caching logic # to make sure the cache hit length is a multiple of the block size of
divisible = self.other_block_size % self.full_attention_block_size # each attention type. Requiring this because we don't support partial
assert divisible == 0, ( # block cache hit yet.
"KVCacheCoordinator assumes the block_size of full " self.lcm_block_size = lcm(self.full_attention_block_size, self.other_block_size)
"attention layers is divisible by other layers now."
)
if max(self.full_attention_group_ids) < min(self.other_group_ids): if max(self.full_attention_group_ids) < min(self.other_group_ids):
self.full_attn_first = True self.full_attn_first = True
@ -414,25 +447,48 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
- The number of tokens of the longest cache hit. - The number of tokens of the longest cache hit.
""" """
# First, find the longest cache hit for full attention. # First, find the longest cache hit for full attention.
if self.full_attention_spec.block_size == self.hash_block_size:
# Common case.
full_attention_block_hashes: BlockHashList = block_hashes
else:
# block_size is a multiple of hash_block_size. This happens when different
# KV cache groups have different block sizes. In this case, we need to
# recalculate block_hashes at the granularity of block_size, using the
# original block_hashes (at the granularity of hash_block_size).
full_attention_block_hashes = BlockHashListWithBlockSize(
block_hashes, self.hash_block_size, self.full_attention_spec.block_size
)
hit_blocks_full_attn = self.full_attention_manager_cls.find_longest_cache_hit( hit_blocks_full_attn = self.full_attention_manager_cls.find_longest_cache_hit(
block_hashes=block_hashes, block_hashes=full_attention_block_hashes,
max_length=max_cache_hit_length, max_length=max_cache_hit_length,
kv_cache_group_ids=self.full_attention_group_ids, kv_cache_group_ids=self.full_attention_group_ids,
block_pool=self.block_pool, block_pool=self.block_pool,
kv_cache_spec=self.full_attention_spec, kv_cache_spec=self.full_attention_spec,
use_eagle=self.use_eagle, use_eagle=self.use_eagle,
alignment_tokens=self.lcm_block_size,
) )
hit_length = len(hit_blocks_full_attn[0]) * self.full_attention_block_size hit_length = len(hit_blocks_full_attn[0]) * self.full_attention_block_size
# Next, find the cache hit for the other attention WITHIN # Next, find the cache hit for the other attention WITHIN
# the cache hit of full attention. # the cache hit of full attention.
if self.other_spec.block_size == self.hash_block_size:
# Common case.
other_block_hashes: BlockHashList = block_hashes
else:
# Similar to the full attention case, here we need to recalculate
# block_hashes at the granularity of block_size, using the original
# block_hashes (at the granularity of hash_block_size).
other_block_hashes = BlockHashListWithBlockSize(
block_hashes, self.hash_block_size, self.other_spec.block_size
)
hit_blocks_other_attn = self.other_attention_cls.find_longest_cache_hit( hit_blocks_other_attn = self.other_attention_cls.find_longest_cache_hit(
block_hashes=block_hashes, block_hashes=other_block_hashes,
max_length=hit_length, max_length=hit_length,
kv_cache_group_ids=self.other_group_ids, kv_cache_group_ids=self.other_group_ids,
block_pool=self.block_pool, block_pool=self.block_pool,
kv_cache_spec=self.other_spec, kv_cache_spec=self.other_spec,
use_eagle=self.use_eagle, use_eagle=self.use_eagle,
alignment_tokens=self.lcm_block_size,
) )
hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size
@ -466,6 +522,7 @@ def get_kv_cache_coordinator(
enable_kv_cache_events: bool, enable_kv_cache_events: bool,
dcp_world_size: int, dcp_world_size: int,
pcp_world_size: int, pcp_world_size: int,
hash_block_size: int,
) -> KVCacheCoordinator: ) -> KVCacheCoordinator:
if not enable_caching: if not enable_caching:
return KVCacheCoordinatorNoPrefixCache( return KVCacheCoordinatorNoPrefixCache(
@ -473,8 +530,9 @@ def get_kv_cache_coordinator(
max_model_len, max_model_len,
use_eagle, use_eagle,
enable_kv_cache_events, enable_kv_cache_events,
dcp_world_size=dcp_world_size, dcp_world_size,
pcp_world_size=pcp_world_size, pcp_world_size,
hash_block_size,
) )
if len(kv_cache_config.kv_cache_groups) == 1: if len(kv_cache_config.kv_cache_groups) == 1:
return UnitaryKVCacheCoordinator( return UnitaryKVCacheCoordinator(
@ -483,8 +541,9 @@ def get_kv_cache_coordinator(
use_eagle, use_eagle,
enable_caching, enable_caching,
enable_kv_cache_events, enable_kv_cache_events,
dcp_world_size=dcp_world_size, dcp_world_size,
pcp_world_size=pcp_world_size, pcp_world_size,
hash_block_size,
) )
return HybridKVCacheCoordinator( return HybridKVCacheCoordinator(
kv_cache_config, kv_cache_config,
@ -492,6 +551,7 @@ def get_kv_cache_coordinator(
use_eagle, use_eagle,
enable_caching, enable_caching,
enable_kv_cache_events, enable_kv_cache_events,
dcp_world_size=dcp_world_size, dcp_world_size,
pcp_world_size=pcp_world_size, pcp_world_size,
hash_block_size,
) )

View File

@ -95,6 +95,7 @@ class KVCacheManager:
self, self,
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
max_model_len: int, max_model_len: int,
hash_block_size: int,
enable_caching: bool = True, enable_caching: bool = True,
use_eagle: bool = False, use_eagle: bool = False,
log_stats: bool = False, log_stats: bool = False,
@ -107,28 +108,11 @@ class KVCacheManager:
self.enable_caching = enable_caching self.enable_caching = enable_caching
self.use_eagle = use_eagle self.use_eagle = use_eagle
self.log_stats = log_stats self.log_stats = log_stats
# FIXME: make prefix cache stats conditional on log_stats # FIXME: make prefix cache stats conditional on log_stats. We still need
# this comment because when the log stats is enabled there are still
# potential configs we could expose in the future.
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
self.block_size: int | None = None
if self.enable_caching:
assert (
len(
set(
g.kv_cache_spec.block_size
for g in kv_cache_config.kv_cache_groups
)
)
== 1
), "Only one block size is supported for now"
self.block_size = kv_cache_config.kv_cache_groups[
0
].kv_cache_spec.block_size
if dcp_world_size * pcp_world_size > 1:
assert len(kv_cache_config.kv_cache_groups) == 1
self.block_size *= dcp_world_size * pcp_world_size
self.coordinator = get_kv_cache_coordinator( self.coordinator = get_kv_cache_coordinator(
kv_cache_config=kv_cache_config, kv_cache_config=kv_cache_config,
max_model_len=self.max_model_len, max_model_len=self.max_model_len,
@ -137,6 +121,7 @@ class KVCacheManager:
enable_kv_cache_events=enable_kv_cache_events, enable_kv_cache_events=enable_kv_cache_events,
dcp_world_size=dcp_world_size, dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size, pcp_world_size=pcp_world_size,
hash_block_size=hash_block_size,
) )
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups) self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
self.block_pool = self.coordinator.block_pool self.block_pool = self.coordinator.block_pool

View File

@ -5,9 +5,9 @@
import copy import copy
import os import os
from collections import defaultdict from collections import defaultdict
from collections.abc import Callable, Iterable, Sequence from collections.abc import Callable, Iterable, Iterator, Sequence
from dataclasses import dataclass from dataclasses import dataclass, replace
from typing import Any, NewType, TypeAlias from typing import Any, NewType, TypeAlias, overload
from vllm import envs from vllm import envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
@ -825,11 +825,11 @@ def get_num_blocks(
return num_blocks return num_blocks
def get_uniform_page_size(kv_cache_spec: dict[str, KVCacheSpec]) -> int: def get_uniform_page_size(kv_cache_specs: Iterable[KVCacheSpec]) -> int:
""" """
Get the page size of the KV cache. Get the page size of the KV cache.
""" """
page_sizes = set(layer.page_size_bytes for layer in kv_cache_spec.values()) page_sizes = {layer.page_size_bytes for layer in kv_cache_specs}
assert len(page_sizes) == 1 assert len(page_sizes) == 1
return page_sizes.pop() return page_sizes.pop()
@ -882,6 +882,46 @@ def is_kv_cache_page_size_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool
return len(page_sizes) == 1 return len(page_sizes) == 1
def unify_kv_cache_spec_page_size(
kv_cache_spec: dict[str, KVCacheSpec],
) -> dict[str, KVCacheSpec]:
"""
Unify the page size of the given KVCacheSpec. If the page size of all layers
are the same, return the original KVCacheSpec. If not same, unify the page
size by increasing the block size of layers with smaller page size. Raise
NotImplementedError if failed to unify the page size.
Args:
kv_cache_spec: The KVCacheSpec of each attention layer in the model
Returns:
The updated KVCacheSpec with the same page_size_bytes.
"""
page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()}
if len(page_sizes) <= 1:
# All layers have the same page size, no need to unify.
return kv_cache_spec
max_page_size = max(page_sizes)
new_kv_cache_spec = {}
for layer_name, layer_spec in kv_cache_spec.items():
if layer_spec.page_size_bytes == max_page_size:
new_kv_cache_spec[layer_name] = layer_spec
else:
layer_page_size = layer_spec.page_size_bytes
if max_page_size % layer_page_size != 0:
raise NotImplementedError(
"The page size of the layer is not divisible by the "
"maximum page size. Cannot unify by adjusting block_size."
)
ratio = max_page_size // layer_page_size
new_block_size = layer_spec.block_size * ratio
new_spec = replace(layer_spec, block_size=new_block_size)
assert new_spec.page_size_bytes == max_page_size
new_kv_cache_spec[layer_name] = new_spec
return new_kv_cache_spec
def is_kv_cache_type_attention_free(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: def is_kv_cache_type_attention_free(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
# kv_cache_spec is an empty dict for attention free models # kv_cache_spec is an empty dict for attention free models
return not kv_cache_spec return not kv_cache_spec
@ -1010,7 +1050,6 @@ def _get_kv_cache_groups_uniform_page_size(
def get_kv_cache_config_from_groups( def get_kv_cache_config_from_groups(
vllm_config: VllmConfig, vllm_config: VllmConfig,
kv_cache_groups: list[KVCacheGroupSpec], kv_cache_groups: list[KVCacheGroupSpec],
kv_cache_specs: dict[str, KVCacheSpec],
available_memory: int, available_memory: int,
) -> KVCacheConfig: ) -> KVCacheConfig:
""" """
@ -1020,7 +1059,6 @@ def get_kv_cache_config_from_groups(
Args: Args:
vllm_config: The global VllmConfig vllm_config: The global VllmConfig
kv_cache_groups: The KV cache groups kv_cache_groups: The KV cache groups
kv_cache_specs: The KV cache spec of each attention layer in the model
available_memory: Memory available for KV cache in bytes available_memory: Memory available for KV cache in bytes
Returns: Returns:
The generated KVCacheConfig The generated KVCacheConfig
@ -1064,7 +1102,9 @@ def get_kv_cache_config_from_groups(
# full.1, sw.2: share another Tensor with size=available_memory//2 # full.1, sw.2: share another Tensor with size=available_memory//2
group_size = max(len(group.layer_names) for group in kv_cache_groups) group_size = max(len(group.layer_names) for group in kv_cache_groups)
page_size = get_uniform_page_size(kv_cache_specs) page_size = get_uniform_page_size(
[group.kv_cache_spec for group in kv_cache_groups]
)
assert group_size > 0, "group_size must be greater than 0" assert group_size > 0, "group_size must be greater than 0"
num_blocks = get_num_blocks( num_blocks = get_num_blocks(
vllm_config, group_size, available_memory, page_size vllm_config, group_size, available_memory, page_size
@ -1166,7 +1206,8 @@ def get_kv_cache_groups(
# This returns an empty list to allow for the KVCacheManager to handle # This returns an empty list to allow for the KVCacheManager to handle
# attention free models. # attention free models.
return [] return []
elif is_kv_cache_spec_uniform(kv_cache_spec):
if is_kv_cache_spec_uniform(kv_cache_spec):
# KV cache of all layers are the same, which is true for # KV cache of all layers are the same, which is true for
# most models. Allocate the same amount of memory for # most models. Allocate the same amount of memory for
# each layer. # each layer.
@ -1176,14 +1217,16 @@ def get_kv_cache_groups(
# full attention, or all layers are sliding window attention with the # full attention, or all layers are sliding window attention with the
# same window size). Put all layers into one group. # same window size). Put all layers into one group.
return _get_kv_cache_groups_uniform_type(uniform_spec) return _get_kv_cache_groups_uniform_type(uniform_spec)
elif is_kv_cache_page_size_uniform(kv_cache_spec):
# Model contains multiple attention types, but KV cache of all layers
# have the same physical memory per block per layer. Split the layers
# into groups with the same number of layers, and thus same total page
# size.
return _get_kv_cache_groups_uniform_page_size(kv_cache_spec)
raise NotImplementedError # As KVCacheManager can only allocate memory of one size, we need to unify
# the page size of the layers. For cases cannot be unified, this function
# will raise an error.
kv_cache_spec = unify_kv_cache_spec_page_size(kv_cache_spec)
# Model contains multiple attention types, but KV cache of all layers
# have the same physical memory per block per layer. Split the layers
# into groups with the same number of layers, and thus same total page
# size.
return _get_kv_cache_groups_uniform_page_size(kv_cache_spec)
def generate_scheduler_kv_cache_config( def generate_scheduler_kv_cache_config(
@ -1327,10 +1370,7 @@ def get_kv_cache_configs(
) == len(kv_cache_spec_one_worker), "Some layers are not assigned to any group." ) == len(kv_cache_spec_one_worker), "Some layers are not assigned to any group."
kv_cache_configs.append( kv_cache_configs.append(
get_kv_cache_config_from_groups( get_kv_cache_config_from_groups(
vllm_config, vllm_config, kv_cache_groups_one_worker, available_memory_one_worker
kv_cache_groups_one_worker,
kv_cache_spec_one_worker,
available_memory_one_worker,
) )
) )
@ -1353,3 +1393,79 @@ def get_kv_cache_configs(
_report_kv_cache_config(vllm_config, kv_cache_config) _report_kv_cache_config(vllm_config, kv_cache_config)
return kv_cache_configs return kv_cache_configs
class BlockHashListWithBlockSize:
"""
Convert block-hash granularity from `hash_block_size` to `target_block_size`.
Used when KV cache groups have different block sizes: `hash_block_size`
is the size used to compute the original `block_hashes`; `target_block_size`
is the group's actual block size.
Currently, only scaling up by an integer factor is supported (i.e.,
`target_block_size` is a multiple of `hash_block_size`). Conversion is
performed lazily on access for efficiency, by concatenating consecutive
hashes at `hash_block_size` to form each hash at `target_block_size`.
Example (`hash_block_size` = 16, `target_block_size` = 32):
concatenating two 16-size hashes yields one 32-size hash:
Block hashes with block_size 16:
| Token Range | 0-15 | 16-31 | 32-47 | 48-63 |
|-------------|------|-------|-------|-------|
| Hash | A | B | C | D |
Block hashes with block_size 32:
| Token Range | 0-31 | 32-63 |
|-------------|------|-------|
| Hash | AB | CD |
Args:
block_hashes: Block hashes to convert, computed at `hash_block_size`.
hash_block_size: Block size at which `block_hashes` were computed.
target_block_size: Desired block size; must be a multiple of `hash_block_size`.
"""
def __init__(
self,
block_hashes: list[BlockHash],
hash_block_size: int,
target_block_size: int,
):
self.block_hashes = block_hashes
assert target_block_size % hash_block_size == 0
self.scale_factor = target_block_size // hash_block_size
def __len__(self) -> int:
return len(self.block_hashes) // self.scale_factor
@overload
def __getitem__(self, idx: int) -> BlockHash: ...
@overload
def __getitem__(self, idx: slice) -> list[BlockHash]: ...
def __getitem__(self, idx):
if isinstance(idx, int):
return self._get_value_at(idx)
if isinstance(idx, slice):
start, stop, step = idx.indices(len(self))
return [self._get_value_at(i) for i in range(start, stop, step)]
raise TypeError(f"Invalid index type: {type(idx)!r}")
def __iter__(self) -> Iterator[BlockHash]:
for i in range(len(self)):
yield self._get_value_at(i)
def _get_value_at(self, idx: int) -> BlockHash:
base = idx * self.scale_factor
end = base + self.scale_factor
merged_hash: bytes = self.block_hashes[base]
for i in range(base + 1, end):
merged_hash += self.block_hashes[i]
return BlockHash(merged_hash)
BlockHashList = list[BlockHash] | BlockHashListWithBlockSize

View File

@ -186,6 +186,7 @@ class Scheduler(SchedulerInterface):
enable_kv_cache_events=self.enable_kv_cache_events, enable_kv_cache_events=self.enable_kv_cache_events,
dcp_world_size=self.dcp_world_size, dcp_world_size=self.dcp_world_size,
pcp_world_size=self.pcp_world_size, pcp_world_size=self.pcp_world_size,
hash_block_size=self.block_size,
) )
self.use_pp = self.parallel_config.pipeline_parallel_size > 1 self.use_pp = self.parallel_config.pipeline_parallel_size > 1
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER

View File

@ -7,7 +7,7 @@ from collections.abc import Sequence
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock from vllm.v1.core.kv_cache_utils import BlockHashList, KVCacheBlock
from vllm.v1.kv_cache_interface import ( from vllm.v1.kv_cache_interface import (
ChunkedLocalAttentionSpec, ChunkedLocalAttentionSpec,
CrossAttentionSpec, CrossAttentionSpec,
@ -207,12 +207,13 @@ class SingleTypeKVCacheManager(ABC):
@abstractmethod @abstractmethod
def find_longest_cache_hit( def find_longest_cache_hit(
cls, cls,
block_hashes: list[BlockHash], block_hashes: BlockHashList,
max_length: int, max_length: int,
kv_cache_group_ids: list[int], kv_cache_group_ids: list[int],
block_pool: BlockPool, block_pool: BlockPool,
kv_cache_spec: KVCacheSpec, kv_cache_spec: KVCacheSpec,
use_eagle: bool, use_eagle: bool,
alignment_tokens: int,
dcp_world_size: int = 1, dcp_world_size: int = 1,
pcp_world_size: int = 1, pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]: ) -> tuple[list[KVCacheBlock], ...]:
@ -232,6 +233,11 @@ class SingleTypeKVCacheManager(ABC):
block_pool: The block pool. block_pool: The block pool.
kv_cache_spec: The kv cache spec. kv_cache_spec: The kv cache spec.
use_eagle: Whether to use eagle. use_eagle: Whether to use eagle.
alignment_tokens: The returned cache hit length (in tokens) should
be a multiple of this value (in tokens). By default, it should
be set to the block_size.
dcp_world_size: The world size of decode context parallelism.
pcp_world_size: The world size of prefill context parallelism.
Returns: Returns:
A list of cached blocks with skipped blocks replaced by null block A list of cached blocks with skipped blocks replaced by null block
@ -299,17 +305,18 @@ class FullAttentionManager(SingleTypeKVCacheManager):
@classmethod @classmethod
def find_longest_cache_hit( def find_longest_cache_hit(
cls, cls,
block_hashes: list[BlockHash], block_hashes: BlockHashList,
max_length: int, max_length: int,
kv_cache_group_ids: list[int], kv_cache_group_ids: list[int],
block_pool: BlockPool, block_pool: BlockPool,
kv_cache_spec: KVCacheSpec, kv_cache_spec: KVCacheSpec,
use_eagle: bool, use_eagle: bool,
alignment_tokens: int,
dcp_world_size: int = 1, dcp_world_size: int = 1,
pcp_world_size: int = 1, pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]: ) -> tuple[list[KVCacheBlock], ...]:
assert isinstance( assert isinstance(
kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec) kv_cache_spec, FullAttentionSpec | ChunkedLocalAttentionSpec
), ( ), (
"FullAttentionManager can only be used for full attention " "FullAttentionManager can only be used for full attention "
"and chunked local attention groups" "and chunked local attention groups"
@ -333,6 +340,13 @@ class FullAttentionManager(SingleTypeKVCacheManager):
else: else:
break break
if use_eagle and computed_blocks[0]: if use_eagle and computed_blocks[0]:
# Need to drop the last matched block if eagle is enabled.
for computed in computed_blocks:
computed.pop()
while (
block_size != alignment_tokens # Faster for common case.
and len(computed_blocks[0]) * block_size % alignment_tokens != 0
):
for computed in computed_blocks: for computed in computed_blocks:
computed.pop() computed.pop()
return computed_blocks return computed_blocks
@ -359,12 +373,13 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
@classmethod @classmethod
def find_longest_cache_hit( def find_longest_cache_hit(
cls, cls,
block_hashes: list[BlockHash], block_hashes: BlockHashList,
max_length: int, max_length: int,
kv_cache_group_ids: list[int], kv_cache_group_ids: list[int],
block_pool: BlockPool, block_pool: BlockPool,
kv_cache_spec: KVCacheSpec, kv_cache_spec: KVCacheSpec,
use_eagle: bool, use_eagle: bool,
alignment_tokens: int,
dcp_world_size: int = 1, dcp_world_size: int = 1,
pcp_world_size: int = 1, pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]: ) -> tuple[list[KVCacheBlock], ...]:
@ -396,6 +411,7 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
[block_pool.null_block] * max_num_blocks [block_pool.null_block] * max_num_blocks
for _ in range(len(kv_cache_group_ids)) for _ in range(len(kv_cache_group_ids))
) )
block_size = kv_cache_spec.block_size
num_contiguous_blocks = 0 num_contiguous_blocks = 0
match_found = False match_found = False
# Search from right to left and early stop when a match is found. # Search from right to left and early stop when a match is found.
@ -403,6 +419,15 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
if cached_block := block_pool.get_cached_block( if cached_block := block_pool.get_cached_block(
block_hashes[i], kv_cache_group_ids block_hashes[i], kv_cache_group_ids
): ):
# Skip prefix matching check if the block is not aligned with
# `alignment_tokens`.
if (
num_contiguous_blocks == 0
and block_size != alignment_tokens # Faster for common case.
and (i + 1) * block_size % alignment_tokens != 0
):
continue
# Add the cached block to the computed blocks.
for computed, cached in zip(computed_blocks, cached_block): for computed, cached in zip(computed_blocks, cached_block):
computed[i] = cached computed[i] = cached
num_contiguous_blocks += 1 num_contiguous_blocks += 1
@ -421,7 +446,16 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
# `num_contiguous_blocks < sliding_window_contiguous_blocks`. # `num_contiguous_blocks < sliding_window_contiguous_blocks`.
for computed in computed_blocks: for computed in computed_blocks:
del computed[num_contiguous_blocks:] del computed[num_contiguous_blocks:]
while (
block_size != alignment_tokens # Faster for common case.
and len(computed_blocks[0]) * block_size % alignment_tokens != 0
):
for computed in computed_blocks:
computed.pop()
if use_eagle and computed_blocks[0]: if use_eagle and computed_blocks[0]:
assert kv_cache_spec.block_size == alignment_tokens, (
"aligned_length is not compatible with eagle now"
)
for computed in computed_blocks: for computed in computed_blocks:
computed.pop() computed.pop()
return computed_blocks return computed_blocks
@ -475,12 +509,13 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
@classmethod @classmethod
def find_longest_cache_hit( def find_longest_cache_hit(
cls, cls,
block_hashes: list[BlockHash], block_hashes: BlockHashList,
max_length: int, max_length: int,
kv_cache_group_ids: list[int], kv_cache_group_ids: list[int],
block_pool: BlockPool, block_pool: BlockPool,
kv_cache_spec: KVCacheSpec, kv_cache_spec: KVCacheSpec,
use_eagle: bool, use_eagle: bool,
alignment_tokens: int,
dcp_world_size: int = 1, dcp_world_size: int = 1,
pcp_world_size: int = 1, pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]: ) -> tuple[list[KVCacheBlock], ...]:
@ -511,6 +546,10 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
block_pool: The block pool. block_pool: The block pool.
kv_cache_spec: The kv cache spec. kv_cache_spec: The kv cache spec.
use_eagle: Whether to use eagle. use_eagle: Whether to use eagle.
dcp_world_size: The world size of decode context parallelism.
pcp_world_size: The world size of prefill context parallelism.
alignment_tokens: The returned cache hit length (in tokens) should
be a multiple of this value (in tokens).
Returns: Returns:
A list of cached blocks A list of cached blocks
@ -524,6 +563,10 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
) )
assert dcp_world_size == 1, "DCP not support chunked local attn now." assert dcp_world_size == 1, "DCP not support chunked local attn now."
assert pcp_world_size == 1, "PCP not support chunked local attn now." assert pcp_world_size == 1, "PCP not support chunked local attn now."
assert kv_cache_spec.block_size == alignment_tokens, (
"KV cache groups with different block sizes are not compatible with "
"chunked local attention now"
)
max_num_blocks = max_length // kv_cache_spec.block_size max_num_blocks = max_length // kv_cache_spec.block_size
if max_length > 0: if max_length > 0:
local_attention_start_idx = ( local_attention_start_idx = (
@ -612,12 +655,13 @@ class MambaManager(SingleTypeKVCacheManager):
@classmethod @classmethod
def find_longest_cache_hit( def find_longest_cache_hit(
cls, cls,
block_hashes: list[BlockHash], block_hashes: BlockHashList,
max_length: int, max_length: int,
kv_cache_group_ids: list[int], kv_cache_group_ids: list[int],
block_pool: BlockPool, block_pool: BlockPool,
kv_cache_spec: KVCacheSpec, kv_cache_spec: KVCacheSpec,
use_eagle: bool, use_eagle: bool,
alignment_tokens: int,
dcp_world_size: int = 1, dcp_world_size: int = 1,
pcp_world_size: int = 1, pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]: ) -> tuple[list[KVCacheBlock], ...]:
@ -630,12 +674,21 @@ class MambaManager(SingleTypeKVCacheManager):
[] for _ in range(len(kv_cache_group_ids)) [] for _ in range(len(kv_cache_group_ids))
) )
max_num_blocks = max_length // kv_cache_spec.block_size block_size = kv_cache_spec.block_size
max_num_blocks = max_length // block_size
# Search from right to left and early stop when a match is found. # Search from right to left and early stop when a match is found.
for i in range(max_num_blocks - 1, -1, -1): for i in range(max_num_blocks - 1, -1, -1):
if cached_block := block_pool.get_cached_block( if cached_block := block_pool.get_cached_block(
block_hashes[i], kv_cache_group_ids block_hashes[i], kv_cache_group_ids
): ):
# When enable Mamba prefix caching, `block_size` will be aligned
# across full attention layers and Mamba layers to ensure the
# prefix hit length aligned at block
if (
block_size != alignment_tokens # Faster for common case.
and (i + 1) * block_size % alignment_tokens != 0
):
continue
for computed, cached in zip(computed_blocks, cached_block): for computed, cached in zip(computed_blocks, cached_block):
# the hit length logic later assumes: # the hit length logic later assumes:
# hit_length = len(hit_blocks_other_attn[0]) # hit_length = len(hit_blocks_other_attn[0])
@ -708,12 +761,13 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
@classmethod @classmethod
def find_longest_cache_hit( def find_longest_cache_hit(
cls, cls,
block_hashes: list[BlockHash], block_hashes: BlockHashList,
max_length: int, max_length: int,
kv_cache_group_ids: list[int], kv_cache_group_ids: list[int],
block_pool: BlockPool, block_pool: BlockPool,
kv_cache_spec: KVCacheSpec, kv_cache_spec: KVCacheSpec,
use_eagle: bool, use_eagle: bool,
alignment_tokens: int,
dcp_world_size: int = 1, dcp_world_size: int = 1,
pcp_world_size: int = 1, pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]: ) -> tuple[list[KVCacheBlock], ...]: