mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-02 01:51:20 +08:00
[Hybrid Allocator] Support KV cache groups with different block_size (#29143)
Signed-off-by: Yifan Qiao <yifanqiao@berkeley.edu> Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
parent
e502098643
commit
48ddb02b79
@ -1248,7 +1248,9 @@ def test_allocate_with_lookahead():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Test case 1: Requires additional lookahead tokens
|
# Test case 1: Requires additional lookahead tokens
|
||||||
kv_cache_manager = KVCacheManager(kv_cache_config=config, max_model_len=100)
|
kv_cache_manager = KVCacheManager(
|
||||||
|
kv_cache_config=config, max_model_len=100, hash_block_size=block_size
|
||||||
|
)
|
||||||
blocks = kv_cache_manager.allocate_slots(
|
blocks = kv_cache_manager.allocate_slots(
|
||||||
request,
|
request,
|
||||||
num_new_tokens=3,
|
num_new_tokens=3,
|
||||||
@ -1257,7 +1259,9 @@ def test_allocate_with_lookahead():
|
|||||||
assert len(blocks.get_block_ids()[0]) == 2 # ceil(5/4)=2 blocks
|
assert len(blocks.get_block_ids()[0]) == 2 # ceil(5/4)=2 blocks
|
||||||
|
|
||||||
# Test case 2: With precomputed blocks
|
# Test case 2: With precomputed blocks
|
||||||
kv_cache_manager = KVCacheManager(kv_cache_config=config, max_model_len=100)
|
kv_cache_manager = KVCacheManager(
|
||||||
|
kv_cache_config=config, max_model_len=100, hash_block_size=block_size
|
||||||
|
)
|
||||||
# required_blocks = ceil((3 + 2) /4) = 2
|
# required_blocks = ceil((3 + 2) /4) = 2
|
||||||
blocks = kv_cache_manager.allocate_slots(
|
blocks = kv_cache_manager.allocate_slots(
|
||||||
request,
|
request,
|
||||||
@ -1268,7 +1272,9 @@ def test_allocate_with_lookahead():
|
|||||||
|
|
||||||
# Test case 3: With precomputed blocks
|
# Test case 3: With precomputed blocks
|
||||||
# required_blocks = ceil((3 + 4) / 4) = 2
|
# required_blocks = ceil((3 + 4) / 4) = 2
|
||||||
kv_cache_manager = KVCacheManager(kv_cache_config=config, max_model_len=100)
|
kv_cache_manager = KVCacheManager(
|
||||||
|
kv_cache_config=config, max_model_len=100, hash_block_size=block_size
|
||||||
|
)
|
||||||
blocks = kv_cache_manager.allocate_slots(
|
blocks = kv_cache_manager.allocate_slots(
|
||||||
request,
|
request,
|
||||||
num_new_tokens=3,
|
num_new_tokens=3,
|
||||||
@ -1495,7 +1501,8 @@ def test_get_kv_cache_config_one_worker():
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
# different hidden size
|
|
||||||
|
# different hidden size but same type, use UniformTypeKVCacheSpecs
|
||||||
kv_cache_specs_hybrid = {
|
kv_cache_specs_hybrid = {
|
||||||
"layer_1": new_kv_cache_spec(head_size=128),
|
"layer_1": new_kv_cache_spec(head_size=128),
|
||||||
"layer_2": new_kv_cache_spec(head_size=64),
|
"layer_2": new_kv_cache_spec(head_size=64),
|
||||||
@ -1519,6 +1526,40 @@ def test_get_kv_cache_config_one_worker():
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Different hidden size and different type, align by different block size
|
||||||
|
kv_cache_specs_hybrid = {
|
||||||
|
"layer_1": new_kv_cache_spec(head_size=64),
|
||||||
|
"layer_2": new_sliding_window_spec(head_size=32),
|
||||||
|
}
|
||||||
|
kv_cache_config_hybrid = get_kv_cache_configs(
|
||||||
|
vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 32]
|
||||||
|
)[0]
|
||||||
|
assert kv_cache_config_hybrid == KVCacheConfig(
|
||||||
|
num_blocks=32,
|
||||||
|
kv_cache_tensors=[
|
||||||
|
KVCacheTensor(
|
||||||
|
size=mem_per_block_per_layer * 32, shared_by=["layer_1", "layer_2"]
|
||||||
|
),
|
||||||
|
],
|
||||||
|
kv_cache_groups=[
|
||||||
|
KVCacheGroupSpec(["layer_1"], new_kv_cache_spec(head_size=64)),
|
||||||
|
KVCacheGroupSpec(
|
||||||
|
["layer_2"], new_sliding_window_spec(head_size=32, block_size=32)
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# different hidden size that cannot be aligned by using different block size
|
||||||
|
kv_cache_specs_hybrid = {
|
||||||
|
"layer_1": new_kv_cache_spec(head_size=64),
|
||||||
|
"layer_2": new_sliding_window_spec(head_size=96),
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
get_kv_cache_configs(
|
||||||
|
vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 2 * 32]
|
||||||
|
)[0]
|
||||||
|
|
||||||
# Test num_gpu_blocks_override
|
# Test num_gpu_blocks_override
|
||||||
vllm_config.cache_config.num_gpu_blocks_override = 16
|
vllm_config.cache_config.num_gpu_blocks_override = 16
|
||||||
kv_cache_config_override_blocks = get_kv_cache_configs(
|
kv_cache_config_override_blocks = get_kv_cache_configs(
|
||||||
|
|||||||
@ -134,6 +134,7 @@ def test_prefill(hash_fn):
|
|||||||
make_kv_cache_config(block_size, 11),
|
make_kv_cache_config(block_size, 11),
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
enable_caching=True,
|
enable_caching=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Complete 3 blocks (48 tokens)
|
# Complete 3 blocks (48 tokens)
|
||||||
@ -256,6 +257,7 @@ def test_prefill_hybrid_model():
|
|||||||
make_kv_cache_config_hybrid_model(block_size, 21),
|
make_kv_cache_config_hybrid_model(block_size, 21),
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
enable_caching=True,
|
enable_caching=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
hash_fn = sha256
|
hash_fn = sha256
|
||||||
@ -416,6 +418,7 @@ def test_prefill_plp():
|
|||||||
make_kv_cache_config(block_size, 11),
|
make_kv_cache_config(block_size, 11),
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
enable_caching=True,
|
enable_caching=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
)
|
)
|
||||||
# the default hash function is sha256
|
# the default hash function is sha256
|
||||||
hash_fn = sha256
|
hash_fn = sha256
|
||||||
@ -523,6 +526,7 @@ def test_decode():
|
|||||||
make_kv_cache_config(block_size, 11),
|
make_kv_cache_config(block_size, 11),
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
enable_caching=True,
|
enable_caching=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Complete 3 blocks (48 tokens)
|
# Complete 3 blocks (48 tokens)
|
||||||
@ -585,6 +589,7 @@ def test_evict():
|
|||||||
make_kv_cache_config(block_size, 11),
|
make_kv_cache_config(block_size, 11),
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
enable_caching=True,
|
enable_caching=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
last_token_id = 5 * 16 + 7
|
last_token_id = 5 * 16 + 7
|
||||||
@ -643,6 +648,7 @@ def test_hash_block_correct_reuse():
|
|||||||
make_kv_cache_config(16, 2),
|
make_kv_cache_config(16, 2),
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
enable_caching=True,
|
enable_caching=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Allocate 1 block and cache it.
|
# Allocate 1 block and cache it.
|
||||||
@ -683,6 +689,7 @@ def test_computed_blocks_not_evicted():
|
|||||||
make_kv_cache_config(block_size, 3),
|
make_kv_cache_config(block_size, 3),
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
enable_caching=True,
|
enable_caching=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Allocate a block and cache it.
|
# Allocate a block and cache it.
|
||||||
@ -741,6 +748,7 @@ def test_basic_prefix_caching_disabled():
|
|||||||
make_kv_cache_config(block_size, 5),
|
make_kv_cache_config(block_size, 5),
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
enable_caching=False,
|
enable_caching=False,
|
||||||
|
hash_block_size=block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
req1 = make_request(
|
req1 = make_request(
|
||||||
@ -790,6 +798,7 @@ def test_cache_blocks(hash_fn):
|
|||||||
block_pool = BlockPool(
|
block_pool = BlockPool(
|
||||||
num_gpu_blocks=5,
|
num_gpu_blocks=5,
|
||||||
enable_caching=True,
|
enable_caching=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
)
|
)
|
||||||
# Req:
|
# Req:
|
||||||
# Block 0: [0, 1, 2, 3]
|
# Block 0: [0, 1, 2, 3]
|
||||||
@ -833,7 +842,9 @@ def test_cache_blocks_multi_group():
|
|||||||
This tests that blocks are cached correctly for different kv cache groups.
|
This tests that blocks are cached correctly for different kv cache groups.
|
||||||
"""
|
"""
|
||||||
block_size = 4
|
block_size = 4
|
||||||
block_pool = BlockPool(num_gpu_blocks=10, enable_caching=True)
|
block_pool = BlockPool(
|
||||||
|
num_gpu_blocks=10, enable_caching=True, hash_block_size=block_size
|
||||||
|
)
|
||||||
|
|
||||||
# Req:
|
# Req:
|
||||||
# Block 0/4: [0, 1, 2, 3]
|
# Block 0/4: [0, 1, 2, 3]
|
||||||
@ -921,6 +932,7 @@ def test_mm_prefix_caching():
|
|||||||
make_kv_cache_config(block_size, 11),
|
make_kv_cache_config(block_size, 11),
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
enable_caching=True,
|
enable_caching=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Common prompt tokens (T is text tokens and P is image placeholder tokens)
|
# Common prompt tokens (T is text tokens and P is image placeholder tokens)
|
||||||
@ -1020,6 +1032,7 @@ def test_cache_key_salting():
|
|||||||
make_kv_cache_config(block_size, 11),
|
make_kv_cache_config(block_size, 11),
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
enable_caching=True,
|
enable_caching=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3 complete blocks and an incomplete block with 11 tokens.
|
# 3 complete blocks and an incomplete block with 11 tokens.
|
||||||
@ -1101,6 +1114,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
|||||||
make_kv_cache_config(block_size, 11),
|
make_kv_cache_config(block_size, 11),
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
enable_caching=True,
|
enable_caching=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
)
|
)
|
||||||
# Complete 3 blocks (48 tokens)
|
# Complete 3 blocks (48 tokens)
|
||||||
# | Common-0 | Common-1 | Common-2 | ... |
|
# | Common-0 | Common-1 | Common-2 | ... |
|
||||||
@ -1173,6 +1187,7 @@ def test_reset_prefix_cache():
|
|||||||
make_kv_cache_config(block_size, 11),
|
make_kv_cache_config(block_size, 11),
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
enable_caching=True,
|
enable_caching=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
full_block_token_ids = [i for i in range(3) for _ in range(16)]
|
full_block_token_ids = [i for i in range(3) for _ in range(16)]
|
||||||
@ -1213,6 +1228,7 @@ def test_prefix_cache_stats_disabled():
|
|||||||
make_kv_cache_config(block_size, 11),
|
make_kv_cache_config(block_size, 11),
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
enable_caching=True,
|
enable_caching=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
log_stats=False, # Disable logging stats
|
log_stats=False, # Disable logging stats
|
||||||
)
|
)
|
||||||
assert manager.prefix_cache_stats is None
|
assert manager.prefix_cache_stats is None
|
||||||
@ -1232,7 +1248,7 @@ def test_prefix_cache_stats_disabled():
|
|||||||
|
|
||||||
|
|
||||||
def test_maybe_evict_cached_block():
|
def test_maybe_evict_cached_block():
|
||||||
pool = BlockPool(num_gpu_blocks=4, enable_caching=True)
|
pool = BlockPool(num_gpu_blocks=4, enable_caching=True, hash_block_size=16)
|
||||||
block_hash0 = make_block_hash_with_group_id(BlockHash(b"10"), 1000)
|
block_hash0 = make_block_hash_with_group_id(BlockHash(b"10"), 1000)
|
||||||
block_hash1 = make_block_hash_with_group_id(BlockHash(b"20"), 2000)
|
block_hash1 = make_block_hash_with_group_id(BlockHash(b"20"), 2000)
|
||||||
block_hash2 = make_block_hash_with_group_id(BlockHash(b"30"), 3000)
|
block_hash2 = make_block_hash_with_group_id(BlockHash(b"30"), 3000)
|
||||||
@ -1293,6 +1309,7 @@ def test_kv_cache_events(blocks_to_cache: int):
|
|||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
enable_caching=True,
|
enable_caching=True,
|
||||||
enable_kv_cache_events=True,
|
enable_kv_cache_events=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
num_tokens = block_size * blocks_to_cache
|
num_tokens = block_size * blocks_to_cache
|
||||||
@ -1351,6 +1368,7 @@ def test_kv_cache_events_with_lora(blocks_to_cache: int):
|
|||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
enable_caching=True,
|
enable_caching=True,
|
||||||
enable_kv_cache_events=True,
|
enable_kv_cache_events=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Test with LoRA request
|
# Test with LoRA request
|
||||||
@ -1405,6 +1423,7 @@ def test_eagle_enabled_removes_last_block():
|
|||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
enable_caching=True,
|
enable_caching=True,
|
||||||
use_eagle=True,
|
use_eagle=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Request with 3 full blocks (48 tokens)
|
# Request with 3 full blocks (48 tokens)
|
||||||
@ -1437,6 +1456,7 @@ def test_eagle_with_partial_blocks():
|
|||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
enable_caching=True,
|
enable_caching=True,
|
||||||
use_eagle=True,
|
use_eagle=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
)
|
)
|
||||||
# 2 full blocks + 5 tokens (non-divisible length)
|
# 2 full blocks + 5 tokens (non-divisible length)
|
||||||
token_ids = [0] * (2 * block_size + 5)
|
token_ids = [0] * (2 * block_size + 5)
|
||||||
@ -1476,6 +1496,7 @@ def test_eagle_with_sliding_window():
|
|||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
enable_caching=True,
|
enable_caching=True,
|
||||||
use_eagle=True,
|
use_eagle=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2 full blocks + 5 tokens (non-divisible length)
|
# 2 full blocks + 5 tokens (non-divisible length)
|
||||||
@ -1522,6 +1543,76 @@ def test_eagle_with_sliding_window():
|
|||||||
assert num_tokens == 0
|
assert num_tokens == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_different_block_size():
|
||||||
|
block_size = 16
|
||||||
|
# full attention and sliding window attention layers have the same page size:
|
||||||
|
# (32 tokens/block * float16 token, vs. 16 tokens/block * float32 token)
|
||||||
|
kv_cache_config = KVCacheConfig(
|
||||||
|
num_blocks=100,
|
||||||
|
kv_cache_tensors=[],
|
||||||
|
kv_cache_groups=[
|
||||||
|
KVCacheGroupSpec(
|
||||||
|
["layer1"],
|
||||||
|
FullAttentionSpec(block_size * 2, 1, 1, torch.float16),
|
||||||
|
),
|
||||||
|
KVCacheGroupSpec(
|
||||||
|
["layer2"],
|
||||||
|
SlidingWindowSpec(
|
||||||
|
block_size,
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
torch.float32,
|
||||||
|
sliding_window=2 * block_size,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
manager = KVCacheManager(
|
||||||
|
kv_cache_config=kv_cache_config,
|
||||||
|
max_model_len=8192,
|
||||||
|
enable_caching=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 10 blocks of 16 tokens each. Token ids are not strictly aligned for each block.
|
||||||
|
common_token_ids = [i for i in range(10) for _ in range(block_size)]
|
||||||
|
|
||||||
|
req0 = make_request("0", common_token_ids, block_size, sha256)
|
||||||
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||||
|
assert not computed_blocks.blocks[0]
|
||||||
|
assert not computed_blocks.blocks[1]
|
||||||
|
assert num_computed_tokens == 0
|
||||||
|
blocks = manager.allocate_slots(
|
||||||
|
req0, 7 * block_size, len(computed_blocks.blocks[0]) * 16, computed_blocks
|
||||||
|
)
|
||||||
|
assert blocks.get_block_ids() == ([1, 2, 3, 4], [5, 6, 7, 8, 9, 10, 11])
|
||||||
|
req1 = make_request("1", common_token_ids[: 7 * block_size + 1], block_size, sha256)
|
||||||
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||||
|
assert len(computed_blocks.blocks[0]) == 3
|
||||||
|
assert len(computed_blocks.blocks[1]) == 6
|
||||||
|
assert num_computed_tokens == 6 * 16
|
||||||
|
|
||||||
|
req2 = make_request("2", common_token_ids[: 6 * block_size + 1], block_size, sha256)
|
||||||
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||||
|
assert len(computed_blocks.blocks[0]) == 3
|
||||||
|
assert len(computed_blocks.blocks[1]) == 6
|
||||||
|
assert num_computed_tokens == 6 * 16
|
||||||
|
|
||||||
|
# Evict some blocks to make sliding window cache hit length 5*16
|
||||||
|
# But should return 4 * 16 because full attention cache hit length must be
|
||||||
|
# a multiple of 32
|
||||||
|
manager.block_pool.cached_block_hash_to_block.pop(
|
||||||
|
make_block_hash_with_group_id(req1.block_hashes[6], 1), 11
|
||||||
|
)
|
||||||
|
manager.block_pool.cached_block_hash_to_block.pop(
|
||||||
|
make_block_hash_with_group_id(req1.block_hashes[5], 1), 10
|
||||||
|
)
|
||||||
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||||
|
assert len(computed_blocks.blocks[0]) == 2
|
||||||
|
assert len(computed_blocks.blocks[1]) == 4
|
||||||
|
assert num_computed_tokens == 4 * 16
|
||||||
|
|
||||||
|
|
||||||
def test_block_lookup_cache_single_block_per_key():
|
def test_block_lookup_cache_single_block_per_key():
|
||||||
cache = BlockHashToBlockMap()
|
cache = BlockHashToBlockMap()
|
||||||
key0 = BlockHashWithGroupId(b"hash0")
|
key0 = BlockHashWithGroupId(b"hash0")
|
||||||
|
|||||||
@ -41,7 +41,9 @@ def test_chunked_local_attention_possible_cached_prefix():
|
|||||||
attention_chunk_size=4,
|
attention_chunk_size=4,
|
||||||
)
|
)
|
||||||
|
|
||||||
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
|
block_pool = BlockPool(
|
||||||
|
num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
|
||||||
|
)
|
||||||
manager = get_chunked_local_attention_manager(
|
manager = get_chunked_local_attention_manager(
|
||||||
chunked_local_attention_spec, block_pool
|
chunked_local_attention_spec, block_pool
|
||||||
)
|
)
|
||||||
@ -70,6 +72,7 @@ def test_chunked_local_attention_possible_cached_prefix():
|
|||||||
block_pool=block_pool,
|
block_pool=block_pool,
|
||||||
kv_cache_spec=chunked_local_attention_spec,
|
kv_cache_spec=chunked_local_attention_spec,
|
||||||
use_eagle=False,
|
use_eagle=False,
|
||||||
|
alignment_tokens=block_size,
|
||||||
)[0]
|
)[0]
|
||||||
assert len(computed_blocks) == expect_length
|
assert len(computed_blocks) == expect_length
|
||||||
|
|
||||||
@ -111,7 +114,9 @@ def test_sliding_window_possible_cached_prefix():
|
|||||||
sliding_window=4,
|
sliding_window=4,
|
||||||
)
|
)
|
||||||
|
|
||||||
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
|
block_pool = BlockPool(
|
||||||
|
num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
|
||||||
|
)
|
||||||
manager = get_sliding_window_manager(sliding_window_spec, block_pool)
|
manager = get_sliding_window_manager(sliding_window_spec, block_pool)
|
||||||
|
|
||||||
def run_one_case(block_is_cached, expect_length):
|
def run_one_case(block_is_cached, expect_length):
|
||||||
@ -138,6 +143,7 @@ def test_sliding_window_possible_cached_prefix():
|
|||||||
block_pool=block_pool,
|
block_pool=block_pool,
|
||||||
kv_cache_spec=sliding_window_spec,
|
kv_cache_spec=sliding_window_spec,
|
||||||
use_eagle=False,
|
use_eagle=False,
|
||||||
|
alignment_tokens=block_size,
|
||||||
)[0]
|
)[0]
|
||||||
assert len(computed_blocks) == expect_length
|
assert len(computed_blocks) == expect_length
|
||||||
|
|
||||||
@ -178,7 +184,7 @@ def test_chunked_local_attention_remove_skipped_blocks():
|
|||||||
attention_chunk_size=4,
|
attention_chunk_size=4,
|
||||||
)
|
)
|
||||||
|
|
||||||
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
|
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True, hash_block_size=2)
|
||||||
|
|
||||||
manager = get_chunked_local_attention_manager(attention_spec, block_pool)
|
manager = get_chunked_local_attention_manager(attention_spec, block_pool)
|
||||||
|
|
||||||
@ -239,7 +245,7 @@ def test_sliding_window_remove_skipped_blocks():
|
|||||||
sliding_window=4,
|
sliding_window=4,
|
||||||
)
|
)
|
||||||
|
|
||||||
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
|
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True, hash_block_size=2)
|
||||||
|
|
||||||
manager = get_sliding_window_manager(sliding_window_spec, block_pool)
|
manager = get_sliding_window_manager(sliding_window_spec, block_pool)
|
||||||
|
|
||||||
@ -316,7 +322,9 @@ def test_get_num_blocks_to_allocate():
|
|||||||
sliding_window=4, # Placeholder value, not related to test result
|
sliding_window=4, # Placeholder value, not related to test result
|
||||||
)
|
)
|
||||||
|
|
||||||
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
|
block_pool = BlockPool(
|
||||||
|
num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
|
||||||
|
)
|
||||||
manager = get_sliding_window_manager(sliding_window_spec, block_pool)
|
manager = get_sliding_window_manager(sliding_window_spec, block_pool)
|
||||||
cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)]
|
cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)]
|
||||||
cached_blocks_2 = [block_pool.null_block for _ in range(5)] + [
|
cached_blocks_2 = [block_pool.null_block for _ in range(5)] + [
|
||||||
@ -341,7 +349,9 @@ def test_chunked_local_attention_get_num_blocks_to_allocate():
|
|||||||
attention_chunk_size=4, # Placeholder value, not related to test result
|
attention_chunk_size=4, # Placeholder value, not related to test result
|
||||||
)
|
)
|
||||||
|
|
||||||
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
|
block_pool = BlockPool(
|
||||||
|
num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
|
||||||
|
)
|
||||||
manager = get_chunked_local_attention_manager(attention_spec, block_pool)
|
manager = get_chunked_local_attention_manager(attention_spec, block_pool)
|
||||||
cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)]
|
cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)]
|
||||||
cached_blocks_2 = [block_pool.null_block for _ in range(5)] + [
|
cached_blocks_2 = [block_pool.null_block for _ in range(5)] + [
|
||||||
|
|||||||
@ -1816,9 +1816,11 @@ class EngineArgs:
|
|||||||
if model_config.runner_type != "pooling":
|
if model_config.runner_type != "pooling":
|
||||||
default_chunked_prefill = True
|
default_chunked_prefill = True
|
||||||
|
|
||||||
# Disable prefix caching default for hybrid models
|
# Disable prefix caching default for hybrid models and mamba-only
|
||||||
# since the feature is still experimental.
|
# models since the feature is still experimental.
|
||||||
default_prefix_caching = not model_config.is_hybrid
|
default_prefix_caching = not (
|
||||||
|
model_config.is_hybrid or model_config.is_attention_free
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
assert model_config.pooler_config is not None
|
assert model_config.pooler_config is not None
|
||||||
|
|
||||||
|
|||||||
@ -289,9 +289,6 @@ class MambaModelConfig(VerifyAndUpdateConfig):
|
|||||||
model_config = vllm_config.model_config
|
model_config = vllm_config.model_config
|
||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
|
|
||||||
if cache_config.mamba_block_size is None:
|
|
||||||
cache_config.mamba_block_size = model_config.max_model_len
|
|
||||||
|
|
||||||
if cache_config.enable_prefix_caching:
|
if cache_config.enable_prefix_caching:
|
||||||
if model_config.supports_mamba_prefix_caching:
|
if model_config.supports_mamba_prefix_caching:
|
||||||
logger.info(
|
logger.info(
|
||||||
@ -299,6 +296,11 @@ class MambaModelConfig(VerifyAndUpdateConfig):
|
|||||||
"Its support for Mamba layers is experimental. "
|
"Its support for Mamba layers is experimental. "
|
||||||
"Please report any issues you may observe."
|
"Please report any issues you may observe."
|
||||||
)
|
)
|
||||||
|
# By default, mamba block size will be set to max_model_len (see
|
||||||
|
# below). When enabling prefix caching, we align mamba block size
|
||||||
|
# to the block size as the basic granularity for prefix caching.
|
||||||
|
if cache_config.mamba_block_size is None:
|
||||||
|
cache_config.mamba_block_size = cache_config.block_size
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Hybrid or mamba-based model detected without "
|
"Hybrid or mamba-based model detected without "
|
||||||
@ -306,6 +308,9 @@ class MambaModelConfig(VerifyAndUpdateConfig):
|
|||||||
)
|
)
|
||||||
cache_config.enable_prefix_caching = False
|
cache_config.enable_prefix_caching = False
|
||||||
|
|
||||||
|
if cache_config.mamba_block_size is None:
|
||||||
|
cache_config.mamba_block_size = model_config.max_model_len
|
||||||
|
|
||||||
# TODO(tdoublep): remove once cascade attention is supported
|
# TODO(tdoublep): remove once cascade attention is supported
|
||||||
logger.info(
|
logger.info(
|
||||||
"Disabling cascade attention since it is not supported for hybrid models."
|
"Disabling cascade attention since it is not supported for hybrid models."
|
||||||
|
|||||||
@ -13,6 +13,8 @@ from vllm.distributed.kv_events import (
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.v1.core.kv_cache_utils import (
|
from vllm.v1.core.kv_cache_utils import (
|
||||||
BlockHash,
|
BlockHash,
|
||||||
|
BlockHashList,
|
||||||
|
BlockHashListWithBlockSize,
|
||||||
BlockHashWithGroupId,
|
BlockHashWithGroupId,
|
||||||
ExternalBlockHash,
|
ExternalBlockHash,
|
||||||
FreeKVCacheBlockQueue,
|
FreeKVCacheBlockQueue,
|
||||||
@ -133,6 +135,10 @@ class BlockPool:
|
|||||||
Args:
|
Args:
|
||||||
num_gpu_blocks: The number of blocks in the pool.
|
num_gpu_blocks: The number of blocks in the pool.
|
||||||
enable_caching: Whether to enable prefix caching.
|
enable_caching: Whether to enable prefix caching.
|
||||||
|
hash_block_size: The block size of which the block hashes are computed.
|
||||||
|
The actual block size usually equals hash_block_size, but in cases
|
||||||
|
where different KV cache groups have different block sizes, the
|
||||||
|
actual block size can be a multiple of hash_block_size.
|
||||||
enable_kv_cache_events: Whether to enable kv cache events.
|
enable_kv_cache_events: Whether to enable kv cache events.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -140,11 +146,13 @@ class BlockPool:
|
|||||||
self,
|
self,
|
||||||
num_gpu_blocks: int,
|
num_gpu_blocks: int,
|
||||||
enable_caching: bool,
|
enable_caching: bool,
|
||||||
|
hash_block_size: int,
|
||||||
enable_kv_cache_events: bool = False,
|
enable_kv_cache_events: bool = False,
|
||||||
):
|
):
|
||||||
assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
|
assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
|
||||||
self.num_gpu_blocks = num_gpu_blocks
|
self.num_gpu_blocks = num_gpu_blocks
|
||||||
self.enable_caching = enable_caching
|
self.enable_caching = enable_caching
|
||||||
|
self.hash_block_size = hash_block_size
|
||||||
# All kv-cache blocks.
|
# All kv-cache blocks.
|
||||||
self.blocks: list[KVCacheBlock] = [
|
self.blocks: list[KVCacheBlock] = [
|
||||||
KVCacheBlock(idx) for idx in range(num_gpu_blocks)
|
KVCacheBlock(idx) for idx in range(num_gpu_blocks)
|
||||||
@ -223,8 +231,20 @@ class BlockPool:
|
|||||||
return
|
return
|
||||||
new_full_blocks = blocks[num_cached_blocks:num_full_blocks]
|
new_full_blocks = blocks[num_cached_blocks:num_full_blocks]
|
||||||
assert len(request.block_hashes) >= num_full_blocks
|
assert len(request.block_hashes) >= num_full_blocks
|
||||||
new_block_hashes = request.block_hashes[num_cached_blocks:]
|
if block_size == self.hash_block_size:
|
||||||
|
# Common case.
|
||||||
|
block_hashes: BlockHashList = request.block_hashes
|
||||||
|
else:
|
||||||
|
# block_size is a multiple of hash_block_size. This happens when
|
||||||
|
# different KV cache groups have different block sizes.
|
||||||
|
assert block_size % self.hash_block_size == 0
|
||||||
|
# Recalculate block_hashes at the granularity of block_size, using
|
||||||
|
# the original block_hashes (at the granularity of hash_block_size).
|
||||||
|
block_hashes = BlockHashListWithBlockSize(
|
||||||
|
request.block_hashes, self.hash_block_size, block_size
|
||||||
|
)
|
||||||
|
|
||||||
|
new_block_hashes = block_hashes[num_cached_blocks:]
|
||||||
new_hashes: list[ExternalBlockHash] | None = (
|
new_hashes: list[ExternalBlockHash] | None = (
|
||||||
[] if self.enable_kv_cache_events else None
|
[] if self.enable_kv_cache_events else None
|
||||||
)
|
)
|
||||||
|
|||||||
@ -2,15 +2,25 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
from math import lcm
|
||||||
|
|
||||||
from vllm.v1.core.block_pool import BlockPool
|
from vllm.v1.core.block_pool import BlockPool
|
||||||
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
|
from vllm.v1.core.kv_cache_utils import (
|
||||||
|
BlockHash,
|
||||||
|
BlockHashList,
|
||||||
|
BlockHashListWithBlockSize,
|
||||||
|
KVCacheBlock,
|
||||||
|
)
|
||||||
from vllm.v1.core.single_type_kv_cache_manager import (
|
from vllm.v1.core.single_type_kv_cache_manager import (
|
||||||
CrossAttentionManager,
|
CrossAttentionManager,
|
||||||
FullAttentionManager,
|
FullAttentionManager,
|
||||||
get_manager_for_kv_cache_spec,
|
get_manager_for_kv_cache_spec,
|
||||||
)
|
)
|
||||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig, KVCacheSpec
|
from vllm.v1.kv_cache_interface import (
|
||||||
|
FullAttentionSpec,
|
||||||
|
KVCacheConfig,
|
||||||
|
KVCacheSpec,
|
||||||
|
)
|
||||||
from vllm.v1.request import Request
|
from vllm.v1.request import Request
|
||||||
|
|
||||||
|
|
||||||
@ -28,13 +38,17 @@ class KVCacheCoordinator(ABC):
|
|||||||
enable_kv_cache_events: bool,
|
enable_kv_cache_events: bool,
|
||||||
dcp_world_size: int,
|
dcp_world_size: int,
|
||||||
pcp_world_size: int,
|
pcp_world_size: int,
|
||||||
|
hash_block_size: int,
|
||||||
):
|
):
|
||||||
self.kv_cache_config = kv_cache_config
|
self.kv_cache_config = kv_cache_config
|
||||||
self.max_model_len = max_model_len
|
self.max_model_len = max_model_len
|
||||||
self.enable_caching = enable_caching
|
self.enable_caching = enable_caching
|
||||||
|
|
||||||
self.block_pool = BlockPool(
|
self.block_pool = BlockPool(
|
||||||
kv_cache_config.num_blocks, enable_caching, enable_kv_cache_events
|
kv_cache_config.num_blocks,
|
||||||
|
enable_caching,
|
||||||
|
hash_block_size,
|
||||||
|
enable_kv_cache_events,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Needs special handling for find_longest_cache_hit if eagle is enabled
|
# Needs special handling for find_longest_cache_hit if eagle is enabled
|
||||||
@ -213,6 +227,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
|
|||||||
enable_kv_cache_events: bool,
|
enable_kv_cache_events: bool,
|
||||||
dcp_world_size: int,
|
dcp_world_size: int,
|
||||||
pcp_world_size: int,
|
pcp_world_size: int,
|
||||||
|
hash_block_size: int,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
kv_cache_config,
|
kv_cache_config,
|
||||||
@ -222,6 +237,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
|
|||||||
enable_kv_cache_events,
|
enable_kv_cache_events,
|
||||||
dcp_world_size=dcp_world_size,
|
dcp_world_size=dcp_world_size,
|
||||||
pcp_world_size=pcp_world_size,
|
pcp_world_size=pcp_world_size,
|
||||||
|
hash_block_size=hash_block_size,
|
||||||
)
|
)
|
||||||
self.num_single_type_manager = len(self.single_type_managers)
|
self.num_single_type_manager = len(self.single_type_managers)
|
||||||
|
|
||||||
@ -255,6 +271,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
|
|||||||
enable_kv_cache_events: bool,
|
enable_kv_cache_events: bool,
|
||||||
dcp_world_size: int,
|
dcp_world_size: int,
|
||||||
pcp_world_size: int,
|
pcp_world_size: int,
|
||||||
|
hash_block_size: int,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
kv_cache_config,
|
kv_cache_config,
|
||||||
@ -264,6 +281,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
|
|||||||
enable_kv_cache_events,
|
enable_kv_cache_events,
|
||||||
dcp_world_size=dcp_world_size,
|
dcp_world_size=dcp_world_size,
|
||||||
pcp_world_size=pcp_world_size,
|
pcp_world_size=pcp_world_size,
|
||||||
|
hash_block_size=hash_block_size,
|
||||||
)
|
)
|
||||||
self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec
|
self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec
|
||||||
self.block_size = self.kv_cache_spec.block_size
|
self.block_size = self.kv_cache_spec.block_size
|
||||||
@ -273,6 +291,11 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
|
|||||||
self.block_size *= dcp_world_size
|
self.block_size *= dcp_world_size
|
||||||
if pcp_world_size > 1:
|
if pcp_world_size > 1:
|
||||||
self.block_size *= pcp_world_size
|
self.block_size *= pcp_world_size
|
||||||
|
# For models using only Mamba, block_size is set to max_model_len when
|
||||||
|
# prefix caching is disabled, and hash_block_size validation is skipped.
|
||||||
|
assert not enable_caching or (hash_block_size == self.block_size), (
|
||||||
|
"UnitaryKVCacheCoordinator assumes hash_block_size == block_size"
|
||||||
|
)
|
||||||
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
|
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
|
||||||
"UnitaryKVCacheCoordinator assumes only one kv cache group"
|
"UnitaryKVCacheCoordinator assumes only one kv cache group"
|
||||||
)
|
)
|
||||||
@ -289,6 +312,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
|
|||||||
block_pool=self.block_pool,
|
block_pool=self.block_pool,
|
||||||
kv_cache_spec=self.kv_cache_spec,
|
kv_cache_spec=self.kv_cache_spec,
|
||||||
use_eagle=self.use_eagle,
|
use_eagle=self.use_eagle,
|
||||||
|
alignment_tokens=self.block_size,
|
||||||
dcp_world_size=self.dcp_world_size,
|
dcp_world_size=self.dcp_world_size,
|
||||||
pcp_world_size=self.pcp_world_size,
|
pcp_world_size=self.pcp_world_size,
|
||||||
)
|
)
|
||||||
@ -313,6 +337,7 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
|||||||
enable_kv_cache_events: bool,
|
enable_kv_cache_events: bool,
|
||||||
dcp_world_size: int,
|
dcp_world_size: int,
|
||||||
pcp_world_size: int,
|
pcp_world_size: int,
|
||||||
|
hash_block_size: int,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
kv_cache_config,
|
kv_cache_config,
|
||||||
@ -322,7 +347,17 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
|||||||
enable_kv_cache_events,
|
enable_kv_cache_events,
|
||||||
dcp_world_size=dcp_world_size,
|
dcp_world_size=dcp_world_size,
|
||||||
pcp_world_size=pcp_world_size,
|
pcp_world_size=pcp_world_size,
|
||||||
|
hash_block_size=hash_block_size,
|
||||||
)
|
)
|
||||||
|
# hash_block_size: the block size used to compute block hashes.
|
||||||
|
# The actual block size usually equals hash_block_size, but in cases where
|
||||||
|
# different KV cache groups have different block sizes, the actual block size
|
||||||
|
# can be a multiple of hash_block_size.
|
||||||
|
self.hash_block_size = hash_block_size
|
||||||
|
assert all(
|
||||||
|
g.kv_cache_spec.block_size % hash_block_size == 0
|
||||||
|
for g in kv_cache_config.kv_cache_groups
|
||||||
|
), "block_size must be divisible by hash_block_size"
|
||||||
assert dcp_world_size == 1, "DCP not support hybrid attn now."
|
assert dcp_world_size == 1, "DCP not support hybrid attn now."
|
||||||
assert pcp_world_size == 1, "PCP not support hybrid attn now."
|
assert pcp_world_size == 1, "PCP not support hybrid attn now."
|
||||||
self.verify_and_split_kv_cache_groups()
|
self.verify_and_split_kv_cache_groups()
|
||||||
@ -373,14 +408,12 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
|||||||
self.other_spec = other_spec
|
self.other_spec = other_spec
|
||||||
self.full_attention_block_size = self.full_attention_spec.block_size
|
self.full_attention_block_size = self.full_attention_spec.block_size
|
||||||
self.other_block_size = self.other_spec.block_size
|
self.other_block_size = self.other_spec.block_size
|
||||||
|
# The LCM of the block sizes of full attention and other attention.
|
||||||
if self.enable_caching:
|
# The cache hit length must be a multiple of the LCM of the block sizes
|
||||||
# this requirement is only needed for the prefix caching logic
|
# to make sure the cache hit length is a multiple of the block size of
|
||||||
divisible = self.other_block_size % self.full_attention_block_size
|
# each attention type. Requiring this because we don't support partial
|
||||||
assert divisible == 0, (
|
# block cache hit yet.
|
||||||
"KVCacheCoordinator assumes the block_size of full "
|
self.lcm_block_size = lcm(self.full_attention_block_size, self.other_block_size)
|
||||||
"attention layers is divisible by other layers now."
|
|
||||||
)
|
|
||||||
|
|
||||||
if max(self.full_attention_group_ids) < min(self.other_group_ids):
|
if max(self.full_attention_group_ids) < min(self.other_group_ids):
|
||||||
self.full_attn_first = True
|
self.full_attn_first = True
|
||||||
@ -414,25 +447,48 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
|||||||
- The number of tokens of the longest cache hit.
|
- The number of tokens of the longest cache hit.
|
||||||
"""
|
"""
|
||||||
# First, find the longest cache hit for full attention.
|
# First, find the longest cache hit for full attention.
|
||||||
|
if self.full_attention_spec.block_size == self.hash_block_size:
|
||||||
|
# Common case.
|
||||||
|
full_attention_block_hashes: BlockHashList = block_hashes
|
||||||
|
else:
|
||||||
|
# block_size is a multiple of hash_block_size. This happens when different
|
||||||
|
# KV cache groups have different block sizes. In this case, we need to
|
||||||
|
# recalculate block_hashes at the granularity of block_size, using the
|
||||||
|
# original block_hashes (at the granularity of hash_block_size).
|
||||||
|
full_attention_block_hashes = BlockHashListWithBlockSize(
|
||||||
|
block_hashes, self.hash_block_size, self.full_attention_spec.block_size
|
||||||
|
)
|
||||||
hit_blocks_full_attn = self.full_attention_manager_cls.find_longest_cache_hit(
|
hit_blocks_full_attn = self.full_attention_manager_cls.find_longest_cache_hit(
|
||||||
block_hashes=block_hashes,
|
block_hashes=full_attention_block_hashes,
|
||||||
max_length=max_cache_hit_length,
|
max_length=max_cache_hit_length,
|
||||||
kv_cache_group_ids=self.full_attention_group_ids,
|
kv_cache_group_ids=self.full_attention_group_ids,
|
||||||
block_pool=self.block_pool,
|
block_pool=self.block_pool,
|
||||||
kv_cache_spec=self.full_attention_spec,
|
kv_cache_spec=self.full_attention_spec,
|
||||||
use_eagle=self.use_eagle,
|
use_eagle=self.use_eagle,
|
||||||
|
alignment_tokens=self.lcm_block_size,
|
||||||
)
|
)
|
||||||
hit_length = len(hit_blocks_full_attn[0]) * self.full_attention_block_size
|
hit_length = len(hit_blocks_full_attn[0]) * self.full_attention_block_size
|
||||||
|
|
||||||
# Next, find the cache hit for the other attention WITHIN
|
# Next, find the cache hit for the other attention WITHIN
|
||||||
# the cache hit of full attention.
|
# the cache hit of full attention.
|
||||||
|
if self.other_spec.block_size == self.hash_block_size:
|
||||||
|
# Common case.
|
||||||
|
other_block_hashes: BlockHashList = block_hashes
|
||||||
|
else:
|
||||||
|
# Similar to the full attention case, here we need to recalculate
|
||||||
|
# block_hashes at the granularity of block_size, using the original
|
||||||
|
# block_hashes (at the granularity of hash_block_size).
|
||||||
|
other_block_hashes = BlockHashListWithBlockSize(
|
||||||
|
block_hashes, self.hash_block_size, self.other_spec.block_size
|
||||||
|
)
|
||||||
hit_blocks_other_attn = self.other_attention_cls.find_longest_cache_hit(
|
hit_blocks_other_attn = self.other_attention_cls.find_longest_cache_hit(
|
||||||
block_hashes=block_hashes,
|
block_hashes=other_block_hashes,
|
||||||
max_length=hit_length,
|
max_length=hit_length,
|
||||||
kv_cache_group_ids=self.other_group_ids,
|
kv_cache_group_ids=self.other_group_ids,
|
||||||
block_pool=self.block_pool,
|
block_pool=self.block_pool,
|
||||||
kv_cache_spec=self.other_spec,
|
kv_cache_spec=self.other_spec,
|
||||||
use_eagle=self.use_eagle,
|
use_eagle=self.use_eagle,
|
||||||
|
alignment_tokens=self.lcm_block_size,
|
||||||
)
|
)
|
||||||
hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size
|
hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size
|
||||||
|
|
||||||
@ -466,6 +522,7 @@ def get_kv_cache_coordinator(
|
|||||||
enable_kv_cache_events: bool,
|
enable_kv_cache_events: bool,
|
||||||
dcp_world_size: int,
|
dcp_world_size: int,
|
||||||
pcp_world_size: int,
|
pcp_world_size: int,
|
||||||
|
hash_block_size: int,
|
||||||
) -> KVCacheCoordinator:
|
) -> KVCacheCoordinator:
|
||||||
if not enable_caching:
|
if not enable_caching:
|
||||||
return KVCacheCoordinatorNoPrefixCache(
|
return KVCacheCoordinatorNoPrefixCache(
|
||||||
@ -473,8 +530,9 @@ def get_kv_cache_coordinator(
|
|||||||
max_model_len,
|
max_model_len,
|
||||||
use_eagle,
|
use_eagle,
|
||||||
enable_kv_cache_events,
|
enable_kv_cache_events,
|
||||||
dcp_world_size=dcp_world_size,
|
dcp_world_size,
|
||||||
pcp_world_size=pcp_world_size,
|
pcp_world_size,
|
||||||
|
hash_block_size,
|
||||||
)
|
)
|
||||||
if len(kv_cache_config.kv_cache_groups) == 1:
|
if len(kv_cache_config.kv_cache_groups) == 1:
|
||||||
return UnitaryKVCacheCoordinator(
|
return UnitaryKVCacheCoordinator(
|
||||||
@ -483,8 +541,9 @@ def get_kv_cache_coordinator(
|
|||||||
use_eagle,
|
use_eagle,
|
||||||
enable_caching,
|
enable_caching,
|
||||||
enable_kv_cache_events,
|
enable_kv_cache_events,
|
||||||
dcp_world_size=dcp_world_size,
|
dcp_world_size,
|
||||||
pcp_world_size=pcp_world_size,
|
pcp_world_size,
|
||||||
|
hash_block_size,
|
||||||
)
|
)
|
||||||
return HybridKVCacheCoordinator(
|
return HybridKVCacheCoordinator(
|
||||||
kv_cache_config,
|
kv_cache_config,
|
||||||
@ -492,6 +551,7 @@ def get_kv_cache_coordinator(
|
|||||||
use_eagle,
|
use_eagle,
|
||||||
enable_caching,
|
enable_caching,
|
||||||
enable_kv_cache_events,
|
enable_kv_cache_events,
|
||||||
dcp_world_size=dcp_world_size,
|
dcp_world_size,
|
||||||
pcp_world_size=pcp_world_size,
|
pcp_world_size,
|
||||||
|
hash_block_size,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -95,6 +95,7 @@ class KVCacheManager:
|
|||||||
self,
|
self,
|
||||||
kv_cache_config: KVCacheConfig,
|
kv_cache_config: KVCacheConfig,
|
||||||
max_model_len: int,
|
max_model_len: int,
|
||||||
|
hash_block_size: int,
|
||||||
enable_caching: bool = True,
|
enable_caching: bool = True,
|
||||||
use_eagle: bool = False,
|
use_eagle: bool = False,
|
||||||
log_stats: bool = False,
|
log_stats: bool = False,
|
||||||
@ -107,28 +108,11 @@ class KVCacheManager:
|
|||||||
self.enable_caching = enable_caching
|
self.enable_caching = enable_caching
|
||||||
self.use_eagle = use_eagle
|
self.use_eagle = use_eagle
|
||||||
self.log_stats = log_stats
|
self.log_stats = log_stats
|
||||||
# FIXME: make prefix cache stats conditional on log_stats
|
# FIXME: make prefix cache stats conditional on log_stats. We still need
|
||||||
|
# this comment because when the log stats is enabled there are still
|
||||||
|
# potential configs we could expose in the future.
|
||||||
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
|
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
|
||||||
|
|
||||||
self.block_size: int | None = None
|
|
||||||
if self.enable_caching:
|
|
||||||
assert (
|
|
||||||
len(
|
|
||||||
set(
|
|
||||||
g.kv_cache_spec.block_size
|
|
||||||
for g in kv_cache_config.kv_cache_groups
|
|
||||||
)
|
|
||||||
)
|
|
||||||
== 1
|
|
||||||
), "Only one block size is supported for now"
|
|
||||||
self.block_size = kv_cache_config.kv_cache_groups[
|
|
||||||
0
|
|
||||||
].kv_cache_spec.block_size
|
|
||||||
|
|
||||||
if dcp_world_size * pcp_world_size > 1:
|
|
||||||
assert len(kv_cache_config.kv_cache_groups) == 1
|
|
||||||
self.block_size *= dcp_world_size * pcp_world_size
|
|
||||||
|
|
||||||
self.coordinator = get_kv_cache_coordinator(
|
self.coordinator = get_kv_cache_coordinator(
|
||||||
kv_cache_config=kv_cache_config,
|
kv_cache_config=kv_cache_config,
|
||||||
max_model_len=self.max_model_len,
|
max_model_len=self.max_model_len,
|
||||||
@ -137,6 +121,7 @@ class KVCacheManager:
|
|||||||
enable_kv_cache_events=enable_kv_cache_events,
|
enable_kv_cache_events=enable_kv_cache_events,
|
||||||
dcp_world_size=dcp_world_size,
|
dcp_world_size=dcp_world_size,
|
||||||
pcp_world_size=pcp_world_size,
|
pcp_world_size=pcp_world_size,
|
||||||
|
hash_block_size=hash_block_size,
|
||||||
)
|
)
|
||||||
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
|
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
|
||||||
self.block_pool = self.coordinator.block_pool
|
self.block_pool = self.coordinator.block_pool
|
||||||
|
|||||||
@ -5,9 +5,9 @@
|
|||||||
import copy
|
import copy
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Callable, Iterable, Sequence
|
from collections.abc import Callable, Iterable, Iterator, Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, replace
|
||||||
from typing import Any, NewType, TypeAlias
|
from typing import Any, NewType, TypeAlias, overload
|
||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
@ -825,11 +825,11 @@ def get_num_blocks(
|
|||||||
return num_blocks
|
return num_blocks
|
||||||
|
|
||||||
|
|
||||||
def get_uniform_page_size(kv_cache_spec: dict[str, KVCacheSpec]) -> int:
|
def get_uniform_page_size(kv_cache_specs: Iterable[KVCacheSpec]) -> int:
|
||||||
"""
|
"""
|
||||||
Get the page size of the KV cache.
|
Get the page size of the KV cache.
|
||||||
"""
|
"""
|
||||||
page_sizes = set(layer.page_size_bytes for layer in kv_cache_spec.values())
|
page_sizes = {layer.page_size_bytes for layer in kv_cache_specs}
|
||||||
assert len(page_sizes) == 1
|
assert len(page_sizes) == 1
|
||||||
return page_sizes.pop()
|
return page_sizes.pop()
|
||||||
|
|
||||||
@ -882,6 +882,46 @@ def is_kv_cache_page_size_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool
|
|||||||
return len(page_sizes) == 1
|
return len(page_sizes) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def unify_kv_cache_spec_page_size(
|
||||||
|
kv_cache_spec: dict[str, KVCacheSpec],
|
||||||
|
) -> dict[str, KVCacheSpec]:
|
||||||
|
"""
|
||||||
|
Unify the page size of the given KVCacheSpec. If the page size of all layers
|
||||||
|
are the same, return the original KVCacheSpec. If not same, unify the page
|
||||||
|
size by increasing the block size of layers with smaller page size. Raise
|
||||||
|
NotImplementedError if failed to unify the page size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kv_cache_spec: The KVCacheSpec of each attention layer in the model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The updated KVCacheSpec with the same page_size_bytes.
|
||||||
|
"""
|
||||||
|
page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()}
|
||||||
|
if len(page_sizes) <= 1:
|
||||||
|
# All layers have the same page size, no need to unify.
|
||||||
|
return kv_cache_spec
|
||||||
|
|
||||||
|
max_page_size = max(page_sizes)
|
||||||
|
new_kv_cache_spec = {}
|
||||||
|
for layer_name, layer_spec in kv_cache_spec.items():
|
||||||
|
if layer_spec.page_size_bytes == max_page_size:
|
||||||
|
new_kv_cache_spec[layer_name] = layer_spec
|
||||||
|
else:
|
||||||
|
layer_page_size = layer_spec.page_size_bytes
|
||||||
|
if max_page_size % layer_page_size != 0:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"The page size of the layer is not divisible by the "
|
||||||
|
"maximum page size. Cannot unify by adjusting block_size."
|
||||||
|
)
|
||||||
|
ratio = max_page_size // layer_page_size
|
||||||
|
new_block_size = layer_spec.block_size * ratio
|
||||||
|
new_spec = replace(layer_spec, block_size=new_block_size)
|
||||||
|
assert new_spec.page_size_bytes == max_page_size
|
||||||
|
new_kv_cache_spec[layer_name] = new_spec
|
||||||
|
return new_kv_cache_spec
|
||||||
|
|
||||||
|
|
||||||
def is_kv_cache_type_attention_free(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
|
def is_kv_cache_type_attention_free(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
|
||||||
# kv_cache_spec is an empty dict for attention free models
|
# kv_cache_spec is an empty dict for attention free models
|
||||||
return not kv_cache_spec
|
return not kv_cache_spec
|
||||||
@ -1010,7 +1050,6 @@ def _get_kv_cache_groups_uniform_page_size(
|
|||||||
def get_kv_cache_config_from_groups(
|
def get_kv_cache_config_from_groups(
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
kv_cache_groups: list[KVCacheGroupSpec],
|
kv_cache_groups: list[KVCacheGroupSpec],
|
||||||
kv_cache_specs: dict[str, KVCacheSpec],
|
|
||||||
available_memory: int,
|
available_memory: int,
|
||||||
) -> KVCacheConfig:
|
) -> KVCacheConfig:
|
||||||
"""
|
"""
|
||||||
@ -1020,7 +1059,6 @@ def get_kv_cache_config_from_groups(
|
|||||||
Args:
|
Args:
|
||||||
vllm_config: The global VllmConfig
|
vllm_config: The global VllmConfig
|
||||||
kv_cache_groups: The KV cache groups
|
kv_cache_groups: The KV cache groups
|
||||||
kv_cache_specs: The KV cache spec of each attention layer in the model
|
|
||||||
available_memory: Memory available for KV cache in bytes
|
available_memory: Memory available for KV cache in bytes
|
||||||
Returns:
|
Returns:
|
||||||
The generated KVCacheConfig
|
The generated KVCacheConfig
|
||||||
@ -1064,7 +1102,9 @@ def get_kv_cache_config_from_groups(
|
|||||||
# full.1, sw.2: share another Tensor with size=available_memory//2
|
# full.1, sw.2: share another Tensor with size=available_memory//2
|
||||||
group_size = max(len(group.layer_names) for group in kv_cache_groups)
|
group_size = max(len(group.layer_names) for group in kv_cache_groups)
|
||||||
|
|
||||||
page_size = get_uniform_page_size(kv_cache_specs)
|
page_size = get_uniform_page_size(
|
||||||
|
[group.kv_cache_spec for group in kv_cache_groups]
|
||||||
|
)
|
||||||
assert group_size > 0, "group_size must be greater than 0"
|
assert group_size > 0, "group_size must be greater than 0"
|
||||||
num_blocks = get_num_blocks(
|
num_blocks = get_num_blocks(
|
||||||
vllm_config, group_size, available_memory, page_size
|
vllm_config, group_size, available_memory, page_size
|
||||||
@ -1166,7 +1206,8 @@ def get_kv_cache_groups(
|
|||||||
# This returns an empty list to allow for the KVCacheManager to handle
|
# This returns an empty list to allow for the KVCacheManager to handle
|
||||||
# attention free models.
|
# attention free models.
|
||||||
return []
|
return []
|
||||||
elif is_kv_cache_spec_uniform(kv_cache_spec):
|
|
||||||
|
if is_kv_cache_spec_uniform(kv_cache_spec):
|
||||||
# KV cache of all layers are the same, which is true for
|
# KV cache of all layers are the same, which is true for
|
||||||
# most models. Allocate the same amount of memory for
|
# most models. Allocate the same amount of memory for
|
||||||
# each layer.
|
# each layer.
|
||||||
@ -1176,14 +1217,16 @@ def get_kv_cache_groups(
|
|||||||
# full attention, or all layers are sliding window attention with the
|
# full attention, or all layers are sliding window attention with the
|
||||||
# same window size). Put all layers into one group.
|
# same window size). Put all layers into one group.
|
||||||
return _get_kv_cache_groups_uniform_type(uniform_spec)
|
return _get_kv_cache_groups_uniform_type(uniform_spec)
|
||||||
elif is_kv_cache_page_size_uniform(kv_cache_spec):
|
|
||||||
# Model contains multiple attention types, but KV cache of all layers
|
|
||||||
# have the same physical memory per block per layer. Split the layers
|
|
||||||
# into groups with the same number of layers, and thus same total page
|
|
||||||
# size.
|
|
||||||
return _get_kv_cache_groups_uniform_page_size(kv_cache_spec)
|
|
||||||
|
|
||||||
raise NotImplementedError
|
# As KVCacheManager can only allocate memory of one size, we need to unify
|
||||||
|
# the page size of the layers. For cases cannot be unified, this function
|
||||||
|
# will raise an error.
|
||||||
|
kv_cache_spec = unify_kv_cache_spec_page_size(kv_cache_spec)
|
||||||
|
# Model contains multiple attention types, but KV cache of all layers
|
||||||
|
# have the same physical memory per block per layer. Split the layers
|
||||||
|
# into groups with the same number of layers, and thus same total page
|
||||||
|
# size.
|
||||||
|
return _get_kv_cache_groups_uniform_page_size(kv_cache_spec)
|
||||||
|
|
||||||
|
|
||||||
def generate_scheduler_kv_cache_config(
|
def generate_scheduler_kv_cache_config(
|
||||||
@ -1327,10 +1370,7 @@ def get_kv_cache_configs(
|
|||||||
) == len(kv_cache_spec_one_worker), "Some layers are not assigned to any group."
|
) == len(kv_cache_spec_one_worker), "Some layers are not assigned to any group."
|
||||||
kv_cache_configs.append(
|
kv_cache_configs.append(
|
||||||
get_kv_cache_config_from_groups(
|
get_kv_cache_config_from_groups(
|
||||||
vllm_config,
|
vllm_config, kv_cache_groups_one_worker, available_memory_one_worker
|
||||||
kv_cache_groups_one_worker,
|
|
||||||
kv_cache_spec_one_worker,
|
|
||||||
available_memory_one_worker,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1353,3 +1393,79 @@ def get_kv_cache_configs(
|
|||||||
_report_kv_cache_config(vllm_config, kv_cache_config)
|
_report_kv_cache_config(vllm_config, kv_cache_config)
|
||||||
|
|
||||||
return kv_cache_configs
|
return kv_cache_configs
|
||||||
|
|
||||||
|
|
||||||
|
class BlockHashListWithBlockSize:
|
||||||
|
"""
|
||||||
|
Convert block-hash granularity from `hash_block_size` to `target_block_size`.
|
||||||
|
Used when KV cache groups have different block sizes: `hash_block_size`
|
||||||
|
is the size used to compute the original `block_hashes`; `target_block_size`
|
||||||
|
is the group's actual block size.
|
||||||
|
|
||||||
|
Currently, only scaling up by an integer factor is supported (i.e.,
|
||||||
|
`target_block_size` is a multiple of `hash_block_size`). Conversion is
|
||||||
|
performed lazily on access for efficiency, by concatenating consecutive
|
||||||
|
hashes at `hash_block_size` to form each hash at `target_block_size`.
|
||||||
|
|
||||||
|
Example (`hash_block_size` = 16, `target_block_size` = 32):
|
||||||
|
concatenating two 16-size hashes yields one 32-size hash:
|
||||||
|
|
||||||
|
Block hashes with block_size 16:
|
||||||
|
| Token Range | 0-15 | 16-31 | 32-47 | 48-63 |
|
||||||
|
|-------------|------|-------|-------|-------|
|
||||||
|
| Hash | A | B | C | D |
|
||||||
|
|
||||||
|
Block hashes with block_size 32:
|
||||||
|
| Token Range | 0-31 | 32-63 |
|
||||||
|
|-------------|------|-------|
|
||||||
|
| Hash | AB | CD |
|
||||||
|
|
||||||
|
Args:
|
||||||
|
block_hashes: Block hashes to convert, computed at `hash_block_size`.
|
||||||
|
hash_block_size: Block size at which `block_hashes` were computed.
|
||||||
|
target_block_size: Desired block size; must be a multiple of `hash_block_size`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
block_hashes: list[BlockHash],
|
||||||
|
hash_block_size: int,
|
||||||
|
target_block_size: int,
|
||||||
|
):
|
||||||
|
self.block_hashes = block_hashes
|
||||||
|
assert target_block_size % hash_block_size == 0
|
||||||
|
self.scale_factor = target_block_size // hash_block_size
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.block_hashes) // self.scale_factor
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __getitem__(self, idx: int) -> BlockHash: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __getitem__(self, idx: slice) -> list[BlockHash]: ...
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
if isinstance(idx, int):
|
||||||
|
return self._get_value_at(idx)
|
||||||
|
|
||||||
|
if isinstance(idx, slice):
|
||||||
|
start, stop, step = idx.indices(len(self))
|
||||||
|
return [self._get_value_at(i) for i in range(start, stop, step)]
|
||||||
|
|
||||||
|
raise TypeError(f"Invalid index type: {type(idx)!r}")
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[BlockHash]:
|
||||||
|
for i in range(len(self)):
|
||||||
|
yield self._get_value_at(i)
|
||||||
|
|
||||||
|
def _get_value_at(self, idx: int) -> BlockHash:
|
||||||
|
base = idx * self.scale_factor
|
||||||
|
end = base + self.scale_factor
|
||||||
|
merged_hash: bytes = self.block_hashes[base]
|
||||||
|
for i in range(base + 1, end):
|
||||||
|
merged_hash += self.block_hashes[i]
|
||||||
|
return BlockHash(merged_hash)
|
||||||
|
|
||||||
|
|
||||||
|
BlockHashList = list[BlockHash] | BlockHashListWithBlockSize
|
||||||
|
|||||||
@ -186,6 +186,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
enable_kv_cache_events=self.enable_kv_cache_events,
|
enable_kv_cache_events=self.enable_kv_cache_events,
|
||||||
dcp_world_size=self.dcp_world_size,
|
dcp_world_size=self.dcp_world_size,
|
||||||
pcp_world_size=self.pcp_world_size,
|
pcp_world_size=self.pcp_world_size,
|
||||||
|
hash_block_size=self.block_size,
|
||||||
)
|
)
|
||||||
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
|
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
|
||||||
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
|
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from collections.abc import Sequence
|
|||||||
|
|
||||||
from vllm.utils.math_utils import cdiv
|
from vllm.utils.math_utils import cdiv
|
||||||
from vllm.v1.core.block_pool import BlockPool
|
from vllm.v1.core.block_pool import BlockPool
|
||||||
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
|
from vllm.v1.core.kv_cache_utils import BlockHashList, KVCacheBlock
|
||||||
from vllm.v1.kv_cache_interface import (
|
from vllm.v1.kv_cache_interface import (
|
||||||
ChunkedLocalAttentionSpec,
|
ChunkedLocalAttentionSpec,
|
||||||
CrossAttentionSpec,
|
CrossAttentionSpec,
|
||||||
@ -207,12 +207,13 @@ class SingleTypeKVCacheManager(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def find_longest_cache_hit(
|
def find_longest_cache_hit(
|
||||||
cls,
|
cls,
|
||||||
block_hashes: list[BlockHash],
|
block_hashes: BlockHashList,
|
||||||
max_length: int,
|
max_length: int,
|
||||||
kv_cache_group_ids: list[int],
|
kv_cache_group_ids: list[int],
|
||||||
block_pool: BlockPool,
|
block_pool: BlockPool,
|
||||||
kv_cache_spec: KVCacheSpec,
|
kv_cache_spec: KVCacheSpec,
|
||||||
use_eagle: bool,
|
use_eagle: bool,
|
||||||
|
alignment_tokens: int,
|
||||||
dcp_world_size: int = 1,
|
dcp_world_size: int = 1,
|
||||||
pcp_world_size: int = 1,
|
pcp_world_size: int = 1,
|
||||||
) -> tuple[list[KVCacheBlock], ...]:
|
) -> tuple[list[KVCacheBlock], ...]:
|
||||||
@ -232,6 +233,11 @@ class SingleTypeKVCacheManager(ABC):
|
|||||||
block_pool: The block pool.
|
block_pool: The block pool.
|
||||||
kv_cache_spec: The kv cache spec.
|
kv_cache_spec: The kv cache spec.
|
||||||
use_eagle: Whether to use eagle.
|
use_eagle: Whether to use eagle.
|
||||||
|
alignment_tokens: The returned cache hit length (in tokens) should
|
||||||
|
be a multiple of this value (in tokens). By default, it should
|
||||||
|
be set to the block_size.
|
||||||
|
dcp_world_size: The world size of decode context parallelism.
|
||||||
|
pcp_world_size: The world size of prefill context parallelism.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of cached blocks with skipped blocks replaced by null block
|
A list of cached blocks with skipped blocks replaced by null block
|
||||||
@ -299,17 +305,18 @@ class FullAttentionManager(SingleTypeKVCacheManager):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def find_longest_cache_hit(
|
def find_longest_cache_hit(
|
||||||
cls,
|
cls,
|
||||||
block_hashes: list[BlockHash],
|
block_hashes: BlockHashList,
|
||||||
max_length: int,
|
max_length: int,
|
||||||
kv_cache_group_ids: list[int],
|
kv_cache_group_ids: list[int],
|
||||||
block_pool: BlockPool,
|
block_pool: BlockPool,
|
||||||
kv_cache_spec: KVCacheSpec,
|
kv_cache_spec: KVCacheSpec,
|
||||||
use_eagle: bool,
|
use_eagle: bool,
|
||||||
|
alignment_tokens: int,
|
||||||
dcp_world_size: int = 1,
|
dcp_world_size: int = 1,
|
||||||
pcp_world_size: int = 1,
|
pcp_world_size: int = 1,
|
||||||
) -> tuple[list[KVCacheBlock], ...]:
|
) -> tuple[list[KVCacheBlock], ...]:
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec)
|
kv_cache_spec, FullAttentionSpec | ChunkedLocalAttentionSpec
|
||||||
), (
|
), (
|
||||||
"FullAttentionManager can only be used for full attention "
|
"FullAttentionManager can only be used for full attention "
|
||||||
"and chunked local attention groups"
|
"and chunked local attention groups"
|
||||||
@ -333,6 +340,13 @@ class FullAttentionManager(SingleTypeKVCacheManager):
|
|||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
if use_eagle and computed_blocks[0]:
|
if use_eagle and computed_blocks[0]:
|
||||||
|
# Need to drop the last matched block if eagle is enabled.
|
||||||
|
for computed in computed_blocks:
|
||||||
|
computed.pop()
|
||||||
|
while (
|
||||||
|
block_size != alignment_tokens # Faster for common case.
|
||||||
|
and len(computed_blocks[0]) * block_size % alignment_tokens != 0
|
||||||
|
):
|
||||||
for computed in computed_blocks:
|
for computed in computed_blocks:
|
||||||
computed.pop()
|
computed.pop()
|
||||||
return computed_blocks
|
return computed_blocks
|
||||||
@ -359,12 +373,13 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def find_longest_cache_hit(
|
def find_longest_cache_hit(
|
||||||
cls,
|
cls,
|
||||||
block_hashes: list[BlockHash],
|
block_hashes: BlockHashList,
|
||||||
max_length: int,
|
max_length: int,
|
||||||
kv_cache_group_ids: list[int],
|
kv_cache_group_ids: list[int],
|
||||||
block_pool: BlockPool,
|
block_pool: BlockPool,
|
||||||
kv_cache_spec: KVCacheSpec,
|
kv_cache_spec: KVCacheSpec,
|
||||||
use_eagle: bool,
|
use_eagle: bool,
|
||||||
|
alignment_tokens: int,
|
||||||
dcp_world_size: int = 1,
|
dcp_world_size: int = 1,
|
||||||
pcp_world_size: int = 1,
|
pcp_world_size: int = 1,
|
||||||
) -> tuple[list[KVCacheBlock], ...]:
|
) -> tuple[list[KVCacheBlock], ...]:
|
||||||
@ -396,6 +411,7 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
|
|||||||
[block_pool.null_block] * max_num_blocks
|
[block_pool.null_block] * max_num_blocks
|
||||||
for _ in range(len(kv_cache_group_ids))
|
for _ in range(len(kv_cache_group_ids))
|
||||||
)
|
)
|
||||||
|
block_size = kv_cache_spec.block_size
|
||||||
num_contiguous_blocks = 0
|
num_contiguous_blocks = 0
|
||||||
match_found = False
|
match_found = False
|
||||||
# Search from right to left and early stop when a match is found.
|
# Search from right to left and early stop when a match is found.
|
||||||
@ -403,6 +419,15 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
|
|||||||
if cached_block := block_pool.get_cached_block(
|
if cached_block := block_pool.get_cached_block(
|
||||||
block_hashes[i], kv_cache_group_ids
|
block_hashes[i], kv_cache_group_ids
|
||||||
):
|
):
|
||||||
|
# Skip prefix matching check if the block is not aligned with
|
||||||
|
# `alignment_tokens`.
|
||||||
|
if (
|
||||||
|
num_contiguous_blocks == 0
|
||||||
|
and block_size != alignment_tokens # Faster for common case.
|
||||||
|
and (i + 1) * block_size % alignment_tokens != 0
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
# Add the cached block to the computed blocks.
|
||||||
for computed, cached in zip(computed_blocks, cached_block):
|
for computed, cached in zip(computed_blocks, cached_block):
|
||||||
computed[i] = cached
|
computed[i] = cached
|
||||||
num_contiguous_blocks += 1
|
num_contiguous_blocks += 1
|
||||||
@ -421,7 +446,16 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
|
|||||||
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
|
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
|
||||||
for computed in computed_blocks:
|
for computed in computed_blocks:
|
||||||
del computed[num_contiguous_blocks:]
|
del computed[num_contiguous_blocks:]
|
||||||
|
while (
|
||||||
|
block_size != alignment_tokens # Faster for common case.
|
||||||
|
and len(computed_blocks[0]) * block_size % alignment_tokens != 0
|
||||||
|
):
|
||||||
|
for computed in computed_blocks:
|
||||||
|
computed.pop()
|
||||||
if use_eagle and computed_blocks[0]:
|
if use_eagle and computed_blocks[0]:
|
||||||
|
assert kv_cache_spec.block_size == alignment_tokens, (
|
||||||
|
"aligned_length is not compatible with eagle now"
|
||||||
|
)
|
||||||
for computed in computed_blocks:
|
for computed in computed_blocks:
|
||||||
computed.pop()
|
computed.pop()
|
||||||
return computed_blocks
|
return computed_blocks
|
||||||
@ -475,12 +509,13 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def find_longest_cache_hit(
|
def find_longest_cache_hit(
|
||||||
cls,
|
cls,
|
||||||
block_hashes: list[BlockHash],
|
block_hashes: BlockHashList,
|
||||||
max_length: int,
|
max_length: int,
|
||||||
kv_cache_group_ids: list[int],
|
kv_cache_group_ids: list[int],
|
||||||
block_pool: BlockPool,
|
block_pool: BlockPool,
|
||||||
kv_cache_spec: KVCacheSpec,
|
kv_cache_spec: KVCacheSpec,
|
||||||
use_eagle: bool,
|
use_eagle: bool,
|
||||||
|
alignment_tokens: int,
|
||||||
dcp_world_size: int = 1,
|
dcp_world_size: int = 1,
|
||||||
pcp_world_size: int = 1,
|
pcp_world_size: int = 1,
|
||||||
) -> tuple[list[KVCacheBlock], ...]:
|
) -> tuple[list[KVCacheBlock], ...]:
|
||||||
@ -511,6 +546,10 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
|
|||||||
block_pool: The block pool.
|
block_pool: The block pool.
|
||||||
kv_cache_spec: The kv cache spec.
|
kv_cache_spec: The kv cache spec.
|
||||||
use_eagle: Whether to use eagle.
|
use_eagle: Whether to use eagle.
|
||||||
|
dcp_world_size: The world size of decode context parallelism.
|
||||||
|
pcp_world_size: The world size of prefill context parallelism.
|
||||||
|
alignment_tokens: The returned cache hit length (in tokens) should
|
||||||
|
be a multiple of this value (in tokens).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of cached blocks
|
A list of cached blocks
|
||||||
@ -524,6 +563,10 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
|
|||||||
)
|
)
|
||||||
assert dcp_world_size == 1, "DCP not support chunked local attn now."
|
assert dcp_world_size == 1, "DCP not support chunked local attn now."
|
||||||
assert pcp_world_size == 1, "PCP not support chunked local attn now."
|
assert pcp_world_size == 1, "PCP not support chunked local attn now."
|
||||||
|
assert kv_cache_spec.block_size == alignment_tokens, (
|
||||||
|
"KV cache groups with different block sizes are not compatible with "
|
||||||
|
"chunked local attention now"
|
||||||
|
)
|
||||||
max_num_blocks = max_length // kv_cache_spec.block_size
|
max_num_blocks = max_length // kv_cache_spec.block_size
|
||||||
if max_length > 0:
|
if max_length > 0:
|
||||||
local_attention_start_idx = (
|
local_attention_start_idx = (
|
||||||
@ -612,12 +655,13 @@ class MambaManager(SingleTypeKVCacheManager):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def find_longest_cache_hit(
|
def find_longest_cache_hit(
|
||||||
cls,
|
cls,
|
||||||
block_hashes: list[BlockHash],
|
block_hashes: BlockHashList,
|
||||||
max_length: int,
|
max_length: int,
|
||||||
kv_cache_group_ids: list[int],
|
kv_cache_group_ids: list[int],
|
||||||
block_pool: BlockPool,
|
block_pool: BlockPool,
|
||||||
kv_cache_spec: KVCacheSpec,
|
kv_cache_spec: KVCacheSpec,
|
||||||
use_eagle: bool,
|
use_eagle: bool,
|
||||||
|
alignment_tokens: int,
|
||||||
dcp_world_size: int = 1,
|
dcp_world_size: int = 1,
|
||||||
pcp_world_size: int = 1,
|
pcp_world_size: int = 1,
|
||||||
) -> tuple[list[KVCacheBlock], ...]:
|
) -> tuple[list[KVCacheBlock], ...]:
|
||||||
@ -630,12 +674,21 @@ class MambaManager(SingleTypeKVCacheManager):
|
|||||||
[] for _ in range(len(kv_cache_group_ids))
|
[] for _ in range(len(kv_cache_group_ids))
|
||||||
)
|
)
|
||||||
|
|
||||||
max_num_blocks = max_length // kv_cache_spec.block_size
|
block_size = kv_cache_spec.block_size
|
||||||
|
max_num_blocks = max_length // block_size
|
||||||
# Search from right to left and early stop when a match is found.
|
# Search from right to left and early stop when a match is found.
|
||||||
for i in range(max_num_blocks - 1, -1, -1):
|
for i in range(max_num_blocks - 1, -1, -1):
|
||||||
if cached_block := block_pool.get_cached_block(
|
if cached_block := block_pool.get_cached_block(
|
||||||
block_hashes[i], kv_cache_group_ids
|
block_hashes[i], kv_cache_group_ids
|
||||||
):
|
):
|
||||||
|
# When enable Mamba prefix caching, `block_size` will be aligned
|
||||||
|
# across full attention layers and Mamba layers to ensure the
|
||||||
|
# prefix hit length aligned at block
|
||||||
|
if (
|
||||||
|
block_size != alignment_tokens # Faster for common case.
|
||||||
|
and (i + 1) * block_size % alignment_tokens != 0
|
||||||
|
):
|
||||||
|
continue
|
||||||
for computed, cached in zip(computed_blocks, cached_block):
|
for computed, cached in zip(computed_blocks, cached_block):
|
||||||
# the hit length logic later assumes:
|
# the hit length logic later assumes:
|
||||||
# hit_length = len(hit_blocks_other_attn[0])
|
# hit_length = len(hit_blocks_other_attn[0])
|
||||||
@ -708,12 +761,13 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def find_longest_cache_hit(
|
def find_longest_cache_hit(
|
||||||
cls,
|
cls,
|
||||||
block_hashes: list[BlockHash],
|
block_hashes: BlockHashList,
|
||||||
max_length: int,
|
max_length: int,
|
||||||
kv_cache_group_ids: list[int],
|
kv_cache_group_ids: list[int],
|
||||||
block_pool: BlockPool,
|
block_pool: BlockPool,
|
||||||
kv_cache_spec: KVCacheSpec,
|
kv_cache_spec: KVCacheSpec,
|
||||||
use_eagle: bool,
|
use_eagle: bool,
|
||||||
|
alignment_tokens: int,
|
||||||
dcp_world_size: int = 1,
|
dcp_world_size: int = 1,
|
||||||
pcp_world_size: int = 1,
|
pcp_world_size: int = 1,
|
||||||
) -> tuple[list[KVCacheBlock], ...]:
|
) -> tuple[list[KVCacheBlock], ...]:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user