mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-28 03:57:15 +08:00
[KVCache] Make KVCacheSpec hashable (#21791)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
parent
2470419119
commit
755fa8b657
@ -17,7 +17,7 @@ from vllm.v1.core.kv_cache_utils import (
|
||||
estimate_max_model_len, generate_block_hash_extra_keys,
|
||||
get_kv_cache_config, get_max_concurrency_for_kv_cache_config,
|
||||
hash_block_tokens, hash_request_tokens, init_none_hash,
|
||||
unify_kv_cache_configs)
|
||||
is_kv_cache_type_uniform, unify_kv_cache_configs)
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, KVCacheTensor,
|
||||
SlidingWindowSpec)
|
||||
@ -685,6 +685,38 @@ def test_merge_kv_cache_spec():
|
||||
assert merged_layer_spec.sliding_window == 1
|
||||
|
||||
|
||||
def test_is_kv_cache_type_uniform():
|
||||
kv_cache_spec = {
|
||||
"layer_1": new_kv_cache_spec(num_kv_heads=32),
|
||||
"layer_2": new_kv_cache_spec(num_kv_heads=32),
|
||||
}
|
||||
assert is_kv_cache_type_uniform(kv_cache_spec)
|
||||
|
||||
kv_cache_spec = {
|
||||
"layer_1": new_kv_cache_spec(num_kv_heads=32),
|
||||
"layer_2": new_kv_cache_spec(num_kv_heads=32, sliding_window=1),
|
||||
}
|
||||
assert is_kv_cache_type_uniform(kv_cache_spec)
|
||||
|
||||
kv_cache_spec = {
|
||||
"layer_1": new_kv_cache_spec(num_kv_heads=32),
|
||||
"layer_2": new_sliding_window_spec(num_kv_heads=32, sliding_window=1),
|
||||
}
|
||||
assert not is_kv_cache_type_uniform(kv_cache_spec)
|
||||
|
||||
kv_cache_spec = {
|
||||
"layer_1": new_sliding_window_spec(num_kv_heads=32, sliding_window=1),
|
||||
"layer_2": new_sliding_window_spec(num_kv_heads=32, sliding_window=1),
|
||||
}
|
||||
assert is_kv_cache_type_uniform(kv_cache_spec)
|
||||
|
||||
kv_cache_spec = {
|
||||
"layer_1": new_sliding_window_spec(num_kv_heads=32, sliding_window=1),
|
||||
"layer_2": new_sliding_window_spec(num_kv_heads=32, sliding_window=2),
|
||||
}
|
||||
assert not is_kv_cache_type_uniform(kv_cache_spec)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_id", "max_model_len", "want_estimated_max_len"), [
|
||||
("Qwen/Qwen1.5-7B", 16385, 16384),
|
||||
|
||||
@ -30,7 +30,9 @@ model_config = {
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [5])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_sliding_window_retrieval(monkeypatch, model, batch_size, seed):
|
||||
@pytest.mark.parametrize("disable_hybrid_kv_cache_manager", [True, False])
|
||||
def test_sliding_window_retrieval(monkeypatch, model, batch_size, seed,
|
||||
disable_hybrid_kv_cache_manager):
|
||||
"""
|
||||
The test does a bunch of assignments "x1 = 10\nx2 = 33\n..." and then
|
||||
asks for value of one of them (which is outside the sliding window).
|
||||
@ -42,7 +44,9 @@ def test_sliding_window_retrieval(monkeypatch, model, batch_size, seed):
|
||||
|
||||
test_config = model_config[model]
|
||||
|
||||
llm = LLM(model=model)
|
||||
llm = LLM(
|
||||
model=model,
|
||||
disable_hybrid_kv_cache_manager=disable_hybrid_kv_cache_manager)
|
||||
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
|
||||
|
||||
prompts, answer, indices = prep_prompts(batch_size,
|
||||
|
||||
@ -7,7 +7,8 @@ from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
|
||||
from vllm.v1.core.single_type_kv_cache_manager import (
|
||||
FullAttentionManager, get_manager_for_kv_cache_spec)
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheSpec)
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
@ -258,44 +259,40 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
||||
one of them is full attention. Then, split the kv cache groups into full
|
||||
attention groups and other groups.
|
||||
"""
|
||||
full_attention_type_id: Optional[str] = None
|
||||
other_type_id: Optional[str] = None
|
||||
full_attention_spec: Optional[FullAttentionSpec] = None
|
||||
other_spec: Optional[KVCacheSpec] = None
|
||||
self.full_attention_group_ids: list[int] = []
|
||||
self.other_group_ids: list[int] = []
|
||||
for i, g in enumerate(self.kv_cache_config.kv_cache_groups):
|
||||
if isinstance(g.kv_cache_spec, FullAttentionSpec):
|
||||
if full_attention_type_id is None:
|
||||
full_attention_type_id = g.kv_cache_spec.type_id
|
||||
if full_attention_spec is None:
|
||||
full_attention_spec = g.kv_cache_spec
|
||||
else:
|
||||
assert full_attention_type_id == g.kv_cache_spec.type_id, (
|
||||
assert full_attention_spec == g.kv_cache_spec, (
|
||||
"HybridKVCacheCoordinator assumes exactly one type of "
|
||||
"full attention groups now.")
|
||||
self.full_attention_group_ids.append(i)
|
||||
else:
|
||||
if other_type_id is None:
|
||||
other_type_id = g.kv_cache_spec.type_id
|
||||
if other_spec is None:
|
||||
other_spec = g.kv_cache_spec
|
||||
else:
|
||||
assert other_type_id == g.kv_cache_spec.type_id, (
|
||||
assert other_spec == g.kv_cache_spec, (
|
||||
"HybridKVCacheCoordinator assumes "
|
||||
"exactly one other type of groups now.")
|
||||
self.other_group_ids.append(i)
|
||||
|
||||
assert full_attention_type_id is not None, (
|
||||
assert full_attention_spec is not None, (
|
||||
"HybridKVCacheCoordinator assumes exactly one type of full "
|
||||
"attention groups now.")
|
||||
assert other_type_id is not None, (
|
||||
assert other_spec is not None, (
|
||||
"HybridKVCacheCoordinator assumes exactly one type of other "
|
||||
"groups now.")
|
||||
|
||||
self.full_attention_manager_cls = FullAttentionManager
|
||||
self.other_attention_cls = self.single_type_managers[
|
||||
self.other_group_ids[0]].__class__
|
||||
|
||||
self.full_attention_spec = self.kv_cache_config.kv_cache_groups[
|
||||
self.full_attention_group_ids[0]].kv_cache_spec
|
||||
self.other_spec = self.kv_cache_config.kv_cache_groups[
|
||||
self.other_group_ids[0]].kv_cache_spec
|
||||
|
||||
self.full_attention_spec = full_attention_spec
|
||||
self.other_spec = other_spec
|
||||
self.full_attention_block_size = self.full_attention_spec.block_size
|
||||
self.other_block_size = self.other_spec.block_size
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
import os
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Iterable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import astuple, dataclass
|
||||
from typing import Any, Callable, NamedTuple, Optional
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
@ -727,7 +727,9 @@ def create_kv_cache_group_specs(
|
||||
|
||||
def is_kv_cache_type_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
|
||||
"""
|
||||
Whether all layers in the given KVCacheSpec have the same type of KV cache.
|
||||
Whether all layers in the given KVCacheSpec have the same KV cache spec.
|
||||
Note that we regard FullAttentionSpec with and without sliding window as
|
||||
the same type.
|
||||
|
||||
Args:
|
||||
kv_cache_spec: The kv cache spec of each attention layer in the model
|
||||
@ -736,8 +738,12 @@ def is_kv_cache_type_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
|
||||
True if all layers have the same type, False otherwise.
|
||||
"""
|
||||
|
||||
layer_keys = set(layer.type_id for layer in kv_cache_spec.values())
|
||||
return len(layer_keys) == 1
|
||||
try:
|
||||
kv_cache_spec_values = list(kv_cache_spec.values())
|
||||
_ = kv_cache_spec_values[0].merge(kv_cache_spec_values)
|
||||
except AssertionError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_max_concurrency_for_kv_cache_config(
|
||||
@ -928,12 +934,12 @@ def _get_kv_cache_config_uniform_page_size(
|
||||
Returns:
|
||||
The generated KVCacheConfig
|
||||
"""
|
||||
# Group all layers by type_id.
|
||||
# Group all layers by kv_cache_spec.
|
||||
# E.g., 2 full attention layers and 3 sliding window attention layers,
|
||||
# -> (full.0, full.1), (sw.0, sw.1, sw.2).
|
||||
same_type_layers: dict[str, list[str]] = defaultdict(list)
|
||||
same_type_layers: dict[KVCacheSpec, list[str]] = defaultdict(list)
|
||||
for layer_name, layer_spec in kv_cache_spec.items():
|
||||
same_type_layers[layer_spec.type_id].append(layer_name)
|
||||
same_type_layers[layer_spec].append(layer_name)
|
||||
|
||||
# Split each group into smaller groups, to make the number of layers in each
|
||||
# group identical. Add padding to the last group of each type if necessary.
|
||||
@ -1017,12 +1023,7 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
|
||||
kv_cache_spec: The kv cache spec of each attention layer in the model
|
||||
"""
|
||||
|
||||
def is_hybrid(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
|
||||
type_ids = set(layer_spec.type_id
|
||||
for layer_spec in kv_cache_spec.values())
|
||||
return len(type_ids) > 1
|
||||
|
||||
if not is_hybrid(kv_cache_spec):
|
||||
if is_kv_cache_type_uniform(kv_cache_spec):
|
||||
return
|
||||
|
||||
logger.warning(
|
||||
@ -1060,7 +1061,7 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
|
||||
attention_chunk_size=spec.attention_chunk_size,
|
||||
)
|
||||
|
||||
if is_hybrid(kv_cache_spec):
|
||||
if not is_kv_cache_type_uniform(kv_cache_spec):
|
||||
raise ValueError("Hybrid KV cache manager is disabled but failed to "
|
||||
"convert the KV cache specs to one unified type.")
|
||||
|
||||
@ -1119,11 +1120,11 @@ def unify_kv_cache_configs(kv_cache_configs: list[KVCacheConfig]):
|
||||
in-place modified to make them consistent.
|
||||
"""
|
||||
|
||||
# Sort the kv cache groups by the type_id of their KV cache spec.
|
||||
# Sort the kv cache groups by their KV cache spec.
|
||||
# This can avoid the inconsistency caused by the order of groups.
|
||||
for kv_cache_config in kv_cache_configs:
|
||||
kv_cache_config.kv_cache_groups.sort(
|
||||
key=lambda x: x.kv_cache_spec.type_id)
|
||||
kv_cache_config.kv_cache_groups.sort(key=lambda x: (type(
|
||||
x.kv_cache_spec).__name__, astuple(x.kv_cache_spec)))
|
||||
|
||||
# Verify that the groups of each rank are the same.
|
||||
for kv_cache_config in kv_cache_configs[1:]:
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, fields
|
||||
from math import prod
|
||||
from typing import Optional
|
||||
|
||||
@ -16,7 +16,7 @@ from vllm.utils import cdiv, get_dtype_size
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class KVCacheSpec:
|
||||
"""
|
||||
A base class for specifying the KV cache format of one layer.
|
||||
@ -25,20 +25,6 @@ class KVCacheSpec:
|
||||
# number of tokens in a block
|
||||
block_size: int
|
||||
|
||||
@property
|
||||
def type_id(self) -> str:
|
||||
"""
|
||||
The type identifier of this KV cache.
|
||||
Return different strings for layers with different KV cache type (e.g.,
|
||||
different number of tokens like full attention vs sliding window
|
||||
attention, different KV cache size per token like layers with different
|
||||
number of heads)
|
||||
|
||||
Returns:
|
||||
The type identifier of this KV cache.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
"""
|
||||
@ -63,13 +49,12 @@ class KVCacheSpec:
|
||||
"""
|
||||
Merge a list of KVCacheSpec objects into a single KVCacheSpec object.
|
||||
"""
|
||||
assert all(spec.type_id == specs[0].type_id for spec in specs[1:]), (
|
||||
"All layers in the same KV cache group must share the same "
|
||||
"type_id.")
|
||||
assert all(spec == specs[0] for spec in specs[1:]), (
|
||||
"All layers in the same KV cache group must be the same.")
|
||||
return copy.deepcopy(specs[0])
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class AttentionSpec(KVCacheSpec):
|
||||
num_kv_heads: int
|
||||
head_size: int
|
||||
@ -84,7 +69,7 @@ class AttentionSpec(KVCacheSpec):
|
||||
* get_dtype_size(self.dtype)
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class FullAttentionSpec(AttentionSpec):
|
||||
sliding_window: Optional[int] = None
|
||||
attention_chunk_size: Optional[int] = None
|
||||
@ -98,10 +83,6 @@ class FullAttentionSpec(AttentionSpec):
|
||||
Default to None for not using sliding window attention.
|
||||
"""
|
||||
|
||||
@property
|
||||
def type_id(self) -> str:
|
||||
return f"full_attention_{self.block_size}_{self.page_size_bytes}"
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
|
||||
@ -123,15 +104,28 @@ class FullAttentionSpec(AttentionSpec):
|
||||
Merge a list of FullAttentionSpec objects into a single
|
||||
FullAttentionSpec object.
|
||||
"""
|
||||
merged_spec = super().merge(specs)
|
||||
assert all(isinstance(spec, FullAttentionSpec) for spec in specs), (
|
||||
"All attention layers in the same KV cache group must be "
|
||||
"FullAttentionSpec.")
|
||||
|
||||
sliding_window = set(spec.sliding_window for spec in specs
|
||||
if spec.sliding_window is not None)
|
||||
attention_chunk_size = set(spec.attention_chunk_size for spec in specs
|
||||
if spec.attention_chunk_size is not None)
|
||||
|
||||
merged_spec.sliding_window = cls.merge_window_sizes(sliding_window)
|
||||
merged_spec.attention_chunk_size = (
|
||||
cls.merge_window_sizes(attention_chunk_size))
|
||||
merged_spec = cls(
|
||||
block_size=specs[0].block_size,
|
||||
num_kv_heads=specs[0].num_kv_heads,
|
||||
head_size=specs[0].head_size,
|
||||
dtype=specs[0].dtype,
|
||||
use_mla=specs[0].use_mla,
|
||||
sliding_window=cls.merge_window_sizes(sliding_window),
|
||||
attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
|
||||
)
|
||||
for spec in specs:
|
||||
for f in fields(AttentionSpec):
|
||||
assert getattr(spec, f.name) == getattr(merged_spec, f.name), (
|
||||
"All attention layers in the same KV cache group must have "
|
||||
"the same attention spec.")
|
||||
assert (
|
||||
(merged_spec.sliding_window is not None) +
|
||||
(merged_spec.attention_chunk_size is not None) <= 1
|
||||
@ -140,16 +134,10 @@ class FullAttentionSpec(AttentionSpec):
|
||||
return merged_spec
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class ChunkedLocalAttentionSpec(AttentionSpec):
|
||||
attention_chunk_size: int
|
||||
|
||||
@property
|
||||
def type_id(self) -> str:
|
||||
return (
|
||||
f"local_attention_{self.attention_chunk_size}_{self.block_size}_{self.page_size_bytes}"
|
||||
) # noqa
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
max_num_batched_tokens = (
|
||||
@ -165,17 +153,13 @@ class ChunkedLocalAttentionSpec(AttentionSpec):
|
||||
return cdiv(num_tokens, self.block_size) * self.page_size_bytes
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class SlidingWindowSpec(AttentionSpec):
|
||||
sliding_window: int
|
||||
|
||||
def __post_init__(self):
|
||||
assert not self.use_mla, "MLA is not supported for sliding window"
|
||||
|
||||
@property
|
||||
def type_id(self) -> str:
|
||||
return f"sliding_window_{self.sliding_window}_{self.block_size}_{self.page_size_bytes}" # noqa
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
max_num_batched_tokens = (
|
||||
@ -195,23 +179,17 @@ class SlidingWindowSpec(AttentionSpec):
|
||||
return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class MambaSpec(KVCacheSpec):
|
||||
shapes: tuple[tuple[int, ...], ...]
|
||||
dtype: torch.dtype
|
||||
page_size_padded: Optional[int] = None
|
||||
mamba_type: str = "mamba2"
|
||||
|
||||
def __post_init__(self):
|
||||
self.num_elements = sum(prod(shape) for shape in self.shapes)
|
||||
|
||||
@property
|
||||
def type_id(self) -> str:
|
||||
return f"mamba_{self.shapes}_{self.dtype}_{self.mamba_type}"
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
page_size = self.num_elements * get_dtype_size(self.dtype)
|
||||
num_elements = sum(prod(shape) for shape in self.shapes)
|
||||
page_size = num_elements * get_dtype_size(self.dtype)
|
||||
if self.page_size_padded is not None:
|
||||
assert self.page_size_padded >= page_size
|
||||
return self.page_size_padded
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user