From 7ac67ea5255c764e87bdfc5c712bfaa35f491764 Mon Sep 17 00:00:00 2001 From: Or Ozeri Date: Fri, 19 Sep 2025 19:53:45 +0300 Subject: [PATCH] [KV offload][3/N] Add worker-side CPU support (#21448) Signed-off-by: Or Ozeri --- tests/v1/kv_offload/test_cpu_gpu.py | 177 +++++++++++++++++++++++++++ vllm/v1/kv_offload/worker/cpu_gpu.py | 171 ++++++++++++++++++++++++++ 2 files changed, 348 insertions(+) create mode 100644 tests/v1/kv_offload/test_cpu_gpu.py create mode 100644 vllm/v1/kv_offload/worker/cpu_gpu.py diff --git a/tests/v1/kv_offload/test_cpu_gpu.py b/tests/v1/kv_offload/test_cpu_gpu.py new file mode 100644 index 000000000000..0edb9513e3ff --- /dev/null +++ b/tests/v1/kv_offload/test_cpu_gpu.py @@ -0,0 +1,177 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import random +import time + +import pytest +import torch + +from vllm.platforms import current_platform +from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend +from vllm.v1.attention.backends.flashinfer import FlashInferBackend +from vllm.v1.attention.backends.mla.flashattn_mla import FlashAttnMLABackend +from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec +from vllm.v1.kv_offload.worker.cpu_gpu import CpuGpuOffloadingHandler + +NUM_GPU_BLOCKS = [64] +NUM_CPU_BLOCKS = [256] +GPU_BLOCK_SIZES = [16] +GPU_BLOCKS_PER_CPU_BLOCK = [1, 3] +HEAD_SIZES = [64] +NUM_HEADS = [8] +NUM_LAYERS = [4] +DTYPES = [torch.bfloat16] +SEEDS = [0] +CUDA_DEVICES = ['cuda:0'] +NUM_MAPPINGS = [3] + + +@pytest.mark.parametrize("gpu_to_cpu", [True, False]) +@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("gpu_block_size", GPU_BLOCK_SIZES) +@pytest.mark.parametrize("gpu_blocks_per_cpu_block", GPU_BLOCKS_PER_CPU_BLOCK) +@pytest.mark.parametrize("num_gpu_blocks", NUM_GPU_BLOCKS) +@pytest.mark.parametrize("num_cpu_blocks", NUM_CPU_BLOCKS) +@pytest.mark.parametrize("num_layers", NUM_LAYERS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_transfer( + gpu_to_cpu: bool, + num_mappings: int, + head_size: int, + num_heads: int, + gpu_block_size: int, + gpu_blocks_per_cpu_block: int, + num_gpu_blocks: int, + num_cpu_blocks: int, + num_layers: int, + dtype: torch.dtype, + seed: int, + device: str, +) -> None: + current_platform.seed_everything(seed) + + # create per-layer GPU KV caches + attn_backends_list = [ + FlashAttentionBackend, FlashInferBackend, FlashAttnMLABackend + ] + + gpu_caches = {} + attn_backends = {} + for i in range(num_layers): + layer_name = f'layer {i}' + + attn_backend = attn_backends_list[i % len(attn_backends_list)] + attn_backends[layer_name] = attn_backend + + gpu_cache_shape = attn_backend.get_kv_cache_shape( + num_gpu_blocks, gpu_block_size, num_heads, head_size) + gpu_caches[layer_name] = torch.rand(gpu_cache_shape, + dtype=dtype, + device=device) + + # create handler + cpu_block_size = gpu_blocks_per_cpu_block * gpu_block_size + handler = CpuGpuOffloadingHandler(attn_backends=attn_backends, + gpu_block_size=gpu_block_size, + cpu_block_size=cpu_block_size, + num_cpu_blocks=num_cpu_blocks, + gpu_caches=gpu_caches) + + # select block mappings + gpu_blocks = random.sample(range(num_gpu_blocks), + num_mappings * gpu_blocks_per_cpu_block) + cpu_blocks = random.sample(range(num_cpu_blocks), num_mappings) + + # convert cpu blocks to gpu block size + cpu_blocks_in_gpu_block_size = [] + for cpu_block in cpu_blocks: + base_block_id = cpu_block * gpu_blocks_per_cpu_block + for i in range(gpu_blocks_per_cpu_block): + cpu_blocks_in_gpu_block_size.append(i + base_block_id) + + # maybe skip a GPU block to test writing to the middle of a CPU block + if gpu_to_cpu: + gpu_blocks = gpu_blocks[gpu_blocks_per_cpu_block - 1:] + cpu_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size[ + gpu_blocks_per_cpu_block - 1:] + + # set transfer direction + if gpu_to_cpu: + src_kv_caches = handler.gpu_tensors + dst_kv_caches = handler.cpu_tensors + src_spec_class = GPULoadStoreSpec + dst_spec_class = CPULoadStoreSpec + src_blocks = gpu_blocks + dst_blocks = cpu_blocks + src_blocks_in_gpu_block_size = gpu_blocks + dst_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size + dst_size_in_gpu_blocks = num_cpu_blocks * gpu_blocks_per_cpu_block + else: + src_kv_caches = handler.cpu_tensors + dst_kv_caches = handler.gpu_tensors + src_spec_class = CPULoadStoreSpec + dst_spec_class = GPULoadStoreSpec + src_blocks = cpu_blocks + dst_blocks = gpu_blocks + src_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size + dst_blocks_in_gpu_block_size = gpu_blocks + dst_size_in_gpu_blocks = num_gpu_blocks + + # build dst -> src mapping + dst_to_src = {} + for src_block, dst_block in zip(src_blocks_in_gpu_block_size, + dst_blocks_in_gpu_block_size): + dst_to_src[dst_block] = src_block + + # build transfer specs + src_spec = src_spec_class(src_blocks) + dst_spec = dst_spec_class(dst_blocks) + + # clone src and dst tensors before transfer + orig_src_caches = [x.clone() for x in src_kv_caches] + orig_dst_caches = [x.clone() for x in dst_kv_caches] + + # call transfer function + assert handler.transfer_async(1, (src_spec, dst_spec)) + assert set(handler.transfer_events.keys()) == {1} + + # wait for transfer to complete + end_time = time.time() + 10 + while time.time() < end_time: + finished = handler.get_finished() + if finished: + assert finished == [(1, True)] + break + time.sleep(0.1) + + # verify src tensors did not change + for orig_tensor, tensor in zip(orig_src_caches, src_kv_caches): + assert torch.equal(orig_tensor, tensor) + + # verify dst tensors + for dst_block in range(dst_size_in_gpu_blocks): + src_block_candidate = dst_to_src.get(dst_block) + for src_cache, dst_cache, orig_dst_cache, kv_dim in zip( + src_kv_caches, dst_kv_caches, orig_dst_caches, + handler.kv_dim_before_num_blocks): + if kv_dim: + # iterate over key, value + for i in range(2): + if src_block_candidate is not None: + expected_value = src_cache[i][src_block_candidate] + else: + expected_value = orig_dst_cache[i][dst_block] + torch.testing.assert_close(dst_cache[i][dst_block].cpu(), + expected_value.cpu()) + else: + if src_block_candidate is not None: + expected_value = src_cache[src_block_candidate] + else: + expected_value = orig_dst_cache[dst_block] + torch.testing.assert_close(dst_cache[dst_block].cpu(), + expected_value.cpu()) diff --git a/vllm/v1/kv_offload/worker/cpu_gpu.py b/vllm/v1/kv_offload/worker/cpu_gpu.py new file mode 100644 index 000000000000..556c29247e5e --- /dev/null +++ b/vllm/v1/kv_offload/worker/cpu_gpu.py @@ -0,0 +1,171 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import numpy as np +import torch + +from vllm import _custom_ops as ops +from vllm.attention import AttentionBackend +from vllm.logger import init_logger +from vllm.utils import is_pin_memory_available +from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec +from vllm.v1.kv_offload.worker.worker import (OffloadingHandler, + TransferResult, TransferSpec) + +logger = init_logger(__name__) + + +def expand_block_ids(block_ids: np.ndarray, + block_size_factor: int, + output: np.ndarray, + skip_count: int = 0): + """ + Convert a list of block IDs to a list of matching block ids, + assuming each block is composed of actual block_size_factor blocks. + Outputs to output tensor. + The first skip_count blocks will be skipped. + Note that skip_count must be less than block_size_factor. + + For example, if block_ids = [0, 1, 3] and block_size_factor = 4, + then it yields [0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15] + since 0 maps to [0, 1, 2, 3] + 1 maps to [4, 5, 6, 7] + and 3 maps to [12, 13, 14, 15] + """ + assert skip_count < block_size_factor + + first_range = np.arange(skip_count, block_size_factor) + full_range = np.arange(0, block_size_factor) + + output_idx = 0 + for i, block_id in enumerate(block_ids): + base_block_id = block_id * block_size_factor + indices = first_range if i == 0 else full_range + output_end_idx = output_idx + len(indices) + output[output_idx:output_end_idx] = base_block_id + indices + output_idx = output_end_idx + + +class CpuGpuOffloadingHandler(OffloadingHandler): + + def __init__(self, gpu_block_size: int, cpu_block_size: int, + num_cpu_blocks: int, gpu_caches: dict[str, torch.Tensor], + attn_backends: dict[str, type[AttentionBackend]]): + assert cpu_block_size % gpu_block_size == 0 + self.block_size_factor = cpu_block_size // gpu_block_size + + # cuda streams for gpu->cpu and cpu->gpu + self.d2h_stream = torch.cuda.Stream() + self.h2d_stream = torch.cuda.Stream() + + # job_id -> transfer cuda event + self.transfer_events: dict[int, torch.cuda.Event] = {} + # list of cuda events available for re-use + self.events_pool: list[torch.cuda.Event] = [] + + pin_memory = is_pin_memory_available() + + # allocate cpu tensors + logger.info("Allocating %d CPU tensors...", len(gpu_caches)) + self.gpu_tensors: list[torch.Tensor] = [] + self.cpu_tensors: list[torch.Tensor] = [] + self.kv_dim_before_num_blocks: list[bool] = [] + for layer_name, gpu_tensor in gpu_caches.items(): + self.gpu_tensors.append(gpu_tensor) + + gpu_shape = gpu_tensor.shape + test_shape = attn_backends[layer_name].get_kv_cache_shape( + num_blocks=1234, block_size=16, num_kv_heads=8, head_size=256) + if test_shape[0] == 1234: + # shape is (num_blocks, ...) + num_blocks_idx = 0 + self.kv_dim_before_num_blocks.append(False) + else: + # shape should be (2, num_blocks, ...) + assert test_shape[0] == 2 + assert test_shape[1] == 1234 + assert gpu_shape[0] == 2 + + num_blocks_idx = 1 + self.kv_dim_before_num_blocks.append(True) + + cpu_shape = list(gpu_shape) + cpu_shape[num_blocks_idx] = num_cpu_blocks * self.block_size_factor + + logger.debug("Allocating CPU tensor of shape %r", cpu_shape) + self.cpu_tensors.append( + torch.zeros(cpu_shape, + dtype=gpu_tensor.dtype, + device="cpu", + pin_memory=pin_memory)) + + def transfer_async(self, job_id: int, spec: TransferSpec) -> bool: + src_spec, dst_spec = spec + if isinstance(src_spec, CPULoadStoreSpec): + assert isinstance(dst_spec, GPULoadStoreSpec) + stream = self.h2d_stream + src_tensors = self.cpu_tensors + dst_tensors = self.gpu_tensors + src_block_size_factor = self.block_size_factor + dst_block_size_factor = 1 + else: + assert isinstance(src_spec, GPULoadStoreSpec) + assert isinstance(dst_spec, CPULoadStoreSpec) + stream = self.d2h_stream + src_tensors = self.gpu_tensors + dst_tensors = self.cpu_tensors + src_block_size_factor = 1 + dst_block_size_factor = self.block_size_factor + + src_blocks = src_spec.block_ids + dst_blocks = dst_spec.block_ids + assert src_blocks.ndim == 1 + assert dst_blocks.ndim == 1 + + dst_sub_blocks_to_skip = (-src_blocks.size % dst_block_size_factor) + src_sub_block_count = src_blocks.size * src_block_size_factor + + assert ( + src_sub_block_count == dst_blocks.size * dst_block_size_factor - + dst_sub_blocks_to_skip) + + src_to_dst = np.empty((src_sub_block_count, 2), dtype=np.int64) + expand_block_ids(src_blocks, src_block_size_factor, src_to_dst[:, 0]) + expand_block_ids(dst_blocks, + dst_block_size_factor, + src_to_dst[:, 1], + skip_count=dst_sub_blocks_to_skip) + src_to_dst_tensor = torch.from_numpy(src_to_dst) + + event = self.events_pool.pop() if self.events_pool \ + else torch.cuda.Event() + with torch.cuda.stream(stream): + for src_tensor, dst_tensor, kv_dim in zip( + src_tensors, dst_tensors, self.kv_dim_before_num_blocks): + if kv_dim: + src_key_cache = src_tensor[0] + dst_key_cache = dst_tensor[0] + ops.swap_blocks(src_key_cache, dst_key_cache, + src_to_dst_tensor) + src_value_cache = src_tensor[1] + dst_value_cache = dst_tensor[1] + ops.swap_blocks(src_value_cache, dst_value_cache, + src_to_dst_tensor) + else: + ops.swap_blocks(src_tensor, dst_tensor, src_to_dst_tensor) + event.record(stream) + + self.transfer_events[job_id] = event + + # success + return True + + def get_finished(self) -> list[TransferResult]: + results: list[TransferResult] = [] + for job_id, event in self.transfer_events.items(): + if event.query(): + results.append((job_id, True)) + self.events_pool.append(event) + for job_id, _ in results: + del self.transfer_events[job_id] + return results