[Hybrid Allocator] Support Pipeline Parallel (#23974)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang 2025-09-14 15:55:17 -07:00 committed by GitHub
parent 90f3f7d73e
commit 8e5cdcda4e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 472 additions and 235 deletions

View File

@ -215,9 +215,7 @@ TEXT_GENERATION_MODELS = {
EMBEDDING_MODELS = { # type: ignore[var-annotated]
# [Text-only]
"intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(runner="pooling"),
# TODO: re-enable when https://github.com/vllm-project/vllm/issues/23883
# is fixed
#"BAAI/bge-multilingual-gemma2": PPTestSettings.fast(runner="pooling"),
"BAAI/bge-multilingual-gemma2": PPTestSettings.fast(runner="pooling"),
"Qwen/Qwen2.5-Math-RM-72B": PPTestSettings.fast(
load_format="dummy", runner="pooling"
),

View File

@ -10,7 +10,7 @@ from vllm import LLM
from vllm.config import ModelImpl
from vllm.engine.llm_engine import LLMEngine as V0LLMEngine
from vllm.utils import GiB_bytes
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
from vllm.v1.engine.core import EngineCore as V1EngineCore
from ..utils import create_new_process_for_each_test
@ -68,11 +68,11 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
def _initialize_kv_caches_v1(self, vllm_config):
kv_cache_specs = self.model_executor.get_kv_cache_specs()
scheduler_kv_cache_config = get_kv_cache_config(
scheduler_kv_cache_config = get_kv_cache_configs(
vllm_config,
kv_cache_specs[0],
10 * GiB_bytes,
)
kv_cache_specs,
[10 * GiB_bytes],
)[0]
# gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config
return 1, 0, scheduler_kv_cache_config

View File

@ -18,13 +18,12 @@ from vllm.v1.core.kv_cache_manager import KVCacheManager
from vllm.v1.core.kv_cache_utils import (
BlockHash, FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics,
estimate_max_model_len, generate_block_hash_extra_keys,
get_kv_cache_config, get_max_concurrency_for_kv_cache_config,
get_kv_cache_configs, get_max_concurrency_for_kv_cache_config,
get_request_block_hasher, hash_block_tokens, init_none_hash,
is_kv_cache_type_uniform, make_block_hash_with_group_id,
unify_kv_cache_configs)
is_kv_cache_type_uniform, make_block_hash_with_group_id)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheTensor,
SlidingWindowSpec)
KVCacheGroupSpec, KVCacheSpec,
KVCacheTensor, SlidingWindowSpec)
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request
@ -531,102 +530,288 @@ def test_metrics():
assert not metrics.query_queue
def test_unify_kv_cache_configs():
same_kv_cache_config = [
KVCacheConfig(
num_blocks=10,
kv_cache_tensors=[
KVCacheTensor(size=100, shared_by=["layer1"]),
KVCacheTensor(size=100, shared_by=["layer2"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer2"],
new_kv_cache_spec(num_kv_heads=4)),
],
),
KVCacheConfig(
num_blocks=20,
kv_cache_tensors=[
KVCacheTensor(size=100, shared_by=["layer1"]),
KVCacheTensor(size=100, shared_by=["layer2"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer2"],
new_kv_cache_spec(num_kv_heads=4)),
],
),
]
unify_kv_cache_configs(same_kv_cache_config)
assert same_kv_cache_config[0].num_blocks == 10
assert same_kv_cache_config[1].num_blocks == 10
def test_get_kv_cache_configs_multiple_workers():
model_config = ModelConfig(max_model_len=16)
vllm_config = VllmConfig(model_config=model_config)
need_sort_kv_cache_config = [
ref_kv_cache_spec = new_kv_cache_spec()
same_kv_cache_specs = [{
"layer1": new_kv_cache_spec(),
"layer2": new_kv_cache_spec(),
}, {
"layer1": new_kv_cache_spec(),
"layer2": new_kv_cache_spec(),
}]
# Basic case. All things are the same.
kv_cache_configs = get_kv_cache_configs(vllm_config, same_kv_cache_specs, [
ref_kv_cache_spec.page_size_bytes * 2 * 10,
ref_kv_cache_spec.page_size_bytes * 2 * 10
])
assert kv_cache_configs == [
KVCacheConfig(
num_blocks=10,
kv_cache_tensors=[
KVCacheTensor(size=100, shared_by=["layer1"]),
KVCacheTensor(size=100, shared_by=["layer2"]),
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
shared_by=["layer1"]),
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
shared_by=["layer2"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer2"],
new_kv_cache_spec(num_kv_heads=4)),
KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec),
],
),
KVCacheConfig(
num_blocks=20,
num_blocks=10,
kv_cache_tensors=[
KVCacheTensor(size=100, shared_by=["layer1"]),
KVCacheTensor(size=100, shared_by=["layer2"]),
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
shared_by=["layer1"]),
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
shared_by=["layer2"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer2"],
new_kv_cache_spec(num_kv_heads=4)),
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec),
],
),
]
unify_kv_cache_configs(need_sort_kv_cache_config)
sorted_kv_cache_groups = [
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer2"], new_kv_cache_spec(num_kv_heads=4)),
]
assert (
need_sort_kv_cache_config[0].kv_cache_groups == sorted_kv_cache_groups)
assert (
need_sort_kv_cache_config[1].kv_cache_groups == sorted_kv_cache_groups)
diff_kv_cache_config = [
# Different available memory. This is the case for TP.
# Use the smallest memory available.
kv_cache_configs = get_kv_cache_configs(vllm_config, same_kv_cache_specs, [
ref_kv_cache_spec.page_size_bytes * 2 * 10,
ref_kv_cache_spec.page_size_bytes * 2 * 20
])
assert kv_cache_configs == [
KVCacheConfig(
num_blocks=10,
kv_cache_tensors=[
KVCacheTensor(size=100, shared_by=["layer1"]),
KVCacheTensor(size=100, shared_by=["layer2"]),
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
shared_by=["layer1"]),
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
shared_by=["layer2"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer2"],
new_kv_cache_spec(num_kv_heads=4)),
KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec),
],
),
KVCacheConfig(
num_blocks=20,
num_blocks=10,
kv_cache_tensors=[
KVCacheTensor(size=100, shared_by=["layer1"]),
KVCacheTensor(size=100, shared_by=["layer2"]),
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 20,
shared_by=["layer1"]),
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 20,
shared_by=["layer2"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer2"],
new_kv_cache_spec(num_kv_heads=8)),
KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec),
],
),
]
# Different KV cache specs. This is the case for PP.
different_layer_specs = [{
"layer1": new_kv_cache_spec(),
}, {
"layer2": new_kv_cache_spec(),
"layer3": new_kv_cache_spec(),
}]
# Different workers have different layers.
kv_cache_configs = get_kv_cache_configs(
vllm_config, different_layer_specs, [
ref_kv_cache_spec.page_size_bytes * 2 * 10,
ref_kv_cache_spec.page_size_bytes * 2 * 10
])
assert kv_cache_configs == [
KVCacheConfig(
num_blocks=10,
kv_cache_tensors=[
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 20,
shared_by=["layer1"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
],
),
KVCacheConfig(
num_blocks=10,
kv_cache_tensors=[
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
shared_by=["layer2"]),
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
shared_by=["layer3"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer2", "layer3"], new_kv_cache_spec()),
],
),
]
# Some layers are the same, some are different. This is the case for TP+PP
tp_pp_kv_cache_specs = [{
"layer1": new_kv_cache_spec(),
"layer2": new_kv_cache_spec(),
}, {
"layer1": new_kv_cache_spec(),
"layer2": new_kv_cache_spec(),
}, {
"layer3": new_kv_cache_spec(),
}, {
"layer3": new_kv_cache_spec(),
}]
kv_cache_configs = get_kv_cache_configs(
vllm_config, tp_pp_kv_cache_specs, [
ref_kv_cache_spec.page_size_bytes * 2 * 10,
ref_kv_cache_spec.page_size_bytes * 2 * 10,
ref_kv_cache_spec.page_size_bytes * 2 * 10,
ref_kv_cache_spec.page_size_bytes * 2 * 10,
])
assert kv_cache_configs == [
KVCacheConfig(
num_blocks=10,
kv_cache_tensors=[
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
shared_by=["layer1"]),
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
shared_by=["layer2"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec),
],
),
KVCacheConfig(
num_blocks=10,
kv_cache_tensors=[
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
shared_by=["layer1"]),
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
shared_by=["layer2"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec),
],
),
KVCacheConfig(
num_blocks=10,
kv_cache_tensors=[
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 20,
shared_by=["layer3"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer3"], ref_kv_cache_spec),
],
),
KVCacheConfig(
num_blocks=10,
kv_cache_tensors=[
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 20,
shared_by=["layer3"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer3"], ref_kv_cache_spec),
],
),
]
# Different workers have different types of layers. This is the case for
# hybrid models + PP.
different_type_layer_specs = [{
"layer1": new_kv_cache_spec(),
"layer2": new_kv_cache_spec(),
}, {
"layer3": new_sliding_window_spec(),
"layer4": new_sliding_window_spec(),
}]
kv_cache_configs = get_kv_cache_configs(
vllm_config, different_type_layer_specs, [
ref_kv_cache_spec.page_size_bytes * 2 * 10,
ref_kv_cache_spec.page_size_bytes * 2 * 10,
])
assert kv_cache_configs == [
KVCacheConfig(
num_blocks=10,
kv_cache_tensors=[
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
shared_by=["layer1"]),
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
shared_by=["layer2"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec),
KVCacheGroupSpec([], new_sliding_window_spec()),
],
),
KVCacheConfig(
num_blocks=10,
kv_cache_tensors=[
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
shared_by=["layer3"]),
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
shared_by=["layer4"]),
],
kv_cache_groups=[
KVCacheGroupSpec([], ref_kv_cache_spec),
KVCacheGroupSpec(["layer3", "layer4"],
new_sliding_window_spec()),
],
),
]
# When divided into multiple KVCacheGroups, need to ensure the number of
# layers per group is similar.
different_type_layer_specs = [{
"layer1": new_kv_cache_spec(),
"layer2": new_sliding_window_spec(),
"layer3": new_sliding_window_spec(),
}, {
"layer4": new_kv_cache_spec(),
"layer5": new_sliding_window_spec(),
"layer6": new_sliding_window_spec(),
}]
kv_cache_configs = get_kv_cache_configs(
vllm_config, different_type_layer_specs, [
ref_kv_cache_spec.page_size_bytes * 10,
ref_kv_cache_spec.page_size_bytes * 10,
])
assert kv_cache_configs == [
KVCacheConfig(
num_blocks=10,
kv_cache_tensors=[
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
shared_by=["layer1", "layer2", "layer3"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer1"], ref_kv_cache_spec),
KVCacheGroupSpec(["layer2"], new_sliding_window_spec()),
KVCacheGroupSpec(["layer3"], new_sliding_window_spec()),
],
),
KVCacheConfig(
num_blocks=10,
kv_cache_tensors=[
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
shared_by=["layer4", "layer5", "layer6"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer4"], ref_kv_cache_spec),
KVCacheGroupSpec(["layer5"], new_sliding_window_spec()),
KVCacheGroupSpec(["layer6"], new_sliding_window_spec()),
],
),
]
# Have conflicting layers. Need to raise an error.
conflicting_layer_specs = [{
"layer1": new_kv_cache_spec(),
}, {
"layer1": new_sliding_window_spec(),
}]
with pytest.raises(AssertionError):
unify_kv_cache_configs(diff_kv_cache_config)
get_kv_cache_configs(vllm_config, conflicting_layer_specs, [
ref_kv_cache_spec.page_size_bytes * 2 * 10,
ref_kv_cache_spec.page_size_bytes * 2 * 10,
])
def test_merge_kv_cache_spec():
@ -890,7 +1075,7 @@ def test_allocate_with_lookahead():
assert len(blocks.get_block_ids()[0]) == 2
def test_get_kv_cache_config():
def test_get_kv_cache_config_one_worker():
# pass max_model_len to pass check_enough_kv_cache_memory
model_config = ModelConfig(max_model_len=16)
vllm_config = VllmConfig(model_config=model_config)
@ -901,8 +1086,10 @@ def test_get_kv_cache_config():
'layer_1': new_kv_cache_spec(),
'layer_2': new_kv_cache_spec(),
}
kv_cache_config_full = get_kv_cache_config(
vllm_config, kv_cache_specs_full, mem_per_block_per_layer * 2 * 32)
kv_cache_config_full = get_kv_cache_configs(
vllm_config, [kv_cache_specs_full],
[mem_per_block_per_layer * 2 * 32])[0]
print(kv_cache_config_full)
assert kv_cache_config_full == KVCacheConfig(
num_blocks=32,
kv_cache_tensors=[
@ -920,8 +1107,9 @@ def test_get_kv_cache_config():
'layer_1': new_sliding_window_spec(),
'layer_2': new_sliding_window_spec(),
}
kv_cache_config_sliding = get_kv_cache_config(
vllm_config, kv_cache_specs_sliding, mem_per_block_per_layer * 2 * 32)
kv_cache_config_sliding = get_kv_cache_configs(
vllm_config, [kv_cache_specs_sliding],
[mem_per_block_per_layer * 2 * 32])[0]
assert kv_cache_config_sliding == KVCacheConfig(
num_blocks=32,
kv_cache_tensors=[
@ -940,8 +1128,9 @@ def test_get_kv_cache_config():
'layer_1': new_kv_cache_spec(),
'layer_2': new_sliding_window_spec(),
}
kv_cache_config_hybrid = get_kv_cache_config(
vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 2 * 32)
kv_cache_config_hybrid = get_kv_cache_configs(
vllm_config, [kv_cache_specs_hybrid],
[mem_per_block_per_layer * 2 * 32])[0]
assert kv_cache_config_hybrid == KVCacheConfig(
num_blocks=32,
kv_cache_tensors=[
@ -962,8 +1151,9 @@ def test_get_kv_cache_config():
'layer_1': new_kv_cache_spec(),
'layer_2': new_sliding_window_spec(),
}
kv_cache_config_hybrid = get_kv_cache_config(
vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 2 * 32)
kv_cache_config_hybrid = get_kv_cache_configs(
vllm_config, [kv_cache_specs_hybrid],
[mem_per_block_per_layer * 2 * 32])[0]
assert kv_cache_config_hybrid == KVCacheConfig(
num_blocks=64,
kv_cache_tensors=[
@ -985,21 +1175,22 @@ def test_get_kv_cache_config():
'layer_5': new_sliding_window_spec(),
'layer_6': new_sliding_window_spec(),
}
kv_cache_config_hybrid = get_kv_cache_config(
vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 2 * 32)
kv_cache_config_hybrid = get_kv_cache_configs(
vllm_config, [kv_cache_specs_hybrid],
[mem_per_block_per_layer * 2 * 32])[0]
assert kv_cache_config_hybrid == KVCacheConfig(
num_blocks=32,
kv_cache_tensors=[
KVCacheTensor(size=mem_per_block_per_layer * 32,
shared_by=["layer_1", "layer_3", "layer_5"]),
shared_by=["layer_1", "layer_3", "layer_4"]),
KVCacheTensor(size=mem_per_block_per_layer * 32,
shared_by=["layer_2", "layer_4", "layer_6"]),
shared_by=["layer_2", "layer_5", "layer_6"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer_3", "layer_4"],
KVCacheGroupSpec(["layer_3", "layer_5"],
new_sliding_window_spec()),
KVCacheGroupSpec(["layer_5", "layer_6"],
KVCacheGroupSpec(["layer_4", "layer_6"],
new_sliding_window_spec()),
],
)
@ -1017,27 +1208,30 @@ def test_get_kv_cache_config():
'layer_9': new_sliding_window_spec(),
'layer_10': new_sliding_window_spec(),
}
kv_cache_config_hybrid = get_kv_cache_config(
vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 3 * 32)
kv_cache_config_hybrid = get_kv_cache_configs(
vllm_config, [kv_cache_specs_hybrid],
[mem_per_block_per_layer * 3 * 32])[0]
assert kv_cache_config_hybrid == KVCacheConfig(
num_blocks=32,
kv_cache_tensors=[
KVCacheTensor(
size=mem_per_block_per_layer * 32,
shared_by=["layer_1", "layer_4", "layer_7", "layer_10"]),
shared_by=["layer_1", "layer_4", "layer_5", "layer_6"]),
KVCacheTensor(
size=mem_per_block_per_layer * 32,
shared_by=["layer_2", "layer_7", "layer_8", "layer_9"]),
KVCacheTensor(size=mem_per_block_per_layer * 32,
shared_by=["layer_2", "layer_5", "layer_8"]),
KVCacheTensor(size=mem_per_block_per_layer * 32,
shared_by=["layer_3", "layer_6", "layer_9"]),
shared_by=["layer_3", "layer_10"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer_1", "layer_2", "layer_3"],
new_kv_cache_spec()),
KVCacheGroupSpec(["layer_4", "layer_5", "layer_6"],
KVCacheGroupSpec(["layer_4", "layer_7", "layer_10"],
new_sliding_window_spec()),
KVCacheGroupSpec(["layer_7", "layer_8", "layer_9"],
KVCacheGroupSpec(["layer_5", "layer_8"],
new_sliding_window_spec()),
KVCacheGroupSpec(["layer_6", "layer_9"],
new_sliding_window_spec()),
KVCacheGroupSpec(["layer_10"], new_sliding_window_spec()),
],
)
@ -1047,13 +1241,14 @@ def test_get_kv_cache_config():
'layer_2': new_kv_cache_spec(),
}
with pytest.raises(NotImplementedError):
get_kv_cache_config(vllm_config, kv_cache_specs_hybrid,
mem_per_block_per_layer * 2 * 32)
get_kv_cache_configs(vllm_config, [kv_cache_specs_hybrid],
[mem_per_block_per_layer * 2 * 32])[0]
# Test num_gpu_blocks_override
vllm_config.cache_config.num_gpu_blocks_override = 16
kv_cache_config_override_blocks = get_kv_cache_config(
vllm_config, kv_cache_specs_full, mem_per_block_per_layer * 2 * 32)
kv_cache_config_override_blocks = get_kv_cache_configs(
vllm_config, [kv_cache_specs_full],
[mem_per_block_per_layer * 2 * 32])[0]
assert kv_cache_config_override_blocks == KVCacheConfig(
num_blocks=16,
kv_cache_tensors=[
@ -1065,3 +1260,16 @@ def test_get_kv_cache_config():
kv_cache_groups=[
KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec())
])
def test_get_kv_cache_configs_attention_free():
kv_cache_specs: dict[str, KVCacheSpec] = {}
vllm_config = VllmConfig(model_config=ModelConfig(max_model_len=16))
kv_cache_configs = get_kv_cache_configs(vllm_config, [kv_cache_specs], [0])
assert kv_cache_configs == [
KVCacheConfig(
num_blocks=1,
kv_cache_tensors=[],
kv_cache_groups=[],
)
]

View File

@ -10,7 +10,7 @@ from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.utils import GiB_bytes
from vllm.v1.core.kv_cache_utils import (estimate_max_model_len,
get_kv_cache_config)
get_kv_cache_configs)
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
SchedulerOutput)
from vllm.v1.worker.tpu_model_runner import (
@ -477,8 +477,8 @@ def test_init_kv_cache_without_kv_sharing():
# 2 (non-MLA) * 8 (num_heads) * 128 (head_dim)
# * 2 (bfloat16, kv_cache dtype) * 128 (block_size) = 512KB
num_expected_blocks = 20480 # 20GB / 512KB / 2 (num layers)
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
available_memory)
kv_cache_config = get_kv_cache_configs(vllm_config, [kv_cache_spec],
[available_memory])[0]
assert kv_cache_config.num_blocks == num_expected_blocks
assert len(kv_cache_config.kv_cache_tensors) == 2
assert kv_cache_config.kv_cache_tensors[0].size == available_memory // 2
@ -550,8 +550,8 @@ def test_init_kv_cache_with_kv_sharing_valid():
# with KV sharing, we can allocate (available_mem//page_size//1) blocks
# which is twice as many as without KV sharing
num_expected_blocks = 2 * 20480 # 20GB / 512KB
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
available_memory)
kv_cache_config = get_kv_cache_configs(vllm_config, [kv_cache_spec],
[available_memory])[0]
assert kv_cache_config.num_blocks == num_expected_blocks
assert len(kv_cache_config.kv_cache_tensors) == 1
# Each layer now has twice the available memory for KV cache

View File

@ -15,7 +15,7 @@ from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams
from vllm.utils import GiB_bytes, update_environment_variables
from vllm.v1.core.kv_cache_utils import (estimate_max_model_len,
get_kv_cache_config)
get_kv_cache_configs)
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
SchedulerOutput)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
@ -585,8 +585,8 @@ def test_init_kv_cache_without_kv_sharing():
available_memory = 20 * GiB_bytes
# page size for layer 0's kv_cache_spec is 32KB
num_expected_blocks = 327680 # 20GB / 32KB / 2 (num layers)
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
available_memory)
kv_cache_config = get_kv_cache_configs(vllm_config, [kv_cache_spec],
[available_memory])[0]
assert kv_cache_config.num_blocks == num_expected_blocks
assert len(kv_cache_config.kv_cache_tensors) == 2
assert kv_cache_config.kv_cache_tensors[0].size == available_memory // 2
@ -657,8 +657,8 @@ def test_init_kv_cache_with_kv_sharing_valid():
# with KV sharing, we can allocate (available_mem//page_size//1) blocks
# which is twice as many as without KV sharing
num_expected_blocks = 655360 # 20GB / 32KB
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
available_memory)
kv_cache_config = get_kv_cache_configs(vllm_config, [kv_cache_spec],
[available_memory])[0]
assert kv_cache_config.num_blocks == num_expected_blocks
assert len(kv_cache_config.kv_cache_tensors) == 1
# Each layer now has twice the available memory for KV cache
@ -788,8 +788,8 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
kv_cache_spec = runner.get_kv_cache_spec()
available_memory = 5 * GiB_bytes
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
available_memory)
kv_cache_config = get_kv_cache_configs(vllm_config, [kv_cache_spec],
[available_memory])[0]
runner.initialize_kv_cache(kv_cache_config)
# random partition of blocks

View File

@ -5,7 +5,7 @@
import os
from collections import defaultdict, deque
from collections.abc import Iterable, Sequence
from dataclasses import astuple, dataclass
from dataclasses import dataclass
from typing import Any, Callable, NewType, Optional, Union
from vllm import envs
@ -811,59 +811,21 @@ def get_uniform_page_size(kv_cache_spec: dict[str, KVCacheSpec]) -> int:
return page_sizes.pop()
def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int) -> KVCacheConfig:
def _get_kv_cache_groups_uniform_type(
kv_cache_specs: dict[str, KVCacheSpec]) -> list[KVCacheGroupSpec]:
"""
Generates the KV cache configuration for a model with one type of KV cache.
Divide the available memory equally among all layers.
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of each attention layer in the model
available_memory: Memory available for KV cache in bytes.
kv_cache_specs: The kv cache spec of each attention layer in the model
Returns:
The generated KVCacheConfig
The generated KVCacheGroupSpecs
"""
page_size = get_uniform_page_size(kv_cache_spec)
num_blocks = get_num_blocks(vllm_config, len(kv_cache_spec),
available_memory, page_size)
per_layer_size = page_size * num_blocks
# All layers have the same KV cache spec, so we create one kv cache group
# for all layers.
grouped_layer_names = [list(kv_cache_spec.keys())]
# Each layer uses a separate Tensor to store its KV cache.
kv_cache_tensors = [
KVCacheTensor(size=per_layer_size, shared_by=[layer_name])
for layer_name in kv_cache_spec
]
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=kv_cache_tensors,
kv_cache_groups=create_kv_cache_group_specs(kv_cache_spec,
grouped_layer_names),
)
num_tokens = num_blocks * vllm_config.cache_config.block_size
if vllm_config.parallel_config.decode_context_parallel_size > 1:
num_tokens *= vllm_config.parallel_config.decode_context_parallel_size
logger.info(
"Multiplying the GPU KV cache size by the dcp_world_size %d.",
vllm_config.parallel_config.decode_context_parallel_size)
num_tokens_str = f"{num_tokens:,}"
logger.info("GPU KV cache size: %s tokens", num_tokens_str)
max_model_len_str = f"{vllm_config.model_config.max_model_len:,}"
max_concurrency = get_max_concurrency_for_kv_cache_config(
vllm_config, kv_cache_config)
logger.info("Maximum concurrency for %s tokens per request: %.2fx",
max_model_len_str, max_concurrency)
return kv_cache_config
return create_kv_cache_group_specs(kv_cache_specs,
[list(kv_cache_specs.keys())])
def is_kv_cache_page_size_uniform(
@ -888,11 +850,10 @@ def is_kv_cache_type_attention_free(
return not kv_cache_spec
def _get_kv_cache_config_uniform_page_size(
vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int) -> KVCacheConfig:
def _get_kv_cache_groups_uniform_page_size(
kv_cache_spec: dict[str, KVCacheSpec]) -> list[KVCacheGroupSpec]:
"""
Generates the KV cache configuration for hybrid models with multiple
Generates the KV cache groups for hybrid models with multiple
attention types but still with a uniform page size (physical memory per
block per layer) for all layers.
@ -949,11 +910,9 @@ def _get_kv_cache_config_uniform_page_size(
memory per block is the same for all groups.
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The KVCacheSpec of each attention layer in the model
available_memory: Memory available for KV cache in bytes.
Returns:
The generated KVCacheConfig
The generated KVCacheGroupSpecs
"""
# Group all layers by kv_cache_spec.
# E.g., 2 full attention layers and 3 sliding window attention layers,
@ -966,7 +925,7 @@ def _get_kv_cache_config_uniform_page_size(
# group identical. Add padding to the last group of each type if necessary.
# E.g., (full.0, full.1), (sw.0, sw.1, sw.2)
# split to 3 groups with 2 layers each:
# (full.0, full.1), (sw.0, sw.1), (sw.2, padding).
# (full.0, full.1), (sw.0, sw.2), (sw.1, padding).
# FIXME(Chen): At the moment of writing this code (2025-06-02), all
# open-source hybrid model follows a n:1 pattern between different attention
# types (e.g., Gemma3 5:1 between sw and full, LLaMA4 3:1 between local and
@ -984,19 +943,60 @@ def _get_kv_cache_config_uniform_page_size(
num_padding_layers,
num_padding_layers / len(layers) * 100,
)
for i in range(0, len(layers), group_size):
grouped_layers.append(layers[i:i + group_size])
kv_cache_groups = create_kv_cache_group_specs(kv_cache_spec,
grouped_layers)
num_groups = cdiv(len(layers), group_size)
# In PP case, say if we have
# - stage 0: full.0, sw.0, sw.1
# - stage 1: full.1, sw.2, sw.3
# We should have 3 groups: (full.0, full.1), (sw.0, sw.2), (sw.1, sw.3)
# It can't be (full.0, full.1), (sw.0, sw.1), (sw.2, sw.3) because
# the 3 groups in stage 0 will be (full.0), (sw.0, sw.1), (empty group)
# and it will be padded to (full.0, padding), (sw.0, sw.1),
# (padding, padding) to ensure the number of layers in each group is
# the same and will cause memory waste.
# To avoid this, we assign layers[i::num_groups] to the i-th group
# instead of layers[i * group_size: (i + 1) * group_size]
for i in range(num_groups):
grouped_layers.append(layers[i::num_groups])
return create_kv_cache_group_specs(kv_cache_spec, grouped_layers)
def get_kv_cache_config_from_groups(vllm_config: VllmConfig,
kv_cache_groups: list[KVCacheGroupSpec],
kv_cache_specs: dict[str, KVCacheSpec],
available_memory: int) -> KVCacheConfig:
"""
Generate the KV cache configuration from the KV cache groups and spec
of each layer.
Args:
vllm_config: The global VllmConfig
kv_cache_groups: The KV cache groups
kv_cache_specs: The KV cache spec of each attention layer in the model
available_memory: Memory available for KV cache in bytes
Returns:
The generated KVCacheConfig
"""
if len(kv_cache_groups) == 0:
# Attention free models do not have KV cache.
# Return num_blocks=1 as BlockPool always needs a null_block.
return KVCacheConfig(
num_blocks=1,
kv_cache_tensors=[],
kv_cache_groups=kv_cache_groups,
)
# Determine how model runners should initialize the KV cache tensors.
# We will have group_size memory pools, each is shared by one layer from
# each group. As layers of different groups have different block table,
# they will use different parts of the shared Tensor.
# The memory layout in the example will be:
# full.0, sw.0, sw.2: share a Tensor with size=available_memory//2
# full.1, sw.1: share another Tensor with size=available_memory//2
page_size = get_uniform_page_size(kv_cache_spec)
# The memory layout for 3 groups (full.0, full.1), (sw.0, sw.2),
# (sw.1, padding) will be: (group_size = 2)
# full.0, sw.0, sw.1: share a Tensor with size=available_memory//2
# full.1, sw.2: share another Tensor with size=available_memory//2
group_size = max(len(group.layer_names) for group in kv_cache_groups)
page_size = get_uniform_page_size(kv_cache_specs)
assert group_size > 0, "group_size must be greater than 0"
num_blocks = get_num_blocks(vllm_config, group_size, available_memory,
page_size)
per_memory_pool_size = page_size * num_blocks
@ -1004,8 +1004,8 @@ def _get_kv_cache_config_uniform_page_size(
for i in range(group_size):
shared_by = []
for j in range(len(kv_cache_groups)):
if i < len(grouped_layers[j]):
shared_by.append(grouped_layers[j][i])
if i < len(kv_cache_groups[j].layer_names):
shared_by.append(kv_cache_groups[j].layer_names[i])
kv_cache_tensors.append(
KVCacheTensor(size=per_memory_pool_size, shared_by=shared_by))
@ -1019,7 +1019,12 @@ def _get_kv_cache_config_uniform_page_size(
[group.kv_cache_spec.block_size for group in kv_cache_groups])
# Print the KV cache size and maximum concurrency.
num_tokens = num_blocks // len(grouped_layers) * min_block_size
num_tokens = num_blocks // len(kv_cache_groups) * min_block_size
if vllm_config.parallel_config.decode_context_parallel_size > 1:
num_tokens *= vllm_config.parallel_config.decode_context_parallel_size
logger.info(
"Multiplying the GPU KV cache size by the dcp_world_size %d.",
vllm_config.parallel_config.decode_context_parallel_size)
num_tokens_str = f"{num_tokens:,}"
logger.info("GPU KV cache size: %s tokens", num_tokens_str)
max_model_len_str = f"{vllm_config.model_config.max_model_len:,}"
@ -1030,10 +1035,6 @@ def _get_kv_cache_config_uniform_page_size(
return kv_cache_config
def _get_kv_cache_config_attention_free() -> KVCacheConfig:
return KVCacheConfig(num_blocks=1, kv_cache_tensors=[], kv_cache_groups=[])
def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
"""
This function tries to convert the KV cache specs to one type if the model
@ -1087,72 +1088,112 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
"convert the KV cache specs to one unified type.")
def get_kv_cache_config(
vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int,
) -> KVCacheConfig:
def get_kv_cache_groups(
vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec]) -> list[KVCacheGroupSpec]:
"""
Generates the KV cache configuration for a model.
Split the layers in the model into groups with the same KV cache spec.
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of each attention layer in the model
available_memory: Memory available for KV cache in bytes.
Returns:
The generated KVCacheConfigs
The generated KVCacheGroups
"""
check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)
if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager:
unify_hybrid_kv_cache_specs(kv_cache_spec)
if is_kv_cache_type_attention_free(kv_cache_spec):
# This returns a kv_cache config with 0 kv_cache groups and 1 block
# to allow for the KVCache manager to handle attention free models.
return _get_kv_cache_config_attention_free()
# This returns an empty list to allow for the KVCacheManager to handle
# attention free models.
return []
elif is_kv_cache_type_uniform(kv_cache_spec):
# KV cache of all layers are the same, which is true for
# most models. Allocate the same amount of memory for
# each layer.
return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec,
available_memory)
return _get_kv_cache_groups_uniform_type(kv_cache_spec)
elif is_kv_cache_page_size_uniform(kv_cache_spec):
# Model contains multiple attention types, but KV cache of all layers
# have the same physical memory per block per layer. Split the layers
# into groups with the same number of layers, and thus same total page
# size.
return _get_kv_cache_config_uniform_page_size(vllm_config,
kv_cache_spec,
available_memory)
return _get_kv_cache_groups_uniform_page_size(kv_cache_spec)
raise NotImplementedError
def unify_kv_cache_configs(kv_cache_configs: list[KVCacheConfig]):
def get_kv_cache_configs(vllm_config: VllmConfig,
kv_cache_specs: list[dict[str, KVCacheSpec]],
available_memory: list[int]) -> list[KVCacheConfig]:
"""
Make the KV cache configurations for each worker consistent, so that all
workers can be controlled by the same KVCacheManager.
This function verifies that the layer group of each worker are the same,
and changes the num_blocks of each worker to the smallest among all workers.
Generates the KV cache configurations for a model.
Since we use a shared centralized controller for all workers, we need the
`kv_cache_config` to be consistent across all workers to make sure
the KV cache allocation can be applied to all workers. However, different
workers may have different memory available, and different type of layers
(when pipeline parallel is enabled). To handle the difference between
workers, the current implementation is:
1. Merge the KV cache specs of all workers to get the KVCacheSpecs for
the whole model.
2. Generate the KV cache groups based on the layer ratio of the whole model.
3. Generate the KV cache configs for each worker based on the KV cache
grouping strategy. (This is reasonable because the layer ratio of
different PP stages are similar.)
4. Change the num_blocks of each worker to the smallest among all workers.
Args:
kv_cache_configs: The KV cache configurations for each worker. Will be
in-place modified to make them consistent.
vllm_config: The global VllmConfig
kv_cache_specs: List of dict[layer_name, KVCacheSpec] for each worker.
available_memory: Memory available for KV cache in bytes for each
worker.
Returns:
The generated KVCacheConfigs for each worker.
"""
# Sort the kv cache groups by their KV cache spec.
# This can avoid the inconsistency caused by the order of groups.
for kv_cache_config in kv_cache_configs:
kv_cache_config.kv_cache_groups.sort(key=lambda x: (type(
x.kv_cache_spec).__name__, astuple(x.kv_cache_spec)))
# Check if the available memory is enough for each worker.
for kv_cache_spec_one_worker, available_memory_one_worker in zip(
kv_cache_specs, available_memory):
check_enough_kv_cache_memory(vllm_config, kv_cache_spec_one_worker,
available_memory_one_worker)
# Verify that the groups of each rank are the same.
for kv_cache_config in kv_cache_configs[1:]:
for group_rank_0, group_rank_i in zip(
kv_cache_configs[0].kv_cache_groups,
kv_cache_config.kv_cache_groups):
assert group_rank_0.kv_cache_spec == group_rank_i.kv_cache_spec
# Merge the KV cache specs of all workers. Different PP stages may have
# different layer names, and different TP ranks of the same PP stage should
# have the same KV cache spec.
merged_kv_cache_specs: dict[str, KVCacheSpec] = {}
for kv_cache_spec_one_worker in kv_cache_specs:
for layer_name, layer_spec in kv_cache_spec_one_worker.items():
if layer_name not in merged_kv_cache_specs:
merged_kv_cache_specs[layer_name] = layer_spec
else:
assert merged_kv_cache_specs[layer_name] == layer_spec, (
"The KV cache specs for the same layer are different "
"across workers. This is not supported yet.")
global_kv_cache_groups = get_kv_cache_groups(vllm_config,
merged_kv_cache_specs)
kv_cache_configs: list[KVCacheConfig] = []
for kv_cache_spec_one_worker, available_memory_one_worker in zip(
kv_cache_specs, available_memory):
kv_cache_groups_one_worker: list[KVCacheGroupSpec] = []
for group in global_kv_cache_groups:
group_layer_names_one_worker = [
layer_name for layer_name in group.layer_names
if layer_name in kv_cache_spec_one_worker
]
kv_cache_groups_one_worker.append(
KVCacheGroupSpec(group_layer_names_one_worker,
group.kv_cache_spec))
assert sum(
len(group.layer_names) for group in
kv_cache_groups_one_worker) == len(kv_cache_spec_one_worker), (
"Some layers are not assigned to any group.")
kv_cache_configs.append(
get_kv_cache_config_from_groups(vllm_config,
kv_cache_groups_one_worker,
kv_cache_spec_one_worker,
available_memory_one_worker))
# Change the num_blocks of each rank to the smallest among all ranks. We
# do not need to shrink the tensor size because it is valid to only use the

View File

@ -29,10 +29,9 @@ from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.utils import (decorate_logs, get_hash_fn_by_name, make_zmq_socket,
resolve_obj_by_qualname, set_process_title)
from vllm.v1.core.kv_cache_utils import (BlockHash, get_kv_cache_config,
from vllm.v1.core.kv_cache_utils import (BlockHash, get_kv_cache_configs,
get_request_block_hasher,
init_none_hash,
unify_kv_cache_configs)
init_none_hash)
from vllm.v1.core.sched.interface import SchedulerInterface
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
@ -191,18 +190,9 @@ class EngineCore:
available_gpu_memory = [0] * len(kv_cache_specs)
assert len(kv_cache_specs) == len(available_gpu_memory)
# Get the kv cache tensor size
kv_cache_configs = [
get_kv_cache_config(vllm_config, kv_cache_spec_one_worker,
available_gpu_memory_one_worker)
for kv_cache_spec_one_worker, available_gpu_memory_one_worker in
zip(kv_cache_specs, available_gpu_memory)
]
# Since we use a shared centralized controller, we need the
# `kv_cache_config` to be consistent across all workers to make sure
# all the memory operators can be applied to all workers.
unify_kv_cache_configs(kv_cache_configs)
kv_cache_configs = get_kv_cache_configs(vllm_config, kv_cache_specs,
available_gpu_memory)
# All workers have the same kv_cache_config except layer names, so use
# an arbitrary one to initialize the scheduler.