diff --git a/tests/v1/kv_connector/unit/test_offloading_connector.py b/tests/v1/kv_connector/unit/test_offloading_connector.py index 23b6c4802d106..69565f584ab89 100644 --- a/tests/v1/kv_connector/unit/test_offloading_connector.py +++ b/tests/v1/kv_connector/unit/test_offloading_connector.py @@ -19,6 +19,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector import ( ) from vllm.forward_context import ForwardContext from vllm.utils.hashing import sha256 +from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend from vllm.v1.core.kv_cache_utils import ( BlockHash, get_request_block_hasher, @@ -92,7 +93,7 @@ class MockOffloadingSpec(OffloadingSpec): return self.manager def get_handlers( - self, _ + self, _, __ ) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]: yield GPULoadStoreSpec, MockLoadStoreSpec, self.handler yield MockLoadStoreSpec, GPULoadStoreSpec, self.handler @@ -138,7 +139,10 @@ class RequestRunner: self.worker_connector = OffloadingConnector(vllm_config, KVConnectorRole.WORKER) # register worker kv_caches to enable OffloadingWorker creations - self.worker_connector.register_kv_caches(kv_caches={"a": torch.empty(0)}) + self.worker_connector.register_cross_layers_kv_cache( + kv_cache=torch.empty(0), + attn_backend=FlashAttentionBackend, + ) # extract connector of scheduler scheduler_connector = self.scheduler.connector diff --git a/tests/v1/kv_offload/test_cpu_offloading.py b/tests/v1/kv_offload/test_cpu_offloading.py index b654ea4298dbb..3ee41c40859dc 100644 --- a/tests/v1/kv_offload/test_cpu_offloading.py +++ b/tests/v1/kv_offload/test_cpu_offloading.py @@ -12,8 +12,10 @@ from tqdm import tqdm from vllm import LLM, SamplingParams, TokensPrompt from vllm.config import KVEventsConfig, KVTransferConfig from vllm.distributed.kv_events import BlockStored, KVEventBatch +from vllm.utils.system_utils import set_env_var -CPU_BLOCK_SIZES = [16, 48] +CPU_BLOCK_SIZES = [48] +ATTN_BACKENDS = ["FLASH_ATTN", "FLASHINFER"] class MockSubscriber: @@ -63,8 +65,88 @@ class MockSubscriber: self.sub.close() +def _latency_test(llm: LLM, subscriber: MockSubscriber): + sampling_params = SamplingParams(max_tokens=1) + + num_times_cpu_better_than_cold = 0 + num_tests = 10 + total_cold_time = 0.0 + total_gpu_hit_time = 0.0 + total_cpu_hit_time = 0.0 + prompt_token_ids = [0] * 10001 + for i in tqdm(range(num_tests), desc="Running tests"): + prompt_token_ids[0] = i + prompts = [TokensPrompt(prompt_token_ids=prompt_token_ids)] + + # run generation - this should trigger saving KV cache + start_time = time.time() + llm.generate(prompts, sampling_params, use_tqdm=False) + cold_time = time.time() - start_time + total_cold_time += cold_time + + # run generation again - should hit the GPU prefix cache + start_time = time.time() + llm.generate(prompts, sampling_params, use_tqdm=False) + gpu_hit_time = time.time() - start_time + total_gpu_hit_time += gpu_hit_time + + # reset prefix cache to avoid GPU hit. + llm.reset_prefix_cache() + + assert subscriber.get_new_cpu_stored_events() + + # run generation again - this should trigger loading from CPU + start_time = time.time() + llm.generate(prompts, sampling_params, use_tqdm=False) + cpu_hit_time = time.time() - start_time + total_cpu_hit_time += cpu_hit_time + + if cpu_hit_time < cold_time: + num_times_cpu_better_than_cold += 1 + + print("Average times:") + print(f" Cold: {total_cold_time * 1000 / num_tests:.2f}ms") + print(f" GPU hit: {total_gpu_hit_time * 1000 / num_tests:.2f}ms") + print(f" CPU hit: {total_cpu_hit_time * 1000 / num_tests:.2f}ms") + + assert num_times_cpu_better_than_cold >= 0.8 * num_tests + + +def _accuracy_test(llm: LLM, subscriber: MockSubscriber): + sampling_params = SamplingParams(max_tokens=1) + cpu_block_size = ( + llm.llm_engine.vllm_config.kv_transfer_config.kv_connector_extra_config[ + "block_size" + ] + ) + + subscriber.get_new_cpu_stored_events() + + # prepend prompt to be cpu block aligned + prompt = "Let's count to 10. One, two, three, four," + while ( + len(llm.generate(prompt, use_tqdm=False)[0].prompt_token_ids) % cpu_block_size + != 0 + ): + prompt = ". " + prompt + + assert subscriber.get_new_cpu_stored_events() + + test_count = 100 + success_count = 0 + for i in range(test_count): + if ( + llm.generate(prompt, sampling_params, use_tqdm=False)[0].outputs[0].text + == " five" + ): + success_count += 1 + + assert success_count >= 0.5 * test_count + + @pytest.mark.parametrize("cpu_block_size", CPU_BLOCK_SIZES) -def test_cpu_offloading(cpu_block_size: int) -> None: +@pytest.mark.parametrize("attn_backend", ATTN_BACKENDS) +def test_cpu_offloading(cpu_block_size: int, attn_backend: str) -> None: """ Tests OffloadingConnector with CPUOffloadingSpec. """ @@ -92,61 +174,20 @@ def test_cpu_offloading(cpu_block_size: int) -> None: topic="test", ) - llm = LLM( - model="meta-llama/Llama-3.2-1B-Instruct", - gpu_memory_utilization=0.5, - kv_events_config=kv_events_config, - kv_transfer_config=kv_transfer_config, - ) - - sampling_params = SamplingParams(temperature=0, max_tokens=1) + with set_env_var("VLLM_ATTENTION_BACKEND", attn_backend): + llm = LLM( + model="meta-llama/Llama-3.2-1B-Instruct", + gpu_memory_utilization=0.5, + kv_events_config=kv_events_config, + kv_transfer_config=kv_transfer_config, + ) events_endpoint = events_endpoint.replace("*", "127.0.0.1") subscriber = MockSubscriber(events_endpoint, topic=kv_events_config.topic) try: - num_times_cpu_better_than_cold = 0 - num_tests = 10 - total_cold_time = 0.0 - total_gpu_hit_time = 0.0 - total_cpu_hit_time = 0.0 - prompt_token_ids = [0] * 10001 - for i in tqdm(range(num_tests), desc="Running tests"): - prompt_token_ids[0] = i - prompts = [TokensPrompt(prompt_token_ids=prompt_token_ids)] - - # run generation - this should trigger saving KV cache - start_time = time.time() - llm.generate(prompts, sampling_params, use_tqdm=False) - cold_time = time.time() - start_time - total_cold_time += cold_time - - # run generation again - should hit the GPU prefix cache - start_time = time.time() - llm.generate(prompts, sampling_params, use_tqdm=False) - gpu_hit_time = time.time() - start_time - total_gpu_hit_time += gpu_hit_time - - # reset prefix cache to avoid GPU hit. - llm.reset_prefix_cache() - - assert subscriber.get_new_cpu_stored_events() - - # run generation again - this should trigger loading from CPU - start_time = time.time() - llm.generate(prompts, sampling_params, use_tqdm=False) - cpu_hit_time = time.time() - start_time - total_cpu_hit_time += cpu_hit_time - - if cpu_hit_time < cold_time: - num_times_cpu_better_than_cold += 1 - - print("Average times:") - print(f" Cold: {total_cold_time * 1000 / num_tests:.2f}ms") - print(f" GPU hit: {total_gpu_hit_time * 1000 / num_tests:.2f}ms") - print(f" CPU hit: {total_cpu_hit_time * 1000 / num_tests:.2f}ms") - - assert num_times_cpu_better_than_cold >= 0.8 * num_tests + _latency_test(llm, subscriber) + _accuracy_test(llm, subscriber) finally: subscriber.close() del llm diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 824e458978350..01c1364f7ee62 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -483,7 +483,10 @@ def test_kv_cache_stride_order(monkeypatch, model_runner): # Permutation that gets you back to expected kv shape for test_stride in ((1, 4, 0, 2, 3), (0, 1, 2, 3, 4)): - def rnd_stride_order(test_stride=test_stride): + def rnd_stride_order( + include_num_layers_dimension: bool = False, test_stride=test_stride + ): + assert not include_num_layers_dimension return test_stride # Patch the attention backend class and re-trigger the KV cache creation diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 188becb6ad6f0..67ded88475243 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -76,7 +76,34 @@ class AttentionBackend(ABC): raise NotImplementedError @staticmethod - def get_kv_cache_stride_order() -> tuple[int, ...]: + def get_kv_cache_stride_order( + include_num_layers_dimension: bool = False, + ) -> tuple[int, ...]: + """ + Get the physical (memory layout) ordering of the kv cache dimensions. + e.g. if the KV cache shape is + [2, num_blocks, block_size, num_heads, head_size], + and get_kv_cache_stride_order returns (1, 3, 0, 2, 4) then the physical + ordering of dimensions is + [num_blocks, num_heads, 2, block_size, head_size]. + + If this function is unimplemented / raises NotImplementedError, + the physical layout of the KV cache will match the logical shape. + + Args: + include_num_layers_dimension: if True, includes an additional + num_layers dimension, which is assumed to be prepended + to the logical KV cache shape. + With the above example, a return value (2, 4, 0, 1, 3, 5) + corresponds to + [num_blocks, num_heads, num_layers, 2, block_size, head_size]. + + If an additional dimension is NOT included in the returned + tuple, the physical layout will not include a layers dimension. + + Returns: + A tuple of ints which is a permutation of range(len(shape)). + """ raise NotImplementedError @classmethod diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index f85eb414b2222..74f09278b7bb1 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -38,7 +38,7 @@ The class provides the following primitives: import enum from abc import ABC, abstractmethod from collections.abc import Callable, Iterable -from typing import TYPE_CHECKING, Any, Literal, Optional +from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional import torch @@ -47,7 +47,7 @@ from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.outputs import KVConnectorOutput if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionMetadata + from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata from vllm.config import VllmConfig from vllm.distributed.kv_events import KVCacheEvent from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( @@ -142,6 +142,18 @@ class KVConnectorMetadata(ABC): # noqa: B024 class KVConnectorBase_V1(ABC): + """ + Base class for KV connectors. + + Attributes: + prefer_cross_layer_blocks (bool): Indicates whether this connector + prefers KV blocks that hold KV data for all layers (for speeding + up KV data transfers). + Defaults to False. + """ + + prefer_cross_layer_blocks: ClassVar[bool] = False + def __init__( self, vllm_config: "VllmConfig", @@ -226,6 +238,23 @@ class KVConnectorBase_V1(ABC): """ return + def register_cross_layers_kv_cache( + self, kv_cache: torch.Tensor, attn_backend: type["AttentionBackend"] + ): + """ + Initialize with a single KV cache tensor used by all layers. + The first dimension should be num_layers. + This function will only be called for models with uniform layers, + and only if the prefers_cross_layer_blocks is set to True. + Only one of the functions + {register_kv_caches, register_cross_layers_kv_cache} will be called. + + Args: + kv_cache: a cross-layers kv cache tensor + attn_backend: The attention backend that corresponds to all layers + """ + return + def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp): """ Set the xPU-specific ops for copying KV between host and device. diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py index 582e42cc466ae..8cd09014cab11 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py @@ -4,12 +4,12 @@ from collections import defaultdict from collections.abc import Iterable, Iterator from dataclasses import dataclass from itertools import islice -from typing import Any +from typing import Any, ClassVar import torch -from vllm.attention import AttentionMetadata -from vllm.config import VllmConfig +from vllm.attention import Attention, AttentionBackend, AttentionMetadata +from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent from vllm.distributed.kv_transfer.kv_connector.v1 import ( KVConnectorBase_V1, @@ -42,6 +42,8 @@ class OffloadingConnectorMetadata(KVConnectorMetadata): class OffloadingConnector(KVConnectorBase_V1): + prefer_cross_layer_blocks: ClassVar[bool] = True + def __init__( self, vllm_config: VllmConfig, @@ -63,6 +65,12 @@ class OffloadingConnector(KVConnectorBase_V1): assert self.connector_worker is not None self.connector_worker.register_kv_caches(kv_caches) + def register_cross_layers_kv_cache( + self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend] + ): + assert self.connector_worker is not None + self.connector_worker.register_cross_layers_kv_cache(kv_cache, attn_backend) + def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: assert self.connector_worker is not None assert isinstance(self._connector_metadata, OffloadingConnectorMetadata) @@ -422,10 +430,35 @@ class OffloadingConnectorWorker: self._job_counter = job_id + 1 return job_id - def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): - for src_cls, dst_cls, handler in self.spec.get_handlers(kv_caches): + def _register_handlers( + self, + kv_caches: dict[str, torch.Tensor], + attn_backends: dict[str, type[AttentionBackend]], + ): + for src_cls, dst_cls, handler in self.spec.get_handlers( + kv_caches, attn_backends + ): self.worker.register_handler(src_cls, dst_cls, handler) + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + layer_names = list(kv_caches.keys()) + layers = get_layers_from_vllm_config( + self.spec.vllm_config, Attention, layer_names + ) + attn_backends = { + layer_name: layers[layer_name].get_attn_backend() + for layer_name in layer_names + } + self._register_handlers(kv_caches, attn_backends) + + def register_cross_layers_kv_cache( + self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend] + ): + cross_layer_name = "ALL_LAYERS" + kv_caches = {cross_layer_name: kv_cache} + attn_backends = {cross_layer_name: attn_backend} + self._register_handlers(kv_caches, attn_backends) + def start_load_kv(self, metadata: OffloadingConnectorMetadata): for req_id, transfer_spec in metadata.reqs_to_load.items(): job_id = self._generate_job_id() diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index cf3c1d05f5b3f..9fa6b1dfd19dd 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -99,12 +99,20 @@ class FlashAttentionBackend(AttentionBackend): return (2, num_blocks, block_size, num_kv_heads, head_size) @staticmethod - def get_kv_cache_stride_order() -> tuple[int, ...]: + def get_kv_cache_stride_order( + include_num_layers_dimension: bool = False, + ) -> tuple[int, ...]: # `stride_order` indicates the permutation that gets # us from `get_kv_cache_shape` to the actual memory layout we want. cache_layout = get_kv_cache_layout() - if cache_layout == "NHD": + if cache_layout == "NHD" and include_num_layers_dimension: + # (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size) + return (2, 0, 1, 3, 4, 5) + elif cache_layout == "NHD": stride_order = (0, 1, 2, 3, 4) + elif cache_layout == "HND" and include_num_layers_dimension: + # (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size) + return (2, 4, 0, 1, 3, 5) elif cache_layout == "HND": stride_order = (0, 1, 3, 2, 4) else: diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 4da1637d96eb6..3ad7e8c52fc1f 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -309,12 +309,20 @@ class FlashInferBackend(AttentionBackend): return (num_blocks, 2, block_size, num_kv_heads, head_size) @staticmethod - def get_kv_cache_stride_order() -> tuple[int, ...]: + def get_kv_cache_stride_order( + include_num_layers_dimension: bool = False, + ) -> tuple[int, ...]: # `stride_order` indicates the permutation that gets us from # `get_kv_cache_shape` to the actual memory layout we want. cache_layout = get_kv_cache_layout() - if cache_layout == "NHD": + if cache_layout == "NHD" and include_num_layers_dimension: + # (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size) + return (1, 0, 2, 3, 4, 5) + elif cache_layout == "NHD": stride_order = (0, 1, 2, 3, 4) + elif cache_layout == "HND" and include_num_layers_dimension: + # (num_blocks, 2, num_kv_heads, num_layers, block_size, head_size) + return (1, 2, 4, 0, 3, 5) elif cache_layout == "HND": stride_order = (0, 1, 3, 2, 4) else: diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 32f406980f2ed..43aef8a7cca91 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -308,6 +308,15 @@ class MLACommonBackend(AttentionBackend): ) -> tuple[int, ...]: return (num_blocks, block_size, head_size) + @staticmethod + def get_kv_cache_stride_order( + include_num_layers_dimension: bool = False, + ) -> tuple[int, ...]: + # `stride_order` indicates the permutation that gets + # us from `get_kv_cache_shape` to the actual memory layout we want. + # (num_blocks, num_layers, block_size, head_size) + return (1, 0, 2, 3) if include_num_layers_dimension else (0, 1, 2) + @classmethod def get_supported_head_sizes(cls) -> list[int]: return [576] diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index cc0988435768c..d38361e0fcbf8 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -48,7 +48,11 @@ class DeepseekV32IndexerBackend(AttentionBackend): return (num_blocks, block_size, head_size) @staticmethod - def get_kv_cache_stride_order() -> tuple[int, ...]: + def get_kv_cache_stride_order( + include_num_layers_dimension: bool = False, + ) -> tuple[int, ...]: + if include_num_layers_dimension: + return (0, 1, 2, 3) return (0, 1, 2) diff --git a/vllm/v1/kv_offload/cpu.py b/vllm/v1/kv_offload/cpu.py index 4b1bbe6f0cc2a..86747299eb107 100644 --- a/vllm/v1/kv_offload/cpu.py +++ b/vllm/v1/kv_offload/cpu.py @@ -4,8 +4,8 @@ from collections.abc import Iterator import torch -from vllm.config import VllmConfig, get_layers_from_vllm_config -from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.attention import AttentionBackend +from vllm.config import VllmConfig from vllm.platforms import current_platform from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager from vllm.v1.kv_offload.arc_manager import ARCOffloadingManager @@ -63,7 +63,9 @@ class CPUOffloadingSpec(OffloadingSpec): return self._manager def get_handlers( - self, kv_caches: dict[str, torch.Tensor] + self, + kv_caches: dict[str, torch.Tensor], + attn_backends: dict[str, type[AttentionBackend]], ) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]: if not self._handler: if not current_platform.is_cuda_alike(): @@ -71,15 +73,6 @@ class CPUOffloadingSpec(OffloadingSpec): "CPU Offloading is currently only supported on CUDA-alike GPUs" ) - layer_names = list(kv_caches.keys()) - layers = get_layers_from_vllm_config( - self.vllm_config, AttentionLayerBase, layer_names - ) - attn_backends = { - layer_name: layers[layer_name].get_attn_backend() - for layer_name in layer_names - } - self._handler = CpuGpuOffloadingHandler( attn_backends=attn_backends, gpu_block_size=self.gpu_block_size, diff --git a/vllm/v1/kv_offload/spec.py b/vllm/v1/kv_offload/spec.py index a3c539a47d458..c1813a4ff4ea9 100644 --- a/vllm/v1/kv_offload/spec.py +++ b/vllm/v1/kv_offload/spec.py @@ -11,6 +11,7 @@ from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager from vllm.v1.kv_offload.worker.worker import OffloadingHandler if TYPE_CHECKING: + from vllm.attention import AttentionBackend from vllm.config import VllmConfig logger = init_logger(__name__) @@ -48,13 +49,16 @@ class OffloadingSpec(ABC): @abstractmethod def get_handlers( - self, kv_caches: dict[str, torch.Tensor] + self, + kv_caches: dict[str, torch.Tensor], + attn_backends: dict[str, type["AttentionBackend"]], ) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]: """ Get offloading handlers along with their respective src and dst types. Args: kv_caches: A dictionary of layer_name -> gpu_kv_cache tensor. + attn_backends: A dictionary of layer_name -> AttentionBackend. Yields: Tuples of (src_type, dst_type, offloading_handler). diff --git a/vllm/v1/kv_offload/worker/cpu_gpu.py b/vllm/v1/kv_offload/worker/cpu_gpu.py index 111046377a5da..bb163f0043fc6 100644 --- a/vllm/v1/kv_offload/worker/cpu_gpu.py +++ b/vllm/v1/kv_offload/worker/cpu_gpu.py @@ -83,10 +83,18 @@ class CpuGpuOffloadingHandler(OffloadingHandler): self.gpu_tensors.append(gpu_tensor) gpu_shape = gpu_tensor.shape - test_shape = attn_backends[layer_name].get_kv_cache_shape( + attn_backend = attn_backends[layer_name] + test_shape = attn_backend.get_kv_cache_shape( num_blocks=1234, block_size=16, num_kv_heads=8, head_size=256 ) - if test_shape[0] == 1234: + + if len(gpu_shape) != len(test_shape): + # cross-layers tensor + # shape is (num_blocks, ...) + assert len(gpu_shape) == len(test_shape) + 1 + num_blocks_idx = 0 + self.kv_dim_before_num_blocks.append(False) + elif test_shape[0] == 1234: # shape is (num_blocks, ...) num_blocks_idx = 0 self.kv_dim_before_num_blocks.append(False) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0490ed39c8c78..4b0a08ab57e16 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -349,6 +349,9 @@ class GPUModelRunner( # self.model: nn.Module # Set after load_model # Initialize in initialize_kv_cache self.kv_caches: list[torch.Tensor] = [] + # Initialize in initialize_kv_cache_tensors + self.cross_layers_kv_cache: torch.Tensor | None = None + self.cross_layers_attn_backend: type[AttentionBackend] | None = None # indexes: [kv_cache_group_id][attn_group] self.attn_groups: list[list[AttentionGroup]] = [] # self.kv_cache_config: KVCacheConfig @@ -4930,12 +4933,30 @@ class GPUModelRunner( Dict[str, torch.Tensor]: A map between layer names to their corresponding memory buffer for KV cache. """ - # Initialize the memory buffer for KV cache - kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config) - # Change the memory buffer to the desired shape - kv_caches = self._reshape_kv_cache_tensors( - kv_cache_config, kv_cache_raw_tensors, kernel_block_sizes - ) + + # Try creating KV caches optimized for kv-connector transfers + cache_dtype = self.cache_config.cache_dtype + if self.use_uniform_kv_cache(self.attn_groups, cache_dtype): + kv_caches, cross_layers_kv_cache, attn_backend = ( + self.allocate_uniform_kv_caches( + kv_cache_config, + self.attn_groups, + cache_dtype, + self.device, + kernel_block_sizes, + ) + ) + self.cross_layers_kv_cache = cross_layers_kv_cache + self.cross_layers_attn_backend = attn_backend + else: + # Fallback to the general case + # Initialize the memory buffer for KV cache + kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config) + + # Change the memory buffer to the desired shape + kv_caches = self._reshape_kv_cache_tensors( + kv_cache_config, kv_cache_raw_tensors, kernel_block_sizes + ) # Set up cross-layer KV cache sharing for layer_name, target_layer_name in self.shared_kv_cache_layers.items(): @@ -5017,7 +5038,13 @@ class GPUModelRunner( if has_kv_transfer_group(): kv_transfer_group = get_kv_transfer_group() - kv_transfer_group.register_kv_caches(kv_caches) + if self.cross_layers_kv_cache is not None: + assert self.cross_layers_attn_backend is not None + kv_transfer_group.register_cross_layers_kv_cache( + self.cross_layers_kv_cache, self.cross_layers_attn_backend + ) + else: + kv_transfer_group.register_kv_caches(kv_caches) kv_transfer_group.set_host_xfer_buffer_ops(copy_kv_blocks) if self.dcp_world_size > 1: diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py index db037a9fccd5c..e59361f21372a 100644 --- a/vllm/v1/worker/kv_connector_model_runner_mixin.py +++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py @@ -11,7 +11,11 @@ from typing import ( TYPE_CHECKING, # noqa: UP035 ) +import torch + +from vllm.attention import AttentionBackend from vllm.config import VllmConfig +from vllm.config.cache import CacheDType from vllm.distributed.kv_transfer import ( ensure_kv_transfer_shutdown, get_kv_transfer_group, @@ -21,11 +25,13 @@ from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.forward_context import get_forward_context, set_forward_context from vllm.logger import init_logger +from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig from vllm.v1.outputs import ( EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput, ModelRunnerOutput, ) +from vllm.v1.worker.utils import AttentionGroup if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -142,3 +148,162 @@ class KVConnectorModelRunnerMixin: if has_kv_transfer_group(): return get_kv_transfer_group().get_kv_connector_stats() return None + + @staticmethod + def use_uniform_kv_cache( + attn_groups: list[list[AttentionGroup]], + cache_dtype: CacheDType, + ) -> bool: + """ + Determines whether a uniform KV layout should be used. + A uniform layout means all layers KV caches will share the same + underlying tensor, where for a given block number, the respective + KV data for all layers will be contiguous. + This will allow efficient KV transfer of per-block KV data for all + layers at once. + Note this layout will only be applied given 3 conditions: + 1. The KV Cache config contains just a single group where all layers + have the same page size. + 2. A KV connector is configured, and the KV connector instance prefers + to use this layout (prefer_cross_layer_blocks() returns True) + 2. The flash attention backend supports this layout + (get_kv_cache_stride_order(True) includes a placement for a + num_layers dimension) + + Note that the actual placement of the num_layers dimensions + in the unified layers tensors will be determined by the attention + backend. + Thus, the layers KV data may still not be contiguous per block + if the attention backend does not support it. + + Args: + attn_groups: The list of attention groups for this model + cache_dtype: The KV cache dtype + Returns: + True if we should use a uniform KV cache layout. + """ + + if not has_kv_transfer_group(): + return False + if not get_kv_transfer_group().prefer_cross_layer_blocks: + return False + + if len(attn_groups) != 1 or len(attn_groups[0]) != 1: + return False + + attn_group = attn_groups[0][0] + kv_cache_spec = attn_group.kv_cache_spec + if not isinstance(kv_cache_spec, AttentionSpec): + return False + + attn_backend = attn_group.backend + kv_cache_shape = attn_backend.get_kv_cache_shape( + 1234, + kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size, + cache_dtype_str=cache_dtype, + ) + + try: + kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( + include_num_layers_dimension=True + ) + except (AttributeError, NotImplementedError): + return False + + # check that attention backend include a layers dimension + return len(kv_cache_stride_order) == len(kv_cache_shape) + 1 + + @staticmethod + def allocate_uniform_kv_caches( + kv_cache_config: KVCacheConfig, + attn_groups: list[list[AttentionGroup]], + cache_dtype: CacheDType, + device: torch.device, + kernel_block_sizes: list[int], + ) -> tuple[dict[str, torch.Tensor], torch.Tensor, type[AttentionBackend]]: + """ + Initializes and reshapes KV caches for the simple case where all + layers have the same layout. + + This function assumes use_uniform_kv_cache() returned True. + + Args: + kv_cache_config: The KV cache config + attn_groups: The list of attention groups for this model + cache_dtype: The KV cache dtype + device: The torch device to allocate on. + kernel_block_sizes: The kernel block sizes for each KV cache group. + Returns: + A tuple (kv_caches, cross_layers_kv_cache, attn_backend) where: + kv_caches is a dict mapping between layer names to their + corresponding memory buffer for KV cache. + cross_layers_kv_cache is the cross layers kv cache tensor + attn_backend is the attention backend matching this tensor + """ + attn_group = attn_groups[0][0] + kv_cache_spec = attn_group.kv_cache_spec + assert isinstance(kv_cache_spec, AttentionSpec) + + tensor_sizes = set( + kv_cache_tensor.size for kv_cache_tensor in kv_cache_config.kv_cache_tensors + ) + assert len(tensor_sizes) == 1 + tensor_size = tensor_sizes.pop() + + page_size = kv_cache_spec.page_size_bytes + assert tensor_size % page_size == 0 + num_blocks = tensor_size // page_size + num_layers = len(kv_cache_config.kv_cache_tensors) + total_size = tensor_size * num_layers + + assert len(kernel_block_sizes) == 1 + kernel_block_size = kernel_block_sizes[0] + num_blocks_per_kv_block = kv_cache_spec.block_size // kernel_block_size + kernel_num_blocks = num_blocks * num_blocks_per_kv_block + + attn_backend = attn_group.backend + kv_cache_shape = attn_backend.get_kv_cache_shape( + kernel_num_blocks, + kernel_block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size, + cache_dtype_str=cache_dtype, + ) + + # prepend a num_layers dimension into the shape + kv_cache_shape = (num_layers,) + kv_cache_shape + + try: + kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( + include_num_layers_dimension=True + ) + assert len(kv_cache_stride_order) == len(kv_cache_shape) + except (AttributeError, NotImplementedError): + kv_cache_stride_order = tuple(range(len(kv_cache_shape))) + + kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order) + + logger.info("Allocating a cross layer KV cache of shape %s", kv_cache_shape) + + # allocate one contiguous buffer for all layers + cross_layers_kv_cache = ( + torch.zeros(total_size, dtype=torch.int8, device=device) + .view(kv_cache_spec.dtype) + .view(kv_cache_shape) + ) + + # Maintain original KV shape view. + inv_order = [ + kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order)) + ] + permuted_kv_cache = cross_layers_kv_cache.permute(*inv_order) + + kv_caches = {} + for i, kv_cache_tensor in enumerate(kv_cache_config.kv_cache_tensors): + tensor = permuted_kv_cache[i] + for layer_name in kv_cache_tensor.shared_by: + kv_caches[layer_name] = tensor + + return kv_caches, cross_layers_kv_cache, attn_backend