diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index ad34becb1e8d..71ea43383a7e 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -12,13 +12,11 @@ from vllm.utils import GiB_bytes, sha256 from vllm.v1.core.kv_cache_manager import KVCacheManager # disable yapf here as it formats differently than isort such that both fail # yapf: disable -from vllm.v1.core.kv_cache_utils import (FreeKVCacheBlockQueue, KVCacheBlock, - PrefixCachingMetrics, - estimate_max_model_len, - generate_block_hash_extra_keys, - hash_block_tokens, - hash_request_tokens, - unify_kv_cache_configs) +from vllm.v1.core.kv_cache_utils import ( + FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics, + estimate_max_model_len, generate_block_hash_extra_keys, + get_max_concurrency_for_kv_cache_config, hash_block_tokens, + hash_request_tokens, unify_kv_cache_configs) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheTensor, SlidingWindowSpec) @@ -597,6 +595,84 @@ def test_estimate_max_model_len(model_id, max_model_len, assert estimated_max_len == want_estimated_max_len +def test_get_max_concurrency_for_kv_cache_config(): + # Create a VllmConfig + model_id = "Qwen/Qwen1.5-7B" + max_model_len = 16384 + model_config = ModelConfig( + model_id, + task="generate", + tokenizer=model_id, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="float16", + max_model_len=max_model_len, + ) + scheduler_config = SchedulerConfig(max_num_batched_tokens=1024, + enable_chunked_prefill=True) + + vllm_config = VllmConfig( + model_config=model_config, + scheduler_config=scheduler_config, + ) + + full_attention_spec = FullAttentionSpec( + block_size=16, + num_kv_heads=32, + head_size=128, + dtype=torch.float16, + use_mla=False, + ) + + sliding_window_spec = SlidingWindowSpec( + block_size=16, + num_kv_heads=32, + head_size=128, + dtype=torch.float16, + use_mla=False, + sliding_window=1024, + ) + + kv_cache_config_full_attention = KVCacheConfig( + num_blocks=int(1024 * 1.5), + tensors={}, + kv_cache_groups=[ + KVCacheGroupSpec([f"layer_{i}" for i in range(32)], + full_attention_spec), + ], + ) + max_concurrency_full_attention = get_max_concurrency_for_kv_cache_config( + vllm_config, kv_cache_config_full_attention) + assert max_concurrency_full_attention == 1.5 + + kv_cache_config_sliding_window = KVCacheConfig( + num_blocks=129 * 3, + tensors={}, + kv_cache_groups=[ + KVCacheGroupSpec([f"layer_{i}" for i in range(32)], + sliding_window_spec), + ], + ) + max_concurrency_sliding_window = get_max_concurrency_for_kv_cache_config( + vllm_config, kv_cache_config_sliding_window) + assert max_concurrency_sliding_window == 3 + + kv_cache_config_hybrid_model = KVCacheConfig( + num_blocks=(1024 + 129) * 3, + tensors={}, + kv_cache_groups=[ + KVCacheGroupSpec([f"layer_{i}" for i in range(32)], + full_attention_spec), + KVCacheGroupSpec([f"layer_{i}" for i in range(32, 64)], + sliding_window_spec), + ], + ) + max_concurrency_hybrid_model = get_max_concurrency_for_kv_cache_config( + vllm_config, kv_cache_config_hybrid_model) + assert max_concurrency_hybrid_model == 3 + + def test_allocate_with_lookahead(): """Verify that lookahead tokens correctly affect block allocation""" block_size = 4 diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 3b5a379267e5..ad3c21f794b9 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -3,13 +3,13 @@ """KV-Cache Utilities.""" import os from collections import deque -from collections.abc import Sequence +from collections.abc import Iterable, Sequence from dataclasses import dataclass from typing import Any, Callable, NamedTuple, Optional from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.utils import GiB_bytes, sha256 +from vllm.utils import GiB_bytes, cdiv, sha256 from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheSpec, KVCacheTensor, SlidingWindowSpec) @@ -468,6 +468,15 @@ def hash_request_tokens(hash_function: Any, block_size: int, return ret +def max_memory_usage_bytes(vllm_config: VllmConfig, + kv_cache_specs: Iterable[KVCacheSpec]) -> int: + """ + Get the maximum memory usage in bytes for the given KV cache specs. + """ + return sum( + spec.max_memory_usage_bytes(vllm_config) for spec in kv_cache_specs) + + def estimate_max_model_len(vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec], available_memory: int) -> int: @@ -489,11 +498,8 @@ def estimate_max_model_len(vllm_config: VllmConfig, # Modify the max_model_len for this calculation vllm_config.model_config.max_model_len = model_len # Calculate memory needed for the given model length - memory_needed = sum( - (layer_spec.max_memory_usage_bytes(vllm_config) - for layer_spec in kv_cache_spec.values()), - start=0, - ) + memory_needed = max_memory_usage_bytes(vllm_config, + kv_cache_spec.values()) return memory_needed <= available_memory # Binary search for the maximum model length @@ -538,9 +544,7 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig, "initializing the engine.") max_model_len = vllm_config.model_config.max_model_len - needed_memory = 0 - for layer_spec in kv_cache_spec.values(): - needed_memory += layer_spec.max_memory_usage_bytes(vllm_config) + needed_memory = max_memory_usage_bytes(vllm_config, kv_cache_spec.values()) if needed_memory > available_memory: # Estimate the maximum model length that can fit in the available memory @@ -606,6 +610,24 @@ def is_kv_cache_type_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: return len(layer_keys) == 1 +def get_max_concurrency_for_kv_cache_config( + vllm_config: VllmConfig, kv_cache_config: KVCacheConfig) -> float: + """ + Get the maximum concurrency for the given KV cache configuration. + """ + num_layer_per_group = max( + len(group.layer_names) for group in kv_cache_config.kv_cache_groups) + max_memory_usage_per_request = num_layer_per_group * max_memory_usage_bytes( + vllm_config, + (group.kv_cache_spec for group in kv_cache_config.kv_cache_groups)) + memory_per_block = kv_cache_config.kv_cache_groups[ + 0].kv_cache_spec.page_size_bytes * num_layer_per_group + num_block_per_request = cdiv(max_memory_usage_per_request, + memory_per_block) + max_concurrency = kv_cache_config.num_blocks / num_block_per_request + return max_concurrency + + def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec], available_memory: int) -> KVCacheConfig: @@ -637,14 +659,6 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, "num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override) num_blocks = num_gpu_blocks_override - num_tokens = num_blocks * vllm_config.cache_config.block_size - num_tokens_str = f"{num_tokens:,}" - logger.info("GPU KV cache size: %s tokens", num_tokens_str) - max_model_len_str = f"{vllm_config.model_config.max_model_len:,}" - max_concurrency = num_tokens / vllm_config.model_config.max_model_len - logger.info("Maximum concurrency for %s tokens per request: %.2fx", - max_model_len_str, max_concurrency) - per_layer_size = page_size * num_blocks # All layers have the same KV cache spec, so we create one kv cache group # for all layers. @@ -659,6 +673,15 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, kv_cache_groups=create_kv_cache_group_specs(kv_cache_spec, grouped_layer_names), ) + + num_tokens = num_blocks * vllm_config.cache_config.block_size + num_tokens_str = f"{num_tokens:,}" + logger.info("GPU KV cache size: %s tokens", num_tokens_str) + max_model_len_str = f"{vllm_config.model_config.max_model_len:,}" + max_concurrency = get_max_concurrency_for_kv_cache_config( + vllm_config, kv_cache_config) + logger.info("Maximum concurrency for %s tokens per request: %.2fx", + max_model_len_str, max_concurrency) return kv_cache_config @@ -705,8 +728,8 @@ def get_kv_cache_config(vllm_config: VllmConfig, Returns: The generated KVCacheConfigs """ - check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory) unify_hybrid_kv_cache_specs(kv_cache_spec) + check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory) if is_kv_cache_type_uniform(kv_cache_spec): # KV cache of all layers are the same, which is true for # most models. Allocate the same amount of memory for