diff --git a/tests/v1/test_kv_sharing.py b/tests/v1/test_kv_sharing.py new file mode 100644 index 0000000000000..6b01b7d3e1d6c --- /dev/null +++ b/tests/v1/test_kv_sharing.py @@ -0,0 +1,189 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from unittest.mock import Mock + +import torch + +from vllm.v1.attention.backends.flash_attn import ( + FlashAttentionBackend, FlashAttentionMetadataBuilder) +from vllm.v1.attention.backends.flex_attention import ( + FlexAttentionBackend, FlexAttentionMetadataBuilder) +from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheGroupSpec +from vllm.v1.worker.utils import (AttentionGroup, + initialize_kv_cache_for_kv_sharing) + + +def new_kv_cache_spec(): + return FullAttentionSpec(16, 1, 1, torch.float32, False) + + +def test_initialize_kv_cache_for_kv_sharing_different_attn_groups(): + """ + Test initializing KV cache sharing with different attention groups. + Layers in the same KV cache group might be placed in different attn groups + if they have different attention backends. + """ + shared_kv_cache_layers = { + "model.layers.2": "model.layers.0", + "model.layers.3": "model.layers.1", + } + + # Layers 0 and 1 both belong in KV cache group 0 + # However, if they have have different attention backends, they will be + # placed in different attention groups for KV cache group 0 + kv_cache_groups = [ + KVCacheGroupSpec(["model.layers.0", "model.layers.1"], + new_kv_cache_spec()), + ] + + attn_groups = [ + # KV cache group 0 has two attention groups + [ + AttentionGroup( + backend=FlashAttentionBackend, + metadata_builder=Mock(spec=FlashAttentionMetadataBuilder), + layer_names=["model.layers.0"], + ), + AttentionGroup( + backend=FlexAttentionBackend, + metadata_builder=Mock(spec=FlexAttentionMetadataBuilder), + layer_names=["model.layers.1"], + ), + ], + ] + + # Only layers 0 and 1 will have KV caches allocated + kv_caches = { + "model.layers.0": torch.zeros(1, 2, 3), + "model.layers.1": torch.ones(1, 2, 3), + } + + initialize_kv_cache_for_kv_sharing( + shared_kv_cache_layers=shared_kv_cache_layers, + kv_cache_groups=kv_cache_groups, + kv_caches=kv_caches, + attn_groups=attn_groups, + ) + + # Check that the KV caches were shared correctly + assert kv_caches["model.layers.2"].data_ptr( + ) == kv_caches["model.layers.0"].data_ptr() + assert kv_caches["model.layers.3"].data_ptr( + ) == kv_caches["model.layers.1"].data_ptr() + + # Check that the layers were added to the correct KV cache group + assert len(kv_cache_groups) == 1 + assert kv_cache_groups[0].layer_names == [ + "model.layers.0", "model.layers.1", "model.layers.2", "model.layers.3" + ] + + # Check that the layers were added to the attention groups + assert len(attn_groups) == 1 and len(attn_groups[0]) == 2 + assert attn_groups[0][0].layer_names == [ + "model.layers.0", "model.layers.2" + ] + assert attn_groups[0][1].layer_names == [ + "model.layers.1", "model.layers.3" + ] + + +def test_initialize_kv_cache_for_kv_sharing_same_attn_groups(): + """ + Test case assuming that all layers in the same KV cache group have the same + attention backends. This is true for most models. + """ + shared_kv_cache_layers = { + "model.layers.2": "model.layers.0", + "model.layers.3": "model.layers.1", + } + + kv_cache_groups = [ + KVCacheGroupSpec(["model.layers.0", "model.layers.1"], + new_kv_cache_spec()), + ] + + attn_groups = [ + # KV cache group 0 has a single attention group + # as all layers have the same flash attention backend + [ + AttentionGroup( + backend=FlashAttentionBackend, + metadata_builder=Mock(spec=FlashAttentionMetadataBuilder), + layer_names=["model.layers.0", "model.layers.1"], + ), + ], + ] + + kv_caches = { + "model.layers.0": torch.zeros(1, 2, 3), + "model.layers.1": torch.ones(1, 2, 3), + } + + initialize_kv_cache_for_kv_sharing( + shared_kv_cache_layers=shared_kv_cache_layers, + kv_cache_groups=kv_cache_groups, + kv_caches=kv_caches, + attn_groups=attn_groups, + ) + + # Check that the KV caches were shared correctly + assert kv_caches["model.layers.2"].data_ptr( + ) == kv_caches["model.layers.0"].data_ptr() + assert kv_caches["model.layers.3"].data_ptr( + ) == kv_caches["model.layers.1"].data_ptr() + + # Check that the layers were added to the correct KV cache group + assert len(kv_cache_groups) == 1 + assert kv_cache_groups[0].layer_names == [ + "model.layers.0", "model.layers.1", "model.layers.2", "model.layers.3" + ] + + # Check that the layers were added to the attention groups + assert len(attn_groups) == 1 and len(attn_groups[0]) == 1 + assert attn_groups[0][0].layer_names == [ + "model.layers.0", "model.layers.1", "model.layers.2", "model.layers.3" + ] + + +def test_initialize_kv_cache_for_kv_sharing_no_attn_groups(): + """ + Test KV sharing set up when no attention groups are provided. + This is the case for the TPU model runner, which doesn't have + support for attention groups yet. + """ + shared_kv_cache_layers = { + "model.layers.2": "model.layers.0", + "model.layers.3": "model.layers.1", + } + + kv_cache_groups = [ + KVCacheGroupSpec(["model.layers.0"], new_kv_cache_spec()), + KVCacheGroupSpec(["model.layers.1"], new_kv_cache_spec()), + ] + + kv_caches = { + "model.layers.0": torch.zeros(1, 2, 3), + "model.layers.1": torch.ones(1, 2, 3), + } + + initialize_kv_cache_for_kv_sharing( + shared_kv_cache_layers=shared_kv_cache_layers, + kv_cache_groups=kv_cache_groups, + kv_caches=kv_caches, + ) + + # Check that the KV caches were shared correctly + assert kv_caches["model.layers.2"].data_ptr( + ) == kv_caches["model.layers.0"].data_ptr() + assert kv_caches["model.layers.3"].data_ptr( + ) == kv_caches["model.layers.1"].data_ptr() + + # Check that the layers were added to the correct KV cache group + assert len(kv_cache_groups) == 2 + assert kv_cache_groups[0].layer_names == [ + "model.layers.0", "model.layers.2" + ] + assert kv_cache_groups[1].layer_names == [ + "model.layers.1", "model.layers.3" + ] diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index e7079235d6510..b138f11af1eb1 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -225,26 +225,34 @@ def initialize_kv_cache_for_kv_sharing( Note that layers in shared_kv_cache_layers.keys() are not originally included as it only contains layers which have its own KV cache allocation. + attn_groups: Optional list of attention groups. Layers in the same KV + cache group may be placed in different attention groups if they + have different attention backends. Currently only provided by + GPU model runner. """ - # Record index of KV cache group for each layer that allocates a KV cache. - layer_to_kv_cache_group_idx: dict[str, int] = {} - for i, kv_cache_group in enumerate(kv_cache_groups): - for layer_name in kv_cache_group.layer_names: - layer_to_kv_cache_group_idx[layer_name] = i + # mapping from layer name to tuple of (kv_cache_group_idx, attn_group_idx) + layer_to_attn_group_idx: dict[str, tuple[int, int]] = {} + if attn_groups: + for kv_cache_group_idx, kv_attn_groups in enumerate(attn_groups): + for attn_group_idx, attn_group in enumerate(kv_attn_groups): + for layer_name in attn_group.layer_names: + layer_to_attn_group_idx[layer_name] = (kv_cache_group_idx, + attn_group_idx) + else: + for kv_cache_group_idx, kv_cache_group in enumerate(kv_cache_groups): + for layer_name in kv_cache_group.layer_names: + # attn group idx default to 0 if not provided + layer_to_attn_group_idx[layer_name] = (kv_cache_group_idx, 0) for layer_name, target_layer_name in shared_kv_cache_layers.items(): kv_caches[layer_name] = kv_caches[target_layer_name] - group_idx = layer_to_kv_cache_group_idx[target_layer_name] - kv_cache_groups[group_idx].layer_names.append(layer_name) + kv_cache_group_idx = layer_to_attn_group_idx[target_layer_name][0] + kv_cache_groups[kv_cache_group_idx].layer_names.append(layer_name) - if attn_groups is not None: - assert len(attn_groups[group_idx]) == 1, ( - "Only one attention group per KV cache group is supported " - "for KV-cache sharing for now.") - # TODO(lucas): I think in the future the layers that re-use a - # KV cache will be in a different attention group so we can - # remove this code from here. - attn_groups[group_idx][0].layer_names.append(layer_name) + if attn_groups: + attn_group_idx = layer_to_attn_group_idx[target_layer_name][1] + attn_groups[kv_cache_group_idx][attn_group_idx].layer_names.append( + layer_name) def bind_kv_cache(