[KV offload][2/N] Introduce LRU-based CPU offloading management (#20075)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
Or Ozeri 2025-09-19 03:20:51 +03:00 committed by GitHub
parent 9a4600e4dc
commit 9d1c50a5ac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 464 additions and 0 deletions

View File

@ -0,0 +1,175 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Optional
import numpy as np
from vllm.v1.core.kv_cache_utils import BlockHash
from vllm.v1.kv_offload.abstract import (LoadStoreSpec, OffloadingEvent,
PrepareStoreOutput)
from vllm.v1.kv_offload.backends.cpu import CPUBackend
from vllm.v1.kv_offload.lru_manager import LRUOffloadingManager
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec
@dataclass
class ExpectedPrepareStoreOutput:
block_hashes_to_store: list[int]
store_block_ids: list[int]
block_hashes_evicted: list[int]
def to_hashes(int_hashes: list[int]) -> list[BlockHash]:
return [BlockHash(str(i).encode()) for i in int_hashes]
def verify_store_output(
prepare_store_output: Optional[PrepareStoreOutput],
expected_prepare_store_output: ExpectedPrepareStoreOutput):
assert prepare_store_output is not None
assert (prepare_store_output.block_hashes_to_store == to_hashes(
expected_prepare_store_output.block_hashes_to_store))
assert (prepare_store_output.block_hashes_evicted == to_hashes(
expected_prepare_store_output.block_hashes_evicted))
store_spec = prepare_store_output.store_spec
assert isinstance(store_spec, CPULoadStoreSpec)
expected_array = np.array(expected_prepare_store_output.store_block_ids,
dtype=np.int64)
assert np.array_equal(expected_array, store_spec.block_ids)
def verify_load_output(prepare_load_output: LoadStoreSpec,
expected_prepare_load_output: list[int]):
assert isinstance(prepare_load_output, CPULoadStoreSpec)
expected_array = np.array(expected_prepare_load_output, dtype=np.int64)
assert np.array_equal(expected_array, prepare_load_output.block_ids)
def verify_events(events: Iterable[OffloadingEvent],
block_size: int,
expected_stores: tuple[set[int], ...] = (),
expected_evictions: tuple[set[int], ...] = ()):
stores: list[set[BlockHash]] = []
evictions: list[set[BlockHash]] = []
for event in events:
assert event.medium == CPULoadStoreSpec.medium()
assert event.block_size == block_size
if event.removed:
evictions.append(set(event.block_hashes))
else:
stores.append(set(event.block_hashes))
def to_hash_sets(
int_sets: tuple[set[int], ...]) -> tuple[set[BlockHash], ...]:
return tuple([set(to_hashes(list(int_set))) for int_set in int_sets])
assert tuple(evictions) == to_hash_sets(expected_evictions)
assert tuple(stores) == to_hash_sets(expected_stores)
def test_cpu_manager():
"""
Tests LRUOffloadingManager with a CPUBackend.
"""
# initialize a CPU backend with a capacity of 4 blocks
block_size = 256
cpu_backend = CPUBackend(block_size=block_size, num_blocks=4)
cpu_manager = LRUOffloadingManager(cpu_backend, enable_events=True)
# prepare store [1, 2]
prepare_store_output = cpu_manager.prepare_store(to_hashes([1, 2]))
verify_store_output(
prepare_store_output,
ExpectedPrepareStoreOutput(
block_hashes_to_store=[1, 2],
store_block_ids=[0, 1],
block_hashes_evicted=[],
))
# lookup [1, 2] -> not ready
assert cpu_manager.lookup(to_hashes([1, 2])) == 0
# no events so far
assert list(cpu_manager.take_events()) == []
# complete store [1, 2]
cpu_manager.complete_store(to_hashes([1, 2]))
verify_events(cpu_manager.take_events(),
block_size=block_size,
expected_stores=({1, 2}, ))
# lookup [1, 2]
assert cpu_manager.lookup(to_hashes([1])) == 1
assert cpu_manager.lookup(to_hashes([1, 2])) == 2
assert cpu_manager.lookup(to_hashes([1, 2, 3])) == 2
# prepare store [2, 3, 4, 5] -> evicts [1]
prepare_store_output = cpu_manager.prepare_store(to_hashes([2, 3, 4, 5]))
verify_store_output(
prepare_store_output,
ExpectedPrepareStoreOutput(
block_hashes_to_store=[3, 4, 5],
store_block_ids=[2, 3, 0],
block_hashes_evicted=[1],
))
# verify eviction event
verify_events(cpu_manager.take_events(),
block_size=block_size,
expected_evictions=({1}, ))
# prepare store with no space
assert cpu_manager.prepare_store(to_hashes([1, 6])) is None
# complete store [2, 3, 4, 5]
cpu_manager.complete_store(to_hashes([2, 3, 4, 5]))
# prepare load [2, 3]
prepare_load_output = cpu_manager.prepare_load(to_hashes([2, 3]))
verify_load_output(prepare_load_output, [1, 2])
# prepare store with no space ([2, 3] is being loaded)
assert cpu_manager.prepare_store(to_hashes([6, 7, 8])) is None
# complete load [2, 3]
cpu_manager.complete_load(to_hashes([2, 3]))
# prepare store [6, 7, 8] -> evicts [2, 3, 4] (oldest)
prepare_store_output = cpu_manager.prepare_store(to_hashes([6, 7, 8]))
verify_store_output(
prepare_store_output,
ExpectedPrepareStoreOutput(
block_hashes_to_store=[6, 7, 8],
store_block_ids=[3, 2, 1],
block_hashes_evicted=[2, 3, 4],
))
# complete store [6, 7, 8]
cpu_manager.complete_store(to_hashes([6, 7, 8]))
# touch [5, 6, 7] (move to end of LRU order)
cpu_manager.touch(to_hashes([5, 6, 7]))
# prepare store [7, 9] -> evicts [8] (oldest following previous touch)
prepare_store_output = cpu_manager.prepare_store(to_hashes([9]))
verify_store_output(
prepare_store_output,
ExpectedPrepareStoreOutput(
block_hashes_to_store=[9],
store_block_ids=[1],
block_hashes_evicted=[8],
))
# complete store [7, 9] with failure
cpu_manager.complete_store(to_hashes([7, 9]), success=False)
# assert [7] is still stored, but [9] is not
assert cpu_manager.lookup(to_hashes([7])) == 1
assert cpu_manager.lookup(to_hashes([9])) == 0
verify_events(cpu_manager.take_events(),
block_size=block_size,
expected_stores=({3, 4, 5}, {6, 7, 8}),
expected_evictions=({2, 3, 4}, {8}))

View File

@ -0,0 +1,96 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ctypes
from abc import ABC, abstractmethod
from collections.abc import Iterable
from vllm.v1.core.kv_cache_utils import BlockHash
from vllm.v1.kv_offload.abstract import LoadStoreSpec
class BlockStatus(ctypes.Structure):
"""
Offloading status for a single block of KV data.
Holds the following information:
ref_cnt - the current number of transfers using this block as a source.
A value of -1 indicates the block is not yet ready to be read.
load_store_spec - backend-specific information on how to actually
read/write the block.
"""
_fields_ = [("ref_cnt", ctypes.c_int32)]
def __init__(self):
super().__init__()
# initialize block as "not ready" (ref_cnt = -1)
self.ref_cnt = -1
@property
def is_ready(self) -> bool:
"""
Returns whether the block is ready to be read.
"""
return self.ref_cnt >= 0
class Backend(ABC):
"""
An abstract class for allocating and returning specs for writing
KV blocks to some backend.
"""
def __init__(self, block_size: int, medium: str):
self.block_size = block_size
self.medium = medium
@abstractmethod
def get_num_free_blocks(self):
"""
Returns the number of current number of blocks that can be allocated.
"""
pass
@abstractmethod
def allocate_blocks(self,
block_hashes: list[BlockHash]) -> list[BlockStatus]:
"""
Allocate space for writing blocks.
This method assumes there is enough space for allocation.
It is unsafe to use without checking get_num_free_blocks beforehand.
Args:
block_hashes: the hashes identifying the blocks to be written.
Returns:
A list of BlockStatus for the allocated blocks.
The ref_cnt of each returned item will be -1, meaning the block
is not yet ready to be read.
"""
pass
@abstractmethod
def free(self, block: BlockStatus):
"""
Free a previously allocated block.
You should only call this function with blocks returned by
allocate_blocks, and only once per each block.
Args:
block: The block to be freed.
"""
pass
def get_load_store_spec(self, block_hashes: Iterable[BlockHash],
blocks: Iterable[BlockStatus]) -> LoadStoreSpec:
"""
Get backend-specific information on how to read/write blocks.
Args:
block_hashes: the list of block hashes identifying the blocks.
blocks: the list of blocks.
Returns:
A LoadStoreSpec that can be used by a worker
to read/write the blocks.
"""
raise NotImplementedError

View File

@ -0,0 +1,61 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ctypes
from collections.abc import Iterable
from vllm.v1.core.kv_cache_utils import BlockHash
from vllm.v1.kv_offload.abstract import LoadStoreSpec
from vllm.v1.kv_offload.backend import Backend, BlockStatus
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec
class CPUBlockStatus(BlockStatus):
_fields_ = BlockStatus._fields_ + [("block_id", ctypes.c_int64)
] # type: ignore
def __init__(self, block_id: int):
super().__init__()
self.block_id = block_id
class CPUBackend(Backend):
def __init__(self, block_size: int, num_blocks: int):
super().__init__(block_size=block_size,
medium=CPULoadStoreSpec.medium())
self.num_blocks: int = num_blocks
self.num_allocated_blocks: int = 0
self.allocated_blocks_free_list: list[int] = []
def get_num_free_blocks(self):
return (len(self.allocated_blocks_free_list) + self.num_blocks -
self.num_allocated_blocks)
def allocate_blocks(self,
block_hashes: list[BlockHash]) -> list[BlockStatus]:
num_fresh_blocks = min(len(block_hashes),
self.num_blocks - self.num_allocated_blocks)
num_reused_blocks = len(block_hashes) - num_fresh_blocks
assert len(self.allocated_blocks_free_list) >= num_reused_blocks
# allocate fresh blocks
blocks: list[BlockStatus] = []
for _ in range(num_fresh_blocks):
blocks.append(CPUBlockStatus(self.num_allocated_blocks))
self.num_allocated_blocks += 1
# allocate reused blocks
for _ in range(num_reused_blocks):
block_id = self.allocated_blocks_free_list.pop()
blocks.append(CPUBlockStatus(block_id))
return blocks
def free(self, block: BlockStatus):
assert isinstance(block, CPUBlockStatus)
self.allocated_blocks_free_list.append(block.block_id)
def get_load_store_spec(self, block_hashes: Iterable[BlockHash],
blocks: Iterable[BlockStatus]) -> LoadStoreSpec:
return CPULoadStoreSpec([block.block_id for block in blocks])

View File

@ -0,0 +1,132 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import OrderedDict
from collections.abc import Iterable
from typing import Optional
from vllm.v1.core.kv_cache_utils import BlockHash
from vllm.v1.kv_offload.abstract import (LoadStoreSpec, OffloadingEvent,
OffloadingManager, PrepareStoreOutput)
from vllm.v1.kv_offload.backend import Backend, BlockStatus
class LRUOffloadingManager(OffloadingManager):
"""
An OffloadingManager with a pluggable backend, which evicts blocks by LRU.
"""
def __init__(self, backend: Backend, enable_events: bool = False):
self.backend: Backend = backend
# block_hash -> BlockStatus
self.blocks: OrderedDict[BlockHash, BlockStatus] = OrderedDict()
self.events: Optional[list[OffloadingEvent]] = \
[] if enable_events else None
def lookup(self, block_hashes: Iterable[BlockHash]) -> int:
hit_count = 0
for block_hash in block_hashes:
block = self.blocks.get(block_hash)
if block is None or not block.is_ready:
break
hit_count += 1
return hit_count
def prepare_load(self, block_hashes: Iterable[BlockHash]) -> LoadStoreSpec:
blocks = []
for block_hash in block_hashes:
block = self.blocks[block_hash]
assert block.is_ready
block.ref_cnt += 1
blocks.append(block)
return self.backend.get_load_store_spec(block_hashes, blocks)
def touch(self, block_hashes: Iterable[BlockHash]):
for block_hash in reversed(list(block_hashes)):
if self.blocks.get(block_hash):
self.blocks.move_to_end(block_hash)
def complete_load(self, block_hashes: Iterable[BlockHash]):
for block_hash in block_hashes:
block = self.blocks[block_hash]
assert block.ref_cnt > 0
block.ref_cnt -= 1
def prepare_store(
self,
block_hashes: Iterable[BlockHash]) -> Optional[PrepareStoreOutput]:
# filter out blocks that are already stored
block_hashes_to_store = [
block_hash for block_hash in block_hashes
if block_hash not in self.blocks
]
num_blocks_to_evict = (len(block_hashes_to_store) -
self.backend.get_num_free_blocks())
# build list of blocks to evict
to_evict = []
if num_blocks_to_evict > 0:
for block_hash, block in self.blocks.items():
if block.ref_cnt == 0:
to_evict.append(block_hash)
num_blocks_to_evict -= 1
if num_blocks_to_evict == 0:
break
else:
# we could not evict enough blocks
return None
# evict blocks
for block_hash in to_evict:
self.backend.free(self.blocks.pop(block_hash))
if to_evict and self.events is not None:
self.events.append(
OffloadingEvent(block_hashes=to_evict,
block_size=self.backend.block_size,
medium=self.backend.medium,
removed=True))
blocks = self.backend.allocate_blocks(block_hashes_to_store)
assert len(blocks) == len(block_hashes_to_store)
for block_hash, block in zip(block_hashes_to_store, blocks):
self.blocks[block_hash] = block
# build store specs for allocated blocks
store_spec = self.backend.get_load_store_spec(block_hashes_to_store,
blocks)
return PrepareStoreOutput(block_hashes_to_store=block_hashes_to_store,
store_spec=store_spec,
block_hashes_evicted=to_evict)
def complete_store(self,
block_hashes: Iterable[BlockHash],
success: bool = True):
stored_block_hashes: list[BlockHash] = []
if success:
for block_hash in block_hashes:
block = self.blocks[block_hash]
if not block.is_ready:
block.ref_cnt = 0
stored_block_hashes.append(block_hash)
else:
for block_hash in block_hashes:
block = self.blocks[block_hash]
if not block.is_ready:
self.backend.free(block)
del self.blocks[block_hash]
if stored_block_hashes and self.events is not None:
self.events.append(
OffloadingEvent(block_hashes=stored_block_hashes,
block_size=self.backend.block_size,
medium=self.backend.medium,
removed=False))
def take_events(self) -> Iterable[OffloadingEvent]:
if self.events is not None:
yield from self.events
self.events.clear()