mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:06:10 +08:00
[v1] Hybrid Memory Allocator (#17996)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
parent
3465b87ef8
commit
f8a1a2d108
@ -15,8 +15,8 @@ from vllm.v1.core.kv_cache_manager import KVCacheManager
|
||||
from vllm.v1.core.kv_cache_utils import (
|
||||
FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics,
|
||||
estimate_max_model_len, generate_block_hash_extra_keys,
|
||||
get_max_concurrency_for_kv_cache_config, hash_block_tokens,
|
||||
hash_request_tokens, unify_kv_cache_configs)
|
||||
get_kv_cache_config, get_max_concurrency_for_kv_cache_config,
|
||||
hash_block_tokens, hash_request_tokens, unify_kv_cache_configs)
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, KVCacheTensor,
|
||||
SlidingWindowSpec)
|
||||
@ -63,6 +63,20 @@ def new_kv_cache_spec(block_size=16,
|
||||
sliding_window=sliding_window)
|
||||
|
||||
|
||||
def new_sliding_window_spec(block_size=16,
|
||||
num_kv_heads=2,
|
||||
head_size=64,
|
||||
dtype=torch.float32,
|
||||
use_mla=False,
|
||||
sliding_window=1):
|
||||
return SlidingWindowSpec(block_size=block_size,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
use_mla=use_mla,
|
||||
sliding_window=sliding_window)
|
||||
|
||||
|
||||
def test_none_hash(monkeypatch):
|
||||
import vllm.v1.core.kv_cache_utils
|
||||
|
||||
@ -403,10 +417,10 @@ def test_unify_kv_cache_configs():
|
||||
same_kv_cache_config = [
|
||||
KVCacheConfig(
|
||||
num_blocks=10,
|
||||
tensors={
|
||||
"layer1": KVCacheTensor(100),
|
||||
"layer2": KVCacheTensor(100),
|
||||
},
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(size=100, shared_by=["layer1"]),
|
||||
KVCacheTensor(size=100, shared_by=["layer2"]),
|
||||
],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
|
||||
KVCacheGroupSpec(["layer2"],
|
||||
@ -415,10 +429,10 @@ def test_unify_kv_cache_configs():
|
||||
),
|
||||
KVCacheConfig(
|
||||
num_blocks=20,
|
||||
tensors={
|
||||
"layer1": KVCacheTensor(100),
|
||||
"layer2": KVCacheTensor(100),
|
||||
},
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(size=100, shared_by=["layer1"]),
|
||||
KVCacheTensor(size=100, shared_by=["layer2"]),
|
||||
],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
|
||||
KVCacheGroupSpec(["layer2"],
|
||||
@ -433,10 +447,10 @@ def test_unify_kv_cache_configs():
|
||||
need_sort_kv_cache_config = [
|
||||
KVCacheConfig(
|
||||
num_blocks=10,
|
||||
tensors={
|
||||
"layer1": KVCacheTensor(100),
|
||||
"layer2": KVCacheTensor(100),
|
||||
},
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(size=100, shared_by=["layer1"]),
|
||||
KVCacheTensor(size=100, shared_by=["layer2"]),
|
||||
],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
|
||||
KVCacheGroupSpec(["layer2"],
|
||||
@ -445,10 +459,10 @@ def test_unify_kv_cache_configs():
|
||||
),
|
||||
KVCacheConfig(
|
||||
num_blocks=20,
|
||||
tensors={
|
||||
"layer1": KVCacheTensor(100),
|
||||
"layer2": KVCacheTensor(100),
|
||||
},
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(size=100, shared_by=["layer1"]),
|
||||
KVCacheTensor(size=100, shared_by=["layer2"]),
|
||||
],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(["layer2"],
|
||||
new_kv_cache_spec(num_kv_heads=4)),
|
||||
@ -464,10 +478,10 @@ def test_unify_kv_cache_configs():
|
||||
diff_kv_cache_config = [
|
||||
KVCacheConfig(
|
||||
num_blocks=10,
|
||||
tensors={
|
||||
"layer1": KVCacheTensor(100),
|
||||
"layer2": KVCacheTensor(100),
|
||||
},
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(size=100, shared_by=["layer1"]),
|
||||
KVCacheTensor(size=100, shared_by=["layer2"]),
|
||||
],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
|
||||
KVCacheGroupSpec(["layer2"],
|
||||
@ -476,10 +490,10 @@ def test_unify_kv_cache_configs():
|
||||
),
|
||||
KVCacheConfig(
|
||||
num_blocks=20,
|
||||
tensors={
|
||||
"layer1": KVCacheTensor(100),
|
||||
"layer2": KVCacheTensor(100),
|
||||
},
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(size=100, shared_by=["layer1"]),
|
||||
KVCacheTensor(size=100, shared_by=["layer2"]),
|
||||
],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
|
||||
KVCacheGroupSpec(["layer2"],
|
||||
@ -636,7 +650,7 @@ def test_get_max_concurrency_for_kv_cache_config():
|
||||
|
||||
kv_cache_config_full_attention = KVCacheConfig(
|
||||
num_blocks=int(1024 * 1.5),
|
||||
tensors={},
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec([f"layer_{i}" for i in range(32)],
|
||||
full_attention_spec),
|
||||
@ -648,7 +662,7 @@ def test_get_max_concurrency_for_kv_cache_config():
|
||||
|
||||
kv_cache_config_sliding_window = KVCacheConfig(
|
||||
num_blocks=129 * 3,
|
||||
tensors={},
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec([f"layer_{i}" for i in range(32)],
|
||||
sliding_window_spec),
|
||||
@ -660,7 +674,7 @@ def test_get_max_concurrency_for_kv_cache_config():
|
||||
|
||||
kv_cache_config_hybrid_model = KVCacheConfig(
|
||||
num_blocks=(1024 + 129) * 3,
|
||||
tensors={},
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec([f"layer_{i}" for i in range(32)],
|
||||
full_attention_spec),
|
||||
@ -678,9 +692,9 @@ def test_allocate_with_lookahead():
|
||||
block_size = 4
|
||||
config = KVCacheConfig(
|
||||
num_blocks=10,
|
||||
tensors={
|
||||
"layer1": KVCacheTensor(100),
|
||||
},
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(size=100, shared_by=["layer1"]),
|
||||
],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(["layer1"],
|
||||
new_kv_cache_spec(block_size=block_size)),
|
||||
@ -702,7 +716,7 @@ def test_allocate_with_lookahead():
|
||||
num_new_tokens=3,
|
||||
num_lookahead_tokens=2, # Total required: 3+2=5 tokens
|
||||
)
|
||||
assert len(blocks.blocks) == 2 # ceil(5/4)=2 blocks
|
||||
assert len(blocks.get_block_ids()[0]) == 2 # ceil(5/4)=2 blocks
|
||||
|
||||
# Test case 2: With precomputed blocks
|
||||
kv_cache_manager = KVCacheManager(kv_cache_config=config,
|
||||
@ -713,7 +727,7 @@ def test_allocate_with_lookahead():
|
||||
num_new_tokens=3,
|
||||
num_lookahead_tokens=2,
|
||||
)
|
||||
assert len(blocks.blocks) == 2
|
||||
assert len(blocks.get_block_ids()[0]) == 2
|
||||
|
||||
# Test case 3: With precomputed blocks
|
||||
# required_blocks = ceil((3 + 4) / 4) = 2
|
||||
@ -724,4 +738,165 @@ def test_allocate_with_lookahead():
|
||||
num_new_tokens=3,
|
||||
num_lookahead_tokens=4,
|
||||
)
|
||||
assert len(blocks.blocks) == 2
|
||||
assert len(blocks.get_block_ids()[0]) == 2
|
||||
|
||||
|
||||
def test_get_kv_cache_config():
|
||||
# pass max_model_len to pass check_enough_kv_cache_memory
|
||||
model_config = ModelConfig(max_model_len=16)
|
||||
vllm_config = VllmConfig(model_config=model_config)
|
||||
|
||||
mem_per_block_per_layer = 16 * 2 * 64 * 4 * 2
|
||||
# all layers are full attention -> single group
|
||||
kv_cache_specs_full = {
|
||||
'layer_1': new_kv_cache_spec(),
|
||||
'layer_2': new_kv_cache_spec(),
|
||||
}
|
||||
kv_cache_config_full = get_kv_cache_config(
|
||||
vllm_config, kv_cache_specs_full, mem_per_block_per_layer * 2 * 32)
|
||||
assert kv_cache_config_full == KVCacheConfig(
|
||||
num_blocks=32,
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(size=mem_per_block_per_layer * 32,
|
||||
shared_by=["layer_1"]),
|
||||
KVCacheTensor(size=mem_per_block_per_layer * 32,
|
||||
shared_by=["layer_2"]),
|
||||
],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec())
|
||||
])
|
||||
|
||||
# all layers are sliding window -> single group
|
||||
kv_cache_specs_sliding = {
|
||||
'layer_1': new_sliding_window_spec(),
|
||||
'layer_2': new_sliding_window_spec(),
|
||||
}
|
||||
kv_cache_config_sliding = get_kv_cache_config(
|
||||
vllm_config, kv_cache_specs_sliding, mem_per_block_per_layer * 2 * 32)
|
||||
assert kv_cache_config_sliding == KVCacheConfig(
|
||||
num_blocks=32,
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(size=mem_per_block_per_layer * 32,
|
||||
shared_by=["layer_1"]),
|
||||
KVCacheTensor(size=mem_per_block_per_layer * 32,
|
||||
shared_by=["layer_2"]),
|
||||
],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(["layer_1", "layer_2"], new_sliding_window_spec())
|
||||
])
|
||||
|
||||
# full + sliding, but disable_hybrid_kv_cache_manager
|
||||
vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = True
|
||||
kv_cache_specs_hybrid = {
|
||||
'layer_1': new_kv_cache_spec(),
|
||||
'layer_2': new_sliding_window_spec(),
|
||||
}
|
||||
kv_cache_config_hybrid = get_kv_cache_config(
|
||||
vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 2 * 32)
|
||||
assert kv_cache_config_hybrid == KVCacheConfig(
|
||||
num_blocks=32,
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(size=mem_per_block_per_layer * 32,
|
||||
shared_by=["layer_1"]),
|
||||
KVCacheTensor(size=mem_per_block_per_layer * 32,
|
||||
shared_by=["layer_2"]),
|
||||
],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(["layer_1", "layer_2"],
|
||||
new_kv_cache_spec(sliding_window=1)),
|
||||
],
|
||||
)
|
||||
vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = False
|
||||
|
||||
# full + sliding, with hybrid_kv_cache_manager
|
||||
kv_cache_specs_hybrid = {
|
||||
'layer_1': new_kv_cache_spec(),
|
||||
'layer_2': new_sliding_window_spec(),
|
||||
}
|
||||
kv_cache_config_hybrid = get_kv_cache_config(
|
||||
vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 2 * 32)
|
||||
assert kv_cache_config_hybrid == KVCacheConfig(
|
||||
num_blocks=64,
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(size=mem_per_block_per_layer * 64,
|
||||
shared_by=["layer_1", "layer_2"]),
|
||||
],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(["layer_1"], new_kv_cache_spec()),
|
||||
KVCacheGroupSpec(["layer_2"], new_sliding_window_spec()),
|
||||
],
|
||||
)
|
||||
|
||||
# 2 full + 4 sliding, 2 layers per group
|
||||
kv_cache_specs_hybrid = {
|
||||
'layer_1': new_kv_cache_spec(),
|
||||
'layer_2': new_kv_cache_spec(),
|
||||
'layer_3': new_sliding_window_spec(),
|
||||
'layer_4': new_sliding_window_spec(),
|
||||
'layer_5': new_sliding_window_spec(),
|
||||
'layer_6': new_sliding_window_spec(),
|
||||
}
|
||||
kv_cache_config_hybrid = get_kv_cache_config(
|
||||
vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 2 * 32)
|
||||
assert kv_cache_config_hybrid == KVCacheConfig(
|
||||
num_blocks=32,
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(size=mem_per_block_per_layer * 32,
|
||||
shared_by=["layer_1", "layer_3", "layer_5"]),
|
||||
KVCacheTensor(size=mem_per_block_per_layer * 32,
|
||||
shared_by=["layer_2", "layer_4", "layer_6"]),
|
||||
],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec()),
|
||||
KVCacheGroupSpec(["layer_3", "layer_4"],
|
||||
new_sliding_window_spec()),
|
||||
KVCacheGroupSpec(["layer_5", "layer_6"],
|
||||
new_sliding_window_spec()),
|
||||
],
|
||||
)
|
||||
|
||||
# 3 full + 7 sliding, pad to 3 full + 9 sliding
|
||||
kv_cache_specs_hybrid = {
|
||||
'layer_1': new_kv_cache_spec(),
|
||||
'layer_2': new_kv_cache_spec(),
|
||||
'layer_3': new_kv_cache_spec(),
|
||||
'layer_4': new_sliding_window_spec(),
|
||||
'layer_5': new_sliding_window_spec(),
|
||||
'layer_6': new_sliding_window_spec(),
|
||||
'layer_7': new_sliding_window_spec(),
|
||||
'layer_8': new_sliding_window_spec(),
|
||||
'layer_9': new_sliding_window_spec(),
|
||||
'layer_10': new_sliding_window_spec(),
|
||||
}
|
||||
kv_cache_config_hybrid = get_kv_cache_config(
|
||||
vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 3 * 32)
|
||||
assert kv_cache_config_hybrid == KVCacheConfig(
|
||||
num_blocks=32,
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(
|
||||
size=mem_per_block_per_layer * 32,
|
||||
shared_by=["layer_1", "layer_4", "layer_7", "layer_10"]),
|
||||
KVCacheTensor(size=mem_per_block_per_layer * 32,
|
||||
shared_by=["layer_2", "layer_5", "layer_8"]),
|
||||
KVCacheTensor(size=mem_per_block_per_layer * 32,
|
||||
shared_by=["layer_3", "layer_6", "layer_9"]),
|
||||
],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(["layer_1", "layer_2", "layer_3"],
|
||||
new_kv_cache_spec()),
|
||||
KVCacheGroupSpec(["layer_4", "layer_5", "layer_6"],
|
||||
new_sliding_window_spec()),
|
||||
KVCacheGroupSpec(["layer_7", "layer_8", "layer_9"],
|
||||
new_sliding_window_spec()),
|
||||
KVCacheGroupSpec(["layer_10"], new_sliding_window_spec()),
|
||||
],
|
||||
)
|
||||
|
||||
# different hidden size, unimplemented
|
||||
kv_cache_specs_hybrid = {
|
||||
'layer_1': new_kv_cache_spec(head_size=128),
|
||||
'layer_2': new_kv_cache_spec(),
|
||||
}
|
||||
with pytest.raises(NotImplementedError):
|
||||
get_kv_cache_config(vllm_config, kv_cache_specs_hybrid,
|
||||
mem_per_block_per_layer * 2 * 32)
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Compare the with and without prefix caching."""
|
||||
|
||||
import copy
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
@ -13,8 +14,8 @@ from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import sha256
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
|
||||
hash_block_tokens)
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
|
||||
KVCacheBlock, hash_block_tokens)
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, SlidingWindowSpec)
|
||||
|
||||
@ -47,7 +48,7 @@ def make_request(request_id,
|
||||
def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
|
||||
return KVCacheConfig(
|
||||
num_blocks=num_blocks,
|
||||
tensors={},
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["layer"],
|
||||
@ -57,6 +58,38 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
|
||||
)
|
||||
|
||||
|
||||
def make_kv_cache_config_hybrid_model(block_size: int,
|
||||
num_blocks: int) -> KVCacheConfig:
|
||||
return KVCacheConfig(
|
||||
num_blocks=num_blocks,
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["layer1"],
|
||||
FullAttentionSpec(block_size, 1, 1, torch.float32, False),
|
||||
),
|
||||
KVCacheGroupSpec(
|
||||
["layer2"],
|
||||
SlidingWindowSpec(block_size,
|
||||
1,
|
||||
1,
|
||||
torch.float32,
|
||||
False,
|
||||
sliding_window=2 * block_size),
|
||||
),
|
||||
KVCacheGroupSpec(
|
||||
["layer3"],
|
||||
SlidingWindowSpec(block_size,
|
||||
1,
|
||||
1,
|
||||
torch.float32,
|
||||
False,
|
||||
sliding_window=2 * block_size),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hash_algo", ["sha256", "hash"])
|
||||
def test_prefill(hash_algo):
|
||||
manager = KVCacheManager(
|
||||
@ -79,10 +112,10 @@ def test_prefill(hash_algo):
|
||||
req0 = make_request("0", all_token_ids)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert len(manager.req_to_block_hashes[req0.request_id]) == 3
|
||||
assert not computed_blocks.blocks
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req0, 55,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
|
||||
|
||||
@ -92,7 +125,8 @@ def test_prefill(hash_algo):
|
||||
block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16])
|
||||
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
|
||||
block_tokens)
|
||||
assert manager.block_pool.blocks[block_id].block_hash == block_hash
|
||||
assert manager.block_pool.blocks[
|
||||
block_id].block_hash.block_hash == block_hash
|
||||
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
||||
parent_block_hash = block_hash.hash_value
|
||||
|
||||
@ -111,10 +145,10 @@ def test_prefill(hash_algo):
|
||||
assert num_computed_tokens == 3 * 16
|
||||
num_new_tokens = 53 - 3 * 16
|
||||
blocks = manager.allocate_slots(req1, num_new_tokens,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[5]]
|
||||
for block in computed_blocks.blocks:
|
||||
for block in computed_blocks.blocks[0]:
|
||||
assert block.ref_cnt == 2
|
||||
|
||||
# At this point, we should have 5 free blocks left.
|
||||
@ -145,7 +179,7 @@ def test_prefill(hash_algo):
|
||||
assert num_computed_tokens == 3 * 16
|
||||
num_new_tokens = 53 - 3 * 16
|
||||
blocks = manager.allocate_slots(req2, num_new_tokens,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[6]]
|
||||
|
||||
@ -165,10 +199,10 @@ def test_prefill(hash_algo):
|
||||
# Cache miss and eviction.
|
||||
req3 = make_request("3", [99] * (16 * 10))
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
|
||||
assert not computed_blocks.blocks
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req3, 16 * 10,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
# This block ID order also checks the eviction order.
|
||||
assert blocks.get_block_ids() == [[7, 8, 9, 10, 4, 5, 6, 3, 2, 1]]
|
||||
@ -177,6 +211,138 @@ def test_prefill(hash_algo):
|
||||
assert manager.block_pool.free_block_queue.free_list_tail is None
|
||||
|
||||
|
||||
def test_prefill_hybrid_model():
|
||||
block_size = 16
|
||||
manager = KVCacheManager(
|
||||
make_kv_cache_config_hybrid_model(block_size, 21),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
)
|
||||
|
||||
hash_fn = hash
|
||||
|
||||
# Complete 3 blocks (48 tokens)
|
||||
common_token_ids = [i for i in range(3) for _ in range(block_size)]
|
||||
|
||||
# Fully cache miss
|
||||
# Incomplete 1 block (7 tokens)
|
||||
unique_token_ids = [3] * 7
|
||||
all_token_ids = common_token_ids + unique_token_ids
|
||||
req0 = make_request("0", all_token_ids)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert len(manager.req_to_block_hashes[req0.request_id]) == 3
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req0, 55,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[1, 2, 3, 4], [5, 6, 7, 8],
|
||||
[9, 10, 11, 12]]
|
||||
|
||||
# Check full block metadata
|
||||
parent_block_hash = None
|
||||
for length, block_ids in zip((1, 2, 3),
|
||||
((1, 5, 9), (2, 6, 10), (3, 7, 11))):
|
||||
block_tokens = tuple(all_token_ids[(length - 1) * 16:length * 16])
|
||||
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
|
||||
block_tokens)
|
||||
for block_id in block_ids:
|
||||
assert manager.block_pool.blocks[
|
||||
block_id].block_hash.block_hash == block_hash
|
||||
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
||||
parent_block_hash = block_hash.hash_value
|
||||
|
||||
# Check partial block metadata
|
||||
for block_id in (4, 8, 12):
|
||||
assert manager.block_pool.blocks[block_id].block_hash is None
|
||||
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
||||
|
||||
# Cache hit in the common prefix
|
||||
# Incomplete 1 block (5 tokens)
|
||||
unique_token_ids = [3] * 5
|
||||
req1 = make_request("1", common_token_ids + unique_token_ids)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
|
||||
assert computed_blocks.get_block_ids() == [[1, 2, 3], [0, 6, 7],
|
||||
[0, 10, 11]]
|
||||
assert num_computed_tokens == 3 * 16
|
||||
num_new_tokens = 53 - 3 * 16
|
||||
blocks = manager.allocate_slots(req1, num_new_tokens,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[13], [14], [15]]
|
||||
for block_per_group in computed_blocks.blocks:
|
||||
for block in block_per_group:
|
||||
if block != manager.block_pool.null_block:
|
||||
assert block.ref_cnt == 2
|
||||
|
||||
block_hashes = manager.req_to_block_hashes[req1.request_id]
|
||||
manager.free(req0)
|
||||
manager.free(req1)
|
||||
|
||||
cached_block_hash_to_block_bak = copy.copy(
|
||||
manager.block_pool.cached_block_hash_to_block)
|
||||
|
||||
def test_partial_request_hit(request_id: str,
|
||||
hash_to_evict: list[BlockHashWithGroupId],
|
||||
expect_hit_length: int):
|
||||
req = make_request(request_id, common_token_ids + unique_token_ids)
|
||||
for hash_with_group_id in hash_to_evict:
|
||||
manager.block_pool.cached_block_hash_to_block.pop(
|
||||
hash_with_group_id)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
||||
assert len(manager.req_to_block_hashes[req.request_id]) == 3
|
||||
assert num_computed_tokens == expect_hit_length * block_size
|
||||
for block_per_group in computed_blocks.blocks:
|
||||
assert len(block_per_group) == num_computed_tokens // block_size
|
||||
for hash_with_group_id in hash_to_evict:
|
||||
manager.block_pool.cached_block_hash_to_block[
|
||||
hash_with_group_id] = cached_block_hash_to_block_bak[
|
||||
hash_with_group_id]
|
||||
manager.free(req)
|
||||
|
||||
# Evict the blocks outside sliding window, does not affect the hit length.
|
||||
test_partial_request_hit("2", [
|
||||
BlockHashWithGroupId(block_hashes[0], 1),
|
||||
BlockHashWithGroupId(block_hashes[0], 2)
|
||||
], 3)
|
||||
|
||||
# Evict the first block of full attention, makes total cache miss.
|
||||
test_partial_request_hit("3", [
|
||||
BlockHashWithGroupId(block_hashes[0], 0),
|
||||
], 0)
|
||||
|
||||
# Evict the last block of all layers, reduces the hit length to 2.
|
||||
test_partial_request_hit("4", [
|
||||
BlockHashWithGroupId(block_hashes[2], 0),
|
||||
BlockHashWithGroupId(block_hashes[2], 1),
|
||||
BlockHashWithGroupId(block_hashes[2], 2),
|
||||
], 2)
|
||||
|
||||
# Evict the last block of full attention, reduces the hit length to 2.
|
||||
test_partial_request_hit("5", [BlockHashWithGroupId(block_hashes[2], 0)],
|
||||
2)
|
||||
|
||||
# Evict the last block of sliding window, reduces the hit length to 2.
|
||||
test_partial_request_hit("6", [BlockHashWithGroupId(block_hashes[2], 1)],
|
||||
2)
|
||||
|
||||
# Evict the last block of sliding window, reduces the hit length to 2.
|
||||
test_partial_request_hit("7", [BlockHashWithGroupId(block_hashes[2], 2)],
|
||||
2)
|
||||
|
||||
# Evict different set of blocks for full attention and sliding window makes
|
||||
# total cache miss.
|
||||
# The cache hit length of full attention is 1 * block_size.
|
||||
# The cache hit length of sliding window is 2 * block_size.
|
||||
# Then it is cache miss as the two type of layers have different hit length.
|
||||
test_partial_request_hit("8", [
|
||||
BlockHashWithGroupId(block_hashes[2], 0),
|
||||
BlockHashWithGroupId(block_hashes[0], 1),
|
||||
BlockHashWithGroupId(block_hashes[0], 2),
|
||||
], 0)
|
||||
|
||||
|
||||
def test_prefill_plp():
|
||||
'''Test prefill with APC and some prompt logprobs (plp) requests.
|
||||
|
||||
@ -203,13 +369,13 @@ def test_prefill_plp():
|
||||
req0 = make_request("0", all_token_ids, prompt_logprobs=5)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert len(manager.req_to_block_hashes[req0.request_id]) == 0
|
||||
assert not computed_blocks.blocks
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req0, 55,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
|
||||
req0_block_hashes = [b.block_hash for b in blocks.blocks]
|
||||
req0_block_hashes = [b.block_hash for b in blocks.blocks[0]]
|
||||
|
||||
# Check full block metadata
|
||||
parent_block_hash = None
|
||||
@ -217,7 +383,8 @@ def test_prefill_plp():
|
||||
block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16])
|
||||
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
|
||||
block_tokens)
|
||||
assert manager.block_pool.blocks[block_id].block_hash == block_hash
|
||||
assert manager.block_pool.blocks[
|
||||
block_id].block_hash.block_hash == block_hash
|
||||
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
||||
parent_block_hash = block_hash.hash_value
|
||||
|
||||
@ -237,10 +404,10 @@ def test_prefill_plp():
|
||||
assert num_computed_tokens == 3 * 16
|
||||
num_new_tokens = 53 - 3 * 16
|
||||
blocks = manager.allocate_slots(req1, num_new_tokens,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[5]]
|
||||
for block in computed_blocks.blocks:
|
||||
for block in computed_blocks.blocks[0]:
|
||||
assert block.ref_cnt == 2
|
||||
|
||||
# At this point, we should have 5 free blocks left.
|
||||
@ -269,14 +436,14 @@ def test_prefill_plp():
|
||||
prompt_logprobs=5)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert len(manager.req_to_block_hashes[req2.request_id]) == 0
|
||||
assert not computed_blocks.blocks
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req2, 55,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
block_ids = blocks.get_block_ids()
|
||||
# Duplicate cached blocks have different ids but same hashes vs request #0
|
||||
assert [b.block_hash for b in blocks.blocks] == req0_block_hashes
|
||||
assert [b.block_hash for b in blocks.blocks[0]] == req0_block_hashes
|
||||
assert block_ids != [[1, 2, 3, 4]]
|
||||
|
||||
# Request #2 block hashes are valid since request #0 hashes are.
|
||||
@ -302,10 +469,10 @@ def test_decode():
|
||||
unique_token_ids = [3] * 7
|
||||
req0 = make_request("0", common_token_ids + unique_token_ids)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert not computed_blocks.blocks
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req0, 55,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
|
||||
|
||||
@ -314,10 +481,10 @@ def test_decode():
|
||||
for _ in range(4):
|
||||
req0.append_output_token_ids(8)
|
||||
new_blocks = manager.allocate_slots(req0, 4,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert new_blocks is not None and len(new_blocks.blocks) == 0
|
||||
assert manager.single_type_manager.req_to_blocks[
|
||||
assert new_blocks is not None and len(new_blocks.blocks[0]) == 0
|
||||
assert manager.coordinator.single_type_managers[0].req_to_blocks[
|
||||
req0.request_id][-1].block_hash is None
|
||||
|
||||
# Append slots with allocating a new block.
|
||||
@ -327,12 +494,12 @@ def test_decode():
|
||||
for _ in range(9 + 10):
|
||||
req0.append_output_token_ids(7)
|
||||
new_blocks = manager.allocate_slots(req0, 19,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert new_blocks is not None and len(new_blocks.blocks) == 1
|
||||
assert manager.single_type_manager.req_to_blocks[
|
||||
assert new_blocks is not None and len(new_blocks.blocks[0]) == 1
|
||||
assert manager.coordinator.single_type_managers[0].req_to_blocks[
|
||||
req0.request_id][-2].block_hash is not None
|
||||
assert manager.single_type_manager.req_to_blocks[
|
||||
assert manager.coordinator.single_type_managers[0].req_to_blocks[
|
||||
req0.request_id][-1].block_hash is None
|
||||
|
||||
|
||||
@ -346,23 +513,23 @@ def test_evict():
|
||||
last_token_id = 5 * 16 + 7
|
||||
req0 = make_request("0", list(range(last_token_id)))
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert not computed_blocks.blocks
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req0, 5 * 16 + 7,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks) == 6 # 5 full + 1 partial
|
||||
assert len(blocks.blocks[0]) == 6 # 5 full + 1 partial
|
||||
|
||||
# 3 blocks.
|
||||
req1 = make_request("1", list(range(last_token_id,
|
||||
last_token_id + 3 * 16)))
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert not computed_blocks.blocks
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req1, 3 * 16,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks) == 3 # 3 full blocks
|
||||
assert len(blocks.blocks[0]) == 3 # 3 full blocks
|
||||
last_token_id += 3 * 16
|
||||
|
||||
# 10 - (6 + 3) == 1
|
||||
@ -382,7 +549,7 @@ def test_evict():
|
||||
assert computed_blocks.get_block_ids() == [[1, 2]]
|
||||
assert num_computed_tokens == 2 * 16
|
||||
blocks = manager.allocate_slots(req2, 3,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[10]]
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 7
|
||||
@ -404,12 +571,12 @@ def test_hash_block_correct_reuse():
|
||||
num_tokens = block_size * 1
|
||||
req = make_request("0", list(range(num_tokens)))
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
||||
assert not computed_blocks.blocks
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req, num_tokens,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks) == 1
|
||||
assert len(blocks.blocks[0]) == 1
|
||||
|
||||
# Deallocate the block.
|
||||
manager.free(req)
|
||||
@ -418,15 +585,15 @@ def test_hash_block_correct_reuse():
|
||||
# block is cleared.
|
||||
req = make_request("1", list(range(num_tokens - 1)))
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
||||
assert not computed_blocks.blocks
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req, num_tokens - 1,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks) == 1
|
||||
assert len(blocks.blocks[0]) == 1
|
||||
|
||||
assert manager.block_pool.blocks[
|
||||
blocks.blocks[0].block_id].block_hash is None
|
||||
assert manager.block_pool.blocks[blocks.blocks[0]
|
||||
[0].block_id].block_hash is None
|
||||
|
||||
|
||||
def test_computed_blocks_not_evicted():
|
||||
@ -445,24 +612,24 @@ def test_computed_blocks_not_evicted():
|
||||
num_tokens = block_size * 1
|
||||
req0 = make_request("0", list(range(num_tokens)))
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert not computed_blocks.blocks
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req0, num_tokens,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks) == 1
|
||||
assert blocks.blocks[0].block_id == 1
|
||||
assert len(blocks.blocks[0]) == 1
|
||||
assert blocks.blocks[0][0].block_id == 1
|
||||
|
||||
# Allocate another block.
|
||||
req1 = make_request("1", list(range(num_tokens, num_tokens * 2)))
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert not computed_blocks.blocks
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req1, num_tokens,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks) == 1
|
||||
assert blocks.blocks[0].block_id == 2
|
||||
assert len(blocks.blocks[0]) == 1
|
||||
assert blocks.blocks[0][0].block_id == 2
|
||||
|
||||
# Free the blocks.
|
||||
manager.free(req0)
|
||||
@ -472,15 +639,15 @@ def test_computed_blocks_not_evicted():
|
||||
# cached block rather than the first one.
|
||||
req2 = make_request("2", list(range(num_tokens * 2)))
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert len(computed_blocks.blocks) == 1
|
||||
assert computed_blocks.blocks[0].block_id == 1
|
||||
assert len(computed_blocks.blocks[0]) == 1
|
||||
assert computed_blocks.blocks[0][0].block_id == 1
|
||||
assert num_computed_tokens == block_size
|
||||
|
||||
blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks) == 1
|
||||
assert blocks.blocks[0].block_id == 2
|
||||
assert len(blocks.blocks[0]) == 1
|
||||
assert blocks.blocks[0][0].block_id == 2
|
||||
|
||||
|
||||
def test_basic_prefix_caching_disabled():
|
||||
@ -497,12 +664,12 @@ def test_basic_prefix_caching_disabled():
|
||||
req1 = make_request("1", list(range(10))) # 2 blocks and some more
|
||||
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert not computed_blocks.blocks
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req1, 10,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks) == 3
|
||||
assert len(blocks.blocks[0]) == 3
|
||||
|
||||
# Free the blocks.
|
||||
manager.free(req1)
|
||||
@ -510,20 +677,20 @@ def test_basic_prefix_caching_disabled():
|
||||
# No caching.
|
||||
req2 = make_request("2", list(range(16))) # shared prefix
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert not computed_blocks.blocks
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req2, 16,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks) == 4
|
||||
assert len(blocks.blocks[0]) == 4
|
||||
|
||||
# New requests should not have any blocks.
|
||||
req3 = make_request("3", list(range(4)))
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
|
||||
assert not computed_blocks.blocks
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req3, 4,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert not blocks
|
||||
|
||||
@ -558,6 +725,7 @@ def test_cache_blocks(hash_fn):
|
||||
num_full_blocks=2,
|
||||
block_size=block_size,
|
||||
hash_fn=hash_fn,
|
||||
kv_cache_group_id=0,
|
||||
)
|
||||
|
||||
assert len(block_pool.cached_block_hash_to_block) == 2
|
||||
@ -573,11 +741,83 @@ def test_cache_blocks(hash_fn):
|
||||
num_full_blocks=3,
|
||||
block_size=block_size,
|
||||
hash_fn=hash_fn,
|
||||
kv_cache_group_id=0,
|
||||
)
|
||||
assert len(block_pool.cached_block_hash_to_block) == 3
|
||||
assert blocks[0].block_hash is not None
|
||||
|
||||
|
||||
def test_cache_blocks_multi_group():
|
||||
"""
|
||||
This tests that blocks are cached correctly for different kv cache groups.
|
||||
"""
|
||||
block_size = 4
|
||||
block_pool = BlockPool(num_gpu_blocks=10, enable_caching=True)
|
||||
|
||||
# Req:
|
||||
# Block 0/4: [0, 1, 2, 3]
|
||||
# Block 1/5: [4, 5, 6, 7]
|
||||
# Block 2/6: [8, 9, 10, 11]
|
||||
# Block 3/7: [12, 13]
|
||||
req = make_request("0", list(range(14)))
|
||||
|
||||
# Cache the blocks for group 0.
|
||||
blocks = [KVCacheBlock(block_id=i) for i in range(2)]
|
||||
block_hashes: list[BlockHash] = []
|
||||
block_pool.cache_full_blocks(
|
||||
request=req,
|
||||
blocks=blocks,
|
||||
block_hashes=block_hashes,
|
||||
num_cached_blocks=0,
|
||||
num_full_blocks=2,
|
||||
block_size=block_size,
|
||||
hash_fn=hash,
|
||||
kv_cache_group_id=0,
|
||||
)
|
||||
assert len(block_pool.cached_block_hash_to_block) == 2
|
||||
assert len(block_hashes) == 2
|
||||
assert all([block.block_hash is not None for block in blocks])
|
||||
|
||||
# Cache the blocks for group 1.
|
||||
blocks = [KVCacheBlock(block_id=i) for i in range(3)]
|
||||
block_pool.cache_full_blocks(
|
||||
request=req,
|
||||
blocks=blocks,
|
||||
block_hashes=block_hashes,
|
||||
num_cached_blocks=0,
|
||||
num_full_blocks=3,
|
||||
block_size=block_size,
|
||||
hash_fn=hash,
|
||||
kv_cache_group_id=1,
|
||||
)
|
||||
assert len(block_pool.cached_block_hash_to_block) == 5
|
||||
assert len(block_hashes) == 3
|
||||
assert all([block.block_hash is not None for block in blocks])
|
||||
|
||||
# Block hash 0: hit for group 0 and 1
|
||||
# Block hash 1: hit for group 0 and 1
|
||||
# Block hash 2: hit for group 1
|
||||
|
||||
assert block_pool.get_cached_block(block_hashes[0],
|
||||
kv_cache_group_ids=[0]) is not None
|
||||
assert block_pool.get_cached_block(block_hashes[1],
|
||||
kv_cache_group_ids=[0]) is not None
|
||||
assert block_pool.get_cached_block(block_hashes[2],
|
||||
kv_cache_group_ids=[0]) is None
|
||||
assert block_pool.get_cached_block(block_hashes[0],
|
||||
kv_cache_group_ids=[1]) is not None
|
||||
assert block_pool.get_cached_block(block_hashes[1],
|
||||
kv_cache_group_ids=[1]) is not None
|
||||
assert block_pool.get_cached_block(block_hashes[2],
|
||||
kv_cache_group_ids=[1]) is not None
|
||||
assert block_pool.get_cached_block(block_hashes[0],
|
||||
kv_cache_group_ids=[0, 1]) is not None
|
||||
assert block_pool.get_cached_block(block_hashes[1],
|
||||
kv_cache_group_ids=[0, 1]) is not None
|
||||
assert block_pool.get_cached_block(block_hashes[2],
|
||||
kv_cache_group_ids=[0, 1]) is None
|
||||
|
||||
|
||||
def test_mm_prefix_caching():
|
||||
"""
|
||||
This tests that the multi-modal prefix caching is correct.
|
||||
@ -614,7 +854,7 @@ def test_mm_prefix_caching():
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
|
||||
# Completed block should have hashes with extra keys.
|
||||
assert not computed_blocks.blocks
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
block_hashes = manager.req_to_block_hashes[req0.request_id]
|
||||
assert len(block_hashes) == 3
|
||||
@ -623,7 +863,7 @@ def test_mm_prefix_caching():
|
||||
assert block_hashes[2].extra_keys == ("bbb", )
|
||||
|
||||
blocks = manager.allocate_slots(req0, 59,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
|
||||
req0.num_computed_tokens = 59
|
||||
@ -632,9 +872,9 @@ def test_mm_prefix_caching():
|
||||
for _ in range(5):
|
||||
req0.append_output_token_ids(8)
|
||||
new_blocks = manager.allocate_slots(req0, 5,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert new_blocks is not None and len(new_blocks.blocks) == 0
|
||||
assert new_blocks is not None and len(new_blocks.blocks[0]) == 0
|
||||
|
||||
# The just completed block should have hashes with extra keys.
|
||||
assert len(block_hashes) == 4
|
||||
@ -652,7 +892,7 @@ def test_mm_prefix_caching():
|
||||
mm_positions=mm_positions,
|
||||
mm_hashes=mm_hashes)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert len(computed_blocks.blocks) == 3
|
||||
assert len(computed_blocks.blocks[0]) == 3
|
||||
assert num_computed_tokens == 3 * 16
|
||||
|
||||
|
||||
@ -675,7 +915,7 @@ def test_cache_key_salting():
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
|
||||
# Completed block should have hashes with extra keys.
|
||||
assert not computed_blocks.blocks
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
block_hashes = manager.req_to_block_hashes[req0.request_id]
|
||||
assert len(block_hashes) == 3
|
||||
@ -684,7 +924,7 @@ def test_cache_key_salting():
|
||||
assert block_hashes[2].extra_keys is None
|
||||
|
||||
blocks = manager.allocate_slots(req0, 59,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
|
||||
req0.num_computed_tokens = 59
|
||||
@ -693,9 +933,9 @@ def test_cache_key_salting():
|
||||
for _ in range(5):
|
||||
req0.append_output_token_ids(8)
|
||||
new_blocks = manager.allocate_slots(req0, 5,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert new_blocks is not None and len(new_blocks.blocks) == 0
|
||||
assert new_blocks is not None and len(new_blocks.blocks[0]) == 0
|
||||
|
||||
# Now one more block that should not have extra keys.
|
||||
assert len(block_hashes) == 4
|
||||
@ -706,14 +946,14 @@ def test_cache_key_salting():
|
||||
req1 = make_request("1", token_ids, cache_salt="salt1")
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
# Should match only a prefix of 3 blocks.
|
||||
assert len(computed_blocks.blocks) == 3
|
||||
assert len(computed_blocks.blocks[0]) == 3
|
||||
assert num_computed_tokens == 3 * block_size
|
||||
|
||||
# Test cache miss with same content but different salt.
|
||||
token_ids = common_token_ids + [4] * 11
|
||||
req2 = make_request("2", token_ids, cache_salt="salt2")
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert len(computed_blocks.blocks) == 0
|
||||
assert len(computed_blocks.blocks[0]) == 0
|
||||
assert num_computed_tokens == 0
|
||||
block_hashes = manager.req_to_block_hashes[req2.request_id]
|
||||
assert len(block_hashes) == 3
|
||||
@ -738,20 +978,24 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||
common_token_ids = [i for i in range(3) for _ in range(16)]
|
||||
req0 = make_request("0", common_token_ids)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert not computed_blocks.blocks
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
manager.allocate_slots(req0, 48,
|
||||
len(computed_blocks.blocks) * 16, computed_blocks)
|
||||
block_part0 = manager.single_type_manager.req_to_blocks[req0.request_id]
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
block_part0 = manager.coordinator.single_type_managers[0].req_to_blocks[
|
||||
req0.request_id]
|
||||
|
||||
# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
|
||||
req1 = make_request("1", common_token_ids * 2)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert computed_blocks.blocks == block_part0
|
||||
assert computed_blocks.blocks[0] == block_part0
|
||||
assert num_computed_tokens == 3 * 16
|
||||
manager.allocate_slots(req1, 48,
|
||||
len(computed_blocks.blocks) * 16, computed_blocks)
|
||||
block_part1 = manager.single_type_manager.req_to_blocks[req1.request_id]
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
block_part1 = manager.coordinator.single_type_managers[0].req_to_blocks[
|
||||
req1.request_id]
|
||||
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
|
||||
# | Req1-5(F)| ... |
|
||||
manager.free(req1)
|
||||
@ -762,10 +1006,11 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||
# | Req1-5(F)| Req2-0 | Req2-1 | ... |
|
||||
req2 = make_request("2", [7] * block_size * 2)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert not computed_blocks.blocks
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
manager.allocate_slots(req2, block_size * 2,
|
||||
len(computed_blocks.blocks) * 16, computed_blocks)
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
|
||||
# Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed,
|
||||
# but it cannot be allocated due to insufficient free blocks (2).
|
||||
@ -773,11 +1018,11 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 5
|
||||
req3 = make_request("3", common_token_ids * 3)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
|
||||
assert computed_blocks.blocks == block_part1
|
||||
assert computed_blocks.blocks[0] == block_part1
|
||||
assert num_computed_tokens == 6 * 16
|
||||
# Req3 cannot be allocated.
|
||||
assert manager.allocate_slots(req3, 48,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks) is None
|
||||
# Block 0-2 are used by Req 1.
|
||||
assert {block.ref_cnt for block in block_part1[:3]} == {1}
|
||||
@ -804,9 +1049,9 @@ def test_reset_prefix_cache():
|
||||
req1 = make_request("1", all_token_ids)
|
||||
computed_blocks, _ = manager.get_computed_blocks(req1)
|
||||
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
|
||||
assert len(computed_blocks.blocks) == 3
|
||||
assert len(computed_blocks.blocks[0]) == 3
|
||||
blocks = manager.allocate_slots(req1, 7,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [[5]]
|
||||
|
||||
@ -836,10 +1081,11 @@ def test_prefix_cache_stats_disabled():
|
||||
# Call all functions that check whether log_stats is disabled.
|
||||
req = make_request("0", list(range(16)))
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
||||
assert not computed_blocks.blocks
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
manager.allocate_slots(req, 16,
|
||||
len(computed_blocks.blocks) * 16, computed_blocks)
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
manager.reset_prefix_cache()
|
||||
|
||||
# Ensure prefix_cache_stats remains None
|
||||
@ -918,7 +1164,8 @@ def test_eagle_enabled_removes_last_block():
|
||||
# Prime the cache
|
||||
computed_blocks, _ = manager.get_computed_blocks(req)
|
||||
manager.allocate_slots(req, len(token_ids),
|
||||
len(computed_blocks.blocks) * 16, computed_blocks)
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
manager.free(req)
|
||||
|
||||
# New request with same tokens + Eagle enabled
|
||||
@ -928,7 +1175,7 @@ def test_eagle_enabled_removes_last_block():
|
||||
# Should retain 1 block:
|
||||
# 1. Original 3 blocks → pop last hash → 2 matched blocks
|
||||
# 2. drop last matched block → 1 remaining block
|
||||
assert len(computed_blocks.blocks) == 1
|
||||
assert len(computed_blocks.blocks[0]) == 1
|
||||
assert num_tokens == 1 * block_size # 16 tokens
|
||||
|
||||
|
||||
@ -948,14 +1195,15 @@ def test_eagle_with_partial_blocks():
|
||||
# Prime the cache
|
||||
computed_blocks, _ = manager.get_computed_blocks(req)
|
||||
manager.allocate_slots(req, len(token_ids),
|
||||
len(computed_blocks.blocks) * 16, computed_blocks)
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
manager.free(req)
|
||||
|
||||
# New request with Eagle enabled
|
||||
req_eagle = make_request("partial_eagle", token_ids)
|
||||
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
|
||||
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
|
||||
assert len(computed_blocks.blocks) == 1
|
||||
assert len(computed_blocks.blocks[0]) == 1
|
||||
assert num_tokens == 1 * block_size
|
||||
|
||||
|
||||
@ -973,7 +1221,7 @@ def test_eagle_with_sliding_window():
|
||||
manager = KVCacheManager(
|
||||
KVCacheConfig(
|
||||
num_blocks=10,
|
||||
tensors={},
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[KVCacheGroupSpec(['layer'], sliding_window_spec)],
|
||||
),
|
||||
max_model_len=8192,
|
||||
@ -988,7 +1236,8 @@ def test_eagle_with_sliding_window():
|
||||
# Prime the cache
|
||||
computed_blocks, _ = manager.get_computed_blocks(req)
|
||||
manager.allocate_slots(req, len(token_ids),
|
||||
len(computed_blocks.blocks) * 16, computed_blocks)
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
# record the block hash of the first block in the request for later use
|
||||
block_hash_first_block = manager.req_to_block_hashes[req.request_id][0]
|
||||
assert block_hash_first_block is not None
|
||||
@ -998,13 +1247,14 @@ def test_eagle_with_sliding_window():
|
||||
req_eagle = make_request("partial_eagle", token_ids)
|
||||
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
|
||||
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
|
||||
assert len(computed_blocks.blocks) == 1
|
||||
assert len(computed_blocks.blocks[0]) == 1
|
||||
assert num_tokens == 1 * block_size
|
||||
|
||||
# Evict the first block in the request
|
||||
assert manager.block_pool.get_cached_block(
|
||||
block_hash_first_block) is not None
|
||||
manager.block_pool.cached_block_hash_to_block.pop(block_hash_first_block)
|
||||
block_hash_first_block, kv_cache_group_ids=[0]) is not None
|
||||
manager.block_pool.cached_block_hash_to_block.pop(
|
||||
BlockHashWithGroupId(block_hash_first_block, 0))
|
||||
|
||||
# New request
|
||||
req_after_evict = make_request("partial_eagle_after_evict", token_ids)
|
||||
@ -1012,5 +1262,5 @@ def test_eagle_with_sliding_window():
|
||||
# Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is
|
||||
# not considered. But after dropping the last matched block due to eagle,
|
||||
# there will be no matched prefix.
|
||||
assert len(computed_blocks.blocks) == 0
|
||||
assert len(computed_blocks.blocks[0]) == 0
|
||||
assert num_tokens == 0
|
||||
|
||||
@ -97,7 +97,7 @@ def create_scheduler(
|
||||
)
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=num_blocks, # A large number of blocks to hold all requests
|
||||
tensors={},
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(['layer'],
|
||||
FullAttentionSpec(block_size, 1, 1, torch.float32,
|
||||
@ -814,10 +814,10 @@ def _assert_right_kv_cache_manager(
|
||||
# Make sure the request stats are right.
|
||||
EXPECTED_TOTAL_BLOCKS = num_tokens // block_size
|
||||
for req_id in req_ids:
|
||||
blocks = (scheduler.kv_cache_manager.single_type_manager.
|
||||
req_to_blocks[req_id])
|
||||
blocks = (scheduler.kv_cache_manager.coordinator.
|
||||
single_type_managers[0].req_to_blocks[req_id])
|
||||
hashes = scheduler.kv_cache_manager.req_to_block_hashes[req_id]
|
||||
assert (scheduler.kv_cache_manager.single_type_manager.
|
||||
assert (scheduler.kv_cache_manager.coordinator.single_type_managers[0].
|
||||
num_cached_block[req_id] == EXPECTED_TOTAL_BLOCKS)
|
||||
assert len(blocks) == EXPECTED_TOTAL_BLOCKS
|
||||
assert len(hashes) == EXPECTED_TOTAL_BLOCKS
|
||||
@ -1198,11 +1198,11 @@ def assert_scheduler_empty(scheduler: Scheduler):
|
||||
assert len(scheduler.encoder_cache_manager.cached) == 0
|
||||
|
||||
# KVCache Manager.
|
||||
assert len(
|
||||
scheduler.kv_cache_manager.single_type_manager.req_to_blocks) == 0
|
||||
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
|
||||
req_to_blocks) == 0
|
||||
assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0
|
||||
assert len(
|
||||
scheduler.kv_cache_manager.single_type_manager.num_cached_block) == 0
|
||||
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
|
||||
num_cached_block) == 0
|
||||
num_free_blocks = (
|
||||
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
|
||||
assert num_free_blocks == (
|
||||
|
||||
@ -4,7 +4,8 @@
|
||||
import torch
|
||||
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
|
||||
KVCacheBlock)
|
||||
from vllm.v1.core.single_type_kv_cache_manager import SlidingWindowManager
|
||||
from vllm.v1.kv_cache_interface import SlidingWindowSpec
|
||||
|
||||
@ -12,9 +13,8 @@ from vllm.v1.kv_cache_interface import SlidingWindowSpec
|
||||
def get_sliding_window_manager(sliding_window_spec, block_pool):
|
||||
return SlidingWindowManager(sliding_window_spec,
|
||||
block_pool,
|
||||
use_eagle=False,
|
||||
num_kv_cache_groups=1,
|
||||
caching_hash_fn=lambda x: x)
|
||||
caching_hash_fn=lambda x: x,
|
||||
kv_cache_group_id=0)
|
||||
|
||||
|
||||
def test_sliding_window_possible_cached_prefix():
|
||||
@ -42,13 +42,18 @@ def test_sliding_window_possible_cached_prefix():
|
||||
for i, (block_hash,
|
||||
is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
|
||||
if is_cached:
|
||||
block_pool.cached_block_hash_to_block[block_hash] = {
|
||||
i: block_pool.blocks[i + 10]
|
||||
}
|
||||
block_pool.cached_block_hash_to_block[BlockHashWithGroupId(
|
||||
block_hash, 0)] = {
|
||||
i: block_pool.blocks[i + 10],
|
||||
}
|
||||
|
||||
computed_blocks = manager.find_longest_cache_hit(
|
||||
block_hash_list,
|
||||
len(block_hash_list) * block_size)
|
||||
block_hashes=block_hash_list,
|
||||
max_length=len(block_hash_list) * block_size,
|
||||
kv_cache_group_ids=[0],
|
||||
block_pool=block_pool,
|
||||
kv_cache_spec=sliding_window_spec,
|
||||
use_eagle=False)[0]
|
||||
assert len(computed_blocks) == expect_length
|
||||
|
||||
assert all(block == block_pool.null_block
|
||||
@ -95,13 +100,13 @@ def test_sliding_window_remove_skipped_blocks():
|
||||
|
||||
null_block_id = block_pool.null_block.block_id
|
||||
|
||||
def id_to_block_table(ids):
|
||||
def id_to_block_table(ids) -> list[KVCacheBlock]:
|
||||
return [
|
||||
KVCacheBlock(id_)
|
||||
if id_ != null_block_id else block_pool.null_block for id_ in ids
|
||||
]
|
||||
|
||||
def assert_block_id(block_table, ids):
|
||||
def assert_block_id(block_table: list[KVCacheBlock], ids: list[int]):
|
||||
for block, id_ in zip(block_table, ids):
|
||||
if id_ == null_block_id:
|
||||
assert block == block_pool.null_block
|
||||
|
||||
@ -18,7 +18,7 @@ class TestConfig:
|
||||
|
||||
model_config = {
|
||||
"bigcode/starcoder2-3b": TestConfig(4096, (800, 1100)),
|
||||
"google/gemma-2-2b-it": TestConfig(4096, (400, 800)),
|
||||
"google/gemma-3-1b-it": TestConfig(4096, (400, 800)),
|
||||
}
|
||||
|
||||
|
||||
@ -26,7 +26,7 @@ model_config = {
|
||||
"model",
|
||||
[
|
||||
"bigcode/starcoder2-3b", # sliding window only
|
||||
"google/gemma-2-2b-it", # sliding window + full attention
|
||||
"google/gemma-3-1b-it", # sliding window + full attention
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [5])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
|
||||
@ -36,8 +36,8 @@ def test_basic_inferface():
|
||||
req_meta = kv_connector_metadata.requests[request_id]
|
||||
|
||||
for block_id, block in zip(
|
||||
req_meta.local_block_ids, scheduler.kv_cache_manager.
|
||||
single_type_manager.req_to_blocks[request_id]):
|
||||
req_meta.local_block_ids, scheduler.kv_cache_manager.coordinator.
|
||||
single_type_managers[0].req_to_blocks[request_id]):
|
||||
assert block_id == block.block_id
|
||||
|
||||
|
||||
|
||||
@ -54,8 +54,8 @@ def test_basic_lifecycle():
|
||||
assert len(scheduler.waiting) == 0
|
||||
|
||||
# ... but blocks should not be freed.
|
||||
blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[
|
||||
request_id]
|
||||
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0].req_to_blocks[request_id]
|
||||
for block in blocks:
|
||||
assert block.ref_cnt == 1
|
||||
|
||||
|
||||
@ -51,8 +51,8 @@ def test_basic_lifecycle():
|
||||
assert (block_pool.free_block_queue.num_free_blocks
|
||||
< START_FREE_BLOCK_QUEUE_SIZE)
|
||||
assert len(block_pool.cached_block_hash_to_block) == 0
|
||||
blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[
|
||||
request_id]
|
||||
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0].req_to_blocks[request_id]
|
||||
for block in blocks:
|
||||
assert block._block_hash is None
|
||||
|
||||
@ -87,8 +87,8 @@ def test_basic_lifecycle():
|
||||
|
||||
# Confirm the block are actually allocated.
|
||||
num_hashed_blocks = 0
|
||||
blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[
|
||||
request_id]
|
||||
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0].req_to_blocks[request_id]
|
||||
for block in blocks:
|
||||
assert block.ref_cnt == 1
|
||||
num_hashed_blocks += (1 if block._block_hash is not None else 0)
|
||||
@ -261,10 +261,10 @@ def test_no_spurious_prefix_caching():
|
||||
assert len(scheduler.running) == 1
|
||||
assert len(scheduler.waiting) == 1
|
||||
|
||||
local_blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[
|
||||
request_local.request_id]
|
||||
remote_blocks = scheduler.kv_cache_manager.single_type_manager.req_to_blocks[ # noqa: E501
|
||||
request_remote.request_id]
|
||||
local_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0].req_to_blocks[request_local.request_id]
|
||||
remote_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0].req_to_blocks[request_remote.request_id]
|
||||
|
||||
# Local should have cached blocks (but not all due to preallocate).
|
||||
num_hashed_blocks = 0
|
||||
@ -300,8 +300,8 @@ def test_full_block_prompt():
|
||||
# STEP (1): Initialize a recv.
|
||||
scheduler_output = scheduler.schedule()
|
||||
# All blocks should be allocated.
|
||||
num_blocks = len(scheduler.kv_cache_manager.single_type_manager.
|
||||
req_to_blocks[request_id])
|
||||
num_blocks = len(scheduler.kv_cache_manager.coordinator.
|
||||
single_type_managers[0].req_to_blocks[request_id])
|
||||
assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS
|
||||
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
@ -319,8 +319,8 @@ def test_full_block_prompt():
|
||||
|
||||
# We need to recompute the final token of the prompt to generate
|
||||
# the first new token, so we should not have a new block.
|
||||
num_blocks = len(scheduler.kv_cache_manager.single_type_manager.
|
||||
req_to_blocks[request_id])
|
||||
num_blocks = len(scheduler.kv_cache_manager.coordinator.
|
||||
single_type_managers[0].req_to_blocks[request_id])
|
||||
assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS
|
||||
assert (scheduler_output.scheduled_new_reqs[0].num_computed_tokens ==
|
||||
NUM_TOKENS - 1)
|
||||
|
||||
@ -32,11 +32,11 @@ def assert_scheduler_empty(scheduler: Scheduler):
|
||||
assert len(scheduler.encoder_cache_manager.cached) == 0
|
||||
|
||||
# KVCache Manager.
|
||||
assert len(
|
||||
scheduler.kv_cache_manager.single_type_manager.req_to_blocks) == 0
|
||||
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
|
||||
req_to_blocks) == 0
|
||||
assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0
|
||||
assert len(
|
||||
scheduler.kv_cache_manager.single_type_manager.num_cached_block) == 0
|
||||
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
|
||||
num_cached_block) == 0
|
||||
num_free_blocks = (
|
||||
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
|
||||
assert num_free_blocks == (
|
||||
@ -96,7 +96,7 @@ def create_scheduler(
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=num_blocks, # A large number of blocks to hold all requests
|
||||
tensors={},
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(['layer'],
|
||||
FullAttentionSpec(block_size, 1, 1, torch.float32,
|
||||
|
||||
@ -40,12 +40,13 @@ def initialize_kv_cache(runner: GPUModelRunner):
|
||||
tensor_size = attn_spec.page_size_bytes * NUM_BLOCKS
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=NUM_BLOCKS,
|
||||
tensors={
|
||||
"layer.0": KVCacheTensor(size=tensor_size),
|
||||
},
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(size=tensor_size, shared_by=["layer.0"]),
|
||||
],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(layer_names=["layer.0"], kv_cache_spec=attn_spec)
|
||||
])
|
||||
],
|
||||
)
|
||||
runner.kv_cache_config = kv_cache_config
|
||||
runner.input_batch = InputBatch(
|
||||
max_num_reqs=runner.max_num_reqs,
|
||||
@ -518,9 +519,9 @@ def test_init_kv_cache_without_kv_sharing():
|
||||
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
|
||||
available_memory)
|
||||
assert kv_cache_config.num_blocks == num_expected_blocks
|
||||
assert len(kv_cache_config.tensors) == 2
|
||||
assert kv_cache_config.tensors[layer_0].size == available_memory // 2
|
||||
assert kv_cache_config.tensors[layer_1].size == available_memory // 2
|
||||
assert len(kv_cache_config.kv_cache_tensors) == 2
|
||||
assert kv_cache_config.kv_cache_tensors[0].size == available_memory // 2
|
||||
assert kv_cache_config.kv_cache_tensors[1].size == available_memory // 2
|
||||
|
||||
max_context_len =\
|
||||
estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
|
||||
@ -530,9 +531,9 @@ def test_init_kv_cache_without_kv_sharing():
|
||||
# important: override tensor size to prevent large mem alloc during test
|
||||
# this will only allocate 2 block worth of memory (2 * 32kb)
|
||||
kv_cache_config.num_blocks = 1
|
||||
for layer in kv_cache_config.tensors:
|
||||
kv_cache_config.tensors[layer].size =\
|
||||
kv_cache_spec[layer].page_size_bytes
|
||||
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
||||
kv_cache_tensor.size = (
|
||||
kv_cache_spec[kv_cache_tensor.shared_by[0]].page_size_bytes)
|
||||
|
||||
runner.initialize_kv_cache(kv_cache_config)
|
||||
|
||||
@ -589,10 +590,10 @@ def test_init_kv_cache_with_kv_sharing_valid():
|
||||
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
|
||||
available_memory)
|
||||
assert kv_cache_config.num_blocks == num_expected_blocks
|
||||
assert len(kv_cache_config.tensors) == 1
|
||||
assert len(kv_cache_config.kv_cache_tensors) == 1
|
||||
# Each layer now has twice the available memory for KV cache
|
||||
# compared to no KV sharing
|
||||
assert kv_cache_config.tensors[layer_0].size == available_memory
|
||||
assert kv_cache_config.kv_cache_tensors[0].size == available_memory
|
||||
|
||||
max_context_len =\
|
||||
estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
|
||||
@ -602,7 +603,7 @@ def test_init_kv_cache_with_kv_sharing_valid():
|
||||
# important: override tensor size to prevent large mem alloc during test
|
||||
# this will only allocate 1 block worth of memory (32kb)
|
||||
kv_cache_config.num_blocks = 1
|
||||
kv_cache_config.tensors[layer_0].size =\
|
||||
kv_cache_config.kv_cache_tensors[0].size =\
|
||||
kv_cache_spec[layer_0].page_size_bytes
|
||||
|
||||
runner.initialize_kv_cache(kv_cache_config)
|
||||
|
||||
@ -2104,6 +2104,12 @@ class SchedulerConfig:
|
||||
default scheduler. Can be a class directly or the path to a class of form
|
||||
"mod.custom_class"."""
|
||||
|
||||
disable_hybrid_kv_cache_manager: bool = False
|
||||
"""If set to True, KV cache manager will allocate the same size of KV cache
|
||||
for all attention layers even if there are multiple type of attention layers
|
||||
like full attention and sliding window attention.
|
||||
"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
@ -4465,6 +4471,21 @@ class VllmConfig:
|
||||
if not self.instance_id:
|
||||
self.instance_id = random_uuid()[:5]
|
||||
|
||||
if (envs.VLLM_USE_V1
|
||||
and not self.scheduler_config.disable_hybrid_kv_cache_manager):
|
||||
# logger should only print warning message for hybrid models. As we
|
||||
# can't know whether the model is hybrid or not now, so we don't log
|
||||
# warning message here and will log it later.
|
||||
if not (current_platform.is_cuda() or current_platform.is_rocm()):
|
||||
# Hybrid KV cache manager is not supported on non-GPU platforms.
|
||||
self.disable_hybrid_kv_cache_manager = True
|
||||
if self.kv_transfer_config is not None:
|
||||
# Hybrid KV cache manager is not compatible with KV transfer.
|
||||
self.disable_hybrid_kv_cache_manager = True
|
||||
if self.kv_events_config is not None:
|
||||
# Hybrid KV cache manager is not compatible with KV events.
|
||||
self.disable_hybrid_kv_cache_manager = True
|
||||
|
||||
def update_sizes_for_sequence_parallelism(self,
|
||||
possible_sizes: list) -> list:
|
||||
# remove the sizes that not multiple of tp_size when
|
||||
|
||||
@ -387,6 +387,9 @@ class EngineArgs:
|
||||
bool] = SchedulerConfig.enable_chunked_prefill
|
||||
disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
|
||||
|
||||
disable_hybrid_kv_cache_manager: bool = (
|
||||
SchedulerConfig.disable_hybrid_kv_cache_manager)
|
||||
|
||||
guided_decoding_backend: GuidedDecodingBackend = DecodingConfig.backend
|
||||
guided_decoding_disable_fallback: bool = DecodingConfig.disable_fallback
|
||||
guided_decoding_disable_any_whitespace: bool = \
|
||||
@ -849,6 +852,9 @@ class EngineArgs:
|
||||
**scheduler_kwargs["disable_chunked_mm_input"])
|
||||
scheduler_group.add_argument("--scheduler-cls",
|
||||
**scheduler_kwargs["scheduler_cls"])
|
||||
scheduler_group.add_argument(
|
||||
"--disable-hybrid-kv-cache-manager",
|
||||
**scheduler_kwargs["disable_hybrid_kv_cache_manager"])
|
||||
|
||||
# vLLM arguments
|
||||
vllm_kwargs = get_kwargs(VllmConfig)
|
||||
@ -1174,6 +1180,8 @@ class EngineArgs:
|
||||
max_num_partial_prefills=self.max_num_partial_prefills,
|
||||
max_long_partial_prefills=self.max_long_partial_prefills,
|
||||
long_prefill_token_threshold=self.long_prefill_token_threshold,
|
||||
disable_hybrid_kv_cache_manager=self.
|
||||
disable_hybrid_kv_cache_manager,
|
||||
)
|
||||
|
||||
lora_config = LoRAConfig(
|
||||
|
||||
@ -7,8 +7,8 @@ from typing import Callable, Optional
|
||||
from vllm.distributed.kv_events import (AllBlocksCleared, BlockRemoved,
|
||||
BlockStored, KVCacheEvent)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, FreeKVCacheBlockQueue,
|
||||
KVCacheBlock,
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
|
||||
FreeKVCacheBlockQueue, KVCacheBlock,
|
||||
generate_block_hash_extra_keys,
|
||||
hash_block_tokens)
|
||||
from vllm.v1.request import Request
|
||||
@ -27,6 +27,7 @@ class BlockPool:
|
||||
Args:
|
||||
num_gpu_blocks: The number of blocks in the pool.
|
||||
enable_caching: Whether to enable prefix caching.
|
||||
enable_kv_cache_events: Whether to enable kv cache events.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -56,7 +57,7 @@ class BlockPool:
|
||||
# if there is already an identical block in the cache. This is because
|
||||
# we want to make sure the allocated block IDs won't change so that
|
||||
# block tables are append-only.
|
||||
self.cached_block_hash_to_block: dict[BlockHash, dict[
|
||||
self.cached_block_hash_to_block: dict[BlockHashWithGroupId, dict[
|
||||
int, KVCacheBlock]] = defaultdict(dict)
|
||||
|
||||
# To represent a placeholder block with block_id=0.
|
||||
@ -68,22 +69,29 @@ class BlockPool:
|
||||
self.enable_kv_cache_events = enable_kv_cache_events
|
||||
self.kv_event_queue: list[KVCacheEvent] = []
|
||||
|
||||
def get_cached_block(self,
|
||||
block_hash: BlockHash) -> Optional[KVCacheBlock]:
|
||||
"""Get a cached block by the block hash, or None if cache miss.
|
||||
def get_cached_block(
|
||||
self, block_hash: BlockHash,
|
||||
kv_cache_group_ids: list[int]) -> Optional[list[KVCacheBlock]]:
|
||||
"""Get the cached block by the block hash for each group in
|
||||
`kv_cache_group_ids`, or None if cache miss for any group.
|
||||
If there are duplicated blocks, we return the first block in the cache.
|
||||
|
||||
Args:
|
||||
block_hash: The hash value of the block.
|
||||
kv_cache_group_ids: The ids of the KV cache groups.
|
||||
|
||||
Returns:
|
||||
The cached block if it exists, or None.
|
||||
The cached blocks if exists, or None.
|
||||
"""
|
||||
cached_blocks = self.cached_block_hash_to_block.get(block_hash)
|
||||
if not cached_blocks:
|
||||
return None
|
||||
first_block_id = next(iter(cached_blocks))
|
||||
return cached_blocks[first_block_id]
|
||||
cached_blocks = []
|
||||
for group_id in kv_cache_group_ids:
|
||||
cached_blocks_one_group = self.cached_block_hash_to_block.get(
|
||||
BlockHashWithGroupId(block_hash, group_id))
|
||||
if not cached_blocks_one_group:
|
||||
return None
|
||||
first_block_id = next(iter(cached_blocks_one_group))
|
||||
cached_blocks.append(cached_blocks_one_group[first_block_id])
|
||||
return cached_blocks
|
||||
|
||||
def cache_full_blocks(
|
||||
self,
|
||||
@ -93,6 +101,7 @@ class BlockPool:
|
||||
num_cached_blocks: int,
|
||||
num_full_blocks: int,
|
||||
block_size: int,
|
||||
kv_cache_group_id: int,
|
||||
hash_fn: Callable,
|
||||
) -> None:
|
||||
"""Cache a list of full blocks for prefix caching.
|
||||
@ -112,6 +121,7 @@ class BlockPool:
|
||||
num_full_blocks: The number of blocks that are full and should
|
||||
be cached after this function.
|
||||
block_size: Number of tokens in each block.
|
||||
kv_cache_group_id: The id of the KV cache group.
|
||||
hash_fn: The hash function to use for block hashes.
|
||||
"""
|
||||
if num_cached_blocks == num_full_blocks:
|
||||
@ -126,7 +136,7 @@ class BlockPool:
|
||||
else:
|
||||
prev_block = blocks[num_cached_blocks - 1]
|
||||
assert prev_block.block_hash is not None
|
||||
prev_block_hash_value = prev_block.block_hash.hash_value
|
||||
prev_block_hash_value = prev_block.block_hash.get_hash_value()
|
||||
|
||||
parent_block_hash = prev_block_hash_value
|
||||
new_hashes: Optional[list[int]] = ([] if self.enable_kv_cache_events
|
||||
@ -138,8 +148,9 @@ class BlockPool:
|
||||
# The block hash may already be computed in
|
||||
# "get_computed_blocks" if the tokens are not generated by
|
||||
# this request (either the prompt tokens or the previously
|
||||
# generated tokens with preemption). In this case we simply
|
||||
# reuse the block hash.
|
||||
# generated tokens with preemption), or by other
|
||||
# single_type_managers with the same block_size.
|
||||
# In this case we simply reuse the block hash.
|
||||
block_hash = new_block_hashes[i]
|
||||
else:
|
||||
# Otherwise compute the block hash and cache it in the request
|
||||
@ -166,8 +177,11 @@ class BlockPool:
|
||||
block_hashes.append(block_hash)
|
||||
|
||||
# Update and added the full block to the cache.
|
||||
blk.block_hash = block_hash
|
||||
self.cached_block_hash_to_block[block_hash][blk.block_id] = blk
|
||||
block_hash_with_group_id = BlockHashWithGroupId(
|
||||
block_hash, kv_cache_group_id)
|
||||
blk.block_hash = block_hash_with_group_id
|
||||
self.cached_block_hash_to_block[block_hash_with_group_id][
|
||||
blk.block_id] = blk
|
||||
if new_hashes is not None:
|
||||
new_hashes.append(block_hash.hash_value)
|
||||
prev_block_hash_value = block_hash.hash_value
|
||||
@ -237,12 +251,16 @@ class BlockPool:
|
||||
del self.cached_block_hash_to_block[block_hash]
|
||||
|
||||
if self.enable_kv_cache_events:
|
||||
# FIXME (Chen): Not sure whether we should return `hash_value`
|
||||
# or `(hash_value, group_id)` here. But it's fine now because
|
||||
# we disable hybrid kv cache manager when kv cache event is
|
||||
# enabled, so there is only one group.
|
||||
self.kv_event_queue.append(
|
||||
BlockRemoved(block_hashes=[block_hash.hash_value]))
|
||||
BlockRemoved(block_hashes=[block_hash.get_hash_value()]))
|
||||
return True
|
||||
return False
|
||||
|
||||
def touch(self, blocks: list[KVCacheBlock]) -> None:
|
||||
def touch(self, blocks: list[list[KVCacheBlock]]) -> None:
|
||||
"""Touch a block increases its reference count by 1, and may remove
|
||||
the block from the free queue. This is used when a block is hit by
|
||||
another request with the same prefix.
|
||||
@ -250,12 +268,13 @@ class BlockPool:
|
||||
Args:
|
||||
blocks: A list of blocks to touch.
|
||||
"""
|
||||
for block in blocks:
|
||||
# ref_cnt=0 means this block is in the free list (i.e. eviction
|
||||
# candidate), so remove it.
|
||||
if block.ref_cnt == 0 and not block.is_null:
|
||||
self.free_block_queue.remove(block)
|
||||
block.incr_ref()
|
||||
for blocks_per_group in blocks:
|
||||
for block in blocks_per_group:
|
||||
# ref_cnt=0 means this block is in the free list (i.e. eviction
|
||||
# candidate), so remove it.
|
||||
if block.ref_cnt == 0 and not block.is_null:
|
||||
self.free_block_queue.remove(block)
|
||||
block.incr_ref()
|
||||
|
||||
def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None:
|
||||
"""Free a list of blocks. The blocks should be ordered by their
|
||||
|
||||
358
vllm/v1/core/kv_cache_coordinator.py
Normal file
358
vllm/v1/core/kv_cache_coordinator.py
Normal file
@ -0,0 +1,358 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Optional
|
||||
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
|
||||
from vllm.v1.core.single_type_kv_cache_manager import (
|
||||
FullAttentionManager, SingleTypeKVCacheManager,
|
||||
get_manager_for_kv_cache_spec)
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
class KVCacheCoordinator(ABC):
|
||||
"""
|
||||
Coordinate the KV cache of different KV cache groups.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
max_model_len: int,
|
||||
use_eagle: bool,
|
||||
enable_caching: bool,
|
||||
caching_hash_fn: Callable,
|
||||
enable_kv_cache_events: bool,
|
||||
):
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.max_model_len = max_model_len
|
||||
|
||||
self.block_pool = BlockPool(kv_cache_config.num_blocks, enable_caching,
|
||||
enable_kv_cache_events)
|
||||
self.single_type_managers: list[SingleTypeKVCacheManager] = []
|
||||
|
||||
# Needs special handling for find_longest_cache_hit if eagle is enabled
|
||||
self.use_eagle = use_eagle
|
||||
|
||||
for i in range(len(self.kv_cache_config.kv_cache_groups)):
|
||||
kv_cache_spec = self.kv_cache_config.kv_cache_groups[
|
||||
i].kv_cache_spec
|
||||
self.single_type_managers.append(
|
||||
get_manager_for_kv_cache_spec(
|
||||
kv_cache_spec=kv_cache_spec,
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_group_id=i,
|
||||
caching_hash_fn=caching_hash_fn,
|
||||
))
|
||||
|
||||
def get_num_blocks_to_allocate(
|
||||
self, request_id: str, num_tokens: int,
|
||||
new_computed_blocks: list[list[KVCacheBlock]]) -> int:
|
||||
"""
|
||||
Get the number of blocks needed to be allocated for the request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
num_tokens: The total number of tokens that need a slot (including
|
||||
tokens that are already allocated).
|
||||
new_computed_blocks: The new computed blocks just hitting the
|
||||
prefix caching.
|
||||
|
||||
Returns:
|
||||
The number of blocks.
|
||||
"""
|
||||
num_blocks_to_allocate = 0
|
||||
for i, manager in enumerate(self.single_type_managers):
|
||||
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
|
||||
request_id, num_tokens, new_computed_blocks[i])
|
||||
return num_blocks_to_allocate
|
||||
|
||||
def save_new_computed_blocks(
|
||||
self, request_id: str,
|
||||
new_computed_blocks: list[list[KVCacheBlock]]) -> None:
|
||||
"""
|
||||
Add the new computed blocks to the request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
new_computed_blocks: The new computed blocks just hitting the
|
||||
prefix cache.
|
||||
"""
|
||||
for i, manager in enumerate(self.single_type_managers):
|
||||
manager.save_new_computed_blocks(request_id,
|
||||
new_computed_blocks[i])
|
||||
|
||||
def allocate_new_blocks(self, request_id: str,
|
||||
num_tokens: int) -> list[list[KVCacheBlock]]:
|
||||
"""
|
||||
Allocate new blocks for the request to give it at least `num_tokens`
|
||||
token slots.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
num_tokens: The total number of tokens that need a slot (including
|
||||
tokens that are already allocated).
|
||||
|
||||
Returns:
|
||||
The new allocated blocks.
|
||||
"""
|
||||
new_blocks = []
|
||||
for manager in self.single_type_managers:
|
||||
new_blocks.append(
|
||||
manager.allocate_new_blocks(request_id, num_tokens))
|
||||
return new_blocks
|
||||
|
||||
def cache_blocks(self, request: Request, block_hashes: list[BlockHash],
|
||||
num_computed_tokens: int) -> None:
|
||||
"""
|
||||
Cache the blocks for the request.
|
||||
|
||||
Args:
|
||||
request: The request.
|
||||
block_hashes: The block hashes of the request.
|
||||
num_tokens: The total number of tokens that need to be cached
|
||||
(including tokens that are already cached).
|
||||
"""
|
||||
for manager in self.single_type_managers:
|
||||
manager.cache_blocks(request, block_hashes, num_computed_tokens)
|
||||
|
||||
def free(self, request_id: str) -> None:
|
||||
"""
|
||||
Free the blocks for the request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
"""
|
||||
for manager in self.single_type_managers:
|
||||
manager.free(request_id)
|
||||
|
||||
def get_num_common_prefix_blocks(self, request_id: str,
|
||||
num_running_requests: int) -> list[int]:
|
||||
"""
|
||||
Get the number of common prefix blocks for a request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
block_hashes: The block hashes of the request.
|
||||
|
||||
Returns:
|
||||
The number of common prefix blocks.
|
||||
"""
|
||||
num_blocks_per_group = [
|
||||
manager.get_num_common_prefix_blocks(request_id,
|
||||
num_running_requests)
|
||||
for manager in self.single_type_managers
|
||||
]
|
||||
return num_blocks_per_group
|
||||
|
||||
def remove_skipped_blocks(self, request_id: str,
|
||||
num_computed_tokens: int) -> None:
|
||||
"""
|
||||
Remove the blocks that are no longer needed from `blocks` and replace
|
||||
the removed blocks with null_block.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
num_computed_tokens: The number of tokens that have been computed.
|
||||
"""
|
||||
for manager in self.single_type_managers:
|
||||
manager.remove_skipped_blocks(request_id, num_computed_tokens)
|
||||
|
||||
def get_blocks(self, request_id: str) -> list[list[KVCacheBlock]]:
|
||||
"""
|
||||
Get the blocks for the request.
|
||||
"""
|
||||
return [
|
||||
manager.req_to_blocks[request_id]
|
||||
for manager in self.single_type_managers
|
||||
]
|
||||
|
||||
@abstractmethod
|
||||
def find_longest_cache_hit(
|
||||
self, block_hashes: list[BlockHash],
|
||||
max_cache_hit_length: int) -> tuple[list[list[KVCacheBlock]], int]:
|
||||
pass
|
||||
|
||||
|
||||
class UnitaryKVCacheCoordinator(KVCacheCoordinator):
|
||||
"""
|
||||
KV cache coordinator for models with only one KV cache group. This is the
|
||||
case for models with only one KV cache type, e.g., all attention layers use
|
||||
full attention or all attention layers use sliding window attention.
|
||||
"""
|
||||
|
||||
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
|
||||
use_eagle: bool, enable_caching: bool,
|
||||
caching_hash_fn: Callable, enable_kv_cache_events: bool):
|
||||
super().__init__(kv_cache_config, max_model_len, use_eagle,
|
||||
enable_caching, caching_hash_fn,
|
||||
enable_kv_cache_events)
|
||||
self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[
|
||||
0].kv_cache_spec
|
||||
self.block_size = self.kv_cache_spec.block_size
|
||||
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
|
||||
"UnitaryKVCacheCoordinator assumes only one kv cache group")
|
||||
|
||||
def find_longest_cache_hit(
|
||||
self, block_hashes: list[BlockHash],
|
||||
max_cache_hit_length: int) -> tuple[list[list[KVCacheBlock]], int]:
|
||||
hit_blocks = self.single_type_managers[0].find_longest_cache_hit(
|
||||
block_hashes=block_hashes,
|
||||
max_length=max_cache_hit_length,
|
||||
kv_cache_group_ids=[0],
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_spec=self.kv_cache_spec,
|
||||
use_eagle=self.use_eagle,
|
||||
)
|
||||
return hit_blocks, len(hit_blocks[0]) * self.block_size
|
||||
|
||||
|
||||
class HybridKVCacheCoordinator(KVCacheCoordinator):
|
||||
"""
|
||||
KV cache coordinator for hybrid models with multiple KV cache types, and
|
||||
thus multiple kv cache groups.
|
||||
To simplify `find_longest_cache_hit`, it only supports the combination of
|
||||
two types of KV cache groups, and one of them must be full attention.
|
||||
May extend to more general cases in the future.
|
||||
"""
|
||||
|
||||
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
|
||||
use_eagle: bool, enable_caching: bool,
|
||||
caching_hash_fn: Callable, enable_kv_cache_events: bool):
|
||||
super().__init__(kv_cache_config, max_model_len, use_eagle,
|
||||
enable_caching, caching_hash_fn,
|
||||
enable_kv_cache_events)
|
||||
self.verify_and_split_kv_cache_groups()
|
||||
|
||||
def verify_and_split_kv_cache_groups(self) -> None:
|
||||
"""
|
||||
Verifies that the model has exactly two types of KV cache groups, and
|
||||
one of them is full attention. Then, split the kv cache groups into full
|
||||
attention groups and other groups.
|
||||
"""
|
||||
full_attention_type_id: Optional[str] = None
|
||||
other_type_id: Optional[str] = None
|
||||
self.full_attention_group_ids: list[int] = []
|
||||
self.other_group_ids: list[int] = []
|
||||
for i, g in enumerate(self.kv_cache_config.kv_cache_groups):
|
||||
if isinstance(g.kv_cache_spec, FullAttentionSpec):
|
||||
if full_attention_type_id is None:
|
||||
full_attention_type_id = g.kv_cache_spec.type_id
|
||||
else:
|
||||
assert full_attention_type_id == g.kv_cache_spec.type_id, (
|
||||
"HybridKVCacheCoordinator assumes exactly one type of "
|
||||
"full attention groups now.")
|
||||
self.full_attention_group_ids.append(i)
|
||||
else:
|
||||
if other_type_id is None:
|
||||
other_type_id = g.kv_cache_spec.type_id
|
||||
else:
|
||||
assert other_type_id == g.kv_cache_spec.type_id, (
|
||||
"HybridKVCacheCoordinator assumes "
|
||||
"exactly one other type of groups now.")
|
||||
self.other_group_ids.append(i)
|
||||
|
||||
assert full_attention_type_id is not None, (
|
||||
"HybridKVCacheCoordinator assumes exactly one type of full "
|
||||
"attention groups now.")
|
||||
assert other_type_id is not None, (
|
||||
"HybridKVCacheCoordinator assumes exactly one type of other "
|
||||
"groups now.")
|
||||
|
||||
self.full_attention_manager_cls = FullAttentionManager
|
||||
self.other_attention_cls = self.single_type_managers[
|
||||
self.other_group_ids[0]].__class__
|
||||
|
||||
self.full_attention_spec = self.kv_cache_config.kv_cache_groups[
|
||||
self.full_attention_group_ids[0]].kv_cache_spec
|
||||
self.other_spec = self.kv_cache_config.kv_cache_groups[
|
||||
self.other_group_ids[0]].kv_cache_spec
|
||||
|
||||
self.full_attention_block_size = self.full_attention_spec.block_size
|
||||
self.other_block_size = self.other_spec.block_size
|
||||
assert self.other_block_size % self.full_attention_block_size == 0, (
|
||||
"KVCacheCoordinator assumes the block_size of full attention "
|
||||
"layers is divisible by other layers now.")
|
||||
|
||||
def find_longest_cache_hit(
|
||||
self,
|
||||
block_hashes: list[BlockHash],
|
||||
max_cache_hit_length: int,
|
||||
) -> tuple[list[list[KVCacheBlock]], int]:
|
||||
"""
|
||||
Find the longest cache hit for the request.
|
||||
|
||||
Args:
|
||||
block_hashes: The block hashes of the request.
|
||||
max_cache_hit_length: The maximum length of the cache hit.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- A list of the cache hit blocks for each single type manager.
|
||||
- The number of tokens of the longest cache hit.
|
||||
"""
|
||||
# First, find the longest cache hit for full attention.
|
||||
hit_blocks_full_attn = (
|
||||
self.full_attention_manager_cls.find_longest_cache_hit(
|
||||
block_hashes=block_hashes,
|
||||
max_length=max_cache_hit_length,
|
||||
kv_cache_group_ids=self.full_attention_group_ids,
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_spec=self.full_attention_spec,
|
||||
use_eagle=self.use_eagle,
|
||||
))
|
||||
hit_length = len(
|
||||
hit_blocks_full_attn[0]) * self.full_attention_block_size
|
||||
|
||||
# Next, find the cache hit for the other attention WITHIN
|
||||
# the cache hit of full attention.
|
||||
hit_blocks_other_attn = (
|
||||
self.other_attention_cls.find_longest_cache_hit(
|
||||
block_hashes=block_hashes,
|
||||
max_length=hit_length,
|
||||
kv_cache_group_ids=self.other_group_ids,
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_spec=self.other_spec,
|
||||
use_eagle=self.use_eagle,
|
||||
))
|
||||
hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size
|
||||
|
||||
# NOTE: the prefix cache hit length must be a multiply of block_size as
|
||||
# we don't support partial block cache hit yet. The cache hit length
|
||||
# of other attention is ensured to be a multiply of the block size of
|
||||
# full attention layers in current implementation, because hit_length is
|
||||
# a multiply of other attention's block size, and other attention's
|
||||
# block size is a multiply of full attention's block size (verified in
|
||||
# `verify_and_split_kv_cache_groups`).
|
||||
assert hit_length % self.full_attention_block_size == 0
|
||||
|
||||
# Truncate the full attention cache hit to the length of the
|
||||
# cache hit of the other attention.
|
||||
for i in range(len(hit_blocks_full_attn)):
|
||||
del hit_blocks_full_attn[i][hit_length //
|
||||
self.full_attention_block_size:]
|
||||
|
||||
# Merge the hit blocks of full attention and other attention.
|
||||
hit_blocks = hit_blocks_other_attn
|
||||
for group_id, blocks in enumerate(hit_blocks_full_attn):
|
||||
# NOTE: there is only one full attention group in most cases. So
|
||||
# the time complexity of insert is fine.
|
||||
hit_blocks.insert(group_id, blocks)
|
||||
return hit_blocks, hit_length
|
||||
|
||||
|
||||
def get_kv_cache_coordinator(
|
||||
kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool,
|
||||
enable_caching: bool, caching_hash_fn: Callable,
|
||||
enable_kv_cache_events: bool) -> KVCacheCoordinator:
|
||||
if len(kv_cache_config.kv_cache_groups) == 1:
|
||||
return UnitaryKVCacheCoordinator(kv_cache_config, max_model_len,
|
||||
use_eagle, enable_caching,
|
||||
caching_hash_fn,
|
||||
enable_kv_cache_events)
|
||||
else:
|
||||
return HybridKVCacheCoordinator(kv_cache_config, max_model_len,
|
||||
use_eagle, enable_caching,
|
||||
caching_hash_fn,
|
||||
enable_kv_cache_events)
|
||||
@ -8,11 +8,9 @@ from typing import Optional
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import sha256
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
|
||||
hash_request_tokens)
|
||||
from vllm.v1.core.single_type_kv_cache_manager import (
|
||||
get_manager_for_kv_cache_spec)
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
@ -22,16 +20,24 @@ logger = init_logger(__name__)
|
||||
|
||||
@dataclass
|
||||
class KVCacheBlocks:
|
||||
blocks: list[KVCacheBlock]
|
||||
"""
|
||||
The allocation result of KVCacheManager, work as the interface between
|
||||
Scheduler and KVCacheManager, to hide KVCacheManager's internal data
|
||||
structure from the Scheduler.
|
||||
"""
|
||||
blocks: list[list[KVCacheBlock]]
|
||||
"""
|
||||
blocks[i][j] refers to the i-th kv_cache_group and the j-th block of tokens.
|
||||
We don't use block of tokens as the outer dimension because it assumes all
|
||||
kv_cache_groups have the same number of blocks, which is true for now but
|
||||
will be broken if we want to give different block_size to different
|
||||
kv_cache_groups in the future.
|
||||
"""
|
||||
|
||||
def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks":
|
||||
"""Adds two KVCacheBlocks instances."""
|
||||
return KVCacheBlocks(self.blocks + other.blocks)
|
||||
|
||||
@classmethod
|
||||
def create_empty(cls) -> "KVCacheBlocks":
|
||||
"""Creates a new KVCacheBlocks instance with no blocks."""
|
||||
return cls([])
|
||||
return KVCacheBlocks(
|
||||
[blk1 + blk2 for blk1, blk2 in zip(self.blocks, other.blocks)])
|
||||
|
||||
def get_block_ids(self) -> list[list[int]]:
|
||||
"""
|
||||
@ -39,15 +45,20 @@ class KVCacheBlocks:
|
||||
|
||||
Returns:
|
||||
list[list[int]]: A two-level list where
|
||||
* the outer list corresponds to KV cache groups (only 1 group now)
|
||||
* the outer list corresponds to KV cache groups
|
||||
* each inner list contains the block_ids of the blocks in that group
|
||||
"""
|
||||
return [[block.block_id for block in self.blocks]]
|
||||
block_ids = []
|
||||
for group in self.blocks:
|
||||
block_ids.append([blk.block_id for blk in group])
|
||||
return block_ids
|
||||
|
||||
def get_unhashed_block_ids(self) -> list[int]:
|
||||
"""Get block_ids of unhashed blocks from KVCacheBlocks instance."""
|
||||
assert len(self.blocks) == 1, "Only one group is supported"
|
||||
return [
|
||||
block.block_id for block in self.blocks if block.block_hash is None
|
||||
block.block_id for block in self.blocks[0]
|
||||
if block.block_hash is None
|
||||
]
|
||||
|
||||
|
||||
@ -63,12 +74,6 @@ class KVCacheManager:
|
||||
log_stats: bool = False,
|
||||
enable_kv_cache_events: bool = False,
|
||||
) -> None:
|
||||
assert len(kv_cache_config.kv_cache_groups) == 1, (
|
||||
"KVCacheManager does not support hybrid models with more than 1 "
|
||||
"kv cache group")
|
||||
kv_cache_spec = kv_cache_config.kv_cache_groups[0].kv_cache_spec
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
self.num_gpu_blocks = kv_cache_config.num_blocks
|
||||
self.max_model_len = max_model_len
|
||||
|
||||
self.enable_caching = enable_caching
|
||||
@ -77,17 +82,24 @@ class KVCacheManager:
|
||||
self.log_stats = log_stats
|
||||
# FIXME: make prefix cache stats conditional on log_stats
|
||||
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
|
||||
assert len(
|
||||
set(g.kv_cache_spec.block_size
|
||||
for g in kv_cache_config.kv_cache_groups)
|
||||
) == 1, "Only one block size is supported for now"
|
||||
self.block_size = kv_cache_config.kv_cache_groups[
|
||||
0].kv_cache_spec.block_size
|
||||
|
||||
self.block_pool = BlockPool(self.num_gpu_blocks, enable_caching,
|
||||
enable_kv_cache_events)
|
||||
|
||||
self.single_type_manager = get_manager_for_kv_cache_spec(
|
||||
kv_cache_spec=kv_cache_spec,
|
||||
block_pool=self.block_pool,
|
||||
self.coordinator = get_kv_cache_coordinator(
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_model_len=self.max_model_len,
|
||||
use_eagle=self.use_eagle,
|
||||
num_kv_cache_groups=1,
|
||||
enable_caching=enable_caching,
|
||||
caching_hash_fn=self.caching_hash_fn,
|
||||
enable_kv_cache_events=enable_kv_cache_events,
|
||||
)
|
||||
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
|
||||
self.block_pool = self.coordinator.block_pool
|
||||
self.kv_cache_config = kv_cache_config
|
||||
|
||||
# Mapping from request ID to kv block hashes.
|
||||
# This is to avoid recomputing the block hashes for each call of
|
||||
@ -133,7 +145,7 @@ class KVCacheManager:
|
||||
# When the request requires prompt logprobs, we skip prefix caching.
|
||||
if (not self.enable_caching
|
||||
or request.sampling_params.prompt_logprobs is not None):
|
||||
return KVCacheBlocks.create_empty(), 0
|
||||
return self.create_empty_block_list(), 0
|
||||
|
||||
# The block hashes for the request may already be computed
|
||||
# if the scheduler has tried to schedule the request before.
|
||||
@ -154,20 +166,16 @@ class KVCacheManager:
|
||||
# num_computed_tokens to be block-size aligned. Removing this limitation
|
||||
# could slightly improve performance in the future.
|
||||
max_cache_hit_length = request.num_tokens - 1
|
||||
|
||||
computed_blocks = self.single_type_manager.find_longest_cache_hit(
|
||||
block_hashes, max_cache_hit_length)
|
||||
# NOTE(woosuk): Since incomplete blocks are not eligible for
|
||||
# sharing, `num_computed_tokens` is always a multiple of
|
||||
# `block_size`.
|
||||
num_computed_tokens = len(computed_blocks) * self.block_size
|
||||
computed_blocks, num_new_computed_tokens = (
|
||||
self.coordinator.find_longest_cache_hit(block_hashes,
|
||||
max_cache_hit_length))
|
||||
|
||||
if self.log_stats:
|
||||
assert self.prefix_cache_stats is not None
|
||||
self.prefix_cache_stats.queries += request.num_tokens
|
||||
self.prefix_cache_stats.hits += num_computed_tokens
|
||||
self.prefix_cache_stats.hits += num_new_computed_tokens
|
||||
|
||||
return KVCacheBlocks(computed_blocks), num_computed_tokens
|
||||
return KVCacheBlocks(computed_blocks), num_new_computed_tokens
|
||||
|
||||
def allocate_slots(
|
||||
self,
|
||||
@ -220,7 +228,9 @@ class KVCacheManager:
|
||||
if new_computed_blocks is not None:
|
||||
new_computed_block_list = new_computed_blocks.blocks
|
||||
else:
|
||||
new_computed_block_list = []
|
||||
new_computed_block_list = [
|
||||
[] for _ in range(len(self.kv_cache_config.kv_cache_groups))
|
||||
]
|
||||
|
||||
# Free the blocks that are skipped during the attention computation
|
||||
# (e.g., tokens outside the sliding window).
|
||||
@ -228,8 +238,8 @@ class KVCacheManager:
|
||||
# insufficient free blocks.
|
||||
# Should call this function before allocating new blocks to reduce
|
||||
# the number of evicted blocks.
|
||||
self.single_type_manager.remove_skipped_blocks(
|
||||
request.request_id, request.num_computed_tokens)
|
||||
self.coordinator.remove_skipped_blocks(request.request_id,
|
||||
request.num_computed_tokens)
|
||||
|
||||
# The number of computed tokens is the number of computed tokens plus
|
||||
# the new prefix caching hits
|
||||
@ -238,12 +248,12 @@ class KVCacheManager:
|
||||
num_tokens_need_slot = min(
|
||||
num_computed_tokens + num_new_tokens + num_lookahead_tokens,
|
||||
self.max_model_len)
|
||||
num_blocks_to_allocate = (
|
||||
self.single_type_manager.get_num_blocks_to_allocate(
|
||||
request_id=request.request_id,
|
||||
num_tokens=num_tokens_need_slot,
|
||||
new_computed_blocks=new_computed_block_list,
|
||||
))
|
||||
|
||||
num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate(
|
||||
request_id=request.request_id,
|
||||
num_tokens=num_tokens_need_slot,
|
||||
new_computed_blocks=new_computed_block_list,
|
||||
)
|
||||
|
||||
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
|
||||
# Cannot allocate new blocks
|
||||
@ -253,16 +263,16 @@ class KVCacheManager:
|
||||
if self.enable_caching:
|
||||
self.block_pool.touch(new_computed_block_list)
|
||||
else:
|
||||
assert not new_computed_block_list, (
|
||||
assert all(not blocks for blocks in new_computed_block_list), (
|
||||
"Computed blocks should be empty when "
|
||||
"prefix caching is disabled")
|
||||
|
||||
# Append the new computed blocks to the request blocks until now to
|
||||
# avoid the case where the new blocks cannot be allocated.
|
||||
self.single_type_manager.save_new_computed_blocks(
|
||||
request.request_id, new_computed_block_list)
|
||||
self.coordinator.save_new_computed_blocks(request.request_id,
|
||||
new_computed_block_list)
|
||||
|
||||
new_blocks = self.single_type_manager.allocate_new_blocks(
|
||||
new_blocks = self.coordinator.allocate_new_blocks(
|
||||
request.request_id, num_tokens_need_slot)
|
||||
|
||||
# P/D: delay caching blocks if we have to recv from
|
||||
@ -273,7 +283,7 @@ class KVCacheManager:
|
||||
# Speculated tokens might be rejected in the future, so we does
|
||||
# not cache any speculated tokens. We only cache blocks with
|
||||
# generated (accepted) tokens.
|
||||
self.single_type_manager.cache_blocks(
|
||||
self.coordinator.cache_blocks(
|
||||
request, self.req_to_block_hashes[request.request_id],
|
||||
num_computed_tokens + num_new_tokens - num_draft_tokens)
|
||||
|
||||
@ -287,7 +297,7 @@ class KVCacheManager:
|
||||
Args:
|
||||
request: The request to free the blocks.
|
||||
"""
|
||||
self.single_type_manager.free(request.request_id)
|
||||
self.coordinator.free(request.request_id)
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
"""Reset prefix cache. This function may be used in RLHF
|
||||
@ -345,10 +355,8 @@ class KVCacheManager:
|
||||
group.
|
||||
"""
|
||||
assert request.status == RequestStatus.RUNNING
|
||||
return [
|
||||
self.single_type_manager.get_num_common_prefix_blocks(
|
||||
request.request_id, num_running_requests)
|
||||
]
|
||||
return self.coordinator.get_num_common_prefix_blocks(
|
||||
request.request_id, num_running_requests)
|
||||
|
||||
def free_block_hashes(self, request: Request) -> None:
|
||||
"""Discard the block hashes for the request.
|
||||
@ -368,6 +376,15 @@ class KVCacheManager:
|
||||
|
||||
def get_block_ids(self, request_id: str) -> list[list[int]]:
|
||||
"""Get the block ids of a request."""
|
||||
assert request_id in self.single_type_manager.req_to_blocks
|
||||
return KVCacheBlocks(self.single_type_manager.req_to_blocks[request_id]
|
||||
).get_block_ids()
|
||||
return KVCacheBlocks(
|
||||
self.coordinator.get_blocks(request_id)).get_block_ids()
|
||||
|
||||
def cache_blocks(self, request: Request, block_hashes: list[BlockHash],
|
||||
num_computed_tokens: int) -> None:
|
||||
"""Cache the blocks for the request."""
|
||||
self.coordinator.cache_blocks(request, block_hashes,
|
||||
num_computed_tokens)
|
||||
|
||||
def create_empty_block_list(self) -> KVCacheBlocks:
|
||||
"""Creates a new KVCacheBlocks instance with no blocks."""
|
||||
return KVCacheBlocks([[] for _ in range(self.num_kv_cache_groups)])
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""KV-Cache Utilities."""
|
||||
|
||||
import os
|
||||
from collections import deque
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Iterable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, NamedTuple, Optional
|
||||
@ -33,6 +34,18 @@ class BlockHash(NamedTuple):
|
||||
extra_keys: Optional[Any] = None
|
||||
|
||||
|
||||
class BlockHashWithGroupId(NamedTuple):
|
||||
# The hash value for the contents (e.g., token_ids) of a block without group
|
||||
# ID. The value is the same for blocks representing the same tokens but for
|
||||
# different groups.
|
||||
block_hash: BlockHash
|
||||
# The KV cache group ID.
|
||||
group_id: int
|
||||
|
||||
def get_hash_value(self) -> int:
|
||||
return self.block_hash.hash_value
|
||||
|
||||
|
||||
# The hash seed for the first block of the prefix block sequence.
|
||||
#
|
||||
# Even if the hash function is the builtin hash(), we use sha256 to generate
|
||||
@ -44,7 +57,7 @@ class BlockHash(NamedTuple):
|
||||
# This aligns with the behavior of Python's hash() function, which also uses
|
||||
# a random seed if PYTHONHASHSEED is not set.
|
||||
NONE_HASH = int.from_bytes(os.urandom(32), byteorder="big") if os.getenv(
|
||||
'PYTHONHASHSEED') is None else sha256(os.getenv('PYTHONHASHSEED'))
|
||||
"PYTHONHASHSEED") is None else sha256(os.getenv("PYTHONHASHSEED"))
|
||||
|
||||
|
||||
class PrefixCachingMetrics:
|
||||
@ -118,7 +131,7 @@ class KVCacheBlock:
|
||||
ref_cnt: int = 0
|
||||
# The hash of the block composed of (block hash, tuple of token IDs).
|
||||
# It is only available when the block is full.
|
||||
_block_hash: Optional[BlockHash] = None
|
||||
_block_hash: Optional[BlockHashWithGroupId] = None
|
||||
|
||||
# Used to construct a doubly linked list for free blocks.
|
||||
# These two attributes should only be manipulated by FreeKVCacheBlockQueue.
|
||||
@ -135,11 +148,11 @@ class KVCacheBlock:
|
||||
self.ref_cnt -= 1
|
||||
|
||||
@property
|
||||
def block_hash(self) -> Optional[BlockHash]:
|
||||
def block_hash(self) -> Optional[BlockHashWithGroupId]:
|
||||
return self._block_hash
|
||||
|
||||
@block_hash.setter
|
||||
def block_hash(self, block_hash: BlockHash):
|
||||
def block_hash(self, block_hash: BlockHashWithGroupId):
|
||||
assert self.block_hash is None, (
|
||||
"The block already has a hash. This should not happen.")
|
||||
self._block_hash = block_hash
|
||||
@ -151,10 +164,10 @@ class KVCacheBlock:
|
||||
def __repr__(self) -> str:
|
||||
# Use block_id instead of KVCacheBlock object to avoid calling __repr__
|
||||
# on KVCacheBlock object recursively.
|
||||
prev_block_id = self.prev_free_block.block_id \
|
||||
if self.prev_free_block else None
|
||||
next_block_id = self.next_free_block.block_id \
|
||||
if self.next_free_block else None
|
||||
prev_block_id = (self.prev_free_block.block_id
|
||||
if self.prev_free_block else None)
|
||||
next_block_id = (self.next_free_block.block_id
|
||||
if self.next_free_block else None)
|
||||
return (f"KVCacheBlock(block_id={self.block_id}, "
|
||||
f"ref_cnt={self.ref_cnt}, "
|
||||
f"_block_hash={self._block_hash}, "
|
||||
@ -570,20 +583,20 @@ def create_kv_cache_group_specs(
|
||||
kv_cache_spec: dict[str, KVCacheSpec],
|
||||
grouped_layer_names: list[list[str]]) -> list[KVCacheGroupSpec]:
|
||||
"""
|
||||
Create KVCacheGroupSpec object for each kv cache group layer.
|
||||
The layers in the same group should share the same
|
||||
KVCacheSpec.
|
||||
Create KVCacheGroupSpec object for each kv cache group layer.
|
||||
The layers in the same group should share the same
|
||||
KVCacheSpec.
|
||||
|
||||
Args:
|
||||
kv_cache_spec:
|
||||
A mapping from each layer name to its corresponding KVCacheSpec.
|
||||
grouped_layer_names:
|
||||
A list of kv cache groups, where each element is a list of layer
|
||||
names that belong to the same group and should share the same
|
||||
KVCacheSpec.
|
||||
Returns:
|
||||
A list of KVCacheGroupSpec objects, one for each group.
|
||||
"""
|
||||
Args:
|
||||
kv_cache_spec:
|
||||
A mapping from each layer name to its corresponding KVCacheSpec.
|
||||
grouped_layer_names:
|
||||
A list of kv cache groups, where each element is a list of layer
|
||||
names that belong to the same group and should share the same
|
||||
KVCacheSpec.
|
||||
Returns:
|
||||
A list of KVCacheGroupSpec objects, one for each group.
|
||||
"""
|
||||
kv_cache_groups = []
|
||||
for layer_names_one_group in grouped_layer_names:
|
||||
layer_specs = [
|
||||
@ -628,6 +641,37 @@ def get_max_concurrency_for_kv_cache_config(
|
||||
return max_concurrency
|
||||
|
||||
|
||||
def get_num_blocks(vllm_config: VllmConfig, num_layers: int,
|
||||
available_memory: int, page_size: int) -> int:
|
||||
"""
|
||||
Get the number of kv cache blocks.
|
||||
|
||||
Args:
|
||||
vllm_config: The global VllmConfig
|
||||
num_layers: The number of layers
|
||||
available_memory: Memory available for KV cache in bytes.
|
||||
page_size: The page size of the KV cache.
|
||||
"""
|
||||
num_blocks = int(available_memory // page_size // num_layers)
|
||||
num_blocks = max(num_blocks, 0)
|
||||
if vllm_config.cache_config.num_gpu_blocks_override is not None:
|
||||
num_gpu_blocks_override = \
|
||||
vllm_config.cache_config.num_gpu_blocks_override
|
||||
logger.info(
|
||||
"Overriding num_gpu_blocks=%d with "
|
||||
"num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override)
|
||||
return num_blocks
|
||||
|
||||
|
||||
def get_uniform_page_size(kv_cache_spec: dict[str, KVCacheSpec]) -> int:
|
||||
"""
|
||||
Get the page size of the KV cache.
|
||||
"""
|
||||
page_sizes = set(layer.page_size_bytes for layer in kv_cache_spec.values())
|
||||
assert len(page_sizes) == 1
|
||||
return page_sizes.pop()
|
||||
|
||||
|
||||
def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
|
||||
kv_cache_spec: dict[str, KVCacheSpec],
|
||||
available_memory: int) -> KVCacheConfig:
|
||||
@ -644,32 +688,24 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
|
||||
The generated KVCacheConfig
|
||||
"""
|
||||
|
||||
page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()}
|
||||
assert len(page_sizes) == 1
|
||||
page_size = page_sizes.pop()
|
||||
|
||||
num_blocks = int(available_memory // page_size // len(kv_cache_spec))
|
||||
num_blocks = max(num_blocks, 0)
|
||||
|
||||
if vllm_config.cache_config.num_gpu_blocks_override is not None:
|
||||
num_gpu_blocks_override = \
|
||||
vllm_config.cache_config.num_gpu_blocks_override
|
||||
logger.info(
|
||||
"Overriding num_gpu_blocks=%d with "
|
||||
"num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override)
|
||||
num_blocks = num_gpu_blocks_override
|
||||
page_size = get_uniform_page_size(kv_cache_spec)
|
||||
num_blocks = get_num_blocks(vllm_config, len(kv_cache_spec),
|
||||
available_memory, page_size)
|
||||
|
||||
per_layer_size = page_size * num_blocks
|
||||
# All layers have the same KV cache spec, so we create one kv cache group
|
||||
# for all layers.
|
||||
grouped_layer_names = [list(kv_cache_spec.keys())]
|
||||
|
||||
# Each layer uses a separate Tensor to store its KV cache.
|
||||
kv_cache_tensors = [
|
||||
KVCacheTensor(size=per_layer_size, shared_by=[layer_name])
|
||||
for layer_name in kv_cache_spec
|
||||
]
|
||||
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=num_blocks,
|
||||
tensors={
|
||||
layer_name: KVCacheTensor(size=per_layer_size)
|
||||
for layer_name in kv_cache_spec
|
||||
},
|
||||
kv_cache_tensors=kv_cache_tensors,
|
||||
kv_cache_groups=create_kv_cache_group_specs(kv_cache_spec,
|
||||
grouped_layer_names),
|
||||
)
|
||||
@ -685,17 +721,185 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
|
||||
return kv_cache_config
|
||||
|
||||
|
||||
def is_kv_cache_page_size_uniform(
|
||||
kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
|
||||
"""
|
||||
Whether all layers in the given KVCacheSpec have the same page size.
|
||||
Args:
|
||||
kv_cache_spec: The KVCacheSpec of each attention layer in the model
|
||||
|
||||
Returns:
|
||||
True if all layers have the same page size, False otherwise.
|
||||
"""
|
||||
|
||||
page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()}
|
||||
return len(page_sizes) == 1
|
||||
|
||||
|
||||
def _get_kv_cache_config_uniform_page_size(
|
||||
vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec],
|
||||
available_memory: int) -> KVCacheConfig:
|
||||
"""
|
||||
Generates the KV cache configuration for hybrid models with multiple
|
||||
attention types but still with a uniform page size (physical memory per
|
||||
block per layer) for all layers.
|
||||
|
||||
Detailed explanation about kv cache management of hybrid models:
|
||||
The layers in the models are repeated with some patterns, e.g., a model
|
||||
with 10 full attention layers and 20 sliding window attention layers can be
|
||||
regarded as repeating the pattern (1 * full, 2 * sw) 10 times.
|
||||
The KVCacheManager allocates different block tables for each of the 3 layers
|
||||
in the pattern, and repeats each of them 10 times to generate the
|
||||
block_table for the 30 layers in the model.
|
||||
Therefore, we can group the layers in the model into 3 kv_cache_groups, each
|
||||
of which contains 10 layers in the model.
|
||||
The KVCacheManager allocates the block_table for each group based on its
|
||||
kv_cache spec, and the model runner applies the block table to each layer
|
||||
in the group.
|
||||
For example:
|
||||
1. A model only uses full attention. The pattern is
|
||||
(num_hidden_layers * full), so there is only one group and the block table
|
||||
is shared by all layers. It is already handled by
|
||||
`_get_kv_cache_config_uniform_type`.
|
||||
2. A model with 10 full attention layers and 20 sliding window
|
||||
attention layers. There are 3 layers in the pattern (1 * full, 2 * sw), so
|
||||
there are 3 kv_cache_groups, each of which represents 10 layers.
|
||||
|
||||
To simplify the implementation, we make the following assumptions:
|
||||
1. Physical memory per block: Must be the same across all KV cache groups.
|
||||
Breaking this assumption is non-trivial due to memory fragmentation concerns
|
||||
when allocating blocks of different sizes.
|
||||
2. Tokens per block (block_size): Currently, we directly use
|
||||
`CacheConfig.block_size` for all layers. It can be extended to vary by KV
|
||||
cache group, but within each KV cache group, all layers must share the same
|
||||
block size.
|
||||
3. Physical memory per token per layer: This property is decided by model
|
||||
config. Currently we only support models that have the same physical memory
|
||||
per token per layer for all layers. Can be relaxed with a simple extension,
|
||||
but still need to keep physical memory per block the same for all groups.
|
||||
4. Number of layers per group: Currently assumed the same for all layers.
|
||||
Can be relaxed with a simple extension, but still need to keep physical
|
||||
memory per block the same for all groups.
|
||||
5. Attention type within groups: All layers in a group must share the same
|
||||
attention type. One exception is that, when
|
||||
`--disable-hybrid-kv-cache-manager` is true, the single group for full
|
||||
attention layers may also include attention layers using sliding window or
|
||||
LLaMA 4 local attention. See `unify_hybrid_kv_cache_specs` for more details.
|
||||
6. Support for multiple attention types: The design for most components is
|
||||
general to an arbitrary number of attention types. But
|
||||
`find_longest_cache_hit` only supports one attention type or two
|
||||
types of full-attention plus exactly one another type. The general
|
||||
implementation of this function is feasible but we don't know how to
|
||||
implement it cleanly yet.
|
||||
|
||||
As we assume tokens per block, physical memory per token per layer, and
|
||||
number of layers per group are the same now, we can ensure that physical
|
||||
memory per block is the same for all groups.
|
||||
|
||||
Args:
|
||||
vllm_config: The global VllmConfig
|
||||
kv_cache_spec: The KVCacheSpec of each attention layer in the model
|
||||
available_memory: Memory available for KV cache in bytes.
|
||||
Returns:
|
||||
The generated KVCacheConfig
|
||||
"""
|
||||
# Group all layers by type_id.
|
||||
# E.g., 2 full attention layers and 3 sliding window attention layers,
|
||||
# -> (full.0, full.1), (sw.0, sw.1, sw.2).
|
||||
same_type_layers: dict[str, list[str]] = defaultdict(list)
|
||||
for layer_name, layer_spec in kv_cache_spec.items():
|
||||
same_type_layers[layer_spec.type_id].append(layer_name)
|
||||
|
||||
# Split each group into smaller groups, to make the number of layers in each
|
||||
# group identical. Add padding to the last group of each type if necessary.
|
||||
# E.g., (full.0, full.1), (sw.0, sw.1, sw.2)
|
||||
# split to 3 groups with 2 layers each:
|
||||
# (full.0, full.1), (sw.0, sw.1), (sw.2, padding).
|
||||
# FIXME(Chen): At the moment of writing this code (2025-06-02), all
|
||||
# open-source hybrid model follows a n:1 pattern between different attention
|
||||
# types (e.g., Gemma3 5:1 between sw and full, LLaMA4 3:1 between local and
|
||||
# full), so we can use the "1" in the n:1 pattern as the group size, which
|
||||
# is the minimum number of layers among all attention types. Need a better
|
||||
# strategy if we want to support more complex patterns (e.g., 20 full + 30
|
||||
# sw, where the group size should be 10).
|
||||
group_size = min([len(layers) for layers in same_type_layers.values()])
|
||||
grouped_layers = []
|
||||
for layers in same_type_layers.values():
|
||||
num_padding_layers = group_size - len(layers) % group_size
|
||||
if num_padding_layers != group_size:
|
||||
logger.warning(
|
||||
"Add %d padding layers, may waste at most %.2f%% KV cache memory", # noqa
|
||||
num_padding_layers,
|
||||
num_padding_layers / len(layers) * 100,
|
||||
)
|
||||
for i in range(0, len(layers), group_size):
|
||||
grouped_layers.append(layers[i:i + group_size])
|
||||
kv_cache_groups = create_kv_cache_group_specs(kv_cache_spec,
|
||||
grouped_layers)
|
||||
|
||||
# Determine how model runners should initialize the KV cache tensors.
|
||||
# We will have group_size memory pools, each is shared by one layer from
|
||||
# each group. As layers of different groups have different block table,
|
||||
# they will use different parts of the shared Tensor.
|
||||
# The memory layout in the example will be:
|
||||
# full.0, sw.0, sw.2: share a Tensor with size=available_memory//2
|
||||
# full.1, sw.1: share another Tensor with size=available_memory//2
|
||||
page_size = get_uniform_page_size(kv_cache_spec)
|
||||
num_blocks = get_num_blocks(vllm_config, group_size, available_memory,
|
||||
page_size)
|
||||
per_memory_pool_size = page_size * num_blocks
|
||||
kv_cache_tensors = []
|
||||
for i in range(group_size):
|
||||
shared_by = []
|
||||
for j in range(len(kv_cache_groups)):
|
||||
if i < len(grouped_layers[j]):
|
||||
shared_by.append(grouped_layers[j][i])
|
||||
kv_cache_tensors.append(
|
||||
KVCacheTensor(size=per_memory_pool_size, shared_by=shared_by))
|
||||
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=num_blocks,
|
||||
kv_cache_tensors=kv_cache_tensors,
|
||||
kv_cache_groups=kv_cache_groups,
|
||||
)
|
||||
|
||||
# Print the KV cache size and maximum concurrency.
|
||||
num_tokens = num_blocks // len(
|
||||
grouped_layers) * vllm_config.cache_config.block_size
|
||||
num_tokens_str = f"{num_tokens:,}"
|
||||
logger.info("GPU KV cache size: %s tokens", num_tokens_str)
|
||||
max_model_len_str = f"{vllm_config.model_config.max_model_len:,}"
|
||||
max_concurrency = get_max_concurrency_for_kv_cache_config(
|
||||
vllm_config, kv_cache_config)
|
||||
logger.info("Maximum concurrency for %s tokens per request: %.2fx",
|
||||
max_model_len_str, max_concurrency)
|
||||
return kv_cache_config
|
||||
|
||||
|
||||
def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
|
||||
"""
|
||||
Only models with one type of KV cache are supported yet. This function tries
|
||||
to convert the KV cache specs to one type if the model is a hybrid model
|
||||
with multiple type of KV cache. It will convert all SlidingWindowSpec to
|
||||
FullAttentionSpec if both types are present.
|
||||
This function tries to convert the KV cache specs to one type if the model
|
||||
is a hybrid model with multiple type of KV cache. It will convert all
|
||||
SlidingWindowSpec to FullAttentionSpec if both types are present.
|
||||
|
||||
Args:
|
||||
kv_cache_spec: The kv cache spec of each attention layer in the model
|
||||
"""
|
||||
|
||||
def is_hybrid(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
|
||||
type_ids = set(layer_spec.type_id
|
||||
for layer_spec in kv_cache_spec.values())
|
||||
return len(type_ids) > 1
|
||||
|
||||
if not is_hybrid(kv_cache_spec):
|
||||
return
|
||||
|
||||
logger.warning(
|
||||
"Hybrid KV cache manager is disabled for this hybrid model, "
|
||||
"This means we do not enable any optimizations for saving KV cache "
|
||||
"memory (e.g., dropping the KV cache outside the sliding window). "
|
||||
"The compute of layers like sliding window is still saved.")
|
||||
|
||||
has_full_attention = any(
|
||||
isinstance(spec, FullAttentionSpec) for spec in kv_cache_spec.values())
|
||||
has_sliding_window = any(
|
||||
@ -712,13 +916,18 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
|
||||
sliding_window=spec.sliding_window,
|
||||
)
|
||||
|
||||
if is_hybrid(kv_cache_spec):
|
||||
raise ValueError("Hybrid KV cache manager is disabled but failed to "
|
||||
"convert the KV cache specs to one unified type.")
|
||||
|
||||
def get_kv_cache_config(vllm_config: VllmConfig,
|
||||
kv_cache_spec: dict[str, KVCacheSpec],
|
||||
available_memory: int) -> KVCacheConfig:
|
||||
|
||||
def get_kv_cache_config(
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_spec: dict[str, KVCacheSpec],
|
||||
available_memory: int,
|
||||
) -> KVCacheConfig:
|
||||
"""
|
||||
Generates the KV cache configuration for a model
|
||||
TODO: support hybrid models with more than one type of KV cache.
|
||||
Generates the KV cache configuration for a model.
|
||||
|
||||
Args:
|
||||
vllm_config: The global VllmConfig
|
||||
@ -728,14 +937,25 @@ def get_kv_cache_config(vllm_config: VllmConfig,
|
||||
Returns:
|
||||
The generated KVCacheConfigs
|
||||
"""
|
||||
unify_hybrid_kv_cache_specs(kv_cache_spec)
|
||||
check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)
|
||||
|
||||
if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager:
|
||||
unify_hybrid_kv_cache_specs(kv_cache_spec)
|
||||
|
||||
if is_kv_cache_type_uniform(kv_cache_spec):
|
||||
# KV cache of all layers are the same, which is true for
|
||||
# most models. Allocate the same amount of memory for
|
||||
# each layer.
|
||||
return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec,
|
||||
available_memory)
|
||||
elif is_kv_cache_page_size_uniform(kv_cache_spec):
|
||||
# Model contains multiple attention types, but KV cache of all layers
|
||||
# have the same physical memory per block per layer. Split the layers
|
||||
# into groups with the same number of layers, and thus same total page
|
||||
# size.
|
||||
return _get_kv_cache_config_uniform_page_size(vllm_config,
|
||||
kv_cache_spec,
|
||||
available_memory)
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ from vllm.logger import init_logger
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
|
||||
compute_encoder_budget)
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheManager
|
||||
from vllm.v1.core.sched.interface import SchedulerInterface
|
||||
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
|
||||
SchedulerOutput)
|
||||
@ -377,7 +377,8 @@ class Scheduler(SchedulerInterface):
|
||||
# KVTransfer: WAITING reqs have num_computed_tokens > 0
|
||||
# after async KV recvs are completed.
|
||||
else:
|
||||
new_computed_blocks = KVCacheBlocks.create_empty()
|
||||
new_computed_blocks = (
|
||||
self.kv_cache_manager.create_empty_block_list())
|
||||
num_new_local_computed_tokens = 0
|
||||
num_computed_tokens = request.num_computed_tokens
|
||||
|
||||
@ -1010,7 +1011,7 @@ class Scheduler(SchedulerInterface):
|
||||
num_computed_tokens = len(block_ids) * self.block_size
|
||||
if num_computed_tokens == request.num_tokens:
|
||||
num_computed_tokens -= 1
|
||||
self.kv_cache_manager.single_type_manager.cache_blocks(
|
||||
self.kv_cache_manager.cache_blocks(
|
||||
request,
|
||||
self.kv_cache_manager.req_to_block_hashes[request.request_id],
|
||||
num_computed_tokens,
|
||||
|
||||
@ -22,8 +22,7 @@ class SingleTypeKVCacheManager(ABC):
|
||||
self,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
block_pool: BlockPool,
|
||||
use_eagle: bool,
|
||||
num_kv_cache_groups: int,
|
||||
kv_cache_group_id: int,
|
||||
caching_hash_fn: Callable,
|
||||
) -> None:
|
||||
"""
|
||||
@ -31,9 +30,7 @@ class SingleTypeKVCacheManager(ABC):
|
||||
Args:
|
||||
kv_cache_spec: The kv_cache_spec for this manager.
|
||||
block_pool: The block pool.
|
||||
use_eagle: Whether to use eagle.
|
||||
num_kv_cache_groups: The number of kv cache groups managed by this
|
||||
manager.
|
||||
kv_cache_group_id: The id of the kv cache group of this manager.
|
||||
caching_hash_fn: The caching hash function.
|
||||
"""
|
||||
|
||||
@ -41,9 +38,6 @@ class SingleTypeKVCacheManager(ABC):
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.block_pool = block_pool
|
||||
|
||||
# Needs special handling for find_longest_cache_hit if eagle is enabled
|
||||
self.use_eagle = use_eagle
|
||||
|
||||
# Mapping from request ID to blocks to track the blocks allocated
|
||||
# for each request, so that we can free the blocks when the request
|
||||
# is finished.
|
||||
@ -56,8 +50,8 @@ class SingleTypeKVCacheManager(ABC):
|
||||
# data for reempted ones.
|
||||
self.num_cached_block: dict[str, int] = {}
|
||||
|
||||
self.num_kv_cache_groups = num_kv_cache_groups
|
||||
self.caching_hash_fn = caching_hash_fn
|
||||
self.kv_cache_group_id = kv_cache_group_id
|
||||
|
||||
def get_num_blocks_to_allocate(
|
||||
self, request_id: str, num_tokens: int,
|
||||
@ -86,8 +80,7 @@ class SingleTypeKVCacheManager(ABC):
|
||||
num_evictable_computed_blocks = sum(
|
||||
blk.ref_cnt == 0 and not blk.is_null
|
||||
for blk in new_computed_blocks)
|
||||
return ((num_new_blocks + num_evictable_computed_blocks) *
|
||||
self.num_kv_cache_groups)
|
||||
return num_new_blocks + num_evictable_computed_blocks
|
||||
|
||||
def save_new_computed_blocks(
|
||||
self, request_id: str,
|
||||
@ -130,8 +123,7 @@ class SingleTypeKVCacheManager(ABC):
|
||||
if num_new_blocks <= 0:
|
||||
return []
|
||||
else:
|
||||
new_blocks = self.block_pool.get_new_blocks(
|
||||
num_new_blocks * self.num_kv_cache_groups)
|
||||
new_blocks = self.block_pool.get_new_blocks(num_new_blocks)
|
||||
req_blocks.extend(new_blocks)
|
||||
return new_blocks
|
||||
|
||||
@ -156,12 +148,19 @@ class SingleTypeKVCacheManager(ABC):
|
||||
num_cached_blocks=num_cached_blocks,
|
||||
num_full_blocks=num_full_blocks,
|
||||
block_size=self.block_size,
|
||||
kv_cache_group_id=self.kv_cache_group_id,
|
||||
hash_fn=self.caching_hash_fn,
|
||||
)
|
||||
|
||||
self.num_cached_block[request.request_id] = num_full_blocks
|
||||
|
||||
def free(self, request_id: str) -> None:
|
||||
"""
|
||||
Free the blocks for the request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
"""
|
||||
# Default to [] in case a request is freed (aborted) before alloc.
|
||||
req_blocks = self.req_to_blocks.pop(request_id, [])
|
||||
|
||||
@ -188,12 +187,22 @@ class SingleTypeKVCacheManager(ABC):
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def find_longest_cache_hit(self, block_hashes: list[BlockHash],
|
||||
max_length: int) -> list[KVCacheBlock]:
|
||||
def find_longest_cache_hit(
|
||||
cls,
|
||||
block_hashes: list[BlockHash],
|
||||
max_length: int,
|
||||
kv_cache_group_ids: list[int],
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
) -> list[list[KVCacheBlock]]:
|
||||
"""
|
||||
Get the longest cache hit prefix of the blocks that is not longer than
|
||||
`max_length`. If no cache hit is found, return an empty list.
|
||||
`max_length`. The prefix should be a common prefix hit for all the
|
||||
kv cache groups in `kv_cache_group_ids`. If no cache hit is found,
|
||||
return an empty list.
|
||||
If eagle is enabled, drop the last matched block to force recompute the
|
||||
last block to get the required hidden states for eagle drafting head.
|
||||
Need to be customized for each attention type.
|
||||
@ -201,12 +210,20 @@ class SingleTypeKVCacheManager(ABC):
|
||||
Args:
|
||||
block_hashes: The block hashes of the request.
|
||||
max_length: The maximum length of the cache hit prefix.
|
||||
kv_cache_group_ids: The ids of the kv cache groups.
|
||||
block_pool: The block pool.
|
||||
kv_cache_spec: The kv cache spec.
|
||||
use_eagle: Whether to use eagle.
|
||||
|
||||
Returns:
|
||||
A list of cached blocks with skipped blocks replaced by null block.
|
||||
A list of cached blocks with skipped blocks replaced by null block
|
||||
for each kv cache group in `kv_cache_group_ids`.
|
||||
Return a list of length `len(kv_cache_group_ids)`, where the i-th
|
||||
element is a list of cached blocks for the i-th kv cache group
|
||||
in `kv_cache_group_ids`.
|
||||
For example, sliding window manager should return a list like
|
||||
[NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)] for block size 4 and
|
||||
sliding window 8.
|
||||
[[NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)]] for block size 4
|
||||
and sliding window 8 and len(kv_cache_group_ids) = 1.
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
@ -215,11 +232,9 @@ class SingleTypeKVCacheManager(ABC):
|
||||
def remove_skipped_blocks(self, request_id: str,
|
||||
num_computed_tokens: int) -> None:
|
||||
"""
|
||||
Remove the blocks that are no longer needed from `blocks`. The removed
|
||||
blocks should be replaced by null_block. Return the removed blocks in
|
||||
eviction order, where the first returned block should be evicted first.
|
||||
Don't free the removed blocks in this function. Need to be customized
|
||||
for each attention type.
|
||||
Remove the blocks that are no longer needed from `blocks` and free the
|
||||
blocks. The removed blocks should be replaced by null_block.
|
||||
Need to be customized for each attention type.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
@ -230,21 +245,36 @@ class SingleTypeKVCacheManager(ABC):
|
||||
|
||||
class FullAttentionManager(SingleTypeKVCacheManager):
|
||||
|
||||
def find_longest_cache_hit(self, block_hashes: list[BlockHash],
|
||||
max_length: int) -> list[KVCacheBlock]:
|
||||
computed_blocks: list[KVCacheBlock] = []
|
||||
max_num_blocks = max_length // self.block_size
|
||||
@classmethod
|
||||
def find_longest_cache_hit(
|
||||
cls,
|
||||
block_hashes: list[BlockHash],
|
||||
max_length: int,
|
||||
kv_cache_group_ids: list[int],
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
) -> list[list[KVCacheBlock]]:
|
||||
assert isinstance(kv_cache_spec, FullAttentionSpec), (
|
||||
"FullAttentionManager can only be used for full attention groups")
|
||||
computed_blocks: list[list[KVCacheBlock]] = [
|
||||
[] for _ in range(len(kv_cache_group_ids))
|
||||
]
|
||||
max_num_blocks = max_length // kv_cache_spec.block_size
|
||||
for i in range(max_num_blocks):
|
||||
block_hash = block_hashes[i]
|
||||
# block_hashes is a chain of block hashes. If a block hash is not
|
||||
# in the cached_block_hash_to_id, the following block hashes are
|
||||
# not computed yet for sure.
|
||||
if cached_block := self.block_pool.get_cached_block(block_hash):
|
||||
computed_blocks.append(cached_block)
|
||||
if cached_block := block_pool.get_cached_block(
|
||||
block_hash, kv_cache_group_ids):
|
||||
for j in range(len(kv_cache_group_ids)):
|
||||
computed_blocks[j].append(cached_block[j])
|
||||
else:
|
||||
break
|
||||
if self.use_eagle and len(computed_blocks) > 0:
|
||||
computed_blocks.pop()
|
||||
if use_eagle and len(computed_blocks[0]) > 0:
|
||||
for j in range(len(kv_cache_group_ids)):
|
||||
computed_blocks[j].pop()
|
||||
return computed_blocks
|
||||
|
||||
def remove_skipped_blocks(self, request_id: str,
|
||||
@ -267,45 +297,58 @@ class FullAttentionManager(SingleTypeKVCacheManager):
|
||||
class SlidingWindowManager(SingleTypeKVCacheManager):
|
||||
|
||||
def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool,
|
||||
use_eagle: bool, **kwargs) -> None:
|
||||
super().__init__(kv_cache_spec, block_pool, use_eagle, **kwargs)
|
||||
**kwargs) -> None:
|
||||
super().__init__(kv_cache_spec, block_pool, **kwargs)
|
||||
self.sliding_window = kv_cache_spec.sliding_window
|
||||
self._null_block = block_pool.null_block
|
||||
|
||||
@classmethod
|
||||
def find_longest_cache_hit(
|
||||
cls,
|
||||
block_hashes: list[BlockHash],
|
||||
max_length: int,
|
||||
kv_cache_group_ids: list[int],
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
) -> list[list[KVCacheBlock]]:
|
||||
assert isinstance(kv_cache_spec, SlidingWindowSpec), (
|
||||
"SlidingWindowManager can only be used for sliding window groups")
|
||||
|
||||
# The number of contiguous blocks needed for prefix cache hit.
|
||||
# -1 since the input token itself is also included in the window
|
||||
self.sliding_window_contiguous_blocks = cdiv(
|
||||
(kv_cache_spec.sliding_window - 1), self.block_size)
|
||||
if self.use_eagle:
|
||||
sliding_window_contiguous_blocks = cdiv(
|
||||
kv_cache_spec.sliding_window - 1, kv_cache_spec.block_size)
|
||||
if use_eagle:
|
||||
# Need to drop the last matched block if eagle is enabled. For
|
||||
# sliding window layer, we achieve this by increasing the number of
|
||||
# contiguous blocks needed for prefix cache hit by one and dropping
|
||||
# the last matched block.
|
||||
self.sliding_window_contiguous_blocks += 1
|
||||
self._null_block = block_pool.null_block
|
||||
sliding_window_contiguous_blocks += 1
|
||||
|
||||
def find_longest_cache_hit(self, block_hashes: list[BlockHash],
|
||||
max_length: int) -> list[KVCacheBlock]:
|
||||
# TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to
|
||||
# optimize the time complexity from O(max_num_blocks) to
|
||||
# O(max_num_blocks / sliding_window_contiguous_blocks +
|
||||
# sliding_window_contiguous_blocks),
|
||||
# which is good for low cache hit rate scenarios.
|
||||
max_num_blocks = max_length // self.block_size
|
||||
computed_blocks = [self._null_block] * max_num_blocks
|
||||
max_num_blocks = max_length // kv_cache_spec.block_size
|
||||
computed_blocks = [[block_pool.null_block] * max_num_blocks
|
||||
for _ in range(len(kv_cache_group_ids))]
|
||||
num_contiguous_blocks = 0
|
||||
|
||||
match_found = False
|
||||
# Search from right to left and early stop when a match is found.
|
||||
for i in range(max_num_blocks - 1, -1, -1):
|
||||
if cached_block := self.block_pool.get_cached_block(
|
||||
block_hashes[i]):
|
||||
computed_blocks[i] = cached_block
|
||||
if cached_block := block_pool.get_cached_block(
|
||||
block_hashes[i], kv_cache_group_ids):
|
||||
for j in range(len(kv_cache_group_ids)):
|
||||
computed_blocks[j][i] = cached_block[j]
|
||||
num_contiguous_blocks += 1
|
||||
if (num_contiguous_blocks
|
||||
>= self.sliding_window_contiguous_blocks):
|
||||
if (num_contiguous_blocks >= sliding_window_contiguous_blocks):
|
||||
# Trim the trailing blocks.
|
||||
# E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3]
|
||||
# when sliding_window_contiguous_blocks=2.
|
||||
del computed_blocks[i + num_contiguous_blocks:]
|
||||
for j in range(len(kv_cache_group_ids)):
|
||||
del computed_blocks[j][i + num_contiguous_blocks:]
|
||||
match_found = True
|
||||
break
|
||||
else:
|
||||
@ -313,9 +356,11 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
|
||||
if not match_found:
|
||||
# The first `num_contiguous_blocks` is a cache hit even if
|
||||
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
|
||||
del computed_blocks[num_contiguous_blocks:]
|
||||
if self.use_eagle and len(computed_blocks) > 0:
|
||||
computed_blocks.pop()
|
||||
for j in range(len(kv_cache_group_ids)):
|
||||
del computed_blocks[j][num_contiguous_blocks:]
|
||||
if use_eagle and len(computed_blocks[0]) > 0:
|
||||
for j in range(len(kv_cache_group_ids)):
|
||||
computed_blocks[j].pop()
|
||||
return computed_blocks
|
||||
|
||||
def remove_skipped_blocks(self, request_id: str,
|
||||
|
||||
@ -157,11 +157,10 @@ class SlidingWindowSpec(AttentionSpec):
|
||||
@dataclass
|
||||
class KVCacheTensor:
|
||||
"""
|
||||
A dataclass for specifying how the workers should initialize the KV cache
|
||||
for a layer. Only contains the size of KV cache for that layer for now. Will
|
||||
be extended to support multiple layers sharing the same memory pool.
|
||||
A class for specifying how the workers should initialize the KV cache.
|
||||
"""
|
||||
size: int # The size of KV cache Tensor in bytes
|
||||
size: int # size of the KV cache tensor in bytes
|
||||
shared_by: list[str] # layer names that share the same KV cache tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -183,27 +182,13 @@ class KVCacheConfig:
|
||||
"""
|
||||
"""The number of KV cache blocks"""
|
||||
num_blocks: int
|
||||
"""layer_name -> how to initialize KV cache for that layer"""
|
||||
tensors: dict[str, KVCacheTensor]
|
||||
"""How should model runner initialize the KV cache tensors for each layer"""
|
||||
kv_cache_tensors: list[KVCacheTensor]
|
||||
"""
|
||||
The kv cache groups of the model.
|
||||
The layers in the models are repeated with some patterns, e.g., a model
|
||||
with 10 full attention layers and 20 sliding window attention layers can be
|
||||
regarded as repeating the pattern (1 * full, 2 * sw) 10 times.
|
||||
The KVCacheManager allocates different block tables for each of the 3 layers
|
||||
in the pattern, and repeats each of them 10 times to generate the
|
||||
block_table for the 30 layers in the model.
|
||||
Therefore, we can group the layers in the model into 3 groups, each of which
|
||||
contains 10 layers in the model.
|
||||
The KVCacheManager allocates the block_table for each group based on its
|
||||
kv_cache spec, and the model runner applies the block table to each layer
|
||||
in the group.
|
||||
For example:
|
||||
1. A model only uses full attention. The pattern is
|
||||
(num_hidden_layers * full), so there is only one group and the block table
|
||||
is shared by all layers.
|
||||
2. (WIP) A model with 10 full attention layers and 20 sliding window
|
||||
attention layers. There are 3 layers in the pattern (1 * full, 2 * sw), so
|
||||
there are 3 groups, each of which represents 10 layers in the model.
|
||||
For models with only one type of attention, there is only one group that
|
||||
contains all layers.
|
||||
For models with multiple types of attention, there will be multiple groups,
|
||||
see `_get_kv_cache_config_uniform_page_size` for more details.
|
||||
"""
|
||||
kv_cache_groups: list[KVCacheGroupSpec]
|
||||
|
||||
@ -2088,33 +2088,58 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
block_sizes=block_sizes,
|
||||
)
|
||||
|
||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
def _allocate_kv_cache_tensors(
|
||||
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Initialize KV cache based on `kv_cache_config`.
|
||||
Initializes the KV cache buffer with the correct size. The buffer needs
|
||||
to be reshaped to the desired shape before being used by the models.
|
||||
|
||||
Args:
|
||||
kv_cache_config: Configuration for the KV cache, including the KV
|
||||
cache size of each layer
|
||||
kv_cache_config: The KV cache config
|
||||
Returns:
|
||||
dict[str, torch.Tensor]: A map between layer names to their
|
||||
corresponding memory buffer for KV cache.
|
||||
"""
|
||||
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
|
||||
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
||||
tensor = torch.zeros(kv_cache_tensor.size,
|
||||
dtype=torch.int8,
|
||||
device=self.device)
|
||||
for layer_name in kv_cache_tensor.shared_by:
|
||||
kv_cache_raw_tensors[layer_name] = tensor
|
||||
|
||||
layer_names = set()
|
||||
for group in kv_cache_config.kv_cache_groups:
|
||||
layer_names.update(group.layer_names)
|
||||
assert layer_names == set(kv_cache_raw_tensors.keys(
|
||||
)), "Some layers are not correctly initialized"
|
||||
return kv_cache_raw_tensors
|
||||
|
||||
def _reshape_kv_cache_tensors(
|
||||
self,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
kv_cache_raw_tensors: dict[str, torch.Tensor],
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.may_reinitialize_input_batch(kv_cache_config)
|
||||
self.initialize_attn_backend(kv_cache_config)
|
||||
Reshape the KV cache tensors to the desired shape and dtype.
|
||||
|
||||
Args:
|
||||
kv_cache_config: The KV cache config
|
||||
kv_cache_raw_tensors: The KV cache buffer of each layer, with
|
||||
correct size but uninitialized shape.
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: A map between layer names to their
|
||||
corresponding memory buffer for KV cache.
|
||||
"""
|
||||
kv_caches: dict[str, torch.Tensor] = {}
|
||||
|
||||
for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
|
||||
kv_cache_spec = kv_cache_group.kv_cache_spec
|
||||
for layer_name in kv_cache_group.layer_names:
|
||||
tensor_config = kv_cache_config.tensors[layer_name]
|
||||
assert tensor_config.size % kv_cache_spec.page_size_bytes == 0
|
||||
num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes
|
||||
# `num_blocks` is the number of blocks the model runner can use.
|
||||
# `kv_cache_config.num_blocks` is the number of blocks that
|
||||
# KVCacheManager may allocate.
|
||||
# Since different GPUs may have different number of layers and
|
||||
# different memory capacities, `num_blocks` can be different on
|
||||
# different GPUs, and `kv_cache_config.num_blocks` is set to
|
||||
# the min of all `num_blocks`. Verify it here.
|
||||
assert num_blocks >= kv_cache_config.num_blocks
|
||||
for i, kv_cache_group_spec in enumerate(
|
||||
kv_cache_config.kv_cache_groups):
|
||||
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
raw_tensor = kv_cache_raw_tensors[layer_name]
|
||||
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
|
||||
num_blocks = (raw_tensor.numel() //
|
||||
kv_cache_spec.page_size_bytes)
|
||||
if isinstance(kv_cache_spec, AttentionSpec):
|
||||
kv_cache_shape = self.attn_backends[i].get_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
@ -2140,13 +2165,29 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
kv_cache_stride_order.index(i)
|
||||
for i in range(len(kv_cache_stride_order))
|
||||
]
|
||||
kv_caches[layer_name] = torch.zeros(
|
||||
kv_cache_shape, dtype=dtype,
|
||||
device=self.device).permute(*inv_order)
|
||||
kv_caches[layer_name] = kv_cache_raw_tensors[
|
||||
layer_name].view(dtype).view(kv_cache_shape).permute(
|
||||
*inv_order)
|
||||
else:
|
||||
# TODO: add new branches when introducing more types of
|
||||
# KV cache specs.
|
||||
raise ValueError("Unknown KV cache spec type.")
|
||||
raise NotImplementedError
|
||||
return kv_caches
|
||||
|
||||
def initialize_kv_cache_tensors(
|
||||
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Initialize the memory buffer for KV cache.
|
||||
|
||||
Args:
|
||||
kv_cache_config: The KV cache config
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: A map between layer names to their
|
||||
corresponding memory buffer for KV cache.
|
||||
"""
|
||||
# Initialize the memory buffer for KV cache
|
||||
kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config)
|
||||
# Change the memory buffer to the desired shape
|
||||
kv_caches = self._reshape_kv_cache_tensors(kv_cache_config,
|
||||
kv_cache_raw_tensors)
|
||||
|
||||
# Setup `kv_cache_config` and `kv_caches` for models
|
||||
# with cross-layer KV sharing
|
||||
@ -2157,17 +2198,30 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
kv_caches,
|
||||
)
|
||||
|
||||
bind_kv_cache(
|
||||
kv_caches,
|
||||
self.vllm_config.compilation_config.static_forward_context,
|
||||
self.kv_caches)
|
||||
return kv_caches
|
||||
|
||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Initialize KV cache based on `kv_cache_config`.
|
||||
Args:
|
||||
kv_cache_config: Configuration for the KV cache, including the KV
|
||||
cache size of each layer
|
||||
"""
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.may_reinitialize_input_batch(kv_cache_config)
|
||||
self.initialize_attn_backend(kv_cache_config)
|
||||
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
|
||||
|
||||
if self.speculative_config and self.speculative_config.use_eagle():
|
||||
assert isinstance(self.drafter, EagleProposer)
|
||||
# validate all draft model layers belong to the same kv cache
|
||||
# group
|
||||
self.drafter.validate_same_kv_cache_group(kv_cache_config)
|
||||
|
||||
bind_kv_cache(
|
||||
kv_caches,
|
||||
self.vllm_config.compilation_config.static_forward_context,
|
||||
self.kv_caches)
|
||||
|
||||
if has_kv_transfer_group():
|
||||
get_kv_transfer_group().register_kv_caches(kv_caches)
|
||||
|
||||
|
||||
@ -1365,14 +1365,20 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
assert self.block_table_cpu.dtype == self.input_batch.block_table[
|
||||
0].get_cpu_tensor().dtype
|
||||
|
||||
kv_caches: dict[str, torch.Tensor] = {}
|
||||
kv_cache_sizes = {}
|
||||
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
||||
assert len(kv_cache_tensor.shared_by) == 1, (
|
||||
"KV cache tensor shared by multiple layers is not supported in "
|
||||
"TPU.")
|
||||
kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size
|
||||
|
||||
kv_caches: dict[str, torch.Tensor] = {}
|
||||
for kv_cache_group in kv_cache_config.kv_cache_groups:
|
||||
kv_cache_spec = kv_cache_group.kv_cache_spec
|
||||
for layer_name in kv_cache_group.layer_names:
|
||||
tensor_config = kv_cache_config.tensors[layer_name]
|
||||
assert tensor_config.size % kv_cache_spec.page_size_bytes == 0
|
||||
num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes
|
||||
tensor_size = kv_cache_sizes[layer_name]
|
||||
assert tensor_size % kv_cache_spec.page_size_bytes == 0
|
||||
num_blocks = tensor_size // kv_cache_spec.page_size_bytes # noqa
|
||||
if isinstance(kv_cache_spec, AttentionSpec):
|
||||
if self.use_spmd:
|
||||
num_kv_heads = kv_cache_spec.num_kv_heads
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user