From ff54b6bfe38a18d87f8a3020229d5ecddfb9e53e Mon Sep 17 00:00:00 2001 From: Or Ozeri Date: Mon, 22 Sep 2025 22:30:36 +0300 Subject: [PATCH] [KV offload][5/N] Add `CPUOffloadingSpec` (#24251) Signed-off-by: Or Ozeri Signed-off-by: yewentao256 --- docs/features/disagg_prefill.md | 6 ++ .../{test_cpu.py => test_cpu_manager.py} | 0 tests/v1/kv_offload/test_cpu_offloading.py | 62 +++++++++++++++ vllm/v1/kv_offload/cpu.py | 75 +++++++++++++++++++ vllm/v1/kv_offload/factory.py | 3 + 5 files changed, 146 insertions(+) rename tests/v1/kv_offload/{test_cpu.py => test_cpu_manager.py} (100%) create mode 100644 tests/v1/kv_offload/test_cpu_offloading.py create mode 100644 vllm/v1/kv_offload/cpu.py diff --git a/docs/features/disagg_prefill.md b/docs/features/disagg_prefill.md index 996ef00a6b960..69f70b8ff5ac3 100644 --- a/docs/features/disagg_prefill.md +++ b/docs/features/disagg_prefill.md @@ -31,6 +31,12 @@ Now supports 5 types of connectors: --kv-transfer-config '{"kv_connector":"MultiConnector","kv_role":"kv_both","kv_connector_extra_config":{"connectors":[{"kv_connector":"NixlConnector","kv_role":"kv_both"},{"kv_connector":"SharedStorageConnector","kv_role":"kv_both","kv_connector_extra_config":{"shared_storage_path":"local_storage"}}]}}' ``` +- **OffloadingConnector**: enable offloading of KV data to CPU memory, customizing the CPU block size (in tokens) and number of blocks to allocate (per worker): + + ```bash + --kv-transfer-config '{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"block_size": 64, "num_cpu_blocks": 1000}}' + ``` + ## Benchmarks Please refer to for disaggregated prefilling benchmarks. diff --git a/tests/v1/kv_offload/test_cpu.py b/tests/v1/kv_offload/test_cpu_manager.py similarity index 100% rename from tests/v1/kv_offload/test_cpu.py rename to tests/v1/kv_offload/test_cpu_manager.py diff --git a/tests/v1/kv_offload/test_cpu_offloading.py b/tests/v1/kv_offload/test_cpu_offloading.py new file mode 100644 index 0000000000000..fc8ca09bea3de --- /dev/null +++ b/tests/v1/kv_offload/test_cpu_offloading.py @@ -0,0 +1,62 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import time + +import pytest + +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig + +CPU_BLOCK_SIZES = [16, 48] + + +@pytest.mark.parametrize("cpu_block_size", CPU_BLOCK_SIZES) +def test_cpu_offloading(cpu_block_size: int) -> None: + """ + Tests OffloadingConnector with CPUOffloadingSpec. + """ + + # configure OffloadingConnector (spec_name=CPUOffloadingSpec by default) + kv_transfer_config = KVTransferConfig( + kv_connector="OffloadingConnector", + kv_role="kv_both", + kv_connector_extra_config={ + "num_cpu_blocks": 100, + "block_size": cpu_block_size + }, + ) + + llm = LLM( + model="meta-llama/Llama-3.2-1B-Instruct", + gpu_memory_utilization=0.5, + kv_transfer_config=kv_transfer_config, + ) + + prompts = ["Hi " * 100] + sampling_params = SamplingParams(temperature=0, max_tokens=20) + + # run generation - this should trigger saving KV cache + start_time = time.time() + llm.generate(prompts, sampling_params, use_tqdm=False) + cold_time = time.time() - start_time + + # run generation again - should hit the GPU prefix cache + start_time = time.time() + llm.generate(prompts, sampling_params, use_tqdm=False) + gpu_hit_time = time.time() - start_time + + # reset prefix cache to avoid GPU hit. + llm.reset_prefix_cache() + + # sleep for a sec to make sure CPU finished storing + time.sleep(1) + + # run generation again - this should trigger loading from CPU + start_time = time.time() + llm.generate(prompts, sampling_params, use_tqdm=False) + cpu_hit_time = time.time() - start_time + + print("Generation times:") + print(f" Cold: {cold_time * 1000:.2f}ms") + print(f" GPU hit: {gpu_hit_time * 1000:.2f}ms") + print(f" CPU hit: {cpu_hit_time * 1000:.2f}ms") diff --git a/vllm/v1/kv_offload/cpu.py b/vllm/v1/kv_offload/cpu.py new file mode 100644 index 0000000000000..b85d375fe63e2 --- /dev/null +++ b/vllm/v1/kv_offload/cpu.py @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterator +from typing import Optional + +import torch + +from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.platforms import current_platform +from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager +from vllm.v1.kv_offload.backends.cpu import CPUBackend +from vllm.v1.kv_offload.lru_manager import LRUOffloadingManager +from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec +from vllm.v1.kv_offload.spec import OffloadingSpec +from vllm.v1.kv_offload.worker.cpu_gpu import CpuGpuOffloadingHandler +from vllm.v1.kv_offload.worker.worker import OffloadingHandler + + +class CPUOffloadingSpec(OffloadingSpec): + + def __init__(self, vllm_config: VllmConfig): + super().__init__(vllm_config) + + num_cpu_blocks = self.extra_config.get("num_cpu_blocks") + if not num_cpu_blocks: + raise Exception("num_cpu_blocks must be specified " + "in kv_connector_extra_config") + self.num_cpu_blocks: int = num_cpu_blocks + + # scheduler-side + self._manager: Optional[OffloadingManager] = None + + # worker-side + self._handler: Optional[OffloadingHandler] = None + + def get_manager(self) -> OffloadingManager: + if not self._manager: + kv_events_config = self.vllm_config.kv_events_config + enable_events = (kv_events_config is not None + and kv_events_config.enable_kv_cache_events) + self._manager = LRUOffloadingManager(CPUBackend( + block_size=self.offloaded_block_size, + num_blocks=self.num_cpu_blocks), + enable_events=enable_events) + return self._manager + + def get_handlers( + self, kv_caches: dict[str, torch.Tensor] + ) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], + OffloadingHandler]]: + if not self._handler: + if not current_platform.is_cuda(): + raise Exception("CPU Offloading is currently only supported" + " on CUDA GPUs") + + layer_names = list(kv_caches.keys()) + layers = get_layers_from_vllm_config(self.vllm_config, + AttentionLayerBase, + layer_names) + attn_backends = { + layer_name: layers[layer_name].get_attn_backend() + for layer_name in layer_names + } + + self._handler = CpuGpuOffloadingHandler( + attn_backends=attn_backends, + gpu_block_size=self.gpu_block_size, + cpu_block_size=self.offloaded_block_size, + num_cpu_blocks=self.num_cpu_blocks, + gpu_caches=kv_caches) + + assert self._handler is not None + yield GPULoadStoreSpec, CPULoadStoreSpec, self._handler + yield CPULoadStoreSpec, GPULoadStoreSpec, self._handler diff --git a/vllm/v1/kv_offload/factory.py b/vllm/v1/kv_offload/factory.py index 6365ab4a6db75..f9bef6cea9038 100644 --- a/vllm/v1/kv_offload/factory.py +++ b/vllm/v1/kv_offload/factory.py @@ -51,3 +51,6 @@ class OffloadingSpecFactory: # Register various specs here. +OffloadingSpecFactory.register_spec("CPUOffloadingSpec", + "vllm.v1.kv_offload.cpu", + "CPUOffloadingSpec")