[KVCache] Make KVCacheSpec hashable (#21791)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang 2025-07-29 04:58:29 -07:00 committed by GitHub
parent 2470419119
commit 755fa8b657
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 100 additions and 88 deletions

View File

@ -17,7 +17,7 @@ from vllm.v1.core.kv_cache_utils import (
estimate_max_model_len, generate_block_hash_extra_keys,
get_kv_cache_config, get_max_concurrency_for_kv_cache_config,
hash_block_tokens, hash_request_tokens, init_none_hash,
unify_kv_cache_configs)
is_kv_cache_type_uniform, unify_kv_cache_configs)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheTensor,
SlidingWindowSpec)
@ -685,6 +685,38 @@ def test_merge_kv_cache_spec():
assert merged_layer_spec.sliding_window == 1
def test_is_kv_cache_type_uniform():
kv_cache_spec = {
"layer_1": new_kv_cache_spec(num_kv_heads=32),
"layer_2": new_kv_cache_spec(num_kv_heads=32),
}
assert is_kv_cache_type_uniform(kv_cache_spec)
kv_cache_spec = {
"layer_1": new_kv_cache_spec(num_kv_heads=32),
"layer_2": new_kv_cache_spec(num_kv_heads=32, sliding_window=1),
}
assert is_kv_cache_type_uniform(kv_cache_spec)
kv_cache_spec = {
"layer_1": new_kv_cache_spec(num_kv_heads=32),
"layer_2": new_sliding_window_spec(num_kv_heads=32, sliding_window=1),
}
assert not is_kv_cache_type_uniform(kv_cache_spec)
kv_cache_spec = {
"layer_1": new_sliding_window_spec(num_kv_heads=32, sliding_window=1),
"layer_2": new_sliding_window_spec(num_kv_heads=32, sliding_window=1),
}
assert is_kv_cache_type_uniform(kv_cache_spec)
kv_cache_spec = {
"layer_1": new_sliding_window_spec(num_kv_heads=32, sliding_window=1),
"layer_2": new_sliding_window_spec(num_kv_heads=32, sliding_window=2),
}
assert not is_kv_cache_type_uniform(kv_cache_spec)
@pytest.mark.parametrize(
("model_id", "max_model_len", "want_estimated_max_len"), [
("Qwen/Qwen1.5-7B", 16385, 16384),

View File

@ -30,7 +30,9 @@ model_config = {
])
@pytest.mark.parametrize("batch_size", [5])
@pytest.mark.parametrize("seed", [1])
def test_sliding_window_retrieval(monkeypatch, model, batch_size, seed):
@pytest.mark.parametrize("disable_hybrid_kv_cache_manager", [True, False])
def test_sliding_window_retrieval(monkeypatch, model, batch_size, seed,
disable_hybrid_kv_cache_manager):
"""
The test does a bunch of assignments "x1 = 10\nx2 = 33\n..." and then
asks for value of one of them (which is outside the sliding window).
@ -42,7 +44,9 @@ def test_sliding_window_retrieval(monkeypatch, model, batch_size, seed):
test_config = model_config[model]
llm = LLM(model=model)
llm = LLM(
model=model,
disable_hybrid_kv_cache_manager=disable_hybrid_kv_cache_manager)
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
prompts, answer, indices = prep_prompts(batch_size,

View File

@ -7,7 +7,8 @@ from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
from vllm.v1.core.single_type_kv_cache_manager import (
FullAttentionManager, get_manager_for_kv_cache_spec)
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.request import Request
@ -258,44 +259,40 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
one of them is full attention. Then, split the kv cache groups into full
attention groups and other groups.
"""
full_attention_type_id: Optional[str] = None
other_type_id: Optional[str] = None
full_attention_spec: Optional[FullAttentionSpec] = None
other_spec: Optional[KVCacheSpec] = None
self.full_attention_group_ids: list[int] = []
self.other_group_ids: list[int] = []
for i, g in enumerate(self.kv_cache_config.kv_cache_groups):
if isinstance(g.kv_cache_spec, FullAttentionSpec):
if full_attention_type_id is None:
full_attention_type_id = g.kv_cache_spec.type_id
if full_attention_spec is None:
full_attention_spec = g.kv_cache_spec
else:
assert full_attention_type_id == g.kv_cache_spec.type_id, (
assert full_attention_spec == g.kv_cache_spec, (
"HybridKVCacheCoordinator assumes exactly one type of "
"full attention groups now.")
self.full_attention_group_ids.append(i)
else:
if other_type_id is None:
other_type_id = g.kv_cache_spec.type_id
if other_spec is None:
other_spec = g.kv_cache_spec
else:
assert other_type_id == g.kv_cache_spec.type_id, (
assert other_spec == g.kv_cache_spec, (
"HybridKVCacheCoordinator assumes "
"exactly one other type of groups now.")
self.other_group_ids.append(i)
assert full_attention_type_id is not None, (
assert full_attention_spec is not None, (
"HybridKVCacheCoordinator assumes exactly one type of full "
"attention groups now.")
assert other_type_id is not None, (
assert other_spec is not None, (
"HybridKVCacheCoordinator assumes exactly one type of other "
"groups now.")
self.full_attention_manager_cls = FullAttentionManager
self.other_attention_cls = self.single_type_managers[
self.other_group_ids[0]].__class__
self.full_attention_spec = self.kv_cache_config.kv_cache_groups[
self.full_attention_group_ids[0]].kv_cache_spec
self.other_spec = self.kv_cache_config.kv_cache_groups[
self.other_group_ids[0]].kv_cache_spec
self.full_attention_spec = full_attention_spec
self.other_spec = other_spec
self.full_attention_block_size = self.full_attention_spec.block_size
self.other_block_size = self.other_spec.block_size

View File

@ -5,7 +5,7 @@
import os
from collections import defaultdict, deque
from collections.abc import Iterable, Sequence
from dataclasses import dataclass
from dataclasses import astuple, dataclass
from typing import Any, Callable, NamedTuple, Optional
from vllm.config import VllmConfig
@ -727,7 +727,9 @@ def create_kv_cache_group_specs(
def is_kv_cache_type_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
"""
Whether all layers in the given KVCacheSpec have the same type of KV cache.
Whether all layers in the given KVCacheSpec have the same KV cache spec.
Note that we regard FullAttentionSpec with and without sliding window as
the same type.
Args:
kv_cache_spec: The kv cache spec of each attention layer in the model
@ -736,8 +738,12 @@ def is_kv_cache_type_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
True if all layers have the same type, False otherwise.
"""
layer_keys = set(layer.type_id for layer in kv_cache_spec.values())
return len(layer_keys) == 1
try:
kv_cache_spec_values = list(kv_cache_spec.values())
_ = kv_cache_spec_values[0].merge(kv_cache_spec_values)
except AssertionError:
return False
return True
def get_max_concurrency_for_kv_cache_config(
@ -928,12 +934,12 @@ def _get_kv_cache_config_uniform_page_size(
Returns:
The generated KVCacheConfig
"""
# Group all layers by type_id.
# Group all layers by kv_cache_spec.
# E.g., 2 full attention layers and 3 sliding window attention layers,
# -> (full.0, full.1), (sw.0, sw.1, sw.2).
same_type_layers: dict[str, list[str]] = defaultdict(list)
same_type_layers: dict[KVCacheSpec, list[str]] = defaultdict(list)
for layer_name, layer_spec in kv_cache_spec.items():
same_type_layers[layer_spec.type_id].append(layer_name)
same_type_layers[layer_spec].append(layer_name)
# Split each group into smaller groups, to make the number of layers in each
# group identical. Add padding to the last group of each type if necessary.
@ -1017,12 +1023,7 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
kv_cache_spec: The kv cache spec of each attention layer in the model
"""
def is_hybrid(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
type_ids = set(layer_spec.type_id
for layer_spec in kv_cache_spec.values())
return len(type_ids) > 1
if not is_hybrid(kv_cache_spec):
if is_kv_cache_type_uniform(kv_cache_spec):
return
logger.warning(
@ -1060,7 +1061,7 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
attention_chunk_size=spec.attention_chunk_size,
)
if is_hybrid(kv_cache_spec):
if not is_kv_cache_type_uniform(kv_cache_spec):
raise ValueError("Hybrid KV cache manager is disabled but failed to "
"convert the KV cache specs to one unified type.")
@ -1119,11 +1120,11 @@ def unify_kv_cache_configs(kv_cache_configs: list[KVCacheConfig]):
in-place modified to make them consistent.
"""
# Sort the kv cache groups by the type_id of their KV cache spec.
# Sort the kv cache groups by their KV cache spec.
# This can avoid the inconsistency caused by the order of groups.
for kv_cache_config in kv_cache_configs:
kv_cache_config.kv_cache_groups.sort(
key=lambda x: x.kv_cache_spec.type_id)
kv_cache_config.kv_cache_groups.sort(key=lambda x: (type(
x.kv_cache_spec).__name__, astuple(x.kv_cache_spec)))
# Verify that the groups of each rank are the same.
for kv_cache_config in kv_cache_configs[1:]:

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
from dataclasses import dataclass
from dataclasses import dataclass, fields
from math import prod
from typing import Optional
@ -16,7 +16,7 @@ from vllm.utils import cdiv, get_dtype_size
logger = init_logger(__name__)
@dataclass
@dataclass(frozen=True)
class KVCacheSpec:
"""
A base class for specifying the KV cache format of one layer.
@ -25,20 +25,6 @@ class KVCacheSpec:
# number of tokens in a block
block_size: int
@property
def type_id(self) -> str:
"""
The type identifier of this KV cache.
Return different strings for layers with different KV cache type (e.g.,
different number of tokens like full attention vs sliding window
attention, different KV cache size per token like layers with different
number of heads)
Returns:
The type identifier of this KV cache.
"""
raise NotImplementedError
@property
def page_size_bytes(self) -> int:
"""
@ -63,13 +49,12 @@ class KVCacheSpec:
"""
Merge a list of KVCacheSpec objects into a single KVCacheSpec object.
"""
assert all(spec.type_id == specs[0].type_id for spec in specs[1:]), (
"All layers in the same KV cache group must share the same "
"type_id.")
assert all(spec == specs[0] for spec in specs[1:]), (
"All layers in the same KV cache group must be the same.")
return copy.deepcopy(specs[0])
@dataclass
@dataclass(frozen=True)
class AttentionSpec(KVCacheSpec):
num_kv_heads: int
head_size: int
@ -84,7 +69,7 @@ class AttentionSpec(KVCacheSpec):
* get_dtype_size(self.dtype)
@dataclass
@dataclass(frozen=True)
class FullAttentionSpec(AttentionSpec):
sliding_window: Optional[int] = None
attention_chunk_size: Optional[int] = None
@ -98,10 +83,6 @@ class FullAttentionSpec(AttentionSpec):
Default to None for not using sliding window attention.
"""
@property
def type_id(self) -> str:
return f"full_attention_{self.block_size}_{self.page_size_bytes}"
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
@ -123,15 +104,28 @@ class FullAttentionSpec(AttentionSpec):
Merge a list of FullAttentionSpec objects into a single
FullAttentionSpec object.
"""
merged_spec = super().merge(specs)
assert all(isinstance(spec, FullAttentionSpec) for spec in specs), (
"All attention layers in the same KV cache group must be "
"FullAttentionSpec.")
sliding_window = set(spec.sliding_window for spec in specs
if spec.sliding_window is not None)
attention_chunk_size = set(spec.attention_chunk_size for spec in specs
if spec.attention_chunk_size is not None)
merged_spec.sliding_window = cls.merge_window_sizes(sliding_window)
merged_spec.attention_chunk_size = (
cls.merge_window_sizes(attention_chunk_size))
merged_spec = cls(
block_size=specs[0].block_size,
num_kv_heads=specs[0].num_kv_heads,
head_size=specs[0].head_size,
dtype=specs[0].dtype,
use_mla=specs[0].use_mla,
sliding_window=cls.merge_window_sizes(sliding_window),
attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
)
for spec in specs:
for f in fields(AttentionSpec):
assert getattr(spec, f.name) == getattr(merged_spec, f.name), (
"All attention layers in the same KV cache group must have "
"the same attention spec.")
assert (
(merged_spec.sliding_window is not None) +
(merged_spec.attention_chunk_size is not None) <= 1
@ -140,16 +134,10 @@ class FullAttentionSpec(AttentionSpec):
return merged_spec
@dataclass
@dataclass(frozen=True)
class ChunkedLocalAttentionSpec(AttentionSpec):
attention_chunk_size: int
@property
def type_id(self) -> str:
return (
f"local_attention_{self.attention_chunk_size}_{self.block_size}_{self.page_size_bytes}"
) # noqa
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len
max_num_batched_tokens = (
@ -165,17 +153,13 @@ class ChunkedLocalAttentionSpec(AttentionSpec):
return cdiv(num_tokens, self.block_size) * self.page_size_bytes
@dataclass
@dataclass(frozen=True)
class SlidingWindowSpec(AttentionSpec):
sliding_window: int
def __post_init__(self):
assert not self.use_mla, "MLA is not supported for sliding window"
@property
def type_id(self) -> str:
return f"sliding_window_{self.sliding_window}_{self.block_size}_{self.page_size_bytes}" # noqa
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len
max_num_batched_tokens = (
@ -195,23 +179,17 @@ class SlidingWindowSpec(AttentionSpec):
return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes
@dataclass
@dataclass(frozen=True)
class MambaSpec(KVCacheSpec):
shapes: tuple[tuple[int, ...], ...]
dtype: torch.dtype
page_size_padded: Optional[int] = None
mamba_type: str = "mamba2"
def __post_init__(self):
self.num_elements = sum(prod(shape) for shape in self.shapes)
@property
def type_id(self) -> str:
return f"mamba_{self.shapes}_{self.dtype}_{self.mamba_type}"
@property
def page_size_bytes(self) -> int:
page_size = self.num_elements * get_dtype_size(self.dtype)
num_elements = sum(prod(shape) for shape in self.shapes)
page_size = num_elements * get_dtype_size(self.dtype)
if self.page_size_padded is not None:
assert self.page_size_padded >= page_size
return self.page_size_padded