mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 20:04:31 +08:00
[Hybrid Allocator] Support KV cache groups with different block_size (#29143)
Signed-off-by: Yifan Qiao <yifanqiao@berkeley.edu> Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
parent
e502098643
commit
48ddb02b79
@ -1248,7 +1248,9 @@ def test_allocate_with_lookahead():
|
||||
)
|
||||
|
||||
# Test case 1: Requires additional lookahead tokens
|
||||
kv_cache_manager = KVCacheManager(kv_cache_config=config, max_model_len=100)
|
||||
kv_cache_manager = KVCacheManager(
|
||||
kv_cache_config=config, max_model_len=100, hash_block_size=block_size
|
||||
)
|
||||
blocks = kv_cache_manager.allocate_slots(
|
||||
request,
|
||||
num_new_tokens=3,
|
||||
@ -1257,7 +1259,9 @@ def test_allocate_with_lookahead():
|
||||
assert len(blocks.get_block_ids()[0]) == 2 # ceil(5/4)=2 blocks
|
||||
|
||||
# Test case 2: With precomputed blocks
|
||||
kv_cache_manager = KVCacheManager(kv_cache_config=config, max_model_len=100)
|
||||
kv_cache_manager = KVCacheManager(
|
||||
kv_cache_config=config, max_model_len=100, hash_block_size=block_size
|
||||
)
|
||||
# required_blocks = ceil((3 + 2) /4) = 2
|
||||
blocks = kv_cache_manager.allocate_slots(
|
||||
request,
|
||||
@ -1268,7 +1272,9 @@ def test_allocate_with_lookahead():
|
||||
|
||||
# Test case 3: With precomputed blocks
|
||||
# required_blocks = ceil((3 + 4) / 4) = 2
|
||||
kv_cache_manager = KVCacheManager(kv_cache_config=config, max_model_len=100)
|
||||
kv_cache_manager = KVCacheManager(
|
||||
kv_cache_config=config, max_model_len=100, hash_block_size=block_size
|
||||
)
|
||||
blocks = kv_cache_manager.allocate_slots(
|
||||
request,
|
||||
num_new_tokens=3,
|
||||
@ -1495,7 +1501,8 @@ def test_get_kv_cache_config_one_worker():
|
||||
),
|
||||
],
|
||||
)
|
||||
# different hidden size
|
||||
|
||||
# different hidden size but same type, use UniformTypeKVCacheSpecs
|
||||
kv_cache_specs_hybrid = {
|
||||
"layer_1": new_kv_cache_spec(head_size=128),
|
||||
"layer_2": new_kv_cache_spec(head_size=64),
|
||||
@ -1519,6 +1526,40 @@ def test_get_kv_cache_config_one_worker():
|
||||
],
|
||||
)
|
||||
|
||||
# Different hidden size and different type, align by different block size
|
||||
kv_cache_specs_hybrid = {
|
||||
"layer_1": new_kv_cache_spec(head_size=64),
|
||||
"layer_2": new_sliding_window_spec(head_size=32),
|
||||
}
|
||||
kv_cache_config_hybrid = get_kv_cache_configs(
|
||||
vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 32]
|
||||
)[0]
|
||||
assert kv_cache_config_hybrid == KVCacheConfig(
|
||||
num_blocks=32,
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(
|
||||
size=mem_per_block_per_layer * 32, shared_by=["layer_1", "layer_2"]
|
||||
),
|
||||
],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(["layer_1"], new_kv_cache_spec(head_size=64)),
|
||||
KVCacheGroupSpec(
|
||||
["layer_2"], new_sliding_window_spec(head_size=32, block_size=32)
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# different hidden size that cannot be aligned by using different block size
|
||||
kv_cache_specs_hybrid = {
|
||||
"layer_1": new_kv_cache_spec(head_size=64),
|
||||
"layer_2": new_sliding_window_spec(head_size=96),
|
||||
}
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
get_kv_cache_configs(
|
||||
vllm_config, [kv_cache_specs_hybrid], [mem_per_block_per_layer * 2 * 32]
|
||||
)[0]
|
||||
|
||||
# Test num_gpu_blocks_override
|
||||
vllm_config.cache_config.num_gpu_blocks_override = 16
|
||||
kv_cache_config_override_blocks = get_kv_cache_configs(
|
||||
|
||||
@ -134,6 +134,7 @@ def test_prefill(hash_fn):
|
||||
make_kv_cache_config(block_size, 11),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
# Complete 3 blocks (48 tokens)
|
||||
@ -256,6 +257,7 @@ def test_prefill_hybrid_model():
|
||||
make_kv_cache_config_hybrid_model(block_size, 21),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
hash_fn = sha256
|
||||
@ -416,6 +418,7 @@ def test_prefill_plp():
|
||||
make_kv_cache_config(block_size, 11),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
# the default hash function is sha256
|
||||
hash_fn = sha256
|
||||
@ -523,6 +526,7 @@ def test_decode():
|
||||
make_kv_cache_config(block_size, 11),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
# Complete 3 blocks (48 tokens)
|
||||
@ -585,6 +589,7 @@ def test_evict():
|
||||
make_kv_cache_config(block_size, 11),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
last_token_id = 5 * 16 + 7
|
||||
@ -643,6 +648,7 @@ def test_hash_block_correct_reuse():
|
||||
make_kv_cache_config(16, 2),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
# Allocate 1 block and cache it.
|
||||
@ -683,6 +689,7 @@ def test_computed_blocks_not_evicted():
|
||||
make_kv_cache_config(block_size, 3),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
# Allocate a block and cache it.
|
||||
@ -741,6 +748,7 @@ def test_basic_prefix_caching_disabled():
|
||||
make_kv_cache_config(block_size, 5),
|
||||
max_model_len=8192,
|
||||
enable_caching=False,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
req1 = make_request(
|
||||
@ -790,6 +798,7 @@ def test_cache_blocks(hash_fn):
|
||||
block_pool = BlockPool(
|
||||
num_gpu_blocks=5,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
# Req:
|
||||
# Block 0: [0, 1, 2, 3]
|
||||
@ -833,7 +842,9 @@ def test_cache_blocks_multi_group():
|
||||
This tests that blocks are cached correctly for different kv cache groups.
|
||||
"""
|
||||
block_size = 4
|
||||
block_pool = BlockPool(num_gpu_blocks=10, enable_caching=True)
|
||||
block_pool = BlockPool(
|
||||
num_gpu_blocks=10, enable_caching=True, hash_block_size=block_size
|
||||
)
|
||||
|
||||
# Req:
|
||||
# Block 0/4: [0, 1, 2, 3]
|
||||
@ -921,6 +932,7 @@ def test_mm_prefix_caching():
|
||||
make_kv_cache_config(block_size, 11),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
# Common prompt tokens (T is text tokens and P is image placeholder tokens)
|
||||
@ -1020,6 +1032,7 @@ def test_cache_key_salting():
|
||||
make_kv_cache_config(block_size, 11),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
# 3 complete blocks and an incomplete block with 11 tokens.
|
||||
@ -1101,6 +1114,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||
make_kv_cache_config(block_size, 11),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
# Complete 3 blocks (48 tokens)
|
||||
# | Common-0 | Common-1 | Common-2 | ... |
|
||||
@ -1173,6 +1187,7 @@ def test_reset_prefix_cache():
|
||||
make_kv_cache_config(block_size, 11),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
full_block_token_ids = [i for i in range(3) for _ in range(16)]
|
||||
@ -1213,6 +1228,7 @@ def test_prefix_cache_stats_disabled():
|
||||
make_kv_cache_config(block_size, 11),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
log_stats=False, # Disable logging stats
|
||||
)
|
||||
assert manager.prefix_cache_stats is None
|
||||
@ -1232,7 +1248,7 @@ def test_prefix_cache_stats_disabled():
|
||||
|
||||
|
||||
def test_maybe_evict_cached_block():
|
||||
pool = BlockPool(num_gpu_blocks=4, enable_caching=True)
|
||||
pool = BlockPool(num_gpu_blocks=4, enable_caching=True, hash_block_size=16)
|
||||
block_hash0 = make_block_hash_with_group_id(BlockHash(b"10"), 1000)
|
||||
block_hash1 = make_block_hash_with_group_id(BlockHash(b"20"), 2000)
|
||||
block_hash2 = make_block_hash_with_group_id(BlockHash(b"30"), 3000)
|
||||
@ -1293,6 +1309,7 @@ def test_kv_cache_events(blocks_to_cache: int):
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
enable_kv_cache_events=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
num_tokens = block_size * blocks_to_cache
|
||||
@ -1351,6 +1368,7 @@ def test_kv_cache_events_with_lora(blocks_to_cache: int):
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
enable_kv_cache_events=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
# Test with LoRA request
|
||||
@ -1405,6 +1423,7 @@ def test_eagle_enabled_removes_last_block():
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
use_eagle=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
# Request with 3 full blocks (48 tokens)
|
||||
@ -1437,6 +1456,7 @@ def test_eagle_with_partial_blocks():
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
use_eagle=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
# 2 full blocks + 5 tokens (non-divisible length)
|
||||
token_ids = [0] * (2 * block_size + 5)
|
||||
@ -1476,6 +1496,7 @@ def test_eagle_with_sliding_window():
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
use_eagle=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
# 2 full blocks + 5 tokens (non-divisible length)
|
||||
@ -1522,6 +1543,76 @@ def test_eagle_with_sliding_window():
|
||||
assert num_tokens == 0
|
||||
|
||||
|
||||
def test_different_block_size():
|
||||
block_size = 16
|
||||
# full attention and sliding window attention layers have the same page size:
|
||||
# (32 tokens/block * float16 token, vs. 16 tokens/block * float32 token)
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=100,
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["layer1"],
|
||||
FullAttentionSpec(block_size * 2, 1, 1, torch.float16),
|
||||
),
|
||||
KVCacheGroupSpec(
|
||||
["layer2"],
|
||||
SlidingWindowSpec(
|
||||
block_size,
|
||||
1,
|
||||
1,
|
||||
torch.float32,
|
||||
sliding_window=2 * block_size,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
manager = KVCacheManager(
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
# 10 blocks of 16 tokens each. Token ids are not strictly aligned for each block.
|
||||
common_token_ids = [i for i in range(10) for _ in range(block_size)]
|
||||
|
||||
req0 = make_request("0", common_token_ids, block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert not computed_blocks.blocks[1]
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(
|
||||
req0, 7 * block_size, len(computed_blocks.blocks[0]) * 16, computed_blocks
|
||||
)
|
||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], [5, 6, 7, 8, 9, 10, 11])
|
||||
req1 = make_request("1", common_token_ids[: 7 * block_size + 1], block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert len(computed_blocks.blocks[0]) == 3
|
||||
assert len(computed_blocks.blocks[1]) == 6
|
||||
assert num_computed_tokens == 6 * 16
|
||||
|
||||
req2 = make_request("2", common_token_ids[: 6 * block_size + 1], block_size, sha256)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert len(computed_blocks.blocks[0]) == 3
|
||||
assert len(computed_blocks.blocks[1]) == 6
|
||||
assert num_computed_tokens == 6 * 16
|
||||
|
||||
# Evict some blocks to make sliding window cache hit length 5*16
|
||||
# But should return 4 * 16 because full attention cache hit length must be
|
||||
# a multiple of 32
|
||||
manager.block_pool.cached_block_hash_to_block.pop(
|
||||
make_block_hash_with_group_id(req1.block_hashes[6], 1), 11
|
||||
)
|
||||
manager.block_pool.cached_block_hash_to_block.pop(
|
||||
make_block_hash_with_group_id(req1.block_hashes[5], 1), 10
|
||||
)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert len(computed_blocks.blocks[0]) == 2
|
||||
assert len(computed_blocks.blocks[1]) == 4
|
||||
assert num_computed_tokens == 4 * 16
|
||||
|
||||
|
||||
def test_block_lookup_cache_single_block_per_key():
|
||||
cache = BlockHashToBlockMap()
|
||||
key0 = BlockHashWithGroupId(b"hash0")
|
||||
|
||||
@ -41,7 +41,9 @@ def test_chunked_local_attention_possible_cached_prefix():
|
||||
attention_chunk_size=4,
|
||||
)
|
||||
|
||||
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
|
||||
block_pool = BlockPool(
|
||||
num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
|
||||
)
|
||||
manager = get_chunked_local_attention_manager(
|
||||
chunked_local_attention_spec, block_pool
|
||||
)
|
||||
@ -70,6 +72,7 @@ def test_chunked_local_attention_possible_cached_prefix():
|
||||
block_pool=block_pool,
|
||||
kv_cache_spec=chunked_local_attention_spec,
|
||||
use_eagle=False,
|
||||
alignment_tokens=block_size,
|
||||
)[0]
|
||||
assert len(computed_blocks) == expect_length
|
||||
|
||||
@ -111,7 +114,9 @@ def test_sliding_window_possible_cached_prefix():
|
||||
sliding_window=4,
|
||||
)
|
||||
|
||||
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
|
||||
block_pool = BlockPool(
|
||||
num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
|
||||
)
|
||||
manager = get_sliding_window_manager(sliding_window_spec, block_pool)
|
||||
|
||||
def run_one_case(block_is_cached, expect_length):
|
||||
@ -138,6 +143,7 @@ def test_sliding_window_possible_cached_prefix():
|
||||
block_pool=block_pool,
|
||||
kv_cache_spec=sliding_window_spec,
|
||||
use_eagle=False,
|
||||
alignment_tokens=block_size,
|
||||
)[0]
|
||||
assert len(computed_blocks) == expect_length
|
||||
|
||||
@ -178,7 +184,7 @@ def test_chunked_local_attention_remove_skipped_blocks():
|
||||
attention_chunk_size=4,
|
||||
)
|
||||
|
||||
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
|
||||
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True, hash_block_size=2)
|
||||
|
||||
manager = get_chunked_local_attention_manager(attention_spec, block_pool)
|
||||
|
||||
@ -239,7 +245,7 @@ def test_sliding_window_remove_skipped_blocks():
|
||||
sliding_window=4,
|
||||
)
|
||||
|
||||
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
|
||||
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True, hash_block_size=2)
|
||||
|
||||
manager = get_sliding_window_manager(sliding_window_spec, block_pool)
|
||||
|
||||
@ -316,7 +322,9 @@ def test_get_num_blocks_to_allocate():
|
||||
sliding_window=4, # Placeholder value, not related to test result
|
||||
)
|
||||
|
||||
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
|
||||
block_pool = BlockPool(
|
||||
num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
|
||||
)
|
||||
manager = get_sliding_window_manager(sliding_window_spec, block_pool)
|
||||
cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)]
|
||||
cached_blocks_2 = [block_pool.null_block for _ in range(5)] + [
|
||||
@ -341,7 +349,9 @@ def test_chunked_local_attention_get_num_blocks_to_allocate():
|
||||
attention_chunk_size=4, # Placeholder value, not related to test result
|
||||
)
|
||||
|
||||
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
|
||||
block_pool = BlockPool(
|
||||
num_gpu_blocks=100, enable_caching=True, hash_block_size=block_size
|
||||
)
|
||||
manager = get_chunked_local_attention_manager(attention_spec, block_pool)
|
||||
cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)]
|
||||
cached_blocks_2 = [block_pool.null_block for _ in range(5)] + [
|
||||
|
||||
@ -1816,9 +1816,11 @@ class EngineArgs:
|
||||
if model_config.runner_type != "pooling":
|
||||
default_chunked_prefill = True
|
||||
|
||||
# Disable prefix caching default for hybrid models
|
||||
# since the feature is still experimental.
|
||||
default_prefix_caching = not model_config.is_hybrid
|
||||
# Disable prefix caching default for hybrid models and mamba-only
|
||||
# models since the feature is still experimental.
|
||||
default_prefix_caching = not (
|
||||
model_config.is_hybrid or model_config.is_attention_free
|
||||
)
|
||||
else:
|
||||
assert model_config.pooler_config is not None
|
||||
|
||||
|
||||
@ -289,9 +289,6 @@ class MambaModelConfig(VerifyAndUpdateConfig):
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
|
||||
if cache_config.mamba_block_size is None:
|
||||
cache_config.mamba_block_size = model_config.max_model_len
|
||||
|
||||
if cache_config.enable_prefix_caching:
|
||||
if model_config.supports_mamba_prefix_caching:
|
||||
logger.info(
|
||||
@ -299,6 +296,11 @@ class MambaModelConfig(VerifyAndUpdateConfig):
|
||||
"Its support for Mamba layers is experimental. "
|
||||
"Please report any issues you may observe."
|
||||
)
|
||||
# By default, mamba block size will be set to max_model_len (see
|
||||
# below). When enabling prefix caching, we align mamba block size
|
||||
# to the block size as the basic granularity for prefix caching.
|
||||
if cache_config.mamba_block_size is None:
|
||||
cache_config.mamba_block_size = cache_config.block_size
|
||||
else:
|
||||
logger.info(
|
||||
"Hybrid or mamba-based model detected without "
|
||||
@ -306,6 +308,9 @@ class MambaModelConfig(VerifyAndUpdateConfig):
|
||||
)
|
||||
cache_config.enable_prefix_caching = False
|
||||
|
||||
if cache_config.mamba_block_size is None:
|
||||
cache_config.mamba_block_size = model_config.max_model_len
|
||||
|
||||
# TODO(tdoublep): remove once cascade attention is supported
|
||||
logger.info(
|
||||
"Disabling cascade attention since it is not supported for hybrid models."
|
||||
|
||||
@ -13,6 +13,8 @@ from vllm.distributed.kv_events import (
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.kv_cache_utils import (
|
||||
BlockHash,
|
||||
BlockHashList,
|
||||
BlockHashListWithBlockSize,
|
||||
BlockHashWithGroupId,
|
||||
ExternalBlockHash,
|
||||
FreeKVCacheBlockQueue,
|
||||
@ -133,6 +135,10 @@ class BlockPool:
|
||||
Args:
|
||||
num_gpu_blocks: The number of blocks in the pool.
|
||||
enable_caching: Whether to enable prefix caching.
|
||||
hash_block_size: The block size of which the block hashes are computed.
|
||||
The actual block size usually equals hash_block_size, but in cases
|
||||
where different KV cache groups have different block sizes, the
|
||||
actual block size can be a multiple of hash_block_size.
|
||||
enable_kv_cache_events: Whether to enable kv cache events.
|
||||
"""
|
||||
|
||||
@ -140,11 +146,13 @@ class BlockPool:
|
||||
self,
|
||||
num_gpu_blocks: int,
|
||||
enable_caching: bool,
|
||||
hash_block_size: int,
|
||||
enable_kv_cache_events: bool = False,
|
||||
):
|
||||
assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
|
||||
self.num_gpu_blocks = num_gpu_blocks
|
||||
self.enable_caching = enable_caching
|
||||
self.hash_block_size = hash_block_size
|
||||
# All kv-cache blocks.
|
||||
self.blocks: list[KVCacheBlock] = [
|
||||
KVCacheBlock(idx) for idx in range(num_gpu_blocks)
|
||||
@ -223,8 +231,20 @@ class BlockPool:
|
||||
return
|
||||
new_full_blocks = blocks[num_cached_blocks:num_full_blocks]
|
||||
assert len(request.block_hashes) >= num_full_blocks
|
||||
new_block_hashes = request.block_hashes[num_cached_blocks:]
|
||||
if block_size == self.hash_block_size:
|
||||
# Common case.
|
||||
block_hashes: BlockHashList = request.block_hashes
|
||||
else:
|
||||
# block_size is a multiple of hash_block_size. This happens when
|
||||
# different KV cache groups have different block sizes.
|
||||
assert block_size % self.hash_block_size == 0
|
||||
# Recalculate block_hashes at the granularity of block_size, using
|
||||
# the original block_hashes (at the granularity of hash_block_size).
|
||||
block_hashes = BlockHashListWithBlockSize(
|
||||
request.block_hashes, self.hash_block_size, block_size
|
||||
)
|
||||
|
||||
new_block_hashes = block_hashes[num_cached_blocks:]
|
||||
new_hashes: list[ExternalBlockHash] | None = (
|
||||
[] if self.enable_kv_cache_events else None
|
||||
)
|
||||
|
||||
@ -2,15 +2,25 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from math import lcm
|
||||
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
|
||||
from vllm.v1.core.kv_cache_utils import (
|
||||
BlockHash,
|
||||
BlockHashList,
|
||||
BlockHashListWithBlockSize,
|
||||
KVCacheBlock,
|
||||
)
|
||||
from vllm.v1.core.single_type_kv_cache_manager import (
|
||||
CrossAttentionManager,
|
||||
FullAttentionManager,
|
||||
get_manager_for_kv_cache_spec,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
FullAttentionSpec,
|
||||
KVCacheConfig,
|
||||
KVCacheSpec,
|
||||
)
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
@ -28,13 +38,17 @@ class KVCacheCoordinator(ABC):
|
||||
enable_kv_cache_events: bool,
|
||||
dcp_world_size: int,
|
||||
pcp_world_size: int,
|
||||
hash_block_size: int,
|
||||
):
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.max_model_len = max_model_len
|
||||
self.enable_caching = enable_caching
|
||||
|
||||
self.block_pool = BlockPool(
|
||||
kv_cache_config.num_blocks, enable_caching, enable_kv_cache_events
|
||||
kv_cache_config.num_blocks,
|
||||
enable_caching,
|
||||
hash_block_size,
|
||||
enable_kv_cache_events,
|
||||
)
|
||||
|
||||
# Needs special handling for find_longest_cache_hit if eagle is enabled
|
||||
@ -213,6 +227,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
|
||||
enable_kv_cache_events: bool,
|
||||
dcp_world_size: int,
|
||||
pcp_world_size: int,
|
||||
hash_block_size: int,
|
||||
):
|
||||
super().__init__(
|
||||
kv_cache_config,
|
||||
@ -222,6 +237,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
|
||||
enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
pcp_world_size=pcp_world_size,
|
||||
hash_block_size=hash_block_size,
|
||||
)
|
||||
self.num_single_type_manager = len(self.single_type_managers)
|
||||
|
||||
@ -255,6 +271,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
|
||||
enable_kv_cache_events: bool,
|
||||
dcp_world_size: int,
|
||||
pcp_world_size: int,
|
||||
hash_block_size: int,
|
||||
):
|
||||
super().__init__(
|
||||
kv_cache_config,
|
||||
@ -264,6 +281,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
|
||||
enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
pcp_world_size=pcp_world_size,
|
||||
hash_block_size=hash_block_size,
|
||||
)
|
||||
self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec
|
||||
self.block_size = self.kv_cache_spec.block_size
|
||||
@ -273,6 +291,11 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
|
||||
self.block_size *= dcp_world_size
|
||||
if pcp_world_size > 1:
|
||||
self.block_size *= pcp_world_size
|
||||
# For models using only Mamba, block_size is set to max_model_len when
|
||||
# prefix caching is disabled, and hash_block_size validation is skipped.
|
||||
assert not enable_caching or (hash_block_size == self.block_size), (
|
||||
"UnitaryKVCacheCoordinator assumes hash_block_size == block_size"
|
||||
)
|
||||
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
|
||||
"UnitaryKVCacheCoordinator assumes only one kv cache group"
|
||||
)
|
||||
@ -289,6 +312,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_spec=self.kv_cache_spec,
|
||||
use_eagle=self.use_eagle,
|
||||
alignment_tokens=self.block_size,
|
||||
dcp_world_size=self.dcp_world_size,
|
||||
pcp_world_size=self.pcp_world_size,
|
||||
)
|
||||
@ -313,6 +337,7 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
||||
enable_kv_cache_events: bool,
|
||||
dcp_world_size: int,
|
||||
pcp_world_size: int,
|
||||
hash_block_size: int,
|
||||
):
|
||||
super().__init__(
|
||||
kv_cache_config,
|
||||
@ -322,7 +347,17 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
||||
enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
pcp_world_size=pcp_world_size,
|
||||
hash_block_size=hash_block_size,
|
||||
)
|
||||
# hash_block_size: the block size used to compute block hashes.
|
||||
# The actual block size usually equals hash_block_size, but in cases where
|
||||
# different KV cache groups have different block sizes, the actual block size
|
||||
# can be a multiple of hash_block_size.
|
||||
self.hash_block_size = hash_block_size
|
||||
assert all(
|
||||
g.kv_cache_spec.block_size % hash_block_size == 0
|
||||
for g in kv_cache_config.kv_cache_groups
|
||||
), "block_size must be divisible by hash_block_size"
|
||||
assert dcp_world_size == 1, "DCP not support hybrid attn now."
|
||||
assert pcp_world_size == 1, "PCP not support hybrid attn now."
|
||||
self.verify_and_split_kv_cache_groups()
|
||||
@ -373,14 +408,12 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
||||
self.other_spec = other_spec
|
||||
self.full_attention_block_size = self.full_attention_spec.block_size
|
||||
self.other_block_size = self.other_spec.block_size
|
||||
|
||||
if self.enable_caching:
|
||||
# this requirement is only needed for the prefix caching logic
|
||||
divisible = self.other_block_size % self.full_attention_block_size
|
||||
assert divisible == 0, (
|
||||
"KVCacheCoordinator assumes the block_size of full "
|
||||
"attention layers is divisible by other layers now."
|
||||
)
|
||||
# The LCM of the block sizes of full attention and other attention.
|
||||
# The cache hit length must be a multiple of the LCM of the block sizes
|
||||
# to make sure the cache hit length is a multiple of the block size of
|
||||
# each attention type. Requiring this because we don't support partial
|
||||
# block cache hit yet.
|
||||
self.lcm_block_size = lcm(self.full_attention_block_size, self.other_block_size)
|
||||
|
||||
if max(self.full_attention_group_ids) < min(self.other_group_ids):
|
||||
self.full_attn_first = True
|
||||
@ -414,25 +447,48 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
||||
- The number of tokens of the longest cache hit.
|
||||
"""
|
||||
# First, find the longest cache hit for full attention.
|
||||
if self.full_attention_spec.block_size == self.hash_block_size:
|
||||
# Common case.
|
||||
full_attention_block_hashes: BlockHashList = block_hashes
|
||||
else:
|
||||
# block_size is a multiple of hash_block_size. This happens when different
|
||||
# KV cache groups have different block sizes. In this case, we need to
|
||||
# recalculate block_hashes at the granularity of block_size, using the
|
||||
# original block_hashes (at the granularity of hash_block_size).
|
||||
full_attention_block_hashes = BlockHashListWithBlockSize(
|
||||
block_hashes, self.hash_block_size, self.full_attention_spec.block_size
|
||||
)
|
||||
hit_blocks_full_attn = self.full_attention_manager_cls.find_longest_cache_hit(
|
||||
block_hashes=block_hashes,
|
||||
block_hashes=full_attention_block_hashes,
|
||||
max_length=max_cache_hit_length,
|
||||
kv_cache_group_ids=self.full_attention_group_ids,
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_spec=self.full_attention_spec,
|
||||
use_eagle=self.use_eagle,
|
||||
alignment_tokens=self.lcm_block_size,
|
||||
)
|
||||
hit_length = len(hit_blocks_full_attn[0]) * self.full_attention_block_size
|
||||
|
||||
# Next, find the cache hit for the other attention WITHIN
|
||||
# the cache hit of full attention.
|
||||
if self.other_spec.block_size == self.hash_block_size:
|
||||
# Common case.
|
||||
other_block_hashes: BlockHashList = block_hashes
|
||||
else:
|
||||
# Similar to the full attention case, here we need to recalculate
|
||||
# block_hashes at the granularity of block_size, using the original
|
||||
# block_hashes (at the granularity of hash_block_size).
|
||||
other_block_hashes = BlockHashListWithBlockSize(
|
||||
block_hashes, self.hash_block_size, self.other_spec.block_size
|
||||
)
|
||||
hit_blocks_other_attn = self.other_attention_cls.find_longest_cache_hit(
|
||||
block_hashes=block_hashes,
|
||||
block_hashes=other_block_hashes,
|
||||
max_length=hit_length,
|
||||
kv_cache_group_ids=self.other_group_ids,
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_spec=self.other_spec,
|
||||
use_eagle=self.use_eagle,
|
||||
alignment_tokens=self.lcm_block_size,
|
||||
)
|
||||
hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size
|
||||
|
||||
@ -466,6 +522,7 @@ def get_kv_cache_coordinator(
|
||||
enable_kv_cache_events: bool,
|
||||
dcp_world_size: int,
|
||||
pcp_world_size: int,
|
||||
hash_block_size: int,
|
||||
) -> KVCacheCoordinator:
|
||||
if not enable_caching:
|
||||
return KVCacheCoordinatorNoPrefixCache(
|
||||
@ -473,8 +530,9 @@ def get_kv_cache_coordinator(
|
||||
max_model_len,
|
||||
use_eagle,
|
||||
enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
pcp_world_size=pcp_world_size,
|
||||
dcp_world_size,
|
||||
pcp_world_size,
|
||||
hash_block_size,
|
||||
)
|
||||
if len(kv_cache_config.kv_cache_groups) == 1:
|
||||
return UnitaryKVCacheCoordinator(
|
||||
@ -483,8 +541,9 @@ def get_kv_cache_coordinator(
|
||||
use_eagle,
|
||||
enable_caching,
|
||||
enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
pcp_world_size=pcp_world_size,
|
||||
dcp_world_size,
|
||||
pcp_world_size,
|
||||
hash_block_size,
|
||||
)
|
||||
return HybridKVCacheCoordinator(
|
||||
kv_cache_config,
|
||||
@ -492,6 +551,7 @@ def get_kv_cache_coordinator(
|
||||
use_eagle,
|
||||
enable_caching,
|
||||
enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
pcp_world_size=pcp_world_size,
|
||||
dcp_world_size,
|
||||
pcp_world_size,
|
||||
hash_block_size,
|
||||
)
|
||||
|
||||
@ -95,6 +95,7 @@ class KVCacheManager:
|
||||
self,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
max_model_len: int,
|
||||
hash_block_size: int,
|
||||
enable_caching: bool = True,
|
||||
use_eagle: bool = False,
|
||||
log_stats: bool = False,
|
||||
@ -107,28 +108,11 @@ class KVCacheManager:
|
||||
self.enable_caching = enable_caching
|
||||
self.use_eagle = use_eagle
|
||||
self.log_stats = log_stats
|
||||
# FIXME: make prefix cache stats conditional on log_stats
|
||||
# FIXME: make prefix cache stats conditional on log_stats. We still need
|
||||
# this comment because when the log stats is enabled there are still
|
||||
# potential configs we could expose in the future.
|
||||
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
|
||||
|
||||
self.block_size: int | None = None
|
||||
if self.enable_caching:
|
||||
assert (
|
||||
len(
|
||||
set(
|
||||
g.kv_cache_spec.block_size
|
||||
for g in kv_cache_config.kv_cache_groups
|
||||
)
|
||||
)
|
||||
== 1
|
||||
), "Only one block size is supported for now"
|
||||
self.block_size = kv_cache_config.kv_cache_groups[
|
||||
0
|
||||
].kv_cache_spec.block_size
|
||||
|
||||
if dcp_world_size * pcp_world_size > 1:
|
||||
assert len(kv_cache_config.kv_cache_groups) == 1
|
||||
self.block_size *= dcp_world_size * pcp_world_size
|
||||
|
||||
self.coordinator = get_kv_cache_coordinator(
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_model_len=self.max_model_len,
|
||||
@ -137,6 +121,7 @@ class KVCacheManager:
|
||||
enable_kv_cache_events=enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
pcp_world_size=pcp_world_size,
|
||||
hash_block_size=hash_block_size,
|
||||
)
|
||||
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
|
||||
self.block_pool = self.coordinator.block_pool
|
||||
|
||||
@ -5,9 +5,9 @@
|
||||
import copy
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, NewType, TypeAlias
|
||||
from collections.abc import Callable, Iterable, Iterator, Sequence
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import Any, NewType, TypeAlias, overload
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import VllmConfig
|
||||
@ -825,11 +825,11 @@ def get_num_blocks(
|
||||
return num_blocks
|
||||
|
||||
|
||||
def get_uniform_page_size(kv_cache_spec: dict[str, KVCacheSpec]) -> int:
|
||||
def get_uniform_page_size(kv_cache_specs: Iterable[KVCacheSpec]) -> int:
|
||||
"""
|
||||
Get the page size of the KV cache.
|
||||
"""
|
||||
page_sizes = set(layer.page_size_bytes for layer in kv_cache_spec.values())
|
||||
page_sizes = {layer.page_size_bytes for layer in kv_cache_specs}
|
||||
assert len(page_sizes) == 1
|
||||
return page_sizes.pop()
|
||||
|
||||
@ -882,6 +882,46 @@ def is_kv_cache_page_size_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool
|
||||
return len(page_sizes) == 1
|
||||
|
||||
|
||||
def unify_kv_cache_spec_page_size(
|
||||
kv_cache_spec: dict[str, KVCacheSpec],
|
||||
) -> dict[str, KVCacheSpec]:
|
||||
"""
|
||||
Unify the page size of the given KVCacheSpec. If the page size of all layers
|
||||
are the same, return the original KVCacheSpec. If not same, unify the page
|
||||
size by increasing the block size of layers with smaller page size. Raise
|
||||
NotImplementedError if failed to unify the page size.
|
||||
|
||||
Args:
|
||||
kv_cache_spec: The KVCacheSpec of each attention layer in the model
|
||||
|
||||
Returns:
|
||||
The updated KVCacheSpec with the same page_size_bytes.
|
||||
"""
|
||||
page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()}
|
||||
if len(page_sizes) <= 1:
|
||||
# All layers have the same page size, no need to unify.
|
||||
return kv_cache_spec
|
||||
|
||||
max_page_size = max(page_sizes)
|
||||
new_kv_cache_spec = {}
|
||||
for layer_name, layer_spec in kv_cache_spec.items():
|
||||
if layer_spec.page_size_bytes == max_page_size:
|
||||
new_kv_cache_spec[layer_name] = layer_spec
|
||||
else:
|
||||
layer_page_size = layer_spec.page_size_bytes
|
||||
if max_page_size % layer_page_size != 0:
|
||||
raise NotImplementedError(
|
||||
"The page size of the layer is not divisible by the "
|
||||
"maximum page size. Cannot unify by adjusting block_size."
|
||||
)
|
||||
ratio = max_page_size // layer_page_size
|
||||
new_block_size = layer_spec.block_size * ratio
|
||||
new_spec = replace(layer_spec, block_size=new_block_size)
|
||||
assert new_spec.page_size_bytes == max_page_size
|
||||
new_kv_cache_spec[layer_name] = new_spec
|
||||
return new_kv_cache_spec
|
||||
|
||||
|
||||
def is_kv_cache_type_attention_free(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
|
||||
# kv_cache_spec is an empty dict for attention free models
|
||||
return not kv_cache_spec
|
||||
@ -1010,7 +1050,6 @@ def _get_kv_cache_groups_uniform_page_size(
|
||||
def get_kv_cache_config_from_groups(
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_groups: list[KVCacheGroupSpec],
|
||||
kv_cache_specs: dict[str, KVCacheSpec],
|
||||
available_memory: int,
|
||||
) -> KVCacheConfig:
|
||||
"""
|
||||
@ -1020,7 +1059,6 @@ def get_kv_cache_config_from_groups(
|
||||
Args:
|
||||
vllm_config: The global VllmConfig
|
||||
kv_cache_groups: The KV cache groups
|
||||
kv_cache_specs: The KV cache spec of each attention layer in the model
|
||||
available_memory: Memory available for KV cache in bytes
|
||||
Returns:
|
||||
The generated KVCacheConfig
|
||||
@ -1064,7 +1102,9 @@ def get_kv_cache_config_from_groups(
|
||||
# full.1, sw.2: share another Tensor with size=available_memory//2
|
||||
group_size = max(len(group.layer_names) for group in kv_cache_groups)
|
||||
|
||||
page_size = get_uniform_page_size(kv_cache_specs)
|
||||
page_size = get_uniform_page_size(
|
||||
[group.kv_cache_spec for group in kv_cache_groups]
|
||||
)
|
||||
assert group_size > 0, "group_size must be greater than 0"
|
||||
num_blocks = get_num_blocks(
|
||||
vllm_config, group_size, available_memory, page_size
|
||||
@ -1166,7 +1206,8 @@ def get_kv_cache_groups(
|
||||
# This returns an empty list to allow for the KVCacheManager to handle
|
||||
# attention free models.
|
||||
return []
|
||||
elif is_kv_cache_spec_uniform(kv_cache_spec):
|
||||
|
||||
if is_kv_cache_spec_uniform(kv_cache_spec):
|
||||
# KV cache of all layers are the same, which is true for
|
||||
# most models. Allocate the same amount of memory for
|
||||
# each layer.
|
||||
@ -1176,14 +1217,16 @@ def get_kv_cache_groups(
|
||||
# full attention, or all layers are sliding window attention with the
|
||||
# same window size). Put all layers into one group.
|
||||
return _get_kv_cache_groups_uniform_type(uniform_spec)
|
||||
elif is_kv_cache_page_size_uniform(kv_cache_spec):
|
||||
# Model contains multiple attention types, but KV cache of all layers
|
||||
# have the same physical memory per block per layer. Split the layers
|
||||
# into groups with the same number of layers, and thus same total page
|
||||
# size.
|
||||
return _get_kv_cache_groups_uniform_page_size(kv_cache_spec)
|
||||
|
||||
raise NotImplementedError
|
||||
# As KVCacheManager can only allocate memory of one size, we need to unify
|
||||
# the page size of the layers. For cases cannot be unified, this function
|
||||
# will raise an error.
|
||||
kv_cache_spec = unify_kv_cache_spec_page_size(kv_cache_spec)
|
||||
# Model contains multiple attention types, but KV cache of all layers
|
||||
# have the same physical memory per block per layer. Split the layers
|
||||
# into groups with the same number of layers, and thus same total page
|
||||
# size.
|
||||
return _get_kv_cache_groups_uniform_page_size(kv_cache_spec)
|
||||
|
||||
|
||||
def generate_scheduler_kv_cache_config(
|
||||
@ -1327,10 +1370,7 @@ def get_kv_cache_configs(
|
||||
) == len(kv_cache_spec_one_worker), "Some layers are not assigned to any group."
|
||||
kv_cache_configs.append(
|
||||
get_kv_cache_config_from_groups(
|
||||
vllm_config,
|
||||
kv_cache_groups_one_worker,
|
||||
kv_cache_spec_one_worker,
|
||||
available_memory_one_worker,
|
||||
vllm_config, kv_cache_groups_one_worker, available_memory_one_worker
|
||||
)
|
||||
)
|
||||
|
||||
@ -1353,3 +1393,79 @@ def get_kv_cache_configs(
|
||||
_report_kv_cache_config(vllm_config, kv_cache_config)
|
||||
|
||||
return kv_cache_configs
|
||||
|
||||
|
||||
class BlockHashListWithBlockSize:
|
||||
"""
|
||||
Convert block-hash granularity from `hash_block_size` to `target_block_size`.
|
||||
Used when KV cache groups have different block sizes: `hash_block_size`
|
||||
is the size used to compute the original `block_hashes`; `target_block_size`
|
||||
is the group's actual block size.
|
||||
|
||||
Currently, only scaling up by an integer factor is supported (i.e.,
|
||||
`target_block_size` is a multiple of `hash_block_size`). Conversion is
|
||||
performed lazily on access for efficiency, by concatenating consecutive
|
||||
hashes at `hash_block_size` to form each hash at `target_block_size`.
|
||||
|
||||
Example (`hash_block_size` = 16, `target_block_size` = 32):
|
||||
concatenating two 16-size hashes yields one 32-size hash:
|
||||
|
||||
Block hashes with block_size 16:
|
||||
| Token Range | 0-15 | 16-31 | 32-47 | 48-63 |
|
||||
|-------------|------|-------|-------|-------|
|
||||
| Hash | A | B | C | D |
|
||||
|
||||
Block hashes with block_size 32:
|
||||
| Token Range | 0-31 | 32-63 |
|
||||
|-------------|------|-------|
|
||||
| Hash | AB | CD |
|
||||
|
||||
Args:
|
||||
block_hashes: Block hashes to convert, computed at `hash_block_size`.
|
||||
hash_block_size: Block size at which `block_hashes` were computed.
|
||||
target_block_size: Desired block size; must be a multiple of `hash_block_size`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_hashes: list[BlockHash],
|
||||
hash_block_size: int,
|
||||
target_block_size: int,
|
||||
):
|
||||
self.block_hashes = block_hashes
|
||||
assert target_block_size % hash_block_size == 0
|
||||
self.scale_factor = target_block_size // hash_block_size
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.block_hashes) // self.scale_factor
|
||||
|
||||
@overload
|
||||
def __getitem__(self, idx: int) -> BlockHash: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, idx: slice) -> list[BlockHash]: ...
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, int):
|
||||
return self._get_value_at(idx)
|
||||
|
||||
if isinstance(idx, slice):
|
||||
start, stop, step = idx.indices(len(self))
|
||||
return [self._get_value_at(i) for i in range(start, stop, step)]
|
||||
|
||||
raise TypeError(f"Invalid index type: {type(idx)!r}")
|
||||
|
||||
def __iter__(self) -> Iterator[BlockHash]:
|
||||
for i in range(len(self)):
|
||||
yield self._get_value_at(i)
|
||||
|
||||
def _get_value_at(self, idx: int) -> BlockHash:
|
||||
base = idx * self.scale_factor
|
||||
end = base + self.scale_factor
|
||||
merged_hash: bytes = self.block_hashes[base]
|
||||
for i in range(base + 1, end):
|
||||
merged_hash += self.block_hashes[i]
|
||||
return BlockHash(merged_hash)
|
||||
|
||||
|
||||
BlockHashList = list[BlockHash] | BlockHashListWithBlockSize
|
||||
|
||||
@ -186,6 +186,7 @@ class Scheduler(SchedulerInterface):
|
||||
enable_kv_cache_events=self.enable_kv_cache_events,
|
||||
dcp_world_size=self.dcp_world_size,
|
||||
pcp_world_size=self.pcp_world_size,
|
||||
hash_block_size=self.block_size,
|
||||
)
|
||||
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
|
||||
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
|
||||
|
||||
@ -7,7 +7,7 @@ from collections.abc import Sequence
|
||||
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
|
||||
from vllm.v1.core.kv_cache_utils import BlockHashList, KVCacheBlock
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
ChunkedLocalAttentionSpec,
|
||||
CrossAttentionSpec,
|
||||
@ -207,12 +207,13 @@ class SingleTypeKVCacheManager(ABC):
|
||||
@abstractmethod
|
||||
def find_longest_cache_hit(
|
||||
cls,
|
||||
block_hashes: list[BlockHash],
|
||||
block_hashes: BlockHashList,
|
||||
max_length: int,
|
||||
kv_cache_group_ids: list[int],
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
alignment_tokens: int,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
@ -232,6 +233,11 @@ class SingleTypeKVCacheManager(ABC):
|
||||
block_pool: The block pool.
|
||||
kv_cache_spec: The kv cache spec.
|
||||
use_eagle: Whether to use eagle.
|
||||
alignment_tokens: The returned cache hit length (in tokens) should
|
||||
be a multiple of this value (in tokens). By default, it should
|
||||
be set to the block_size.
|
||||
dcp_world_size: The world size of decode context parallelism.
|
||||
pcp_world_size: The world size of prefill context parallelism.
|
||||
|
||||
Returns:
|
||||
A list of cached blocks with skipped blocks replaced by null block
|
||||
@ -299,17 +305,18 @@ class FullAttentionManager(SingleTypeKVCacheManager):
|
||||
@classmethod
|
||||
def find_longest_cache_hit(
|
||||
cls,
|
||||
block_hashes: list[BlockHash],
|
||||
block_hashes: BlockHashList,
|
||||
max_length: int,
|
||||
kv_cache_group_ids: list[int],
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
alignment_tokens: int,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
assert isinstance(
|
||||
kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec)
|
||||
kv_cache_spec, FullAttentionSpec | ChunkedLocalAttentionSpec
|
||||
), (
|
||||
"FullAttentionManager can only be used for full attention "
|
||||
"and chunked local attention groups"
|
||||
@ -333,6 +340,13 @@ class FullAttentionManager(SingleTypeKVCacheManager):
|
||||
else:
|
||||
break
|
||||
if use_eagle and computed_blocks[0]:
|
||||
# Need to drop the last matched block if eagle is enabled.
|
||||
for computed in computed_blocks:
|
||||
computed.pop()
|
||||
while (
|
||||
block_size != alignment_tokens # Faster for common case.
|
||||
and len(computed_blocks[0]) * block_size % alignment_tokens != 0
|
||||
):
|
||||
for computed in computed_blocks:
|
||||
computed.pop()
|
||||
return computed_blocks
|
||||
@ -359,12 +373,13 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
|
||||
@classmethod
|
||||
def find_longest_cache_hit(
|
||||
cls,
|
||||
block_hashes: list[BlockHash],
|
||||
block_hashes: BlockHashList,
|
||||
max_length: int,
|
||||
kv_cache_group_ids: list[int],
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
alignment_tokens: int,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
@ -396,6 +411,7 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
|
||||
[block_pool.null_block] * max_num_blocks
|
||||
for _ in range(len(kv_cache_group_ids))
|
||||
)
|
||||
block_size = kv_cache_spec.block_size
|
||||
num_contiguous_blocks = 0
|
||||
match_found = False
|
||||
# Search from right to left and early stop when a match is found.
|
||||
@ -403,6 +419,15 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
|
||||
if cached_block := block_pool.get_cached_block(
|
||||
block_hashes[i], kv_cache_group_ids
|
||||
):
|
||||
# Skip prefix matching check if the block is not aligned with
|
||||
# `alignment_tokens`.
|
||||
if (
|
||||
num_contiguous_blocks == 0
|
||||
and block_size != alignment_tokens # Faster for common case.
|
||||
and (i + 1) * block_size % alignment_tokens != 0
|
||||
):
|
||||
continue
|
||||
# Add the cached block to the computed blocks.
|
||||
for computed, cached in zip(computed_blocks, cached_block):
|
||||
computed[i] = cached
|
||||
num_contiguous_blocks += 1
|
||||
@ -421,7 +446,16 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
|
||||
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
|
||||
for computed in computed_blocks:
|
||||
del computed[num_contiguous_blocks:]
|
||||
while (
|
||||
block_size != alignment_tokens # Faster for common case.
|
||||
and len(computed_blocks[0]) * block_size % alignment_tokens != 0
|
||||
):
|
||||
for computed in computed_blocks:
|
||||
computed.pop()
|
||||
if use_eagle and computed_blocks[0]:
|
||||
assert kv_cache_spec.block_size == alignment_tokens, (
|
||||
"aligned_length is not compatible with eagle now"
|
||||
)
|
||||
for computed in computed_blocks:
|
||||
computed.pop()
|
||||
return computed_blocks
|
||||
@ -475,12 +509,13 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
|
||||
@classmethod
|
||||
def find_longest_cache_hit(
|
||||
cls,
|
||||
block_hashes: list[BlockHash],
|
||||
block_hashes: BlockHashList,
|
||||
max_length: int,
|
||||
kv_cache_group_ids: list[int],
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
alignment_tokens: int,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
@ -511,6 +546,10 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
|
||||
block_pool: The block pool.
|
||||
kv_cache_spec: The kv cache spec.
|
||||
use_eagle: Whether to use eagle.
|
||||
dcp_world_size: The world size of decode context parallelism.
|
||||
pcp_world_size: The world size of prefill context parallelism.
|
||||
alignment_tokens: The returned cache hit length (in tokens) should
|
||||
be a multiple of this value (in tokens).
|
||||
|
||||
Returns:
|
||||
A list of cached blocks
|
||||
@ -524,6 +563,10 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
|
||||
)
|
||||
assert dcp_world_size == 1, "DCP not support chunked local attn now."
|
||||
assert pcp_world_size == 1, "PCP not support chunked local attn now."
|
||||
assert kv_cache_spec.block_size == alignment_tokens, (
|
||||
"KV cache groups with different block sizes are not compatible with "
|
||||
"chunked local attention now"
|
||||
)
|
||||
max_num_blocks = max_length // kv_cache_spec.block_size
|
||||
if max_length > 0:
|
||||
local_attention_start_idx = (
|
||||
@ -612,12 +655,13 @@ class MambaManager(SingleTypeKVCacheManager):
|
||||
@classmethod
|
||||
def find_longest_cache_hit(
|
||||
cls,
|
||||
block_hashes: list[BlockHash],
|
||||
block_hashes: BlockHashList,
|
||||
max_length: int,
|
||||
kv_cache_group_ids: list[int],
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
alignment_tokens: int,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
@ -630,12 +674,21 @@ class MambaManager(SingleTypeKVCacheManager):
|
||||
[] for _ in range(len(kv_cache_group_ids))
|
||||
)
|
||||
|
||||
max_num_blocks = max_length // kv_cache_spec.block_size
|
||||
block_size = kv_cache_spec.block_size
|
||||
max_num_blocks = max_length // block_size
|
||||
# Search from right to left and early stop when a match is found.
|
||||
for i in range(max_num_blocks - 1, -1, -1):
|
||||
if cached_block := block_pool.get_cached_block(
|
||||
block_hashes[i], kv_cache_group_ids
|
||||
):
|
||||
# When enable Mamba prefix caching, `block_size` will be aligned
|
||||
# across full attention layers and Mamba layers to ensure the
|
||||
# prefix hit length aligned at block
|
||||
if (
|
||||
block_size != alignment_tokens # Faster for common case.
|
||||
and (i + 1) * block_size % alignment_tokens != 0
|
||||
):
|
||||
continue
|
||||
for computed, cached in zip(computed_blocks, cached_block):
|
||||
# the hit length logic later assumes:
|
||||
# hit_length = len(hit_blocks_other_attn[0])
|
||||
@ -708,12 +761,13 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
|
||||
@classmethod
|
||||
def find_longest_cache_hit(
|
||||
cls,
|
||||
block_hashes: list[BlockHash],
|
||||
block_hashes: BlockHashList,
|
||||
max_length: int,
|
||||
kv_cache_group_ids: list[int],
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
alignment_tokens: int,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user