mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:04:53 +08:00
[Hybrid Allocator] Support Pipeline Parallel (#23974)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
parent
90f3f7d73e
commit
8e5cdcda4e
@ -215,9 +215,7 @@ TEXT_GENERATION_MODELS = {
|
||||
EMBEDDING_MODELS = { # type: ignore[var-annotated]
|
||||
# [Text-only]
|
||||
"intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(runner="pooling"),
|
||||
# TODO: re-enable when https://github.com/vllm-project/vllm/issues/23883
|
||||
# is fixed
|
||||
#"BAAI/bge-multilingual-gemma2": PPTestSettings.fast(runner="pooling"),
|
||||
"BAAI/bge-multilingual-gemma2": PPTestSettings.fast(runner="pooling"),
|
||||
"Qwen/Qwen2.5-Math-RM-72B": PPTestSettings.fast(
|
||||
load_format="dummy", runner="pooling"
|
||||
),
|
||||
|
||||
@ -10,7 +10,7 @@ from vllm import LLM
|
||||
from vllm.config import ModelImpl
|
||||
from vllm.engine.llm_engine import LLMEngine as V0LLMEngine
|
||||
from vllm.utils import GiB_bytes
|
||||
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
|
||||
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
|
||||
from vllm.v1.engine.core import EngineCore as V1EngineCore
|
||||
|
||||
from ..utils import create_new_process_for_each_test
|
||||
@ -68,11 +68,11 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
|
||||
|
||||
def _initialize_kv_caches_v1(self, vllm_config):
|
||||
kv_cache_specs = self.model_executor.get_kv_cache_specs()
|
||||
scheduler_kv_cache_config = get_kv_cache_config(
|
||||
scheduler_kv_cache_config = get_kv_cache_configs(
|
||||
vllm_config,
|
||||
kv_cache_specs[0],
|
||||
10 * GiB_bytes,
|
||||
)
|
||||
kv_cache_specs,
|
||||
[10 * GiB_bytes],
|
||||
)[0]
|
||||
|
||||
# gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config
|
||||
return 1, 0, scheduler_kv_cache_config
|
||||
|
||||
@ -18,13 +18,12 @@ from vllm.v1.core.kv_cache_manager import KVCacheManager
|
||||
from vllm.v1.core.kv_cache_utils import (
|
||||
BlockHash, FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics,
|
||||
estimate_max_model_len, generate_block_hash_extra_keys,
|
||||
get_kv_cache_config, get_max_concurrency_for_kv_cache_config,
|
||||
get_kv_cache_configs, get_max_concurrency_for_kv_cache_config,
|
||||
get_request_block_hasher, hash_block_tokens, init_none_hash,
|
||||
is_kv_cache_type_uniform, make_block_hash_with_group_id,
|
||||
unify_kv_cache_configs)
|
||||
is_kv_cache_type_uniform, make_block_hash_with_group_id)
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, KVCacheTensor,
|
||||
SlidingWindowSpec)
|
||||
KVCacheGroupSpec, KVCacheSpec,
|
||||
KVCacheTensor, SlidingWindowSpec)
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats
|
||||
from vllm.v1.request import Request
|
||||
|
||||
@ -531,102 +530,288 @@ def test_metrics():
|
||||
assert not metrics.query_queue
|
||||
|
||||
|
||||
def test_unify_kv_cache_configs():
|
||||
same_kv_cache_config = [
|
||||
KVCacheConfig(
|
||||
num_blocks=10,
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(size=100, shared_by=["layer1"]),
|
||||
KVCacheTensor(size=100, shared_by=["layer2"]),
|
||||
],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
|
||||
KVCacheGroupSpec(["layer2"],
|
||||
new_kv_cache_spec(num_kv_heads=4)),
|
||||
],
|
||||
),
|
||||
KVCacheConfig(
|
||||
num_blocks=20,
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(size=100, shared_by=["layer1"]),
|
||||
KVCacheTensor(size=100, shared_by=["layer2"]),
|
||||
],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
|
||||
KVCacheGroupSpec(["layer2"],
|
||||
new_kv_cache_spec(num_kv_heads=4)),
|
||||
],
|
||||
),
|
||||
]
|
||||
unify_kv_cache_configs(same_kv_cache_config)
|
||||
assert same_kv_cache_config[0].num_blocks == 10
|
||||
assert same_kv_cache_config[1].num_blocks == 10
|
||||
def test_get_kv_cache_configs_multiple_workers():
|
||||
model_config = ModelConfig(max_model_len=16)
|
||||
vllm_config = VllmConfig(model_config=model_config)
|
||||
|
||||
need_sort_kv_cache_config = [
|
||||
ref_kv_cache_spec = new_kv_cache_spec()
|
||||
same_kv_cache_specs = [{
|
||||
"layer1": new_kv_cache_spec(),
|
||||
"layer2": new_kv_cache_spec(),
|
||||
}, {
|
||||
"layer1": new_kv_cache_spec(),
|
||||
"layer2": new_kv_cache_spec(),
|
||||
}]
|
||||
|
||||
# Basic case. All things are the same.
|
||||
kv_cache_configs = get_kv_cache_configs(vllm_config, same_kv_cache_specs, [
|
||||
ref_kv_cache_spec.page_size_bytes * 2 * 10,
|
||||
ref_kv_cache_spec.page_size_bytes * 2 * 10
|
||||
])
|
||||
assert kv_cache_configs == [
|
||||
KVCacheConfig(
|
||||
num_blocks=10,
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(size=100, shared_by=["layer1"]),
|
||||
KVCacheTensor(size=100, shared_by=["layer2"]),
|
||||
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
|
||||
shared_by=["layer1"]),
|
||||
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
|
||||
shared_by=["layer2"]),
|
||||
],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
|
||||
KVCacheGroupSpec(["layer2"],
|
||||
new_kv_cache_spec(num_kv_heads=4)),
|
||||
KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec),
|
||||
],
|
||||
),
|
||||
KVCacheConfig(
|
||||
num_blocks=20,
|
||||
num_blocks=10,
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(size=100, shared_by=["layer1"]),
|
||||
KVCacheTensor(size=100, shared_by=["layer2"]),
|
||||
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
|
||||
shared_by=["layer1"]),
|
||||
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
|
||||
shared_by=["layer2"]),
|
||||
],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(["layer2"],
|
||||
new_kv_cache_spec(num_kv_heads=4)),
|
||||
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
|
||||
KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
unify_kv_cache_configs(need_sort_kv_cache_config)
|
||||
sorted_kv_cache_groups = [
|
||||
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
|
||||
KVCacheGroupSpec(["layer2"], new_kv_cache_spec(num_kv_heads=4)),
|
||||
]
|
||||
assert (
|
||||
need_sort_kv_cache_config[0].kv_cache_groups == sorted_kv_cache_groups)
|
||||
assert (
|
||||
need_sort_kv_cache_config[1].kv_cache_groups == sorted_kv_cache_groups)
|
||||
|
||||
diff_kv_cache_config = [
|
||||
# Different available memory. This is the case for TP.
|
||||
# Use the smallest memory available.
|
||||
kv_cache_configs = get_kv_cache_configs(vllm_config, same_kv_cache_specs, [
|
||||
ref_kv_cache_spec.page_size_bytes * 2 * 10,
|
||||
ref_kv_cache_spec.page_size_bytes * 2 * 20
|
||||
])
|
||||
assert kv_cache_configs == [
|
||||
KVCacheConfig(
|
||||
num_blocks=10,
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(size=100, shared_by=["layer1"]),
|
||||
KVCacheTensor(size=100, shared_by=["layer2"]),
|
||||
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
|
||||
shared_by=["layer1"]),
|
||||
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
|
||||
shared_by=["layer2"]),
|
||||
],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
|
||||
KVCacheGroupSpec(["layer2"],
|
||||
new_kv_cache_spec(num_kv_heads=4)),
|
||||
KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec),
|
||||
],
|
||||
),
|
||||
KVCacheConfig(
|
||||
num_blocks=20,
|
||||
num_blocks=10,
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(size=100, shared_by=["layer1"]),
|
||||
KVCacheTensor(size=100, shared_by=["layer2"]),
|
||||
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 20,
|
||||
shared_by=["layer1"]),
|
||||
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 20,
|
||||
shared_by=["layer2"]),
|
||||
],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
|
||||
KVCacheGroupSpec(["layer2"],
|
||||
new_kv_cache_spec(num_kv_heads=8)),
|
||||
KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
# Different KV cache specs. This is the case for PP.
|
||||
different_layer_specs = [{
|
||||
"layer1": new_kv_cache_spec(),
|
||||
}, {
|
||||
"layer2": new_kv_cache_spec(),
|
||||
"layer3": new_kv_cache_spec(),
|
||||
}]
|
||||
|
||||
# Different workers have different layers.
|
||||
kv_cache_configs = get_kv_cache_configs(
|
||||
vllm_config, different_layer_specs, [
|
||||
ref_kv_cache_spec.page_size_bytes * 2 * 10,
|
||||
ref_kv_cache_spec.page_size_bytes * 2 * 10
|
||||
])
|
||||
assert kv_cache_configs == [
|
||||
KVCacheConfig(
|
||||
num_blocks=10,
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 20,
|
||||
shared_by=["layer1"]),
|
||||
],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
|
||||
],
|
||||
),
|
||||
KVCacheConfig(
|
||||
num_blocks=10,
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
|
||||
shared_by=["layer2"]),
|
||||
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
|
||||
shared_by=["layer3"]),
|
||||
],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(["layer2", "layer3"], new_kv_cache_spec()),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
# Some layers are the same, some are different. This is the case for TP+PP
|
||||
tp_pp_kv_cache_specs = [{
|
||||
"layer1": new_kv_cache_spec(),
|
||||
"layer2": new_kv_cache_spec(),
|
||||
}, {
|
||||
"layer1": new_kv_cache_spec(),
|
||||
"layer2": new_kv_cache_spec(),
|
||||
}, {
|
||||
"layer3": new_kv_cache_spec(),
|
||||
}, {
|
||||
"layer3": new_kv_cache_spec(),
|
||||
}]
|
||||
|
||||
kv_cache_configs = get_kv_cache_configs(
|
||||
vllm_config, tp_pp_kv_cache_specs, [
|
||||
ref_kv_cache_spec.page_size_bytes * 2 * 10,
|
||||
ref_kv_cache_spec.page_size_bytes * 2 * 10,
|
||||
ref_kv_cache_spec.page_size_bytes * 2 * 10,
|
||||
ref_kv_cache_spec.page_size_bytes * 2 * 10,
|
||||
])
|
||||
assert kv_cache_configs == [
|
||||
KVCacheConfig(
|
||||
num_blocks=10,
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
|
||||
shared_by=["layer1"]),
|
||||
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
|
||||
shared_by=["layer2"]),
|
||||
],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec),
|
||||
],
|
||||
),
|
||||
KVCacheConfig(
|
||||
num_blocks=10,
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
|
||||
shared_by=["layer1"]),
|
||||
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
|
||||
shared_by=["layer2"]),
|
||||
],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec),
|
||||
],
|
||||
),
|
||||
KVCacheConfig(
|
||||
num_blocks=10,
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 20,
|
||||
shared_by=["layer3"]),
|
||||
],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(["layer3"], ref_kv_cache_spec),
|
||||
],
|
||||
),
|
||||
KVCacheConfig(
|
||||
num_blocks=10,
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 20,
|
||||
shared_by=["layer3"]),
|
||||
],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(["layer3"], ref_kv_cache_spec),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
# Different workers have different types of layers. This is the case for
|
||||
# hybrid models + PP.
|
||||
different_type_layer_specs = [{
|
||||
"layer1": new_kv_cache_spec(),
|
||||
"layer2": new_kv_cache_spec(),
|
||||
}, {
|
||||
"layer3": new_sliding_window_spec(),
|
||||
"layer4": new_sliding_window_spec(),
|
||||
}]
|
||||
kv_cache_configs = get_kv_cache_configs(
|
||||
vllm_config, different_type_layer_specs, [
|
||||
ref_kv_cache_spec.page_size_bytes * 2 * 10,
|
||||
ref_kv_cache_spec.page_size_bytes * 2 * 10,
|
||||
])
|
||||
assert kv_cache_configs == [
|
||||
KVCacheConfig(
|
||||
num_blocks=10,
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
|
||||
shared_by=["layer1"]),
|
||||
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
|
||||
shared_by=["layer2"]),
|
||||
],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec),
|
||||
KVCacheGroupSpec([], new_sliding_window_spec()),
|
||||
],
|
||||
),
|
||||
KVCacheConfig(
|
||||
num_blocks=10,
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
|
||||
shared_by=["layer3"]),
|
||||
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
|
||||
shared_by=["layer4"]),
|
||||
],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec([], ref_kv_cache_spec),
|
||||
KVCacheGroupSpec(["layer3", "layer4"],
|
||||
new_sliding_window_spec()),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
# When divided into multiple KVCacheGroups, need to ensure the number of
|
||||
# layers per group is similar.
|
||||
different_type_layer_specs = [{
|
||||
"layer1": new_kv_cache_spec(),
|
||||
"layer2": new_sliding_window_spec(),
|
||||
"layer3": new_sliding_window_spec(),
|
||||
}, {
|
||||
"layer4": new_kv_cache_spec(),
|
||||
"layer5": new_sliding_window_spec(),
|
||||
"layer6": new_sliding_window_spec(),
|
||||
}]
|
||||
kv_cache_configs = get_kv_cache_configs(
|
||||
vllm_config, different_type_layer_specs, [
|
||||
ref_kv_cache_spec.page_size_bytes * 10,
|
||||
ref_kv_cache_spec.page_size_bytes * 10,
|
||||
])
|
||||
assert kv_cache_configs == [
|
||||
KVCacheConfig(
|
||||
num_blocks=10,
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
|
||||
shared_by=["layer1", "layer2", "layer3"]),
|
||||
],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(["layer1"], ref_kv_cache_spec),
|
||||
KVCacheGroupSpec(["layer2"], new_sliding_window_spec()),
|
||||
KVCacheGroupSpec(["layer3"], new_sliding_window_spec()),
|
||||
],
|
||||
),
|
||||
KVCacheConfig(
|
||||
num_blocks=10,
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10,
|
||||
shared_by=["layer4", "layer5", "layer6"]),
|
||||
],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(["layer4"], ref_kv_cache_spec),
|
||||
KVCacheGroupSpec(["layer5"], new_sliding_window_spec()),
|
||||
KVCacheGroupSpec(["layer6"], new_sliding_window_spec()),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
# Have conflicting layers. Need to raise an error.
|
||||
conflicting_layer_specs = [{
|
||||
"layer1": new_kv_cache_spec(),
|
||||
}, {
|
||||
"layer1": new_sliding_window_spec(),
|
||||
}]
|
||||
with pytest.raises(AssertionError):
|
||||
unify_kv_cache_configs(diff_kv_cache_config)
|
||||
get_kv_cache_configs(vllm_config, conflicting_layer_specs, [
|
||||
ref_kv_cache_spec.page_size_bytes * 2 * 10,
|
||||
ref_kv_cache_spec.page_size_bytes * 2 * 10,
|
||||
])
|
||||
|
||||
|
||||
def test_merge_kv_cache_spec():
|
||||
@ -890,7 +1075,7 @@ def test_allocate_with_lookahead():
|
||||
assert len(blocks.get_block_ids()[0]) == 2
|
||||
|
||||
|
||||
def test_get_kv_cache_config():
|
||||
def test_get_kv_cache_config_one_worker():
|
||||
# pass max_model_len to pass check_enough_kv_cache_memory
|
||||
model_config = ModelConfig(max_model_len=16)
|
||||
vllm_config = VllmConfig(model_config=model_config)
|
||||
@ -901,8 +1086,10 @@ def test_get_kv_cache_config():
|
||||
'layer_1': new_kv_cache_spec(),
|
||||
'layer_2': new_kv_cache_spec(),
|
||||
}
|
||||
kv_cache_config_full = get_kv_cache_config(
|
||||
vllm_config, kv_cache_specs_full, mem_per_block_per_layer * 2 * 32)
|
||||
kv_cache_config_full = get_kv_cache_configs(
|
||||
vllm_config, [kv_cache_specs_full],
|
||||
[mem_per_block_per_layer * 2 * 32])[0]
|
||||
print(kv_cache_config_full)
|
||||
assert kv_cache_config_full == KVCacheConfig(
|
||||
num_blocks=32,
|
||||
kv_cache_tensors=[
|
||||
@ -920,8 +1107,9 @@ def test_get_kv_cache_config():
|
||||
'layer_1': new_sliding_window_spec(),
|
||||
'layer_2': new_sliding_window_spec(),
|
||||
}
|
||||
kv_cache_config_sliding = get_kv_cache_config(
|
||||
vllm_config, kv_cache_specs_sliding, mem_per_block_per_layer * 2 * 32)
|
||||
kv_cache_config_sliding = get_kv_cache_configs(
|
||||
vllm_config, [kv_cache_specs_sliding],
|
||||
[mem_per_block_per_layer * 2 * 32])[0]
|
||||
assert kv_cache_config_sliding == KVCacheConfig(
|
||||
num_blocks=32,
|
||||
kv_cache_tensors=[
|
||||
@ -940,8 +1128,9 @@ def test_get_kv_cache_config():
|
||||
'layer_1': new_kv_cache_spec(),
|
||||
'layer_2': new_sliding_window_spec(),
|
||||
}
|
||||
kv_cache_config_hybrid = get_kv_cache_config(
|
||||
vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 2 * 32)
|
||||
kv_cache_config_hybrid = get_kv_cache_configs(
|
||||
vllm_config, [kv_cache_specs_hybrid],
|
||||
[mem_per_block_per_layer * 2 * 32])[0]
|
||||
assert kv_cache_config_hybrid == KVCacheConfig(
|
||||
num_blocks=32,
|
||||
kv_cache_tensors=[
|
||||
@ -962,8 +1151,9 @@ def test_get_kv_cache_config():
|
||||
'layer_1': new_kv_cache_spec(),
|
||||
'layer_2': new_sliding_window_spec(),
|
||||
}
|
||||
kv_cache_config_hybrid = get_kv_cache_config(
|
||||
vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 2 * 32)
|
||||
kv_cache_config_hybrid = get_kv_cache_configs(
|
||||
vllm_config, [kv_cache_specs_hybrid],
|
||||
[mem_per_block_per_layer * 2 * 32])[0]
|
||||
assert kv_cache_config_hybrid == KVCacheConfig(
|
||||
num_blocks=64,
|
||||
kv_cache_tensors=[
|
||||
@ -985,21 +1175,22 @@ def test_get_kv_cache_config():
|
||||
'layer_5': new_sliding_window_spec(),
|
||||
'layer_6': new_sliding_window_spec(),
|
||||
}
|
||||
kv_cache_config_hybrid = get_kv_cache_config(
|
||||
vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 2 * 32)
|
||||
kv_cache_config_hybrid = get_kv_cache_configs(
|
||||
vllm_config, [kv_cache_specs_hybrid],
|
||||
[mem_per_block_per_layer * 2 * 32])[0]
|
||||
assert kv_cache_config_hybrid == KVCacheConfig(
|
||||
num_blocks=32,
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(size=mem_per_block_per_layer * 32,
|
||||
shared_by=["layer_1", "layer_3", "layer_5"]),
|
||||
shared_by=["layer_1", "layer_3", "layer_4"]),
|
||||
KVCacheTensor(size=mem_per_block_per_layer * 32,
|
||||
shared_by=["layer_2", "layer_4", "layer_6"]),
|
||||
shared_by=["layer_2", "layer_5", "layer_6"]),
|
||||
],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec()),
|
||||
KVCacheGroupSpec(["layer_3", "layer_4"],
|
||||
KVCacheGroupSpec(["layer_3", "layer_5"],
|
||||
new_sliding_window_spec()),
|
||||
KVCacheGroupSpec(["layer_5", "layer_6"],
|
||||
KVCacheGroupSpec(["layer_4", "layer_6"],
|
||||
new_sliding_window_spec()),
|
||||
],
|
||||
)
|
||||
@ -1017,27 +1208,30 @@ def test_get_kv_cache_config():
|
||||
'layer_9': new_sliding_window_spec(),
|
||||
'layer_10': new_sliding_window_spec(),
|
||||
}
|
||||
kv_cache_config_hybrid = get_kv_cache_config(
|
||||
vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 3 * 32)
|
||||
kv_cache_config_hybrid = get_kv_cache_configs(
|
||||
vllm_config, [kv_cache_specs_hybrid],
|
||||
[mem_per_block_per_layer * 3 * 32])[0]
|
||||
assert kv_cache_config_hybrid == KVCacheConfig(
|
||||
num_blocks=32,
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(
|
||||
size=mem_per_block_per_layer * 32,
|
||||
shared_by=["layer_1", "layer_4", "layer_7", "layer_10"]),
|
||||
shared_by=["layer_1", "layer_4", "layer_5", "layer_6"]),
|
||||
KVCacheTensor(
|
||||
size=mem_per_block_per_layer * 32,
|
||||
shared_by=["layer_2", "layer_7", "layer_8", "layer_9"]),
|
||||
KVCacheTensor(size=mem_per_block_per_layer * 32,
|
||||
shared_by=["layer_2", "layer_5", "layer_8"]),
|
||||
KVCacheTensor(size=mem_per_block_per_layer * 32,
|
||||
shared_by=["layer_3", "layer_6", "layer_9"]),
|
||||
shared_by=["layer_3", "layer_10"]),
|
||||
],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(["layer_1", "layer_2", "layer_3"],
|
||||
new_kv_cache_spec()),
|
||||
KVCacheGroupSpec(["layer_4", "layer_5", "layer_6"],
|
||||
KVCacheGroupSpec(["layer_4", "layer_7", "layer_10"],
|
||||
new_sliding_window_spec()),
|
||||
KVCacheGroupSpec(["layer_7", "layer_8", "layer_9"],
|
||||
KVCacheGroupSpec(["layer_5", "layer_8"],
|
||||
new_sliding_window_spec()),
|
||||
KVCacheGroupSpec(["layer_6", "layer_9"],
|
||||
new_sliding_window_spec()),
|
||||
KVCacheGroupSpec(["layer_10"], new_sliding_window_spec()),
|
||||
],
|
||||
)
|
||||
|
||||
@ -1047,13 +1241,14 @@ def test_get_kv_cache_config():
|
||||
'layer_2': new_kv_cache_spec(),
|
||||
}
|
||||
with pytest.raises(NotImplementedError):
|
||||
get_kv_cache_config(vllm_config, kv_cache_specs_hybrid,
|
||||
mem_per_block_per_layer * 2 * 32)
|
||||
get_kv_cache_configs(vllm_config, [kv_cache_specs_hybrid],
|
||||
[mem_per_block_per_layer * 2 * 32])[0]
|
||||
|
||||
# Test num_gpu_blocks_override
|
||||
vllm_config.cache_config.num_gpu_blocks_override = 16
|
||||
kv_cache_config_override_blocks = get_kv_cache_config(
|
||||
vllm_config, kv_cache_specs_full, mem_per_block_per_layer * 2 * 32)
|
||||
kv_cache_config_override_blocks = get_kv_cache_configs(
|
||||
vllm_config, [kv_cache_specs_full],
|
||||
[mem_per_block_per_layer * 2 * 32])[0]
|
||||
assert kv_cache_config_override_blocks == KVCacheConfig(
|
||||
num_blocks=16,
|
||||
kv_cache_tensors=[
|
||||
@ -1065,3 +1260,16 @@ def test_get_kv_cache_config():
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec())
|
||||
])
|
||||
|
||||
|
||||
def test_get_kv_cache_configs_attention_free():
|
||||
kv_cache_specs: dict[str, KVCacheSpec] = {}
|
||||
vllm_config = VllmConfig(model_config=ModelConfig(max_model_len=16))
|
||||
kv_cache_configs = get_kv_cache_configs(vllm_config, [kv_cache_specs], [0])
|
||||
assert kv_cache_configs == [
|
||||
KVCacheConfig(
|
||||
num_blocks=1,
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[],
|
||||
)
|
||||
]
|
||||
|
||||
@ -10,7 +10,7 @@ from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import GiB_bytes
|
||||
from vllm.v1.core.kv_cache_utils import (estimate_max_model_len,
|
||||
get_kv_cache_config)
|
||||
get_kv_cache_configs)
|
||||
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
|
||||
SchedulerOutput)
|
||||
from vllm.v1.worker.tpu_model_runner import (
|
||||
@ -477,8 +477,8 @@ def test_init_kv_cache_without_kv_sharing():
|
||||
# 2 (non-MLA) * 8 (num_heads) * 128 (head_dim)
|
||||
# * 2 (bfloat16, kv_cache dtype) * 128 (block_size) = 512KB
|
||||
num_expected_blocks = 20480 # 20GB / 512KB / 2 (num layers)
|
||||
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
|
||||
available_memory)
|
||||
kv_cache_config = get_kv_cache_configs(vllm_config, [kv_cache_spec],
|
||||
[available_memory])[0]
|
||||
assert kv_cache_config.num_blocks == num_expected_blocks
|
||||
assert len(kv_cache_config.kv_cache_tensors) == 2
|
||||
assert kv_cache_config.kv_cache_tensors[0].size == available_memory // 2
|
||||
@ -550,8 +550,8 @@ def test_init_kv_cache_with_kv_sharing_valid():
|
||||
# with KV sharing, we can allocate (available_mem//page_size//1) blocks
|
||||
# which is twice as many as without KV sharing
|
||||
num_expected_blocks = 2 * 20480 # 20GB / 512KB
|
||||
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
|
||||
available_memory)
|
||||
kv_cache_config = get_kv_cache_configs(vllm_config, [kv_cache_spec],
|
||||
[available_memory])[0]
|
||||
assert kv_cache_config.num_blocks == num_expected_blocks
|
||||
assert len(kv_cache_config.kv_cache_tensors) == 1
|
||||
# Each layer now has twice the available memory for KV cache
|
||||
|
||||
@ -15,7 +15,7 @@ from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import GiB_bytes, update_environment_variables
|
||||
from vllm.v1.core.kv_cache_utils import (estimate_max_model_len,
|
||||
get_kv_cache_config)
|
||||
get_kv_cache_configs)
|
||||
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
|
||||
SchedulerOutput)
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
@ -585,8 +585,8 @@ def test_init_kv_cache_without_kv_sharing():
|
||||
available_memory = 20 * GiB_bytes
|
||||
# page size for layer 0's kv_cache_spec is 32KB
|
||||
num_expected_blocks = 327680 # 20GB / 32KB / 2 (num layers)
|
||||
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
|
||||
available_memory)
|
||||
kv_cache_config = get_kv_cache_configs(vllm_config, [kv_cache_spec],
|
||||
[available_memory])[0]
|
||||
assert kv_cache_config.num_blocks == num_expected_blocks
|
||||
assert len(kv_cache_config.kv_cache_tensors) == 2
|
||||
assert kv_cache_config.kv_cache_tensors[0].size == available_memory // 2
|
||||
@ -657,8 +657,8 @@ def test_init_kv_cache_with_kv_sharing_valid():
|
||||
# with KV sharing, we can allocate (available_mem//page_size//1) blocks
|
||||
# which is twice as many as without KV sharing
|
||||
num_expected_blocks = 655360 # 20GB / 32KB
|
||||
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
|
||||
available_memory)
|
||||
kv_cache_config = get_kv_cache_configs(vllm_config, [kv_cache_spec],
|
||||
[available_memory])[0]
|
||||
assert kv_cache_config.num_blocks == num_expected_blocks
|
||||
assert len(kv_cache_config.kv_cache_tensors) == 1
|
||||
# Each layer now has twice the available memory for KV cache
|
||||
@ -788,8 +788,8 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
|
||||
kv_cache_spec = runner.get_kv_cache_spec()
|
||||
|
||||
available_memory = 5 * GiB_bytes
|
||||
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
|
||||
available_memory)
|
||||
kv_cache_config = get_kv_cache_configs(vllm_config, [kv_cache_spec],
|
||||
[available_memory])[0]
|
||||
runner.initialize_kv_cache(kv_cache_config)
|
||||
|
||||
# random partition of blocks
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
import os
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Iterable, Sequence
|
||||
from dataclasses import astuple, dataclass
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, NewType, Optional, Union
|
||||
|
||||
from vllm import envs
|
||||
@ -811,59 +811,21 @@ def get_uniform_page_size(kv_cache_spec: dict[str, KVCacheSpec]) -> int:
|
||||
return page_sizes.pop()
|
||||
|
||||
|
||||
def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
|
||||
kv_cache_spec: dict[str, KVCacheSpec],
|
||||
available_memory: int) -> KVCacheConfig:
|
||||
def _get_kv_cache_groups_uniform_type(
|
||||
kv_cache_specs: dict[str, KVCacheSpec]) -> list[KVCacheGroupSpec]:
|
||||
"""
|
||||
Generates the KV cache configuration for a model with one type of KV cache.
|
||||
Divide the available memory equally among all layers.
|
||||
|
||||
Args:
|
||||
vllm_config: The global VllmConfig
|
||||
kv_cache_spec: The kv cache spec of each attention layer in the model
|
||||
available_memory: Memory available for KV cache in bytes.
|
||||
kv_cache_specs: The kv cache spec of each attention layer in the model
|
||||
|
||||
Returns:
|
||||
The generated KVCacheConfig
|
||||
The generated KVCacheGroupSpecs
|
||||
"""
|
||||
|
||||
page_size = get_uniform_page_size(kv_cache_spec)
|
||||
num_blocks = get_num_blocks(vllm_config, len(kv_cache_spec),
|
||||
available_memory, page_size)
|
||||
|
||||
per_layer_size = page_size * num_blocks
|
||||
# All layers have the same KV cache spec, so we create one kv cache group
|
||||
# for all layers.
|
||||
grouped_layer_names = [list(kv_cache_spec.keys())]
|
||||
|
||||
# Each layer uses a separate Tensor to store its KV cache.
|
||||
kv_cache_tensors = [
|
||||
KVCacheTensor(size=per_layer_size, shared_by=[layer_name])
|
||||
for layer_name in kv_cache_spec
|
||||
]
|
||||
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=num_blocks,
|
||||
kv_cache_tensors=kv_cache_tensors,
|
||||
kv_cache_groups=create_kv_cache_group_specs(kv_cache_spec,
|
||||
grouped_layer_names),
|
||||
)
|
||||
|
||||
num_tokens = num_blocks * vllm_config.cache_config.block_size
|
||||
if vllm_config.parallel_config.decode_context_parallel_size > 1:
|
||||
num_tokens *= vllm_config.parallel_config.decode_context_parallel_size
|
||||
logger.info(
|
||||
"Multiplying the GPU KV cache size by the dcp_world_size %d.",
|
||||
vllm_config.parallel_config.decode_context_parallel_size)
|
||||
|
||||
num_tokens_str = f"{num_tokens:,}"
|
||||
logger.info("GPU KV cache size: %s tokens", num_tokens_str)
|
||||
max_model_len_str = f"{vllm_config.model_config.max_model_len:,}"
|
||||
max_concurrency = get_max_concurrency_for_kv_cache_config(
|
||||
vllm_config, kv_cache_config)
|
||||
logger.info("Maximum concurrency for %s tokens per request: %.2fx",
|
||||
max_model_len_str, max_concurrency)
|
||||
return kv_cache_config
|
||||
return create_kv_cache_group_specs(kv_cache_specs,
|
||||
[list(kv_cache_specs.keys())])
|
||||
|
||||
|
||||
def is_kv_cache_page_size_uniform(
|
||||
@ -888,11 +850,10 @@ def is_kv_cache_type_attention_free(
|
||||
return not kv_cache_spec
|
||||
|
||||
|
||||
def _get_kv_cache_config_uniform_page_size(
|
||||
vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec],
|
||||
available_memory: int) -> KVCacheConfig:
|
||||
def _get_kv_cache_groups_uniform_page_size(
|
||||
kv_cache_spec: dict[str, KVCacheSpec]) -> list[KVCacheGroupSpec]:
|
||||
"""
|
||||
Generates the KV cache configuration for hybrid models with multiple
|
||||
Generates the KV cache groups for hybrid models with multiple
|
||||
attention types but still with a uniform page size (physical memory per
|
||||
block per layer) for all layers.
|
||||
|
||||
@ -949,11 +910,9 @@ def _get_kv_cache_config_uniform_page_size(
|
||||
memory per block is the same for all groups.
|
||||
|
||||
Args:
|
||||
vllm_config: The global VllmConfig
|
||||
kv_cache_spec: The KVCacheSpec of each attention layer in the model
|
||||
available_memory: Memory available for KV cache in bytes.
|
||||
Returns:
|
||||
The generated KVCacheConfig
|
||||
The generated KVCacheGroupSpecs
|
||||
"""
|
||||
# Group all layers by kv_cache_spec.
|
||||
# E.g., 2 full attention layers and 3 sliding window attention layers,
|
||||
@ -966,7 +925,7 @@ def _get_kv_cache_config_uniform_page_size(
|
||||
# group identical. Add padding to the last group of each type if necessary.
|
||||
# E.g., (full.0, full.1), (sw.0, sw.1, sw.2)
|
||||
# split to 3 groups with 2 layers each:
|
||||
# (full.0, full.1), (sw.0, sw.1), (sw.2, padding).
|
||||
# (full.0, full.1), (sw.0, sw.2), (sw.1, padding).
|
||||
# FIXME(Chen): At the moment of writing this code (2025-06-02), all
|
||||
# open-source hybrid model follows a n:1 pattern between different attention
|
||||
# types (e.g., Gemma3 5:1 between sw and full, LLaMA4 3:1 between local and
|
||||
@ -984,19 +943,60 @@ def _get_kv_cache_config_uniform_page_size(
|
||||
num_padding_layers,
|
||||
num_padding_layers / len(layers) * 100,
|
||||
)
|
||||
for i in range(0, len(layers), group_size):
|
||||
grouped_layers.append(layers[i:i + group_size])
|
||||
kv_cache_groups = create_kv_cache_group_specs(kv_cache_spec,
|
||||
grouped_layers)
|
||||
num_groups = cdiv(len(layers), group_size)
|
||||
# In PP case, say if we have
|
||||
# - stage 0: full.0, sw.0, sw.1
|
||||
# - stage 1: full.1, sw.2, sw.3
|
||||
# We should have 3 groups: (full.0, full.1), (sw.0, sw.2), (sw.1, sw.3)
|
||||
# It can't be (full.0, full.1), (sw.0, sw.1), (sw.2, sw.3) because
|
||||
# the 3 groups in stage 0 will be (full.0), (sw.0, sw.1), (empty group)
|
||||
# and it will be padded to (full.0, padding), (sw.0, sw.1),
|
||||
# (padding, padding) to ensure the number of layers in each group is
|
||||
# the same and will cause memory waste.
|
||||
# To avoid this, we assign layers[i::num_groups] to the i-th group
|
||||
# instead of layers[i * group_size: (i + 1) * group_size]
|
||||
for i in range(num_groups):
|
||||
grouped_layers.append(layers[i::num_groups])
|
||||
return create_kv_cache_group_specs(kv_cache_spec, grouped_layers)
|
||||
|
||||
|
||||
def get_kv_cache_config_from_groups(vllm_config: VllmConfig,
|
||||
kv_cache_groups: list[KVCacheGroupSpec],
|
||||
kv_cache_specs: dict[str, KVCacheSpec],
|
||||
available_memory: int) -> KVCacheConfig:
|
||||
"""
|
||||
Generate the KV cache configuration from the KV cache groups and spec
|
||||
of each layer.
|
||||
|
||||
Args:
|
||||
vllm_config: The global VllmConfig
|
||||
kv_cache_groups: The KV cache groups
|
||||
kv_cache_specs: The KV cache spec of each attention layer in the model
|
||||
available_memory: Memory available for KV cache in bytes
|
||||
Returns:
|
||||
The generated KVCacheConfig
|
||||
"""
|
||||
if len(kv_cache_groups) == 0:
|
||||
# Attention free models do not have KV cache.
|
||||
# Return num_blocks=1 as BlockPool always needs a null_block.
|
||||
return KVCacheConfig(
|
||||
num_blocks=1,
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=kv_cache_groups,
|
||||
)
|
||||
|
||||
# Determine how model runners should initialize the KV cache tensors.
|
||||
# We will have group_size memory pools, each is shared by one layer from
|
||||
# each group. As layers of different groups have different block table,
|
||||
# they will use different parts of the shared Tensor.
|
||||
# The memory layout in the example will be:
|
||||
# full.0, sw.0, sw.2: share a Tensor with size=available_memory//2
|
||||
# full.1, sw.1: share another Tensor with size=available_memory//2
|
||||
page_size = get_uniform_page_size(kv_cache_spec)
|
||||
# The memory layout for 3 groups (full.0, full.1), (sw.0, sw.2),
|
||||
# (sw.1, padding) will be: (group_size = 2)
|
||||
# full.0, sw.0, sw.1: share a Tensor with size=available_memory//2
|
||||
# full.1, sw.2: share another Tensor with size=available_memory//2
|
||||
group_size = max(len(group.layer_names) for group in kv_cache_groups)
|
||||
|
||||
page_size = get_uniform_page_size(kv_cache_specs)
|
||||
assert group_size > 0, "group_size must be greater than 0"
|
||||
num_blocks = get_num_blocks(vllm_config, group_size, available_memory,
|
||||
page_size)
|
||||
per_memory_pool_size = page_size * num_blocks
|
||||
@ -1004,8 +1004,8 @@ def _get_kv_cache_config_uniform_page_size(
|
||||
for i in range(group_size):
|
||||
shared_by = []
|
||||
for j in range(len(kv_cache_groups)):
|
||||
if i < len(grouped_layers[j]):
|
||||
shared_by.append(grouped_layers[j][i])
|
||||
if i < len(kv_cache_groups[j].layer_names):
|
||||
shared_by.append(kv_cache_groups[j].layer_names[i])
|
||||
kv_cache_tensors.append(
|
||||
KVCacheTensor(size=per_memory_pool_size, shared_by=shared_by))
|
||||
|
||||
@ -1019,7 +1019,12 @@ def _get_kv_cache_config_uniform_page_size(
|
||||
[group.kv_cache_spec.block_size for group in kv_cache_groups])
|
||||
|
||||
# Print the KV cache size and maximum concurrency.
|
||||
num_tokens = num_blocks // len(grouped_layers) * min_block_size
|
||||
num_tokens = num_blocks // len(kv_cache_groups) * min_block_size
|
||||
if vllm_config.parallel_config.decode_context_parallel_size > 1:
|
||||
num_tokens *= vllm_config.parallel_config.decode_context_parallel_size
|
||||
logger.info(
|
||||
"Multiplying the GPU KV cache size by the dcp_world_size %d.",
|
||||
vllm_config.parallel_config.decode_context_parallel_size)
|
||||
num_tokens_str = f"{num_tokens:,}"
|
||||
logger.info("GPU KV cache size: %s tokens", num_tokens_str)
|
||||
max_model_len_str = f"{vllm_config.model_config.max_model_len:,}"
|
||||
@ -1030,10 +1035,6 @@ def _get_kv_cache_config_uniform_page_size(
|
||||
return kv_cache_config
|
||||
|
||||
|
||||
def _get_kv_cache_config_attention_free() -> KVCacheConfig:
|
||||
return KVCacheConfig(num_blocks=1, kv_cache_tensors=[], kv_cache_groups=[])
|
||||
|
||||
|
||||
def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
|
||||
"""
|
||||
This function tries to convert the KV cache specs to one type if the model
|
||||
@ -1087,72 +1088,112 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
|
||||
"convert the KV cache specs to one unified type.")
|
||||
|
||||
|
||||
def get_kv_cache_config(
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_spec: dict[str, KVCacheSpec],
|
||||
available_memory: int,
|
||||
) -> KVCacheConfig:
|
||||
def get_kv_cache_groups(
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_spec: dict[str, KVCacheSpec]) -> list[KVCacheGroupSpec]:
|
||||
"""
|
||||
Generates the KV cache configuration for a model.
|
||||
Split the layers in the model into groups with the same KV cache spec.
|
||||
|
||||
Args:
|
||||
vllm_config: The global VllmConfig
|
||||
kv_cache_spec: The kv cache spec of each attention layer in the model
|
||||
available_memory: Memory available for KV cache in bytes.
|
||||
|
||||
Returns:
|
||||
The generated KVCacheConfigs
|
||||
The generated KVCacheGroups
|
||||
"""
|
||||
check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)
|
||||
if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager:
|
||||
unify_hybrid_kv_cache_specs(kv_cache_spec)
|
||||
|
||||
if is_kv_cache_type_attention_free(kv_cache_spec):
|
||||
# This returns a kv_cache config with 0 kv_cache groups and 1 block
|
||||
# to allow for the KVCache manager to handle attention free models.
|
||||
return _get_kv_cache_config_attention_free()
|
||||
# This returns an empty list to allow for the KVCacheManager to handle
|
||||
# attention free models.
|
||||
return []
|
||||
elif is_kv_cache_type_uniform(kv_cache_spec):
|
||||
# KV cache of all layers are the same, which is true for
|
||||
# most models. Allocate the same amount of memory for
|
||||
# each layer.
|
||||
return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec,
|
||||
available_memory)
|
||||
return _get_kv_cache_groups_uniform_type(kv_cache_spec)
|
||||
elif is_kv_cache_page_size_uniform(kv_cache_spec):
|
||||
# Model contains multiple attention types, but KV cache of all layers
|
||||
# have the same physical memory per block per layer. Split the layers
|
||||
# into groups with the same number of layers, and thus same total page
|
||||
# size.
|
||||
return _get_kv_cache_config_uniform_page_size(vllm_config,
|
||||
kv_cache_spec,
|
||||
available_memory)
|
||||
return _get_kv_cache_groups_uniform_page_size(kv_cache_spec)
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def unify_kv_cache_configs(kv_cache_configs: list[KVCacheConfig]):
|
||||
def get_kv_cache_configs(vllm_config: VllmConfig,
|
||||
kv_cache_specs: list[dict[str, KVCacheSpec]],
|
||||
available_memory: list[int]) -> list[KVCacheConfig]:
|
||||
"""
|
||||
Make the KV cache configurations for each worker consistent, so that all
|
||||
workers can be controlled by the same KVCacheManager.
|
||||
This function verifies that the layer group of each worker are the same,
|
||||
and changes the num_blocks of each worker to the smallest among all workers.
|
||||
Generates the KV cache configurations for a model.
|
||||
Since we use a shared centralized controller for all workers, we need the
|
||||
`kv_cache_config` to be consistent across all workers to make sure
|
||||
the KV cache allocation can be applied to all workers. However, different
|
||||
workers may have different memory available, and different type of layers
|
||||
(when pipeline parallel is enabled). To handle the difference between
|
||||
workers, the current implementation is:
|
||||
1. Merge the KV cache specs of all workers to get the KVCacheSpecs for
|
||||
the whole model.
|
||||
2. Generate the KV cache groups based on the layer ratio of the whole model.
|
||||
3. Generate the KV cache configs for each worker based on the KV cache
|
||||
grouping strategy. (This is reasonable because the layer ratio of
|
||||
different PP stages are similar.)
|
||||
4. Change the num_blocks of each worker to the smallest among all workers.
|
||||
|
||||
Args:
|
||||
kv_cache_configs: The KV cache configurations for each worker. Will be
|
||||
in-place modified to make them consistent.
|
||||
vllm_config: The global VllmConfig
|
||||
kv_cache_specs: List of dict[layer_name, KVCacheSpec] for each worker.
|
||||
available_memory: Memory available for KV cache in bytes for each
|
||||
worker.
|
||||
|
||||
Returns:
|
||||
The generated KVCacheConfigs for each worker.
|
||||
"""
|
||||
|
||||
# Sort the kv cache groups by their KV cache spec.
|
||||
# This can avoid the inconsistency caused by the order of groups.
|
||||
for kv_cache_config in kv_cache_configs:
|
||||
kv_cache_config.kv_cache_groups.sort(key=lambda x: (type(
|
||||
x.kv_cache_spec).__name__, astuple(x.kv_cache_spec)))
|
||||
# Check if the available memory is enough for each worker.
|
||||
for kv_cache_spec_one_worker, available_memory_one_worker in zip(
|
||||
kv_cache_specs, available_memory):
|
||||
check_enough_kv_cache_memory(vllm_config, kv_cache_spec_one_worker,
|
||||
available_memory_one_worker)
|
||||
|
||||
# Verify that the groups of each rank are the same.
|
||||
for kv_cache_config in kv_cache_configs[1:]:
|
||||
for group_rank_0, group_rank_i in zip(
|
||||
kv_cache_configs[0].kv_cache_groups,
|
||||
kv_cache_config.kv_cache_groups):
|
||||
assert group_rank_0.kv_cache_spec == group_rank_i.kv_cache_spec
|
||||
# Merge the KV cache specs of all workers. Different PP stages may have
|
||||
# different layer names, and different TP ranks of the same PP stage should
|
||||
# have the same KV cache spec.
|
||||
merged_kv_cache_specs: dict[str, KVCacheSpec] = {}
|
||||
for kv_cache_spec_one_worker in kv_cache_specs:
|
||||
for layer_name, layer_spec in kv_cache_spec_one_worker.items():
|
||||
if layer_name not in merged_kv_cache_specs:
|
||||
merged_kv_cache_specs[layer_name] = layer_spec
|
||||
else:
|
||||
assert merged_kv_cache_specs[layer_name] == layer_spec, (
|
||||
"The KV cache specs for the same layer are different "
|
||||
"across workers. This is not supported yet.")
|
||||
global_kv_cache_groups = get_kv_cache_groups(vllm_config,
|
||||
merged_kv_cache_specs)
|
||||
|
||||
kv_cache_configs: list[KVCacheConfig] = []
|
||||
for kv_cache_spec_one_worker, available_memory_one_worker in zip(
|
||||
kv_cache_specs, available_memory):
|
||||
kv_cache_groups_one_worker: list[KVCacheGroupSpec] = []
|
||||
for group in global_kv_cache_groups:
|
||||
group_layer_names_one_worker = [
|
||||
layer_name for layer_name in group.layer_names
|
||||
if layer_name in kv_cache_spec_one_worker
|
||||
]
|
||||
kv_cache_groups_one_worker.append(
|
||||
KVCacheGroupSpec(group_layer_names_one_worker,
|
||||
group.kv_cache_spec))
|
||||
assert sum(
|
||||
len(group.layer_names) for group in
|
||||
kv_cache_groups_one_worker) == len(kv_cache_spec_one_worker), (
|
||||
"Some layers are not assigned to any group.")
|
||||
kv_cache_configs.append(
|
||||
get_kv_cache_config_from_groups(vllm_config,
|
||||
kv_cache_groups_one_worker,
|
||||
kv_cache_spec_one_worker,
|
||||
available_memory_one_worker))
|
||||
|
||||
# Change the num_blocks of each rank to the smallest among all ranks. We
|
||||
# do not need to shrink the tensor size because it is valid to only use the
|
||||
|
||||
@ -29,10 +29,9 @@ from vllm.transformers_utils.config import (
|
||||
maybe_register_config_serialize_by_value)
|
||||
from vllm.utils import (decorate_logs, get_hash_fn_by_name, make_zmq_socket,
|
||||
resolve_obj_by_qualname, set_process_title)
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, get_kv_cache_config,
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, get_kv_cache_configs,
|
||||
get_request_block_hasher,
|
||||
init_none_hash,
|
||||
unify_kv_cache_configs)
|
||||
init_none_hash)
|
||||
from vllm.v1.core.sched.interface import SchedulerInterface
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
|
||||
@ -191,18 +190,9 @@ class EngineCore:
|
||||
available_gpu_memory = [0] * len(kv_cache_specs)
|
||||
|
||||
assert len(kv_cache_specs) == len(available_gpu_memory)
|
||||
# Get the kv cache tensor size
|
||||
kv_cache_configs = [
|
||||
get_kv_cache_config(vllm_config, kv_cache_spec_one_worker,
|
||||
available_gpu_memory_one_worker)
|
||||
for kv_cache_spec_one_worker, available_gpu_memory_one_worker in
|
||||
zip(kv_cache_specs, available_gpu_memory)
|
||||
]
|
||||
|
||||
# Since we use a shared centralized controller, we need the
|
||||
# `kv_cache_config` to be consistent across all workers to make sure
|
||||
# all the memory operators can be applied to all workers.
|
||||
unify_kv_cache_configs(kv_cache_configs)
|
||||
kv_cache_configs = get_kv_cache_configs(vllm_config, kv_cache_specs,
|
||||
available_gpu_memory)
|
||||
|
||||
# All workers have the same kv_cache_config except layer names, so use
|
||||
# an arbitrary one to initialize the scheduler.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user