mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:05:01 +08:00
[KV offload][3/N] Add worker-side CPU support (#21448)
Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
parent
ce75e15373
commit
7ac67ea525
177
tests/v1/kv_offload/test_cpu_gpu.py
Normal file
177
tests/v1/kv_offload/test_cpu_gpu.py
Normal file
@ -0,0 +1,177 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import random
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
||||
from vllm.v1.attention.backends.flashinfer import FlashInferBackend
|
||||
from vllm.v1.attention.backends.mla.flashattn_mla import FlashAttnMLABackend
|
||||
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec
|
||||
from vllm.v1.kv_offload.worker.cpu_gpu import CpuGpuOffloadingHandler
|
||||
|
||||
NUM_GPU_BLOCKS = [64]
|
||||
NUM_CPU_BLOCKS = [256]
|
||||
GPU_BLOCK_SIZES = [16]
|
||||
GPU_BLOCKS_PER_CPU_BLOCK = [1, 3]
|
||||
HEAD_SIZES = [64]
|
||||
NUM_HEADS = [8]
|
||||
NUM_LAYERS = [4]
|
||||
DTYPES = [torch.bfloat16]
|
||||
SEEDS = [0]
|
||||
CUDA_DEVICES = ['cuda:0']
|
||||
NUM_MAPPINGS = [3]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("gpu_to_cpu", [True, False])
|
||||
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("gpu_block_size", GPU_BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("gpu_blocks_per_cpu_block", GPU_BLOCKS_PER_CPU_BLOCK)
|
||||
@pytest.mark.parametrize("num_gpu_blocks", NUM_GPU_BLOCKS)
|
||||
@pytest.mark.parametrize("num_cpu_blocks", NUM_CPU_BLOCKS)
|
||||
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_transfer(
|
||||
gpu_to_cpu: bool,
|
||||
num_mappings: int,
|
||||
head_size: int,
|
||||
num_heads: int,
|
||||
gpu_block_size: int,
|
||||
gpu_blocks_per_cpu_block: int,
|
||||
num_gpu_blocks: int,
|
||||
num_cpu_blocks: int,
|
||||
num_layers: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
# create per-layer GPU KV caches
|
||||
attn_backends_list = [
|
||||
FlashAttentionBackend, FlashInferBackend, FlashAttnMLABackend
|
||||
]
|
||||
|
||||
gpu_caches = {}
|
||||
attn_backends = {}
|
||||
for i in range(num_layers):
|
||||
layer_name = f'layer {i}'
|
||||
|
||||
attn_backend = attn_backends_list[i % len(attn_backends_list)]
|
||||
attn_backends[layer_name] = attn_backend
|
||||
|
||||
gpu_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
num_gpu_blocks, gpu_block_size, num_heads, head_size)
|
||||
gpu_caches[layer_name] = torch.rand(gpu_cache_shape,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
|
||||
# create handler
|
||||
cpu_block_size = gpu_blocks_per_cpu_block * gpu_block_size
|
||||
handler = CpuGpuOffloadingHandler(attn_backends=attn_backends,
|
||||
gpu_block_size=gpu_block_size,
|
||||
cpu_block_size=cpu_block_size,
|
||||
num_cpu_blocks=num_cpu_blocks,
|
||||
gpu_caches=gpu_caches)
|
||||
|
||||
# select block mappings
|
||||
gpu_blocks = random.sample(range(num_gpu_blocks),
|
||||
num_mappings * gpu_blocks_per_cpu_block)
|
||||
cpu_blocks = random.sample(range(num_cpu_blocks), num_mappings)
|
||||
|
||||
# convert cpu blocks to gpu block size
|
||||
cpu_blocks_in_gpu_block_size = []
|
||||
for cpu_block in cpu_blocks:
|
||||
base_block_id = cpu_block * gpu_blocks_per_cpu_block
|
||||
for i in range(gpu_blocks_per_cpu_block):
|
||||
cpu_blocks_in_gpu_block_size.append(i + base_block_id)
|
||||
|
||||
# maybe skip a GPU block to test writing to the middle of a CPU block
|
||||
if gpu_to_cpu:
|
||||
gpu_blocks = gpu_blocks[gpu_blocks_per_cpu_block - 1:]
|
||||
cpu_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size[
|
||||
gpu_blocks_per_cpu_block - 1:]
|
||||
|
||||
# set transfer direction
|
||||
if gpu_to_cpu:
|
||||
src_kv_caches = handler.gpu_tensors
|
||||
dst_kv_caches = handler.cpu_tensors
|
||||
src_spec_class = GPULoadStoreSpec
|
||||
dst_spec_class = CPULoadStoreSpec
|
||||
src_blocks = gpu_blocks
|
||||
dst_blocks = cpu_blocks
|
||||
src_blocks_in_gpu_block_size = gpu_blocks
|
||||
dst_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size
|
||||
dst_size_in_gpu_blocks = num_cpu_blocks * gpu_blocks_per_cpu_block
|
||||
else:
|
||||
src_kv_caches = handler.cpu_tensors
|
||||
dst_kv_caches = handler.gpu_tensors
|
||||
src_spec_class = CPULoadStoreSpec
|
||||
dst_spec_class = GPULoadStoreSpec
|
||||
src_blocks = cpu_blocks
|
||||
dst_blocks = gpu_blocks
|
||||
src_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size
|
||||
dst_blocks_in_gpu_block_size = gpu_blocks
|
||||
dst_size_in_gpu_blocks = num_gpu_blocks
|
||||
|
||||
# build dst -> src mapping
|
||||
dst_to_src = {}
|
||||
for src_block, dst_block in zip(src_blocks_in_gpu_block_size,
|
||||
dst_blocks_in_gpu_block_size):
|
||||
dst_to_src[dst_block] = src_block
|
||||
|
||||
# build transfer specs
|
||||
src_spec = src_spec_class(src_blocks)
|
||||
dst_spec = dst_spec_class(dst_blocks)
|
||||
|
||||
# clone src and dst tensors before transfer
|
||||
orig_src_caches = [x.clone() for x in src_kv_caches]
|
||||
orig_dst_caches = [x.clone() for x in dst_kv_caches]
|
||||
|
||||
# call transfer function
|
||||
assert handler.transfer_async(1, (src_spec, dst_spec))
|
||||
assert set(handler.transfer_events.keys()) == {1}
|
||||
|
||||
# wait for transfer to complete
|
||||
end_time = time.time() + 10
|
||||
while time.time() < end_time:
|
||||
finished = handler.get_finished()
|
||||
if finished:
|
||||
assert finished == [(1, True)]
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
# verify src tensors did not change
|
||||
for orig_tensor, tensor in zip(orig_src_caches, src_kv_caches):
|
||||
assert torch.equal(orig_tensor, tensor)
|
||||
|
||||
# verify dst tensors
|
||||
for dst_block in range(dst_size_in_gpu_blocks):
|
||||
src_block_candidate = dst_to_src.get(dst_block)
|
||||
for src_cache, dst_cache, orig_dst_cache, kv_dim in zip(
|
||||
src_kv_caches, dst_kv_caches, orig_dst_caches,
|
||||
handler.kv_dim_before_num_blocks):
|
||||
if kv_dim:
|
||||
# iterate over key, value
|
||||
for i in range(2):
|
||||
if src_block_candidate is not None:
|
||||
expected_value = src_cache[i][src_block_candidate]
|
||||
else:
|
||||
expected_value = orig_dst_cache[i][dst_block]
|
||||
torch.testing.assert_close(dst_cache[i][dst_block].cpu(),
|
||||
expected_value.cpu())
|
||||
else:
|
||||
if src_block_candidate is not None:
|
||||
expected_value = src_cache[src_block_candidate]
|
||||
else:
|
||||
expected_value = orig_dst_cache[dst_block]
|
||||
torch.testing.assert_close(dst_cache[dst_block].cpu(),
|
||||
expected_value.cpu())
|
||||
171
vllm/v1/kv_offload/worker/cpu_gpu.py
Normal file
171
vllm/v1/kv_offload/worker/cpu_gpu.py
Normal file
@ -0,0 +1,171 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention import AttentionBackend
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import is_pin_memory_available
|
||||
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec
|
||||
from vllm.v1.kv_offload.worker.worker import (OffloadingHandler,
|
||||
TransferResult, TransferSpec)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def expand_block_ids(block_ids: np.ndarray,
|
||||
block_size_factor: int,
|
||||
output: np.ndarray,
|
||||
skip_count: int = 0):
|
||||
"""
|
||||
Convert a list of block IDs to a list of matching block ids,
|
||||
assuming each block is composed of actual block_size_factor blocks.
|
||||
Outputs to output tensor.
|
||||
The first skip_count blocks will be skipped.
|
||||
Note that skip_count must be less than block_size_factor.
|
||||
|
||||
For example, if block_ids = [0, 1, 3] and block_size_factor = 4,
|
||||
then it yields [0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15]
|
||||
since 0 maps to [0, 1, 2, 3]
|
||||
1 maps to [4, 5, 6, 7]
|
||||
and 3 maps to [12, 13, 14, 15]
|
||||
"""
|
||||
assert skip_count < block_size_factor
|
||||
|
||||
first_range = np.arange(skip_count, block_size_factor)
|
||||
full_range = np.arange(0, block_size_factor)
|
||||
|
||||
output_idx = 0
|
||||
for i, block_id in enumerate(block_ids):
|
||||
base_block_id = block_id * block_size_factor
|
||||
indices = first_range if i == 0 else full_range
|
||||
output_end_idx = output_idx + len(indices)
|
||||
output[output_idx:output_end_idx] = base_block_id + indices
|
||||
output_idx = output_end_idx
|
||||
|
||||
|
||||
class CpuGpuOffloadingHandler(OffloadingHandler):
|
||||
|
||||
def __init__(self, gpu_block_size: int, cpu_block_size: int,
|
||||
num_cpu_blocks: int, gpu_caches: dict[str, torch.Tensor],
|
||||
attn_backends: dict[str, type[AttentionBackend]]):
|
||||
assert cpu_block_size % gpu_block_size == 0
|
||||
self.block_size_factor = cpu_block_size // gpu_block_size
|
||||
|
||||
# cuda streams for gpu->cpu and cpu->gpu
|
||||
self.d2h_stream = torch.cuda.Stream()
|
||||
self.h2d_stream = torch.cuda.Stream()
|
||||
|
||||
# job_id -> transfer cuda event
|
||||
self.transfer_events: dict[int, torch.cuda.Event] = {}
|
||||
# list of cuda events available for re-use
|
||||
self.events_pool: list[torch.cuda.Event] = []
|
||||
|
||||
pin_memory = is_pin_memory_available()
|
||||
|
||||
# allocate cpu tensors
|
||||
logger.info("Allocating %d CPU tensors...", len(gpu_caches))
|
||||
self.gpu_tensors: list[torch.Tensor] = []
|
||||
self.cpu_tensors: list[torch.Tensor] = []
|
||||
self.kv_dim_before_num_blocks: list[bool] = []
|
||||
for layer_name, gpu_tensor in gpu_caches.items():
|
||||
self.gpu_tensors.append(gpu_tensor)
|
||||
|
||||
gpu_shape = gpu_tensor.shape
|
||||
test_shape = attn_backends[layer_name].get_kv_cache_shape(
|
||||
num_blocks=1234, block_size=16, num_kv_heads=8, head_size=256)
|
||||
if test_shape[0] == 1234:
|
||||
# shape is (num_blocks, ...)
|
||||
num_blocks_idx = 0
|
||||
self.kv_dim_before_num_blocks.append(False)
|
||||
else:
|
||||
# shape should be (2, num_blocks, ...)
|
||||
assert test_shape[0] == 2
|
||||
assert test_shape[1] == 1234
|
||||
assert gpu_shape[0] == 2
|
||||
|
||||
num_blocks_idx = 1
|
||||
self.kv_dim_before_num_blocks.append(True)
|
||||
|
||||
cpu_shape = list(gpu_shape)
|
||||
cpu_shape[num_blocks_idx] = num_cpu_blocks * self.block_size_factor
|
||||
|
||||
logger.debug("Allocating CPU tensor of shape %r", cpu_shape)
|
||||
self.cpu_tensors.append(
|
||||
torch.zeros(cpu_shape,
|
||||
dtype=gpu_tensor.dtype,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory))
|
||||
|
||||
def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
|
||||
src_spec, dst_spec = spec
|
||||
if isinstance(src_spec, CPULoadStoreSpec):
|
||||
assert isinstance(dst_spec, GPULoadStoreSpec)
|
||||
stream = self.h2d_stream
|
||||
src_tensors = self.cpu_tensors
|
||||
dst_tensors = self.gpu_tensors
|
||||
src_block_size_factor = self.block_size_factor
|
||||
dst_block_size_factor = 1
|
||||
else:
|
||||
assert isinstance(src_spec, GPULoadStoreSpec)
|
||||
assert isinstance(dst_spec, CPULoadStoreSpec)
|
||||
stream = self.d2h_stream
|
||||
src_tensors = self.gpu_tensors
|
||||
dst_tensors = self.cpu_tensors
|
||||
src_block_size_factor = 1
|
||||
dst_block_size_factor = self.block_size_factor
|
||||
|
||||
src_blocks = src_spec.block_ids
|
||||
dst_blocks = dst_spec.block_ids
|
||||
assert src_blocks.ndim == 1
|
||||
assert dst_blocks.ndim == 1
|
||||
|
||||
dst_sub_blocks_to_skip = (-src_blocks.size % dst_block_size_factor)
|
||||
src_sub_block_count = src_blocks.size * src_block_size_factor
|
||||
|
||||
assert (
|
||||
src_sub_block_count == dst_blocks.size * dst_block_size_factor -
|
||||
dst_sub_blocks_to_skip)
|
||||
|
||||
src_to_dst = np.empty((src_sub_block_count, 2), dtype=np.int64)
|
||||
expand_block_ids(src_blocks, src_block_size_factor, src_to_dst[:, 0])
|
||||
expand_block_ids(dst_blocks,
|
||||
dst_block_size_factor,
|
||||
src_to_dst[:, 1],
|
||||
skip_count=dst_sub_blocks_to_skip)
|
||||
src_to_dst_tensor = torch.from_numpy(src_to_dst)
|
||||
|
||||
event = self.events_pool.pop() if self.events_pool \
|
||||
else torch.cuda.Event()
|
||||
with torch.cuda.stream(stream):
|
||||
for src_tensor, dst_tensor, kv_dim in zip(
|
||||
src_tensors, dst_tensors, self.kv_dim_before_num_blocks):
|
||||
if kv_dim:
|
||||
src_key_cache = src_tensor[0]
|
||||
dst_key_cache = dst_tensor[0]
|
||||
ops.swap_blocks(src_key_cache, dst_key_cache,
|
||||
src_to_dst_tensor)
|
||||
src_value_cache = src_tensor[1]
|
||||
dst_value_cache = dst_tensor[1]
|
||||
ops.swap_blocks(src_value_cache, dst_value_cache,
|
||||
src_to_dst_tensor)
|
||||
else:
|
||||
ops.swap_blocks(src_tensor, dst_tensor, src_to_dst_tensor)
|
||||
event.record(stream)
|
||||
|
||||
self.transfer_events[job_id] = event
|
||||
|
||||
# success
|
||||
return True
|
||||
|
||||
def get_finished(self) -> list[TransferResult]:
|
||||
results: list[TransferResult] = []
|
||||
for job_id, event in self.transfer_events.items():
|
||||
if event.query():
|
||||
results.append((job_id, True))
|
||||
self.events_pool.append(event)
|
||||
for job_id, _ in results:
|
||||
del self.transfer_events[job_id]
|
||||
return results
|
||||
Loading…
x
Reference in New Issue
Block a user