diff --git a/tests/v1/tpu/worker/test_tpu_model_runner.py b/tests/v1/tpu/worker/test_tpu_model_runner.py index 230c97e787a98..bc54b6ecc749e 100644 --- a/tests/v1/tpu/worker/test_tpu_model_runner.py +++ b/tests/v1/tpu/worker/test_tpu_model_runner.py @@ -4,8 +4,13 @@ import unittest.mock as mock import pytest -from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig +from vllm.attention.layer import Attention +from vllm.config import (CacheConfig, ModelConfig, SchedulerConfig, VllmConfig, + set_current_vllm_config) from vllm.sampling_params import SamplingParams +from vllm.utils import GiB_bytes +from vllm.v1.core.kv_cache_utils import (estimate_max_model_len, + get_kv_cache_config) from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, SchedulerOutput) from vllm.v1.worker.tpu_model_runner import ( @@ -363,3 +368,223 @@ def test_get_req_paddings(): assert _get_req_paddings(1, 32) == [8, 16, 32] assert _get_req_paddings(8, 32) == [8, 16, 32] assert _get_req_paddings(8, 36) == [8, 16, 32, 36] + + +def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(): + layer_0 = "model.layers.0.self_attn.attn" + layer_1 = "model.layers.1.self_attn.attn" + error_msg = f"{layer_1} must come before the current layer" + with pytest.raises(ValueError, match=error_msg): + fwd_context = { + # initialization below will fail because target layer is invalid; + # the target layer needs to come before layer 1 + layer_0: + Attention( + num_heads=8, + head_size=64, + scale=1.0, + prefix=layer_0, + kv_sharing_target_layer_name=layer_1, + ), + layer_1: + Attention( + num_heads=8, + head_size=64, + scale=1.0, + prefix=layer_1, + ) + } + # suppress var not used error + assert fwd_context is not None + + +def test_init_kv_cache_with_kv_sharing_target_layer_not_exist(): + layer_0 = "model.layers.0.self_attn.attn" + layer_1 = "model.layers.1.self_attn.attn" + invalid_layer = "model.layers.0.cross_attn.attn" + error_msg = f"{invalid_layer} is not a valid Attention layer in the model" + with pytest.raises(ValueError, match=error_msg): + fwd_context = { + layer_0: + Attention( + num_heads=8, + head_size=64, + scale=1.0, + prefix=layer_0, + ), + layer_1: + Attention( + num_heads=8, + head_size=64, + scale=1.0, + prefix=layer_1, + # invalid layer: cross_attn.atn doesn't exist! + kv_sharing_target_layer_name=invalid_layer, + ) + } + # suppress var not used error + assert fwd_context is not None + + +def test_init_kv_cache_with_kv_sharing_target_same_as_current(): + layer_0 = "model.layers.0.self_attn.attn" + layer_1 = "model.layers.1.self_attn.attn" + error_msg = f"{layer_1} cannot be the same as the current layer" + with pytest.raises(ValueError, match=error_msg): + fwd_context = { + # initialization below will fail because target layer is invalid; + # the target layer needs to come before layer 1 + layer_0: + Attention( + num_heads=8, + head_size=64, + scale=1.0, + prefix=layer_0, + ), + layer_1: + Attention( + num_heads=8, + head_size=64, + scale=1.0, + prefix=layer_1, + kv_sharing_target_layer_name=layer_1, + ) + } + # suppress var not used error + assert fwd_context is not None + + +def test_init_kv_cache_without_kv_sharing(model_runner): + layer_0 = "model.layers.0.self_attn.attn" + layer_1 = "model.layers.1.self_attn.attn" + vllm_config = model_runner.vllm_config + with set_current_vllm_config(vllm_config): + fwd_context = { + layer_0: + Attention( + num_heads=8, + head_size=64, + scale=1.0, + prefix=layer_0, + ), + layer_1: + Attention( + num_heads=8, + head_size=64, + scale=1.0, + prefix=layer_1, + ) + } + # suppress var not used error + assert fwd_context is not None + # Set high context length to test max context length estimation + vllm_config.model_config.max_model_len = 3_000_000 + vllm_ctx = vllm_config.compilation_config.static_forward_context + kv_cache_spec = model_runner.get_kv_cache_spec() + assert len(kv_cache_spec) == 2 + assert len(model_runner.shared_kv_cache_layers) == 0 + + available_memory = 20 * GiB_bytes + # page size for layer 0's kv_cache_spec is 32KB + num_expected_blocks = 327680 # 20GB / 32KB / 2 (num layers) + kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec, + available_memory) + assert kv_cache_config.num_blocks == num_expected_blocks + assert len(kv_cache_config.tensors) == 2 + assert kv_cache_config.tensors[layer_0].size == available_memory // 2 + assert kv_cache_config.tensors[layer_1].size == available_memory // 2 + + max_context_len =\ + estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes) + # max context len with KV sharing should be 2x as large as without + assert max_context_len == 1310720 + + # important: override tensor size to prevent large mem alloc during test + # this will only allocate 2 block worth of memory (2 * 32kb) + kv_cache_config.num_blocks = 1 + for layer in kv_cache_config.tensors: + kv_cache_config.tensors[layer].size =\ + kv_cache_spec[layer].page_size_bytes + + model_runner.initialize_kv_cache(kv_cache_config) + + layer_0_kv = vllm_ctx[layer_0].kv_cache[0] + layer_1_kv = vllm_ctx[layer_1].kv_cache[0] + # check layer 1 kv cache does NOT share memory with layer 0 + assert id(layer_1_kv) != id(layer_0_kv) + + # check layer 1 added to kv cache group's layer names + assert len(kv_cache_config.kv_cache_groups) == 1 + assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2 + assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0 + assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1 + + +def test_init_kv_cache_with_kv_sharing_valid(model_runner): + layer_0 = "model.layers.0.self_attn.attn" + layer_1 = "model.layers.1.self_attn.attn" + vllm_config = model_runner.vllm_config + with set_current_vllm_config(vllm_config): + fwd_context = { + layer_0: + Attention( + num_heads=8, + head_size=64, + scale=1.0, + prefix=layer_0, + ), + layer_1: + Attention( + num_heads=8, + head_size=64, + scale=1.0, + prefix=layer_1, + kv_sharing_target_layer_name="model.layers.0.self_attn.attn", + ) + } + # suppress var not used error + assert fwd_context is not None + # Set high context length to test max context length estimation + vllm_config.model_config.max_model_len = 3_000_000 + vllm_ctx = vllm_config.compilation_config.static_forward_context + kv_cache_spec = model_runner.get_kv_cache_spec() + assert len(kv_cache_spec) == 1 + assert layer_0 in kv_cache_spec + assert model_runner.shared_kv_cache_layers[layer_1] == layer_0 + + available_memory = 20 * GiB_bytes + # page size for layer 0's kv_cache_spec is 32KB + # with KV sharing, we can allocate (available_mem//page_size//1) blocks + # which is twice as many as without KV sharing + num_expected_blocks = 655360 # 20GB / 32KB + kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec, + available_memory) + assert kv_cache_config.num_blocks == num_expected_blocks + assert len(kv_cache_config.tensors) == 1 + # Each layer now has twice the available memory for KV cache + # compared to no KV sharing + assert kv_cache_config.tensors[layer_0].size == available_memory + + max_context_len =\ + estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes) + # max context len with KV sharing should be 2x as large as without + assert max_context_len == 2 * 1310720 + + # important: override tensor size to prevent large mem alloc during test + # this will only allocate 1 block worth of memory (32kb) + kv_cache_config.num_blocks = 1 + kv_cache_config.tensors[layer_0].size =\ + kv_cache_spec[layer_0].page_size_bytes + + model_runner.initialize_kv_cache(kv_cache_config) + + layer_0_kv = vllm_ctx[layer_0].kv_cache[0] + layer_1_kv = vllm_ctx[layer_1].kv_cache[0] + # check layer 1 kv cache shares memory with layer 0 + assert id(layer_1_kv) == id(layer_0_kv) + + # check layer 1 added to kv cache group's layer names + assert len(kv_cache_config.kv_cache_groups) == 1 + assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2 + assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0 + assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1 diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index ceb9d4df25e62..5e2fd2fbf747b 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -7,8 +7,11 @@ import pytest from vllm.attention import Attention from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, - SchedulerConfig, VllmConfig) + SchedulerConfig, VllmConfig, set_current_vllm_config) from vllm.sampling_params import SamplingParams +from vllm.utils import GiB_bytes +from vllm.v1.core.kv_cache_utils import (estimate_max_model_len, + get_kv_cache_config) from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, SchedulerOutput) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, @@ -19,6 +22,7 @@ from vllm.v1.worker.gpu_model_runner import GPUModelRunner BLOCK_SIZE = 16 NUM_BLOCKS = 10 +DEVICE = "cuda" def initialize_kv_cache(runner: GPUModelRunner): @@ -55,8 +59,7 @@ def initialize_kv_cache(runner: GPUModelRunner): runner.initialize_attn_backend(kv_cache_config) -@pytest.fixture -def model_runner(): +def get_vllm_config(): scheduler_config = SchedulerConfig( max_num_seqs=10, max_num_batched_tokens=512, @@ -84,13 +87,18 @@ def model_runner(): scheduler_config=scheduler_config, parallel_config=parallel_config, ) - num_heads = model_config.get_num_kv_heads(parallel_config) + return vllm_config + + +@pytest.fixture +def model_runner(): + vllm_config = get_vllm_config() + model_config = vllm_config.model_config + num_heads = model_config.get_num_kv_heads(vllm_config.parallel_config) head_size = model_config.get_head_size() vllm_config.compilation_config.static_forward_context[ "layer.0"] = Attention(num_heads, head_size, 0.1) - - device = "cuda" - runner = GPUModelRunner(vllm_config, device) + runner = GPUModelRunner(vllm_config, DEVICE) initialize_kv_cache(runner) return runner @@ -385,3 +393,225 @@ def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2): model_runner_2.load_model() # Load real weights inplace assert str(model_runner.get_model().state_dict()) == str( model_runner_2.get_model().state_dict()) + + +def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(): + layer_0 = "model.layers.0.self_attn.attn" + layer_1 = "model.layers.1.self_attn.attn" + error_msg = f"{layer_1} must come before the current layer" + with pytest.raises(ValueError, match=error_msg): + fwd_context = { + # initialization below will fail because target layer is invalid; + # the target layer needs to come before layer 1 + layer_0: + Attention( + num_heads=8, + head_size=64, + scale=1.0, + prefix=layer_0, + kv_sharing_target_layer_name=layer_1, + ), + layer_1: + Attention( + num_heads=8, + head_size=64, + scale=1.0, + prefix=layer_1, + ) + } + # suppress var not used error + assert fwd_context is not None + + +def test_init_kv_cache_with_kv_sharing_target_layer_not_exist(): + layer_0 = "model.layers.0.self_attn.attn" + layer_1 = "model.layers.1.self_attn.attn" + invalid_layer = "model.layers.0.cross_attn.attn" + error_msg = f"{invalid_layer} is not a valid Attention layer in the model" + with pytest.raises(ValueError, match=error_msg): + fwd_context = { + layer_0: + Attention( + num_heads=8, + head_size=64, + scale=1.0, + prefix=layer_0, + ), + layer_1: + Attention( + num_heads=8, + head_size=64, + scale=1.0, + prefix=layer_1, + # invalid layer: cross_attn.atn doesn't exist! + kv_sharing_target_layer_name=invalid_layer, + ) + } + # suppress var not used error + assert fwd_context is not None + + +def test_init_kv_cache_with_kv_sharing_target_same_as_current(): + layer_0 = "model.layers.0.self_attn.attn" + layer_1 = "model.layers.1.self_attn.attn" + error_msg = f"{layer_1} cannot be the same as the current layer" + with pytest.raises(ValueError, match=error_msg): + fwd_context = { + # initialization below will fail because target layer is invalid; + # the target layer needs to come before layer 1 + layer_0: + Attention( + num_heads=8, + head_size=64, + scale=1.0, + prefix=layer_0, + ), + layer_1: + Attention( + num_heads=8, + head_size=64, + scale=1.0, + prefix=layer_1, + kv_sharing_target_layer_name=layer_1, + ) + } + # suppress var not used error + assert fwd_context is not None + + +def test_init_kv_cache_without_kv_sharing(): + layer_0 = "model.layers.0.self_attn.attn" + layer_1 = "model.layers.1.self_attn.attn" + vllm_config = get_vllm_config() + with set_current_vllm_config(vllm_config): + fwd_context = { + layer_0: + Attention( + num_heads=8, + head_size=64, + scale=1.0, + prefix=layer_0, + ), + layer_1: + Attention( + num_heads=8, + head_size=64, + scale=1.0, + prefix=layer_1, + ) + } + # suppress var not used error + assert fwd_context is not None + # Set high context length to test max context length estimation + vllm_config.model_config.max_model_len = 3_000_000 + vllm_ctx = vllm_config.compilation_config.static_forward_context + runner = GPUModelRunner(vllm_config, DEVICE) + kv_cache_spec = runner.get_kv_cache_spec() + assert len(kv_cache_spec) == 2 + assert len(runner.shared_kv_cache_layers) == 0 + + available_memory = 20 * GiB_bytes + # page size for layer 0's kv_cache_spec is 32KB + num_expected_blocks = 327680 # 20GB / 32KB / 2 (num layers) + kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec, + available_memory) + assert kv_cache_config.num_blocks == num_expected_blocks + assert len(kv_cache_config.tensors) == 2 + assert kv_cache_config.tensors[layer_0].size == available_memory // 2 + assert kv_cache_config.tensors[layer_1].size == available_memory // 2 + + max_context_len =\ + estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes) + # max context len with KV sharing should be 2x as large as without + assert max_context_len == 1310720 + + # important: override tensor size to prevent large mem alloc during test + # this will only allocate 2 block worth of memory (2 * 32kb) + kv_cache_config.num_blocks = 1 + for layer in kv_cache_config.tensors: + kv_cache_config.tensors[layer].size =\ + kv_cache_spec[layer].page_size_bytes + + runner.initialize_kv_cache(kv_cache_config) + + layer_0_kv = vllm_ctx[layer_0].kv_cache[0] + layer_1_kv = vllm_ctx[layer_1].kv_cache[0] + # check layer 1 kv cache does NOT share memory with layer 0 + assert id(layer_1_kv) != id(layer_0_kv) + + # check layer 1 added to kv cache group's layer names + assert len(kv_cache_config.kv_cache_groups) == 1 + assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2 + assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0 + assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1 + + +def test_init_kv_cache_with_kv_sharing_valid(): + layer_0 = "model.layers.0.self_attn.attn" + layer_1 = "model.layers.1.self_attn.attn" + vllm_config = get_vllm_config() + with set_current_vllm_config(vllm_config): + fwd_context = { + layer_0: + Attention( + num_heads=8, + head_size=64, + scale=1.0, + prefix=layer_0, + ), + layer_1: + Attention( + num_heads=8, + head_size=64, + scale=1.0, + prefix=layer_1, + kv_sharing_target_layer_name="model.layers.0.self_attn.attn", + ) + } + # suppress var not used error + assert fwd_context is not None + # Set high context length to test max context length estimation + vllm_config.model_config.max_model_len = 3_000_000 + vllm_ctx = vllm_config.compilation_config.static_forward_context + runner = GPUModelRunner(vllm_config, DEVICE) + kv_cache_spec = runner.get_kv_cache_spec() + assert len(kv_cache_spec) == 1 + assert layer_0 in kv_cache_spec + assert runner.shared_kv_cache_layers[layer_1] == layer_0 + + available_memory = 20 * GiB_bytes + # page size for layer 0's kv_cache_spec is 32KB + # with KV sharing, we can allocate (available_mem//page_size//1) blocks + # which is twice as many as without KV sharing + num_expected_blocks = 655360 # 20GB / 32KB + kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec, + available_memory) + assert kv_cache_config.num_blocks == num_expected_blocks + assert len(kv_cache_config.tensors) == 1 + # Each layer now has twice the available memory for KV cache + # compared to no KV sharing + assert kv_cache_config.tensors[layer_0].size == available_memory + + max_context_len =\ + estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes) + # max context len with KV sharing should be 2x as large as without + assert max_context_len == 2 * 1310720 + + # important: override tensor size to prevent large mem alloc during test + # this will only allocate 1 block worth of memory (32kb) + kv_cache_config.num_blocks = 1 + kv_cache_config.tensors[layer_0].size =\ + kv_cache_spec[layer_0].page_size_bytes + + runner.initialize_kv_cache(kv_cache_config) + + layer_0_kv = vllm_ctx[layer_0].kv_cache[0] + layer_1_kv = vllm_ctx[layer_1].kv_cache[0] + # check layer 1 kv cache shares memory with layer 0 + assert id(layer_1_kv) == id(layer_0_kv) + + # check layer 1 added to kv cache group's layer names + assert len(kv_cache_config.kv_cache_groups) == 1 + assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2 + assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0 + assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1 diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index deb3951d6617b..0ba5a5bf94c9b 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -270,6 +270,7 @@ class AttentionImpl(ABC, Generic[T]): blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, ) -> None: raise NotImplementedError diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index a2fd557f8e0cb..c1663516de358 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -306,7 +306,10 @@ class BlocksparseFlashAttentionImpl(AttentionImpl): blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, ) -> None: + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("KV sharing is not supported in V0.") assert blocksparse_params is not None assert alibi_slopes is None, ValueError( "Alibi not support for blocksparse flash attention.") diff --git a/vllm/attention/backends/cpu_mla.py b/vllm/attention/backends/cpu_mla.py index 39e667bca9cd2..cf7883e121abb 100644 --- a/vllm/attention/backends/cpu_mla.py +++ b/vllm/attention/backends/cpu_mla.py @@ -206,12 +206,13 @@ class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]): blocksparse_params: Optional[Dict[str, Any]], logits_soft_cap: Optional[float], attn_type: str, + kv_sharing_target_layer_name: Optional[str], # MLA Specific Arguments **mla_args) -> None: super().__init__(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, blocksparse_params, logits_soft_cap, attn_type, - **mla_args) + kv_sharing_target_layer_name, **mla_args) unsupported_features = [ alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap diff --git a/vllm/attention/backends/dual_chunk_flash_attn.py b/vllm/attention/backends/dual_chunk_flash_attn.py index 3548df88d0c5d..963bccdf21bc0 100644 --- a/vllm/attention/backends/dual_chunk_flash_attn.py +++ b/vllm/attention/backends/dual_chunk_flash_attn.py @@ -290,9 +290,12 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl): blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, layer_idx: int = -1, dual_chunk_attention_config: Optional[Dict[str, Any]] = None, ) -> None: + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("KV sharing is not supported in V0.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 26be2c04f297e..73e3772682e69 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -618,8 +618,11 @@ class FlashAttentionImpl(AttentionImpl): blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, use_irope: bool = False, ) -> None: + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("KV sharing is not supported in V0.") if blocksparse_params is not None: raise ValueError( "FlashAttention does not support block-sparse attention.") diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 7ae7ea37f4afc..a3937760f03b8 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -936,8 +936,11 @@ class FlashInferImpl(AttentionImpl): blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, use_irope: bool = False, ) -> None: + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("KV sharing is not supported in V0.") if use_irope: logger.warning_once( "Using irope in FlashInfer is not supported yet, it will fall" diff --git a/vllm/attention/backends/flashmla.py b/vllm/attention/backends/flashmla.py index 9a6b8a40e1311..e185d0260d0a0 100644 --- a/vllm/attention/backends/flashmla.py +++ b/vllm/attention/backends/flashmla.py @@ -184,12 +184,13 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): blocksparse_params: Optional[Dict[str, Any]], logits_soft_cap: Optional[float], attn_type: str, + kv_sharing_target_layer_name: Optional[str] = None, # MLA Specific Arguments **mla_args) -> None: super().__init__(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, blocksparse_params, logits_soft_cap, attn_type, - **mla_args) + kv_sharing_target_layer_name, **mla_args) assert is_flashmla_supported(), \ "FlashMLA is not supported on this device" diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 5128e49752e11..9bd513fd894f5 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -110,9 +110,12 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module): blocksparse_params: Optional[Dict[str, Any]] = None, max_seq_len: int = 4096, attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, use_irope: bool = False, ) -> None: super(AttentionImpl, self).__init__() + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("KV sharing is not supported in V0.") if use_irope: logger.warning_once( "Using irope in HPU is not supported yet, it will fall back " diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index 30441b3ad136a..5051c6a7cc4fd 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -123,8 +123,11 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, use_irope: bool = False, ) -> None: + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("KV sharing is not supported in V0.") if use_irope: logger.warning_once( "Using irope in Ipex is not supported yet, it will fall" diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 50842abd3924f..78cf952881303 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -1000,6 +1000,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): blocksparse_params: Optional[Dict[str, Any]], logits_soft_cap: Optional[float], attn_type: str, + kv_sharing_target_layer_name: Optional[str], # MLA Specific Arguments q_lora_rank: Optional[int], kv_lora_rank: int, @@ -1009,6 +1010,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): v_head_dim: int, kv_b_proj: ColumnParallelLinear, ) -> None: + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("KV sharing not supported in V0.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index a6823ac059fb7..7ad67615d33d9 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -109,8 +109,11 @@ class PallasAttentionBackendImpl(AttentionImpl): blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, use_irope: bool = False, ) -> None: + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("KV sharing is not supported in V0.") if use_irope: logger.warning_once( "Using irope in Pallas is not supported yet, it will fall back " diff --git a/vllm/attention/backends/rocm_aiter_mla.py b/vllm/attention/backends/rocm_aiter_mla.py index 855036071d0d1..1edf34351db3f 100644 --- a/vllm/attention/backends/rocm_aiter_mla.py +++ b/vllm/attention/backends/rocm_aiter_mla.py @@ -370,12 +370,13 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): blocksparse_params: Optional[dict[str, Any]], logits_soft_cap: Optional[float], attn_type: str, + kv_sharing_target_layer_name: Optional[str], # MLA Specific Arguments **mla_args) -> None: super().__init__(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, blocksparse_params, logits_soft_cap, attn_type, - **mla_args) + kv_sharing_target_layer_name, **mla_args) unsupported_features = [ alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 755e0da06cef9..4b460dc0b58cd 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -494,8 +494,11 @@ class ROCmFlashAttentionImpl(AttentionImpl): blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, use_irope: bool = False, ) -> None: + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("KV sharing is not supported in V0.") if use_irope: logger.warning_once( "Using irope in ROCm Flash Attention is not supported yet, it " diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 7606340044f1d..f3fb5adcf05ce 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -405,8 +405,11 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, use_irope: bool = False, ) -> None: + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("KV sharing is not supported in V0.") if blocksparse_params is not None: raise ValueError( "Torch SPDA does not support block-sparse attention.") diff --git a/vllm/attention/backends/triton_mla.py b/vllm/attention/backends/triton_mla.py index d9fff8fac1584..e06f7d54e3421 100644 --- a/vllm/attention/backends/triton_mla.py +++ b/vllm/attention/backends/triton_mla.py @@ -38,12 +38,13 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): blocksparse_params: Optional[Dict[str, Any]], logits_soft_cap: Optional[float], attn_type: str, + kv_sharing_target_layer_name: Optional[str], # MLA Specific Arguments **mla_args) -> None: super().__init__(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, blocksparse_params, logits_soft_cap, attn_type, - **mla_args) + kv_sharing_target_layer_name, **mla_args) unsupported_features = [ alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 8355e03977e78..04ef928b7d7b3 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -390,8 +390,11 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, use_irope: bool = False, ) -> None: + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("KV sharing is not supported in V0.") if blocksparse_params is not None: raise ValueError( "XFormers does not support block-sparse attention.") diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 6c5b05a5c7b14..a5fbd1a1c0166 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -21,6 +21,7 @@ from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.platforms import _Backend, current_platform from vllm.utils import direct_register_custom_op +from vllm.v1.attention.backends.utils import validate_kv_sharing_target class Attention(nn.Module): @@ -50,6 +51,7 @@ class Attention(nn.Module): use_mla: bool = False, prefix: str = "", attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, **extra_impl_args, ) -> None: """ @@ -135,7 +137,7 @@ class Attention(nn.Module): self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, blocksparse_params, logits_soft_cap, attn_type, - **extra_impl_args) + kv_sharing_target_layer_name, **extra_impl_args) self.backend = backend_name_to_enum(attn_backend.get_name()) self.dtype = dtype @@ -153,6 +155,19 @@ class Attention(nn.Module): compilation_config.static_forward_context[prefix] = self self.layer_name = prefix self.attn_type = attn_type + + if kv_sharing_target_layer_name is not None: + if not envs.VLLM_USE_V1: + raise NotImplementedError( + "Cross-layer KV sharing is not supported in V0.") + + validate_kv_sharing_target( + prefix, + kv_sharing_target_layer_name, + compilation_config.static_forward_context, + ) + self.kv_sharing_target_layer_name = kv_sharing_target_layer_name + # use a placeholder kv cache tensor during init, which will be replaced # by bind_kv_cache # this variable will not be accessed if use_direct_call is True diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 9e989df1cd892..a92c51883af1c 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -485,6 +485,7 @@ class FlashAttentionImpl(AttentionImpl): blocksparse_params: Optional[dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, use_irope: bool = False, ) -> None: if blocksparse_params is not None: @@ -506,6 +507,7 @@ class FlashAttentionImpl(AttentionImpl): # In flash-attn, setting logits_soft_cap as 0 means no soft cap. logits_soft_cap = 0 self.logits_soft_cap = logits_soft_cap + self.kv_sharing_target_layer_name = kv_sharing_target_layer_name assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads @@ -569,22 +571,26 @@ class FlashAttentionImpl(AttentionImpl): # performance to make sure it does not introduce any overhead. num_actual_tokens = attn_metadata.num_actual_tokens - # Reshape the input keys and values and store them in the cache. - # NOTE(woosuk): Here, key and value are padded while slot_mapping is - # not padded. However, we don't need to do key[:num_actual_tokens] and - # value[:num_actual_tokens] because the reshape_and_cache_flash op uses - # the slot_mapping's shape to determine the number of actual tokens. key_cache, value_cache = kv_cache.unbind(0) - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) + + if self.kv_sharing_target_layer_name is None: + # Reshape the input keys and values and store them in the cache. + # Skip this if sharing KV cache with an earlier attention layer. + # NOTE(woosuk): Here, key and value are padded while slot_mapping is + # not padded. However, we don't need to do key[:num_actual_tokens] + # and value[:num_actual_tokens] because the reshape_and_cache_flash + # op uses the slot_mapping's shape to determine the number of + # actual tokens. + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) if self.kv_cache_dtype.startswith("fp8"): key_cache = key_cache.view(torch.float8_e4m3fn) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 8bd998eba7695..f1b61c152a9d8 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -507,6 +507,7 @@ class FlashInferImpl(AttentionImpl): blocksparse_params: Optional[dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[int] = None, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -521,6 +522,7 @@ class FlashInferImpl(AttentionImpl): self.sliding_window = (sliding_window - 1, 0) self.kv_cache_dtype = kv_cache_dtype self.logits_soft_cap = logits_soft_cap + self.kv_sharing_target_layer_name = kv_sharing_target_layer_name assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads @@ -568,21 +570,25 @@ class FlashInferImpl(AttentionImpl): # performance to make sure it does not introduce any overhead. num_actual_tokens = attn_metadata.num_actual_tokens - # Reshape the input keys and values and store them in the cache. - # NOTE(woosuk): Here, key and value are padded while slot_mapping is - # not padded. However, we don't need to do key[:num_actual_tokens] and - # value[:num_actual_tokens] because the reshape_and_cache_flash op uses - # the slot_mapping's shape to determine the number of actual tokens. - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - kv_cache[:, 0], - kv_cache[:, 1], - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) + + if self.kv_sharing_target_layer_name is None: + # Reshape the input keys and values and store them in the cache. + # Skip this if sharing KV cache with an earlier attention layer. + # NOTE(woosuk): Here, key and value are padded while slot_mapping is + # not padded. However, we don't need to do key[:num_actual_tokens] + # and value[:num_actual_tokens] because the reshape_and_cache_flash + # op uses the slot_mapping's shape to determine the number of + # actual tokens. + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + kv_cache[:, 0], + kv_cache[:, 1], + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) window_left = (self.sliding_window[0] if self.sliding_window is not None else -1) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 96befca5a1e94..06acbb909a4f6 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -586,6 +586,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): blocksparse_params: Optional[dict[str, Any]], logits_soft_cap: Optional[float], attn_type: str, + kv_sharing_target_layer_name: Optional[str], # MLA Specific Arguments q_lora_rank: Optional[int], kv_lora_rank: int, @@ -595,6 +596,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): v_head_dim: int, kv_b_proj: ColumnParallelLinear, ) -> None: + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("KV sharing is not supported for MLA") + self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 060a7c9d8c853..318b8ede14366 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -93,12 +93,13 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): blocksparse_params: Optional[dict[str, Any]], logits_soft_cap: Optional[float], attn_type: str, + kv_sharing_target_layer_name: Optional[str], # MLA Specific Arguments **mla_args) -> None: super().__init__(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, blocksparse_params, logits_soft_cap, attn_type, - **mla_args) + kv_sharing_target_layer_name, **mla_args) assert is_flashmla_supported(), \ "FlashMLA is not supported on this device" diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 8925b5a5cd7d0..1f0406a7ac1f8 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -139,12 +139,13 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): blocksparse_params: Optional[dict[str, Any]], logits_soft_cap: Optional[float], attn_type: str, + kv_sharing_target_layer_name: Optional[str], # MLA Specific Arguments **mla_args) -> None: super().__init__(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, blocksparse_params, logits_soft_cap, attn_type, - **mla_args) + kv_sharing_target_layer_name, **mla_args) assert (num_heads == 16 or num_heads == 128), ( f"Aiter MLA only supports 16 or 128 number of heads.\n" f"Provided {num_heads} number of heads.\n" diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index 0857fc133c431..e26d7909184b5 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -41,12 +41,13 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): blocksparse_params: Optional[dict[str, Any]], logits_soft_cap: Optional[float], attn_type: str, + kv_sharing_target_layer_name: Optional[str], # MLA Specific Arguments **mla_args) -> None: super().__init__(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, blocksparse_params, logits_soft_cap, attn_type, - **mla_args) + kv_sharing_target_layer_name, **mla_args) unsupported_features = [ alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 896f1394cfa4b..0f956ba88b9c1 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -113,6 +113,7 @@ class PallasAttentionBackendImpl(AttentionImpl): blocksparse_params: Optional[dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[int] = None, use_irope: bool = False, ) -> None: if use_irope: @@ -128,6 +129,7 @@ class PallasAttentionBackendImpl(AttentionImpl): self.num_kv_heads = num_kv_heads self.sliding_window = sliding_window self.logits_soft_cap = logits_soft_cap + self.kv_sharing_target_layer_name = kv_sharing_target_layer_name assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads @@ -181,7 +183,9 @@ class PallasAttentionBackendImpl(AttentionImpl): num_tokens, hidden_size = query.shape query = query.view(num_tokens, self.num_heads, self.head_size) - if kv_cache.numel() > 0: + if self.kv_sharing_target_layer_name is None and kv_cache.numel() > 0: + # Write input keys and values to the KV cache. + # Skip this if sharing KV cache with an earlier attention layer. slot_mapping = attn_metadata.slot_mapping write_to_kv_cache(key, value, kv_cache, slot_mapping) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 6a3314dd87889..968f137011186 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -88,6 +88,7 @@ class TritonAttentionImpl(AttentionImpl): blocksparse_params: Optional[dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[int] = None, use_irope: bool = False, ) -> None: if blocksparse_params is not None: @@ -109,6 +110,7 @@ class TritonAttentionImpl(AttentionImpl): # In flash-attn, setting logits_soft_cap as 0 means no soft cap. logits_soft_cap = 0 self.logits_soft_cap = logits_soft_cap + self.kv_sharing_target_layer_name = kv_sharing_target_layer_name self.use_irope = use_irope @@ -178,31 +180,34 @@ class TritonAttentionImpl(AttentionImpl): if use_prefill_decode_attn: key_cache, value_cache = PagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size) - - # Reshape the input keys and values and store them in the cache. - PagedAttention.write_to_paged_cache( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - else: key_cache, value_cache = kv_cache.unbind(0) - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) + + if self.kv_sharing_target_layer_name is None: + # Reshape the input keys and values and store them in the cache. + # Skip this if sharing KV cache with an earlier attention layer. + if use_prefill_decode_attn: + PagedAttention.write_to_paged_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + else: + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) if self.kv_cache_dtype.startswith("fp8"): key_cache = key_cache.view(self.fp8_dtype) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 2e65619ed7bc8..72c7643539273 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -17,3 +17,36 @@ class CommonAttentionMetadata: seq_lens: torch.Tensor """(batch_size,), the length of each request including both computed tokens and newly scheduled tokens""" + + +def validate_kv_sharing_target(current_layer_name, target_layer_name, + static_forward_context): + error_msg = (f"Specified KV sharing target layer for {current_layer_name} " + f"is not valid: target layer {target_layer_name} ") + + if current_layer_name == target_layer_name: + raise ValueError(error_msg + + "cannot be the same as the current layer.") + + if target_layer_name not in static_forward_context: + from vllm.model_executor.models.utils import extract_layer_index + + # If target layer name is not in the static fwd context, it means either + # a) the target layer does not come BEFORE the current layer, or + # b) the target layer is not an Attention layer that exists in the model + current_layer_idx = extract_layer_index(current_layer_name) + target_layer_idx = extract_layer_index(target_layer_name) + if current_layer_idx <= target_layer_idx: + raise ValueError(error_msg + "must come before the current layer.") + else: + raise ValueError(error_msg + + "is not a valid Attention layer in the model.") + + # Currently KV sharing is only supported between layers of the same type + target_layer_attn_type = static_forward_context[ + target_layer_name].attn_type + expected = static_forward_context[current_layer_name].attn_type + if target_layer_attn_type != expected: + raise ValueError( + error_msg + + f"must be the same type as the current layer ({expected}).") diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c96ad0c015301..b7448be26f107 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -59,8 +59,8 @@ from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin -from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs, - scatter_mm_placeholders) +from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing, + sanity_check_mm_encoder_outputs, scatter_mm_placeholders) if TYPE_CHECKING: import xgrammar as xgr @@ -276,6 +276,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): pin_memory=self.pin_memory) self.seq_lens_np = self.seq_lens_cpu.numpy() + # Layer pairings for cross-layer KV sharing. + # If an Attention layer `layer_name` is in the keys of this dict, it + # means this layer will perform attention using the keys and values + # from the KV cache of `shared_kv_cache_layers[layer_name]`. + self.shared_kv_cache_layers: dict[str, str] = {} + def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> bool: """ Update the order of requests in the batch based on the attention @@ -2097,6 +2103,15 @@ class GPUModelRunner(LoRAModelRunnerMixin): # KV cache specs. raise ValueError("Unknown KV cache spec type.") + # Setup `kv_cache_config` and `kv_caches` for models + # with cross-layer KV sharing + if self.shared_kv_cache_layers: + initialize_kv_cache_for_kv_sharing( + self.shared_kv_cache_layers, + kv_cache_config.kv_cache_groups, + kv_caches, + ) + if self.speculative_config and self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) # validate all draft model layers belong to the same kv cache @@ -2125,6 +2140,18 @@ class GPUModelRunner(LoRAModelRunnerMixin): use_mla = self.vllm_config.model_config.use_mla kv_cache_spec: dict[str, KVCacheSpec] = {} for layer_name, attn_module in layers.items(): + if (kv_tgt_layer := + attn_module.kv_sharing_target_layer_name) is not None: + # The layer doesn't need its own KV cache and will use that of + # the target layer. We skip creating a KVCacheSpec for it, so + # that KV cache management logic will act as this layer does + # not exist, and doesn't allocate KV cache for the layer. This + # enables the memory saving of cross-layer kv sharing, allowing + # a given amount of memory to accommodate longer context lengths + # or enable more requests to be processed simultaneously. + self.shared_kv_cache_layers[layer_name] = kv_tgt_layer + continue + # TODO: Support other attention modules, e.g., cross-attention if attn_module.attn_type == AttentionType.DECODER: if attn_module.sliding_window is not None: diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 48ea3cb7bff0d..f15234f49ce05 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -44,7 +44,8 @@ from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin -from .utils import sanity_check_mm_encoder_outputs +from .utils import (initialize_kv_cache_for_kv_sharing, + sanity_check_mm_encoder_outputs) if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -238,6 +239,12 @@ class TPUModelRunner(LoRAModelRunnerMixin): self.num_reqs_paddings = _get_req_paddings( min_req_size=MIN_NUM_SEQS, max_req_size=self.max_num_reqs) + # Layer pairings for cross-layer KV sharing. + # If an Attention layer `layer_name` is in the keys of this dict, it + # means this layer will perform attention using the keys and values + # from the KV cache of `shared_kv_cache_layers[layer_name]`. + self.shared_kv_cache_layers: dict[str, str] = {} + # tensors for structured decoding self.grammar_bitmask_cpu = torch.zeros( (self.max_num_reqs, cdiv(self.vocab_size, 32)), @@ -455,6 +462,18 @@ class TPUModelRunner(LoRAModelRunnerMixin): block_size = self.vllm_config.cache_config.block_size kv_cache_spec: dict[str, KVCacheSpec] = {} for layer_name, attn_module in layers.items(): + if (kv_tgt_layer := + attn_module.kv_sharing_target_layer_name) is not None: + # The layer doesn't need its own KV cache and will use that of + # the target layer. We skip creating a KVCacheSpec for it, so + # that KV cache management logic will act as this layer does + # not exist, and doesn't allocate KV cache for the layer. This + # enables the memory saving of cross-layer kv sharing, allowing + # a given amount of memory to accommodate longer context lengths + # or enable more requests to be processed simultaneously. + self.shared_kv_cache_layers[layer_name] = kv_tgt_layer + continue + if attn_module.attn_type == AttentionType.DECODER: if attn_module.sliding_window is not None: kv_cache_spec[layer_name] = SlidingWindowSpec( @@ -1376,6 +1395,15 @@ class TPUModelRunner(LoRAModelRunnerMixin): else: raise NotImplementedError + # Setup `kv_cache_config` and `kv_caches` for models + # with cross-layer KV sharing + if self.shared_kv_cache_layers: + initialize_kv_cache_for_kv_sharing( + self.shared_kv_cache_layers, + kv_cache_config.kv_cache_groups, + kv_caches, + ) + bind_kv_cache( kv_caches, self.vllm_config.compilation_config.static_forward_context, diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index b23b28c1d7e9c..055cf01530f02 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -4,6 +4,8 @@ from typing import Optional import torch +from vllm.v1.kv_cache_interface import KVCacheGroupSpec + def sanity_check_mm_encoder_outputs( mm_embeddings: object, @@ -73,3 +75,37 @@ def gather_mm_placeholders( return placeholders return placeholders[is_embed] + + +def initialize_kv_cache_for_kv_sharing( + shared_kv_cache_layers: dict[str, str], + kv_cache_groups: list[KVCacheGroupSpec], + kv_caches: dict[str, torch.Tensor], +) -> None: + """ + Sets up KV cache sharing by reusing the allocated KV caches in `kv_caches` + for layers that do not allocate its own KV cache, based on the mapping in + `shared_kv_cache_layers`. Adds these layers to the corresponding KV cache + group, which is needed to ensure that attention metadata is assigned later. + + Args: + shared_kv_cache_layers: Layer pairings for cross-layer KV sharing. + If an Attention layer `layer_name` is in the keys of this dict, it + means this layer will perform attention using the keys and values + from the KV cache of `shared_kv_cache_layers[layer_name]`. + kv_cache_groups: The KV cache groups of the model. + kv_caches: The allocated kv_caches with layer names as keys. + Note that layers in shared_kv_cache_layers.keys() are not + originally included as it only contains layers which have its own + KV cache allocation. + """ + # Record index of KV cache group for each layer that allocates a KV cache. + layer_to_kv_cache_group_idx: dict[str, int] = {} + for i, kv_cache_group in enumerate(kv_cache_groups): + for layer_name in kv_cache_group.layer_names: + layer_to_kv_cache_group_idx[layer_name] = i + + for layer_name, target_layer_name in shared_kv_cache_layers.items(): + kv_caches[layer_name] = kv_caches[target_layer_name] + group_idx = layer_to_kv_cache_group_idx[target_layer_name] + kv_cache_groups[group_idx].layer_names.append(layer_name)