mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 07:29:08 +08:00
[KVCache] Make KVCacheSpec hashable (#21791)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
parent
2470419119
commit
755fa8b657
@ -17,7 +17,7 @@ from vllm.v1.core.kv_cache_utils import (
|
|||||||
estimate_max_model_len, generate_block_hash_extra_keys,
|
estimate_max_model_len, generate_block_hash_extra_keys,
|
||||||
get_kv_cache_config, get_max_concurrency_for_kv_cache_config,
|
get_kv_cache_config, get_max_concurrency_for_kv_cache_config,
|
||||||
hash_block_tokens, hash_request_tokens, init_none_hash,
|
hash_block_tokens, hash_request_tokens, init_none_hash,
|
||||||
unify_kv_cache_configs)
|
is_kv_cache_type_uniform, unify_kv_cache_configs)
|
||||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||||
KVCacheGroupSpec, KVCacheTensor,
|
KVCacheGroupSpec, KVCacheTensor,
|
||||||
SlidingWindowSpec)
|
SlidingWindowSpec)
|
||||||
@ -685,6 +685,38 @@ def test_merge_kv_cache_spec():
|
|||||||
assert merged_layer_spec.sliding_window == 1
|
assert merged_layer_spec.sliding_window == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_kv_cache_type_uniform():
|
||||||
|
kv_cache_spec = {
|
||||||
|
"layer_1": new_kv_cache_spec(num_kv_heads=32),
|
||||||
|
"layer_2": new_kv_cache_spec(num_kv_heads=32),
|
||||||
|
}
|
||||||
|
assert is_kv_cache_type_uniform(kv_cache_spec)
|
||||||
|
|
||||||
|
kv_cache_spec = {
|
||||||
|
"layer_1": new_kv_cache_spec(num_kv_heads=32),
|
||||||
|
"layer_2": new_kv_cache_spec(num_kv_heads=32, sliding_window=1),
|
||||||
|
}
|
||||||
|
assert is_kv_cache_type_uniform(kv_cache_spec)
|
||||||
|
|
||||||
|
kv_cache_spec = {
|
||||||
|
"layer_1": new_kv_cache_spec(num_kv_heads=32),
|
||||||
|
"layer_2": new_sliding_window_spec(num_kv_heads=32, sliding_window=1),
|
||||||
|
}
|
||||||
|
assert not is_kv_cache_type_uniform(kv_cache_spec)
|
||||||
|
|
||||||
|
kv_cache_spec = {
|
||||||
|
"layer_1": new_sliding_window_spec(num_kv_heads=32, sliding_window=1),
|
||||||
|
"layer_2": new_sliding_window_spec(num_kv_heads=32, sliding_window=1),
|
||||||
|
}
|
||||||
|
assert is_kv_cache_type_uniform(kv_cache_spec)
|
||||||
|
|
||||||
|
kv_cache_spec = {
|
||||||
|
"layer_1": new_sliding_window_spec(num_kv_heads=32, sliding_window=1),
|
||||||
|
"layer_2": new_sliding_window_spec(num_kv_heads=32, sliding_window=2),
|
||||||
|
}
|
||||||
|
assert not is_kv_cache_type_uniform(kv_cache_spec)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("model_id", "max_model_len", "want_estimated_max_len"), [
|
("model_id", "max_model_len", "want_estimated_max_len"), [
|
||||||
("Qwen/Qwen1.5-7B", 16385, 16384),
|
("Qwen/Qwen1.5-7B", 16385, 16384),
|
||||||
|
|||||||
@ -30,7 +30,9 @@ model_config = {
|
|||||||
])
|
])
|
||||||
@pytest.mark.parametrize("batch_size", [5])
|
@pytest.mark.parametrize("batch_size", [5])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_sliding_window_retrieval(monkeypatch, model, batch_size, seed):
|
@pytest.mark.parametrize("disable_hybrid_kv_cache_manager", [True, False])
|
||||||
|
def test_sliding_window_retrieval(monkeypatch, model, batch_size, seed,
|
||||||
|
disable_hybrid_kv_cache_manager):
|
||||||
"""
|
"""
|
||||||
The test does a bunch of assignments "x1 = 10\nx2 = 33\n..." and then
|
The test does a bunch of assignments "x1 = 10\nx2 = 33\n..." and then
|
||||||
asks for value of one of them (which is outside the sliding window).
|
asks for value of one of them (which is outside the sliding window).
|
||||||
@ -42,7 +44,9 @@ def test_sliding_window_retrieval(monkeypatch, model, batch_size, seed):
|
|||||||
|
|
||||||
test_config = model_config[model]
|
test_config = model_config[model]
|
||||||
|
|
||||||
llm = LLM(model=model)
|
llm = LLM(
|
||||||
|
model=model,
|
||||||
|
disable_hybrid_kv_cache_manager=disable_hybrid_kv_cache_manager)
|
||||||
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
|
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
|
||||||
|
|
||||||
prompts, answer, indices = prep_prompts(batch_size,
|
prompts, answer, indices = prep_prompts(batch_size,
|
||||||
|
|||||||
@ -7,7 +7,8 @@ from vllm.v1.core.block_pool import BlockPool
|
|||||||
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
|
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
|
||||||
from vllm.v1.core.single_type_kv_cache_manager import (
|
from vllm.v1.core.single_type_kv_cache_manager import (
|
||||||
FullAttentionManager, get_manager_for_kv_cache_spec)
|
FullAttentionManager, get_manager_for_kv_cache_spec)
|
||||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||||
|
KVCacheSpec)
|
||||||
from vllm.v1.request import Request
|
from vllm.v1.request import Request
|
||||||
|
|
||||||
|
|
||||||
@ -258,44 +259,40 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
|||||||
one of them is full attention. Then, split the kv cache groups into full
|
one of them is full attention. Then, split the kv cache groups into full
|
||||||
attention groups and other groups.
|
attention groups and other groups.
|
||||||
"""
|
"""
|
||||||
full_attention_type_id: Optional[str] = None
|
full_attention_spec: Optional[FullAttentionSpec] = None
|
||||||
other_type_id: Optional[str] = None
|
other_spec: Optional[KVCacheSpec] = None
|
||||||
self.full_attention_group_ids: list[int] = []
|
self.full_attention_group_ids: list[int] = []
|
||||||
self.other_group_ids: list[int] = []
|
self.other_group_ids: list[int] = []
|
||||||
for i, g in enumerate(self.kv_cache_config.kv_cache_groups):
|
for i, g in enumerate(self.kv_cache_config.kv_cache_groups):
|
||||||
if isinstance(g.kv_cache_spec, FullAttentionSpec):
|
if isinstance(g.kv_cache_spec, FullAttentionSpec):
|
||||||
if full_attention_type_id is None:
|
if full_attention_spec is None:
|
||||||
full_attention_type_id = g.kv_cache_spec.type_id
|
full_attention_spec = g.kv_cache_spec
|
||||||
else:
|
else:
|
||||||
assert full_attention_type_id == g.kv_cache_spec.type_id, (
|
assert full_attention_spec == g.kv_cache_spec, (
|
||||||
"HybridKVCacheCoordinator assumes exactly one type of "
|
"HybridKVCacheCoordinator assumes exactly one type of "
|
||||||
"full attention groups now.")
|
"full attention groups now.")
|
||||||
self.full_attention_group_ids.append(i)
|
self.full_attention_group_ids.append(i)
|
||||||
else:
|
else:
|
||||||
if other_type_id is None:
|
if other_spec is None:
|
||||||
other_type_id = g.kv_cache_spec.type_id
|
other_spec = g.kv_cache_spec
|
||||||
else:
|
else:
|
||||||
assert other_type_id == g.kv_cache_spec.type_id, (
|
assert other_spec == g.kv_cache_spec, (
|
||||||
"HybridKVCacheCoordinator assumes "
|
"HybridKVCacheCoordinator assumes "
|
||||||
"exactly one other type of groups now.")
|
"exactly one other type of groups now.")
|
||||||
self.other_group_ids.append(i)
|
self.other_group_ids.append(i)
|
||||||
|
|
||||||
assert full_attention_type_id is not None, (
|
assert full_attention_spec is not None, (
|
||||||
"HybridKVCacheCoordinator assumes exactly one type of full "
|
"HybridKVCacheCoordinator assumes exactly one type of full "
|
||||||
"attention groups now.")
|
"attention groups now.")
|
||||||
assert other_type_id is not None, (
|
assert other_spec is not None, (
|
||||||
"HybridKVCacheCoordinator assumes exactly one type of other "
|
"HybridKVCacheCoordinator assumes exactly one type of other "
|
||||||
"groups now.")
|
"groups now.")
|
||||||
|
|
||||||
self.full_attention_manager_cls = FullAttentionManager
|
self.full_attention_manager_cls = FullAttentionManager
|
||||||
self.other_attention_cls = self.single_type_managers[
|
self.other_attention_cls = self.single_type_managers[
|
||||||
self.other_group_ids[0]].__class__
|
self.other_group_ids[0]].__class__
|
||||||
|
self.full_attention_spec = full_attention_spec
|
||||||
self.full_attention_spec = self.kv_cache_config.kv_cache_groups[
|
self.other_spec = other_spec
|
||||||
self.full_attention_group_ids[0]].kv_cache_spec
|
|
||||||
self.other_spec = self.kv_cache_config.kv_cache_groups[
|
|
||||||
self.other_group_ids[0]].kv_cache_spec
|
|
||||||
|
|
||||||
self.full_attention_block_size = self.full_attention_spec.block_size
|
self.full_attention_block_size = self.full_attention_spec.block_size
|
||||||
self.other_block_size = self.other_spec.block_size
|
self.other_block_size = self.other_spec.block_size
|
||||||
|
|
||||||
|
|||||||
@ -5,7 +5,7 @@
|
|||||||
import os
|
import os
|
||||||
from collections import defaultdict, deque
|
from collections import defaultdict, deque
|
||||||
from collections.abc import Iterable, Sequence
|
from collections.abc import Iterable, Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import astuple, dataclass
|
||||||
from typing import Any, Callable, NamedTuple, Optional
|
from typing import Any, Callable, NamedTuple, Optional
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
@ -727,7 +727,9 @@ def create_kv_cache_group_specs(
|
|||||||
|
|
||||||
def is_kv_cache_type_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
|
def is_kv_cache_type_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
|
||||||
"""
|
"""
|
||||||
Whether all layers in the given KVCacheSpec have the same type of KV cache.
|
Whether all layers in the given KVCacheSpec have the same KV cache spec.
|
||||||
|
Note that we regard FullAttentionSpec with and without sliding window as
|
||||||
|
the same type.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
kv_cache_spec: The kv cache spec of each attention layer in the model
|
kv_cache_spec: The kv cache spec of each attention layer in the model
|
||||||
@ -736,8 +738,12 @@ def is_kv_cache_type_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
|
|||||||
True if all layers have the same type, False otherwise.
|
True if all layers have the same type, False otherwise.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
layer_keys = set(layer.type_id for layer in kv_cache_spec.values())
|
try:
|
||||||
return len(layer_keys) == 1
|
kv_cache_spec_values = list(kv_cache_spec.values())
|
||||||
|
_ = kv_cache_spec_values[0].merge(kv_cache_spec_values)
|
||||||
|
except AssertionError:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def get_max_concurrency_for_kv_cache_config(
|
def get_max_concurrency_for_kv_cache_config(
|
||||||
@ -928,12 +934,12 @@ def _get_kv_cache_config_uniform_page_size(
|
|||||||
Returns:
|
Returns:
|
||||||
The generated KVCacheConfig
|
The generated KVCacheConfig
|
||||||
"""
|
"""
|
||||||
# Group all layers by type_id.
|
# Group all layers by kv_cache_spec.
|
||||||
# E.g., 2 full attention layers and 3 sliding window attention layers,
|
# E.g., 2 full attention layers and 3 sliding window attention layers,
|
||||||
# -> (full.0, full.1), (sw.0, sw.1, sw.2).
|
# -> (full.0, full.1), (sw.0, sw.1, sw.2).
|
||||||
same_type_layers: dict[str, list[str]] = defaultdict(list)
|
same_type_layers: dict[KVCacheSpec, list[str]] = defaultdict(list)
|
||||||
for layer_name, layer_spec in kv_cache_spec.items():
|
for layer_name, layer_spec in kv_cache_spec.items():
|
||||||
same_type_layers[layer_spec.type_id].append(layer_name)
|
same_type_layers[layer_spec].append(layer_name)
|
||||||
|
|
||||||
# Split each group into smaller groups, to make the number of layers in each
|
# Split each group into smaller groups, to make the number of layers in each
|
||||||
# group identical. Add padding to the last group of each type if necessary.
|
# group identical. Add padding to the last group of each type if necessary.
|
||||||
@ -1017,12 +1023,7 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
|
|||||||
kv_cache_spec: The kv cache spec of each attention layer in the model
|
kv_cache_spec: The kv cache spec of each attention layer in the model
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def is_hybrid(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
|
if is_kv_cache_type_uniform(kv_cache_spec):
|
||||||
type_ids = set(layer_spec.type_id
|
|
||||||
for layer_spec in kv_cache_spec.values())
|
|
||||||
return len(type_ids) > 1
|
|
||||||
|
|
||||||
if not is_hybrid(kv_cache_spec):
|
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -1060,7 +1061,7 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
|
|||||||
attention_chunk_size=spec.attention_chunk_size,
|
attention_chunk_size=spec.attention_chunk_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_hybrid(kv_cache_spec):
|
if not is_kv_cache_type_uniform(kv_cache_spec):
|
||||||
raise ValueError("Hybrid KV cache manager is disabled but failed to "
|
raise ValueError("Hybrid KV cache manager is disabled but failed to "
|
||||||
"convert the KV cache specs to one unified type.")
|
"convert the KV cache specs to one unified type.")
|
||||||
|
|
||||||
@ -1119,11 +1120,11 @@ def unify_kv_cache_configs(kv_cache_configs: list[KVCacheConfig]):
|
|||||||
in-place modified to make them consistent.
|
in-place modified to make them consistent.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Sort the kv cache groups by the type_id of their KV cache spec.
|
# Sort the kv cache groups by their KV cache spec.
|
||||||
# This can avoid the inconsistency caused by the order of groups.
|
# This can avoid the inconsistency caused by the order of groups.
|
||||||
for kv_cache_config in kv_cache_configs:
|
for kv_cache_config in kv_cache_configs:
|
||||||
kv_cache_config.kv_cache_groups.sort(
|
kv_cache_config.kv_cache_groups.sort(key=lambda x: (type(
|
||||||
key=lambda x: x.kv_cache_spec.type_id)
|
x.kv_cache_spec).__name__, astuple(x.kv_cache_spec)))
|
||||||
|
|
||||||
# Verify that the groups of each rank are the same.
|
# Verify that the groups of each rank are the same.
|
||||||
for kv_cache_config in kv_cache_configs[1:]:
|
for kv_cache_config in kv_cache_configs[1:]:
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, fields
|
||||||
from math import prod
|
from math import prod
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@ -16,7 +16,7 @@ from vllm.utils import cdiv, get_dtype_size
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass(frozen=True)
|
||||||
class KVCacheSpec:
|
class KVCacheSpec:
|
||||||
"""
|
"""
|
||||||
A base class for specifying the KV cache format of one layer.
|
A base class for specifying the KV cache format of one layer.
|
||||||
@ -25,20 +25,6 @@ class KVCacheSpec:
|
|||||||
# number of tokens in a block
|
# number of tokens in a block
|
||||||
block_size: int
|
block_size: int
|
||||||
|
|
||||||
@property
|
|
||||||
def type_id(self) -> str:
|
|
||||||
"""
|
|
||||||
The type identifier of this KV cache.
|
|
||||||
Return different strings for layers with different KV cache type (e.g.,
|
|
||||||
different number of tokens like full attention vs sliding window
|
|
||||||
attention, different KV cache size per token like layers with different
|
|
||||||
number of heads)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The type identifier of this KV cache.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def page_size_bytes(self) -> int:
|
def page_size_bytes(self) -> int:
|
||||||
"""
|
"""
|
||||||
@ -63,13 +49,12 @@ class KVCacheSpec:
|
|||||||
"""
|
"""
|
||||||
Merge a list of KVCacheSpec objects into a single KVCacheSpec object.
|
Merge a list of KVCacheSpec objects into a single KVCacheSpec object.
|
||||||
"""
|
"""
|
||||||
assert all(spec.type_id == specs[0].type_id for spec in specs[1:]), (
|
assert all(spec == specs[0] for spec in specs[1:]), (
|
||||||
"All layers in the same KV cache group must share the same "
|
"All layers in the same KV cache group must be the same.")
|
||||||
"type_id.")
|
|
||||||
return copy.deepcopy(specs[0])
|
return copy.deepcopy(specs[0])
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass(frozen=True)
|
||||||
class AttentionSpec(KVCacheSpec):
|
class AttentionSpec(KVCacheSpec):
|
||||||
num_kv_heads: int
|
num_kv_heads: int
|
||||||
head_size: int
|
head_size: int
|
||||||
@ -84,7 +69,7 @@ class AttentionSpec(KVCacheSpec):
|
|||||||
* get_dtype_size(self.dtype)
|
* get_dtype_size(self.dtype)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass(frozen=True)
|
||||||
class FullAttentionSpec(AttentionSpec):
|
class FullAttentionSpec(AttentionSpec):
|
||||||
sliding_window: Optional[int] = None
|
sliding_window: Optional[int] = None
|
||||||
attention_chunk_size: Optional[int] = None
|
attention_chunk_size: Optional[int] = None
|
||||||
@ -98,10 +83,6 @@ class FullAttentionSpec(AttentionSpec):
|
|||||||
Default to None for not using sliding window attention.
|
Default to None for not using sliding window attention.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@property
|
|
||||||
def type_id(self) -> str:
|
|
||||||
return f"full_attention_{self.block_size}_{self.page_size_bytes}"
|
|
||||||
|
|
||||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||||
max_model_len = vllm_config.model_config.max_model_len
|
max_model_len = vllm_config.model_config.max_model_len
|
||||||
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
|
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
|
||||||
@ -123,15 +104,28 @@ class FullAttentionSpec(AttentionSpec):
|
|||||||
Merge a list of FullAttentionSpec objects into a single
|
Merge a list of FullAttentionSpec objects into a single
|
||||||
FullAttentionSpec object.
|
FullAttentionSpec object.
|
||||||
"""
|
"""
|
||||||
merged_spec = super().merge(specs)
|
assert all(isinstance(spec, FullAttentionSpec) for spec in specs), (
|
||||||
|
"All attention layers in the same KV cache group must be "
|
||||||
|
"FullAttentionSpec.")
|
||||||
|
|
||||||
sliding_window = set(spec.sliding_window for spec in specs
|
sliding_window = set(spec.sliding_window for spec in specs
|
||||||
if spec.sliding_window is not None)
|
if spec.sliding_window is not None)
|
||||||
attention_chunk_size = set(spec.attention_chunk_size for spec in specs
|
attention_chunk_size = set(spec.attention_chunk_size for spec in specs
|
||||||
if spec.attention_chunk_size is not None)
|
if spec.attention_chunk_size is not None)
|
||||||
|
merged_spec = cls(
|
||||||
merged_spec.sliding_window = cls.merge_window_sizes(sliding_window)
|
block_size=specs[0].block_size,
|
||||||
merged_spec.attention_chunk_size = (
|
num_kv_heads=specs[0].num_kv_heads,
|
||||||
cls.merge_window_sizes(attention_chunk_size))
|
head_size=specs[0].head_size,
|
||||||
|
dtype=specs[0].dtype,
|
||||||
|
use_mla=specs[0].use_mla,
|
||||||
|
sliding_window=cls.merge_window_sizes(sliding_window),
|
||||||
|
attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
|
||||||
|
)
|
||||||
|
for spec in specs:
|
||||||
|
for f in fields(AttentionSpec):
|
||||||
|
assert getattr(spec, f.name) == getattr(merged_spec, f.name), (
|
||||||
|
"All attention layers in the same KV cache group must have "
|
||||||
|
"the same attention spec.")
|
||||||
assert (
|
assert (
|
||||||
(merged_spec.sliding_window is not None) +
|
(merged_spec.sliding_window is not None) +
|
||||||
(merged_spec.attention_chunk_size is not None) <= 1
|
(merged_spec.attention_chunk_size is not None) <= 1
|
||||||
@ -140,16 +134,10 @@ class FullAttentionSpec(AttentionSpec):
|
|||||||
return merged_spec
|
return merged_spec
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass(frozen=True)
|
||||||
class ChunkedLocalAttentionSpec(AttentionSpec):
|
class ChunkedLocalAttentionSpec(AttentionSpec):
|
||||||
attention_chunk_size: int
|
attention_chunk_size: int
|
||||||
|
|
||||||
@property
|
|
||||||
def type_id(self) -> str:
|
|
||||||
return (
|
|
||||||
f"local_attention_{self.attention_chunk_size}_{self.block_size}_{self.page_size_bytes}"
|
|
||||||
) # noqa
|
|
||||||
|
|
||||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||||
max_model_len = vllm_config.model_config.max_model_len
|
max_model_len = vllm_config.model_config.max_model_len
|
||||||
max_num_batched_tokens = (
|
max_num_batched_tokens = (
|
||||||
@ -165,17 +153,13 @@ class ChunkedLocalAttentionSpec(AttentionSpec):
|
|||||||
return cdiv(num_tokens, self.block_size) * self.page_size_bytes
|
return cdiv(num_tokens, self.block_size) * self.page_size_bytes
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass(frozen=True)
|
||||||
class SlidingWindowSpec(AttentionSpec):
|
class SlidingWindowSpec(AttentionSpec):
|
||||||
sliding_window: int
|
sliding_window: int
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
assert not self.use_mla, "MLA is not supported for sliding window"
|
assert not self.use_mla, "MLA is not supported for sliding window"
|
||||||
|
|
||||||
@property
|
|
||||||
def type_id(self) -> str:
|
|
||||||
return f"sliding_window_{self.sliding_window}_{self.block_size}_{self.page_size_bytes}" # noqa
|
|
||||||
|
|
||||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||||
max_model_len = vllm_config.model_config.max_model_len
|
max_model_len = vllm_config.model_config.max_model_len
|
||||||
max_num_batched_tokens = (
|
max_num_batched_tokens = (
|
||||||
@ -195,23 +179,17 @@ class SlidingWindowSpec(AttentionSpec):
|
|||||||
return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes
|
return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass(frozen=True)
|
||||||
class MambaSpec(KVCacheSpec):
|
class MambaSpec(KVCacheSpec):
|
||||||
shapes: tuple[tuple[int, ...], ...]
|
shapes: tuple[tuple[int, ...], ...]
|
||||||
dtype: torch.dtype
|
dtype: torch.dtype
|
||||||
page_size_padded: Optional[int] = None
|
page_size_padded: Optional[int] = None
|
||||||
mamba_type: str = "mamba2"
|
mamba_type: str = "mamba2"
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
self.num_elements = sum(prod(shape) for shape in self.shapes)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def type_id(self) -> str:
|
|
||||||
return f"mamba_{self.shapes}_{self.dtype}_{self.mamba_type}"
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def page_size_bytes(self) -> int:
|
def page_size_bytes(self) -> int:
|
||||||
page_size = self.num_elements * get_dtype_size(self.dtype)
|
num_elements = sum(prod(shape) for shape in self.shapes)
|
||||||
|
page_size = num_elements * get_dtype_size(self.dtype)
|
||||||
if self.page_size_padded is not None:
|
if self.page_size_padded is not None:
|
||||||
assert self.page_size_padded >= page_size
|
assert self.page_size_padded >= page_size
|
||||||
return self.page_size_padded
|
return self.page_size_padded
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user