mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:25:00 +08:00
[Hybrid Allocator] Support full attention with different hidden size (#25101)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
parent
c60e6137f0
commit
9607d5eb44
@ -18,12 +18,14 @@ from vllm.v1.core.kv_cache_manager import KVCacheManager
|
||||
from vllm.v1.core.kv_cache_utils import (
|
||||
BlockHash, FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics,
|
||||
estimate_max_model_len, generate_block_hash_extra_keys,
|
||||
get_kv_cache_configs, get_max_concurrency_for_kv_cache_config,
|
||||
get_request_block_hasher, hash_block_tokens, init_none_hash,
|
||||
is_kv_cache_type_uniform, make_block_hash_with_group_id)
|
||||
generate_scheduler_kv_cache_config, get_kv_cache_configs,
|
||||
get_max_concurrency_for_kv_cache_config, get_request_block_hasher,
|
||||
hash_block_tokens, init_none_hash, is_kv_cache_spec_uniform,
|
||||
make_block_hash_with_group_id)
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, KVCacheSpec,
|
||||
KVCacheTensor, SlidingWindowSpec)
|
||||
KVCacheTensor, SlidingWindowSpec,
|
||||
UniformTypeKVCacheSpecs)
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats
|
||||
from vllm.v1.request import Request
|
||||
|
||||
@ -927,36 +929,36 @@ def test_merge_kv_cache_spec():
|
||||
assert merged_layer_spec.sliding_window == 1
|
||||
|
||||
|
||||
def test_is_kv_cache_type_uniform():
|
||||
def test_is_kv_cache_spec_uniform():
|
||||
kv_cache_spec = {
|
||||
"layer_1": new_kv_cache_spec(num_kv_heads=32),
|
||||
"layer_2": new_kv_cache_spec(num_kv_heads=32),
|
||||
}
|
||||
assert is_kv_cache_type_uniform(kv_cache_spec)
|
||||
assert is_kv_cache_spec_uniform(kv_cache_spec)
|
||||
|
||||
kv_cache_spec = {
|
||||
"layer_1": new_kv_cache_spec(num_kv_heads=32),
|
||||
"layer_2": new_kv_cache_spec(num_kv_heads=32, sliding_window=1),
|
||||
}
|
||||
assert is_kv_cache_type_uniform(kv_cache_spec)
|
||||
assert is_kv_cache_spec_uniform(kv_cache_spec)
|
||||
|
||||
kv_cache_spec = {
|
||||
"layer_1": new_kv_cache_spec(num_kv_heads=32),
|
||||
"layer_2": new_sliding_window_spec(num_kv_heads=32, sliding_window=1),
|
||||
}
|
||||
assert not is_kv_cache_type_uniform(kv_cache_spec)
|
||||
assert not is_kv_cache_spec_uniform(kv_cache_spec)
|
||||
|
||||
kv_cache_spec = {
|
||||
"layer_1": new_sliding_window_spec(num_kv_heads=32, sliding_window=1),
|
||||
"layer_2": new_sliding_window_spec(num_kv_heads=32, sliding_window=1),
|
||||
}
|
||||
assert is_kv_cache_type_uniform(kv_cache_spec)
|
||||
assert is_kv_cache_spec_uniform(kv_cache_spec)
|
||||
|
||||
kv_cache_spec = {
|
||||
"layer_1": new_sliding_window_spec(num_kv_heads=32, sliding_window=1),
|
||||
"layer_2": new_sliding_window_spec(num_kv_heads=32, sliding_window=2),
|
||||
}
|
||||
assert not is_kv_cache_type_uniform(kv_cache_spec)
|
||||
assert not is_kv_cache_spec_uniform(kv_cache_spec)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -1286,14 +1288,28 @@ def test_get_kv_cache_config_one_worker():
|
||||
],
|
||||
)
|
||||
|
||||
# different hidden size, unimplemented
|
||||
# different hidden size
|
||||
kv_cache_specs_hybrid = {
|
||||
'layer_1': new_kv_cache_spec(head_size=128),
|
||||
'layer_2': new_kv_cache_spec(),
|
||||
'layer_2': new_kv_cache_spec(head_size=64),
|
||||
}
|
||||
with pytest.raises(NotImplementedError):
|
||||
get_kv_cache_configs(vllm_config, [kv_cache_specs_hybrid],
|
||||
[mem_per_block_per_layer * 2 * 32])[0]
|
||||
kv_cache_config_hybrid = get_kv_cache_configs(
|
||||
vllm_config, [kv_cache_specs_hybrid],
|
||||
[mem_per_block_per_layer * 3 * 32])[0]
|
||||
assert kv_cache_config_hybrid == KVCacheConfig(
|
||||
num_blocks=32,
|
||||
kv_cache_tensors=[
|
||||
KVCacheTensor(size=mem_per_block_per_layer * 32 * 2,
|
||||
shared_by=["layer_1"]),
|
||||
KVCacheTensor(size=mem_per_block_per_layer * 32,
|
||||
shared_by=["layer_2"]),
|
||||
],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(["layer_1", "layer_2"],
|
||||
UniformTypeKVCacheSpecs(
|
||||
block_size=16,
|
||||
kv_cache_specs=kv_cache_specs_hybrid))
|
||||
])
|
||||
|
||||
# Test num_gpu_blocks_override
|
||||
vllm_config.cache_config.num_gpu_blocks_override = 16
|
||||
@ -1324,3 +1340,75 @@ def test_get_kv_cache_configs_attention_free():
|
||||
kv_cache_groups=[],
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def test_generate_uniform_type_kv_cache_specs():
|
||||
# All layers are full attention, can be merged
|
||||
kv_cache_specs = {
|
||||
'layer_1': new_kv_cache_spec(),
|
||||
'layer_2': new_kv_cache_spec(head_size=128),
|
||||
}
|
||||
uniform_spec = UniformTypeKVCacheSpecs.from_specs(kv_cache_specs)
|
||||
assert uniform_spec == UniformTypeKVCacheSpecs(
|
||||
block_size=16, kv_cache_specs=kv_cache_specs)
|
||||
|
||||
# Full attention + sliding window, cannot be merged
|
||||
kv_cache_specs = {
|
||||
'layer_1': new_kv_cache_spec(),
|
||||
'layer_2': new_sliding_window_spec(sliding_window=1),
|
||||
}
|
||||
uniform_spec = UniformTypeKVCacheSpecs.from_specs(kv_cache_specs)
|
||||
assert uniform_spec is None
|
||||
|
||||
# different order of full attention + sliding window, cannot be merged
|
||||
kv_cache_specs = {
|
||||
'layer_1': new_sliding_window_spec(sliding_window=1),
|
||||
'layer_2': new_kv_cache_spec(),
|
||||
}
|
||||
uniform_spec = UniformTypeKVCacheSpecs.from_specs(kv_cache_specs)
|
||||
assert uniform_spec is None
|
||||
|
||||
# Same-size sliding window, can be merged
|
||||
kv_cache_specs = {
|
||||
'layer_1': new_sliding_window_spec(sliding_window=1),
|
||||
'layer_2': new_sliding_window_spec(sliding_window=1, head_size=128),
|
||||
}
|
||||
uniform_spec = UniformTypeKVCacheSpecs.from_specs(kv_cache_specs)
|
||||
assert uniform_spec == UniformTypeKVCacheSpecs(
|
||||
block_size=16, kv_cache_specs=kv_cache_specs)
|
||||
|
||||
# different block sizes, cannot be merged
|
||||
kv_cache_specs = {
|
||||
'layer_1': new_kv_cache_spec(block_size=16),
|
||||
'layer_2': new_kv_cache_spec(block_size=32),
|
||||
}
|
||||
uniform_spec = UniformTypeKVCacheSpecs.from_specs(kv_cache_specs)
|
||||
assert uniform_spec is None
|
||||
|
||||
|
||||
def test_generate_scheduler_kv_cache_config():
|
||||
kv_cache_specs = {
|
||||
'layer_1': new_kv_cache_spec(),
|
||||
'layer_2': new_kv_cache_spec(head_size=128),
|
||||
}
|
||||
kv_cache_configs = [
|
||||
KVCacheConfig(
|
||||
num_blocks=10,
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(['layer_1', 'layer_2'],
|
||||
UniformTypeKVCacheSpecs(
|
||||
block_size=16,
|
||||
kv_cache_specs=kv_cache_specs)),
|
||||
],
|
||||
)
|
||||
]
|
||||
scheduler_kv_cache_config = generate_scheduler_kv_cache_config(
|
||||
kv_cache_configs)
|
||||
assert scheduler_kv_cache_config == KVCacheConfig(
|
||||
num_blocks=10,
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(['layer_1', 'layer_2'], new_kv_cache_spec())
|
||||
],
|
||||
)
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""KV-Cache Utilities."""
|
||||
|
||||
import copy
|
||||
import os
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Iterable, Sequence
|
||||
@ -15,7 +16,8 @@ from vllm.utils import GiB_bytes, cdiv, sha256_cbor
|
||||
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
|
||||
FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, KVCacheSpec,
|
||||
KVCacheTensor, SlidingWindowSpec)
|
||||
KVCacheTensor, SlidingWindowSpec,
|
||||
UniformTypeKVCacheSpecs)
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats
|
||||
from vllm.v1.request import Request
|
||||
|
||||
@ -750,7 +752,7 @@ def create_kv_cache_group_specs(
|
||||
return kv_cache_groups
|
||||
|
||||
|
||||
def is_kv_cache_type_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
|
||||
def is_kv_cache_spec_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
|
||||
"""
|
||||
Whether all layers in the given KVCacheSpec have the same KV cache spec.
|
||||
Note that we regard FullAttentionSpec with and without sliding window as
|
||||
@ -793,6 +795,21 @@ def get_max_concurrency_for_kv_cache_config(
|
||||
return max_concurrency
|
||||
|
||||
|
||||
def may_override_num_blocks(vllm_config: VllmConfig, num_blocks: int) -> int:
|
||||
"""
|
||||
Override the number of kv cache blocks if `num_gpu_blocks_override` is set.
|
||||
"""
|
||||
if vllm_config.cache_config.num_gpu_blocks_override is not None:
|
||||
num_gpu_blocks_override = \
|
||||
vllm_config.cache_config.num_gpu_blocks_override
|
||||
logger.info(
|
||||
"Overriding num_gpu_blocks=%d with "
|
||||
"num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override)
|
||||
num_blocks = num_gpu_blocks_override
|
||||
|
||||
return num_blocks
|
||||
|
||||
|
||||
def get_num_blocks(vllm_config: VllmConfig, num_layers: int,
|
||||
available_memory: int, page_size: int) -> int:
|
||||
"""
|
||||
@ -806,13 +823,7 @@ def get_num_blocks(vllm_config: VllmConfig, num_layers: int,
|
||||
"""
|
||||
num_blocks = int(available_memory // page_size // num_layers)
|
||||
num_blocks = max(num_blocks, 0)
|
||||
if vllm_config.cache_config.num_gpu_blocks_override is not None:
|
||||
num_gpu_blocks_override = \
|
||||
vllm_config.cache_config.num_gpu_blocks_override
|
||||
logger.info(
|
||||
"Overriding num_gpu_blocks=%d with "
|
||||
"num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override)
|
||||
num_blocks = num_gpu_blocks_override
|
||||
num_blocks = may_override_num_blocks(vllm_config, num_blocks)
|
||||
return num_blocks
|
||||
|
||||
|
||||
@ -825,11 +836,11 @@ def get_uniform_page_size(kv_cache_spec: dict[str, KVCacheSpec]) -> int:
|
||||
return page_sizes.pop()
|
||||
|
||||
|
||||
def _get_kv_cache_groups_uniform_type(
|
||||
def _get_kv_cache_groups_uniform_spec(
|
||||
kv_cache_specs: dict[str, KVCacheSpec]) -> list[KVCacheGroupSpec]:
|
||||
"""
|
||||
Generates the KV cache configuration for a model with one type of KV cache.
|
||||
Divide the available memory equally among all layers.
|
||||
Generates the KV cache configuration for a model with the same KV cache
|
||||
spec for all layers.
|
||||
|
||||
Args:
|
||||
kv_cache_specs: The kv cache spec of each attention layer in the model
|
||||
@ -842,6 +853,22 @@ def _get_kv_cache_groups_uniform_type(
|
||||
[list(kv_cache_specs.keys())])
|
||||
|
||||
|
||||
def _get_kv_cache_groups_uniform_type(
|
||||
spec: UniformTypeKVCacheSpecs) -> list[KVCacheGroupSpec]:
|
||||
"""
|
||||
Generates the KV cache configuration for a model with one type of KV cache
|
||||
but different hidden sizes. All layers are merged into one group.
|
||||
|
||||
Args:
|
||||
spec: The UniformTypeKVCacheSpecs of the model
|
||||
|
||||
Returns:
|
||||
The generated KVCacheGroupSpecs
|
||||
"""
|
||||
|
||||
return [KVCacheGroupSpec(list(spec.kv_cache_specs.keys()), spec)]
|
||||
|
||||
|
||||
def is_kv_cache_page_size_uniform(
|
||||
kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
|
||||
"""
|
||||
@ -1000,28 +1027,45 @@ def get_kv_cache_config_from_groups(vllm_config: VllmConfig,
|
||||
)
|
||||
|
||||
# Determine how model runners should initialize the KV cache tensors.
|
||||
# We will have group_size memory pools, each is shared by one layer from
|
||||
# each group. As layers of different groups have different block table,
|
||||
# they will use different parts of the shared Tensor.
|
||||
# The memory layout for 3 groups (full.0, full.1), (sw.0, sw.2),
|
||||
# (sw.1, padding) will be: (group_size = 2)
|
||||
# full.0, sw.0, sw.1: share a Tensor with size=available_memory//2
|
||||
# full.1, sw.2: share another Tensor with size=available_memory//2
|
||||
group_size = max(len(group.layer_names) for group in kv_cache_groups)
|
||||
if len(kv_cache_groups) == 1 and \
|
||||
isinstance(kv_cache_groups[0].kv_cache_spec, UniformTypeKVCacheSpecs):
|
||||
# Special case: all layers have the same type of KV cache but with
|
||||
# different hidden size. Allocate different amount of memory for each
|
||||
# layer based on its hidden size.
|
||||
num_blocks = available_memory // kv_cache_groups[
|
||||
0].kv_cache_spec.page_size_bytes
|
||||
num_blocks = may_override_num_blocks(vllm_config, num_blocks)
|
||||
per_layer_specs = kv_cache_groups[0].kv_cache_spec.kv_cache_specs
|
||||
kv_cache_tensors = [
|
||||
KVCacheTensor(size=per_layer_specs[layer_name].page_size_bytes *
|
||||
num_blocks,
|
||||
shared_by=[layer_name])
|
||||
for layer_name in kv_cache_groups[0].layer_names
|
||||
]
|
||||
else:
|
||||
# General case:
|
||||
# We will have group_size memory pools, each is shared by one layer from
|
||||
# each group. As layers of different groups have different block table,
|
||||
# they will use different parts of the shared Tensor.
|
||||
# The memory layout for 3 groups (full.0, full.1), (sw.0, sw.2),
|
||||
# (sw.1, padding) will be: (group_size = 2)
|
||||
# full.0, sw.0, sw.1: share a Tensor with size=available_memory//2
|
||||
# full.1, sw.2: share another Tensor with size=available_memory//2
|
||||
group_size = max(len(group.layer_names) for group in kv_cache_groups)
|
||||
|
||||
page_size = get_uniform_page_size(kv_cache_specs)
|
||||
assert group_size > 0, "group_size must be greater than 0"
|
||||
num_blocks = get_num_blocks(vllm_config, group_size, available_memory,
|
||||
page_size)
|
||||
per_memory_pool_size = page_size * num_blocks
|
||||
kv_cache_tensors = []
|
||||
for i in range(group_size):
|
||||
shared_by = []
|
||||
for j in range(len(kv_cache_groups)):
|
||||
if i < len(kv_cache_groups[j].layer_names):
|
||||
shared_by.append(kv_cache_groups[j].layer_names[i])
|
||||
kv_cache_tensors.append(
|
||||
KVCacheTensor(size=per_memory_pool_size, shared_by=shared_by))
|
||||
page_size = get_uniform_page_size(kv_cache_specs)
|
||||
assert group_size > 0, "group_size must be greater than 0"
|
||||
num_blocks = get_num_blocks(vllm_config, group_size, available_memory,
|
||||
page_size)
|
||||
kv_cache_tensors = []
|
||||
for i in range(group_size):
|
||||
shared_by = []
|
||||
for j in range(len(kv_cache_groups)):
|
||||
if i < len(kv_cache_groups[j].layer_names):
|
||||
shared_by.append(kv_cache_groups[j].layer_names[i])
|
||||
kv_cache_tensors.append(
|
||||
KVCacheTensor(size=page_size * num_blocks,
|
||||
shared_by=shared_by))
|
||||
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=num_blocks,
|
||||
@ -1059,7 +1103,7 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
|
||||
kv_cache_spec: The kv cache spec of each attention layer in the model
|
||||
"""
|
||||
|
||||
if is_kv_cache_type_uniform(kv_cache_spec):
|
||||
if is_kv_cache_spec_uniform(kv_cache_spec):
|
||||
return
|
||||
|
||||
logger.warning(
|
||||
@ -1097,7 +1141,7 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
|
||||
attention_chunk_size=spec.attention_chunk_size,
|
||||
)
|
||||
|
||||
if not is_kv_cache_type_uniform(kv_cache_spec):
|
||||
if not is_kv_cache_spec_uniform(kv_cache_spec):
|
||||
raise ValueError("Hybrid KV cache manager is disabled but failed to "
|
||||
"convert the KV cache specs to one unified type.")
|
||||
|
||||
@ -1122,11 +1166,16 @@ def get_kv_cache_groups(
|
||||
# This returns an empty list to allow for the KVCacheManager to handle
|
||||
# attention free models.
|
||||
return []
|
||||
elif is_kv_cache_type_uniform(kv_cache_spec):
|
||||
elif is_kv_cache_spec_uniform(kv_cache_spec):
|
||||
# KV cache of all layers are the same, which is true for
|
||||
# most models. Allocate the same amount of memory for
|
||||
# each layer.
|
||||
return _get_kv_cache_groups_uniform_type(kv_cache_spec)
|
||||
return _get_kv_cache_groups_uniform_spec(kv_cache_spec)
|
||||
elif uniform_spec := UniformTypeKVCacheSpecs.from_specs(kv_cache_spec):
|
||||
# All layers need the same number of token slots (e.g., all layers are
|
||||
# full attention, or all layers are sliding window attention with the
|
||||
# same window size). Put all layers into one group.
|
||||
return _get_kv_cache_groups_uniform_type(uniform_spec)
|
||||
elif is_kv_cache_page_size_uniform(kv_cache_spec):
|
||||
# Model contains multiple attention types, but KV cache of all layers
|
||||
# have the same physical memory per block per layer. Split the layers
|
||||
@ -1137,6 +1186,27 @@ def get_kv_cache_groups(
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def generate_scheduler_kv_cache_config(
|
||||
kv_cache_configs: list[KVCacheConfig]) -> KVCacheConfig:
|
||||
"""
|
||||
Generate the KV cache configuration for the scheduler.
|
||||
"""
|
||||
assert all([
|
||||
cfg.num_blocks == kv_cache_configs[0].num_blocks
|
||||
for cfg in kv_cache_configs
|
||||
])
|
||||
# All workers have the same kv_cache_config except layer names, so use
|
||||
# an arbitrary one to initialize the scheduler.
|
||||
cfg = copy.deepcopy(kv_cache_configs[0])
|
||||
for group in cfg.kv_cache_groups:
|
||||
if isinstance(group.kv_cache_spec, UniformTypeKVCacheSpecs):
|
||||
# All layers in the UniformTypeKVCacheSpecs have the same type,
|
||||
# so use an arbitrary one to initialize the scheduler.
|
||||
group.kv_cache_spec = next(
|
||||
iter(group.kv_cache_spec.kv_cache_specs.values()))
|
||||
return cfg
|
||||
|
||||
|
||||
def get_kv_cache_configs(vllm_config: VllmConfig,
|
||||
kv_cache_specs: list[dict[str, KVCacheSpec]],
|
||||
available_memory: list[int]) -> list[KVCacheConfig]:
|
||||
|
||||
@ -29,7 +29,9 @@ from vllm.transformers_utils.config import (
|
||||
maybe_register_config_serialize_by_value)
|
||||
from vllm.utils import (decorate_logs, get_hash_fn_by_name, make_zmq_socket,
|
||||
resolve_obj_by_qualname, set_process_title)
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, get_kv_cache_configs,
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash,
|
||||
generate_scheduler_kv_cache_config,
|
||||
get_kv_cache_configs,
|
||||
get_request_block_hasher,
|
||||
init_none_hash)
|
||||
from vllm.v1.core.sched.interface import SchedulerInterface
|
||||
@ -196,16 +198,10 @@ class EngineCore:
|
||||
|
||||
kv_cache_configs = get_kv_cache_configs(vllm_config, kv_cache_specs,
|
||||
available_gpu_memory)
|
||||
|
||||
# All workers have the same kv_cache_config except layer names, so use
|
||||
# an arbitrary one to initialize the scheduler.
|
||||
assert all([
|
||||
cfg.num_blocks == kv_cache_configs[0].num_blocks
|
||||
for cfg in kv_cache_configs
|
||||
])
|
||||
num_gpu_blocks = kv_cache_configs[0].num_blocks
|
||||
scheduler_kv_cache_config = generate_scheduler_kv_cache_config(
|
||||
kv_cache_configs)
|
||||
num_gpu_blocks = scheduler_kv_cache_config.num_blocks
|
||||
num_cpu_blocks = 0
|
||||
scheduler_kv_cache_config = kv_cache_configs[0]
|
||||
|
||||
# Initialize kv cache and warmup the execution
|
||||
self.model_executor.initialize_from_config(kv_cache_configs)
|
||||
|
||||
@ -234,6 +234,76 @@ class CrossAttentionSpec(AttentionSpec):
|
||||
return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UniformTypeKVCacheSpecs(KVCacheSpec):
|
||||
"""
|
||||
A KV cache spec for multiple layers with the same type of attention. Here,
|
||||
same types means always need the same number of token slots. For example,
|
||||
sliding window attentions with different window sizes are not the same type
|
||||
and should not be merged into one UniformTypeKVCacheSpecs.
|
||||
"""
|
||||
kv_cache_specs: dict[str, KVCacheSpec]
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
return sum(spec.page_size_bytes
|
||||
for spec in self.kv_cache_specs.values())
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
max_num_pages = max(
|
||||
cdiv(spec.max_memory_usage_bytes(vllm_config),
|
||||
spec.page_size_bytes)
|
||||
for spec in self.kv_cache_specs.values())
|
||||
return max_num_pages * self.page_size_bytes
|
||||
|
||||
@classmethod
|
||||
def is_uniform_type(cls, kv_cache_specs: dict[str, KVCacheSpec]) -> bool:
|
||||
"""
|
||||
Whether all layers have the same type of KV cache spec.
|
||||
"""
|
||||
block_sizes = set(spec.block_size for spec in kv_cache_specs.values())
|
||||
if len(block_sizes) > 1:
|
||||
# Different block sizes, not uniform.
|
||||
return False
|
||||
one_spec = next(iter(kv_cache_specs.values()))
|
||||
if isinstance(one_spec, (FullAttentionSpec, CrossAttentionSpec)):
|
||||
return all(
|
||||
isinstance(spec, type(one_spec))
|
||||
for spec in kv_cache_specs.values())
|
||||
elif isinstance(one_spec, SlidingWindowSpec):
|
||||
return all(
|
||||
isinstance(spec, SlidingWindowSpec)
|
||||
and spec.sliding_window == one_spec.sliding_window
|
||||
for spec in kv_cache_specs.values())
|
||||
elif isinstance(one_spec, ChunkedLocalAttentionSpec):
|
||||
return all(
|
||||
isinstance(spec, ChunkedLocalAttentionSpec)
|
||||
and spec.attention_chunk_size == one_spec.attention_chunk_size
|
||||
for spec in kv_cache_specs.values())
|
||||
elif isinstance(one_spec, MambaSpec):
|
||||
return all(
|
||||
isinstance(spec, MambaSpec) and spec.num_speculative_blocks ==
|
||||
one_spec.num_speculative_blocks
|
||||
for spec in kv_cache_specs.values())
|
||||
else:
|
||||
# NOTE(Chen): Please add new branches for new KV cache spec types.
|
||||
raise NotImplementedError(
|
||||
f"Unsupported KV cache spec type: {type(one_spec)}")
|
||||
|
||||
@classmethod
|
||||
def from_specs(cls, kv_cache_specs: dict[str,
|
||||
KVCacheSpec]) -> Optional[Self]:
|
||||
"""
|
||||
Return a SameTypeKVCacheSpecs object if all layers have the same type
|
||||
of KV cache spec. Return None if not.
|
||||
"""
|
||||
if cls.is_uniform_type(kv_cache_specs):
|
||||
block_size = next(iter(kv_cache_specs.values())).block_size
|
||||
return cls(block_size=block_size, kv_cache_specs=kv_cache_specs)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVCacheTensor:
|
||||
"""
|
||||
|
||||
@ -8,7 +8,7 @@ from collections import defaultdict
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -74,7 +74,8 @@ from vllm.v1.kv_cache_interface import (AttentionSpec,
|
||||
EncoderOnlyAttentionSpec,
|
||||
FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, KVCacheSpec,
|
||||
MambaSpec, SlidingWindowSpec)
|
||||
MambaSpec, SlidingWindowSpec,
|
||||
UniformTypeKVCacheSpecs)
|
||||
# yapf: enable
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
||||
DraftTokenIds, LogprobsLists, LogprobsTensors,
|
||||
@ -1187,7 +1188,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
common_prefix_len = self._compute_cascade_attn_prefix_len(
|
||||
num_scheduled_tokens,
|
||||
num_common_prefix_blocks,
|
||||
kv_cache_group_spec.kv_cache_spec,
|
||||
attn_group.kv_cache_spec,
|
||||
builder,
|
||||
)
|
||||
|
||||
@ -3453,12 +3454,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
assert len(self.attn_groups) == 0, \
|
||||
"Attention backends are already initialized"
|
||||
|
||||
def get_attn_backends_for_layers(
|
||||
layer_names: list[str]
|
||||
) -> dict[type[AttentionBackend], list[str]]:
|
||||
layers = get_layers_from_vllm_config(self.vllm_config,
|
||||
AttentionLayerBase,
|
||||
layer_names)
|
||||
class AttentionGroupKey(NamedTuple):
|
||||
attn_backend: type[AttentionBackend]
|
||||
kv_cache_spec: KVCacheSpec
|
||||
|
||||
def get_attn_backends_for_group(
|
||||
kv_cache_group_spec: KVCacheGroupSpec,
|
||||
) -> dict[AttentionGroupKey, list[str]]:
|
||||
layers = get_layers_from_vllm_config(
|
||||
self.vllm_config, AttentionLayerBase,
|
||||
kv_cache_group_spec.layer_names)
|
||||
attn_backends = {}
|
||||
attn_backend_layers = defaultdict(list)
|
||||
# Dedupe based on full class name; this is a bit safer than
|
||||
@ -3466,7 +3471,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# attention backend subclasses (e.g. ChunkedLocalAttention) unless
|
||||
# they are cached correctly, there will be different objects per
|
||||
# layer.
|
||||
for layer_name in layer_names:
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
attn_backend = layers[layer_name].get_attn_backend()
|
||||
|
||||
if layer_name in self.kv_sharing_fast_prefill_eligible_layers:
|
||||
@ -3475,8 +3480,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
attn_backend,
|
||||
)
|
||||
|
||||
key = attn_backend.full_cls_name()
|
||||
attn_backends[key] = attn_backend
|
||||
full_cls_name = attn_backend.full_cls_name()
|
||||
layer_kv_cache_spec = kv_cache_group_spec.kv_cache_spec
|
||||
if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs):
|
||||
layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[
|
||||
layer_name]
|
||||
key = (full_cls_name, layer_kv_cache_spec)
|
||||
attn_backends[key] = AttentionGroupKey(attn_backend,
|
||||
layer_kv_cache_spec)
|
||||
attn_backend_layers[key].append(layer_name)
|
||||
return {
|
||||
attn_backends[k]: v
|
||||
@ -3484,11 +3495,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
}
|
||||
|
||||
def create_attn_groups(
|
||||
attn_backends_map: dict[AttentionBackend, list[str]],
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
attn_backends_map: dict[AttentionGroupKey, list[str]],
|
||||
) -> list[AttentionGroup]:
|
||||
attn_groups: list[AttentionGroup] = []
|
||||
for attn_backend, layer_names in attn_backends_map.items():
|
||||
for (attn_backend,
|
||||
kv_cache_spec), layer_names in attn_backends_map.items():
|
||||
attn_metadata_builders = []
|
||||
attn_metadata_builders.append(attn_backend.get_builder_cls()(
|
||||
kv_cache_spec,
|
||||
@ -3506,16 +3517,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
))
|
||||
attn_group = AttentionGroup(attn_backend,
|
||||
attn_metadata_builders,
|
||||
layer_names)
|
||||
layer_names, kv_cache_spec)
|
||||
attn_groups.append(attn_group)
|
||||
return attn_groups
|
||||
|
||||
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
|
||||
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
|
||||
attn_backends = get_attn_backends_for_layers(
|
||||
kv_cache_group_spec.layer_names)
|
||||
self.attn_groups.append(
|
||||
create_attn_groups(attn_backends, kv_cache_spec))
|
||||
attn_backends = get_attn_backends_for_group(kv_cache_group_spec)
|
||||
self.attn_groups.append(create_attn_groups(attn_backends))
|
||||
|
||||
# Calculate reorder batch threshold (if needed)
|
||||
self.calculate_reorder_batch_threshold()
|
||||
@ -3680,14 +3688,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
def _attn_group_iterator(self) -> Iterator[AttentionGroup]:
|
||||
return itertools.chain.from_iterable(self.attn_groups)
|
||||
|
||||
def _kv_cache_spec_attn_group_iterator(
|
||||
self) -> Iterator[tuple[KVCacheSpec, AttentionGroup]]:
|
||||
def _kv_cache_spec_attn_group_iterator(self) -> Iterator[AttentionGroup]:
|
||||
if not self.kv_cache_config.kv_cache_groups:
|
||||
return
|
||||
for kv_cache_spec_id, attn_groups in enumerate(self.attn_groups):
|
||||
for attn_group in attn_groups:
|
||||
yield self.kv_cache_config.kv_cache_groups[
|
||||
kv_cache_spec_id].kv_cache_spec, attn_group
|
||||
for attn_groups in self.attn_groups:
|
||||
yield from attn_groups
|
||||
|
||||
def _reshape_kv_cache_tensors(
|
||||
self,
|
||||
@ -3707,7 +3712,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
"""
|
||||
kv_caches: dict[str, torch.Tensor] = {}
|
||||
has_attn, has_mamba = False, False
|
||||
for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator():
|
||||
for group in self._kv_cache_spec_attn_group_iterator():
|
||||
kv_cache_spec = group.kv_cache_spec
|
||||
attn_backend = group.backend
|
||||
for layer_name in group.layer_names:
|
||||
if layer_name in self.runner_only_attn_layers:
|
||||
@ -3787,7 +3793,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
kv_caches: The KV cache buffer of each layer.
|
||||
"""
|
||||
|
||||
for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator():
|
||||
for group in self._kv_cache_spec_attn_group_iterator():
|
||||
kv_cache_spec = group.kv_cache_spec
|
||||
for layer_name in group.layer_names:
|
||||
kv_cache = kv_caches[layer_name]
|
||||
if (isinstance(kv_cache_spec, AttentionSpec)
|
||||
|
||||
@ -15,7 +15,7 @@ from vllm.multimodal.registry import MultiModalRegistry
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
|
||||
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
|
||||
from vllm.v1.kv_cache_interface import KVCacheGroupSpec
|
||||
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.layer import Attention
|
||||
@ -132,6 +132,7 @@ class AttentionGroup:
|
||||
backend: type[AttentionBackend]
|
||||
metadata_builders: list[AttentionMetadataBuilder]
|
||||
layer_names: list[str]
|
||||
kv_cache_spec: KVCacheSpec
|
||||
|
||||
def get_metadata_builder(self,
|
||||
ubatch_id: Optional[int] = None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user