diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index ebe3a30e3352d..e9c6f1f95cd71 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -17,7 +17,7 @@ from vllm.v1.core.kv_cache_utils import ( estimate_max_model_len, generate_block_hash_extra_keys, get_kv_cache_config, get_max_concurrency_for_kv_cache_config, hash_block_tokens, hash_request_tokens, init_none_hash, - unify_kv_cache_configs) + is_kv_cache_type_uniform, unify_kv_cache_configs) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheTensor, SlidingWindowSpec) @@ -685,6 +685,38 @@ def test_merge_kv_cache_spec(): assert merged_layer_spec.sliding_window == 1 +def test_is_kv_cache_type_uniform(): + kv_cache_spec = { + "layer_1": new_kv_cache_spec(num_kv_heads=32), + "layer_2": new_kv_cache_spec(num_kv_heads=32), + } + assert is_kv_cache_type_uniform(kv_cache_spec) + + kv_cache_spec = { + "layer_1": new_kv_cache_spec(num_kv_heads=32), + "layer_2": new_kv_cache_spec(num_kv_heads=32, sliding_window=1), + } + assert is_kv_cache_type_uniform(kv_cache_spec) + + kv_cache_spec = { + "layer_1": new_kv_cache_spec(num_kv_heads=32), + "layer_2": new_sliding_window_spec(num_kv_heads=32, sliding_window=1), + } + assert not is_kv_cache_type_uniform(kv_cache_spec) + + kv_cache_spec = { + "layer_1": new_sliding_window_spec(num_kv_heads=32, sliding_window=1), + "layer_2": new_sliding_window_spec(num_kv_heads=32, sliding_window=1), + } + assert is_kv_cache_type_uniform(kv_cache_spec) + + kv_cache_spec = { + "layer_1": new_sliding_window_spec(num_kv_heads=32, sliding_window=1), + "layer_2": new_sliding_window_spec(num_kv_heads=32, sliding_window=2), + } + assert not is_kv_cache_type_uniform(kv_cache_spec) + + @pytest.mark.parametrize( ("model_id", "max_model_len", "want_estimated_max_len"), [ ("Qwen/Qwen1.5-7B", 16385, 16384), diff --git a/tests/v1/e2e/test_correctness_sliding_window.py b/tests/v1/e2e/test_correctness_sliding_window.py index 277ea3c838505..4dfe1d3bb33fa 100644 --- a/tests/v1/e2e/test_correctness_sliding_window.py +++ b/tests/v1/e2e/test_correctness_sliding_window.py @@ -30,7 +30,9 @@ model_config = { ]) @pytest.mark.parametrize("batch_size", [5]) @pytest.mark.parametrize("seed", [1]) -def test_sliding_window_retrieval(monkeypatch, model, batch_size, seed): +@pytest.mark.parametrize("disable_hybrid_kv_cache_manager", [True, False]) +def test_sliding_window_retrieval(monkeypatch, model, batch_size, seed, + disable_hybrid_kv_cache_manager): """ The test does a bunch of assignments "x1 = 10\nx2 = 33\n..." and then asks for value of one of them (which is outside the sliding window). @@ -42,7 +44,9 @@ def test_sliding_window_retrieval(monkeypatch, model, batch_size, seed): test_config = model_config[model] - llm = LLM(model=model) + llm = LLM( + model=model, + disable_hybrid_kv_cache_manager=disable_hybrid_kv_cache_manager) sampling_params = SamplingParams(temperature=0.0, max_tokens=100) prompts, answer, indices = prep_prompts(batch_size, diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index de72e60434ad7..0cce2ec81e08a 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -7,7 +7,8 @@ from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock from vllm.v1.core.single_type_kv_cache_manager import ( FullAttentionManager, get_manager_for_kv_cache_spec) -from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, + KVCacheSpec) from vllm.v1.request import Request @@ -258,44 +259,40 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): one of them is full attention. Then, split the kv cache groups into full attention groups and other groups. """ - full_attention_type_id: Optional[str] = None - other_type_id: Optional[str] = None + full_attention_spec: Optional[FullAttentionSpec] = None + other_spec: Optional[KVCacheSpec] = None self.full_attention_group_ids: list[int] = [] self.other_group_ids: list[int] = [] for i, g in enumerate(self.kv_cache_config.kv_cache_groups): if isinstance(g.kv_cache_spec, FullAttentionSpec): - if full_attention_type_id is None: - full_attention_type_id = g.kv_cache_spec.type_id + if full_attention_spec is None: + full_attention_spec = g.kv_cache_spec else: - assert full_attention_type_id == g.kv_cache_spec.type_id, ( + assert full_attention_spec == g.kv_cache_spec, ( "HybridKVCacheCoordinator assumes exactly one type of " "full attention groups now.") self.full_attention_group_ids.append(i) else: - if other_type_id is None: - other_type_id = g.kv_cache_spec.type_id + if other_spec is None: + other_spec = g.kv_cache_spec else: - assert other_type_id == g.kv_cache_spec.type_id, ( + assert other_spec == g.kv_cache_spec, ( "HybridKVCacheCoordinator assumes " "exactly one other type of groups now.") self.other_group_ids.append(i) - assert full_attention_type_id is not None, ( + assert full_attention_spec is not None, ( "HybridKVCacheCoordinator assumes exactly one type of full " "attention groups now.") - assert other_type_id is not None, ( + assert other_spec is not None, ( "HybridKVCacheCoordinator assumes exactly one type of other " "groups now.") self.full_attention_manager_cls = FullAttentionManager self.other_attention_cls = self.single_type_managers[ self.other_group_ids[0]].__class__ - - self.full_attention_spec = self.kv_cache_config.kv_cache_groups[ - self.full_attention_group_ids[0]].kv_cache_spec - self.other_spec = self.kv_cache_config.kv_cache_groups[ - self.other_group_ids[0]].kv_cache_spec - + self.full_attention_spec = full_attention_spec + self.other_spec = other_spec self.full_attention_block_size = self.full_attention_spec.block_size self.other_block_size = self.other_spec.block_size diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 5b0218640a8c8..3a72ac271afa6 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -5,7 +5,7 @@ import os from collections import defaultdict, deque from collections.abc import Iterable, Sequence -from dataclasses import dataclass +from dataclasses import astuple, dataclass from typing import Any, Callable, NamedTuple, Optional from vllm.config import VllmConfig @@ -727,7 +727,9 @@ def create_kv_cache_group_specs( def is_kv_cache_type_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: """ - Whether all layers in the given KVCacheSpec have the same type of KV cache. + Whether all layers in the given KVCacheSpec have the same KV cache spec. + Note that we regard FullAttentionSpec with and without sliding window as + the same type. Args: kv_cache_spec: The kv cache spec of each attention layer in the model @@ -736,8 +738,12 @@ def is_kv_cache_type_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: True if all layers have the same type, False otherwise. """ - layer_keys = set(layer.type_id for layer in kv_cache_spec.values()) - return len(layer_keys) == 1 + try: + kv_cache_spec_values = list(kv_cache_spec.values()) + _ = kv_cache_spec_values[0].merge(kv_cache_spec_values) + except AssertionError: + return False + return True def get_max_concurrency_for_kv_cache_config( @@ -928,12 +934,12 @@ def _get_kv_cache_config_uniform_page_size( Returns: The generated KVCacheConfig """ - # Group all layers by type_id. + # Group all layers by kv_cache_spec. # E.g., 2 full attention layers and 3 sliding window attention layers, # -> (full.0, full.1), (sw.0, sw.1, sw.2). - same_type_layers: dict[str, list[str]] = defaultdict(list) + same_type_layers: dict[KVCacheSpec, list[str]] = defaultdict(list) for layer_name, layer_spec in kv_cache_spec.items(): - same_type_layers[layer_spec.type_id].append(layer_name) + same_type_layers[layer_spec].append(layer_name) # Split each group into smaller groups, to make the number of layers in each # group identical. Add padding to the last group of each type if necessary. @@ -1017,12 +1023,7 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): kv_cache_spec: The kv cache spec of each attention layer in the model """ - def is_hybrid(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: - type_ids = set(layer_spec.type_id - for layer_spec in kv_cache_spec.values()) - return len(type_ids) > 1 - - if not is_hybrid(kv_cache_spec): + if is_kv_cache_type_uniform(kv_cache_spec): return logger.warning( @@ -1060,7 +1061,7 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): attention_chunk_size=spec.attention_chunk_size, ) - if is_hybrid(kv_cache_spec): + if not is_kv_cache_type_uniform(kv_cache_spec): raise ValueError("Hybrid KV cache manager is disabled but failed to " "convert the KV cache specs to one unified type.") @@ -1119,11 +1120,11 @@ def unify_kv_cache_configs(kv_cache_configs: list[KVCacheConfig]): in-place modified to make them consistent. """ - # Sort the kv cache groups by the type_id of their KV cache spec. + # Sort the kv cache groups by their KV cache spec. # This can avoid the inconsistency caused by the order of groups. for kv_cache_config in kv_cache_configs: - kv_cache_config.kv_cache_groups.sort( - key=lambda x: x.kv_cache_spec.type_id) + kv_cache_config.kv_cache_groups.sort(key=lambda x: (type( + x.kv_cache_spec).__name__, astuple(x.kv_cache_spec))) # Verify that the groups of each rank are the same. for kv_cache_config in kv_cache_configs[1:]: diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 1da5230116d26..4ff96f9786b88 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy -from dataclasses import dataclass +from dataclasses import dataclass, fields from math import prod from typing import Optional @@ -16,7 +16,7 @@ from vllm.utils import cdiv, get_dtype_size logger = init_logger(__name__) -@dataclass +@dataclass(frozen=True) class KVCacheSpec: """ A base class for specifying the KV cache format of one layer. @@ -25,20 +25,6 @@ class KVCacheSpec: # number of tokens in a block block_size: int - @property - def type_id(self) -> str: - """ - The type identifier of this KV cache. - Return different strings for layers with different KV cache type (e.g., - different number of tokens like full attention vs sliding window - attention, different KV cache size per token like layers with different - number of heads) - - Returns: - The type identifier of this KV cache. - """ - raise NotImplementedError - @property def page_size_bytes(self) -> int: """ @@ -63,13 +49,12 @@ class KVCacheSpec: """ Merge a list of KVCacheSpec objects into a single KVCacheSpec object. """ - assert all(spec.type_id == specs[0].type_id for spec in specs[1:]), ( - "All layers in the same KV cache group must share the same " - "type_id.") + assert all(spec == specs[0] for spec in specs[1:]), ( + "All layers in the same KV cache group must be the same.") return copy.deepcopy(specs[0]) -@dataclass +@dataclass(frozen=True) class AttentionSpec(KVCacheSpec): num_kv_heads: int head_size: int @@ -84,7 +69,7 @@ class AttentionSpec(KVCacheSpec): * get_dtype_size(self.dtype) -@dataclass +@dataclass(frozen=True) class FullAttentionSpec(AttentionSpec): sliding_window: Optional[int] = None attention_chunk_size: Optional[int] = None @@ -98,10 +83,6 @@ class FullAttentionSpec(AttentionSpec): Default to None for not using sliding window attention. """ - @property - def type_id(self) -> str: - return f"full_attention_{self.block_size}_{self.page_size_bytes}" - def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: max_model_len = vllm_config.model_config.max_model_len return cdiv(max_model_len, self.block_size) * self.page_size_bytes @@ -123,15 +104,28 @@ class FullAttentionSpec(AttentionSpec): Merge a list of FullAttentionSpec objects into a single FullAttentionSpec object. """ - merged_spec = super().merge(specs) + assert all(isinstance(spec, FullAttentionSpec) for spec in specs), ( + "All attention layers in the same KV cache group must be " + "FullAttentionSpec.") + sliding_window = set(spec.sliding_window for spec in specs if spec.sliding_window is not None) attention_chunk_size = set(spec.attention_chunk_size for spec in specs if spec.attention_chunk_size is not None) - - merged_spec.sliding_window = cls.merge_window_sizes(sliding_window) - merged_spec.attention_chunk_size = ( - cls.merge_window_sizes(attention_chunk_size)) + merged_spec = cls( + block_size=specs[0].block_size, + num_kv_heads=specs[0].num_kv_heads, + head_size=specs[0].head_size, + dtype=specs[0].dtype, + use_mla=specs[0].use_mla, + sliding_window=cls.merge_window_sizes(sliding_window), + attention_chunk_size=cls.merge_window_sizes(attention_chunk_size), + ) + for spec in specs: + for f in fields(AttentionSpec): + assert getattr(spec, f.name) == getattr(merged_spec, f.name), ( + "All attention layers in the same KV cache group must have " + "the same attention spec.") assert ( (merged_spec.sliding_window is not None) + (merged_spec.attention_chunk_size is not None) <= 1 @@ -140,16 +134,10 @@ class FullAttentionSpec(AttentionSpec): return merged_spec -@dataclass +@dataclass(frozen=True) class ChunkedLocalAttentionSpec(AttentionSpec): attention_chunk_size: int - @property - def type_id(self) -> str: - return ( - f"local_attention_{self.attention_chunk_size}_{self.block_size}_{self.page_size_bytes}" - ) # noqa - def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: max_model_len = vllm_config.model_config.max_model_len max_num_batched_tokens = ( @@ -165,17 +153,13 @@ class ChunkedLocalAttentionSpec(AttentionSpec): return cdiv(num_tokens, self.block_size) * self.page_size_bytes -@dataclass +@dataclass(frozen=True) class SlidingWindowSpec(AttentionSpec): sliding_window: int def __post_init__(self): assert not self.use_mla, "MLA is not supported for sliding window" - @property - def type_id(self) -> str: - return f"sliding_window_{self.sliding_window}_{self.block_size}_{self.page_size_bytes}" # noqa - def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: max_model_len = vllm_config.model_config.max_model_len max_num_batched_tokens = ( @@ -195,23 +179,17 @@ class SlidingWindowSpec(AttentionSpec): return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes -@dataclass +@dataclass(frozen=True) class MambaSpec(KVCacheSpec): shapes: tuple[tuple[int, ...], ...] dtype: torch.dtype page_size_padded: Optional[int] = None mamba_type: str = "mamba2" - def __post_init__(self): - self.num_elements = sum(prod(shape) for shape in self.shapes) - - @property - def type_id(self) -> str: - return f"mamba_{self.shapes}_{self.dtype}_{self.mamba_type}" - @property def page_size_bytes(self) -> int: - page_size = self.num_elements * get_dtype_size(self.dtype) + num_elements = sum(prod(shape) for shape in self.shapes) + page_size = num_elements * get_dtype_size(self.dtype) if self.page_size_padded is not None: assert self.page_size_padded >= page_size return self.page_size_padded