diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 4bf6bbbfeae2..4cb7ed6ce382 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -18,12 +18,14 @@ from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.core.kv_cache_utils import ( BlockHash, FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics, estimate_max_model_len, generate_block_hash_extra_keys, - get_kv_cache_configs, get_max_concurrency_for_kv_cache_config, - get_request_block_hasher, hash_block_tokens, init_none_hash, - is_kv_cache_type_uniform, make_block_hash_with_group_id) + generate_scheduler_kv_cache_config, get_kv_cache_configs, + get_max_concurrency_for_kv_cache_config, get_request_block_hasher, + hash_block_tokens, init_none_hash, is_kv_cache_spec_uniform, + make_block_hash_with_group_id) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheSpec, - KVCacheTensor, SlidingWindowSpec) + KVCacheTensor, SlidingWindowSpec, + UniformTypeKVCacheSpecs) from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request @@ -927,36 +929,36 @@ def test_merge_kv_cache_spec(): assert merged_layer_spec.sliding_window == 1 -def test_is_kv_cache_type_uniform(): +def test_is_kv_cache_spec_uniform(): kv_cache_spec = { "layer_1": new_kv_cache_spec(num_kv_heads=32), "layer_2": new_kv_cache_spec(num_kv_heads=32), } - assert is_kv_cache_type_uniform(kv_cache_spec) + assert is_kv_cache_spec_uniform(kv_cache_spec) kv_cache_spec = { "layer_1": new_kv_cache_spec(num_kv_heads=32), "layer_2": new_kv_cache_spec(num_kv_heads=32, sliding_window=1), } - assert is_kv_cache_type_uniform(kv_cache_spec) + assert is_kv_cache_spec_uniform(kv_cache_spec) kv_cache_spec = { "layer_1": new_kv_cache_spec(num_kv_heads=32), "layer_2": new_sliding_window_spec(num_kv_heads=32, sliding_window=1), } - assert not is_kv_cache_type_uniform(kv_cache_spec) + assert not is_kv_cache_spec_uniform(kv_cache_spec) kv_cache_spec = { "layer_1": new_sliding_window_spec(num_kv_heads=32, sliding_window=1), "layer_2": new_sliding_window_spec(num_kv_heads=32, sliding_window=1), } - assert is_kv_cache_type_uniform(kv_cache_spec) + assert is_kv_cache_spec_uniform(kv_cache_spec) kv_cache_spec = { "layer_1": new_sliding_window_spec(num_kv_heads=32, sliding_window=1), "layer_2": new_sliding_window_spec(num_kv_heads=32, sliding_window=2), } - assert not is_kv_cache_type_uniform(kv_cache_spec) + assert not is_kv_cache_spec_uniform(kv_cache_spec) @pytest.mark.parametrize( @@ -1286,14 +1288,28 @@ def test_get_kv_cache_config_one_worker(): ], ) - # different hidden size, unimplemented + # different hidden size kv_cache_specs_hybrid = { 'layer_1': new_kv_cache_spec(head_size=128), - 'layer_2': new_kv_cache_spec(), + 'layer_2': new_kv_cache_spec(head_size=64), } - with pytest.raises(NotImplementedError): - get_kv_cache_configs(vllm_config, [kv_cache_specs_hybrid], - [mem_per_block_per_layer * 2 * 32])[0] + kv_cache_config_hybrid = get_kv_cache_configs( + vllm_config, [kv_cache_specs_hybrid], + [mem_per_block_per_layer * 3 * 32])[0] + assert kv_cache_config_hybrid == KVCacheConfig( + num_blocks=32, + kv_cache_tensors=[ + KVCacheTensor(size=mem_per_block_per_layer * 32 * 2, + shared_by=["layer_1"]), + KVCacheTensor(size=mem_per_block_per_layer * 32, + shared_by=["layer_2"]), + ], + kv_cache_groups=[ + KVCacheGroupSpec(["layer_1", "layer_2"], + UniformTypeKVCacheSpecs( + block_size=16, + kv_cache_specs=kv_cache_specs_hybrid)) + ]) # Test num_gpu_blocks_override vllm_config.cache_config.num_gpu_blocks_override = 16 @@ -1324,3 +1340,75 @@ def test_get_kv_cache_configs_attention_free(): kv_cache_groups=[], ) ] + + +def test_generate_uniform_type_kv_cache_specs(): + # All layers are full attention, can be merged + kv_cache_specs = { + 'layer_1': new_kv_cache_spec(), + 'layer_2': new_kv_cache_spec(head_size=128), + } + uniform_spec = UniformTypeKVCacheSpecs.from_specs(kv_cache_specs) + assert uniform_spec == UniformTypeKVCacheSpecs( + block_size=16, kv_cache_specs=kv_cache_specs) + + # Full attention + sliding window, cannot be merged + kv_cache_specs = { + 'layer_1': new_kv_cache_spec(), + 'layer_2': new_sliding_window_spec(sliding_window=1), + } + uniform_spec = UniformTypeKVCacheSpecs.from_specs(kv_cache_specs) + assert uniform_spec is None + + # different order of full attention + sliding window, cannot be merged + kv_cache_specs = { + 'layer_1': new_sliding_window_spec(sliding_window=1), + 'layer_2': new_kv_cache_spec(), + } + uniform_spec = UniformTypeKVCacheSpecs.from_specs(kv_cache_specs) + assert uniform_spec is None + + # Same-size sliding window, can be merged + kv_cache_specs = { + 'layer_1': new_sliding_window_spec(sliding_window=1), + 'layer_2': new_sliding_window_spec(sliding_window=1, head_size=128), + } + uniform_spec = UniformTypeKVCacheSpecs.from_specs(kv_cache_specs) + assert uniform_spec == UniformTypeKVCacheSpecs( + block_size=16, kv_cache_specs=kv_cache_specs) + + # different block sizes, cannot be merged + kv_cache_specs = { + 'layer_1': new_kv_cache_spec(block_size=16), + 'layer_2': new_kv_cache_spec(block_size=32), + } + uniform_spec = UniformTypeKVCacheSpecs.from_specs(kv_cache_specs) + assert uniform_spec is None + + +def test_generate_scheduler_kv_cache_config(): + kv_cache_specs = { + 'layer_1': new_kv_cache_spec(), + 'layer_2': new_kv_cache_spec(head_size=128), + } + kv_cache_configs = [ + KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec(['layer_1', 'layer_2'], + UniformTypeKVCacheSpecs( + block_size=16, + kv_cache_specs=kv_cache_specs)), + ], + ) + ] + scheduler_kv_cache_config = generate_scheduler_kv_cache_config( + kv_cache_configs) + assert scheduler_kv_cache_config == KVCacheConfig( + num_blocks=10, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec(['layer_1', 'layer_2'], new_kv_cache_spec()) + ], + ) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index bc2ec5e42ea2..3ccd00121f8e 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """KV-Cache Utilities.""" +import copy import os from collections import defaultdict, deque from collections.abc import Iterable, Sequence @@ -15,7 +16,8 @@ from vllm.utils import GiB_bytes, cdiv, sha256_cbor from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheSpec, - KVCacheTensor, SlidingWindowSpec) + KVCacheTensor, SlidingWindowSpec, + UniformTypeKVCacheSpecs) from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request @@ -750,7 +752,7 @@ def create_kv_cache_group_specs( return kv_cache_groups -def is_kv_cache_type_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: +def is_kv_cache_spec_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: """ Whether all layers in the given KVCacheSpec have the same KV cache spec. Note that we regard FullAttentionSpec with and without sliding window as @@ -793,6 +795,21 @@ def get_max_concurrency_for_kv_cache_config( return max_concurrency +def may_override_num_blocks(vllm_config: VllmConfig, num_blocks: int) -> int: + """ + Override the number of kv cache blocks if `num_gpu_blocks_override` is set. + """ + if vllm_config.cache_config.num_gpu_blocks_override is not None: + num_gpu_blocks_override = \ + vllm_config.cache_config.num_gpu_blocks_override + logger.info( + "Overriding num_gpu_blocks=%d with " + "num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override) + num_blocks = num_gpu_blocks_override + + return num_blocks + + def get_num_blocks(vllm_config: VllmConfig, num_layers: int, available_memory: int, page_size: int) -> int: """ @@ -806,13 +823,7 @@ def get_num_blocks(vllm_config: VllmConfig, num_layers: int, """ num_blocks = int(available_memory // page_size // num_layers) num_blocks = max(num_blocks, 0) - if vllm_config.cache_config.num_gpu_blocks_override is not None: - num_gpu_blocks_override = \ - vllm_config.cache_config.num_gpu_blocks_override - logger.info( - "Overriding num_gpu_blocks=%d with " - "num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override) - num_blocks = num_gpu_blocks_override + num_blocks = may_override_num_blocks(vllm_config, num_blocks) return num_blocks @@ -825,11 +836,11 @@ def get_uniform_page_size(kv_cache_spec: dict[str, KVCacheSpec]) -> int: return page_sizes.pop() -def _get_kv_cache_groups_uniform_type( +def _get_kv_cache_groups_uniform_spec( kv_cache_specs: dict[str, KVCacheSpec]) -> list[KVCacheGroupSpec]: """ - Generates the KV cache configuration for a model with one type of KV cache. - Divide the available memory equally among all layers. + Generates the KV cache configuration for a model with the same KV cache + spec for all layers. Args: kv_cache_specs: The kv cache spec of each attention layer in the model @@ -842,6 +853,22 @@ def _get_kv_cache_groups_uniform_type( [list(kv_cache_specs.keys())]) +def _get_kv_cache_groups_uniform_type( + spec: UniformTypeKVCacheSpecs) -> list[KVCacheGroupSpec]: + """ + Generates the KV cache configuration for a model with one type of KV cache + but different hidden sizes. All layers are merged into one group. + + Args: + spec: The UniformTypeKVCacheSpecs of the model + + Returns: + The generated KVCacheGroupSpecs + """ + + return [KVCacheGroupSpec(list(spec.kv_cache_specs.keys()), spec)] + + def is_kv_cache_page_size_uniform( kv_cache_spec: dict[str, KVCacheSpec]) -> bool: """ @@ -1000,28 +1027,45 @@ def get_kv_cache_config_from_groups(vllm_config: VllmConfig, ) # Determine how model runners should initialize the KV cache tensors. - # We will have group_size memory pools, each is shared by one layer from - # each group. As layers of different groups have different block table, - # they will use different parts of the shared Tensor. - # The memory layout for 3 groups (full.0, full.1), (sw.0, sw.2), - # (sw.1, padding) will be: (group_size = 2) - # full.0, sw.0, sw.1: share a Tensor with size=available_memory//2 - # full.1, sw.2: share another Tensor with size=available_memory//2 - group_size = max(len(group.layer_names) for group in kv_cache_groups) + if len(kv_cache_groups) == 1 and \ + isinstance(kv_cache_groups[0].kv_cache_spec, UniformTypeKVCacheSpecs): + # Special case: all layers have the same type of KV cache but with + # different hidden size. Allocate different amount of memory for each + # layer based on its hidden size. + num_blocks = available_memory // kv_cache_groups[ + 0].kv_cache_spec.page_size_bytes + num_blocks = may_override_num_blocks(vllm_config, num_blocks) + per_layer_specs = kv_cache_groups[0].kv_cache_spec.kv_cache_specs + kv_cache_tensors = [ + KVCacheTensor(size=per_layer_specs[layer_name].page_size_bytes * + num_blocks, + shared_by=[layer_name]) + for layer_name in kv_cache_groups[0].layer_names + ] + else: + # General case: + # We will have group_size memory pools, each is shared by one layer from + # each group. As layers of different groups have different block table, + # they will use different parts of the shared Tensor. + # The memory layout for 3 groups (full.0, full.1), (sw.0, sw.2), + # (sw.1, padding) will be: (group_size = 2) + # full.0, sw.0, sw.1: share a Tensor with size=available_memory//2 + # full.1, sw.2: share another Tensor with size=available_memory//2 + group_size = max(len(group.layer_names) for group in kv_cache_groups) - page_size = get_uniform_page_size(kv_cache_specs) - assert group_size > 0, "group_size must be greater than 0" - num_blocks = get_num_blocks(vllm_config, group_size, available_memory, - page_size) - per_memory_pool_size = page_size * num_blocks - kv_cache_tensors = [] - for i in range(group_size): - shared_by = [] - for j in range(len(kv_cache_groups)): - if i < len(kv_cache_groups[j].layer_names): - shared_by.append(kv_cache_groups[j].layer_names[i]) - kv_cache_tensors.append( - KVCacheTensor(size=per_memory_pool_size, shared_by=shared_by)) + page_size = get_uniform_page_size(kv_cache_specs) + assert group_size > 0, "group_size must be greater than 0" + num_blocks = get_num_blocks(vllm_config, group_size, available_memory, + page_size) + kv_cache_tensors = [] + for i in range(group_size): + shared_by = [] + for j in range(len(kv_cache_groups)): + if i < len(kv_cache_groups[j].layer_names): + shared_by.append(kv_cache_groups[j].layer_names[i]) + kv_cache_tensors.append( + KVCacheTensor(size=page_size * num_blocks, + shared_by=shared_by)) kv_cache_config = KVCacheConfig( num_blocks=num_blocks, @@ -1059,7 +1103,7 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): kv_cache_spec: The kv cache spec of each attention layer in the model """ - if is_kv_cache_type_uniform(kv_cache_spec): + if is_kv_cache_spec_uniform(kv_cache_spec): return logger.warning( @@ -1097,7 +1141,7 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): attention_chunk_size=spec.attention_chunk_size, ) - if not is_kv_cache_type_uniform(kv_cache_spec): + if not is_kv_cache_spec_uniform(kv_cache_spec): raise ValueError("Hybrid KV cache manager is disabled but failed to " "convert the KV cache specs to one unified type.") @@ -1122,11 +1166,16 @@ def get_kv_cache_groups( # This returns an empty list to allow for the KVCacheManager to handle # attention free models. return [] - elif is_kv_cache_type_uniform(kv_cache_spec): + elif is_kv_cache_spec_uniform(kv_cache_spec): # KV cache of all layers are the same, which is true for # most models. Allocate the same amount of memory for # each layer. - return _get_kv_cache_groups_uniform_type(kv_cache_spec) + return _get_kv_cache_groups_uniform_spec(kv_cache_spec) + elif uniform_spec := UniformTypeKVCacheSpecs.from_specs(kv_cache_spec): + # All layers need the same number of token slots (e.g., all layers are + # full attention, or all layers are sliding window attention with the + # same window size). Put all layers into one group. + return _get_kv_cache_groups_uniform_type(uniform_spec) elif is_kv_cache_page_size_uniform(kv_cache_spec): # Model contains multiple attention types, but KV cache of all layers # have the same physical memory per block per layer. Split the layers @@ -1137,6 +1186,27 @@ def get_kv_cache_groups( raise NotImplementedError +def generate_scheduler_kv_cache_config( + kv_cache_configs: list[KVCacheConfig]) -> KVCacheConfig: + """ + Generate the KV cache configuration for the scheduler. + """ + assert all([ + cfg.num_blocks == kv_cache_configs[0].num_blocks + for cfg in kv_cache_configs + ]) + # All workers have the same kv_cache_config except layer names, so use + # an arbitrary one to initialize the scheduler. + cfg = copy.deepcopy(kv_cache_configs[0]) + for group in cfg.kv_cache_groups: + if isinstance(group.kv_cache_spec, UniformTypeKVCacheSpecs): + # All layers in the UniformTypeKVCacheSpecs have the same type, + # so use an arbitrary one to initialize the scheduler. + group.kv_cache_spec = next( + iter(group.kv_cache_spec.kv_cache_specs.values())) + return cfg + + def get_kv_cache_configs(vllm_config: VllmConfig, kv_cache_specs: list[dict[str, KVCacheSpec]], available_memory: list[int]) -> list[KVCacheConfig]: diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index a022e9c0d705..a43042a5510a 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -29,7 +29,9 @@ from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) from vllm.utils import (decorate_logs, get_hash_fn_by_name, make_zmq_socket, resolve_obj_by_qualname, set_process_title) -from vllm.v1.core.kv_cache_utils import (BlockHash, get_kv_cache_configs, +from vllm.v1.core.kv_cache_utils import (BlockHash, + generate_scheduler_kv_cache_config, + get_kv_cache_configs, get_request_block_hasher, init_none_hash) from vllm.v1.core.sched.interface import SchedulerInterface @@ -196,16 +198,10 @@ class EngineCore: kv_cache_configs = get_kv_cache_configs(vllm_config, kv_cache_specs, available_gpu_memory) - - # All workers have the same kv_cache_config except layer names, so use - # an arbitrary one to initialize the scheduler. - assert all([ - cfg.num_blocks == kv_cache_configs[0].num_blocks - for cfg in kv_cache_configs - ]) - num_gpu_blocks = kv_cache_configs[0].num_blocks + scheduler_kv_cache_config = generate_scheduler_kv_cache_config( + kv_cache_configs) + num_gpu_blocks = scheduler_kv_cache_config.num_blocks num_cpu_blocks = 0 - scheduler_kv_cache_config = kv_cache_configs[0] # Initialize kv cache and warmup the execution self.model_executor.initialize_from_config(kv_cache_configs) diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 0cf92a680a68..f72cc8f93a6c 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -234,6 +234,76 @@ class CrossAttentionSpec(AttentionSpec): return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes +@dataclass(frozen=True) +class UniformTypeKVCacheSpecs(KVCacheSpec): + """ + A KV cache spec for multiple layers with the same type of attention. Here, + same types means always need the same number of token slots. For example, + sliding window attentions with different window sizes are not the same type + and should not be merged into one UniformTypeKVCacheSpecs. + """ + kv_cache_specs: dict[str, KVCacheSpec] + + @property + def page_size_bytes(self) -> int: + return sum(spec.page_size_bytes + for spec in self.kv_cache_specs.values()) + + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: + max_num_pages = max( + cdiv(spec.max_memory_usage_bytes(vllm_config), + spec.page_size_bytes) + for spec in self.kv_cache_specs.values()) + return max_num_pages * self.page_size_bytes + + @classmethod + def is_uniform_type(cls, kv_cache_specs: dict[str, KVCacheSpec]) -> bool: + """ + Whether all layers have the same type of KV cache spec. + """ + block_sizes = set(spec.block_size for spec in kv_cache_specs.values()) + if len(block_sizes) > 1: + # Different block sizes, not uniform. + return False + one_spec = next(iter(kv_cache_specs.values())) + if isinstance(one_spec, (FullAttentionSpec, CrossAttentionSpec)): + return all( + isinstance(spec, type(one_spec)) + for spec in kv_cache_specs.values()) + elif isinstance(one_spec, SlidingWindowSpec): + return all( + isinstance(spec, SlidingWindowSpec) + and spec.sliding_window == one_spec.sliding_window + for spec in kv_cache_specs.values()) + elif isinstance(one_spec, ChunkedLocalAttentionSpec): + return all( + isinstance(spec, ChunkedLocalAttentionSpec) + and spec.attention_chunk_size == one_spec.attention_chunk_size + for spec in kv_cache_specs.values()) + elif isinstance(one_spec, MambaSpec): + return all( + isinstance(spec, MambaSpec) and spec.num_speculative_blocks == + one_spec.num_speculative_blocks + for spec in kv_cache_specs.values()) + else: + # NOTE(Chen): Please add new branches for new KV cache spec types. + raise NotImplementedError( + f"Unsupported KV cache spec type: {type(one_spec)}") + + @classmethod + def from_specs(cls, kv_cache_specs: dict[str, + KVCacheSpec]) -> Optional[Self]: + """ + Return a SameTypeKVCacheSpecs object if all layers have the same type + of KV cache spec. Return None if not. + """ + if cls.is_uniform_type(kv_cache_specs): + block_size = next(iter(kv_cache_specs.values())).block_size + return cls(block_size=block_size, kv_cache_specs=kv_cache_specs) + else: + return None + + @dataclass class KVCacheTensor: """ diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index dffadd1d769b..233df8f1b0e9 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -8,7 +8,7 @@ from collections import defaultdict from collections.abc import Iterator from contextlib import contextmanager from copy import deepcopy -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union, cast import numpy as np import torch @@ -74,7 +74,8 @@ from vllm.v1.kv_cache_interface import (AttentionSpec, EncoderOnlyAttentionSpec, FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheSpec, - MambaSpec, SlidingWindowSpec) + MambaSpec, SlidingWindowSpec, + UniformTypeKVCacheSpecs) # yapf: enable from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, DraftTokenIds, LogprobsLists, LogprobsTensors, @@ -1187,7 +1188,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): common_prefix_len = self._compute_cascade_attn_prefix_len( num_scheduled_tokens, num_common_prefix_blocks, - kv_cache_group_spec.kv_cache_spec, + attn_group.kv_cache_spec, builder, ) @@ -3453,12 +3454,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): assert len(self.attn_groups) == 0, \ "Attention backends are already initialized" - def get_attn_backends_for_layers( - layer_names: list[str] - ) -> dict[type[AttentionBackend], list[str]]: - layers = get_layers_from_vllm_config(self.vllm_config, - AttentionLayerBase, - layer_names) + class AttentionGroupKey(NamedTuple): + attn_backend: type[AttentionBackend] + kv_cache_spec: KVCacheSpec + + def get_attn_backends_for_group( + kv_cache_group_spec: KVCacheGroupSpec, + ) -> dict[AttentionGroupKey, list[str]]: + layers = get_layers_from_vllm_config( + self.vllm_config, AttentionLayerBase, + kv_cache_group_spec.layer_names) attn_backends = {} attn_backend_layers = defaultdict(list) # Dedupe based on full class name; this is a bit safer than @@ -3466,7 +3471,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # attention backend subclasses (e.g. ChunkedLocalAttention) unless # they are cached correctly, there will be different objects per # layer. - for layer_name in layer_names: + for layer_name in kv_cache_group_spec.layer_names: attn_backend = layers[layer_name].get_attn_backend() if layer_name in self.kv_sharing_fast_prefill_eligible_layers: @@ -3475,8 +3480,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): attn_backend, ) - key = attn_backend.full_cls_name() - attn_backends[key] = attn_backend + full_cls_name = attn_backend.full_cls_name() + layer_kv_cache_spec = kv_cache_group_spec.kv_cache_spec + if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs): + layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[ + layer_name] + key = (full_cls_name, layer_kv_cache_spec) + attn_backends[key] = AttentionGroupKey(attn_backend, + layer_kv_cache_spec) attn_backend_layers[key].append(layer_name) return { attn_backends[k]: v @@ -3484,11 +3495,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): } def create_attn_groups( - attn_backends_map: dict[AttentionBackend, list[str]], - kv_cache_spec: KVCacheSpec, + attn_backends_map: dict[AttentionGroupKey, list[str]], ) -> list[AttentionGroup]: attn_groups: list[AttentionGroup] = [] - for attn_backend, layer_names in attn_backends_map.items(): + for (attn_backend, + kv_cache_spec), layer_names in attn_backends_map.items(): attn_metadata_builders = [] attn_metadata_builders.append(attn_backend.get_builder_cls()( kv_cache_spec, @@ -3506,16 +3517,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): )) attn_group = AttentionGroup(attn_backend, attn_metadata_builders, - layer_names) + layer_names, kv_cache_spec) attn_groups.append(attn_group) return attn_groups for kv_cache_group_spec in kv_cache_config.kv_cache_groups: - kv_cache_spec = kv_cache_group_spec.kv_cache_spec - attn_backends = get_attn_backends_for_layers( - kv_cache_group_spec.layer_names) - self.attn_groups.append( - create_attn_groups(attn_backends, kv_cache_spec)) + attn_backends = get_attn_backends_for_group(kv_cache_group_spec) + self.attn_groups.append(create_attn_groups(attn_backends)) # Calculate reorder batch threshold (if needed) self.calculate_reorder_batch_threshold() @@ -3680,14 +3688,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def _attn_group_iterator(self) -> Iterator[AttentionGroup]: return itertools.chain.from_iterable(self.attn_groups) - def _kv_cache_spec_attn_group_iterator( - self) -> Iterator[tuple[KVCacheSpec, AttentionGroup]]: + def _kv_cache_spec_attn_group_iterator(self) -> Iterator[AttentionGroup]: if not self.kv_cache_config.kv_cache_groups: return - for kv_cache_spec_id, attn_groups in enumerate(self.attn_groups): - for attn_group in attn_groups: - yield self.kv_cache_config.kv_cache_groups[ - kv_cache_spec_id].kv_cache_spec, attn_group + for attn_groups in self.attn_groups: + yield from attn_groups def _reshape_kv_cache_tensors( self, @@ -3707,7 +3712,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): """ kv_caches: dict[str, torch.Tensor] = {} has_attn, has_mamba = False, False - for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator(): + for group in self._kv_cache_spec_attn_group_iterator(): + kv_cache_spec = group.kv_cache_spec attn_backend = group.backend for layer_name in group.layer_names: if layer_name in self.runner_only_attn_layers: @@ -3787,7 +3793,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): kv_caches: The KV cache buffer of each layer. """ - for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator(): + for group in self._kv_cache_spec_attn_group_iterator(): + kv_cache_spec = group.kv_cache_spec for layer_name in group.layer_names: kv_cache = kv_caches[layer_name] if (isinstance(kv_cache_spec, AttentionSpec) diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index b76ac633892f..021d18b2500f 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -15,7 +15,7 @@ from vllm.multimodal.registry import MultiModalRegistry from vllm.platforms import current_platform from vllm.v1.attention.backends.utils import AttentionMetadataBuilder from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget -from vllm.v1.kv_cache_interface import KVCacheGroupSpec +from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec if TYPE_CHECKING: from vllm.attention.layer import Attention @@ -132,6 +132,7 @@ class AttentionGroup: backend: type[AttentionBackend] metadata_builders: list[AttentionMetadataBuilder] layer_names: list[str] + kv_cache_spec: KVCacheSpec def get_metadata_builder(self, ubatch_id: Optional[int] = None