mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:45:00 +08:00
Support multiple attention groups for KV sharing (#22672)
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
parent
c280066f9d
commit
3e2f7985a2
189
tests/v1/test_kv_sharing.py
Normal file
189
tests/v1/test_kv_sharing.py
Normal file
@ -0,0 +1,189 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.v1.attention.backends.flash_attn import (
|
||||
FlashAttentionBackend, FlashAttentionMetadataBuilder)
|
||||
from vllm.v1.attention.backends.flex_attention import (
|
||||
FlexAttentionBackend, FlexAttentionMetadataBuilder)
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheGroupSpec
|
||||
from vllm.v1.worker.utils import (AttentionGroup,
|
||||
initialize_kv_cache_for_kv_sharing)
|
||||
|
||||
|
||||
def new_kv_cache_spec():
|
||||
return FullAttentionSpec(16, 1, 1, torch.float32, False)
|
||||
|
||||
|
||||
def test_initialize_kv_cache_for_kv_sharing_different_attn_groups():
|
||||
"""
|
||||
Test initializing KV cache sharing with different attention groups.
|
||||
Layers in the same KV cache group might be placed in different attn groups
|
||||
if they have different attention backends.
|
||||
"""
|
||||
shared_kv_cache_layers = {
|
||||
"model.layers.2": "model.layers.0",
|
||||
"model.layers.3": "model.layers.1",
|
||||
}
|
||||
|
||||
# Layers 0 and 1 both belong in KV cache group 0
|
||||
# However, if they have have different attention backends, they will be
|
||||
# placed in different attention groups for KV cache group 0
|
||||
kv_cache_groups = [
|
||||
KVCacheGroupSpec(["model.layers.0", "model.layers.1"],
|
||||
new_kv_cache_spec()),
|
||||
]
|
||||
|
||||
attn_groups = [
|
||||
# KV cache group 0 has two attention groups
|
||||
[
|
||||
AttentionGroup(
|
||||
backend=FlashAttentionBackend,
|
||||
metadata_builder=Mock(spec=FlashAttentionMetadataBuilder),
|
||||
layer_names=["model.layers.0"],
|
||||
),
|
||||
AttentionGroup(
|
||||
backend=FlexAttentionBackend,
|
||||
metadata_builder=Mock(spec=FlexAttentionMetadataBuilder),
|
||||
layer_names=["model.layers.1"],
|
||||
),
|
||||
],
|
||||
]
|
||||
|
||||
# Only layers 0 and 1 will have KV caches allocated
|
||||
kv_caches = {
|
||||
"model.layers.0": torch.zeros(1, 2, 3),
|
||||
"model.layers.1": torch.ones(1, 2, 3),
|
||||
}
|
||||
|
||||
initialize_kv_cache_for_kv_sharing(
|
||||
shared_kv_cache_layers=shared_kv_cache_layers,
|
||||
kv_cache_groups=kv_cache_groups,
|
||||
kv_caches=kv_caches,
|
||||
attn_groups=attn_groups,
|
||||
)
|
||||
|
||||
# Check that the KV caches were shared correctly
|
||||
assert kv_caches["model.layers.2"].data_ptr(
|
||||
) == kv_caches["model.layers.0"].data_ptr()
|
||||
assert kv_caches["model.layers.3"].data_ptr(
|
||||
) == kv_caches["model.layers.1"].data_ptr()
|
||||
|
||||
# Check that the layers were added to the correct KV cache group
|
||||
assert len(kv_cache_groups) == 1
|
||||
assert kv_cache_groups[0].layer_names == [
|
||||
"model.layers.0", "model.layers.1", "model.layers.2", "model.layers.3"
|
||||
]
|
||||
|
||||
# Check that the layers were added to the attention groups
|
||||
assert len(attn_groups) == 1 and len(attn_groups[0]) == 2
|
||||
assert attn_groups[0][0].layer_names == [
|
||||
"model.layers.0", "model.layers.2"
|
||||
]
|
||||
assert attn_groups[0][1].layer_names == [
|
||||
"model.layers.1", "model.layers.3"
|
||||
]
|
||||
|
||||
|
||||
def test_initialize_kv_cache_for_kv_sharing_same_attn_groups():
|
||||
"""
|
||||
Test case assuming that all layers in the same KV cache group have the same
|
||||
attention backends. This is true for most models.
|
||||
"""
|
||||
shared_kv_cache_layers = {
|
||||
"model.layers.2": "model.layers.0",
|
||||
"model.layers.3": "model.layers.1",
|
||||
}
|
||||
|
||||
kv_cache_groups = [
|
||||
KVCacheGroupSpec(["model.layers.0", "model.layers.1"],
|
||||
new_kv_cache_spec()),
|
||||
]
|
||||
|
||||
attn_groups = [
|
||||
# KV cache group 0 has a single attention group
|
||||
# as all layers have the same flash attention backend
|
||||
[
|
||||
AttentionGroup(
|
||||
backend=FlashAttentionBackend,
|
||||
metadata_builder=Mock(spec=FlashAttentionMetadataBuilder),
|
||||
layer_names=["model.layers.0", "model.layers.1"],
|
||||
),
|
||||
],
|
||||
]
|
||||
|
||||
kv_caches = {
|
||||
"model.layers.0": torch.zeros(1, 2, 3),
|
||||
"model.layers.1": torch.ones(1, 2, 3),
|
||||
}
|
||||
|
||||
initialize_kv_cache_for_kv_sharing(
|
||||
shared_kv_cache_layers=shared_kv_cache_layers,
|
||||
kv_cache_groups=kv_cache_groups,
|
||||
kv_caches=kv_caches,
|
||||
attn_groups=attn_groups,
|
||||
)
|
||||
|
||||
# Check that the KV caches were shared correctly
|
||||
assert kv_caches["model.layers.2"].data_ptr(
|
||||
) == kv_caches["model.layers.0"].data_ptr()
|
||||
assert kv_caches["model.layers.3"].data_ptr(
|
||||
) == kv_caches["model.layers.1"].data_ptr()
|
||||
|
||||
# Check that the layers were added to the correct KV cache group
|
||||
assert len(kv_cache_groups) == 1
|
||||
assert kv_cache_groups[0].layer_names == [
|
||||
"model.layers.0", "model.layers.1", "model.layers.2", "model.layers.3"
|
||||
]
|
||||
|
||||
# Check that the layers were added to the attention groups
|
||||
assert len(attn_groups) == 1 and len(attn_groups[0]) == 1
|
||||
assert attn_groups[0][0].layer_names == [
|
||||
"model.layers.0", "model.layers.1", "model.layers.2", "model.layers.3"
|
||||
]
|
||||
|
||||
|
||||
def test_initialize_kv_cache_for_kv_sharing_no_attn_groups():
|
||||
"""
|
||||
Test KV sharing set up when no attention groups are provided.
|
||||
This is the case for the TPU model runner, which doesn't have
|
||||
support for attention groups yet.
|
||||
"""
|
||||
shared_kv_cache_layers = {
|
||||
"model.layers.2": "model.layers.0",
|
||||
"model.layers.3": "model.layers.1",
|
||||
}
|
||||
|
||||
kv_cache_groups = [
|
||||
KVCacheGroupSpec(["model.layers.0"], new_kv_cache_spec()),
|
||||
KVCacheGroupSpec(["model.layers.1"], new_kv_cache_spec()),
|
||||
]
|
||||
|
||||
kv_caches = {
|
||||
"model.layers.0": torch.zeros(1, 2, 3),
|
||||
"model.layers.1": torch.ones(1, 2, 3),
|
||||
}
|
||||
|
||||
initialize_kv_cache_for_kv_sharing(
|
||||
shared_kv_cache_layers=shared_kv_cache_layers,
|
||||
kv_cache_groups=kv_cache_groups,
|
||||
kv_caches=kv_caches,
|
||||
)
|
||||
|
||||
# Check that the KV caches were shared correctly
|
||||
assert kv_caches["model.layers.2"].data_ptr(
|
||||
) == kv_caches["model.layers.0"].data_ptr()
|
||||
assert kv_caches["model.layers.3"].data_ptr(
|
||||
) == kv_caches["model.layers.1"].data_ptr()
|
||||
|
||||
# Check that the layers were added to the correct KV cache group
|
||||
assert len(kv_cache_groups) == 2
|
||||
assert kv_cache_groups[0].layer_names == [
|
||||
"model.layers.0", "model.layers.2"
|
||||
]
|
||||
assert kv_cache_groups[1].layer_names == [
|
||||
"model.layers.1", "model.layers.3"
|
||||
]
|
||||
@ -225,26 +225,34 @@ def initialize_kv_cache_for_kv_sharing(
|
||||
Note that layers in shared_kv_cache_layers.keys() are not
|
||||
originally included as it only contains layers which have its own
|
||||
KV cache allocation.
|
||||
attn_groups: Optional list of attention groups. Layers in the same KV
|
||||
cache group may be placed in different attention groups if they
|
||||
have different attention backends. Currently only provided by
|
||||
GPU model runner.
|
||||
"""
|
||||
# Record index of KV cache group for each layer that allocates a KV cache.
|
||||
layer_to_kv_cache_group_idx: dict[str, int] = {}
|
||||
for i, kv_cache_group in enumerate(kv_cache_groups):
|
||||
for layer_name in kv_cache_group.layer_names:
|
||||
layer_to_kv_cache_group_idx[layer_name] = i
|
||||
# mapping from layer name to tuple of (kv_cache_group_idx, attn_group_idx)
|
||||
layer_to_attn_group_idx: dict[str, tuple[int, int]] = {}
|
||||
if attn_groups:
|
||||
for kv_cache_group_idx, kv_attn_groups in enumerate(attn_groups):
|
||||
for attn_group_idx, attn_group in enumerate(kv_attn_groups):
|
||||
for layer_name in attn_group.layer_names:
|
||||
layer_to_attn_group_idx[layer_name] = (kv_cache_group_idx,
|
||||
attn_group_idx)
|
||||
else:
|
||||
for kv_cache_group_idx, kv_cache_group in enumerate(kv_cache_groups):
|
||||
for layer_name in kv_cache_group.layer_names:
|
||||
# attn group idx default to 0 if not provided
|
||||
layer_to_attn_group_idx[layer_name] = (kv_cache_group_idx, 0)
|
||||
|
||||
for layer_name, target_layer_name in shared_kv_cache_layers.items():
|
||||
kv_caches[layer_name] = kv_caches[target_layer_name]
|
||||
group_idx = layer_to_kv_cache_group_idx[target_layer_name]
|
||||
kv_cache_groups[group_idx].layer_names.append(layer_name)
|
||||
kv_cache_group_idx = layer_to_attn_group_idx[target_layer_name][0]
|
||||
kv_cache_groups[kv_cache_group_idx].layer_names.append(layer_name)
|
||||
|
||||
if attn_groups is not None:
|
||||
assert len(attn_groups[group_idx]) == 1, (
|
||||
"Only one attention group per KV cache group is supported "
|
||||
"for KV-cache sharing for now.")
|
||||
# TODO(lucas): I think in the future the layers that re-use a
|
||||
# KV cache will be in a different attention group so we can
|
||||
# remove this code from here.
|
||||
attn_groups[group_idx][0].layer_names.append(layer_name)
|
||||
if attn_groups:
|
||||
attn_group_idx = layer_to_attn_group_idx[target_layer_name][1]
|
||||
attn_groups[kv_cache_group_idx][attn_group_idx].layer_names.append(
|
||||
layer_name)
|
||||
|
||||
|
||||
def bind_kv_cache(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user