[Hybrid Allocator] Support full attention with different hidden size (#25101)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang 2025-09-19 23:43:59 -07:00 committed by GitHub
parent c60e6137f0
commit 9607d5eb44
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 324 additions and 92 deletions

View File

@ -18,12 +18,14 @@ from vllm.v1.core.kv_cache_manager import KVCacheManager
from vllm.v1.core.kv_cache_utils import (
BlockHash, FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics,
estimate_max_model_len, generate_block_hash_extra_keys,
get_kv_cache_configs, get_max_concurrency_for_kv_cache_config,
get_request_block_hasher, hash_block_tokens, init_none_hash,
is_kv_cache_type_uniform, make_block_hash_with_group_id)
generate_scheduler_kv_cache_config, get_kv_cache_configs,
get_max_concurrency_for_kv_cache_config, get_request_block_hasher,
hash_block_tokens, init_none_hash, is_kv_cache_spec_uniform,
make_block_hash_with_group_id)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheSpec,
KVCacheTensor, SlidingWindowSpec)
KVCacheTensor, SlidingWindowSpec,
UniformTypeKVCacheSpecs)
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request
@ -927,36 +929,36 @@ def test_merge_kv_cache_spec():
assert merged_layer_spec.sliding_window == 1
def test_is_kv_cache_type_uniform():
def test_is_kv_cache_spec_uniform():
kv_cache_spec = {
"layer_1": new_kv_cache_spec(num_kv_heads=32),
"layer_2": new_kv_cache_spec(num_kv_heads=32),
}
assert is_kv_cache_type_uniform(kv_cache_spec)
assert is_kv_cache_spec_uniform(kv_cache_spec)
kv_cache_spec = {
"layer_1": new_kv_cache_spec(num_kv_heads=32),
"layer_2": new_kv_cache_spec(num_kv_heads=32, sliding_window=1),
}
assert is_kv_cache_type_uniform(kv_cache_spec)
assert is_kv_cache_spec_uniform(kv_cache_spec)
kv_cache_spec = {
"layer_1": new_kv_cache_spec(num_kv_heads=32),
"layer_2": new_sliding_window_spec(num_kv_heads=32, sliding_window=1),
}
assert not is_kv_cache_type_uniform(kv_cache_spec)
assert not is_kv_cache_spec_uniform(kv_cache_spec)
kv_cache_spec = {
"layer_1": new_sliding_window_spec(num_kv_heads=32, sliding_window=1),
"layer_2": new_sliding_window_spec(num_kv_heads=32, sliding_window=1),
}
assert is_kv_cache_type_uniform(kv_cache_spec)
assert is_kv_cache_spec_uniform(kv_cache_spec)
kv_cache_spec = {
"layer_1": new_sliding_window_spec(num_kv_heads=32, sliding_window=1),
"layer_2": new_sliding_window_spec(num_kv_heads=32, sliding_window=2),
}
assert not is_kv_cache_type_uniform(kv_cache_spec)
assert not is_kv_cache_spec_uniform(kv_cache_spec)
@pytest.mark.parametrize(
@ -1286,14 +1288,28 @@ def test_get_kv_cache_config_one_worker():
],
)
# different hidden size, unimplemented
# different hidden size
kv_cache_specs_hybrid = {
'layer_1': new_kv_cache_spec(head_size=128),
'layer_2': new_kv_cache_spec(),
'layer_2': new_kv_cache_spec(head_size=64),
}
with pytest.raises(NotImplementedError):
get_kv_cache_configs(vllm_config, [kv_cache_specs_hybrid],
[mem_per_block_per_layer * 2 * 32])[0]
kv_cache_config_hybrid = get_kv_cache_configs(
vllm_config, [kv_cache_specs_hybrid],
[mem_per_block_per_layer * 3 * 32])[0]
assert kv_cache_config_hybrid == KVCacheConfig(
num_blocks=32,
kv_cache_tensors=[
KVCacheTensor(size=mem_per_block_per_layer * 32 * 2,
shared_by=["layer_1"]),
KVCacheTensor(size=mem_per_block_per_layer * 32,
shared_by=["layer_2"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer_1", "layer_2"],
UniformTypeKVCacheSpecs(
block_size=16,
kv_cache_specs=kv_cache_specs_hybrid))
])
# Test num_gpu_blocks_override
vllm_config.cache_config.num_gpu_blocks_override = 16
@ -1324,3 +1340,75 @@ def test_get_kv_cache_configs_attention_free():
kv_cache_groups=[],
)
]
def test_generate_uniform_type_kv_cache_specs():
# All layers are full attention, can be merged
kv_cache_specs = {
'layer_1': new_kv_cache_spec(),
'layer_2': new_kv_cache_spec(head_size=128),
}
uniform_spec = UniformTypeKVCacheSpecs.from_specs(kv_cache_specs)
assert uniform_spec == UniformTypeKVCacheSpecs(
block_size=16, kv_cache_specs=kv_cache_specs)
# Full attention + sliding window, cannot be merged
kv_cache_specs = {
'layer_1': new_kv_cache_spec(),
'layer_2': new_sliding_window_spec(sliding_window=1),
}
uniform_spec = UniformTypeKVCacheSpecs.from_specs(kv_cache_specs)
assert uniform_spec is None
# different order of full attention + sliding window, cannot be merged
kv_cache_specs = {
'layer_1': new_sliding_window_spec(sliding_window=1),
'layer_2': new_kv_cache_spec(),
}
uniform_spec = UniformTypeKVCacheSpecs.from_specs(kv_cache_specs)
assert uniform_spec is None
# Same-size sliding window, can be merged
kv_cache_specs = {
'layer_1': new_sliding_window_spec(sliding_window=1),
'layer_2': new_sliding_window_spec(sliding_window=1, head_size=128),
}
uniform_spec = UniformTypeKVCacheSpecs.from_specs(kv_cache_specs)
assert uniform_spec == UniformTypeKVCacheSpecs(
block_size=16, kv_cache_specs=kv_cache_specs)
# different block sizes, cannot be merged
kv_cache_specs = {
'layer_1': new_kv_cache_spec(block_size=16),
'layer_2': new_kv_cache_spec(block_size=32),
}
uniform_spec = UniformTypeKVCacheSpecs.from_specs(kv_cache_specs)
assert uniform_spec is None
def test_generate_scheduler_kv_cache_config():
kv_cache_specs = {
'layer_1': new_kv_cache_spec(),
'layer_2': new_kv_cache_spec(head_size=128),
}
kv_cache_configs = [
KVCacheConfig(
num_blocks=10,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(['layer_1', 'layer_2'],
UniformTypeKVCacheSpecs(
block_size=16,
kv_cache_specs=kv_cache_specs)),
],
)
]
scheduler_kv_cache_config = generate_scheduler_kv_cache_config(
kv_cache_configs)
assert scheduler_kv_cache_config == KVCacheConfig(
num_blocks=10,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(['layer_1', 'layer_2'], new_kv_cache_spec())
],
)

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""KV-Cache Utilities."""
import copy
import os
from collections import defaultdict, deque
from collections.abc import Iterable, Sequence
@ -15,7 +16,8 @@ from vllm.utils import GiB_bytes, cdiv, sha256_cbor
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheSpec,
KVCacheTensor, SlidingWindowSpec)
KVCacheTensor, SlidingWindowSpec,
UniformTypeKVCacheSpecs)
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request
@ -750,7 +752,7 @@ def create_kv_cache_group_specs(
return kv_cache_groups
def is_kv_cache_type_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
def is_kv_cache_spec_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
"""
Whether all layers in the given KVCacheSpec have the same KV cache spec.
Note that we regard FullAttentionSpec with and without sliding window as
@ -793,6 +795,21 @@ def get_max_concurrency_for_kv_cache_config(
return max_concurrency
def may_override_num_blocks(vllm_config: VllmConfig, num_blocks: int) -> int:
"""
Override the number of kv cache blocks if `num_gpu_blocks_override` is set.
"""
if vllm_config.cache_config.num_gpu_blocks_override is not None:
num_gpu_blocks_override = \
vllm_config.cache_config.num_gpu_blocks_override
logger.info(
"Overriding num_gpu_blocks=%d with "
"num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override)
num_blocks = num_gpu_blocks_override
return num_blocks
def get_num_blocks(vllm_config: VllmConfig, num_layers: int,
available_memory: int, page_size: int) -> int:
"""
@ -806,13 +823,7 @@ def get_num_blocks(vllm_config: VllmConfig, num_layers: int,
"""
num_blocks = int(available_memory // page_size // num_layers)
num_blocks = max(num_blocks, 0)
if vllm_config.cache_config.num_gpu_blocks_override is not None:
num_gpu_blocks_override = \
vllm_config.cache_config.num_gpu_blocks_override
logger.info(
"Overriding num_gpu_blocks=%d with "
"num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override)
num_blocks = num_gpu_blocks_override
num_blocks = may_override_num_blocks(vllm_config, num_blocks)
return num_blocks
@ -825,11 +836,11 @@ def get_uniform_page_size(kv_cache_spec: dict[str, KVCacheSpec]) -> int:
return page_sizes.pop()
def _get_kv_cache_groups_uniform_type(
def _get_kv_cache_groups_uniform_spec(
kv_cache_specs: dict[str, KVCacheSpec]) -> list[KVCacheGroupSpec]:
"""
Generates the KV cache configuration for a model with one type of KV cache.
Divide the available memory equally among all layers.
Generates the KV cache configuration for a model with the same KV cache
spec for all layers.
Args:
kv_cache_specs: The kv cache spec of each attention layer in the model
@ -842,6 +853,22 @@ def _get_kv_cache_groups_uniform_type(
[list(kv_cache_specs.keys())])
def _get_kv_cache_groups_uniform_type(
spec: UniformTypeKVCacheSpecs) -> list[KVCacheGroupSpec]:
"""
Generates the KV cache configuration for a model with one type of KV cache
but different hidden sizes. All layers are merged into one group.
Args:
spec: The UniformTypeKVCacheSpecs of the model
Returns:
The generated KVCacheGroupSpecs
"""
return [KVCacheGroupSpec(list(spec.kv_cache_specs.keys()), spec)]
def is_kv_cache_page_size_uniform(
kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
"""
@ -1000,28 +1027,45 @@ def get_kv_cache_config_from_groups(vllm_config: VllmConfig,
)
# Determine how model runners should initialize the KV cache tensors.
# We will have group_size memory pools, each is shared by one layer from
# each group. As layers of different groups have different block table,
# they will use different parts of the shared Tensor.
# The memory layout for 3 groups (full.0, full.1), (sw.0, sw.2),
# (sw.1, padding) will be: (group_size = 2)
# full.0, sw.0, sw.1: share a Tensor with size=available_memory//2
# full.1, sw.2: share another Tensor with size=available_memory//2
group_size = max(len(group.layer_names) for group in kv_cache_groups)
if len(kv_cache_groups) == 1 and \
isinstance(kv_cache_groups[0].kv_cache_spec, UniformTypeKVCacheSpecs):
# Special case: all layers have the same type of KV cache but with
# different hidden size. Allocate different amount of memory for each
# layer based on its hidden size.
num_blocks = available_memory // kv_cache_groups[
0].kv_cache_spec.page_size_bytes
num_blocks = may_override_num_blocks(vllm_config, num_blocks)
per_layer_specs = kv_cache_groups[0].kv_cache_spec.kv_cache_specs
kv_cache_tensors = [
KVCacheTensor(size=per_layer_specs[layer_name].page_size_bytes *
num_blocks,
shared_by=[layer_name])
for layer_name in kv_cache_groups[0].layer_names
]
else:
# General case:
# We will have group_size memory pools, each is shared by one layer from
# each group. As layers of different groups have different block table,
# they will use different parts of the shared Tensor.
# The memory layout for 3 groups (full.0, full.1), (sw.0, sw.2),
# (sw.1, padding) will be: (group_size = 2)
# full.0, sw.0, sw.1: share a Tensor with size=available_memory//2
# full.1, sw.2: share another Tensor with size=available_memory//2
group_size = max(len(group.layer_names) for group in kv_cache_groups)
page_size = get_uniform_page_size(kv_cache_specs)
assert group_size > 0, "group_size must be greater than 0"
num_blocks = get_num_blocks(vllm_config, group_size, available_memory,
page_size)
per_memory_pool_size = page_size * num_blocks
kv_cache_tensors = []
for i in range(group_size):
shared_by = []
for j in range(len(kv_cache_groups)):
if i < len(kv_cache_groups[j].layer_names):
shared_by.append(kv_cache_groups[j].layer_names[i])
kv_cache_tensors.append(
KVCacheTensor(size=per_memory_pool_size, shared_by=shared_by))
page_size = get_uniform_page_size(kv_cache_specs)
assert group_size > 0, "group_size must be greater than 0"
num_blocks = get_num_blocks(vllm_config, group_size, available_memory,
page_size)
kv_cache_tensors = []
for i in range(group_size):
shared_by = []
for j in range(len(kv_cache_groups)):
if i < len(kv_cache_groups[j].layer_names):
shared_by.append(kv_cache_groups[j].layer_names[i])
kv_cache_tensors.append(
KVCacheTensor(size=page_size * num_blocks,
shared_by=shared_by))
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks,
@ -1059,7 +1103,7 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
kv_cache_spec: The kv cache spec of each attention layer in the model
"""
if is_kv_cache_type_uniform(kv_cache_spec):
if is_kv_cache_spec_uniform(kv_cache_spec):
return
logger.warning(
@ -1097,7 +1141,7 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
attention_chunk_size=spec.attention_chunk_size,
)
if not is_kv_cache_type_uniform(kv_cache_spec):
if not is_kv_cache_spec_uniform(kv_cache_spec):
raise ValueError("Hybrid KV cache manager is disabled but failed to "
"convert the KV cache specs to one unified type.")
@ -1122,11 +1166,16 @@ def get_kv_cache_groups(
# This returns an empty list to allow for the KVCacheManager to handle
# attention free models.
return []
elif is_kv_cache_type_uniform(kv_cache_spec):
elif is_kv_cache_spec_uniform(kv_cache_spec):
# KV cache of all layers are the same, which is true for
# most models. Allocate the same amount of memory for
# each layer.
return _get_kv_cache_groups_uniform_type(kv_cache_spec)
return _get_kv_cache_groups_uniform_spec(kv_cache_spec)
elif uniform_spec := UniformTypeKVCacheSpecs.from_specs(kv_cache_spec):
# All layers need the same number of token slots (e.g., all layers are
# full attention, or all layers are sliding window attention with the
# same window size). Put all layers into one group.
return _get_kv_cache_groups_uniform_type(uniform_spec)
elif is_kv_cache_page_size_uniform(kv_cache_spec):
# Model contains multiple attention types, but KV cache of all layers
# have the same physical memory per block per layer. Split the layers
@ -1137,6 +1186,27 @@ def get_kv_cache_groups(
raise NotImplementedError
def generate_scheduler_kv_cache_config(
kv_cache_configs: list[KVCacheConfig]) -> KVCacheConfig:
"""
Generate the KV cache configuration for the scheduler.
"""
assert all([
cfg.num_blocks == kv_cache_configs[0].num_blocks
for cfg in kv_cache_configs
])
# All workers have the same kv_cache_config except layer names, so use
# an arbitrary one to initialize the scheduler.
cfg = copy.deepcopy(kv_cache_configs[0])
for group in cfg.kv_cache_groups:
if isinstance(group.kv_cache_spec, UniformTypeKVCacheSpecs):
# All layers in the UniformTypeKVCacheSpecs have the same type,
# so use an arbitrary one to initialize the scheduler.
group.kv_cache_spec = next(
iter(group.kv_cache_spec.kv_cache_specs.values()))
return cfg
def get_kv_cache_configs(vllm_config: VllmConfig,
kv_cache_specs: list[dict[str, KVCacheSpec]],
available_memory: list[int]) -> list[KVCacheConfig]:

View File

@ -29,7 +29,9 @@ from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.utils import (decorate_logs, get_hash_fn_by_name, make_zmq_socket,
resolve_obj_by_qualname, set_process_title)
from vllm.v1.core.kv_cache_utils import (BlockHash, get_kv_cache_configs,
from vllm.v1.core.kv_cache_utils import (BlockHash,
generate_scheduler_kv_cache_config,
get_kv_cache_configs,
get_request_block_hasher,
init_none_hash)
from vllm.v1.core.sched.interface import SchedulerInterface
@ -196,16 +198,10 @@ class EngineCore:
kv_cache_configs = get_kv_cache_configs(vllm_config, kv_cache_specs,
available_gpu_memory)
# All workers have the same kv_cache_config except layer names, so use
# an arbitrary one to initialize the scheduler.
assert all([
cfg.num_blocks == kv_cache_configs[0].num_blocks
for cfg in kv_cache_configs
])
num_gpu_blocks = kv_cache_configs[0].num_blocks
scheduler_kv_cache_config = generate_scheduler_kv_cache_config(
kv_cache_configs)
num_gpu_blocks = scheduler_kv_cache_config.num_blocks
num_cpu_blocks = 0
scheduler_kv_cache_config = kv_cache_configs[0]
# Initialize kv cache and warmup the execution
self.model_executor.initialize_from_config(kv_cache_configs)

View File

@ -234,6 +234,76 @@ class CrossAttentionSpec(AttentionSpec):
return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes
@dataclass(frozen=True)
class UniformTypeKVCacheSpecs(KVCacheSpec):
"""
A KV cache spec for multiple layers with the same type of attention. Here,
same types means always need the same number of token slots. For example,
sliding window attentions with different window sizes are not the same type
and should not be merged into one UniformTypeKVCacheSpecs.
"""
kv_cache_specs: dict[str, KVCacheSpec]
@property
def page_size_bytes(self) -> int:
return sum(spec.page_size_bytes
for spec in self.kv_cache_specs.values())
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_num_pages = max(
cdiv(spec.max_memory_usage_bytes(vllm_config),
spec.page_size_bytes)
for spec in self.kv_cache_specs.values())
return max_num_pages * self.page_size_bytes
@classmethod
def is_uniform_type(cls, kv_cache_specs: dict[str, KVCacheSpec]) -> bool:
"""
Whether all layers have the same type of KV cache spec.
"""
block_sizes = set(spec.block_size for spec in kv_cache_specs.values())
if len(block_sizes) > 1:
# Different block sizes, not uniform.
return False
one_spec = next(iter(kv_cache_specs.values()))
if isinstance(one_spec, (FullAttentionSpec, CrossAttentionSpec)):
return all(
isinstance(spec, type(one_spec))
for spec in kv_cache_specs.values())
elif isinstance(one_spec, SlidingWindowSpec):
return all(
isinstance(spec, SlidingWindowSpec)
and spec.sliding_window == one_spec.sliding_window
for spec in kv_cache_specs.values())
elif isinstance(one_spec, ChunkedLocalAttentionSpec):
return all(
isinstance(spec, ChunkedLocalAttentionSpec)
and spec.attention_chunk_size == one_spec.attention_chunk_size
for spec in kv_cache_specs.values())
elif isinstance(one_spec, MambaSpec):
return all(
isinstance(spec, MambaSpec) and spec.num_speculative_blocks ==
one_spec.num_speculative_blocks
for spec in kv_cache_specs.values())
else:
# NOTE(Chen): Please add new branches for new KV cache spec types.
raise NotImplementedError(
f"Unsupported KV cache spec type: {type(one_spec)}")
@classmethod
def from_specs(cls, kv_cache_specs: dict[str,
KVCacheSpec]) -> Optional[Self]:
"""
Return a SameTypeKVCacheSpecs object if all layers have the same type
of KV cache spec. Return None if not.
"""
if cls.is_uniform_type(kv_cache_specs):
block_size = next(iter(kv_cache_specs.values())).block_size
return cls(block_size=block_size, kv_cache_specs=kv_cache_specs)
else:
return None
@dataclass
class KVCacheTensor:
"""

View File

@ -8,7 +8,7 @@ from collections import defaultdict
from collections.abc import Iterator
from contextlib import contextmanager
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union, cast
import numpy as np
import torch
@ -74,7 +74,8 @@ from vllm.v1.kv_cache_interface import (AttentionSpec,
EncoderOnlyAttentionSpec,
FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheSpec,
MambaSpec, SlidingWindowSpec)
MambaSpec, SlidingWindowSpec,
UniformTypeKVCacheSpecs)
# yapf: enable
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
DraftTokenIds, LogprobsLists, LogprobsTensors,
@ -1187,7 +1188,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
common_prefix_len = self._compute_cascade_attn_prefix_len(
num_scheduled_tokens,
num_common_prefix_blocks,
kv_cache_group_spec.kv_cache_spec,
attn_group.kv_cache_spec,
builder,
)
@ -3453,12 +3454,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
assert len(self.attn_groups) == 0, \
"Attention backends are already initialized"
def get_attn_backends_for_layers(
layer_names: list[str]
) -> dict[type[AttentionBackend], list[str]]:
layers = get_layers_from_vllm_config(self.vllm_config,
AttentionLayerBase,
layer_names)
class AttentionGroupKey(NamedTuple):
attn_backend: type[AttentionBackend]
kv_cache_spec: KVCacheSpec
def get_attn_backends_for_group(
kv_cache_group_spec: KVCacheGroupSpec,
) -> dict[AttentionGroupKey, list[str]]:
layers = get_layers_from_vllm_config(
self.vllm_config, AttentionLayerBase,
kv_cache_group_spec.layer_names)
attn_backends = {}
attn_backend_layers = defaultdict(list)
# Dedupe based on full class name; this is a bit safer than
@ -3466,7 +3471,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# attention backend subclasses (e.g. ChunkedLocalAttention) unless
# they are cached correctly, there will be different objects per
# layer.
for layer_name in layer_names:
for layer_name in kv_cache_group_spec.layer_names:
attn_backend = layers[layer_name].get_attn_backend()
if layer_name in self.kv_sharing_fast_prefill_eligible_layers:
@ -3475,8 +3480,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_backend,
)
key = attn_backend.full_cls_name()
attn_backends[key] = attn_backend
full_cls_name = attn_backend.full_cls_name()
layer_kv_cache_spec = kv_cache_group_spec.kv_cache_spec
if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs):
layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[
layer_name]
key = (full_cls_name, layer_kv_cache_spec)
attn_backends[key] = AttentionGroupKey(attn_backend,
layer_kv_cache_spec)
attn_backend_layers[key].append(layer_name)
return {
attn_backends[k]: v
@ -3484,11 +3495,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
}
def create_attn_groups(
attn_backends_map: dict[AttentionBackend, list[str]],
kv_cache_spec: KVCacheSpec,
attn_backends_map: dict[AttentionGroupKey, list[str]],
) -> list[AttentionGroup]:
attn_groups: list[AttentionGroup] = []
for attn_backend, layer_names in attn_backends_map.items():
for (attn_backend,
kv_cache_spec), layer_names in attn_backends_map.items():
attn_metadata_builders = []
attn_metadata_builders.append(attn_backend.get_builder_cls()(
kv_cache_spec,
@ -3506,16 +3517,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
))
attn_group = AttentionGroup(attn_backend,
attn_metadata_builders,
layer_names)
layer_names, kv_cache_spec)
attn_groups.append(attn_group)
return attn_groups
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
attn_backends = get_attn_backends_for_layers(
kv_cache_group_spec.layer_names)
self.attn_groups.append(
create_attn_groups(attn_backends, kv_cache_spec))
attn_backends = get_attn_backends_for_group(kv_cache_group_spec)
self.attn_groups.append(create_attn_groups(attn_backends))
# Calculate reorder batch threshold (if needed)
self.calculate_reorder_batch_threshold()
@ -3680,14 +3688,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _attn_group_iterator(self) -> Iterator[AttentionGroup]:
return itertools.chain.from_iterable(self.attn_groups)
def _kv_cache_spec_attn_group_iterator(
self) -> Iterator[tuple[KVCacheSpec, AttentionGroup]]:
def _kv_cache_spec_attn_group_iterator(self) -> Iterator[AttentionGroup]:
if not self.kv_cache_config.kv_cache_groups:
return
for kv_cache_spec_id, attn_groups in enumerate(self.attn_groups):
for attn_group in attn_groups:
yield self.kv_cache_config.kv_cache_groups[
kv_cache_spec_id].kv_cache_spec, attn_group
for attn_groups in self.attn_groups:
yield from attn_groups
def _reshape_kv_cache_tensors(
self,
@ -3707,7 +3712,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
"""
kv_caches: dict[str, torch.Tensor] = {}
has_attn, has_mamba = False, False
for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator():
for group in self._kv_cache_spec_attn_group_iterator():
kv_cache_spec = group.kv_cache_spec
attn_backend = group.backend
for layer_name in group.layer_names:
if layer_name in self.runner_only_attn_layers:
@ -3787,7 +3793,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
kv_caches: The KV cache buffer of each layer.
"""
for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator():
for group in self._kv_cache_spec_attn_group_iterator():
kv_cache_spec = group.kv_cache_spec
for layer_name in group.layer_names:
kv_cache = kv_caches[layer_name]
if (isinstance(kv_cache_spec, AttentionSpec)

View File

@ -15,7 +15,7 @@ from vllm.multimodal.registry import MultiModalRegistry
from vllm.platforms import current_platform
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
from vllm.v1.kv_cache_interface import KVCacheGroupSpec
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
if TYPE_CHECKING:
from vllm.attention.layer import Attention
@ -132,6 +132,7 @@ class AttentionGroup:
backend: type[AttentionBackend]
metadata_builders: list[AttentionMetadataBuilder]
layer_names: list[str]
kv_cache_spec: KVCacheSpec
def get_metadata_builder(self,
ubatch_id: Optional[int] = None