From 8e5cdcda4e5a55ff49d92d37139042dda44b6b3c Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sun, 14 Sep 2025 15:55:17 -0700 Subject: [PATCH] [Hybrid Allocator] Support Pipeline Parallel (#23974) Signed-off-by: Chen Zhang --- tests/distributed/test_pipeline_parallel.py | 4 +- tests/models/test_initialization.py | 10 +- tests/v1/core/test_kv_cache_utils.py | 404 ++++++++++++++----- tests/v1/tpu/worker/test_tpu_model_runner.py | 10 +- tests/v1/worker/test_gpu_model_runner.py | 14 +- vllm/v1/core/kv_cache_utils.py | 247 +++++++----- vllm/v1/engine/core.py | 18 +- 7 files changed, 472 insertions(+), 235 deletions(-) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index fffab1a984c2..08702e8c061f 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -215,9 +215,7 @@ TEXT_GENERATION_MODELS = { EMBEDDING_MODELS = { # type: ignore[var-annotated] # [Text-only] "intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(runner="pooling"), - # TODO: re-enable when https://github.com/vllm-project/vllm/issues/23883 - # is fixed - #"BAAI/bge-multilingual-gemma2": PPTestSettings.fast(runner="pooling"), + "BAAI/bge-multilingual-gemma2": PPTestSettings.fast(runner="pooling"), "Qwen/Qwen2.5-Math-RM-72B": PPTestSettings.fast( load_format="dummy", runner="pooling" ), diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index c22d94948d24..0e18c45a21ee 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -10,7 +10,7 @@ from vllm import LLM from vllm.config import ModelImpl from vllm.engine.llm_engine import LLMEngine as V0LLMEngine from vllm.utils import GiB_bytes -from vllm.v1.core.kv_cache_utils import get_kv_cache_config +from vllm.v1.core.kv_cache_utils import get_kv_cache_configs from vllm.v1.engine.core import EngineCore as V1EngineCore from ..utils import create_new_process_for_each_test @@ -68,11 +68,11 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch, def _initialize_kv_caches_v1(self, vllm_config): kv_cache_specs = self.model_executor.get_kv_cache_specs() - scheduler_kv_cache_config = get_kv_cache_config( + scheduler_kv_cache_config = get_kv_cache_configs( vllm_config, - kv_cache_specs[0], - 10 * GiB_bytes, - ) + kv_cache_specs, + [10 * GiB_bytes], + )[0] # gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config return 1, 0, scheduler_kv_cache_config diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 44e479098ad5..2b44b16fd63b 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -18,13 +18,12 @@ from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.core.kv_cache_utils import ( BlockHash, FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics, estimate_max_model_len, generate_block_hash_extra_keys, - get_kv_cache_config, get_max_concurrency_for_kv_cache_config, + get_kv_cache_configs, get_max_concurrency_for_kv_cache_config, get_request_block_hasher, hash_block_tokens, init_none_hash, - is_kv_cache_type_uniform, make_block_hash_with_group_id, - unify_kv_cache_configs) + is_kv_cache_type_uniform, make_block_hash_with_group_id) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec, KVCacheTensor, - SlidingWindowSpec) + KVCacheGroupSpec, KVCacheSpec, + KVCacheTensor, SlidingWindowSpec) from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request @@ -531,102 +530,288 @@ def test_metrics(): assert not metrics.query_queue -def test_unify_kv_cache_configs(): - same_kv_cache_config = [ - KVCacheConfig( - num_blocks=10, - kv_cache_tensors=[ - KVCacheTensor(size=100, shared_by=["layer1"]), - KVCacheTensor(size=100, shared_by=["layer2"]), - ], - kv_cache_groups=[ - KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), - KVCacheGroupSpec(["layer2"], - new_kv_cache_spec(num_kv_heads=4)), - ], - ), - KVCacheConfig( - num_blocks=20, - kv_cache_tensors=[ - KVCacheTensor(size=100, shared_by=["layer1"]), - KVCacheTensor(size=100, shared_by=["layer2"]), - ], - kv_cache_groups=[ - KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), - KVCacheGroupSpec(["layer2"], - new_kv_cache_spec(num_kv_heads=4)), - ], - ), - ] - unify_kv_cache_configs(same_kv_cache_config) - assert same_kv_cache_config[0].num_blocks == 10 - assert same_kv_cache_config[1].num_blocks == 10 +def test_get_kv_cache_configs_multiple_workers(): + model_config = ModelConfig(max_model_len=16) + vllm_config = VllmConfig(model_config=model_config) - need_sort_kv_cache_config = [ + ref_kv_cache_spec = new_kv_cache_spec() + same_kv_cache_specs = [{ + "layer1": new_kv_cache_spec(), + "layer2": new_kv_cache_spec(), + }, { + "layer1": new_kv_cache_spec(), + "layer2": new_kv_cache_spec(), + }] + + # Basic case. All things are the same. + kv_cache_configs = get_kv_cache_configs(vllm_config, same_kv_cache_specs, [ + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ref_kv_cache_spec.page_size_bytes * 2 * 10 + ]) + assert kv_cache_configs == [ KVCacheConfig( num_blocks=10, kv_cache_tensors=[ - KVCacheTensor(size=100, shared_by=["layer1"]), - KVCacheTensor(size=100, shared_by=["layer2"]), + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=["layer1"]), + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=["layer2"]), ], kv_cache_groups=[ - KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), - KVCacheGroupSpec(["layer2"], - new_kv_cache_spec(num_kv_heads=4)), + KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec), ], ), KVCacheConfig( - num_blocks=20, + num_blocks=10, kv_cache_tensors=[ - KVCacheTensor(size=100, shared_by=["layer1"]), - KVCacheTensor(size=100, shared_by=["layer2"]), + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=["layer1"]), + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=["layer2"]), ], kv_cache_groups=[ - KVCacheGroupSpec(["layer2"], - new_kv_cache_spec(num_kv_heads=4)), - KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), + KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec), ], ), ] - unify_kv_cache_configs(need_sort_kv_cache_config) - sorted_kv_cache_groups = [ - KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), - KVCacheGroupSpec(["layer2"], new_kv_cache_spec(num_kv_heads=4)), - ] - assert ( - need_sort_kv_cache_config[0].kv_cache_groups == sorted_kv_cache_groups) - assert ( - need_sort_kv_cache_config[1].kv_cache_groups == sorted_kv_cache_groups) - - diff_kv_cache_config = [ + # Different available memory. This is the case for TP. + # Use the smallest memory available. + kv_cache_configs = get_kv_cache_configs(vllm_config, same_kv_cache_specs, [ + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ref_kv_cache_spec.page_size_bytes * 2 * 20 + ]) + assert kv_cache_configs == [ KVCacheConfig( num_blocks=10, kv_cache_tensors=[ - KVCacheTensor(size=100, shared_by=["layer1"]), - KVCacheTensor(size=100, shared_by=["layer2"]), + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=["layer1"]), + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=["layer2"]), ], kv_cache_groups=[ - KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), - KVCacheGroupSpec(["layer2"], - new_kv_cache_spec(num_kv_heads=4)), + KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec), ], ), KVCacheConfig( - num_blocks=20, + num_blocks=10, kv_cache_tensors=[ - KVCacheTensor(size=100, shared_by=["layer1"]), - KVCacheTensor(size=100, shared_by=["layer2"]), + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 20, + shared_by=["layer1"]), + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 20, + shared_by=["layer2"]), ], kv_cache_groups=[ - KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), - KVCacheGroupSpec(["layer2"], - new_kv_cache_spec(num_kv_heads=8)), + KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec), ], ), ] + + # Different KV cache specs. This is the case for PP. + different_layer_specs = [{ + "layer1": new_kv_cache_spec(), + }, { + "layer2": new_kv_cache_spec(), + "layer3": new_kv_cache_spec(), + }] + + # Different workers have different layers. + kv_cache_configs = get_kv_cache_configs( + vllm_config, different_layer_specs, [ + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ref_kv_cache_spec.page_size_bytes * 2 * 10 + ]) + assert kv_cache_configs == [ + KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[ + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 20, + shared_by=["layer1"]), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer1"], new_kv_cache_spec()), + ], + ), + KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[ + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=["layer2"]), + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=["layer3"]), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer2", "layer3"], new_kv_cache_spec()), + ], + ), + ] + + # Some layers are the same, some are different. This is the case for TP+PP + tp_pp_kv_cache_specs = [{ + "layer1": new_kv_cache_spec(), + "layer2": new_kv_cache_spec(), + }, { + "layer1": new_kv_cache_spec(), + "layer2": new_kv_cache_spec(), + }, { + "layer3": new_kv_cache_spec(), + }, { + "layer3": new_kv_cache_spec(), + }] + + kv_cache_configs = get_kv_cache_configs( + vllm_config, tp_pp_kv_cache_specs, [ + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ]) + assert kv_cache_configs == [ + KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[ + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=["layer1"]), + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=["layer2"]), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec), + ], + ), + KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[ + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=["layer1"]), + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=["layer2"]), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec), + ], + ), + KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[ + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 20, + shared_by=["layer3"]), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer3"], ref_kv_cache_spec), + ], + ), + KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[ + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 20, + shared_by=["layer3"]), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer3"], ref_kv_cache_spec), + ], + ), + ] + + # Different workers have different types of layers. This is the case for + # hybrid models + PP. + different_type_layer_specs = [{ + "layer1": new_kv_cache_spec(), + "layer2": new_kv_cache_spec(), + }, { + "layer3": new_sliding_window_spec(), + "layer4": new_sliding_window_spec(), + }] + kv_cache_configs = get_kv_cache_configs( + vllm_config, different_type_layer_specs, [ + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ]) + assert kv_cache_configs == [ + KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[ + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=["layer1"]), + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=["layer2"]), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer1", "layer2"], ref_kv_cache_spec), + KVCacheGroupSpec([], new_sliding_window_spec()), + ], + ), + KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[ + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=["layer3"]), + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=["layer4"]), + ], + kv_cache_groups=[ + KVCacheGroupSpec([], ref_kv_cache_spec), + KVCacheGroupSpec(["layer3", "layer4"], + new_sliding_window_spec()), + ], + ), + ] + + # When divided into multiple KVCacheGroups, need to ensure the number of + # layers per group is similar. + different_type_layer_specs = [{ + "layer1": new_kv_cache_spec(), + "layer2": new_sliding_window_spec(), + "layer3": new_sliding_window_spec(), + }, { + "layer4": new_kv_cache_spec(), + "layer5": new_sliding_window_spec(), + "layer6": new_sliding_window_spec(), + }] + kv_cache_configs = get_kv_cache_configs( + vllm_config, different_type_layer_specs, [ + ref_kv_cache_spec.page_size_bytes * 10, + ref_kv_cache_spec.page_size_bytes * 10, + ]) + assert kv_cache_configs == [ + KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[ + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=["layer1", "layer2", "layer3"]), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer1"], ref_kv_cache_spec), + KVCacheGroupSpec(["layer2"], new_sliding_window_spec()), + KVCacheGroupSpec(["layer3"], new_sliding_window_spec()), + ], + ), + KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[ + KVCacheTensor(size=ref_kv_cache_spec.page_size_bytes * 10, + shared_by=["layer4", "layer5", "layer6"]), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer4"], ref_kv_cache_spec), + KVCacheGroupSpec(["layer5"], new_sliding_window_spec()), + KVCacheGroupSpec(["layer6"], new_sliding_window_spec()), + ], + ), + ] + + # Have conflicting layers. Need to raise an error. + conflicting_layer_specs = [{ + "layer1": new_kv_cache_spec(), + }, { + "layer1": new_sliding_window_spec(), + }] with pytest.raises(AssertionError): - unify_kv_cache_configs(diff_kv_cache_config) + get_kv_cache_configs(vllm_config, conflicting_layer_specs, [ + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ref_kv_cache_spec.page_size_bytes * 2 * 10, + ]) def test_merge_kv_cache_spec(): @@ -890,7 +1075,7 @@ def test_allocate_with_lookahead(): assert len(blocks.get_block_ids()[0]) == 2 -def test_get_kv_cache_config(): +def test_get_kv_cache_config_one_worker(): # pass max_model_len to pass check_enough_kv_cache_memory model_config = ModelConfig(max_model_len=16) vllm_config = VllmConfig(model_config=model_config) @@ -901,8 +1086,10 @@ def test_get_kv_cache_config(): 'layer_1': new_kv_cache_spec(), 'layer_2': new_kv_cache_spec(), } - kv_cache_config_full = get_kv_cache_config( - vllm_config, kv_cache_specs_full, mem_per_block_per_layer * 2 * 32) + kv_cache_config_full = get_kv_cache_configs( + vllm_config, [kv_cache_specs_full], + [mem_per_block_per_layer * 2 * 32])[0] + print(kv_cache_config_full) assert kv_cache_config_full == KVCacheConfig( num_blocks=32, kv_cache_tensors=[ @@ -920,8 +1107,9 @@ def test_get_kv_cache_config(): 'layer_1': new_sliding_window_spec(), 'layer_2': new_sliding_window_spec(), } - kv_cache_config_sliding = get_kv_cache_config( - vllm_config, kv_cache_specs_sliding, mem_per_block_per_layer * 2 * 32) + kv_cache_config_sliding = get_kv_cache_configs( + vllm_config, [kv_cache_specs_sliding], + [mem_per_block_per_layer * 2 * 32])[0] assert kv_cache_config_sliding == KVCacheConfig( num_blocks=32, kv_cache_tensors=[ @@ -940,8 +1128,9 @@ def test_get_kv_cache_config(): 'layer_1': new_kv_cache_spec(), 'layer_2': new_sliding_window_spec(), } - kv_cache_config_hybrid = get_kv_cache_config( - vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 2 * 32) + kv_cache_config_hybrid = get_kv_cache_configs( + vllm_config, [kv_cache_specs_hybrid], + [mem_per_block_per_layer * 2 * 32])[0] assert kv_cache_config_hybrid == KVCacheConfig( num_blocks=32, kv_cache_tensors=[ @@ -962,8 +1151,9 @@ def test_get_kv_cache_config(): 'layer_1': new_kv_cache_spec(), 'layer_2': new_sliding_window_spec(), } - kv_cache_config_hybrid = get_kv_cache_config( - vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 2 * 32) + kv_cache_config_hybrid = get_kv_cache_configs( + vllm_config, [kv_cache_specs_hybrid], + [mem_per_block_per_layer * 2 * 32])[0] assert kv_cache_config_hybrid == KVCacheConfig( num_blocks=64, kv_cache_tensors=[ @@ -985,21 +1175,22 @@ def test_get_kv_cache_config(): 'layer_5': new_sliding_window_spec(), 'layer_6': new_sliding_window_spec(), } - kv_cache_config_hybrid = get_kv_cache_config( - vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 2 * 32) + kv_cache_config_hybrid = get_kv_cache_configs( + vllm_config, [kv_cache_specs_hybrid], + [mem_per_block_per_layer * 2 * 32])[0] assert kv_cache_config_hybrid == KVCacheConfig( num_blocks=32, kv_cache_tensors=[ KVCacheTensor(size=mem_per_block_per_layer * 32, - shared_by=["layer_1", "layer_3", "layer_5"]), + shared_by=["layer_1", "layer_3", "layer_4"]), KVCacheTensor(size=mem_per_block_per_layer * 32, - shared_by=["layer_2", "layer_4", "layer_6"]), + shared_by=["layer_2", "layer_5", "layer_6"]), ], kv_cache_groups=[ KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec()), - KVCacheGroupSpec(["layer_3", "layer_4"], + KVCacheGroupSpec(["layer_3", "layer_5"], new_sliding_window_spec()), - KVCacheGroupSpec(["layer_5", "layer_6"], + KVCacheGroupSpec(["layer_4", "layer_6"], new_sliding_window_spec()), ], ) @@ -1017,27 +1208,30 @@ def test_get_kv_cache_config(): 'layer_9': new_sliding_window_spec(), 'layer_10': new_sliding_window_spec(), } - kv_cache_config_hybrid = get_kv_cache_config( - vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 3 * 32) + kv_cache_config_hybrid = get_kv_cache_configs( + vllm_config, [kv_cache_specs_hybrid], + [mem_per_block_per_layer * 3 * 32])[0] assert kv_cache_config_hybrid == KVCacheConfig( num_blocks=32, kv_cache_tensors=[ KVCacheTensor( size=mem_per_block_per_layer * 32, - shared_by=["layer_1", "layer_4", "layer_7", "layer_10"]), + shared_by=["layer_1", "layer_4", "layer_5", "layer_6"]), + KVCacheTensor( + size=mem_per_block_per_layer * 32, + shared_by=["layer_2", "layer_7", "layer_8", "layer_9"]), KVCacheTensor(size=mem_per_block_per_layer * 32, - shared_by=["layer_2", "layer_5", "layer_8"]), - KVCacheTensor(size=mem_per_block_per_layer * 32, - shared_by=["layer_3", "layer_6", "layer_9"]), + shared_by=["layer_3", "layer_10"]), ], kv_cache_groups=[ KVCacheGroupSpec(["layer_1", "layer_2", "layer_3"], new_kv_cache_spec()), - KVCacheGroupSpec(["layer_4", "layer_5", "layer_6"], + KVCacheGroupSpec(["layer_4", "layer_7", "layer_10"], new_sliding_window_spec()), - KVCacheGroupSpec(["layer_7", "layer_8", "layer_9"], + KVCacheGroupSpec(["layer_5", "layer_8"], + new_sliding_window_spec()), + KVCacheGroupSpec(["layer_6", "layer_9"], new_sliding_window_spec()), - KVCacheGroupSpec(["layer_10"], new_sliding_window_spec()), ], ) @@ -1047,13 +1241,14 @@ def test_get_kv_cache_config(): 'layer_2': new_kv_cache_spec(), } with pytest.raises(NotImplementedError): - get_kv_cache_config(vllm_config, kv_cache_specs_hybrid, - mem_per_block_per_layer * 2 * 32) + get_kv_cache_configs(vllm_config, [kv_cache_specs_hybrid], + [mem_per_block_per_layer * 2 * 32])[0] # Test num_gpu_blocks_override vllm_config.cache_config.num_gpu_blocks_override = 16 - kv_cache_config_override_blocks = get_kv_cache_config( - vllm_config, kv_cache_specs_full, mem_per_block_per_layer * 2 * 32) + kv_cache_config_override_blocks = get_kv_cache_configs( + vllm_config, [kv_cache_specs_full], + [mem_per_block_per_layer * 2 * 32])[0] assert kv_cache_config_override_blocks == KVCacheConfig( num_blocks=16, kv_cache_tensors=[ @@ -1065,3 +1260,16 @@ def test_get_kv_cache_config(): kv_cache_groups=[ KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec()) ]) + + +def test_get_kv_cache_configs_attention_free(): + kv_cache_specs: dict[str, KVCacheSpec] = {} + vllm_config = VllmConfig(model_config=ModelConfig(max_model_len=16)) + kv_cache_configs = get_kv_cache_configs(vllm_config, [kv_cache_specs], [0]) + assert kv_cache_configs == [ + KVCacheConfig( + num_blocks=1, + kv_cache_tensors=[], + kv_cache_groups=[], + ) + ] diff --git a/tests/v1/tpu/worker/test_tpu_model_runner.py b/tests/v1/tpu/worker/test_tpu_model_runner.py index c719e44acc9c..bd9b6131c222 100644 --- a/tests/v1/tpu/worker/test_tpu_model_runner.py +++ b/tests/v1/tpu/worker/test_tpu_model_runner.py @@ -10,7 +10,7 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.utils import GiB_bytes from vllm.v1.core.kv_cache_utils import (estimate_max_model_len, - get_kv_cache_config) + get_kv_cache_configs) from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, SchedulerOutput) from vllm.v1.worker.tpu_model_runner import ( @@ -477,8 +477,8 @@ def test_init_kv_cache_without_kv_sharing(): # 2 (non-MLA) * 8 (num_heads) * 128 (head_dim) # * 2 (bfloat16, kv_cache dtype) * 128 (block_size) = 512KB num_expected_blocks = 20480 # 20GB / 512KB / 2 (num layers) - kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec, - available_memory) + kv_cache_config = get_kv_cache_configs(vllm_config, [kv_cache_spec], + [available_memory])[0] assert kv_cache_config.num_blocks == num_expected_blocks assert len(kv_cache_config.kv_cache_tensors) == 2 assert kv_cache_config.kv_cache_tensors[0].size == available_memory // 2 @@ -550,8 +550,8 @@ def test_init_kv_cache_with_kv_sharing_valid(): # with KV sharing, we can allocate (available_mem//page_size//1) blocks # which is twice as many as without KV sharing num_expected_blocks = 2 * 20480 # 20GB / 512KB - kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec, - available_memory) + kv_cache_config = get_kv_cache_configs(vllm_config, [kv_cache_spec], + [available_memory])[0] assert kv_cache_config.num_blocks == num_expected_blocks assert len(kv_cache_config.kv_cache_tensors) == 1 # Each layer now has twice the available memory for KV cache diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 5ebc00d57303..4ad8df1ce386 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -15,7 +15,7 @@ from vllm.platforms import current_platform from vllm.sampling_params import SamplingParams from vllm.utils import GiB_bytes, update_environment_variables from vllm.v1.core.kv_cache_utils import (estimate_max_model_len, - get_kv_cache_config) + get_kv_cache_configs) from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, SchedulerOutput) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, @@ -585,8 +585,8 @@ def test_init_kv_cache_without_kv_sharing(): available_memory = 20 * GiB_bytes # page size for layer 0's kv_cache_spec is 32KB num_expected_blocks = 327680 # 20GB / 32KB / 2 (num layers) - kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec, - available_memory) + kv_cache_config = get_kv_cache_configs(vllm_config, [kv_cache_spec], + [available_memory])[0] assert kv_cache_config.num_blocks == num_expected_blocks assert len(kv_cache_config.kv_cache_tensors) == 2 assert kv_cache_config.kv_cache_tensors[0].size == available_memory // 2 @@ -657,8 +657,8 @@ def test_init_kv_cache_with_kv_sharing_valid(): # with KV sharing, we can allocate (available_mem//page_size//1) blocks # which is twice as many as without KV sharing num_expected_blocks = 655360 # 20GB / 32KB - kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec, - available_memory) + kv_cache_config = get_kv_cache_configs(vllm_config, [kv_cache_spec], + [available_memory])[0] assert kv_cache_config.num_blocks == num_expected_blocks assert len(kv_cache_config.kv_cache_tensors) == 1 # Each layer now has twice the available memory for KV cache @@ -788,8 +788,8 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): kv_cache_spec = runner.get_kv_cache_spec() available_memory = 5 * GiB_bytes - kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec, - available_memory) + kv_cache_config = get_kv_cache_configs(vllm_config, [kv_cache_spec], + [available_memory])[0] runner.initialize_kv_cache(kv_cache_config) # random partition of blocks diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index f939da8c5b5c..533c0236dad7 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -5,7 +5,7 @@ import os from collections import defaultdict, deque from collections.abc import Iterable, Sequence -from dataclasses import astuple, dataclass +from dataclasses import dataclass from typing import Any, Callable, NewType, Optional, Union from vllm import envs @@ -811,59 +811,21 @@ def get_uniform_page_size(kv_cache_spec: dict[str, KVCacheSpec]) -> int: return page_sizes.pop() -def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, - kv_cache_spec: dict[str, KVCacheSpec], - available_memory: int) -> KVCacheConfig: +def _get_kv_cache_groups_uniform_type( + kv_cache_specs: dict[str, KVCacheSpec]) -> list[KVCacheGroupSpec]: """ Generates the KV cache configuration for a model with one type of KV cache. Divide the available memory equally among all layers. Args: - vllm_config: The global VllmConfig - kv_cache_spec: The kv cache spec of each attention layer in the model - available_memory: Memory available for KV cache in bytes. + kv_cache_specs: The kv cache spec of each attention layer in the model Returns: - The generated KVCacheConfig + The generated KVCacheGroupSpecs """ - page_size = get_uniform_page_size(kv_cache_spec) - num_blocks = get_num_blocks(vllm_config, len(kv_cache_spec), - available_memory, page_size) - - per_layer_size = page_size * num_blocks - # All layers have the same KV cache spec, so we create one kv cache group - # for all layers. - grouped_layer_names = [list(kv_cache_spec.keys())] - - # Each layer uses a separate Tensor to store its KV cache. - kv_cache_tensors = [ - KVCacheTensor(size=per_layer_size, shared_by=[layer_name]) - for layer_name in kv_cache_spec - ] - - kv_cache_config = KVCacheConfig( - num_blocks=num_blocks, - kv_cache_tensors=kv_cache_tensors, - kv_cache_groups=create_kv_cache_group_specs(kv_cache_spec, - grouped_layer_names), - ) - - num_tokens = num_blocks * vllm_config.cache_config.block_size - if vllm_config.parallel_config.decode_context_parallel_size > 1: - num_tokens *= vllm_config.parallel_config.decode_context_parallel_size - logger.info( - "Multiplying the GPU KV cache size by the dcp_world_size %d.", - vllm_config.parallel_config.decode_context_parallel_size) - - num_tokens_str = f"{num_tokens:,}" - logger.info("GPU KV cache size: %s tokens", num_tokens_str) - max_model_len_str = f"{vllm_config.model_config.max_model_len:,}" - max_concurrency = get_max_concurrency_for_kv_cache_config( - vllm_config, kv_cache_config) - logger.info("Maximum concurrency for %s tokens per request: %.2fx", - max_model_len_str, max_concurrency) - return kv_cache_config + return create_kv_cache_group_specs(kv_cache_specs, + [list(kv_cache_specs.keys())]) def is_kv_cache_page_size_uniform( @@ -888,11 +850,10 @@ def is_kv_cache_type_attention_free( return not kv_cache_spec -def _get_kv_cache_config_uniform_page_size( - vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec], - available_memory: int) -> KVCacheConfig: +def _get_kv_cache_groups_uniform_page_size( + kv_cache_spec: dict[str, KVCacheSpec]) -> list[KVCacheGroupSpec]: """ - Generates the KV cache configuration for hybrid models with multiple + Generates the KV cache groups for hybrid models with multiple attention types but still with a uniform page size (physical memory per block per layer) for all layers. @@ -949,11 +910,9 @@ def _get_kv_cache_config_uniform_page_size( memory per block is the same for all groups. Args: - vllm_config: The global VllmConfig kv_cache_spec: The KVCacheSpec of each attention layer in the model - available_memory: Memory available for KV cache in bytes. Returns: - The generated KVCacheConfig + The generated KVCacheGroupSpecs """ # Group all layers by kv_cache_spec. # E.g., 2 full attention layers and 3 sliding window attention layers, @@ -966,7 +925,7 @@ def _get_kv_cache_config_uniform_page_size( # group identical. Add padding to the last group of each type if necessary. # E.g., (full.0, full.1), (sw.0, sw.1, sw.2) # split to 3 groups with 2 layers each: - # (full.0, full.1), (sw.0, sw.1), (sw.2, padding). + # (full.0, full.1), (sw.0, sw.2), (sw.1, padding). # FIXME(Chen): At the moment of writing this code (2025-06-02), all # open-source hybrid model follows a n:1 pattern between different attention # types (e.g., Gemma3 5:1 between sw and full, LLaMA4 3:1 between local and @@ -984,19 +943,60 @@ def _get_kv_cache_config_uniform_page_size( num_padding_layers, num_padding_layers / len(layers) * 100, ) - for i in range(0, len(layers), group_size): - grouped_layers.append(layers[i:i + group_size]) - kv_cache_groups = create_kv_cache_group_specs(kv_cache_spec, - grouped_layers) + num_groups = cdiv(len(layers), group_size) + # In PP case, say if we have + # - stage 0: full.0, sw.0, sw.1 + # - stage 1: full.1, sw.2, sw.3 + # We should have 3 groups: (full.0, full.1), (sw.0, sw.2), (sw.1, sw.3) + # It can't be (full.0, full.1), (sw.0, sw.1), (sw.2, sw.3) because + # the 3 groups in stage 0 will be (full.0), (sw.0, sw.1), (empty group) + # and it will be padded to (full.0, padding), (sw.0, sw.1), + # (padding, padding) to ensure the number of layers in each group is + # the same and will cause memory waste. + # To avoid this, we assign layers[i::num_groups] to the i-th group + # instead of layers[i * group_size: (i + 1) * group_size] + for i in range(num_groups): + grouped_layers.append(layers[i::num_groups]) + return create_kv_cache_group_specs(kv_cache_spec, grouped_layers) + + +def get_kv_cache_config_from_groups(vllm_config: VllmConfig, + kv_cache_groups: list[KVCacheGroupSpec], + kv_cache_specs: dict[str, KVCacheSpec], + available_memory: int) -> KVCacheConfig: + """ + Generate the KV cache configuration from the KV cache groups and spec + of each layer. + + Args: + vllm_config: The global VllmConfig + kv_cache_groups: The KV cache groups + kv_cache_specs: The KV cache spec of each attention layer in the model + available_memory: Memory available for KV cache in bytes + Returns: + The generated KVCacheConfig + """ + if len(kv_cache_groups) == 0: + # Attention free models do not have KV cache. + # Return num_blocks=1 as BlockPool always needs a null_block. + return KVCacheConfig( + num_blocks=1, + kv_cache_tensors=[], + kv_cache_groups=kv_cache_groups, + ) # Determine how model runners should initialize the KV cache tensors. # We will have group_size memory pools, each is shared by one layer from # each group. As layers of different groups have different block table, # they will use different parts of the shared Tensor. - # The memory layout in the example will be: - # full.0, sw.0, sw.2: share a Tensor with size=available_memory//2 - # full.1, sw.1: share another Tensor with size=available_memory//2 - page_size = get_uniform_page_size(kv_cache_spec) + # The memory layout for 3 groups (full.0, full.1), (sw.0, sw.2), + # (sw.1, padding) will be: (group_size = 2) + # full.0, sw.0, sw.1: share a Tensor with size=available_memory//2 + # full.1, sw.2: share another Tensor with size=available_memory//2 + group_size = max(len(group.layer_names) for group in kv_cache_groups) + + page_size = get_uniform_page_size(kv_cache_specs) + assert group_size > 0, "group_size must be greater than 0" num_blocks = get_num_blocks(vllm_config, group_size, available_memory, page_size) per_memory_pool_size = page_size * num_blocks @@ -1004,8 +1004,8 @@ def _get_kv_cache_config_uniform_page_size( for i in range(group_size): shared_by = [] for j in range(len(kv_cache_groups)): - if i < len(grouped_layers[j]): - shared_by.append(grouped_layers[j][i]) + if i < len(kv_cache_groups[j].layer_names): + shared_by.append(kv_cache_groups[j].layer_names[i]) kv_cache_tensors.append( KVCacheTensor(size=per_memory_pool_size, shared_by=shared_by)) @@ -1019,7 +1019,12 @@ def _get_kv_cache_config_uniform_page_size( [group.kv_cache_spec.block_size for group in kv_cache_groups]) # Print the KV cache size and maximum concurrency. - num_tokens = num_blocks // len(grouped_layers) * min_block_size + num_tokens = num_blocks // len(kv_cache_groups) * min_block_size + if vllm_config.parallel_config.decode_context_parallel_size > 1: + num_tokens *= vllm_config.parallel_config.decode_context_parallel_size + logger.info( + "Multiplying the GPU KV cache size by the dcp_world_size %d.", + vllm_config.parallel_config.decode_context_parallel_size) num_tokens_str = f"{num_tokens:,}" logger.info("GPU KV cache size: %s tokens", num_tokens_str) max_model_len_str = f"{vllm_config.model_config.max_model_len:,}" @@ -1030,10 +1035,6 @@ def _get_kv_cache_config_uniform_page_size( return kv_cache_config -def _get_kv_cache_config_attention_free() -> KVCacheConfig: - return KVCacheConfig(num_blocks=1, kv_cache_tensors=[], kv_cache_groups=[]) - - def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): """ This function tries to convert the KV cache specs to one type if the model @@ -1087,72 +1088,112 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): "convert the KV cache specs to one unified type.") -def get_kv_cache_config( - vllm_config: VllmConfig, - kv_cache_spec: dict[str, KVCacheSpec], - available_memory: int, -) -> KVCacheConfig: +def get_kv_cache_groups( + vllm_config: VllmConfig, + kv_cache_spec: dict[str, KVCacheSpec]) -> list[KVCacheGroupSpec]: """ - Generates the KV cache configuration for a model. + Split the layers in the model into groups with the same KV cache spec. Args: vllm_config: The global VllmConfig kv_cache_spec: The kv cache spec of each attention layer in the model - available_memory: Memory available for KV cache in bytes. Returns: - The generated KVCacheConfigs + The generated KVCacheGroups """ - check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory) if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager: unify_hybrid_kv_cache_specs(kv_cache_spec) if is_kv_cache_type_attention_free(kv_cache_spec): - # This returns a kv_cache config with 0 kv_cache groups and 1 block - # to allow for the KVCache manager to handle attention free models. - return _get_kv_cache_config_attention_free() + # This returns an empty list to allow for the KVCacheManager to handle + # attention free models. + return [] elif is_kv_cache_type_uniform(kv_cache_spec): # KV cache of all layers are the same, which is true for # most models. Allocate the same amount of memory for # each layer. - return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec, - available_memory) + return _get_kv_cache_groups_uniform_type(kv_cache_spec) elif is_kv_cache_page_size_uniform(kv_cache_spec): # Model contains multiple attention types, but KV cache of all layers # have the same physical memory per block per layer. Split the layers # into groups with the same number of layers, and thus same total page # size. - return _get_kv_cache_config_uniform_page_size(vllm_config, - kv_cache_spec, - available_memory) + return _get_kv_cache_groups_uniform_page_size(kv_cache_spec) raise NotImplementedError -def unify_kv_cache_configs(kv_cache_configs: list[KVCacheConfig]): +def get_kv_cache_configs(vllm_config: VllmConfig, + kv_cache_specs: list[dict[str, KVCacheSpec]], + available_memory: list[int]) -> list[KVCacheConfig]: """ - Make the KV cache configurations for each worker consistent, so that all - workers can be controlled by the same KVCacheManager. - This function verifies that the layer group of each worker are the same, - and changes the num_blocks of each worker to the smallest among all workers. + Generates the KV cache configurations for a model. + Since we use a shared centralized controller for all workers, we need the + `kv_cache_config` to be consistent across all workers to make sure + the KV cache allocation can be applied to all workers. However, different + workers may have different memory available, and different type of layers + (when pipeline parallel is enabled). To handle the difference between + workers, the current implementation is: + 1. Merge the KV cache specs of all workers to get the KVCacheSpecs for + the whole model. + 2. Generate the KV cache groups based on the layer ratio of the whole model. + 3. Generate the KV cache configs for each worker based on the KV cache + grouping strategy. (This is reasonable because the layer ratio of + different PP stages are similar.) + 4. Change the num_blocks of each worker to the smallest among all workers. Args: - kv_cache_configs: The KV cache configurations for each worker. Will be - in-place modified to make them consistent. + vllm_config: The global VllmConfig + kv_cache_specs: List of dict[layer_name, KVCacheSpec] for each worker. + available_memory: Memory available for KV cache in bytes for each + worker. + + Returns: + The generated KVCacheConfigs for each worker. """ - # Sort the kv cache groups by their KV cache spec. - # This can avoid the inconsistency caused by the order of groups. - for kv_cache_config in kv_cache_configs: - kv_cache_config.kv_cache_groups.sort(key=lambda x: (type( - x.kv_cache_spec).__name__, astuple(x.kv_cache_spec))) + # Check if the available memory is enough for each worker. + for kv_cache_spec_one_worker, available_memory_one_worker in zip( + kv_cache_specs, available_memory): + check_enough_kv_cache_memory(vllm_config, kv_cache_spec_one_worker, + available_memory_one_worker) - # Verify that the groups of each rank are the same. - for kv_cache_config in kv_cache_configs[1:]: - for group_rank_0, group_rank_i in zip( - kv_cache_configs[0].kv_cache_groups, - kv_cache_config.kv_cache_groups): - assert group_rank_0.kv_cache_spec == group_rank_i.kv_cache_spec + # Merge the KV cache specs of all workers. Different PP stages may have + # different layer names, and different TP ranks of the same PP stage should + # have the same KV cache spec. + merged_kv_cache_specs: dict[str, KVCacheSpec] = {} + for kv_cache_spec_one_worker in kv_cache_specs: + for layer_name, layer_spec in kv_cache_spec_one_worker.items(): + if layer_name not in merged_kv_cache_specs: + merged_kv_cache_specs[layer_name] = layer_spec + else: + assert merged_kv_cache_specs[layer_name] == layer_spec, ( + "The KV cache specs for the same layer are different " + "across workers. This is not supported yet.") + global_kv_cache_groups = get_kv_cache_groups(vllm_config, + merged_kv_cache_specs) + + kv_cache_configs: list[KVCacheConfig] = [] + for kv_cache_spec_one_worker, available_memory_one_worker in zip( + kv_cache_specs, available_memory): + kv_cache_groups_one_worker: list[KVCacheGroupSpec] = [] + for group in global_kv_cache_groups: + group_layer_names_one_worker = [ + layer_name for layer_name in group.layer_names + if layer_name in kv_cache_spec_one_worker + ] + kv_cache_groups_one_worker.append( + KVCacheGroupSpec(group_layer_names_one_worker, + group.kv_cache_spec)) + assert sum( + len(group.layer_names) for group in + kv_cache_groups_one_worker) == len(kv_cache_spec_one_worker), ( + "Some layers are not assigned to any group.") + kv_cache_configs.append( + get_kv_cache_config_from_groups(vllm_config, + kv_cache_groups_one_worker, + kv_cache_spec_one_worker, + available_memory_one_worker)) # Change the num_blocks of each rank to the smallest among all ranks. We # do not need to shrink the tensor size because it is valid to only use the diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 995e70385be8..64a67f3b438e 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -29,10 +29,9 @@ from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) from vllm.utils import (decorate_logs, get_hash_fn_by_name, make_zmq_socket, resolve_obj_by_qualname, set_process_title) -from vllm.v1.core.kv_cache_utils import (BlockHash, get_kv_cache_config, +from vllm.v1.core.kv_cache_utils import (BlockHash, get_kv_cache_configs, get_request_block_hasher, - init_none_hash, - unify_kv_cache_configs) + init_none_hash) from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler @@ -191,18 +190,9 @@ class EngineCore: available_gpu_memory = [0] * len(kv_cache_specs) assert len(kv_cache_specs) == len(available_gpu_memory) - # Get the kv cache tensor size - kv_cache_configs = [ - get_kv_cache_config(vllm_config, kv_cache_spec_one_worker, - available_gpu_memory_one_worker) - for kv_cache_spec_one_worker, available_gpu_memory_one_worker in - zip(kv_cache_specs, available_gpu_memory) - ] - # Since we use a shared centralized controller, we need the - # `kv_cache_config` to be consistent across all workers to make sure - # all the memory operators can be applied to all workers. - unify_kv_cache_configs(kv_cache_configs) + kv_cache_configs = get_kv_cache_configs(vllm_config, kv_cache_specs, + available_gpu_memory) # All workers have the same kv_cache_config except layer names, so use # an arbitrary one to initialize the scheduler.