[v1] Hybrid Memory Allocator (#17996)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang 2025-06-06 11:47:09 +08:00 committed by GitHub
parent 3465b87ef8
commit f8a1a2d108
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 1605 additions and 440 deletions

View File

@ -15,8 +15,8 @@ from vllm.v1.core.kv_cache_manager import KVCacheManager
from vllm.v1.core.kv_cache_utils import (
FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics,
estimate_max_model_len, generate_block_hash_extra_keys,
get_max_concurrency_for_kv_cache_config, hash_block_tokens,
hash_request_tokens, unify_kv_cache_configs)
get_kv_cache_config, get_max_concurrency_for_kv_cache_config,
hash_block_tokens, hash_request_tokens, unify_kv_cache_configs)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheTensor,
SlidingWindowSpec)
@ -63,6 +63,20 @@ def new_kv_cache_spec(block_size=16,
sliding_window=sliding_window)
def new_sliding_window_spec(block_size=16,
num_kv_heads=2,
head_size=64,
dtype=torch.float32,
use_mla=False,
sliding_window=1):
return SlidingWindowSpec(block_size=block_size,
num_kv_heads=num_kv_heads,
head_size=head_size,
dtype=dtype,
use_mla=use_mla,
sliding_window=sliding_window)
def test_none_hash(monkeypatch):
import vllm.v1.core.kv_cache_utils
@ -403,10 +417,10 @@ def test_unify_kv_cache_configs():
same_kv_cache_config = [
KVCacheConfig(
num_blocks=10,
tensors={
"layer1": KVCacheTensor(100),
"layer2": KVCacheTensor(100),
},
kv_cache_tensors=[
KVCacheTensor(size=100, shared_by=["layer1"]),
KVCacheTensor(size=100, shared_by=["layer2"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer2"],
@ -415,10 +429,10 @@ def test_unify_kv_cache_configs():
),
KVCacheConfig(
num_blocks=20,
tensors={
"layer1": KVCacheTensor(100),
"layer2": KVCacheTensor(100),
},
kv_cache_tensors=[
KVCacheTensor(size=100, shared_by=["layer1"]),
KVCacheTensor(size=100, shared_by=["layer2"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer2"],
@ -433,10 +447,10 @@ def test_unify_kv_cache_configs():
need_sort_kv_cache_config = [
KVCacheConfig(
num_blocks=10,
tensors={
"layer1": KVCacheTensor(100),
"layer2": KVCacheTensor(100),
},
kv_cache_tensors=[
KVCacheTensor(size=100, shared_by=["layer1"]),
KVCacheTensor(size=100, shared_by=["layer2"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer2"],
@ -445,10 +459,10 @@ def test_unify_kv_cache_configs():
),
KVCacheConfig(
num_blocks=20,
tensors={
"layer1": KVCacheTensor(100),
"layer2": KVCacheTensor(100),
},
kv_cache_tensors=[
KVCacheTensor(size=100, shared_by=["layer1"]),
KVCacheTensor(size=100, shared_by=["layer2"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer2"],
new_kv_cache_spec(num_kv_heads=4)),
@ -464,10 +478,10 @@ def test_unify_kv_cache_configs():
diff_kv_cache_config = [
KVCacheConfig(
num_blocks=10,
tensors={
"layer1": KVCacheTensor(100),
"layer2": KVCacheTensor(100),
},
kv_cache_tensors=[
KVCacheTensor(size=100, shared_by=["layer1"]),
KVCacheTensor(size=100, shared_by=["layer2"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer2"],
@ -476,10 +490,10 @@ def test_unify_kv_cache_configs():
),
KVCacheConfig(
num_blocks=20,
tensors={
"layer1": KVCacheTensor(100),
"layer2": KVCacheTensor(100),
},
kv_cache_tensors=[
KVCacheTensor(size=100, shared_by=["layer1"]),
KVCacheTensor(size=100, shared_by=["layer2"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer2"],
@ -636,7 +650,7 @@ def test_get_max_concurrency_for_kv_cache_config():
kv_cache_config_full_attention = KVCacheConfig(
num_blocks=int(1024 * 1.5),
tensors={},
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec([f"layer_{i}" for i in range(32)],
full_attention_spec),
@ -648,7 +662,7 @@ def test_get_max_concurrency_for_kv_cache_config():
kv_cache_config_sliding_window = KVCacheConfig(
num_blocks=129 * 3,
tensors={},
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec([f"layer_{i}" for i in range(32)],
sliding_window_spec),
@ -660,7 +674,7 @@ def test_get_max_concurrency_for_kv_cache_config():
kv_cache_config_hybrid_model = KVCacheConfig(
num_blocks=(1024 + 129) * 3,
tensors={},
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec([f"layer_{i}" for i in range(32)],
full_attention_spec),
@ -678,9 +692,9 @@ def test_allocate_with_lookahead():
block_size = 4
config = KVCacheConfig(
num_blocks=10,
tensors={
"layer1": KVCacheTensor(100),
},
kv_cache_tensors=[
KVCacheTensor(size=100, shared_by=["layer1"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer1"],
new_kv_cache_spec(block_size=block_size)),
@ -702,7 +716,7 @@ def test_allocate_with_lookahead():
num_new_tokens=3,
num_lookahead_tokens=2, # Total required: 3+2=5 tokens
)
assert len(blocks.blocks) == 2 # ceil(5/4)=2 blocks
assert len(blocks.get_block_ids()[0]) == 2 # ceil(5/4)=2 blocks
# Test case 2: With precomputed blocks
kv_cache_manager = KVCacheManager(kv_cache_config=config,
@ -713,7 +727,7 @@ def test_allocate_with_lookahead():
num_new_tokens=3,
num_lookahead_tokens=2,
)
assert len(blocks.blocks) == 2
assert len(blocks.get_block_ids()[0]) == 2
# Test case 3: With precomputed blocks
# required_blocks = ceil((3 + 4) / 4) = 2
@ -724,4 +738,165 @@ def test_allocate_with_lookahead():
num_new_tokens=3,
num_lookahead_tokens=4,
)
assert len(blocks.blocks) == 2
assert len(blocks.get_block_ids()[0]) == 2
def test_get_kv_cache_config():
# pass max_model_len to pass check_enough_kv_cache_memory
model_config = ModelConfig(max_model_len=16)
vllm_config = VllmConfig(model_config=model_config)
mem_per_block_per_layer = 16 * 2 * 64 * 4 * 2
# all layers are full attention -> single group
kv_cache_specs_full = {
'layer_1': new_kv_cache_spec(),
'layer_2': new_kv_cache_spec(),
}
kv_cache_config_full = get_kv_cache_config(
vllm_config, kv_cache_specs_full, mem_per_block_per_layer * 2 * 32)
assert kv_cache_config_full == KVCacheConfig(
num_blocks=32,
kv_cache_tensors=[
KVCacheTensor(size=mem_per_block_per_layer * 32,
shared_by=["layer_1"]),
KVCacheTensor(size=mem_per_block_per_layer * 32,
shared_by=["layer_2"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec())
])
# all layers are sliding window -> single group
kv_cache_specs_sliding = {
'layer_1': new_sliding_window_spec(),
'layer_2': new_sliding_window_spec(),
}
kv_cache_config_sliding = get_kv_cache_config(
vllm_config, kv_cache_specs_sliding, mem_per_block_per_layer * 2 * 32)
assert kv_cache_config_sliding == KVCacheConfig(
num_blocks=32,
kv_cache_tensors=[
KVCacheTensor(size=mem_per_block_per_layer * 32,
shared_by=["layer_1"]),
KVCacheTensor(size=mem_per_block_per_layer * 32,
shared_by=["layer_2"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer_1", "layer_2"], new_sliding_window_spec())
])
# full + sliding, but disable_hybrid_kv_cache_manager
vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = True
kv_cache_specs_hybrid = {
'layer_1': new_kv_cache_spec(),
'layer_2': new_sliding_window_spec(),
}
kv_cache_config_hybrid = get_kv_cache_config(
vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 2 * 32)
assert kv_cache_config_hybrid == KVCacheConfig(
num_blocks=32,
kv_cache_tensors=[
KVCacheTensor(size=mem_per_block_per_layer * 32,
shared_by=["layer_1"]),
KVCacheTensor(size=mem_per_block_per_layer * 32,
shared_by=["layer_2"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer_1", "layer_2"],
new_kv_cache_spec(sliding_window=1)),
],
)
vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = False
# full + sliding, with hybrid_kv_cache_manager
kv_cache_specs_hybrid = {
'layer_1': new_kv_cache_spec(),
'layer_2': new_sliding_window_spec(),
}
kv_cache_config_hybrid = get_kv_cache_config(
vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 2 * 32)
assert kv_cache_config_hybrid == KVCacheConfig(
num_blocks=64,
kv_cache_tensors=[
KVCacheTensor(size=mem_per_block_per_layer * 64,
shared_by=["layer_1", "layer_2"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer_1"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer_2"], new_sliding_window_spec()),
],
)
# 2 full + 4 sliding, 2 layers per group
kv_cache_specs_hybrid = {
'layer_1': new_kv_cache_spec(),
'layer_2': new_kv_cache_spec(),
'layer_3': new_sliding_window_spec(),
'layer_4': new_sliding_window_spec(),
'layer_5': new_sliding_window_spec(),
'layer_6': new_sliding_window_spec(),
}
kv_cache_config_hybrid = get_kv_cache_config(
vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 2 * 32)
assert kv_cache_config_hybrid == KVCacheConfig(
num_blocks=32,
kv_cache_tensors=[
KVCacheTensor(size=mem_per_block_per_layer * 32,
shared_by=["layer_1", "layer_3", "layer_5"]),
KVCacheTensor(size=mem_per_block_per_layer * 32,
shared_by=["layer_2", "layer_4", "layer_6"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer_3", "layer_4"],
new_sliding_window_spec()),
KVCacheGroupSpec(["layer_5", "layer_6"],
new_sliding_window_spec()),
],
)
# 3 full + 7 sliding, pad to 3 full + 9 sliding
kv_cache_specs_hybrid = {
'layer_1': new_kv_cache_spec(),
'layer_2': new_kv_cache_spec(),
'layer_3': new_kv_cache_spec(),
'layer_4': new_sliding_window_spec(),
'layer_5': new_sliding_window_spec(),
'layer_6': new_sliding_window_spec(),
'layer_7': new_sliding_window_spec(),
'layer_8': new_sliding_window_spec(),
'layer_9': new_sliding_window_spec(),
'layer_10': new_sliding_window_spec(),
}
kv_cache_config_hybrid = get_kv_cache_config(
vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 3 * 32)
assert kv_cache_config_hybrid == KVCacheConfig(
num_blocks=32,
kv_cache_tensors=[
KVCacheTensor(
size=mem_per_block_per_layer * 32,
shared_by=["layer_1", "layer_4", "layer_7", "layer_10"]),
KVCacheTensor(size=mem_per_block_per_layer * 32,
shared_by=["layer_2", "layer_5", "layer_8"]),
KVCacheTensor(size=mem_per_block_per_layer * 32,
shared_by=["layer_3", "layer_6", "layer_9"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer_1", "layer_2", "layer_3"],
new_kv_cache_spec()),
KVCacheGroupSpec(["layer_4", "layer_5", "layer_6"],
new_sliding_window_spec()),
KVCacheGroupSpec(["layer_7", "layer_8", "layer_9"],
new_sliding_window_spec()),
KVCacheGroupSpec(["layer_10"], new_sliding_window_spec()),
],
)
# different hidden size, unimplemented
kv_cache_specs_hybrid = {
'layer_1': new_kv_cache_spec(head_size=128),
'layer_2': new_kv_cache_spec(),
}
with pytest.raises(NotImplementedError):
get_kv_cache_config(vllm_config, kv_cache_specs_hybrid,
mem_per_block_per_layer * 2 * 32)

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Compare the with and without prefix caching."""
import copy
from typing import Optional
import pytest
@ -13,8 +14,8 @@ from vllm.sampling_params import SamplingParams
from vllm.utils import sha256
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
hash_block_tokens)
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
KVCacheBlock, hash_block_tokens)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, SlidingWindowSpec)
@ -47,7 +48,7 @@ def make_request(request_id,
def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
return KVCacheConfig(
num_blocks=num_blocks,
tensors={},
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer"],
@ -57,6 +58,38 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
)
def make_kv_cache_config_hybrid_model(block_size: int,
num_blocks: int) -> KVCacheConfig:
return KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer1"],
FullAttentionSpec(block_size, 1, 1, torch.float32, False),
),
KVCacheGroupSpec(
["layer2"],
SlidingWindowSpec(block_size,
1,
1,
torch.float32,
False,
sliding_window=2 * block_size),
),
KVCacheGroupSpec(
["layer3"],
SlidingWindowSpec(block_size,
1,
1,
torch.float32,
False,
sliding_window=2 * block_size),
),
],
)
@pytest.mark.parametrize("hash_algo", ["sha256", "hash"])
def test_prefill(hash_algo):
manager = KVCacheManager(
@ -79,10 +112,10 @@ def test_prefill(hash_algo):
req0 = make_request("0", all_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert len(manager.req_to_block_hashes[req0.request_id]) == 3
assert not computed_blocks.blocks
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks) * 16,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
@ -92,7 +125,8 @@ def test_prefill(hash_algo):
block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16])
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
block_tokens)
assert manager.block_pool.blocks[block_id].block_hash == block_hash
assert manager.block_pool.blocks[
block_id].block_hash.block_hash == block_hash
assert manager.block_pool.blocks[block_id].ref_cnt == 1
parent_block_hash = block_hash.hash_value
@ -111,10 +145,10 @@ def test_prefill(hash_algo):
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens,
len(computed_blocks.blocks) * 16,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[5]]
for block in computed_blocks.blocks:
for block in computed_blocks.blocks[0]:
assert block.ref_cnt == 2
# At this point, we should have 5 free blocks left.
@ -145,7 +179,7 @@ def test_prefill(hash_algo):
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req2, num_new_tokens,
len(computed_blocks.blocks) * 16,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[6]]
@ -165,10 +199,10 @@ def test_prefill(hash_algo):
# Cache miss and eviction.
req3 = make_request("3", [99] * (16 * 10))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert not computed_blocks.blocks
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req3, 16 * 10,
len(computed_blocks.blocks) * 16,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
# This block ID order also checks the eviction order.
assert blocks.get_block_ids() == [[7, 8, 9, 10, 4, 5, 6, 3, 2, 1]]
@ -177,6 +211,138 @@ def test_prefill(hash_algo):
assert manager.block_pool.free_block_queue.free_list_tail is None
def test_prefill_hybrid_model():
block_size = 16
manager = KVCacheManager(
make_kv_cache_config_hybrid_model(block_size, 21),
max_model_len=8192,
enable_caching=True,
)
hash_fn = hash
# Complete 3 blocks (48 tokens)
common_token_ids = [i for i in range(3) for _ in range(block_size)]
# Fully cache miss
# Incomplete 1 block (7 tokens)
unique_token_ids = [3] * 7
all_token_ids = common_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert len(manager.req_to_block_hashes[req0.request_id]) == 3
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[1, 2, 3, 4], [5, 6, 7, 8],
[9, 10, 11, 12]]
# Check full block metadata
parent_block_hash = None
for length, block_ids in zip((1, 2, 3),
((1, 5, 9), (2, 6, 10), (3, 7, 11))):
block_tokens = tuple(all_token_ids[(length - 1) * 16:length * 16])
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
block_tokens)
for block_id in block_ids:
assert manager.block_pool.blocks[
block_id].block_hash.block_hash == block_hash
assert manager.block_pool.blocks[block_id].ref_cnt == 1
parent_block_hash = block_hash.hash_value
# Check partial block metadata
for block_id in (4, 8, 12):
assert manager.block_pool.blocks[block_id].block_hash is None
assert manager.block_pool.blocks[block_id].ref_cnt == 1
# Cache hit in the common prefix
# Incomplete 1 block (5 tokens)
unique_token_ids = [3] * 5
req1 = make_request("1", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert computed_blocks.get_block_ids() == [[1, 2, 3], [0, 6, 7],
[0, 10, 11]]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[13], [14], [15]]
for block_per_group in computed_blocks.blocks:
for block in block_per_group:
if block != manager.block_pool.null_block:
assert block.ref_cnt == 2
block_hashes = manager.req_to_block_hashes[req1.request_id]
manager.free(req0)
manager.free(req1)
cached_block_hash_to_block_bak = copy.copy(
manager.block_pool.cached_block_hash_to_block)
def test_partial_request_hit(request_id: str,
hash_to_evict: list[BlockHashWithGroupId],
expect_hit_length: int):
req = make_request(request_id, common_token_ids + unique_token_ids)
for hash_with_group_id in hash_to_evict:
manager.block_pool.cached_block_hash_to_block.pop(
hash_with_group_id)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert len(manager.req_to_block_hashes[req.request_id]) == 3
assert num_computed_tokens == expect_hit_length * block_size
for block_per_group in computed_blocks.blocks:
assert len(block_per_group) == num_computed_tokens // block_size
for hash_with_group_id in hash_to_evict:
manager.block_pool.cached_block_hash_to_block[
hash_with_group_id] = cached_block_hash_to_block_bak[
hash_with_group_id]
manager.free(req)
# Evict the blocks outside sliding window, does not affect the hit length.
test_partial_request_hit("2", [
BlockHashWithGroupId(block_hashes[0], 1),
BlockHashWithGroupId(block_hashes[0], 2)
], 3)
# Evict the first block of full attention, makes total cache miss.
test_partial_request_hit("3", [
BlockHashWithGroupId(block_hashes[0], 0),
], 0)
# Evict the last block of all layers, reduces the hit length to 2.
test_partial_request_hit("4", [
BlockHashWithGroupId(block_hashes[2], 0),
BlockHashWithGroupId(block_hashes[2], 1),
BlockHashWithGroupId(block_hashes[2], 2),
], 2)
# Evict the last block of full attention, reduces the hit length to 2.
test_partial_request_hit("5", [BlockHashWithGroupId(block_hashes[2], 0)],
2)
# Evict the last block of sliding window, reduces the hit length to 2.
test_partial_request_hit("6", [BlockHashWithGroupId(block_hashes[2], 1)],
2)
# Evict the last block of sliding window, reduces the hit length to 2.
test_partial_request_hit("7", [BlockHashWithGroupId(block_hashes[2], 2)],
2)
# Evict different set of blocks for full attention and sliding window makes
# total cache miss.
# The cache hit length of full attention is 1 * block_size.
# The cache hit length of sliding window is 2 * block_size.
# Then it is cache miss as the two type of layers have different hit length.
test_partial_request_hit("8", [
BlockHashWithGroupId(block_hashes[2], 0),
BlockHashWithGroupId(block_hashes[0], 1),
BlockHashWithGroupId(block_hashes[0], 2),
], 0)
def test_prefill_plp():
'''Test prefill with APC and some prompt logprobs (plp) requests.
@ -203,13 +369,13 @@ def test_prefill_plp():
req0 = make_request("0", all_token_ids, prompt_logprobs=5)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert len(manager.req_to_block_hashes[req0.request_id]) == 0
assert not computed_blocks.blocks
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks) * 16,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
req0_block_hashes = [b.block_hash for b in blocks.blocks]
req0_block_hashes = [b.block_hash for b in blocks.blocks[0]]
# Check full block metadata
parent_block_hash = None
@ -217,7 +383,8 @@ def test_prefill_plp():
block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16])
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
block_tokens)
assert manager.block_pool.blocks[block_id].block_hash == block_hash
assert manager.block_pool.blocks[
block_id].block_hash.block_hash == block_hash
assert manager.block_pool.blocks[block_id].ref_cnt == 1
parent_block_hash = block_hash.hash_value
@ -237,10 +404,10 @@ def test_prefill_plp():
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens,
len(computed_blocks.blocks) * 16,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[5]]
for block in computed_blocks.blocks:
for block in computed_blocks.blocks[0]:
assert block.ref_cnt == 2
# At this point, we should have 5 free blocks left.
@ -269,14 +436,14 @@ def test_prefill_plp():
prompt_logprobs=5)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(manager.req_to_block_hashes[req2.request_id]) == 0
assert not computed_blocks.blocks
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req2, 55,
len(computed_blocks.blocks) * 16,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
block_ids = blocks.get_block_ids()
# Duplicate cached blocks have different ids but same hashes vs request #0
assert [b.block_hash for b in blocks.blocks] == req0_block_hashes
assert [b.block_hash for b in blocks.blocks[0]] == req0_block_hashes
assert block_ids != [[1, 2, 3, 4]]
# Request #2 block hashes are valid since request #0 hashes are.
@ -302,10 +469,10 @@ def test_decode():
unique_token_ids = [3] * 7
req0 = make_request("0", common_token_ids + unique_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks) * 16,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
@ -314,10 +481,10 @@ def test_decode():
for _ in range(4):
req0.append_output_token_ids(8)
new_blocks = manager.allocate_slots(req0, 4,
len(computed_blocks.blocks) * 16,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert new_blocks is not None and len(new_blocks.blocks) == 0
assert manager.single_type_manager.req_to_blocks[
assert new_blocks is not None and len(new_blocks.blocks[0]) == 0
assert manager.coordinator.single_type_managers[0].req_to_blocks[
req0.request_id][-1].block_hash is None
# Append slots with allocating a new block.
@ -327,12 +494,12 @@ def test_decode():
for _ in range(9 + 10):
req0.append_output_token_ids(7)
new_blocks = manager.allocate_slots(req0, 19,
len(computed_blocks.blocks) * 16,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert new_blocks is not None and len(new_blocks.blocks) == 1
assert manager.single_type_manager.req_to_blocks[
assert new_blocks is not None and len(new_blocks.blocks[0]) == 1
assert manager.coordinator.single_type_managers[0].req_to_blocks[
req0.request_id][-2].block_hash is not None
assert manager.single_type_manager.req_to_blocks[
assert manager.coordinator.single_type_managers[0].req_to_blocks[
req0.request_id][-1].block_hash is None
@ -346,23 +513,23 @@ def test_evict():
last_token_id = 5 * 16 + 7
req0 = make_request("0", list(range(last_token_id)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 5 * 16 + 7,
len(computed_blocks.blocks) * 16,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert len(blocks.blocks) == 6 # 5 full + 1 partial
assert len(blocks.blocks[0]) == 6 # 5 full + 1 partial
# 3 blocks.
req1 = make_request("1", list(range(last_token_id,
last_token_id + 3 * 16)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks.blocks
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req1, 3 * 16,
len(computed_blocks.blocks) * 16,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert len(blocks.blocks) == 3 # 3 full blocks
assert len(blocks.blocks[0]) == 3 # 3 full blocks
last_token_id += 3 * 16
# 10 - (6 + 3) == 1
@ -382,7 +549,7 @@ def test_evict():
assert computed_blocks.get_block_ids() == [[1, 2]]
assert num_computed_tokens == 2 * 16
blocks = manager.allocate_slots(req2, 3,
len(computed_blocks.blocks) * 16,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[10]]
assert manager.block_pool.free_block_queue.num_free_blocks == 7
@ -404,12 +571,12 @@ def test_hash_block_correct_reuse():
num_tokens = block_size * 1
req = make_request("0", list(range(num_tokens)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks.blocks
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req, num_tokens,
len(computed_blocks.blocks) * 16,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert len(blocks.blocks) == 1
assert len(blocks.blocks[0]) == 1
# Deallocate the block.
manager.free(req)
@ -418,15 +585,15 @@ def test_hash_block_correct_reuse():
# block is cleared.
req = make_request("1", list(range(num_tokens - 1)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks.blocks
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req, num_tokens - 1,
len(computed_blocks.blocks) * 16,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert len(blocks.blocks) == 1
assert len(blocks.blocks[0]) == 1
assert manager.block_pool.blocks[
blocks.blocks[0].block_id].block_hash is None
assert manager.block_pool.blocks[blocks.blocks[0]
[0].block_id].block_hash is None
def test_computed_blocks_not_evicted():
@ -445,24 +612,24 @@ def test_computed_blocks_not_evicted():
num_tokens = block_size * 1
req0 = make_request("0", list(range(num_tokens)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, num_tokens,
len(computed_blocks.blocks) * 16,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert len(blocks.blocks) == 1
assert blocks.blocks[0].block_id == 1
assert len(blocks.blocks[0]) == 1
assert blocks.blocks[0][0].block_id == 1
# Allocate another block.
req1 = make_request("1", list(range(num_tokens, num_tokens * 2)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks.blocks
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req1, num_tokens,
len(computed_blocks.blocks) * 16,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert len(blocks.blocks) == 1
assert blocks.blocks[0].block_id == 2
assert len(blocks.blocks[0]) == 1
assert blocks.blocks[0][0].block_id == 2
# Free the blocks.
manager.free(req0)
@ -472,15 +639,15 @@ def test_computed_blocks_not_evicted():
# cached block rather than the first one.
req2 = make_request("2", list(range(num_tokens * 2)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(computed_blocks.blocks) == 1
assert computed_blocks.blocks[0].block_id == 1
assert len(computed_blocks.blocks[0]) == 1
assert computed_blocks.blocks[0][0].block_id == 1
assert num_computed_tokens == block_size
blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens,
len(computed_blocks.blocks) * 16,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert len(blocks.blocks) == 1
assert blocks.blocks[0].block_id == 2
assert len(blocks.blocks[0]) == 1
assert blocks.blocks[0][0].block_id == 2
def test_basic_prefix_caching_disabled():
@ -497,12 +664,12 @@ def test_basic_prefix_caching_disabled():
req1 = make_request("1", list(range(10))) # 2 blocks and some more
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks.blocks
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req1, 10,
len(computed_blocks.blocks) * 16,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert len(blocks.blocks) == 3
assert len(blocks.blocks[0]) == 3
# Free the blocks.
manager.free(req1)
@ -510,20 +677,20 @@ def test_basic_prefix_caching_disabled():
# No caching.
req2 = make_request("2", list(range(16))) # shared prefix
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert not computed_blocks.blocks
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req2, 16,
len(computed_blocks.blocks) * 16,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert len(blocks.blocks) == 4
assert len(blocks.blocks[0]) == 4
# New requests should not have any blocks.
req3 = make_request("3", list(range(4)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert not computed_blocks.blocks
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req3, 4,
len(computed_blocks.blocks) * 16,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert not blocks
@ -558,6 +725,7 @@ def test_cache_blocks(hash_fn):
num_full_blocks=2,
block_size=block_size,
hash_fn=hash_fn,
kv_cache_group_id=0,
)
assert len(block_pool.cached_block_hash_to_block) == 2
@ -573,11 +741,83 @@ def test_cache_blocks(hash_fn):
num_full_blocks=3,
block_size=block_size,
hash_fn=hash_fn,
kv_cache_group_id=0,
)
assert len(block_pool.cached_block_hash_to_block) == 3
assert blocks[0].block_hash is not None
def test_cache_blocks_multi_group():
"""
This tests that blocks are cached correctly for different kv cache groups.
"""
block_size = 4
block_pool = BlockPool(num_gpu_blocks=10, enable_caching=True)
# Req:
# Block 0/4: [0, 1, 2, 3]
# Block 1/5: [4, 5, 6, 7]
# Block 2/6: [8, 9, 10, 11]
# Block 3/7: [12, 13]
req = make_request("0", list(range(14)))
# Cache the blocks for group 0.
blocks = [KVCacheBlock(block_id=i) for i in range(2)]
block_hashes: list[BlockHash] = []
block_pool.cache_full_blocks(
request=req,
blocks=blocks,
block_hashes=block_hashes,
num_cached_blocks=0,
num_full_blocks=2,
block_size=block_size,
hash_fn=hash,
kv_cache_group_id=0,
)
assert len(block_pool.cached_block_hash_to_block) == 2
assert len(block_hashes) == 2
assert all([block.block_hash is not None for block in blocks])
# Cache the blocks for group 1.
blocks = [KVCacheBlock(block_id=i) for i in range(3)]
block_pool.cache_full_blocks(
request=req,
blocks=blocks,
block_hashes=block_hashes,
num_cached_blocks=0,
num_full_blocks=3,
block_size=block_size,
hash_fn=hash,
kv_cache_group_id=1,
)
assert len(block_pool.cached_block_hash_to_block) == 5
assert len(block_hashes) == 3
assert all([block.block_hash is not None for block in blocks])
# Block hash 0: hit for group 0 and 1
# Block hash 1: hit for group 0 and 1
# Block hash 2: hit for group 1
assert block_pool.get_cached_block(block_hashes[0],
kv_cache_group_ids=[0]) is not None
assert block_pool.get_cached_block(block_hashes[1],
kv_cache_group_ids=[0]) is not None
assert block_pool.get_cached_block(block_hashes[2],
kv_cache_group_ids=[0]) is None
assert block_pool.get_cached_block(block_hashes[0],
kv_cache_group_ids=[1]) is not None
assert block_pool.get_cached_block(block_hashes[1],
kv_cache_group_ids=[1]) is not None
assert block_pool.get_cached_block(block_hashes[2],
kv_cache_group_ids=[1]) is not None
assert block_pool.get_cached_block(block_hashes[0],
kv_cache_group_ids=[0, 1]) is not None
assert block_pool.get_cached_block(block_hashes[1],
kv_cache_group_ids=[0, 1]) is not None
assert block_pool.get_cached_block(block_hashes[2],
kv_cache_group_ids=[0, 1]) is None
def test_mm_prefix_caching():
"""
This tests that the multi-modal prefix caching is correct.
@ -614,7 +854,7 @@ def test_mm_prefix_caching():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
# Completed block should have hashes with extra keys.
assert not computed_blocks.blocks
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
block_hashes = manager.req_to_block_hashes[req0.request_id]
assert len(block_hashes) == 3
@ -623,7 +863,7 @@ def test_mm_prefix_caching():
assert block_hashes[2].extra_keys == ("bbb", )
blocks = manager.allocate_slots(req0, 59,
len(computed_blocks.blocks) * 16,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
req0.num_computed_tokens = 59
@ -632,9 +872,9 @@ def test_mm_prefix_caching():
for _ in range(5):
req0.append_output_token_ids(8)
new_blocks = manager.allocate_slots(req0, 5,
len(computed_blocks.blocks) * 16,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert new_blocks is not None and len(new_blocks.blocks) == 0
assert new_blocks is not None and len(new_blocks.blocks[0]) == 0
# The just completed block should have hashes with extra keys.
assert len(block_hashes) == 4
@ -652,7 +892,7 @@ def test_mm_prefix_caching():
mm_positions=mm_positions,
mm_hashes=mm_hashes)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(computed_blocks.blocks) == 3
assert len(computed_blocks.blocks[0]) == 3
assert num_computed_tokens == 3 * 16
@ -675,7 +915,7 @@ def test_cache_key_salting():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
# Completed block should have hashes with extra keys.
assert not computed_blocks.blocks
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
block_hashes = manager.req_to_block_hashes[req0.request_id]
assert len(block_hashes) == 3
@ -684,7 +924,7 @@ def test_cache_key_salting():
assert block_hashes[2].extra_keys is None
blocks = manager.allocate_slots(req0, 59,
len(computed_blocks.blocks) * 16,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
req0.num_computed_tokens = 59
@ -693,9 +933,9 @@ def test_cache_key_salting():
for _ in range(5):
req0.append_output_token_ids(8)
new_blocks = manager.allocate_slots(req0, 5,
len(computed_blocks.blocks) * 16,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert new_blocks is not None and len(new_blocks.blocks) == 0
assert new_blocks is not None and len(new_blocks.blocks[0]) == 0
# Now one more block that should not have extra keys.
assert len(block_hashes) == 4
@ -706,14 +946,14 @@ def test_cache_key_salting():
req1 = make_request("1", token_ids, cache_salt="salt1")
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
# Should match only a prefix of 3 blocks.
assert len(computed_blocks.blocks) == 3
assert len(computed_blocks.blocks[0]) == 3
assert num_computed_tokens == 3 * block_size
# Test cache miss with same content but different salt.
token_ids = common_token_ids + [4] * 11
req2 = make_request("2", token_ids, cache_salt="salt2")
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(computed_blocks.blocks) == 0
assert len(computed_blocks.blocks[0]) == 0
assert num_computed_tokens == 0
block_hashes = manager.req_to_block_hashes[req2.request_id]
assert len(block_hashes) == 3
@ -738,20 +978,24 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
common_token_ids = [i for i in range(3) for _ in range(16)]
req0 = make_request("0", common_token_ids)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
manager.allocate_slots(req0, 48,
len(computed_blocks.blocks) * 16, computed_blocks)
block_part0 = manager.single_type_manager.req_to_blocks[req0.request_id]
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
block_part0 = manager.coordinator.single_type_managers[0].req_to_blocks[
req0.request_id]
# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
req1 = make_request("1", common_token_ids * 2)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert computed_blocks.blocks == block_part0
assert computed_blocks.blocks[0] == block_part0
assert num_computed_tokens == 3 * 16
manager.allocate_slots(req1, 48,
len(computed_blocks.blocks) * 16, computed_blocks)
block_part1 = manager.single_type_manager.req_to_blocks[req1.request_id]
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
block_part1 = manager.coordinator.single_type_managers[0].req_to_blocks[
req1.request_id]
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
# | Req1-5(F)| ... |
manager.free(req1)
@ -762,10 +1006,11 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# | Req1-5(F)| Req2-0 | Req2-1 | ... |
req2 = make_request("2", [7] * block_size * 2)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert not computed_blocks.blocks
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
manager.allocate_slots(req2, block_size * 2,
len(computed_blocks.blocks) * 16, computed_blocks)
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
# Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed,
# but it cannot be allocated due to insufficient free blocks (2).
@ -773,11 +1018,11 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
assert manager.block_pool.free_block_queue.num_free_blocks == 5
req3 = make_request("3", common_token_ids * 3)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert computed_blocks.blocks == block_part1
assert computed_blocks.blocks[0] == block_part1
assert num_computed_tokens == 6 * 16
# Req3 cannot be allocated.
assert manager.allocate_slots(req3, 48,
len(computed_blocks.blocks) * 16,
len(computed_blocks.blocks[0]) * 16,
computed_blocks) is None
# Block 0-2 are used by Req 1.
assert {block.ref_cnt for block in block_part1[:3]} == {1}
@ -804,9 +1049,9 @@ def test_reset_prefix_cache():
req1 = make_request("1", all_token_ids)
computed_blocks, _ = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert len(computed_blocks.blocks) == 3
assert len(computed_blocks.blocks[0]) == 3
blocks = manager.allocate_slots(req1, 7,
len(computed_blocks.blocks) * 16,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks.get_block_ids() == [[5]]
@ -836,10 +1081,11 @@ def test_prefix_cache_stats_disabled():
# Call all functions that check whether log_stats is disabled.
req = make_request("0", list(range(16)))
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks.blocks
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
manager.allocate_slots(req, 16,
len(computed_blocks.blocks) * 16, computed_blocks)
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
manager.reset_prefix_cache()
# Ensure prefix_cache_stats remains None
@ -918,7 +1164,8 @@ def test_eagle_enabled_removes_last_block():
# Prime the cache
computed_blocks, _ = manager.get_computed_blocks(req)
manager.allocate_slots(req, len(token_ids),
len(computed_blocks.blocks) * 16, computed_blocks)
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
manager.free(req)
# New request with same tokens + Eagle enabled
@ -928,7 +1175,7 @@ def test_eagle_enabled_removes_last_block():
# Should retain 1 block:
# 1. Original 3 blocks → pop last hash → 2 matched blocks
# 2. drop last matched block → 1 remaining block
assert len(computed_blocks.blocks) == 1
assert len(computed_blocks.blocks[0]) == 1
assert num_tokens == 1 * block_size # 16 tokens
@ -948,14 +1195,15 @@ def test_eagle_with_partial_blocks():
# Prime the cache
computed_blocks, _ = manager.get_computed_blocks(req)
manager.allocate_slots(req, len(token_ids),
len(computed_blocks.blocks) * 16, computed_blocks)
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
manager.free(req)
# New request with Eagle enabled
req_eagle = make_request("partial_eagle", token_ids)
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
assert len(computed_blocks.blocks) == 1
assert len(computed_blocks.blocks[0]) == 1
assert num_tokens == 1 * block_size
@ -973,7 +1221,7 @@ def test_eagle_with_sliding_window():
manager = KVCacheManager(
KVCacheConfig(
num_blocks=10,
tensors={},
kv_cache_tensors=[],
kv_cache_groups=[KVCacheGroupSpec(['layer'], sliding_window_spec)],
),
max_model_len=8192,
@ -988,7 +1236,8 @@ def test_eagle_with_sliding_window():
# Prime the cache
computed_blocks, _ = manager.get_computed_blocks(req)
manager.allocate_slots(req, len(token_ids),
len(computed_blocks.blocks) * 16, computed_blocks)
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
# record the block hash of the first block in the request for later use
block_hash_first_block = manager.req_to_block_hashes[req.request_id][0]
assert block_hash_first_block is not None
@ -998,13 +1247,14 @@ def test_eagle_with_sliding_window():
req_eagle = make_request("partial_eagle", token_ids)
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
assert len(computed_blocks.blocks) == 1
assert len(computed_blocks.blocks[0]) == 1
assert num_tokens == 1 * block_size
# Evict the first block in the request
assert manager.block_pool.get_cached_block(
block_hash_first_block) is not None
manager.block_pool.cached_block_hash_to_block.pop(block_hash_first_block)
block_hash_first_block, kv_cache_group_ids=[0]) is not None
manager.block_pool.cached_block_hash_to_block.pop(
BlockHashWithGroupId(block_hash_first_block, 0))
# New request
req_after_evict = make_request("partial_eagle_after_evict", token_ids)
@ -1012,5 +1262,5 @@ def test_eagle_with_sliding_window():
# Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is
# not considered. But after dropping the last matched block due to eagle,
# there will be no matched prefix.
assert len(computed_blocks.blocks) == 0
assert len(computed_blocks.blocks[0]) == 0
assert num_tokens == 0

View File

@ -97,7 +97,7 @@ def create_scheduler(
)
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks, # A large number of blocks to hold all requests
tensors={},
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(['layer'],
FullAttentionSpec(block_size, 1, 1, torch.float32,
@ -814,10 +814,10 @@ def _assert_right_kv_cache_manager(
# Make sure the request stats are right.
EXPECTED_TOTAL_BLOCKS = num_tokens // block_size
for req_id in req_ids:
blocks = (scheduler.kv_cache_manager.single_type_manager.
req_to_blocks[req_id])
blocks = (scheduler.kv_cache_manager.coordinator.
single_type_managers[0].req_to_blocks[req_id])
hashes = scheduler.kv_cache_manager.req_to_block_hashes[req_id]
assert (scheduler.kv_cache_manager.single_type_manager.
assert (scheduler.kv_cache_manager.coordinator.single_type_managers[0].
num_cached_block[req_id] == EXPECTED_TOTAL_BLOCKS)
assert len(blocks) == EXPECTED_TOTAL_BLOCKS
assert len(hashes) == EXPECTED_TOTAL_BLOCKS
@ -1198,11 +1198,11 @@ def assert_scheduler_empty(scheduler: Scheduler):
assert len(scheduler.encoder_cache_manager.cached) == 0
# KVCache Manager.
assert len(
scheduler.kv_cache_manager.single_type_manager.req_to_blocks) == 0
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
req_to_blocks) == 0
assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0
assert len(
scheduler.kv_cache_manager.single_type_manager.num_cached_block) == 0
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
num_cached_block) == 0
num_free_blocks = (
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
assert num_free_blocks == (

View File

@ -4,7 +4,8 @@
import torch
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
KVCacheBlock)
from vllm.v1.core.single_type_kv_cache_manager import SlidingWindowManager
from vllm.v1.kv_cache_interface import SlidingWindowSpec
@ -12,9 +13,8 @@ from vllm.v1.kv_cache_interface import SlidingWindowSpec
def get_sliding_window_manager(sliding_window_spec, block_pool):
return SlidingWindowManager(sliding_window_spec,
block_pool,
use_eagle=False,
num_kv_cache_groups=1,
caching_hash_fn=lambda x: x)
caching_hash_fn=lambda x: x,
kv_cache_group_id=0)
def test_sliding_window_possible_cached_prefix():
@ -42,13 +42,18 @@ def test_sliding_window_possible_cached_prefix():
for i, (block_hash,
is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
if is_cached:
block_pool.cached_block_hash_to_block[block_hash] = {
i: block_pool.blocks[i + 10]
}
block_pool.cached_block_hash_to_block[BlockHashWithGroupId(
block_hash, 0)] = {
i: block_pool.blocks[i + 10],
}
computed_blocks = manager.find_longest_cache_hit(
block_hash_list,
len(block_hash_list) * block_size)
block_hashes=block_hash_list,
max_length=len(block_hash_list) * block_size,
kv_cache_group_ids=[0],
block_pool=block_pool,
kv_cache_spec=sliding_window_spec,
use_eagle=False)[0]
assert len(computed_blocks) == expect_length
assert all(block == block_pool.null_block
@ -95,13 +100,13 @@ def test_sliding_window_remove_skipped_blocks():
null_block_id = block_pool.null_block.block_id
def id_to_block_table(ids):
def id_to_block_table(ids) -> list[KVCacheBlock]:
return [
KVCacheBlock(id_)
if id_ != null_block_id else block_pool.null_block for id_ in ids
]
def assert_block_id(block_table, ids):
def assert_block_id(block_table: list[KVCacheBlock], ids: list[int]):
for block, id_ in zip(block_table, ids):
if id_ == null_block_id:
assert block == block_pool.null_block

View File

@ -18,7 +18,7 @@ class TestConfig:
model_config = {
"bigcode/starcoder2-3b": TestConfig(4096, (800, 1100)),
"google/gemma-2-2b-it": TestConfig(4096, (400, 800)),
"google/gemma-3-1b-it": TestConfig(4096, (400, 800)),
}
@ -26,7 +26,7 @@ model_config = {
"model",
[
"bigcode/starcoder2-3b", # sliding window only
"google/gemma-2-2b-it", # sliding window + full attention
"google/gemma-3-1b-it", # sliding window + full attention
])
@pytest.mark.parametrize("batch_size", [5])
@pytest.mark.parametrize("seed", [1])

View File

@ -36,8 +36,8 @@ def test_basic_inferface():
req_meta = kv_connector_metadata.requests[request_id]
for block_id, block in zip(
req_meta.local_block_ids, scheduler.kv_cache_manager.
single_type_manager.req_to_blocks[request_id]):
req_meta.local_block_ids, scheduler.kv_cache_manager.coordinator.
single_type_managers[0].req_to_blocks[request_id]):
assert block_id == block.block_id

View File

@ -54,8 +54,8 @@ def test_basic_lifecycle():
assert len(scheduler.waiting) == 0
# ... but blocks should not be freed.
blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[
request_id]
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
0].req_to_blocks[request_id]
for block in blocks:
assert block.ref_cnt == 1

View File

@ -51,8 +51,8 @@ def test_basic_lifecycle():
assert (block_pool.free_block_queue.num_free_blocks
< START_FREE_BLOCK_QUEUE_SIZE)
assert len(block_pool.cached_block_hash_to_block) == 0
blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[
request_id]
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
0].req_to_blocks[request_id]
for block in blocks:
assert block._block_hash is None
@ -87,8 +87,8 @@ def test_basic_lifecycle():
# Confirm the block are actually allocated.
num_hashed_blocks = 0
blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[
request_id]
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
0].req_to_blocks[request_id]
for block in blocks:
assert block.ref_cnt == 1
num_hashed_blocks += (1 if block._block_hash is not None else 0)
@ -261,10 +261,10 @@ def test_no_spurious_prefix_caching():
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 1
local_blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[
request_local.request_id]
remote_blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[ # noqa: E501
request_remote.request_id]
local_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
0].req_to_blocks[request_local.request_id]
remote_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
0].req_to_blocks[request_remote.request_id]
# Local should have cached blocks (but not all due to preallocate).
num_hashed_blocks = 0
@ -300,8 +300,8 @@ def test_full_block_prompt():
# STEP (1): Initialize a recv.
scheduler_output = scheduler.schedule()
# All blocks should be allocated.
num_blocks = len(scheduler.kv_cache_manager.single_type_manager.
req_to_blocks[request_id])
num_blocks = len(scheduler.kv_cache_manager.coordinator.
single_type_managers[0].req_to_blocks[request_id])
assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
scheduler.update_from_output(scheduler_output, model_runner_output)
@ -319,8 +319,8 @@ def test_full_block_prompt():
# We need to recompute the final token of the prompt to generate
# the first new token, so we should not have a new block.
num_blocks = len(scheduler.kv_cache_manager.single_type_manager.
req_to_blocks[request_id])
num_blocks = len(scheduler.kv_cache_manager.coordinator.
single_type_managers[0].req_to_blocks[request_id])
assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS
assert (scheduler_output.scheduled_new_reqs[0].num_computed_tokens ==
NUM_TOKENS - 1)

View File

@ -32,11 +32,11 @@ def assert_scheduler_empty(scheduler: Scheduler):
assert len(scheduler.encoder_cache_manager.cached) == 0
# KVCache Manager.
assert len(
scheduler.kv_cache_manager.single_type_manager.req_to_blocks) == 0
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
req_to_blocks) == 0
assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0
assert len(
scheduler.kv_cache_manager.single_type_manager.num_cached_block) == 0
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
num_cached_block) == 0
num_free_blocks = (
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
assert num_free_blocks == (
@ -96,7 +96,7 @@ def create_scheduler(
block_size = vllm_config.cache_config.block_size
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks, # A large number of blocks to hold all requests
tensors={},
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(['layer'],
FullAttentionSpec(block_size, 1, 1, torch.float32,

View File

@ -40,12 +40,13 @@ def initialize_kv_cache(runner: GPUModelRunner):
tensor_size = attn_spec.page_size_bytes * NUM_BLOCKS
kv_cache_config = KVCacheConfig(
num_blocks=NUM_BLOCKS,
tensors={
"layer.0": KVCacheTensor(size=tensor_size),
},
kv_cache_tensors=[
KVCacheTensor(size=tensor_size, shared_by=["layer.0"]),
],
kv_cache_groups=[
KVCacheGroupSpec(layer_names=["layer.0"], kv_cache_spec=attn_spec)
])
],
)
runner.kv_cache_config = kv_cache_config
runner.input_batch = InputBatch(
max_num_reqs=runner.max_num_reqs,
@ -518,9 +519,9 @@ def test_init_kv_cache_without_kv_sharing():
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
available_memory)
assert kv_cache_config.num_blocks == num_expected_blocks
assert len(kv_cache_config.tensors) == 2
assert kv_cache_config.tensors[layer_0].size == available_memory // 2
assert kv_cache_config.tensors[layer_1].size == available_memory // 2
assert len(kv_cache_config.kv_cache_tensors) == 2
assert kv_cache_config.kv_cache_tensors[0].size == available_memory // 2
assert kv_cache_config.kv_cache_tensors[1].size == available_memory // 2
max_context_len =\
estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
@ -530,9 +531,9 @@ def test_init_kv_cache_without_kv_sharing():
# important: override tensor size to prevent large mem alloc during test
# this will only allocate 2 block worth of memory (2 * 32kb)
kv_cache_config.num_blocks = 1
for layer in kv_cache_config.tensors:
kv_cache_config.tensors[layer].size =\
kv_cache_spec[layer].page_size_bytes
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
kv_cache_tensor.size = (
kv_cache_spec[kv_cache_tensor.shared_by[0]].page_size_bytes)
runner.initialize_kv_cache(kv_cache_config)
@ -589,10 +590,10 @@ def test_init_kv_cache_with_kv_sharing_valid():
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
available_memory)
assert kv_cache_config.num_blocks == num_expected_blocks
assert len(kv_cache_config.tensors) == 1
assert len(kv_cache_config.kv_cache_tensors) == 1
# Each layer now has twice the available memory for KV cache
# compared to no KV sharing
assert kv_cache_config.tensors[layer_0].size == available_memory
assert kv_cache_config.kv_cache_tensors[0].size == available_memory
max_context_len =\
estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
@ -602,7 +603,7 @@ def test_init_kv_cache_with_kv_sharing_valid():
# important: override tensor size to prevent large mem alloc during test
# this will only allocate 1 block worth of memory (32kb)
kv_cache_config.num_blocks = 1
kv_cache_config.tensors[layer_0].size =\
kv_cache_config.kv_cache_tensors[0].size =\
kv_cache_spec[layer_0].page_size_bytes
runner.initialize_kv_cache(kv_cache_config)

View File

@ -2104,6 +2104,12 @@ class SchedulerConfig:
default scheduler. Can be a class directly or the path to a class of form
"mod.custom_class"."""
disable_hybrid_kv_cache_manager: bool = False
"""If set to True, KV cache manager will allocate the same size of KV cache
for all attention layers even if there are multiple type of attention layers
like full attention and sliding window attention.
"""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
@ -4465,6 +4471,21 @@ class VllmConfig:
if not self.instance_id:
self.instance_id = random_uuid()[:5]
if (envs.VLLM_USE_V1
and not self.scheduler_config.disable_hybrid_kv_cache_manager):
# logger should only print warning message for hybrid models. As we
# can't know whether the model is hybrid or not now, so we don't log
# warning message here and will log it later.
if not (current_platform.is_cuda() or current_platform.is_rocm()):
# Hybrid KV cache manager is not supported on non-GPU platforms.
self.disable_hybrid_kv_cache_manager = True
if self.kv_transfer_config is not None:
# Hybrid KV cache manager is not compatible with KV transfer.
self.disable_hybrid_kv_cache_manager = True
if self.kv_events_config is not None:
# Hybrid KV cache manager is not compatible with KV events.
self.disable_hybrid_kv_cache_manager = True
def update_sizes_for_sequence_parallelism(self,
possible_sizes: list) -> list:
# remove the sizes that not multiple of tp_size when

View File

@ -387,6 +387,9 @@ class EngineArgs:
bool] = SchedulerConfig.enable_chunked_prefill
disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
disable_hybrid_kv_cache_manager: bool = (
SchedulerConfig.disable_hybrid_kv_cache_manager)
guided_decoding_backend: GuidedDecodingBackend = DecodingConfig.backend
guided_decoding_disable_fallback: bool = DecodingConfig.disable_fallback
guided_decoding_disable_any_whitespace: bool = \
@ -849,6 +852,9 @@ class EngineArgs:
**scheduler_kwargs["disable_chunked_mm_input"])
scheduler_group.add_argument("--scheduler-cls",
**scheduler_kwargs["scheduler_cls"])
scheduler_group.add_argument(
"--disable-hybrid-kv-cache-manager",
**scheduler_kwargs["disable_hybrid_kv_cache_manager"])
# vLLM arguments
vllm_kwargs = get_kwargs(VllmConfig)
@ -1174,6 +1180,8 @@ class EngineArgs:
max_num_partial_prefills=self.max_num_partial_prefills,
max_long_partial_prefills=self.max_long_partial_prefills,
long_prefill_token_threshold=self.long_prefill_token_threshold,
disable_hybrid_kv_cache_manager=self.
disable_hybrid_kv_cache_manager,
)
lora_config = LoRAConfig(

View File

@ -7,8 +7,8 @@ from typing import Callable, Optional
from vllm.distributed.kv_events import (AllBlocksCleared, BlockRemoved,
BlockStored, KVCacheEvent)
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import (BlockHash, FreeKVCacheBlockQueue,
KVCacheBlock,
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
FreeKVCacheBlockQueue, KVCacheBlock,
generate_block_hash_extra_keys,
hash_block_tokens)
from vllm.v1.request import Request
@ -27,6 +27,7 @@ class BlockPool:
Args:
num_gpu_blocks: The number of blocks in the pool.
enable_caching: Whether to enable prefix caching.
enable_kv_cache_events: Whether to enable kv cache events.
"""
def __init__(
@ -56,7 +57,7 @@ class BlockPool:
# if there is already an identical block in the cache. This is because
# we want to make sure the allocated block IDs won't change so that
# block tables are append-only.
self.cached_block_hash_to_block: dict[BlockHash, dict[
self.cached_block_hash_to_block: dict[BlockHashWithGroupId, dict[
int, KVCacheBlock]] = defaultdict(dict)
# To represent a placeholder block with block_id=0.
@ -68,22 +69,29 @@ class BlockPool:
self.enable_kv_cache_events = enable_kv_cache_events
self.kv_event_queue: list[KVCacheEvent] = []
def get_cached_block(self,
block_hash: BlockHash) -> Optional[KVCacheBlock]:
"""Get a cached block by the block hash, or None if cache miss.
def get_cached_block(
self, block_hash: BlockHash,
kv_cache_group_ids: list[int]) -> Optional[list[KVCacheBlock]]:
"""Get the cached block by the block hash for each group in
`kv_cache_group_ids`, or None if cache miss for any group.
If there are duplicated blocks, we return the first block in the cache.
Args:
block_hash: The hash value of the block.
kv_cache_group_ids: The ids of the KV cache groups.
Returns:
The cached block if it exists, or None.
The cached blocks if exists, or None.
"""
cached_blocks = self.cached_block_hash_to_block.get(block_hash)
if not cached_blocks:
return None
first_block_id = next(iter(cached_blocks))
return cached_blocks[first_block_id]
cached_blocks = []
for group_id in kv_cache_group_ids:
cached_blocks_one_group = self.cached_block_hash_to_block.get(
BlockHashWithGroupId(block_hash, group_id))
if not cached_blocks_one_group:
return None
first_block_id = next(iter(cached_blocks_one_group))
cached_blocks.append(cached_blocks_one_group[first_block_id])
return cached_blocks
def cache_full_blocks(
self,
@ -93,6 +101,7 @@ class BlockPool:
num_cached_blocks: int,
num_full_blocks: int,
block_size: int,
kv_cache_group_id: int,
hash_fn: Callable,
) -> None:
"""Cache a list of full blocks for prefix caching.
@ -112,6 +121,7 @@ class BlockPool:
num_full_blocks: The number of blocks that are full and should
be cached after this function.
block_size: Number of tokens in each block.
kv_cache_group_id: The id of the KV cache group.
hash_fn: The hash function to use for block hashes.
"""
if num_cached_blocks == num_full_blocks:
@ -126,7 +136,7 @@ class BlockPool:
else:
prev_block = blocks[num_cached_blocks - 1]
assert prev_block.block_hash is not None
prev_block_hash_value = prev_block.block_hash.hash_value
prev_block_hash_value = prev_block.block_hash.get_hash_value()
parent_block_hash = prev_block_hash_value
new_hashes: Optional[list[int]] = ([] if self.enable_kv_cache_events
@ -138,8 +148,9 @@ class BlockPool:
# The block hash may already be computed in
# "get_computed_blocks" if the tokens are not generated by
# this request (either the prompt tokens or the previously
# generated tokens with preemption). In this case we simply
# reuse the block hash.
# generated tokens with preemption), or by other
# single_type_managers with the same block_size.
# In this case we simply reuse the block hash.
block_hash = new_block_hashes[i]
else:
# Otherwise compute the block hash and cache it in the request
@ -166,8 +177,11 @@ class BlockPool:
block_hashes.append(block_hash)
# Update and added the full block to the cache.
blk.block_hash = block_hash
self.cached_block_hash_to_block[block_hash][blk.block_id] = blk
block_hash_with_group_id = BlockHashWithGroupId(
block_hash, kv_cache_group_id)
blk.block_hash = block_hash_with_group_id
self.cached_block_hash_to_block[block_hash_with_group_id][
blk.block_id] = blk
if new_hashes is not None:
new_hashes.append(block_hash.hash_value)
prev_block_hash_value = block_hash.hash_value
@ -237,12 +251,16 @@ class BlockPool:
del self.cached_block_hash_to_block[block_hash]
if self.enable_kv_cache_events:
# FIXME (Chen): Not sure whether we should return `hash_value`
# or `(hash_value, group_id)` here. But it's fine now because
# we disable hybrid kv cache manager when kv cache event is
# enabled, so there is only one group.
self.kv_event_queue.append(
BlockRemoved(block_hashes=[block_hash.hash_value]))
BlockRemoved(block_hashes=[block_hash.get_hash_value()]))
return True
return False
def touch(self, blocks: list[KVCacheBlock]) -> None:
def touch(self, blocks: list[list[KVCacheBlock]]) -> None:
"""Touch a block increases its reference count by 1, and may remove
the block from the free queue. This is used when a block is hit by
another request with the same prefix.
@ -250,12 +268,13 @@ class BlockPool:
Args:
blocks: A list of blocks to touch.
"""
for block in blocks:
# ref_cnt=0 means this block is in the free list (i.e. eviction
# candidate), so remove it.
if block.ref_cnt == 0 and not block.is_null:
self.free_block_queue.remove(block)
block.incr_ref()
for blocks_per_group in blocks:
for block in blocks_per_group:
# ref_cnt=0 means this block is in the free list (i.e. eviction
# candidate), so remove it.
if block.ref_cnt == 0 and not block.is_null:
self.free_block_queue.remove(block)
block.incr_ref()
def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None:
"""Free a list of blocks. The blocks should be ordered by their

View File

@ -0,0 +1,358 @@
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from typing import Callable, Optional
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
from vllm.v1.core.single_type_kv_cache_manager import (
FullAttentionManager, SingleTypeKVCacheManager,
get_manager_for_kv_cache_spec)
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
from vllm.v1.request import Request
class KVCacheCoordinator(ABC):
"""
Coordinate the KV cache of different KV cache groups.
"""
def __init__(
self,
kv_cache_config: KVCacheConfig,
max_model_len: int,
use_eagle: bool,
enable_caching: bool,
caching_hash_fn: Callable,
enable_kv_cache_events: bool,
):
self.kv_cache_config = kv_cache_config
self.max_model_len = max_model_len
self.block_pool = BlockPool(kv_cache_config.num_blocks, enable_caching,
enable_kv_cache_events)
self.single_type_managers: list[SingleTypeKVCacheManager] = []
# Needs special handling for find_longest_cache_hit if eagle is enabled
self.use_eagle = use_eagle
for i in range(len(self.kv_cache_config.kv_cache_groups)):
kv_cache_spec = self.kv_cache_config.kv_cache_groups[
i].kv_cache_spec
self.single_type_managers.append(
get_manager_for_kv_cache_spec(
kv_cache_spec=kv_cache_spec,
block_pool=self.block_pool,
kv_cache_group_id=i,
caching_hash_fn=caching_hash_fn,
))
def get_num_blocks_to_allocate(
self, request_id: str, num_tokens: int,
new_computed_blocks: list[list[KVCacheBlock]]) -> int:
"""
Get the number of blocks needed to be allocated for the request.
Args:
request_id: The request ID.
num_tokens: The total number of tokens that need a slot (including
tokens that are already allocated).
new_computed_blocks: The new computed blocks just hitting the
prefix caching.
Returns:
The number of blocks.
"""
num_blocks_to_allocate = 0
for i, manager in enumerate(self.single_type_managers):
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
request_id, num_tokens, new_computed_blocks[i])
return num_blocks_to_allocate
def save_new_computed_blocks(
self, request_id: str,
new_computed_blocks: list[list[KVCacheBlock]]) -> None:
"""
Add the new computed blocks to the request.
Args:
request_id: The request ID.
new_computed_blocks: The new computed blocks just hitting the
prefix cache.
"""
for i, manager in enumerate(self.single_type_managers):
manager.save_new_computed_blocks(request_id,
new_computed_blocks[i])
def allocate_new_blocks(self, request_id: str,
num_tokens: int) -> list[list[KVCacheBlock]]:
"""
Allocate new blocks for the request to give it at least `num_tokens`
token slots.
Args:
request_id: The request ID.
num_tokens: The total number of tokens that need a slot (including
tokens that are already allocated).
Returns:
The new allocated blocks.
"""
new_blocks = []
for manager in self.single_type_managers:
new_blocks.append(
manager.allocate_new_blocks(request_id, num_tokens))
return new_blocks
def cache_blocks(self, request: Request, block_hashes: list[BlockHash],
num_computed_tokens: int) -> None:
"""
Cache the blocks for the request.
Args:
request: The request.
block_hashes: The block hashes of the request.
num_tokens: The total number of tokens that need to be cached
(including tokens that are already cached).
"""
for manager in self.single_type_managers:
manager.cache_blocks(request, block_hashes, num_computed_tokens)
def free(self, request_id: str) -> None:
"""
Free the blocks for the request.
Args:
request_id: The request ID.
"""
for manager in self.single_type_managers:
manager.free(request_id)
def get_num_common_prefix_blocks(self, request_id: str,
num_running_requests: int) -> list[int]:
"""
Get the number of common prefix blocks for a request.
Args:
request_id: The request ID.
block_hashes: The block hashes of the request.
Returns:
The number of common prefix blocks.
"""
num_blocks_per_group = [
manager.get_num_common_prefix_blocks(request_id,
num_running_requests)
for manager in self.single_type_managers
]
return num_blocks_per_group
def remove_skipped_blocks(self, request_id: str,
num_computed_tokens: int) -> None:
"""
Remove the blocks that are no longer needed from `blocks` and replace
the removed blocks with null_block.
Args:
request_id: The request ID.
num_computed_tokens: The number of tokens that have been computed.
"""
for manager in self.single_type_managers:
manager.remove_skipped_blocks(request_id, num_computed_tokens)
def get_blocks(self, request_id: str) -> list[list[KVCacheBlock]]:
"""
Get the blocks for the request.
"""
return [
manager.req_to_blocks[request_id]
for manager in self.single_type_managers
]
@abstractmethod
def find_longest_cache_hit(
self, block_hashes: list[BlockHash],
max_cache_hit_length: int) -> tuple[list[list[KVCacheBlock]], int]:
pass
class UnitaryKVCacheCoordinator(KVCacheCoordinator):
"""
KV cache coordinator for models with only one KV cache group. This is the
case for models with only one KV cache type, e.g., all attention layers use
full attention or all attention layers use sliding window attention.
"""
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
use_eagle: bool, enable_caching: bool,
caching_hash_fn: Callable, enable_kv_cache_events: bool):
super().__init__(kv_cache_config, max_model_len, use_eagle,
enable_caching, caching_hash_fn,
enable_kv_cache_events)
self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[
0].kv_cache_spec
self.block_size = self.kv_cache_spec.block_size
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
"UnitaryKVCacheCoordinator assumes only one kv cache group")
def find_longest_cache_hit(
self, block_hashes: list[BlockHash],
max_cache_hit_length: int) -> tuple[list[list[KVCacheBlock]], int]:
hit_blocks = self.single_type_managers[0].find_longest_cache_hit(
block_hashes=block_hashes,
max_length=max_cache_hit_length,
kv_cache_group_ids=[0],
block_pool=self.block_pool,
kv_cache_spec=self.kv_cache_spec,
use_eagle=self.use_eagle,
)
return hit_blocks, len(hit_blocks[0]) * self.block_size
class HybridKVCacheCoordinator(KVCacheCoordinator):
"""
KV cache coordinator for hybrid models with multiple KV cache types, and
thus multiple kv cache groups.
To simplify `find_longest_cache_hit`, it only supports the combination of
two types of KV cache groups, and one of them must be full attention.
May extend to more general cases in the future.
"""
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
use_eagle: bool, enable_caching: bool,
caching_hash_fn: Callable, enable_kv_cache_events: bool):
super().__init__(kv_cache_config, max_model_len, use_eagle,
enable_caching, caching_hash_fn,
enable_kv_cache_events)
self.verify_and_split_kv_cache_groups()
def verify_and_split_kv_cache_groups(self) -> None:
"""
Verifies that the model has exactly two types of KV cache groups, and
one of them is full attention. Then, split the kv cache groups into full
attention groups and other groups.
"""
full_attention_type_id: Optional[str] = None
other_type_id: Optional[str] = None
self.full_attention_group_ids: list[int] = []
self.other_group_ids: list[int] = []
for i, g in enumerate(self.kv_cache_config.kv_cache_groups):
if isinstance(g.kv_cache_spec, FullAttentionSpec):
if full_attention_type_id is None:
full_attention_type_id = g.kv_cache_spec.type_id
else:
assert full_attention_type_id == g.kv_cache_spec.type_id, (
"HybridKVCacheCoordinator assumes exactly one type of "
"full attention groups now.")
self.full_attention_group_ids.append(i)
else:
if other_type_id is None:
other_type_id = g.kv_cache_spec.type_id
else:
assert other_type_id == g.kv_cache_spec.type_id, (
"HybridKVCacheCoordinator assumes "
"exactly one other type of groups now.")
self.other_group_ids.append(i)
assert full_attention_type_id is not None, (
"HybridKVCacheCoordinator assumes exactly one type of full "
"attention groups now.")
assert other_type_id is not None, (
"HybridKVCacheCoordinator assumes exactly one type of other "
"groups now.")
self.full_attention_manager_cls = FullAttentionManager
self.other_attention_cls = self.single_type_managers[
self.other_group_ids[0]].__class__
self.full_attention_spec = self.kv_cache_config.kv_cache_groups[
self.full_attention_group_ids[0]].kv_cache_spec
self.other_spec = self.kv_cache_config.kv_cache_groups[
self.other_group_ids[0]].kv_cache_spec
self.full_attention_block_size = self.full_attention_spec.block_size
self.other_block_size = self.other_spec.block_size
assert self.other_block_size % self.full_attention_block_size == 0, (
"KVCacheCoordinator assumes the block_size of full attention "
"layers is divisible by other layers now.")
def find_longest_cache_hit(
self,
block_hashes: list[BlockHash],
max_cache_hit_length: int,
) -> tuple[list[list[KVCacheBlock]], int]:
"""
Find the longest cache hit for the request.
Args:
block_hashes: The block hashes of the request.
max_cache_hit_length: The maximum length of the cache hit.
Returns:
A tuple containing:
- A list of the cache hit blocks for each single type manager.
- The number of tokens of the longest cache hit.
"""
# First, find the longest cache hit for full attention.
hit_blocks_full_attn = (
self.full_attention_manager_cls.find_longest_cache_hit(
block_hashes=block_hashes,
max_length=max_cache_hit_length,
kv_cache_group_ids=self.full_attention_group_ids,
block_pool=self.block_pool,
kv_cache_spec=self.full_attention_spec,
use_eagle=self.use_eagle,
))
hit_length = len(
hit_blocks_full_attn[0]) * self.full_attention_block_size
# Next, find the cache hit for the other attention WITHIN
# the cache hit of full attention.
hit_blocks_other_attn = (
self.other_attention_cls.find_longest_cache_hit(
block_hashes=block_hashes,
max_length=hit_length,
kv_cache_group_ids=self.other_group_ids,
block_pool=self.block_pool,
kv_cache_spec=self.other_spec,
use_eagle=self.use_eagle,
))
hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size
# NOTE: the prefix cache hit length must be a multiply of block_size as
# we don't support partial block cache hit yet. The cache hit length
# of other attention is ensured to be a multiply of the block size of
# full attention layers in current implementation, because hit_length is
# a multiply of other attention's block size, and other attention's
# block size is a multiply of full attention's block size (verified in
# `verify_and_split_kv_cache_groups`).
assert hit_length % self.full_attention_block_size == 0
# Truncate the full attention cache hit to the length of the
# cache hit of the other attention.
for i in range(len(hit_blocks_full_attn)):
del hit_blocks_full_attn[i][hit_length //
self.full_attention_block_size:]
# Merge the hit blocks of full attention and other attention.
hit_blocks = hit_blocks_other_attn
for group_id, blocks in enumerate(hit_blocks_full_attn):
# NOTE: there is only one full attention group in most cases. So
# the time complexity of insert is fine.
hit_blocks.insert(group_id, blocks)
return hit_blocks, hit_length
def get_kv_cache_coordinator(
kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool,
enable_caching: bool, caching_hash_fn: Callable,
enable_kv_cache_events: bool) -> KVCacheCoordinator:
if len(kv_cache_config.kv_cache_groups) == 1:
return UnitaryKVCacheCoordinator(kv_cache_config, max_model_len,
use_eagle, enable_caching,
caching_hash_fn,
enable_kv_cache_events)
else:
return HybridKVCacheCoordinator(kv_cache_config, max_model_len,
use_eagle, enable_caching,
caching_hash_fn,
enable_kv_cache_events)

View File

@ -8,11 +8,9 @@ from typing import Optional
from vllm.distributed.kv_events import KVCacheEvent
from vllm.logger import init_logger
from vllm.utils import sha256
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
hash_request_tokens)
from vllm.v1.core.single_type_kv_cache_manager import (
get_manager_for_kv_cache_spec)
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request, RequestStatus
@ -22,16 +20,24 @@ logger = init_logger(__name__)
@dataclass
class KVCacheBlocks:
blocks: list[KVCacheBlock]
"""
The allocation result of KVCacheManager, work as the interface between
Scheduler and KVCacheManager, to hide KVCacheManager's internal data
structure from the Scheduler.
"""
blocks: list[list[KVCacheBlock]]
"""
blocks[i][j] refers to the i-th kv_cache_group and the j-th block of tokens.
We don't use block of tokens as the outer dimension because it assumes all
kv_cache_groups have the same number of blocks, which is true for now but
will be broken if we want to give different block_size to different
kv_cache_groups in the future.
"""
def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks":
"""Adds two KVCacheBlocks instances."""
return KVCacheBlocks(self.blocks + other.blocks)
@classmethod
def create_empty(cls) -> "KVCacheBlocks":
"""Creates a new KVCacheBlocks instance with no blocks."""
return cls([])
return KVCacheBlocks(
[blk1 + blk2 for blk1, blk2 in zip(self.blocks, other.blocks)])
def get_block_ids(self) -> list[list[int]]:
"""
@ -39,15 +45,20 @@ class KVCacheBlocks:
Returns:
list[list[int]]: A two-level list where
* the outer list corresponds to KV cache groups (only 1 group now)
* the outer list corresponds to KV cache groups
* each inner list contains the block_ids of the blocks in that group
"""
return [[block.block_id for block in self.blocks]]
block_ids = []
for group in self.blocks:
block_ids.append([blk.block_id for blk in group])
return block_ids
def get_unhashed_block_ids(self) -> list[int]:
"""Get block_ids of unhashed blocks from KVCacheBlocks instance."""
assert len(self.blocks) == 1, "Only one group is supported"
return [
block.block_id for block in self.blocks if block.block_hash is None
block.block_id for block in self.blocks[0]
if block.block_hash is None
]
@ -63,12 +74,6 @@ class KVCacheManager:
log_stats: bool = False,
enable_kv_cache_events: bool = False,
) -> None:
assert len(kv_cache_config.kv_cache_groups) == 1, (
"KVCacheManager does not support hybrid models with more than 1 "
"kv cache group")
kv_cache_spec = kv_cache_config.kv_cache_groups[0].kv_cache_spec
self.block_size = kv_cache_spec.block_size
self.num_gpu_blocks = kv_cache_config.num_blocks
self.max_model_len = max_model_len
self.enable_caching = enable_caching
@ -77,17 +82,24 @@ class KVCacheManager:
self.log_stats = log_stats
# FIXME: make prefix cache stats conditional on log_stats
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
assert len(
set(g.kv_cache_spec.block_size
for g in kv_cache_config.kv_cache_groups)
) == 1, "Only one block size is supported for now"
self.block_size = kv_cache_config.kv_cache_groups[
0].kv_cache_spec.block_size
self.block_pool = BlockPool(self.num_gpu_blocks, enable_caching,
enable_kv_cache_events)
self.single_type_manager = get_manager_for_kv_cache_spec(
kv_cache_spec=kv_cache_spec,
block_pool=self.block_pool,
self.coordinator = get_kv_cache_coordinator(
kv_cache_config=kv_cache_config,
max_model_len=self.max_model_len,
use_eagle=self.use_eagle,
num_kv_cache_groups=1,
enable_caching=enable_caching,
caching_hash_fn=self.caching_hash_fn,
enable_kv_cache_events=enable_kv_cache_events,
)
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
self.block_pool = self.coordinator.block_pool
self.kv_cache_config = kv_cache_config
# Mapping from request ID to kv block hashes.
# This is to avoid recomputing the block hashes for each call of
@ -133,7 +145,7 @@ class KVCacheManager:
# When the request requires prompt logprobs, we skip prefix caching.
if (not self.enable_caching
or request.sampling_params.prompt_logprobs is not None):
return KVCacheBlocks.create_empty(), 0
return self.create_empty_block_list(), 0
# The block hashes for the request may already be computed
# if the scheduler has tried to schedule the request before.
@ -154,20 +166,16 @@ class KVCacheManager:
# num_computed_tokens to be block-size aligned. Removing this limitation
# could slightly improve performance in the future.
max_cache_hit_length = request.num_tokens - 1
computed_blocks = self.single_type_manager.find_longest_cache_hit(
block_hashes, max_cache_hit_length)
# NOTE(woosuk): Since incomplete blocks are not eligible for
# sharing, `num_computed_tokens` is always a multiple of
# `block_size`.
num_computed_tokens = len(computed_blocks) * self.block_size
computed_blocks, num_new_computed_tokens = (
self.coordinator.find_longest_cache_hit(block_hashes,
max_cache_hit_length))
if self.log_stats:
assert self.prefix_cache_stats is not None
self.prefix_cache_stats.queries += request.num_tokens
self.prefix_cache_stats.hits += num_computed_tokens
self.prefix_cache_stats.hits += num_new_computed_tokens
return KVCacheBlocks(computed_blocks), num_computed_tokens
return KVCacheBlocks(computed_blocks), num_new_computed_tokens
def allocate_slots(
self,
@ -220,7 +228,9 @@ class KVCacheManager:
if new_computed_blocks is not None:
new_computed_block_list = new_computed_blocks.blocks
else:
new_computed_block_list = []
new_computed_block_list = [
[] for _ in range(len(self.kv_cache_config.kv_cache_groups))
]
# Free the blocks that are skipped during the attention computation
# (e.g., tokens outside the sliding window).
@ -228,8 +238,8 @@ class KVCacheManager:
# insufficient free blocks.
# Should call this function before allocating new blocks to reduce
# the number of evicted blocks.
self.single_type_manager.remove_skipped_blocks(
request.request_id, request.num_computed_tokens)
self.coordinator.remove_skipped_blocks(request.request_id,
request.num_computed_tokens)
# The number of computed tokens is the number of computed tokens plus
# the new prefix caching hits
@ -238,12 +248,12 @@ class KVCacheManager:
num_tokens_need_slot = min(
num_computed_tokens + num_new_tokens + num_lookahead_tokens,
self.max_model_len)
num_blocks_to_allocate = (
self.single_type_manager.get_num_blocks_to_allocate(
request_id=request.request_id,
num_tokens=num_tokens_need_slot,
new_computed_blocks=new_computed_block_list,
))
num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate(
request_id=request.request_id,
num_tokens=num_tokens_need_slot,
new_computed_blocks=new_computed_block_list,
)
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
# Cannot allocate new blocks
@ -253,16 +263,16 @@ class KVCacheManager:
if self.enable_caching:
self.block_pool.touch(new_computed_block_list)
else:
assert not new_computed_block_list, (
assert all(not blocks for blocks in new_computed_block_list), (
"Computed blocks should be empty when "
"prefix caching is disabled")
# Append the new computed blocks to the request blocks until now to
# avoid the case where the new blocks cannot be allocated.
self.single_type_manager.save_new_computed_blocks(
request.request_id, new_computed_block_list)
self.coordinator.save_new_computed_blocks(request.request_id,
new_computed_block_list)
new_blocks = self.single_type_manager.allocate_new_blocks(
new_blocks = self.coordinator.allocate_new_blocks(
request.request_id, num_tokens_need_slot)
# P/D: delay caching blocks if we have to recv from
@ -273,7 +283,7 @@ class KVCacheManager:
# Speculated tokens might be rejected in the future, so we does
# not cache any speculated tokens. We only cache blocks with
# generated (accepted) tokens.
self.single_type_manager.cache_blocks(
self.coordinator.cache_blocks(
request, self.req_to_block_hashes[request.request_id],
num_computed_tokens + num_new_tokens - num_draft_tokens)
@ -287,7 +297,7 @@ class KVCacheManager:
Args:
request: The request to free the blocks.
"""
self.single_type_manager.free(request.request_id)
self.coordinator.free(request.request_id)
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache. This function may be used in RLHF
@ -345,10 +355,8 @@ class KVCacheManager:
group.
"""
assert request.status == RequestStatus.RUNNING
return [
self.single_type_manager.get_num_common_prefix_blocks(
request.request_id, num_running_requests)
]
return self.coordinator.get_num_common_prefix_blocks(
request.request_id, num_running_requests)
def free_block_hashes(self, request: Request) -> None:
"""Discard the block hashes for the request.
@ -368,6 +376,15 @@ class KVCacheManager:
def get_block_ids(self, request_id: str) -> list[list[int]]:
"""Get the block ids of a request."""
assert request_id in self.single_type_manager.req_to_blocks
return KVCacheBlocks(self.single_type_manager.req_to_blocks[request_id]
).get_block_ids()
return KVCacheBlocks(
self.coordinator.get_blocks(request_id)).get_block_ids()
def cache_blocks(self, request: Request, block_hashes: list[BlockHash],
num_computed_tokens: int) -> None:
"""Cache the blocks for the request."""
self.coordinator.cache_blocks(request, block_hashes,
num_computed_tokens)
def create_empty_block_list(self) -> KVCacheBlocks:
"""Creates a new KVCacheBlocks instance with no blocks."""
return KVCacheBlocks([[] for _ in range(self.num_kv_cache_groups)])

View File

@ -1,8 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""KV-Cache Utilities."""
import os
from collections import deque
from collections import defaultdict, deque
from collections.abc import Iterable, Sequence
from dataclasses import dataclass
from typing import Any, Callable, NamedTuple, Optional
@ -33,6 +34,18 @@ class BlockHash(NamedTuple):
extra_keys: Optional[Any] = None
class BlockHashWithGroupId(NamedTuple):
# The hash value for the contents (e.g., token_ids) of a block without group
# ID. The value is the same for blocks representing the same tokens but for
# different groups.
block_hash: BlockHash
# The KV cache group ID.
group_id: int
def get_hash_value(self) -> int:
return self.block_hash.hash_value
# The hash seed for the first block of the prefix block sequence.
#
# Even if the hash function is the builtin hash(), we use sha256 to generate
@ -44,7 +57,7 @@ class BlockHash(NamedTuple):
# This aligns with the behavior of Python's hash() function, which also uses
# a random seed if PYTHONHASHSEED is not set.
NONE_HASH = int.from_bytes(os.urandom(32), byteorder="big") if os.getenv(
'PYTHONHASHSEED') is None else sha256(os.getenv('PYTHONHASHSEED'))
"PYTHONHASHSEED") is None else sha256(os.getenv("PYTHONHASHSEED"))
class PrefixCachingMetrics:
@ -118,7 +131,7 @@ class KVCacheBlock:
ref_cnt: int = 0
# The hash of the block composed of (block hash, tuple of token IDs).
# It is only available when the block is full.
_block_hash: Optional[BlockHash] = None
_block_hash: Optional[BlockHashWithGroupId] = None
# Used to construct a doubly linked list for free blocks.
# These two attributes should only be manipulated by FreeKVCacheBlockQueue.
@ -135,11 +148,11 @@ class KVCacheBlock:
self.ref_cnt -= 1
@property
def block_hash(self) -> Optional[BlockHash]:
def block_hash(self) -> Optional[BlockHashWithGroupId]:
return self._block_hash
@block_hash.setter
def block_hash(self, block_hash: BlockHash):
def block_hash(self, block_hash: BlockHashWithGroupId):
assert self.block_hash is None, (
"The block already has a hash. This should not happen.")
self._block_hash = block_hash
@ -151,10 +164,10 @@ class KVCacheBlock:
def __repr__(self) -> str:
# Use block_id instead of KVCacheBlock object to avoid calling __repr__
# on KVCacheBlock object recursively.
prev_block_id = self.prev_free_block.block_id \
if self.prev_free_block else None
next_block_id = self.next_free_block.block_id \
if self.next_free_block else None
prev_block_id = (self.prev_free_block.block_id
if self.prev_free_block else None)
next_block_id = (self.next_free_block.block_id
if self.next_free_block else None)
return (f"KVCacheBlock(block_id={self.block_id}, "
f"ref_cnt={self.ref_cnt}, "
f"_block_hash={self._block_hash}, "
@ -570,20 +583,20 @@ def create_kv_cache_group_specs(
kv_cache_spec: dict[str, KVCacheSpec],
grouped_layer_names: list[list[str]]) -> list[KVCacheGroupSpec]:
"""
Create KVCacheGroupSpec object for each kv cache group layer.
The layers in the same group should share the same
KVCacheSpec.
Create KVCacheGroupSpec object for each kv cache group layer.
The layers in the same group should share the same
KVCacheSpec.
Args:
kv_cache_spec:
A mapping from each layer name to its corresponding KVCacheSpec.
grouped_layer_names:
A list of kv cache groups, where each element is a list of layer
names that belong to the same group and should share the same
KVCacheSpec.
Returns:
A list of KVCacheGroupSpec objects, one for each group.
"""
Args:
kv_cache_spec:
A mapping from each layer name to its corresponding KVCacheSpec.
grouped_layer_names:
A list of kv cache groups, where each element is a list of layer
names that belong to the same group and should share the same
KVCacheSpec.
Returns:
A list of KVCacheGroupSpec objects, one for each group.
"""
kv_cache_groups = []
for layer_names_one_group in grouped_layer_names:
layer_specs = [
@ -628,6 +641,37 @@ def get_max_concurrency_for_kv_cache_config(
return max_concurrency
def get_num_blocks(vllm_config: VllmConfig, num_layers: int,
available_memory: int, page_size: int) -> int:
"""
Get the number of kv cache blocks.
Args:
vllm_config: The global VllmConfig
num_layers: The number of layers
available_memory: Memory available for KV cache in bytes.
page_size: The page size of the KV cache.
"""
num_blocks = int(available_memory // page_size // num_layers)
num_blocks = max(num_blocks, 0)
if vllm_config.cache_config.num_gpu_blocks_override is not None:
num_gpu_blocks_override = \
vllm_config.cache_config.num_gpu_blocks_override
logger.info(
"Overriding num_gpu_blocks=%d with "
"num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override)
return num_blocks
def get_uniform_page_size(kv_cache_spec: dict[str, KVCacheSpec]) -> int:
"""
Get the page size of the KV cache.
"""
page_sizes = set(layer.page_size_bytes for layer in kv_cache_spec.values())
assert len(page_sizes) == 1
return page_sizes.pop()
def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int) -> KVCacheConfig:
@ -644,32 +688,24 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
The generated KVCacheConfig
"""
page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()}
assert len(page_sizes) == 1
page_size = page_sizes.pop()
num_blocks = int(available_memory // page_size // len(kv_cache_spec))
num_blocks = max(num_blocks, 0)
if vllm_config.cache_config.num_gpu_blocks_override is not None:
num_gpu_blocks_override = \
vllm_config.cache_config.num_gpu_blocks_override
logger.info(
"Overriding num_gpu_blocks=%d with "
"num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override)
num_blocks = num_gpu_blocks_override
page_size = get_uniform_page_size(kv_cache_spec)
num_blocks = get_num_blocks(vllm_config, len(kv_cache_spec),
available_memory, page_size)
per_layer_size = page_size * num_blocks
# All layers have the same KV cache spec, so we create one kv cache group
# for all layers.
grouped_layer_names = [list(kv_cache_spec.keys())]
# Each layer uses a separate Tensor to store its KV cache.
kv_cache_tensors = [
KVCacheTensor(size=per_layer_size, shared_by=[layer_name])
for layer_name in kv_cache_spec
]
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks,
tensors={
layer_name: KVCacheTensor(size=per_layer_size)
for layer_name in kv_cache_spec
},
kv_cache_tensors=kv_cache_tensors,
kv_cache_groups=create_kv_cache_group_specs(kv_cache_spec,
grouped_layer_names),
)
@ -685,17 +721,185 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
return kv_cache_config
def is_kv_cache_page_size_uniform(
kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
"""
Whether all layers in the given KVCacheSpec have the same page size.
Args:
kv_cache_spec: The KVCacheSpec of each attention layer in the model
Returns:
True if all layers have the same page size, False otherwise.
"""
page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()}
return len(page_sizes) == 1
def _get_kv_cache_config_uniform_page_size(
vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int) -> KVCacheConfig:
"""
Generates the KV cache configuration for hybrid models with multiple
attention types but still with a uniform page size (physical memory per
block per layer) for all layers.
Detailed explanation about kv cache management of hybrid models:
The layers in the models are repeated with some patterns, e.g., a model
with 10 full attention layers and 20 sliding window attention layers can be
regarded as repeating the pattern (1 * full, 2 * sw) 10 times.
The KVCacheManager allocates different block tables for each of the 3 layers
in the pattern, and repeats each of them 10 times to generate the
block_table for the 30 layers in the model.
Therefore, we can group the layers in the model into 3 kv_cache_groups, each
of which contains 10 layers in the model.
The KVCacheManager allocates the block_table for each group based on its
kv_cache spec, and the model runner applies the block table to each layer
in the group.
For example:
1. A model only uses full attention. The pattern is
(num_hidden_layers * full), so there is only one group and the block table
is shared by all layers. It is already handled by
`_get_kv_cache_config_uniform_type`.
2. A model with 10 full attention layers and 20 sliding window
attention layers. There are 3 layers in the pattern (1 * full, 2 * sw), so
there are 3 kv_cache_groups, each of which represents 10 layers.
To simplify the implementation, we make the following assumptions:
1. Physical memory per block: Must be the same across all KV cache groups.
Breaking this assumption is non-trivial due to memory fragmentation concerns
when allocating blocks of different sizes.
2. Tokens per block (block_size): Currently, we directly use
`CacheConfig.block_size` for all layers. It can be extended to vary by KV
cache group, but within each KV cache group, all layers must share the same
block size.
3. Physical memory per token per layer: This property is decided by model
config. Currently we only support models that have the same physical memory
per token per layer for all layers. Can be relaxed with a simple extension,
but still need to keep physical memory per block the same for all groups.
4. Number of layers per group: Currently assumed the same for all layers.
Can be relaxed with a simple extension, but still need to keep physical
memory per block the same for all groups.
5. Attention type within groups: All layers in a group must share the same
attention type. One exception is that, when
`--disable-hybrid-kv-cache-manager` is true, the single group for full
attention layers may also include attention layers using sliding window or
LLaMA 4 local attention. See `unify_hybrid_kv_cache_specs` for more details.
6. Support for multiple attention types: The design for most components is
general to an arbitrary number of attention types. But
`find_longest_cache_hit` only supports one attention type or two
types of full-attention plus exactly one another type. The general
implementation of this function is feasible but we don't know how to
implement it cleanly yet.
As we assume tokens per block, physical memory per token per layer, and
number of layers per group are the same now, we can ensure that physical
memory per block is the same for all groups.
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The KVCacheSpec of each attention layer in the model
available_memory: Memory available for KV cache in bytes.
Returns:
The generated KVCacheConfig
"""
# Group all layers by type_id.
# E.g., 2 full attention layers and 3 sliding window attention layers,
# -> (full.0, full.1), (sw.0, sw.1, sw.2).
same_type_layers: dict[str, list[str]] = defaultdict(list)
for layer_name, layer_spec in kv_cache_spec.items():
same_type_layers[layer_spec.type_id].append(layer_name)
# Split each group into smaller groups, to make the number of layers in each
# group identical. Add padding to the last group of each type if necessary.
# E.g., (full.0, full.1), (sw.0, sw.1, sw.2)
# split to 3 groups with 2 layers each:
# (full.0, full.1), (sw.0, sw.1), (sw.2, padding).
# FIXME(Chen): At the moment of writing this code (2025-06-02), all
# open-source hybrid model follows a n:1 pattern between different attention
# types (e.g., Gemma3 5:1 between sw and full, LLaMA4 3:1 between local and
# full), so we can use the "1" in the n:1 pattern as the group size, which
# is the minimum number of layers among all attention types. Need a better
# strategy if we want to support more complex patterns (e.g., 20 full + 30
# sw, where the group size should be 10).
group_size = min([len(layers) for layers in same_type_layers.values()])
grouped_layers = []
for layers in same_type_layers.values():
num_padding_layers = group_size - len(layers) % group_size
if num_padding_layers != group_size:
logger.warning(
"Add %d padding layers, may waste at most %.2f%% KV cache memory", # noqa
num_padding_layers,
num_padding_layers / len(layers) * 100,
)
for i in range(0, len(layers), group_size):
grouped_layers.append(layers[i:i + group_size])
kv_cache_groups = create_kv_cache_group_specs(kv_cache_spec,
grouped_layers)
# Determine how model runners should initialize the KV cache tensors.
# We will have group_size memory pools, each is shared by one layer from
# each group. As layers of different groups have different block table,
# they will use different parts of the shared Tensor.
# The memory layout in the example will be:
# full.0, sw.0, sw.2: share a Tensor with size=available_memory//2
# full.1, sw.1: share another Tensor with size=available_memory//2
page_size = get_uniform_page_size(kv_cache_spec)
num_blocks = get_num_blocks(vllm_config, group_size, available_memory,
page_size)
per_memory_pool_size = page_size * num_blocks
kv_cache_tensors = []
for i in range(group_size):
shared_by = []
for j in range(len(kv_cache_groups)):
if i < len(grouped_layers[j]):
shared_by.append(grouped_layers[j][i])
kv_cache_tensors.append(
KVCacheTensor(size=per_memory_pool_size, shared_by=shared_by))
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=kv_cache_tensors,
kv_cache_groups=kv_cache_groups,
)
# Print the KV cache size and maximum concurrency.
num_tokens = num_blocks // len(
grouped_layers) * vllm_config.cache_config.block_size
num_tokens_str = f"{num_tokens:,}"
logger.info("GPU KV cache size: %s tokens", num_tokens_str)
max_model_len_str = f"{vllm_config.model_config.max_model_len:,}"
max_concurrency = get_max_concurrency_for_kv_cache_config(
vllm_config, kv_cache_config)
logger.info("Maximum concurrency for %s tokens per request: %.2fx",
max_model_len_str, max_concurrency)
return kv_cache_config
def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
"""
Only models with one type of KV cache are supported yet. This function tries
to convert the KV cache specs to one type if the model is a hybrid model
with multiple type of KV cache. It will convert all SlidingWindowSpec to
FullAttentionSpec if both types are present.
This function tries to convert the KV cache specs to one type if the model
is a hybrid model with multiple type of KV cache. It will convert all
SlidingWindowSpec to FullAttentionSpec if both types are present.
Args:
kv_cache_spec: The kv cache spec of each attention layer in the model
"""
def is_hybrid(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
type_ids = set(layer_spec.type_id
for layer_spec in kv_cache_spec.values())
return len(type_ids) > 1
if not is_hybrid(kv_cache_spec):
return
logger.warning(
"Hybrid KV cache manager is disabled for this hybrid model, "
"This means we do not enable any optimizations for saving KV cache "
"memory (e.g., dropping the KV cache outside the sliding window). "
"The compute of layers like sliding window is still saved.")
has_full_attention = any(
isinstance(spec, FullAttentionSpec) for spec in kv_cache_spec.values())
has_sliding_window = any(
@ -712,13 +916,18 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
sliding_window=spec.sliding_window,
)
if is_hybrid(kv_cache_spec):
raise ValueError("Hybrid KV cache manager is disabled but failed to "
"convert the KV cache specs to one unified type.")
def get_kv_cache_config(vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int) -> KVCacheConfig:
def get_kv_cache_config(
vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int,
) -> KVCacheConfig:
"""
Generates the KV cache configuration for a model
TODO: support hybrid models with more than one type of KV cache.
Generates the KV cache configuration for a model.
Args:
vllm_config: The global VllmConfig
@ -728,14 +937,25 @@ def get_kv_cache_config(vllm_config: VllmConfig,
Returns:
The generated KVCacheConfigs
"""
unify_hybrid_kv_cache_specs(kv_cache_spec)
check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)
if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager:
unify_hybrid_kv_cache_specs(kv_cache_spec)
if is_kv_cache_type_uniform(kv_cache_spec):
# KV cache of all layers are the same, which is true for
# most models. Allocate the same amount of memory for
# each layer.
return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec,
available_memory)
elif is_kv_cache_page_size_uniform(kv_cache_spec):
# Model contains multiple attention types, but KV cache of all layers
# have the same physical memory per block per layer. Split the layers
# into groups with the same number of layers, and thus same total page
# size.
return _get_kv_cache_config_uniform_page_size(vllm_config,
kv_cache_spec,
available_memory)
raise NotImplementedError

View File

@ -18,7 +18,7 @@ from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
compute_encoder_budget)
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
from vllm.v1.core.kv_cache_manager import KVCacheManager
from vllm.v1.core.sched.interface import SchedulerInterface
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
SchedulerOutput)
@ -377,7 +377,8 @@ class Scheduler(SchedulerInterface):
# KVTransfer: WAITING reqs have num_computed_tokens > 0
# after async KV recvs are completed.
else:
new_computed_blocks = KVCacheBlocks.create_empty()
new_computed_blocks = (
self.kv_cache_manager.create_empty_block_list())
num_new_local_computed_tokens = 0
num_computed_tokens = request.num_computed_tokens
@ -1010,7 +1011,7 @@ class Scheduler(SchedulerInterface):
num_computed_tokens = len(block_ids) * self.block_size
if num_computed_tokens == request.num_tokens:
num_computed_tokens -= 1
self.kv_cache_manager.single_type_manager.cache_blocks(
self.kv_cache_manager.cache_blocks(
request,
self.kv_cache_manager.req_to_block_hashes[request.request_id],
num_computed_tokens,

View File

@ -22,8 +22,7 @@ class SingleTypeKVCacheManager(ABC):
self,
kv_cache_spec: KVCacheSpec,
block_pool: BlockPool,
use_eagle: bool,
num_kv_cache_groups: int,
kv_cache_group_id: int,
caching_hash_fn: Callable,
) -> None:
"""
@ -31,9 +30,7 @@ class SingleTypeKVCacheManager(ABC):
Args:
kv_cache_spec: The kv_cache_spec for this manager.
block_pool: The block pool.
use_eagle: Whether to use eagle.
num_kv_cache_groups: The number of kv cache groups managed by this
manager.
kv_cache_group_id: The id of the kv cache group of this manager.
caching_hash_fn: The caching hash function.
"""
@ -41,9 +38,6 @@ class SingleTypeKVCacheManager(ABC):
self.kv_cache_spec = kv_cache_spec
self.block_pool = block_pool
# Needs special handling for find_longest_cache_hit if eagle is enabled
self.use_eagle = use_eagle
# Mapping from request ID to blocks to track the blocks allocated
# for each request, so that we can free the blocks when the request
# is finished.
@ -56,8 +50,8 @@ class SingleTypeKVCacheManager(ABC):
# data for reempted ones.
self.num_cached_block: dict[str, int] = {}
self.num_kv_cache_groups = num_kv_cache_groups
self.caching_hash_fn = caching_hash_fn
self.kv_cache_group_id = kv_cache_group_id
def get_num_blocks_to_allocate(
self, request_id: str, num_tokens: int,
@ -86,8 +80,7 @@ class SingleTypeKVCacheManager(ABC):
num_evictable_computed_blocks = sum(
blk.ref_cnt == 0 and not blk.is_null
for blk in new_computed_blocks)
return ((num_new_blocks + num_evictable_computed_blocks) *
self.num_kv_cache_groups)
return num_new_blocks + num_evictable_computed_blocks
def save_new_computed_blocks(
self, request_id: str,
@ -130,8 +123,7 @@ class SingleTypeKVCacheManager(ABC):
if num_new_blocks <= 0:
return []
else:
new_blocks = self.block_pool.get_new_blocks(
num_new_blocks * self.num_kv_cache_groups)
new_blocks = self.block_pool.get_new_blocks(num_new_blocks)
req_blocks.extend(new_blocks)
return new_blocks
@ -156,12 +148,19 @@ class SingleTypeKVCacheManager(ABC):
num_cached_blocks=num_cached_blocks,
num_full_blocks=num_full_blocks,
block_size=self.block_size,
kv_cache_group_id=self.kv_cache_group_id,
hash_fn=self.caching_hash_fn,
)
self.num_cached_block[request.request_id] = num_full_blocks
def free(self, request_id: str) -> None:
"""
Free the blocks for the request.
Args:
request_id: The request ID.
"""
# Default to [] in case a request is freed (aborted) before alloc.
req_blocks = self.req_to_blocks.pop(request_id, [])
@ -188,12 +187,22 @@ class SingleTypeKVCacheManager(ABC):
raise NotImplementedError
@classmethod
@abstractmethod
def find_longest_cache_hit(self, block_hashes: list[BlockHash],
max_length: int) -> list[KVCacheBlock]:
def find_longest_cache_hit(
cls,
block_hashes: list[BlockHash],
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
) -> list[list[KVCacheBlock]]:
"""
Get the longest cache hit prefix of the blocks that is not longer than
`max_length`. If no cache hit is found, return an empty list.
`max_length`. The prefix should be a common prefix hit for all the
kv cache groups in `kv_cache_group_ids`. If no cache hit is found,
return an empty list.
If eagle is enabled, drop the last matched block to force recompute the
last block to get the required hidden states for eagle drafting head.
Need to be customized for each attention type.
@ -201,12 +210,20 @@ class SingleTypeKVCacheManager(ABC):
Args:
block_hashes: The block hashes of the request.
max_length: The maximum length of the cache hit prefix.
kv_cache_group_ids: The ids of the kv cache groups.
block_pool: The block pool.
kv_cache_spec: The kv cache spec.
use_eagle: Whether to use eagle.
Returns:
A list of cached blocks with skipped blocks replaced by null block.
A list of cached blocks with skipped blocks replaced by null block
for each kv cache group in `kv_cache_group_ids`.
Return a list of length `len(kv_cache_group_ids)`, where the i-th
element is a list of cached blocks for the i-th kv cache group
in `kv_cache_group_ids`.
For example, sliding window manager should return a list like
[NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)] for block size 4 and
sliding window 8.
[[NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)]] for block size 4
and sliding window 8 and len(kv_cache_group_ids) = 1.
"""
raise NotImplementedError
@ -215,11 +232,9 @@ class SingleTypeKVCacheManager(ABC):
def remove_skipped_blocks(self, request_id: str,
num_computed_tokens: int) -> None:
"""
Remove the blocks that are no longer needed from `blocks`. The removed
blocks should be replaced by null_block. Return the removed blocks in
eviction order, where the first returned block should be evicted first.
Don't free the removed blocks in this function. Need to be customized
for each attention type.
Remove the blocks that are no longer needed from `blocks` and free the
blocks. The removed blocks should be replaced by null_block.
Need to be customized for each attention type.
Args:
request_id: The request ID.
@ -230,21 +245,36 @@ class SingleTypeKVCacheManager(ABC):
class FullAttentionManager(SingleTypeKVCacheManager):
def find_longest_cache_hit(self, block_hashes: list[BlockHash],
max_length: int) -> list[KVCacheBlock]:
computed_blocks: list[KVCacheBlock] = []
max_num_blocks = max_length // self.block_size
@classmethod
def find_longest_cache_hit(
cls,
block_hashes: list[BlockHash],
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
) -> list[list[KVCacheBlock]]:
assert isinstance(kv_cache_spec, FullAttentionSpec), (
"FullAttentionManager can only be used for full attention groups")
computed_blocks: list[list[KVCacheBlock]] = [
[] for _ in range(len(kv_cache_group_ids))
]
max_num_blocks = max_length // kv_cache_spec.block_size
for i in range(max_num_blocks):
block_hash = block_hashes[i]
# block_hashes is a chain of block hashes. If a block hash is not
# in the cached_block_hash_to_id, the following block hashes are
# not computed yet for sure.
if cached_block := self.block_pool.get_cached_block(block_hash):
computed_blocks.append(cached_block)
if cached_block := block_pool.get_cached_block(
block_hash, kv_cache_group_ids):
for j in range(len(kv_cache_group_ids)):
computed_blocks[j].append(cached_block[j])
else:
break
if self.use_eagle and len(computed_blocks) > 0:
computed_blocks.pop()
if use_eagle and len(computed_blocks[0]) > 0:
for j in range(len(kv_cache_group_ids)):
computed_blocks[j].pop()
return computed_blocks
def remove_skipped_blocks(self, request_id: str,
@ -267,45 +297,58 @@ class FullAttentionManager(SingleTypeKVCacheManager):
class SlidingWindowManager(SingleTypeKVCacheManager):
def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool,
use_eagle: bool, **kwargs) -> None:
super().__init__(kv_cache_spec, block_pool, use_eagle, **kwargs)
**kwargs) -> None:
super().__init__(kv_cache_spec, block_pool, **kwargs)
self.sliding_window = kv_cache_spec.sliding_window
self._null_block = block_pool.null_block
@classmethod
def find_longest_cache_hit(
cls,
block_hashes: list[BlockHash],
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
) -> list[list[KVCacheBlock]]:
assert isinstance(kv_cache_spec, SlidingWindowSpec), (
"SlidingWindowManager can only be used for sliding window groups")
# The number of contiguous blocks needed for prefix cache hit.
# -1 since the input token itself is also included in the window
self.sliding_window_contiguous_blocks = cdiv(
(kv_cache_spec.sliding_window - 1), self.block_size)
if self.use_eagle:
sliding_window_contiguous_blocks = cdiv(
kv_cache_spec.sliding_window - 1, kv_cache_spec.block_size)
if use_eagle:
# Need to drop the last matched block if eagle is enabled. For
# sliding window layer, we achieve this by increasing the number of
# contiguous blocks needed for prefix cache hit by one and dropping
# the last matched block.
self.sliding_window_contiguous_blocks += 1
self._null_block = block_pool.null_block
sliding_window_contiguous_blocks += 1
def find_longest_cache_hit(self, block_hashes: list[BlockHash],
max_length: int) -> list[KVCacheBlock]:
# TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to
# optimize the time complexity from O(max_num_blocks) to
# O(max_num_blocks / sliding_window_contiguous_blocks +
# sliding_window_contiguous_blocks),
# which is good for low cache hit rate scenarios.
max_num_blocks = max_length // self.block_size
computed_blocks = [self._null_block] * max_num_blocks
max_num_blocks = max_length // kv_cache_spec.block_size
computed_blocks = [[block_pool.null_block] * max_num_blocks
for _ in range(len(kv_cache_group_ids))]
num_contiguous_blocks = 0
match_found = False
# Search from right to left and early stop when a match is found.
for i in range(max_num_blocks - 1, -1, -1):
if cached_block := self.block_pool.get_cached_block(
block_hashes[i]):
computed_blocks[i] = cached_block
if cached_block := block_pool.get_cached_block(
block_hashes[i], kv_cache_group_ids):
for j in range(len(kv_cache_group_ids)):
computed_blocks[j][i] = cached_block[j]
num_contiguous_blocks += 1
if (num_contiguous_blocks
>= self.sliding_window_contiguous_blocks):
if (num_contiguous_blocks >= sliding_window_contiguous_blocks):
# Trim the trailing blocks.
# E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3]
# when sliding_window_contiguous_blocks=2.
del computed_blocks[i + num_contiguous_blocks:]
for j in range(len(kv_cache_group_ids)):
del computed_blocks[j][i + num_contiguous_blocks:]
match_found = True
break
else:
@ -313,9 +356,11 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
if not match_found:
# The first `num_contiguous_blocks` is a cache hit even if
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
del computed_blocks[num_contiguous_blocks:]
if self.use_eagle and len(computed_blocks) > 0:
computed_blocks.pop()
for j in range(len(kv_cache_group_ids)):
del computed_blocks[j][num_contiguous_blocks:]
if use_eagle and len(computed_blocks[0]) > 0:
for j in range(len(kv_cache_group_ids)):
computed_blocks[j].pop()
return computed_blocks
def remove_skipped_blocks(self, request_id: str,

View File

@ -157,11 +157,10 @@ class SlidingWindowSpec(AttentionSpec):
@dataclass
class KVCacheTensor:
"""
A dataclass for specifying how the workers should initialize the KV cache
for a layer. Only contains the size of KV cache for that layer for now. Will
be extended to support multiple layers sharing the same memory pool.
A class for specifying how the workers should initialize the KV cache.
"""
size: int # The size of KV cache Tensor in bytes
size: int # size of the KV cache tensor in bytes
shared_by: list[str] # layer names that share the same KV cache tensor
@dataclass
@ -183,27 +182,13 @@ class KVCacheConfig:
"""
"""The number of KV cache blocks"""
num_blocks: int
"""layer_name -> how to initialize KV cache for that layer"""
tensors: dict[str, KVCacheTensor]
"""How should model runner initialize the KV cache tensors for each layer"""
kv_cache_tensors: list[KVCacheTensor]
"""
The kv cache groups of the model.
The layers in the models are repeated with some patterns, e.g., a model
with 10 full attention layers and 20 sliding window attention layers can be
regarded as repeating the pattern (1 * full, 2 * sw) 10 times.
The KVCacheManager allocates different block tables for each of the 3 layers
in the pattern, and repeats each of them 10 times to generate the
block_table for the 30 layers in the model.
Therefore, we can group the layers in the model into 3 groups, each of which
contains 10 layers in the model.
The KVCacheManager allocates the block_table for each group based on its
kv_cache spec, and the model runner applies the block table to each layer
in the group.
For example:
1. A model only uses full attention. The pattern is
(num_hidden_layers * full), so there is only one group and the block table
is shared by all layers.
2. (WIP) A model with 10 full attention layers and 20 sliding window
attention layers. There are 3 layers in the pattern (1 * full, 2 * sw), so
there are 3 groups, each of which represents 10 layers in the model.
For models with only one type of attention, there is only one group that
contains all layers.
For models with multiple types of attention, there will be multiple groups,
see `_get_kv_cache_config_uniform_page_size` for more details.
"""
kv_cache_groups: list[KVCacheGroupSpec]

View File

@ -2088,33 +2088,58 @@ class GPUModelRunner(LoRAModelRunnerMixin):
block_sizes=block_sizes,
)
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
def _allocate_kv_cache_tensors(
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
"""
Initialize KV cache based on `kv_cache_config`.
Initializes the KV cache buffer with the correct size. The buffer needs
to be reshaped to the desired shape before being used by the models.
Args:
kv_cache_config: Configuration for the KV cache, including the KV
cache size of each layer
kv_cache_config: The KV cache config
Returns:
dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache.
"""
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
tensor = torch.zeros(kv_cache_tensor.size,
dtype=torch.int8,
device=self.device)
for layer_name in kv_cache_tensor.shared_by:
kv_cache_raw_tensors[layer_name] = tensor
layer_names = set()
for group in kv_cache_config.kv_cache_groups:
layer_names.update(group.layer_names)
assert layer_names == set(kv_cache_raw_tensors.keys(
)), "Some layers are not correctly initialized"
return kv_cache_raw_tensors
def _reshape_kv_cache_tensors(
self,
kv_cache_config: KVCacheConfig,
kv_cache_raw_tensors: dict[str, torch.Tensor],
) -> dict[str, torch.Tensor]:
"""
self.kv_cache_config = kv_cache_config
self.may_reinitialize_input_batch(kv_cache_config)
self.initialize_attn_backend(kv_cache_config)
Reshape the KV cache tensors to the desired shape and dtype.
Args:
kv_cache_config: The KV cache config
kv_cache_raw_tensors: The KV cache buffer of each layer, with
correct size but uninitialized shape.
Returns:
Dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache.
"""
kv_caches: dict[str, torch.Tensor] = {}
for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
kv_cache_spec = kv_cache_group.kv_cache_spec
for layer_name in kv_cache_group.layer_names:
tensor_config = kv_cache_config.tensors[layer_name]
assert tensor_config.size % kv_cache_spec.page_size_bytes == 0
num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes
# `num_blocks` is the number of blocks the model runner can use.
# `kv_cache_config.num_blocks` is the number of blocks that
# KVCacheManager may allocate.
# Since different GPUs may have different number of layers and
# different memory capacities, `num_blocks` can be different on
# different GPUs, and `kv_cache_config.num_blocks` is set to
# the min of all `num_blocks`. Verify it here.
assert num_blocks >= kv_cache_config.num_blocks
for i, kv_cache_group_spec in enumerate(
kv_cache_config.kv_cache_groups):
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
for layer_name in kv_cache_group_spec.layer_names:
raw_tensor = kv_cache_raw_tensors[layer_name]
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
num_blocks = (raw_tensor.numel() //
kv_cache_spec.page_size_bytes)
if isinstance(kv_cache_spec, AttentionSpec):
kv_cache_shape = self.attn_backends[i].get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
@ -2140,13 +2165,29 @@ class GPUModelRunner(LoRAModelRunnerMixin):
kv_cache_stride_order.index(i)
for i in range(len(kv_cache_stride_order))
]
kv_caches[layer_name] = torch.zeros(
kv_cache_shape, dtype=dtype,
device=self.device).permute(*inv_order)
kv_caches[layer_name] = kv_cache_raw_tensors[
layer_name].view(dtype).view(kv_cache_shape).permute(
*inv_order)
else:
# TODO: add new branches when introducing more types of
# KV cache specs.
raise ValueError("Unknown KV cache spec type.")
raise NotImplementedError
return kv_caches
def initialize_kv_cache_tensors(
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
"""
Initialize the memory buffer for KV cache.
Args:
kv_cache_config: The KV cache config
Returns:
Dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache.
"""
# Initialize the memory buffer for KV cache
kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config)
# Change the memory buffer to the desired shape
kv_caches = self._reshape_kv_cache_tensors(kv_cache_config,
kv_cache_raw_tensors)
# Setup `kv_cache_config` and `kv_caches` for models
# with cross-layer KV sharing
@ -2157,17 +2198,30 @@ class GPUModelRunner(LoRAModelRunnerMixin):
kv_caches,
)
bind_kv_cache(
kv_caches,
self.vllm_config.compilation_config.static_forward_context,
self.kv_caches)
return kv_caches
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize KV cache based on `kv_cache_config`.
Args:
kv_cache_config: Configuration for the KV cache, including the KV
cache size of each layer
"""
self.kv_cache_config = kv_cache_config
self.may_reinitialize_input_batch(kv_cache_config)
self.initialize_attn_backend(kv_cache_config)
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
if self.speculative_config and self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer)
# validate all draft model layers belong to the same kv cache
# group
self.drafter.validate_same_kv_cache_group(kv_cache_config)
bind_kv_cache(
kv_caches,
self.vllm_config.compilation_config.static_forward_context,
self.kv_caches)
if has_kv_transfer_group():
get_kv_transfer_group().register_kv_caches(kv_caches)

View File

@ -1365,14 +1365,20 @@ class TPUModelRunner(LoRAModelRunnerMixin):
assert self.block_table_cpu.dtype == self.input_batch.block_table[
0].get_cpu_tensor().dtype
kv_caches: dict[str, torch.Tensor] = {}
kv_cache_sizes = {}
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
assert len(kv_cache_tensor.shared_by) == 1, (
"KV cache tensor shared by multiple layers is not supported in "
"TPU.")
kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size
kv_caches: dict[str, torch.Tensor] = {}
for kv_cache_group in kv_cache_config.kv_cache_groups:
kv_cache_spec = kv_cache_group.kv_cache_spec
for layer_name in kv_cache_group.layer_names:
tensor_config = kv_cache_config.tensors[layer_name]
assert tensor_config.size % kv_cache_spec.page_size_bytes == 0
num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes
tensor_size = kv_cache_sizes[layer_name]
assert tensor_size % kv_cache_spec.page_size_bytes == 0
num_blocks = tensor_size // kv_cache_spec.page_size_bytes # noqa
if isinstance(kv_cache_spec, AttentionSpec):
if self.use_spmd:
num_kv_heads = kv_cache_spec.num_kv_heads