mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 13:15:01 +08:00
[V1] Support cross-layer KV sharing (#18212)
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
parent
fa98d77773
commit
bdf13965ab
@ -4,8 +4,13 @@ import unittest.mock as mock
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
|
from vllm.attention.layer import Attention
|
||||||
|
from vllm.config import (CacheConfig, ModelConfig, SchedulerConfig, VllmConfig,
|
||||||
|
set_current_vllm_config)
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.utils import GiB_bytes
|
||||||
|
from vllm.v1.core.kv_cache_utils import (estimate_max_model_len,
|
||||||
|
get_kv_cache_config)
|
||||||
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
|
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
|
||||||
SchedulerOutput)
|
SchedulerOutput)
|
||||||
from vllm.v1.worker.tpu_model_runner import (
|
from vllm.v1.worker.tpu_model_runner import (
|
||||||
@ -363,3 +368,223 @@ def test_get_req_paddings():
|
|||||||
assert _get_req_paddings(1, 32) == [8, 16, 32]
|
assert _get_req_paddings(1, 32) == [8, 16, 32]
|
||||||
assert _get_req_paddings(8, 32) == [8, 16, 32]
|
assert _get_req_paddings(8, 32) == [8, 16, 32]
|
||||||
assert _get_req_paddings(8, 36) == [8, 16, 32, 36]
|
assert _get_req_paddings(8, 36) == [8, 16, 32, 36]
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
|
||||||
|
layer_0 = "model.layers.0.self_attn.attn"
|
||||||
|
layer_1 = "model.layers.1.self_attn.attn"
|
||||||
|
error_msg = f"{layer_1} must come before the current layer"
|
||||||
|
with pytest.raises(ValueError, match=error_msg):
|
||||||
|
fwd_context = {
|
||||||
|
# initialization below will fail because target layer is invalid;
|
||||||
|
# the target layer needs to come before layer 1
|
||||||
|
layer_0:
|
||||||
|
Attention(
|
||||||
|
num_heads=8,
|
||||||
|
head_size=64,
|
||||||
|
scale=1.0,
|
||||||
|
prefix=layer_0,
|
||||||
|
kv_sharing_target_layer_name=layer_1,
|
||||||
|
),
|
||||||
|
layer_1:
|
||||||
|
Attention(
|
||||||
|
num_heads=8,
|
||||||
|
head_size=64,
|
||||||
|
scale=1.0,
|
||||||
|
prefix=layer_1,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
# suppress var not used error
|
||||||
|
assert fwd_context is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
|
||||||
|
layer_0 = "model.layers.0.self_attn.attn"
|
||||||
|
layer_1 = "model.layers.1.self_attn.attn"
|
||||||
|
invalid_layer = "model.layers.0.cross_attn.attn"
|
||||||
|
error_msg = f"{invalid_layer} is not a valid Attention layer in the model"
|
||||||
|
with pytest.raises(ValueError, match=error_msg):
|
||||||
|
fwd_context = {
|
||||||
|
layer_0:
|
||||||
|
Attention(
|
||||||
|
num_heads=8,
|
||||||
|
head_size=64,
|
||||||
|
scale=1.0,
|
||||||
|
prefix=layer_0,
|
||||||
|
),
|
||||||
|
layer_1:
|
||||||
|
Attention(
|
||||||
|
num_heads=8,
|
||||||
|
head_size=64,
|
||||||
|
scale=1.0,
|
||||||
|
prefix=layer_1,
|
||||||
|
# invalid layer: cross_attn.atn doesn't exist!
|
||||||
|
kv_sharing_target_layer_name=invalid_layer,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
# suppress var not used error
|
||||||
|
assert fwd_context is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_kv_cache_with_kv_sharing_target_same_as_current():
|
||||||
|
layer_0 = "model.layers.0.self_attn.attn"
|
||||||
|
layer_1 = "model.layers.1.self_attn.attn"
|
||||||
|
error_msg = f"{layer_1} cannot be the same as the current layer"
|
||||||
|
with pytest.raises(ValueError, match=error_msg):
|
||||||
|
fwd_context = {
|
||||||
|
# initialization below will fail because target layer is invalid;
|
||||||
|
# the target layer needs to come before layer 1
|
||||||
|
layer_0:
|
||||||
|
Attention(
|
||||||
|
num_heads=8,
|
||||||
|
head_size=64,
|
||||||
|
scale=1.0,
|
||||||
|
prefix=layer_0,
|
||||||
|
),
|
||||||
|
layer_1:
|
||||||
|
Attention(
|
||||||
|
num_heads=8,
|
||||||
|
head_size=64,
|
||||||
|
scale=1.0,
|
||||||
|
prefix=layer_1,
|
||||||
|
kv_sharing_target_layer_name=layer_1,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
# suppress var not used error
|
||||||
|
assert fwd_context is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_kv_cache_without_kv_sharing(model_runner):
|
||||||
|
layer_0 = "model.layers.0.self_attn.attn"
|
||||||
|
layer_1 = "model.layers.1.self_attn.attn"
|
||||||
|
vllm_config = model_runner.vllm_config
|
||||||
|
with set_current_vllm_config(vllm_config):
|
||||||
|
fwd_context = {
|
||||||
|
layer_0:
|
||||||
|
Attention(
|
||||||
|
num_heads=8,
|
||||||
|
head_size=64,
|
||||||
|
scale=1.0,
|
||||||
|
prefix=layer_0,
|
||||||
|
),
|
||||||
|
layer_1:
|
||||||
|
Attention(
|
||||||
|
num_heads=8,
|
||||||
|
head_size=64,
|
||||||
|
scale=1.0,
|
||||||
|
prefix=layer_1,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
# suppress var not used error
|
||||||
|
assert fwd_context is not None
|
||||||
|
# Set high context length to test max context length estimation
|
||||||
|
vllm_config.model_config.max_model_len = 3_000_000
|
||||||
|
vllm_ctx = vllm_config.compilation_config.static_forward_context
|
||||||
|
kv_cache_spec = model_runner.get_kv_cache_spec()
|
||||||
|
assert len(kv_cache_spec) == 2
|
||||||
|
assert len(model_runner.shared_kv_cache_layers) == 0
|
||||||
|
|
||||||
|
available_memory = 20 * GiB_bytes
|
||||||
|
# page size for layer 0's kv_cache_spec is 32KB
|
||||||
|
num_expected_blocks = 327680 # 20GB / 32KB / 2 (num layers)
|
||||||
|
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
|
||||||
|
available_memory)
|
||||||
|
assert kv_cache_config.num_blocks == num_expected_blocks
|
||||||
|
assert len(kv_cache_config.tensors) == 2
|
||||||
|
assert kv_cache_config.tensors[layer_0].size == available_memory // 2
|
||||||
|
assert kv_cache_config.tensors[layer_1].size == available_memory // 2
|
||||||
|
|
||||||
|
max_context_len =\
|
||||||
|
estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
|
||||||
|
# max context len with KV sharing should be 2x as large as without
|
||||||
|
assert max_context_len == 1310720
|
||||||
|
|
||||||
|
# important: override tensor size to prevent large mem alloc during test
|
||||||
|
# this will only allocate 2 block worth of memory (2 * 32kb)
|
||||||
|
kv_cache_config.num_blocks = 1
|
||||||
|
for layer in kv_cache_config.tensors:
|
||||||
|
kv_cache_config.tensors[layer].size =\
|
||||||
|
kv_cache_spec[layer].page_size_bytes
|
||||||
|
|
||||||
|
model_runner.initialize_kv_cache(kv_cache_config)
|
||||||
|
|
||||||
|
layer_0_kv = vllm_ctx[layer_0].kv_cache[0]
|
||||||
|
layer_1_kv = vllm_ctx[layer_1].kv_cache[0]
|
||||||
|
# check layer 1 kv cache does NOT share memory with layer 0
|
||||||
|
assert id(layer_1_kv) != id(layer_0_kv)
|
||||||
|
|
||||||
|
# check layer 1 added to kv cache group's layer names
|
||||||
|
assert len(kv_cache_config.kv_cache_groups) == 1
|
||||||
|
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
|
||||||
|
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
|
||||||
|
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_kv_cache_with_kv_sharing_valid(model_runner):
|
||||||
|
layer_0 = "model.layers.0.self_attn.attn"
|
||||||
|
layer_1 = "model.layers.1.self_attn.attn"
|
||||||
|
vllm_config = model_runner.vllm_config
|
||||||
|
with set_current_vllm_config(vllm_config):
|
||||||
|
fwd_context = {
|
||||||
|
layer_0:
|
||||||
|
Attention(
|
||||||
|
num_heads=8,
|
||||||
|
head_size=64,
|
||||||
|
scale=1.0,
|
||||||
|
prefix=layer_0,
|
||||||
|
),
|
||||||
|
layer_1:
|
||||||
|
Attention(
|
||||||
|
num_heads=8,
|
||||||
|
head_size=64,
|
||||||
|
scale=1.0,
|
||||||
|
prefix=layer_1,
|
||||||
|
kv_sharing_target_layer_name="model.layers.0.self_attn.attn",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
# suppress var not used error
|
||||||
|
assert fwd_context is not None
|
||||||
|
# Set high context length to test max context length estimation
|
||||||
|
vllm_config.model_config.max_model_len = 3_000_000
|
||||||
|
vllm_ctx = vllm_config.compilation_config.static_forward_context
|
||||||
|
kv_cache_spec = model_runner.get_kv_cache_spec()
|
||||||
|
assert len(kv_cache_spec) == 1
|
||||||
|
assert layer_0 in kv_cache_spec
|
||||||
|
assert model_runner.shared_kv_cache_layers[layer_1] == layer_0
|
||||||
|
|
||||||
|
available_memory = 20 * GiB_bytes
|
||||||
|
# page size for layer 0's kv_cache_spec is 32KB
|
||||||
|
# with KV sharing, we can allocate (available_mem//page_size//1) blocks
|
||||||
|
# which is twice as many as without KV sharing
|
||||||
|
num_expected_blocks = 655360 # 20GB / 32KB
|
||||||
|
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
|
||||||
|
available_memory)
|
||||||
|
assert kv_cache_config.num_blocks == num_expected_blocks
|
||||||
|
assert len(kv_cache_config.tensors) == 1
|
||||||
|
# Each layer now has twice the available memory for KV cache
|
||||||
|
# compared to no KV sharing
|
||||||
|
assert kv_cache_config.tensors[layer_0].size == available_memory
|
||||||
|
|
||||||
|
max_context_len =\
|
||||||
|
estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
|
||||||
|
# max context len with KV sharing should be 2x as large as without
|
||||||
|
assert max_context_len == 2 * 1310720
|
||||||
|
|
||||||
|
# important: override tensor size to prevent large mem alloc during test
|
||||||
|
# this will only allocate 1 block worth of memory (32kb)
|
||||||
|
kv_cache_config.num_blocks = 1
|
||||||
|
kv_cache_config.tensors[layer_0].size =\
|
||||||
|
kv_cache_spec[layer_0].page_size_bytes
|
||||||
|
|
||||||
|
model_runner.initialize_kv_cache(kv_cache_config)
|
||||||
|
|
||||||
|
layer_0_kv = vllm_ctx[layer_0].kv_cache[0]
|
||||||
|
layer_1_kv = vllm_ctx[layer_1].kv_cache[0]
|
||||||
|
# check layer 1 kv cache shares memory with layer 0
|
||||||
|
assert id(layer_1_kv) == id(layer_0_kv)
|
||||||
|
|
||||||
|
# check layer 1 added to kv cache group's layer names
|
||||||
|
assert len(kv_cache_config.kv_cache_groups) == 1
|
||||||
|
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
|
||||||
|
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
|
||||||
|
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
|
||||||
|
|||||||
@ -7,8 +7,11 @@ import pytest
|
|||||||
|
|
||||||
from vllm.attention import Attention
|
from vllm.attention import Attention
|
||||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||||
SchedulerConfig, VllmConfig)
|
SchedulerConfig, VllmConfig, set_current_vllm_config)
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.utils import GiB_bytes
|
||||||
|
from vllm.v1.core.kv_cache_utils import (estimate_max_model_len,
|
||||||
|
get_kv_cache_config)
|
||||||
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
|
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
|
||||||
SchedulerOutput)
|
SchedulerOutput)
|
||||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||||
@ -19,6 +22,7 @@ from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
|||||||
|
|
||||||
BLOCK_SIZE = 16
|
BLOCK_SIZE = 16
|
||||||
NUM_BLOCKS = 10
|
NUM_BLOCKS = 10
|
||||||
|
DEVICE = "cuda"
|
||||||
|
|
||||||
|
|
||||||
def initialize_kv_cache(runner: GPUModelRunner):
|
def initialize_kv_cache(runner: GPUModelRunner):
|
||||||
@ -55,8 +59,7 @@ def initialize_kv_cache(runner: GPUModelRunner):
|
|||||||
runner.initialize_attn_backend(kv_cache_config)
|
runner.initialize_attn_backend(kv_cache_config)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
def get_vllm_config():
|
||||||
def model_runner():
|
|
||||||
scheduler_config = SchedulerConfig(
|
scheduler_config = SchedulerConfig(
|
||||||
max_num_seqs=10,
|
max_num_seqs=10,
|
||||||
max_num_batched_tokens=512,
|
max_num_batched_tokens=512,
|
||||||
@ -84,13 +87,18 @@ def model_runner():
|
|||||||
scheduler_config=scheduler_config,
|
scheduler_config=scheduler_config,
|
||||||
parallel_config=parallel_config,
|
parallel_config=parallel_config,
|
||||||
)
|
)
|
||||||
num_heads = model_config.get_num_kv_heads(parallel_config)
|
return vllm_config
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def model_runner():
|
||||||
|
vllm_config = get_vllm_config()
|
||||||
|
model_config = vllm_config.model_config
|
||||||
|
num_heads = model_config.get_num_kv_heads(vllm_config.parallel_config)
|
||||||
head_size = model_config.get_head_size()
|
head_size = model_config.get_head_size()
|
||||||
vllm_config.compilation_config.static_forward_context[
|
vllm_config.compilation_config.static_forward_context[
|
||||||
"layer.0"] = Attention(num_heads, head_size, 0.1)
|
"layer.0"] = Attention(num_heads, head_size, 0.1)
|
||||||
|
runner = GPUModelRunner(vllm_config, DEVICE)
|
||||||
device = "cuda"
|
|
||||||
runner = GPUModelRunner(vllm_config, device)
|
|
||||||
initialize_kv_cache(runner)
|
initialize_kv_cache(runner)
|
||||||
return runner
|
return runner
|
||||||
|
|
||||||
@ -385,3 +393,225 @@ def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2):
|
|||||||
model_runner_2.load_model() # Load real weights inplace
|
model_runner_2.load_model() # Load real weights inplace
|
||||||
assert str(model_runner.get_model().state_dict()) == str(
|
assert str(model_runner.get_model().state_dict()) == str(
|
||||||
model_runner_2.get_model().state_dict())
|
model_runner_2.get_model().state_dict())
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
|
||||||
|
layer_0 = "model.layers.0.self_attn.attn"
|
||||||
|
layer_1 = "model.layers.1.self_attn.attn"
|
||||||
|
error_msg = f"{layer_1} must come before the current layer"
|
||||||
|
with pytest.raises(ValueError, match=error_msg):
|
||||||
|
fwd_context = {
|
||||||
|
# initialization below will fail because target layer is invalid;
|
||||||
|
# the target layer needs to come before layer 1
|
||||||
|
layer_0:
|
||||||
|
Attention(
|
||||||
|
num_heads=8,
|
||||||
|
head_size=64,
|
||||||
|
scale=1.0,
|
||||||
|
prefix=layer_0,
|
||||||
|
kv_sharing_target_layer_name=layer_1,
|
||||||
|
),
|
||||||
|
layer_1:
|
||||||
|
Attention(
|
||||||
|
num_heads=8,
|
||||||
|
head_size=64,
|
||||||
|
scale=1.0,
|
||||||
|
prefix=layer_1,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
# suppress var not used error
|
||||||
|
assert fwd_context is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
|
||||||
|
layer_0 = "model.layers.0.self_attn.attn"
|
||||||
|
layer_1 = "model.layers.1.self_attn.attn"
|
||||||
|
invalid_layer = "model.layers.0.cross_attn.attn"
|
||||||
|
error_msg = f"{invalid_layer} is not a valid Attention layer in the model"
|
||||||
|
with pytest.raises(ValueError, match=error_msg):
|
||||||
|
fwd_context = {
|
||||||
|
layer_0:
|
||||||
|
Attention(
|
||||||
|
num_heads=8,
|
||||||
|
head_size=64,
|
||||||
|
scale=1.0,
|
||||||
|
prefix=layer_0,
|
||||||
|
),
|
||||||
|
layer_1:
|
||||||
|
Attention(
|
||||||
|
num_heads=8,
|
||||||
|
head_size=64,
|
||||||
|
scale=1.0,
|
||||||
|
prefix=layer_1,
|
||||||
|
# invalid layer: cross_attn.atn doesn't exist!
|
||||||
|
kv_sharing_target_layer_name=invalid_layer,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
# suppress var not used error
|
||||||
|
assert fwd_context is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_kv_cache_with_kv_sharing_target_same_as_current():
|
||||||
|
layer_0 = "model.layers.0.self_attn.attn"
|
||||||
|
layer_1 = "model.layers.1.self_attn.attn"
|
||||||
|
error_msg = f"{layer_1} cannot be the same as the current layer"
|
||||||
|
with pytest.raises(ValueError, match=error_msg):
|
||||||
|
fwd_context = {
|
||||||
|
# initialization below will fail because target layer is invalid;
|
||||||
|
# the target layer needs to come before layer 1
|
||||||
|
layer_0:
|
||||||
|
Attention(
|
||||||
|
num_heads=8,
|
||||||
|
head_size=64,
|
||||||
|
scale=1.0,
|
||||||
|
prefix=layer_0,
|
||||||
|
),
|
||||||
|
layer_1:
|
||||||
|
Attention(
|
||||||
|
num_heads=8,
|
||||||
|
head_size=64,
|
||||||
|
scale=1.0,
|
||||||
|
prefix=layer_1,
|
||||||
|
kv_sharing_target_layer_name=layer_1,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
# suppress var not used error
|
||||||
|
assert fwd_context is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_kv_cache_without_kv_sharing():
|
||||||
|
layer_0 = "model.layers.0.self_attn.attn"
|
||||||
|
layer_1 = "model.layers.1.self_attn.attn"
|
||||||
|
vllm_config = get_vllm_config()
|
||||||
|
with set_current_vllm_config(vllm_config):
|
||||||
|
fwd_context = {
|
||||||
|
layer_0:
|
||||||
|
Attention(
|
||||||
|
num_heads=8,
|
||||||
|
head_size=64,
|
||||||
|
scale=1.0,
|
||||||
|
prefix=layer_0,
|
||||||
|
),
|
||||||
|
layer_1:
|
||||||
|
Attention(
|
||||||
|
num_heads=8,
|
||||||
|
head_size=64,
|
||||||
|
scale=1.0,
|
||||||
|
prefix=layer_1,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
# suppress var not used error
|
||||||
|
assert fwd_context is not None
|
||||||
|
# Set high context length to test max context length estimation
|
||||||
|
vllm_config.model_config.max_model_len = 3_000_000
|
||||||
|
vllm_ctx = vllm_config.compilation_config.static_forward_context
|
||||||
|
runner = GPUModelRunner(vllm_config, DEVICE)
|
||||||
|
kv_cache_spec = runner.get_kv_cache_spec()
|
||||||
|
assert len(kv_cache_spec) == 2
|
||||||
|
assert len(runner.shared_kv_cache_layers) == 0
|
||||||
|
|
||||||
|
available_memory = 20 * GiB_bytes
|
||||||
|
# page size for layer 0's kv_cache_spec is 32KB
|
||||||
|
num_expected_blocks = 327680 # 20GB / 32KB / 2 (num layers)
|
||||||
|
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
|
||||||
|
available_memory)
|
||||||
|
assert kv_cache_config.num_blocks == num_expected_blocks
|
||||||
|
assert len(kv_cache_config.tensors) == 2
|
||||||
|
assert kv_cache_config.tensors[layer_0].size == available_memory // 2
|
||||||
|
assert kv_cache_config.tensors[layer_1].size == available_memory // 2
|
||||||
|
|
||||||
|
max_context_len =\
|
||||||
|
estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
|
||||||
|
# max context len with KV sharing should be 2x as large as without
|
||||||
|
assert max_context_len == 1310720
|
||||||
|
|
||||||
|
# important: override tensor size to prevent large mem alloc during test
|
||||||
|
# this will only allocate 2 block worth of memory (2 * 32kb)
|
||||||
|
kv_cache_config.num_blocks = 1
|
||||||
|
for layer in kv_cache_config.tensors:
|
||||||
|
kv_cache_config.tensors[layer].size =\
|
||||||
|
kv_cache_spec[layer].page_size_bytes
|
||||||
|
|
||||||
|
runner.initialize_kv_cache(kv_cache_config)
|
||||||
|
|
||||||
|
layer_0_kv = vllm_ctx[layer_0].kv_cache[0]
|
||||||
|
layer_1_kv = vllm_ctx[layer_1].kv_cache[0]
|
||||||
|
# check layer 1 kv cache does NOT share memory with layer 0
|
||||||
|
assert id(layer_1_kv) != id(layer_0_kv)
|
||||||
|
|
||||||
|
# check layer 1 added to kv cache group's layer names
|
||||||
|
assert len(kv_cache_config.kv_cache_groups) == 1
|
||||||
|
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
|
||||||
|
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
|
||||||
|
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_kv_cache_with_kv_sharing_valid():
|
||||||
|
layer_0 = "model.layers.0.self_attn.attn"
|
||||||
|
layer_1 = "model.layers.1.self_attn.attn"
|
||||||
|
vllm_config = get_vllm_config()
|
||||||
|
with set_current_vllm_config(vllm_config):
|
||||||
|
fwd_context = {
|
||||||
|
layer_0:
|
||||||
|
Attention(
|
||||||
|
num_heads=8,
|
||||||
|
head_size=64,
|
||||||
|
scale=1.0,
|
||||||
|
prefix=layer_0,
|
||||||
|
),
|
||||||
|
layer_1:
|
||||||
|
Attention(
|
||||||
|
num_heads=8,
|
||||||
|
head_size=64,
|
||||||
|
scale=1.0,
|
||||||
|
prefix=layer_1,
|
||||||
|
kv_sharing_target_layer_name="model.layers.0.self_attn.attn",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
# suppress var not used error
|
||||||
|
assert fwd_context is not None
|
||||||
|
# Set high context length to test max context length estimation
|
||||||
|
vllm_config.model_config.max_model_len = 3_000_000
|
||||||
|
vllm_ctx = vllm_config.compilation_config.static_forward_context
|
||||||
|
runner = GPUModelRunner(vllm_config, DEVICE)
|
||||||
|
kv_cache_spec = runner.get_kv_cache_spec()
|
||||||
|
assert len(kv_cache_spec) == 1
|
||||||
|
assert layer_0 in kv_cache_spec
|
||||||
|
assert runner.shared_kv_cache_layers[layer_1] == layer_0
|
||||||
|
|
||||||
|
available_memory = 20 * GiB_bytes
|
||||||
|
# page size for layer 0's kv_cache_spec is 32KB
|
||||||
|
# with KV sharing, we can allocate (available_mem//page_size//1) blocks
|
||||||
|
# which is twice as many as without KV sharing
|
||||||
|
num_expected_blocks = 655360 # 20GB / 32KB
|
||||||
|
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
|
||||||
|
available_memory)
|
||||||
|
assert kv_cache_config.num_blocks == num_expected_blocks
|
||||||
|
assert len(kv_cache_config.tensors) == 1
|
||||||
|
# Each layer now has twice the available memory for KV cache
|
||||||
|
# compared to no KV sharing
|
||||||
|
assert kv_cache_config.tensors[layer_0].size == available_memory
|
||||||
|
|
||||||
|
max_context_len =\
|
||||||
|
estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes)
|
||||||
|
# max context len with KV sharing should be 2x as large as without
|
||||||
|
assert max_context_len == 2 * 1310720
|
||||||
|
|
||||||
|
# important: override tensor size to prevent large mem alloc during test
|
||||||
|
# this will only allocate 1 block worth of memory (32kb)
|
||||||
|
kv_cache_config.num_blocks = 1
|
||||||
|
kv_cache_config.tensors[layer_0].size =\
|
||||||
|
kv_cache_spec[layer_0].page_size_bytes
|
||||||
|
|
||||||
|
runner.initialize_kv_cache(kv_cache_config)
|
||||||
|
|
||||||
|
layer_0_kv = vllm_ctx[layer_0].kv_cache[0]
|
||||||
|
layer_1_kv = vllm_ctx[layer_1].kv_cache[0]
|
||||||
|
# check layer 1 kv cache shares memory with layer 0
|
||||||
|
assert id(layer_1_kv) == id(layer_0_kv)
|
||||||
|
|
||||||
|
# check layer 1 added to kv cache group's layer names
|
||||||
|
assert len(kv_cache_config.kv_cache_groups) == 1
|
||||||
|
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
|
||||||
|
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
|
||||||
|
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
|
||||||
|
|||||||
@ -270,6 +270,7 @@ class AttentionImpl(ABC, Generic[T]):
|
|||||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
attn_type: str = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
|
kv_sharing_target_layer_name: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|||||||
@ -306,7 +306,10 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
|
|||||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
attn_type: str = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
|
kv_sharing_target_layer_name: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if kv_sharing_target_layer_name is not None:
|
||||||
|
raise NotImplementedError("KV sharing is not supported in V0.")
|
||||||
assert blocksparse_params is not None
|
assert blocksparse_params is not None
|
||||||
assert alibi_slopes is None, ValueError(
|
assert alibi_slopes is None, ValueError(
|
||||||
"Alibi not support for blocksparse flash attention.")
|
"Alibi not support for blocksparse flash attention.")
|
||||||
|
|||||||
@ -206,12 +206,13 @@ class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]):
|
|||||||
blocksparse_params: Optional[Dict[str, Any]],
|
blocksparse_params: Optional[Dict[str, Any]],
|
||||||
logits_soft_cap: Optional[float],
|
logits_soft_cap: Optional[float],
|
||||||
attn_type: str,
|
attn_type: str,
|
||||||
|
kv_sharing_target_layer_name: Optional[str],
|
||||||
# MLA Specific Arguments
|
# MLA Specific Arguments
|
||||||
**mla_args) -> None:
|
**mla_args) -> None:
|
||||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||||
blocksparse_params, logits_soft_cap, attn_type,
|
blocksparse_params, logits_soft_cap, attn_type,
|
||||||
**mla_args)
|
kv_sharing_target_layer_name, **mla_args)
|
||||||
|
|
||||||
unsupported_features = [
|
unsupported_features = [
|
||||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||||
|
|||||||
@ -290,9 +290,12 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
|
|||||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
attn_type: str = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
|
kv_sharing_target_layer_name: Optional[str] = None,
|
||||||
layer_idx: int = -1,
|
layer_idx: int = -1,
|
||||||
dual_chunk_attention_config: Optional[Dict[str, Any]] = None,
|
dual_chunk_attention_config: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if kv_sharing_target_layer_name is not None:
|
||||||
|
raise NotImplementedError("KV sharing is not supported in V0.")
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.scale = float(scale)
|
self.scale = float(scale)
|
||||||
|
|||||||
@ -618,8 +618,11 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
attn_type: str = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
|
kv_sharing_target_layer_name: Optional[str] = None,
|
||||||
use_irope: bool = False,
|
use_irope: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if kv_sharing_target_layer_name is not None:
|
||||||
|
raise NotImplementedError("KV sharing is not supported in V0.")
|
||||||
if blocksparse_params is not None:
|
if blocksparse_params is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"FlashAttention does not support block-sparse attention.")
|
"FlashAttention does not support block-sparse attention.")
|
||||||
|
|||||||
@ -936,8 +936,11 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
attn_type: str = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
|
kv_sharing_target_layer_name: Optional[str] = None,
|
||||||
use_irope: bool = False,
|
use_irope: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if kv_sharing_target_layer_name is not None:
|
||||||
|
raise NotImplementedError("KV sharing is not supported in V0.")
|
||||||
if use_irope:
|
if use_irope:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"Using irope in FlashInfer is not supported yet, it will fall"
|
"Using irope in FlashInfer is not supported yet, it will fall"
|
||||||
|
|||||||
@ -184,12 +184,13 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
|||||||
blocksparse_params: Optional[Dict[str, Any]],
|
blocksparse_params: Optional[Dict[str, Any]],
|
||||||
logits_soft_cap: Optional[float],
|
logits_soft_cap: Optional[float],
|
||||||
attn_type: str,
|
attn_type: str,
|
||||||
|
kv_sharing_target_layer_name: Optional[str] = None,
|
||||||
# MLA Specific Arguments
|
# MLA Specific Arguments
|
||||||
**mla_args) -> None:
|
**mla_args) -> None:
|
||||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||||
blocksparse_params, logits_soft_cap, attn_type,
|
blocksparse_params, logits_soft_cap, attn_type,
|
||||||
**mla_args)
|
kv_sharing_target_layer_name, **mla_args)
|
||||||
|
|
||||||
assert is_flashmla_supported(), \
|
assert is_flashmla_supported(), \
|
||||||
"FlashMLA is not supported on this device"
|
"FlashMLA is not supported on this device"
|
||||||
|
|||||||
@ -110,9 +110,12 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
|
|||||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
max_seq_len: int = 4096,
|
max_seq_len: int = 4096,
|
||||||
attn_type: str = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
|
kv_sharing_target_layer_name: Optional[str] = None,
|
||||||
use_irope: bool = False,
|
use_irope: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(AttentionImpl, self).__init__()
|
super(AttentionImpl, self).__init__()
|
||||||
|
if kv_sharing_target_layer_name is not None:
|
||||||
|
raise NotImplementedError("KV sharing is not supported in V0.")
|
||||||
if use_irope:
|
if use_irope:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"Using irope in HPU is not supported yet, it will fall back "
|
"Using irope in HPU is not supported yet, it will fall back "
|
||||||
|
|||||||
@ -123,8 +123,11 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
|||||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
attn_type: str = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
|
kv_sharing_target_layer_name: Optional[str] = None,
|
||||||
use_irope: bool = False,
|
use_irope: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if kv_sharing_target_layer_name is not None:
|
||||||
|
raise NotImplementedError("KV sharing is not supported in V0.")
|
||||||
if use_irope:
|
if use_irope:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"Using irope in Ipex is not supported yet, it will fall"
|
"Using irope in Ipex is not supported yet, it will fall"
|
||||||
|
|||||||
@ -1000,6 +1000,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
blocksparse_params: Optional[Dict[str, Any]],
|
blocksparse_params: Optional[Dict[str, Any]],
|
||||||
logits_soft_cap: Optional[float],
|
logits_soft_cap: Optional[float],
|
||||||
attn_type: str,
|
attn_type: str,
|
||||||
|
kv_sharing_target_layer_name: Optional[str],
|
||||||
# MLA Specific Arguments
|
# MLA Specific Arguments
|
||||||
q_lora_rank: Optional[int],
|
q_lora_rank: Optional[int],
|
||||||
kv_lora_rank: int,
|
kv_lora_rank: int,
|
||||||
@ -1009,6 +1010,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
v_head_dim: int,
|
v_head_dim: int,
|
||||||
kv_b_proj: ColumnParallelLinear,
|
kv_b_proj: ColumnParallelLinear,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if kv_sharing_target_layer_name is not None:
|
||||||
|
raise NotImplementedError("KV sharing not supported in V0.")
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.scale = float(scale)
|
self.scale = float(scale)
|
||||||
|
|||||||
@ -109,8 +109,11 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
attn_type: str = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
|
kv_sharing_target_layer_name: Optional[str] = None,
|
||||||
use_irope: bool = False,
|
use_irope: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if kv_sharing_target_layer_name is not None:
|
||||||
|
raise NotImplementedError("KV sharing is not supported in V0.")
|
||||||
if use_irope:
|
if use_irope:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"Using irope in Pallas is not supported yet, it will fall back "
|
"Using irope in Pallas is not supported yet, it will fall back "
|
||||||
|
|||||||
@ -370,12 +370,13 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
|||||||
blocksparse_params: Optional[dict[str, Any]],
|
blocksparse_params: Optional[dict[str, Any]],
|
||||||
logits_soft_cap: Optional[float],
|
logits_soft_cap: Optional[float],
|
||||||
attn_type: str,
|
attn_type: str,
|
||||||
|
kv_sharing_target_layer_name: Optional[str],
|
||||||
# MLA Specific Arguments
|
# MLA Specific Arguments
|
||||||
**mla_args) -> None:
|
**mla_args) -> None:
|
||||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||||
blocksparse_params, logits_soft_cap, attn_type,
|
blocksparse_params, logits_soft_cap, attn_type,
|
||||||
**mla_args)
|
kv_sharing_target_layer_name, **mla_args)
|
||||||
|
|
||||||
unsupported_features = [
|
unsupported_features = [
|
||||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||||
|
|||||||
@ -494,8 +494,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
attn_type: str = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
|
kv_sharing_target_layer_name: Optional[str] = None,
|
||||||
use_irope: bool = False,
|
use_irope: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if kv_sharing_target_layer_name is not None:
|
||||||
|
raise NotImplementedError("KV sharing is not supported in V0.")
|
||||||
if use_irope:
|
if use_irope:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"Using irope in ROCm Flash Attention is not supported yet, it "
|
"Using irope in ROCm Flash Attention is not supported yet, it "
|
||||||
|
|||||||
@ -405,8 +405,11 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
|||||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
attn_type: str = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
|
kv_sharing_target_layer_name: Optional[str] = None,
|
||||||
use_irope: bool = False,
|
use_irope: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if kv_sharing_target_layer_name is not None:
|
||||||
|
raise NotImplementedError("KV sharing is not supported in V0.")
|
||||||
if blocksparse_params is not None:
|
if blocksparse_params is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Torch SPDA does not support block-sparse attention.")
|
"Torch SPDA does not support block-sparse attention.")
|
||||||
|
|||||||
@ -38,12 +38,13 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
blocksparse_params: Optional[Dict[str, Any]],
|
blocksparse_params: Optional[Dict[str, Any]],
|
||||||
logits_soft_cap: Optional[float],
|
logits_soft_cap: Optional[float],
|
||||||
attn_type: str,
|
attn_type: str,
|
||||||
|
kv_sharing_target_layer_name: Optional[str],
|
||||||
# MLA Specific Arguments
|
# MLA Specific Arguments
|
||||||
**mla_args) -> None:
|
**mla_args) -> None:
|
||||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||||
blocksparse_params, logits_soft_cap, attn_type,
|
blocksparse_params, logits_soft_cap, attn_type,
|
||||||
**mla_args)
|
kv_sharing_target_layer_name, **mla_args)
|
||||||
|
|
||||||
unsupported_features = [
|
unsupported_features = [
|
||||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||||
|
|||||||
@ -390,8 +390,11 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
attn_type: str = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
|
kv_sharing_target_layer_name: Optional[str] = None,
|
||||||
use_irope: bool = False,
|
use_irope: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if kv_sharing_target_layer_name is not None:
|
||||||
|
raise NotImplementedError("KV sharing is not supported in V0.")
|
||||||
if blocksparse_params is not None:
|
if blocksparse_params is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"XFormers does not support block-sparse attention.")
|
"XFormers does not support block-sparse attention.")
|
||||||
|
|||||||
@ -21,6 +21,7 @@ from vllm.model_executor.layers.quantization.base_config import (
|
|||||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
from vllm.platforms import _Backend, current_platform
|
from vllm.platforms import _Backend, current_platform
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
from vllm.v1.attention.backends.utils import validate_kv_sharing_target
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
@ -50,6 +51,7 @@ class Attention(nn.Module):
|
|||||||
use_mla: bool = False,
|
use_mla: bool = False,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
attn_type: str = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
|
kv_sharing_target_layer_name: Optional[str] = None,
|
||||||
**extra_impl_args,
|
**extra_impl_args,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@ -135,7 +137,7 @@ class Attention(nn.Module):
|
|||||||
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
|
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
|
||||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||||
blocksparse_params, logits_soft_cap, attn_type,
|
blocksparse_params, logits_soft_cap, attn_type,
|
||||||
**extra_impl_args)
|
kv_sharing_target_layer_name, **extra_impl_args)
|
||||||
self.backend = backend_name_to_enum(attn_backend.get_name())
|
self.backend = backend_name_to_enum(attn_backend.get_name())
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
@ -153,6 +155,19 @@ class Attention(nn.Module):
|
|||||||
compilation_config.static_forward_context[prefix] = self
|
compilation_config.static_forward_context[prefix] = self
|
||||||
self.layer_name = prefix
|
self.layer_name = prefix
|
||||||
self.attn_type = attn_type
|
self.attn_type = attn_type
|
||||||
|
|
||||||
|
if kv_sharing_target_layer_name is not None:
|
||||||
|
if not envs.VLLM_USE_V1:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Cross-layer KV sharing is not supported in V0.")
|
||||||
|
|
||||||
|
validate_kv_sharing_target(
|
||||||
|
prefix,
|
||||||
|
kv_sharing_target_layer_name,
|
||||||
|
compilation_config.static_forward_context,
|
||||||
|
)
|
||||||
|
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||||
|
|
||||||
# use a placeholder kv cache tensor during init, which will be replaced
|
# use a placeholder kv cache tensor during init, which will be replaced
|
||||||
# by bind_kv_cache
|
# by bind_kv_cache
|
||||||
# this variable will not be accessed if use_direct_call is True
|
# this variable will not be accessed if use_direct_call is True
|
||||||
|
|||||||
@ -485,6 +485,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
blocksparse_params: Optional[dict[str, Any]] = None,
|
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
attn_type: AttentionType = AttentionType.DECODER,
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
|
kv_sharing_target_layer_name: Optional[str] = None,
|
||||||
use_irope: bool = False,
|
use_irope: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
if blocksparse_params is not None:
|
if blocksparse_params is not None:
|
||||||
@ -506,6 +507,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
||||||
logits_soft_cap = 0
|
logits_soft_cap = 0
|
||||||
self.logits_soft_cap = logits_soft_cap
|
self.logits_soft_cap = logits_soft_cap
|
||||||
|
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||||
|
|
||||||
assert self.num_heads % self.num_kv_heads == 0
|
assert self.num_heads % self.num_kv_heads == 0
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
@ -569,22 +571,26 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
# performance to make sure it does not introduce any overhead.
|
# performance to make sure it does not introduce any overhead.
|
||||||
|
|
||||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||||
# Reshape the input keys and values and store them in the cache.
|
|
||||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
|
||||||
# not padded. However, we don't need to do key[:num_actual_tokens] and
|
|
||||||
# value[:num_actual_tokens] because the reshape_and_cache_flash op uses
|
|
||||||
# the slot_mapping's shape to determine the number of actual tokens.
|
|
||||||
key_cache, value_cache = kv_cache.unbind(0)
|
key_cache, value_cache = kv_cache.unbind(0)
|
||||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
|
||||||
key,
|
if self.kv_sharing_target_layer_name is None:
|
||||||
value,
|
# Reshape the input keys and values and store them in the cache.
|
||||||
key_cache,
|
# Skip this if sharing KV cache with an earlier attention layer.
|
||||||
value_cache,
|
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
||||||
attn_metadata.slot_mapping,
|
# not padded. However, we don't need to do key[:num_actual_tokens]
|
||||||
self.kv_cache_dtype,
|
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
||||||
layer._k_scale,
|
# op uses the slot_mapping's shape to determine the number of
|
||||||
layer._v_scale,
|
# actual tokens.
|
||||||
)
|
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
attn_metadata.slot_mapping,
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
layer._k_scale,
|
||||||
|
layer._v_scale,
|
||||||
|
)
|
||||||
|
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
if self.kv_cache_dtype.startswith("fp8"):
|
||||||
key_cache = key_cache.view(torch.float8_e4m3fn)
|
key_cache = key_cache.view(torch.float8_e4m3fn)
|
||||||
|
|||||||
@ -507,6 +507,7 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
blocksparse_params: Optional[dict[str, Any]] = None,
|
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
attn_type: AttentionType = AttentionType.DECODER,
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
|
kv_sharing_target_layer_name: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
@ -521,6 +522,7 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
self.sliding_window = (sliding_window - 1, 0)
|
self.sliding_window = (sliding_window - 1, 0)
|
||||||
self.kv_cache_dtype = kv_cache_dtype
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
self.logits_soft_cap = logits_soft_cap
|
self.logits_soft_cap = logits_soft_cap
|
||||||
|
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||||
|
|
||||||
assert self.num_heads % self.num_kv_heads == 0
|
assert self.num_heads % self.num_kv_heads == 0
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
@ -568,21 +570,25 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
# performance to make sure it does not introduce any overhead.
|
# performance to make sure it does not introduce any overhead.
|
||||||
|
|
||||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||||
# Reshape the input keys and values and store them in the cache.
|
|
||||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
if self.kv_sharing_target_layer_name is None:
|
||||||
# not padded. However, we don't need to do key[:num_actual_tokens] and
|
# Reshape the input keys and values and store them in the cache.
|
||||||
# value[:num_actual_tokens] because the reshape_and_cache_flash op uses
|
# Skip this if sharing KV cache with an earlier attention layer.
|
||||||
# the slot_mapping's shape to determine the number of actual tokens.
|
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
||||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
# not padded. However, we don't need to do key[:num_actual_tokens]
|
||||||
key,
|
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
||||||
value,
|
# op uses the slot_mapping's shape to determine the number of
|
||||||
kv_cache[:, 0],
|
# actual tokens.
|
||||||
kv_cache[:, 1],
|
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||||
attn_metadata.slot_mapping,
|
key,
|
||||||
self.kv_cache_dtype,
|
value,
|
||||||
layer._k_scale,
|
kv_cache[:, 0],
|
||||||
layer._v_scale,
|
kv_cache[:, 1],
|
||||||
)
|
attn_metadata.slot_mapping,
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
layer._k_scale,
|
||||||
|
layer._v_scale,
|
||||||
|
)
|
||||||
|
|
||||||
window_left = (self.sliding_window[0]
|
window_left = (self.sliding_window[0]
|
||||||
if self.sliding_window is not None else -1)
|
if self.sliding_window is not None else -1)
|
||||||
|
|||||||
@ -586,6 +586,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
blocksparse_params: Optional[dict[str, Any]],
|
blocksparse_params: Optional[dict[str, Any]],
|
||||||
logits_soft_cap: Optional[float],
|
logits_soft_cap: Optional[float],
|
||||||
attn_type: str,
|
attn_type: str,
|
||||||
|
kv_sharing_target_layer_name: Optional[str],
|
||||||
# MLA Specific Arguments
|
# MLA Specific Arguments
|
||||||
q_lora_rank: Optional[int],
|
q_lora_rank: Optional[int],
|
||||||
kv_lora_rank: int,
|
kv_lora_rank: int,
|
||||||
@ -595,6 +596,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
v_head_dim: int,
|
v_head_dim: int,
|
||||||
kv_b_proj: ColumnParallelLinear,
|
kv_b_proj: ColumnParallelLinear,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if kv_sharing_target_layer_name is not None:
|
||||||
|
raise NotImplementedError("KV sharing is not supported for MLA")
|
||||||
|
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.scale = float(scale)
|
self.scale = float(scale)
|
||||||
|
|||||||
@ -93,12 +93,13 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
|||||||
blocksparse_params: Optional[dict[str, Any]],
|
blocksparse_params: Optional[dict[str, Any]],
|
||||||
logits_soft_cap: Optional[float],
|
logits_soft_cap: Optional[float],
|
||||||
attn_type: str,
|
attn_type: str,
|
||||||
|
kv_sharing_target_layer_name: Optional[str],
|
||||||
# MLA Specific Arguments
|
# MLA Specific Arguments
|
||||||
**mla_args) -> None:
|
**mla_args) -> None:
|
||||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||||
blocksparse_params, logits_soft_cap, attn_type,
|
blocksparse_params, logits_soft_cap, attn_type,
|
||||||
**mla_args)
|
kv_sharing_target_layer_name, **mla_args)
|
||||||
|
|
||||||
assert is_flashmla_supported(), \
|
assert is_flashmla_supported(), \
|
||||||
"FlashMLA is not supported on this device"
|
"FlashMLA is not supported on this device"
|
||||||
|
|||||||
@ -139,12 +139,13 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
|||||||
blocksparse_params: Optional[dict[str, Any]],
|
blocksparse_params: Optional[dict[str, Any]],
|
||||||
logits_soft_cap: Optional[float],
|
logits_soft_cap: Optional[float],
|
||||||
attn_type: str,
|
attn_type: str,
|
||||||
|
kv_sharing_target_layer_name: Optional[str],
|
||||||
# MLA Specific Arguments
|
# MLA Specific Arguments
|
||||||
**mla_args) -> None:
|
**mla_args) -> None:
|
||||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||||
blocksparse_params, logits_soft_cap, attn_type,
|
blocksparse_params, logits_soft_cap, attn_type,
|
||||||
**mla_args)
|
kv_sharing_target_layer_name, **mla_args)
|
||||||
assert (num_heads == 16 or num_heads == 128), (
|
assert (num_heads == 16 or num_heads == 128), (
|
||||||
f"Aiter MLA only supports 16 or 128 number of heads.\n"
|
f"Aiter MLA only supports 16 or 128 number of heads.\n"
|
||||||
f"Provided {num_heads} number of heads.\n"
|
f"Provided {num_heads} number of heads.\n"
|
||||||
|
|||||||
@ -41,12 +41,13 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
blocksparse_params: Optional[dict[str, Any]],
|
blocksparse_params: Optional[dict[str, Any]],
|
||||||
logits_soft_cap: Optional[float],
|
logits_soft_cap: Optional[float],
|
||||||
attn_type: str,
|
attn_type: str,
|
||||||
|
kv_sharing_target_layer_name: Optional[str],
|
||||||
# MLA Specific Arguments
|
# MLA Specific Arguments
|
||||||
**mla_args) -> None:
|
**mla_args) -> None:
|
||||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||||
blocksparse_params, logits_soft_cap, attn_type,
|
blocksparse_params, logits_soft_cap, attn_type,
|
||||||
**mla_args)
|
kv_sharing_target_layer_name, **mla_args)
|
||||||
|
|
||||||
unsupported_features = [
|
unsupported_features = [
|
||||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||||
|
|||||||
@ -113,6 +113,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
blocksparse_params: Optional[dict[str, Any]] = None,
|
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
attn_type: str = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
|
kv_sharing_target_layer_name: Optional[int] = None,
|
||||||
use_irope: bool = False,
|
use_irope: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
if use_irope:
|
if use_irope:
|
||||||
@ -128,6 +129,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
self.num_kv_heads = num_kv_heads
|
self.num_kv_heads = num_kv_heads
|
||||||
self.sliding_window = sliding_window
|
self.sliding_window = sliding_window
|
||||||
self.logits_soft_cap = logits_soft_cap
|
self.logits_soft_cap = logits_soft_cap
|
||||||
|
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||||
|
|
||||||
assert self.num_heads % self.num_kv_heads == 0
|
assert self.num_heads % self.num_kv_heads == 0
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
@ -181,7 +183,9 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
num_tokens, hidden_size = query.shape
|
num_tokens, hidden_size = query.shape
|
||||||
query = query.view(num_tokens, self.num_heads, self.head_size)
|
query = query.view(num_tokens, self.num_heads, self.head_size)
|
||||||
|
|
||||||
if kv_cache.numel() > 0:
|
if self.kv_sharing_target_layer_name is None and kv_cache.numel() > 0:
|
||||||
|
# Write input keys and values to the KV cache.
|
||||||
|
# Skip this if sharing KV cache with an earlier attention layer.
|
||||||
slot_mapping = attn_metadata.slot_mapping
|
slot_mapping = attn_metadata.slot_mapping
|
||||||
write_to_kv_cache(key, value, kv_cache, slot_mapping)
|
write_to_kv_cache(key, value, kv_cache, slot_mapping)
|
||||||
|
|
||||||
|
|||||||
@ -88,6 +88,7 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
blocksparse_params: Optional[dict[str, Any]] = None,
|
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||||
logits_soft_cap: Optional[float] = None,
|
logits_soft_cap: Optional[float] = None,
|
||||||
attn_type: AttentionType = AttentionType.DECODER,
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
|
kv_sharing_target_layer_name: Optional[int] = None,
|
||||||
use_irope: bool = False,
|
use_irope: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
if blocksparse_params is not None:
|
if blocksparse_params is not None:
|
||||||
@ -109,6 +110,7 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
||||||
logits_soft_cap = 0
|
logits_soft_cap = 0
|
||||||
self.logits_soft_cap = logits_soft_cap
|
self.logits_soft_cap = logits_soft_cap
|
||||||
|
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||||
|
|
||||||
self.use_irope = use_irope
|
self.use_irope = use_irope
|
||||||
|
|
||||||
@ -178,31 +180,34 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
if use_prefill_decode_attn:
|
if use_prefill_decode_attn:
|
||||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||||
kv_cache, self.num_kv_heads, self.head_size)
|
kv_cache, self.num_kv_heads, self.head_size)
|
||||||
|
|
||||||
# Reshape the input keys and values and store them in the cache.
|
|
||||||
PagedAttention.write_to_paged_cache(
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
attn_metadata.slot_mapping,
|
|
||||||
self.kv_cache_dtype,
|
|
||||||
layer._k_scale,
|
|
||||||
layer._v_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
key_cache, value_cache = kv_cache.unbind(0)
|
key_cache, value_cache = kv_cache.unbind(0)
|
||||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
|
||||||
key,
|
if self.kv_sharing_target_layer_name is None:
|
||||||
value,
|
# Reshape the input keys and values and store them in the cache.
|
||||||
key_cache,
|
# Skip this if sharing KV cache with an earlier attention layer.
|
||||||
value_cache,
|
if use_prefill_decode_attn:
|
||||||
attn_metadata.slot_mapping,
|
PagedAttention.write_to_paged_cache(
|
||||||
self.kv_cache_dtype,
|
key,
|
||||||
layer._k_scale,
|
value,
|
||||||
layer._v_scale,
|
key_cache,
|
||||||
)
|
value_cache,
|
||||||
|
attn_metadata.slot_mapping,
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
layer._k_scale,
|
||||||
|
layer._v_scale,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
attn_metadata.slot_mapping,
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
layer._k_scale,
|
||||||
|
layer._v_scale,
|
||||||
|
)
|
||||||
|
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
if self.kv_cache_dtype.startswith("fp8"):
|
||||||
key_cache = key_cache.view(self.fp8_dtype)
|
key_cache = key_cache.view(self.fp8_dtype)
|
||||||
|
|||||||
@ -17,3 +17,36 @@ class CommonAttentionMetadata:
|
|||||||
seq_lens: torch.Tensor
|
seq_lens: torch.Tensor
|
||||||
"""(batch_size,), the length of each request including both computed tokens
|
"""(batch_size,), the length of each request including both computed tokens
|
||||||
and newly scheduled tokens"""
|
and newly scheduled tokens"""
|
||||||
|
|
||||||
|
|
||||||
|
def validate_kv_sharing_target(current_layer_name, target_layer_name,
|
||||||
|
static_forward_context):
|
||||||
|
error_msg = (f"Specified KV sharing target layer for {current_layer_name} "
|
||||||
|
f"is not valid: target layer {target_layer_name} ")
|
||||||
|
|
||||||
|
if current_layer_name == target_layer_name:
|
||||||
|
raise ValueError(error_msg +
|
||||||
|
"cannot be the same as the current layer.")
|
||||||
|
|
||||||
|
if target_layer_name not in static_forward_context:
|
||||||
|
from vllm.model_executor.models.utils import extract_layer_index
|
||||||
|
|
||||||
|
# If target layer name is not in the static fwd context, it means either
|
||||||
|
# a) the target layer does not come BEFORE the current layer, or
|
||||||
|
# b) the target layer is not an Attention layer that exists in the model
|
||||||
|
current_layer_idx = extract_layer_index(current_layer_name)
|
||||||
|
target_layer_idx = extract_layer_index(target_layer_name)
|
||||||
|
if current_layer_idx <= target_layer_idx:
|
||||||
|
raise ValueError(error_msg + "must come before the current layer.")
|
||||||
|
else:
|
||||||
|
raise ValueError(error_msg +
|
||||||
|
"is not a valid Attention layer in the model.")
|
||||||
|
|
||||||
|
# Currently KV sharing is only supported between layers of the same type
|
||||||
|
target_layer_attn_type = static_forward_context[
|
||||||
|
target_layer_name].attn_type
|
||||||
|
expected = static_forward_context[current_layer_name].attn_type
|
||||||
|
if target_layer_attn_type != expected:
|
||||||
|
raise ValueError(
|
||||||
|
error_msg +
|
||||||
|
f"must be the same type as the current layer ({expected}).")
|
||||||
|
|||||||
@ -59,8 +59,8 @@ from vllm.v1.worker.block_table import BlockTable
|
|||||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||||
|
|
||||||
from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs,
|
from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
|
||||||
scatter_mm_placeholders)
|
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import xgrammar as xgr
|
import xgrammar as xgr
|
||||||
@ -276,6 +276,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
pin_memory=self.pin_memory)
|
pin_memory=self.pin_memory)
|
||||||
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
||||||
|
|
||||||
|
# Layer pairings for cross-layer KV sharing.
|
||||||
|
# If an Attention layer `layer_name` is in the keys of this dict, it
|
||||||
|
# means this layer will perform attention using the keys and values
|
||||||
|
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
|
||||||
|
self.shared_kv_cache_layers: dict[str, str] = {}
|
||||||
|
|
||||||
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> bool:
|
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> bool:
|
||||||
"""
|
"""
|
||||||
Update the order of requests in the batch based on the attention
|
Update the order of requests in the batch based on the attention
|
||||||
@ -2097,6 +2103,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# KV cache specs.
|
# KV cache specs.
|
||||||
raise ValueError("Unknown KV cache spec type.")
|
raise ValueError("Unknown KV cache spec type.")
|
||||||
|
|
||||||
|
# Setup `kv_cache_config` and `kv_caches` for models
|
||||||
|
# with cross-layer KV sharing
|
||||||
|
if self.shared_kv_cache_layers:
|
||||||
|
initialize_kv_cache_for_kv_sharing(
|
||||||
|
self.shared_kv_cache_layers,
|
||||||
|
kv_cache_config.kv_cache_groups,
|
||||||
|
kv_caches,
|
||||||
|
)
|
||||||
|
|
||||||
if self.speculative_config and self.speculative_config.use_eagle():
|
if self.speculative_config and self.speculative_config.use_eagle():
|
||||||
assert isinstance(self.drafter, EagleProposer)
|
assert isinstance(self.drafter, EagleProposer)
|
||||||
# validate all draft model layers belong to the same kv cache
|
# validate all draft model layers belong to the same kv cache
|
||||||
@ -2125,6 +2140,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
use_mla = self.vllm_config.model_config.use_mla
|
use_mla = self.vllm_config.model_config.use_mla
|
||||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||||
for layer_name, attn_module in layers.items():
|
for layer_name, attn_module in layers.items():
|
||||||
|
if (kv_tgt_layer :=
|
||||||
|
attn_module.kv_sharing_target_layer_name) is not None:
|
||||||
|
# The layer doesn't need its own KV cache and will use that of
|
||||||
|
# the target layer. We skip creating a KVCacheSpec for it, so
|
||||||
|
# that KV cache management logic will act as this layer does
|
||||||
|
# not exist, and doesn't allocate KV cache for the layer. This
|
||||||
|
# enables the memory saving of cross-layer kv sharing, allowing
|
||||||
|
# a given amount of memory to accommodate longer context lengths
|
||||||
|
# or enable more requests to be processed simultaneously.
|
||||||
|
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
|
||||||
|
continue
|
||||||
|
|
||||||
# TODO: Support other attention modules, e.g., cross-attention
|
# TODO: Support other attention modules, e.g., cross-attention
|
||||||
if attn_module.attn_type == AttentionType.DECODER:
|
if attn_module.attn_type == AttentionType.DECODER:
|
||||||
if attn_module.sliding_window is not None:
|
if attn_module.sliding_window is not None:
|
||||||
|
|||||||
@ -44,7 +44,8 @@ from vllm.v1.utils import bind_kv_cache
|
|||||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||||
|
|
||||||
from .utils import sanity_check_mm_encoder_outputs
|
from .utils import (initialize_kv_cache_for_kv_sharing,
|
||||||
|
sanity_check_mm_encoder_outputs)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
@ -238,6 +239,12 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.num_reqs_paddings = _get_req_paddings(
|
self.num_reqs_paddings = _get_req_paddings(
|
||||||
min_req_size=MIN_NUM_SEQS, max_req_size=self.max_num_reqs)
|
min_req_size=MIN_NUM_SEQS, max_req_size=self.max_num_reqs)
|
||||||
|
|
||||||
|
# Layer pairings for cross-layer KV sharing.
|
||||||
|
# If an Attention layer `layer_name` is in the keys of this dict, it
|
||||||
|
# means this layer will perform attention using the keys and values
|
||||||
|
# from the KV cache of `shared_kv_cache_layers[layer_name]`.
|
||||||
|
self.shared_kv_cache_layers: dict[str, str] = {}
|
||||||
|
|
||||||
# tensors for structured decoding
|
# tensors for structured decoding
|
||||||
self.grammar_bitmask_cpu = torch.zeros(
|
self.grammar_bitmask_cpu = torch.zeros(
|
||||||
(self.max_num_reqs, cdiv(self.vocab_size, 32)),
|
(self.max_num_reqs, cdiv(self.vocab_size, 32)),
|
||||||
@ -455,6 +462,18 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
block_size = self.vllm_config.cache_config.block_size
|
block_size = self.vllm_config.cache_config.block_size
|
||||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||||
for layer_name, attn_module in layers.items():
|
for layer_name, attn_module in layers.items():
|
||||||
|
if (kv_tgt_layer :=
|
||||||
|
attn_module.kv_sharing_target_layer_name) is not None:
|
||||||
|
# The layer doesn't need its own KV cache and will use that of
|
||||||
|
# the target layer. We skip creating a KVCacheSpec for it, so
|
||||||
|
# that KV cache management logic will act as this layer does
|
||||||
|
# not exist, and doesn't allocate KV cache for the layer. This
|
||||||
|
# enables the memory saving of cross-layer kv sharing, allowing
|
||||||
|
# a given amount of memory to accommodate longer context lengths
|
||||||
|
# or enable more requests to be processed simultaneously.
|
||||||
|
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
|
||||||
|
continue
|
||||||
|
|
||||||
if attn_module.attn_type == AttentionType.DECODER:
|
if attn_module.attn_type == AttentionType.DECODER:
|
||||||
if attn_module.sliding_window is not None:
|
if attn_module.sliding_window is not None:
|
||||||
kv_cache_spec[layer_name] = SlidingWindowSpec(
|
kv_cache_spec[layer_name] = SlidingWindowSpec(
|
||||||
@ -1376,6 +1395,15 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
# Setup `kv_cache_config` and `kv_caches` for models
|
||||||
|
# with cross-layer KV sharing
|
||||||
|
if self.shared_kv_cache_layers:
|
||||||
|
initialize_kv_cache_for_kv_sharing(
|
||||||
|
self.shared_kv_cache_layers,
|
||||||
|
kv_cache_config.kv_cache_groups,
|
||||||
|
kv_caches,
|
||||||
|
)
|
||||||
|
|
||||||
bind_kv_cache(
|
bind_kv_cache(
|
||||||
kv_caches,
|
kv_caches,
|
||||||
self.vllm_config.compilation_config.static_forward_context,
|
self.vllm_config.compilation_config.static_forward_context,
|
||||||
|
|||||||
@ -4,6 +4,8 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.v1.kv_cache_interface import KVCacheGroupSpec
|
||||||
|
|
||||||
|
|
||||||
def sanity_check_mm_encoder_outputs(
|
def sanity_check_mm_encoder_outputs(
|
||||||
mm_embeddings: object,
|
mm_embeddings: object,
|
||||||
@ -73,3 +75,37 @@ def gather_mm_placeholders(
|
|||||||
return placeholders
|
return placeholders
|
||||||
|
|
||||||
return placeholders[is_embed]
|
return placeholders[is_embed]
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_kv_cache_for_kv_sharing(
|
||||||
|
shared_kv_cache_layers: dict[str, str],
|
||||||
|
kv_cache_groups: list[KVCacheGroupSpec],
|
||||||
|
kv_caches: dict[str, torch.Tensor],
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Sets up KV cache sharing by reusing the allocated KV caches in `kv_caches`
|
||||||
|
for layers that do not allocate its own KV cache, based on the mapping in
|
||||||
|
`shared_kv_cache_layers`. Adds these layers to the corresponding KV cache
|
||||||
|
group, which is needed to ensure that attention metadata is assigned later.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shared_kv_cache_layers: Layer pairings for cross-layer KV sharing.
|
||||||
|
If an Attention layer `layer_name` is in the keys of this dict, it
|
||||||
|
means this layer will perform attention using the keys and values
|
||||||
|
from the KV cache of `shared_kv_cache_layers[layer_name]`.
|
||||||
|
kv_cache_groups: The KV cache groups of the model.
|
||||||
|
kv_caches: The allocated kv_caches with layer names as keys.
|
||||||
|
Note that layers in shared_kv_cache_layers.keys() are not
|
||||||
|
originally included as it only contains layers which have its own
|
||||||
|
KV cache allocation.
|
||||||
|
"""
|
||||||
|
# Record index of KV cache group for each layer that allocates a KV cache.
|
||||||
|
layer_to_kv_cache_group_idx: dict[str, int] = {}
|
||||||
|
for i, kv_cache_group in enumerate(kv_cache_groups):
|
||||||
|
for layer_name in kv_cache_group.layer_names:
|
||||||
|
layer_to_kv_cache_group_idx[layer_name] = i
|
||||||
|
|
||||||
|
for layer_name, target_layer_name in shared_kv_cache_layers.items():
|
||||||
|
kv_caches[layer_name] = kv_caches[target_layer_name]
|
||||||
|
group_idx = layer_to_kv_cache_group_idx[target_layer_name]
|
||||||
|
kv_cache_groups[group_idx].layer_names.append(layer_name)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user