Implement ARC KV cache eviction policy for CPU offloader (#27039)

Signed-off-by: Alberto Perdomo <aperdomo@redhat.com>
Signed-off-by: alberto <aperdomo@redhat.com>
Co-authored-by: Or Ozeri <or@ozery.com>
This commit is contained in:
alberto 2025-11-12 17:51:39 +00:00 committed by GitHub
parent 304419576a
commit bac904565f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 565 additions and 5 deletions

View File

@ -11,6 +11,7 @@ from vllm.v1.kv_offload.abstract import (
OffloadingEvent,
PrepareStoreOutput,
)
from vllm.v1.kv_offload.arc_manager import ARCOffloadingManager
from vllm.v1.kv_offload.backends.cpu import CPUBackend
from vllm.v1.kv_offload.lru_manager import LRUOffloadingManager
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec
@ -187,3 +188,310 @@ def test_cpu_manager():
expected_stores=({3, 4, 5}, {6, 7, 8}),
expected_evictions=({2, 3, 4}, {8}),
)
def test_arc_manager_basic():
"""
Tests ARCOffloadingManager basic operations with a CPUBackend.
Verifies that ARC handles store, load, and lookup operations correctly.
"""
# initialize a CPU backend with a capacity of 4 blocks
block_size = 256
cpu_backend = CPUBackend(block_size=block_size, num_blocks=4)
arc_manager = ARCOffloadingManager(cpu_backend, enable_events=True)
# prepare store [1, 2]
prepare_store_output = arc_manager.prepare_store(to_hashes([1, 2]))
verify_store_output(
prepare_store_output,
ExpectedPrepareStoreOutput(
block_hashes_to_store=[1, 2],
store_block_ids=[0, 1],
block_hashes_evicted=[],
),
)
# lookup [1, 2] -> not ready
assert arc_manager.lookup(to_hashes([1, 2])) == 0
# no events so far
assert list(arc_manager.take_events()) == []
# complete store [1, 2]
arc_manager.complete_store(to_hashes([1, 2]))
verify_events(
arc_manager.take_events(), block_size=block_size, expected_stores=({1, 2},)
)
# lookup [1, 2]
assert arc_manager.lookup(to_hashes([1])) == 1
assert arc_manager.lookup(to_hashes([1, 2])) == 2
assert arc_manager.lookup(to_hashes([1, 2, 3])) == 2
# blocks should be in T1 (recent)
assert len(arc_manager.t1) == 2
assert len(arc_manager.t2) == 0
def test_arc_manager_t1_to_t2_promotion():
"""
Tests that accessing a block in T1 promotes it to T2 (frequent).
This is a key feature of ARC's adaptive behavior.
"""
block_size = 256
cpu_backend = CPUBackend(block_size=block_size, num_blocks=4)
arc_manager = ARCOffloadingManager(cpu_backend, enable_events=False)
# store and complete block 1
arc_manager.prepare_store(to_hashes([1]))
arc_manager.complete_store(to_hashes([1]))
# block 1 starts in T1 (recent)
assert to_hashes([1])[0] in arc_manager.t1
assert to_hashes([1])[0] not in arc_manager.t2
# touch block 1 (simulate second access)
arc_manager.touch(to_hashes([1]))
# block 1 should now be in T2 (frequent)
assert to_hashes([1])[0] not in arc_manager.t1
assert to_hashes([1])[0] in arc_manager.t2
def test_arc_manager_eviction_with_load():
"""
Tests ARC eviction behavior similar to LRU test.
Verifies that blocks being loaded (ref_cnt > 0) cannot be evicted.
"""
block_size = 256
cpu_backend = CPUBackend(block_size=block_size, num_blocks=4)
arc_manager = ARCOffloadingManager(cpu_backend, enable_events=True)
# prepare and complete store [1, 2, 3, 4]
prepare_store_output = arc_manager.prepare_store(to_hashes([1, 2, 3, 4]))
verify_store_output(
prepare_store_output,
ExpectedPrepareStoreOutput(
block_hashes_to_store=[1, 2, 3, 4],
store_block_ids=[0, 1, 2, 3],
block_hashes_evicted=[],
),
)
arc_manager.complete_store(to_hashes([1, 2, 3, 4]))
# prepare load [2, 3] (increases ref_cnt)
prepare_load_output = arc_manager.prepare_load(to_hashes([2, 3]))
verify_load_output(prepare_load_output, [1, 2])
# prepare store [5, 6, 7] with [2, 3] being loaded
# should fail because [2, 3] have ref_cnt > 0
assert arc_manager.prepare_store(to_hashes([5, 6, 7])) is None
# complete load [2, 3]
arc_manager.complete_load(to_hashes([2, 3]))
# now prepare store [5, 6, 7] should succeed
# ARC will evict blocks one at a time from T1 as needed
prepare_store_output = arc_manager.prepare_store(to_hashes([5, 6, 7]))
assert prepare_store_output is not None
# Should successfully evict enough blocks to make room (at least 1)
assert len(prepare_store_output.block_hashes_evicted) >= 1
def test_arc_manager_adaptive_target():
"""
Tests ARC's adaptive target adjustment via ghost lists.
When a block in B1 (ghost list) is accessed, target_t1_size increases.
When a block in B2 is accessed, target_t1_size decreases.
"""
block_size = 256
cpu_backend = CPUBackend(block_size=block_size, num_blocks=2)
arc_manager = ARCOffloadingManager(cpu_backend, enable_events=False)
# store blocks 1, 2 (fills cache)
arc_manager.prepare_store(to_hashes([1, 2]))
arc_manager.complete_store(to_hashes([1, 2]))
initial_target = arc_manager.target_t1_size
# store block 3, evicting block 1 (moves to B1 ghost list)
arc_manager.prepare_store(to_hashes([3]))
arc_manager.complete_store(to_hashes([3]))
# block 1 should be in B1 (ghost list)
assert to_hashes([1])[0] in arc_manager.b1
# touch block 1 (cache miss, but in B1)
# this should increase target_t1_size (favor recency)
arc_manager.touch(to_hashes([1]))
# target should have increased
assert arc_manager.target_t1_size > initial_target
def test_arc_manager_t1_t2_eviction_policy():
"""
Tests that ARC evicts from T1 or T2 based on target_t1_size.
If |T1| >= target_t1_size, evict from T1, otherwise from T2.
"""
block_size = 256
cpu_backend = CPUBackend(block_size=block_size, num_blocks=4)
arc_manager = ARCOffloadingManager(cpu_backend, enable_events=False)
# store blocks 1, 2, 3, 4
arc_manager.prepare_store(to_hashes([1, 2, 3, 4]))
arc_manager.complete_store(to_hashes([1, 2, 3, 4]))
# promote blocks 3, 4 to T2 by touching them
arc_manager.touch(to_hashes([3, 4]))
# now: T1 = {1, 2}, T2 = {3, 4}
assert len(arc_manager.t1) == 2
assert len(arc_manager.t2) == 2
# set target_t1_size to prefer evicting from T1
# (when |T1| >= target, evict from T1)
arc_manager.target_t1_size = 1
# store block 5, should evict from T1 (block 1, LRU in T1)
output = arc_manager.prepare_store(to_hashes([5]))
assert output is not None
assert to_hashes([1]) == output.block_hashes_evicted
arc_manager.complete_store(to_hashes([5]))
# block 1 should be in B1 (ghost list)
assert to_hashes([1])[0] in arc_manager.b1
# block 5 should be in T1
assert to_hashes([5])[0] in arc_manager.t1
def test_arc_manager_ghost_list_bounds():
"""
Tests that ghost lists (B1, B2) don't grow unbounded.
They should be capped at cache_capacity.
"""
block_size = 256
cpu_backend = CPUBackend(block_size=block_size, num_blocks=2)
arc_manager = ARCOffloadingManager(cpu_backend, enable_events=False)
# fill cache with blocks 1, 2
arc_manager.prepare_store(to_hashes([1, 2]))
arc_manager.complete_store(to_hashes([1, 2]))
# store many blocks to fill ghost lists
for i in range(3, 20):
arc_manager.prepare_store(to_hashes([i]))
arc_manager.complete_store(to_hashes([i]))
# ghost lists should not exceed cache_capacity
assert len(arc_manager.b1) <= arc_manager.cache_capacity
assert len(arc_manager.b2) <= arc_manager.cache_capacity
def test_arc_manager_touch_ordering():
"""
Tests that touch() correctly updates access patterns.
Similar to LRU test but verifies T1/T2 ordering.
"""
block_size = 256
cpu_backend = CPUBackend(block_size=block_size, num_blocks=4)
arc_manager = ARCOffloadingManager(cpu_backend, enable_events=True)
# store blocks 1, 2, 3, 4
arc_manager.prepare_store(to_hashes([1, 2, 3, 4]))
arc_manager.complete_store(to_hashes([1, 2, 3, 4]))
# promote 3, 4 to T2
arc_manager.touch(to_hashes([3, 4]))
# T1 = {1, 2}, T2 = {3, 4}
# touch [1, 3, 4] - should promote 1 to T2, and move 3,4 to end of T2
arc_manager.touch(to_hashes([1, 3, 4]))
# T1 = {2}, T2 = {1, 3, 4} (in that order, with 4 most recent)
assert len(arc_manager.t1) == 1
assert len(arc_manager.t2) == 3
# store block 5, should evict from T1 (block 2, only one in T1)
prepare_store_output = arc_manager.prepare_store(to_hashes([5]))
verify_store_output(
prepare_store_output,
ExpectedPrepareStoreOutput(
block_hashes_to_store=[5],
store_block_ids=[1], # reuses block 2's storage
block_hashes_evicted=[2],
),
)
def test_arc_manager_failed_store():
"""
Tests that failed store operations clean up correctly.
Similar to LRU test but for ARC.
"""
block_size = 256
cpu_backend = CPUBackend(block_size=block_size, num_blocks=4)
arc_manager = ARCOffloadingManager(cpu_backend, enable_events=True)
# store blocks 1, 2, 3, 4
arc_manager.prepare_store(to_hashes([1, 2, 3, 4]))
arc_manager.complete_store(to_hashes([1, 2, 3, 4]))
# prepare store block 5 (will evict block 1)
prepare_store_output = arc_manager.prepare_store(to_hashes([5]))
assert prepare_store_output is not None
assert len(prepare_store_output.block_hashes_evicted) == 1
# complete store with failure
arc_manager.complete_store(to_hashes([5]), success=False)
# block 5 should not be in cache
assert arc_manager.lookup(to_hashes([5])) == 0
# block 5 should not be in T1 or T2
assert to_hashes([5])[0] not in arc_manager.t1
assert to_hashes([5])[0] not in arc_manager.t2
# evicted block should still be gone (in B1 ghost list)
evicted_hash = prepare_store_output.block_hashes_evicted[0]
assert evicted_hash in arc_manager.b1
def test_arc_manager_full_scenario():
"""
Comprehensive test covering multiple ARC operations in sequence.
Similar to the full LRU test but adapted for ARC behavior.
"""
block_size = 256
cpu_backend = CPUBackend(block_size=block_size, num_blocks=4)
arc_manager = ARCOffloadingManager(cpu_backend, enable_events=True)
# store [1, 2]
arc_manager.prepare_store(to_hashes([1, 2]))
arc_manager.complete_store(to_hashes([1, 2]))
# store [3, 4, 5] -> evicts [1]
prepare_store_output = arc_manager.prepare_store(to_hashes([3, 4, 5]))
assert prepare_store_output is not None
assert len(prepare_store_output.block_hashes_evicted) == 1
arc_manager.complete_store(to_hashes([3, 4, 5]))
# promote some blocks to T2
arc_manager.touch(to_hashes([2, 3]))
# T1 has {4, 5}, T2 has {2, 3}
assert len(arc_manager.t1) == 2
assert len(arc_manager.t2) == 2
# store [6] -> should evict from T1 (4 is oldest in T1)
prepare_store_output = arc_manager.prepare_store(to_hashes([6]))
assert prepare_store_output is not None
arc_manager.complete_store(to_hashes([6]))
# verify blocks 2, 3 (in T2) are still present
assert arc_manager.lookup(to_hashes([2])) == 1
assert arc_manager.lookup(to_hashes([3])) == 1
# verify events
events = list(arc_manager.take_events())
assert len(events) > 0 # should have store and eviction events

View File

@ -0,0 +1,237 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import OrderedDict
from collections.abc import Iterable
from vllm.v1.core.kv_cache_utils import BlockHash
from vllm.v1.kv_offload.abstract import (
LoadStoreSpec,
OffloadingEvent,
OffloadingManager,
PrepareStoreOutput,
)
from vllm.v1.kv_offload.backend import Backend, BlockStatus
class ARCOffloadingManager(OffloadingManager):
"""
An OffloadingManager implementing the ARC (Adaptive Replacement Cache)
eviction policy with a pluggable backend.
Data Structures:
T1: Recent cache containing blocks accessed once.
T2: Frequent cache containing blocks accessed multiple times.
B1/B2: Ghost lists tracking recently evicted blocks from T1/T2.
target_t1_size: Adaptive target size for the T1 partition.
Algorithm Flow:
1. Cache lookup (lookup):
Searches T1 and T2 for block hashes and counts consecutive hits
until a miss or non-ready block is encountered.
2. Cache touch (touch) - Adaptive Learning:
For each block_hash (in reverse order):
- If in T1: Move to T2 (promotion from recent to frequent).
- If in T2: Move to MRU position (end of queue).
- If in B1 ghost list: Increase target_t1_size.
- If in B2 ghost list: Decrease target_t1_size.
3. Block eviction (prepare_store) - Adaptive Replacement:
Determines eviction source based on adaptive target:
- If T1 size > target_t1_size: Evict from T1, add to B1.
- Otherwise: Evict from T2, add to B2.
Finally, bound each ghost list size.
4. Block insertion (prepare_store):
New blocks are always inserted into T1 and removed from B1/B2 if
present. Blocks may later be promoted to T2 during touch operations.
Adaptive Behavior:
The algorithm self-tunes the recency vs. frequency trade-off:
- B1 hit: Recent access patterns matter more increase T1.
- B2 hit: Frequent access patterns matter more decrease T1.
"""
def __init__(self, backend: Backend, enable_events: bool = False):
self.backend: Backend = backend
self.target_t1_size: float = 0.0
self.t1: OrderedDict[BlockHash, BlockStatus] = OrderedDict()
self.t2: OrderedDict[BlockHash, BlockStatus] = OrderedDict()
# block_hash -> None (only care about presence)
self.b1: OrderedDict[BlockHash, None] = OrderedDict()
self.b2: OrderedDict[BlockHash, None] = OrderedDict()
self.events: list[OffloadingEvent] | None = [] if enable_events else None
self.cache_capacity: int = self.backend.get_num_free_blocks()
def lookup(self, block_hashes: Iterable[BlockHash]) -> int:
hit_count = 0
for block_hash in block_hashes:
block = self.t1.get(block_hash) or self.t2.get(block_hash)
if block is None or not block.is_ready:
break
hit_count += 1
return hit_count
def prepare_load(self, block_hashes: Iterable[BlockHash]) -> LoadStoreSpec:
blocks = []
for block_hash in block_hashes:
block = self.t1.get(block_hash) or self.t2.get(block_hash)
assert block is not None, f"Block {block_hash!r} not found in cache"
assert block.is_ready, f"Block {block_hash!r} is not ready for reading"
block.ref_cnt += 1
blocks.append(block)
return self.backend.get_load_store_spec(block_hashes, blocks)
def touch(self, block_hashes: Iterable[BlockHash]):
for block_hash in reversed(list(block_hashes)):
if block_hash in self.t1:
block = self.t1.pop(block_hash)
if not block.is_ready:
# block was just prepared to be stored, not really touched twice
self.t1.move_to_end(block_hash)
else:
self.t2[block_hash] = block
elif block_hash in self.t2:
self.t2.move_to_end(block_hash)
elif block_hash in self.b1:
delta = max(1, len(self.b2) / len(self.b1))
self.target_t1_size = min(
self.target_t1_size + delta, self.cache_capacity
)
# move to MRU position (end) to keep it fresh in the ghost list
self.b1.move_to_end(block_hash)
elif block_hash in self.b2:
delta = max(1, len(self.b1) / len(self.b2))
self.target_t1_size = max(self.target_t1_size - delta, 0)
# move to MRU position (end) to keep it fresh in the ghost list
self.b2.move_to_end(block_hash)
def complete_load(self, block_hashes: Iterable[BlockHash]):
for block_hash in block_hashes:
block = self.t1.get(block_hash) or self.t2.get(block_hash)
assert block is not None, f"Block {block_hash!r} not found"
assert block.ref_cnt > 0, f"Block {block_hash!r} ref_cnt is already 0"
block.ref_cnt -= 1
def prepare_store(
self, block_hashes: Iterable[BlockHash]
) -> PrepareStoreOutput | None:
block_hashes_to_store = []
for block_hash in block_hashes:
if block_hash not in self.t1 and block_hash not in self.t2:
block_hashes_to_store.append(block_hash)
if not block_hashes_to_store:
return PrepareStoreOutput(
block_hashes_to_store=[],
store_spec=self.backend.get_load_store_spec([], []),
block_hashes_evicted=[],
)
num_blocks_to_evict = (
len(block_hashes_to_store) - self.backend.get_num_free_blocks()
)
to_evict = []
while num_blocks_to_evict > 0:
block_to_evict = None
if len(self.t1) >= int(self.target_t1_size):
# try to evict the least recently used (oldest) block from T1
for block_hash, block in self.t1.items():
if block.ref_cnt == 0:
block_to_evict = (block_hash, block)
eviction_t = self.t1
eviction_b = self.b1
break
if not block_to_evict:
# try to evict the least recently used (oldest) block from T2
for block_hash, block in self.t2.items():
if block.ref_cnt == 0:
block_to_evict = (block_hash, block)
eviction_t = self.t2
eviction_b = self.b2
break
else:
# cannot evict enough blocks, cache is full of in-use items
return None
block_hash, block = block_to_evict
del eviction_t[block_hash]
eviction_b[block_hash] = None
to_evict.append(block_hash)
self.backend.free(block)
num_blocks_to_evict -= 1
for b in [self.b1, self.b2]:
for i in range(len(b) - self.cache_capacity):
b.popitem(last=False)
if to_evict and self.events is not None:
self.events.append(
OffloadingEvent(
block_hashes=to_evict,
block_size=self.backend.block_size,
medium=self.backend.medium,
removed=True,
)
)
blocks = self.backend.allocate_blocks(block_hashes_to_store)
assert len(blocks) == len(block_hashes_to_store), (
"Backend did not allocate the expected number of blocks"
)
for block_hash, block in zip(block_hashes_to_store, blocks):
self.t1[block_hash] = block
self.b1.pop(block_hash, None)
self.b2.pop(block_hash, None)
store_spec = self.backend.get_load_store_spec(block_hashes_to_store, blocks)
return PrepareStoreOutput(
block_hashes_to_store=block_hashes_to_store,
store_spec=store_spec,
block_hashes_evicted=to_evict,
)
def complete_store(self, block_hashes: Iterable[BlockHash], success: bool = True):
stored_block_hashes: list[BlockHash] = []
if success:
for block_hash in block_hashes:
block = self.t1.get(block_hash) or self.t2.get(block_hash)
if block is not None and not block.is_ready:
block.ref_cnt = 0
stored_block_hashes.append(block_hash)
else:
for block_hash in block_hashes:
block = self.t1.pop(block_hash, None)
if block is None:
block = self.t2.pop(block_hash, None)
if block is not None and not block.is_ready:
self.backend.free(block)
if stored_block_hashes and self.events is not None:
self.events.append(
OffloadingEvent(
block_hashes=stored_block_hashes,
block_size=self.backend.block_size,
medium=self.backend.medium,
removed=False,
)
)
def take_events(self) -> Iterable[OffloadingEvent]:
if self.events is not None:
yield from self.events
self.events.clear()

View File

@ -8,6 +8,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.platforms import current_platform
from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager
from vllm.v1.kv_offload.arc_manager import ARCOffloadingManager
from vllm.v1.kv_offload.backends.cpu import CPUBackend
from vllm.v1.kv_offload.lru_manager import LRUOffloadingManager
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec
@ -33,18 +34,32 @@ class CPUOffloadingSpec(OffloadingSpec):
# worker-side
self._handler: OffloadingHandler | None = None
self.eviction_policy: str = self.extra_config.get("eviction_policy", "lru")
def get_manager(self) -> OffloadingManager:
if not self._manager:
kv_events_config = self.vllm_config.kv_events_config
enable_events = (
kv_events_config is not None and kv_events_config.enable_kv_cache_events
)
self._manager = LRUOffloadingManager(
CPUBackend(
block_size=self.offloaded_block_size, num_blocks=self.num_cpu_blocks
),
enable_events=enable_events,
backend = CPUBackend(
block_size=self.offloaded_block_size, num_blocks=self.num_cpu_blocks
)
if self.eviction_policy == "lru":
self._manager = LRUOffloadingManager(
backend=backend, enable_events=enable_events
)
elif self.eviction_policy == "arc":
self._manager = ARCOffloadingManager(
backend=backend, enable_events=enable_events
)
else:
raise ValueError(
f"Unknown eviction policy: {self.eviction_policy}. "
f"Supported policies: lru, arc"
)
return self._manager
def get_handlers(