mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-29 12:07:12 +08:00
[KV offload][2/N] Introduce LRU-based CPU offloading management (#20075)
Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
parent
9a4600e4dc
commit
9d1c50a5ac
175
tests/v1/kv_offload/test_cpu.py
Normal file
175
tests/v1/kv_offload/test_cpu.py
Normal file
@ -0,0 +1,175 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
from vllm.v1.kv_offload.abstract import (LoadStoreSpec, OffloadingEvent,
|
||||
PrepareStoreOutput)
|
||||
from vllm.v1.kv_offload.backends.cpu import CPUBackend
|
||||
from vllm.v1.kv_offload.lru_manager import LRUOffloadingManager
|
||||
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExpectedPrepareStoreOutput:
|
||||
block_hashes_to_store: list[int]
|
||||
store_block_ids: list[int]
|
||||
block_hashes_evicted: list[int]
|
||||
|
||||
|
||||
def to_hashes(int_hashes: list[int]) -> list[BlockHash]:
|
||||
return [BlockHash(str(i).encode()) for i in int_hashes]
|
||||
|
||||
|
||||
def verify_store_output(
|
||||
prepare_store_output: Optional[PrepareStoreOutput],
|
||||
expected_prepare_store_output: ExpectedPrepareStoreOutput):
|
||||
assert prepare_store_output is not None
|
||||
assert (prepare_store_output.block_hashes_to_store == to_hashes(
|
||||
expected_prepare_store_output.block_hashes_to_store))
|
||||
assert (prepare_store_output.block_hashes_evicted == to_hashes(
|
||||
expected_prepare_store_output.block_hashes_evicted))
|
||||
store_spec = prepare_store_output.store_spec
|
||||
assert isinstance(store_spec, CPULoadStoreSpec)
|
||||
expected_array = np.array(expected_prepare_store_output.store_block_ids,
|
||||
dtype=np.int64)
|
||||
assert np.array_equal(expected_array, store_spec.block_ids)
|
||||
|
||||
|
||||
def verify_load_output(prepare_load_output: LoadStoreSpec,
|
||||
expected_prepare_load_output: list[int]):
|
||||
assert isinstance(prepare_load_output, CPULoadStoreSpec)
|
||||
expected_array = np.array(expected_prepare_load_output, dtype=np.int64)
|
||||
assert np.array_equal(expected_array, prepare_load_output.block_ids)
|
||||
|
||||
|
||||
def verify_events(events: Iterable[OffloadingEvent],
|
||||
block_size: int,
|
||||
expected_stores: tuple[set[int], ...] = (),
|
||||
expected_evictions: tuple[set[int], ...] = ()):
|
||||
stores: list[set[BlockHash]] = []
|
||||
evictions: list[set[BlockHash]] = []
|
||||
for event in events:
|
||||
assert event.medium == CPULoadStoreSpec.medium()
|
||||
assert event.block_size == block_size
|
||||
if event.removed:
|
||||
evictions.append(set(event.block_hashes))
|
||||
else:
|
||||
stores.append(set(event.block_hashes))
|
||||
|
||||
def to_hash_sets(
|
||||
int_sets: tuple[set[int], ...]) -> tuple[set[BlockHash], ...]:
|
||||
return tuple([set(to_hashes(list(int_set))) for int_set in int_sets])
|
||||
|
||||
assert tuple(evictions) == to_hash_sets(expected_evictions)
|
||||
assert tuple(stores) == to_hash_sets(expected_stores)
|
||||
|
||||
|
||||
def test_cpu_manager():
|
||||
"""
|
||||
Tests LRUOffloadingManager with a CPUBackend.
|
||||
"""
|
||||
# initialize a CPU backend with a capacity of 4 blocks
|
||||
block_size = 256
|
||||
cpu_backend = CPUBackend(block_size=block_size, num_blocks=4)
|
||||
cpu_manager = LRUOffloadingManager(cpu_backend, enable_events=True)
|
||||
|
||||
# prepare store [1, 2]
|
||||
prepare_store_output = cpu_manager.prepare_store(to_hashes([1, 2]))
|
||||
verify_store_output(
|
||||
prepare_store_output,
|
||||
ExpectedPrepareStoreOutput(
|
||||
block_hashes_to_store=[1, 2],
|
||||
store_block_ids=[0, 1],
|
||||
block_hashes_evicted=[],
|
||||
))
|
||||
|
||||
# lookup [1, 2] -> not ready
|
||||
assert cpu_manager.lookup(to_hashes([1, 2])) == 0
|
||||
|
||||
# no events so far
|
||||
assert list(cpu_manager.take_events()) == []
|
||||
|
||||
# complete store [1, 2]
|
||||
cpu_manager.complete_store(to_hashes([1, 2]))
|
||||
verify_events(cpu_manager.take_events(),
|
||||
block_size=block_size,
|
||||
expected_stores=({1, 2}, ))
|
||||
|
||||
# lookup [1, 2]
|
||||
assert cpu_manager.lookup(to_hashes([1])) == 1
|
||||
assert cpu_manager.lookup(to_hashes([1, 2])) == 2
|
||||
assert cpu_manager.lookup(to_hashes([1, 2, 3])) == 2
|
||||
|
||||
# prepare store [2, 3, 4, 5] -> evicts [1]
|
||||
prepare_store_output = cpu_manager.prepare_store(to_hashes([2, 3, 4, 5]))
|
||||
verify_store_output(
|
||||
prepare_store_output,
|
||||
ExpectedPrepareStoreOutput(
|
||||
block_hashes_to_store=[3, 4, 5],
|
||||
store_block_ids=[2, 3, 0],
|
||||
block_hashes_evicted=[1],
|
||||
))
|
||||
|
||||
# verify eviction event
|
||||
verify_events(cpu_manager.take_events(),
|
||||
block_size=block_size,
|
||||
expected_evictions=({1}, ))
|
||||
|
||||
# prepare store with no space
|
||||
assert cpu_manager.prepare_store(to_hashes([1, 6])) is None
|
||||
|
||||
# complete store [2, 3, 4, 5]
|
||||
cpu_manager.complete_store(to_hashes([2, 3, 4, 5]))
|
||||
|
||||
# prepare load [2, 3]
|
||||
prepare_load_output = cpu_manager.prepare_load(to_hashes([2, 3]))
|
||||
verify_load_output(prepare_load_output, [1, 2])
|
||||
|
||||
# prepare store with no space ([2, 3] is being loaded)
|
||||
assert cpu_manager.prepare_store(to_hashes([6, 7, 8])) is None
|
||||
|
||||
# complete load [2, 3]
|
||||
cpu_manager.complete_load(to_hashes([2, 3]))
|
||||
|
||||
# prepare store [6, 7, 8] -> evicts [2, 3, 4] (oldest)
|
||||
prepare_store_output = cpu_manager.prepare_store(to_hashes([6, 7, 8]))
|
||||
verify_store_output(
|
||||
prepare_store_output,
|
||||
ExpectedPrepareStoreOutput(
|
||||
block_hashes_to_store=[6, 7, 8],
|
||||
store_block_ids=[3, 2, 1],
|
||||
block_hashes_evicted=[2, 3, 4],
|
||||
))
|
||||
|
||||
# complete store [6, 7, 8]
|
||||
cpu_manager.complete_store(to_hashes([6, 7, 8]))
|
||||
|
||||
# touch [5, 6, 7] (move to end of LRU order)
|
||||
cpu_manager.touch(to_hashes([5, 6, 7]))
|
||||
|
||||
# prepare store [7, 9] -> evicts [8] (oldest following previous touch)
|
||||
prepare_store_output = cpu_manager.prepare_store(to_hashes([9]))
|
||||
verify_store_output(
|
||||
prepare_store_output,
|
||||
ExpectedPrepareStoreOutput(
|
||||
block_hashes_to_store=[9],
|
||||
store_block_ids=[1],
|
||||
block_hashes_evicted=[8],
|
||||
))
|
||||
|
||||
# complete store [7, 9] with failure
|
||||
cpu_manager.complete_store(to_hashes([7, 9]), success=False)
|
||||
|
||||
# assert [7] is still stored, but [9] is not
|
||||
assert cpu_manager.lookup(to_hashes([7])) == 1
|
||||
assert cpu_manager.lookup(to_hashes([9])) == 0
|
||||
|
||||
verify_events(cpu_manager.take_events(),
|
||||
block_size=block_size,
|
||||
expected_stores=({3, 4, 5}, {6, 7, 8}),
|
||||
expected_evictions=({2, 3, 4}, {8}))
|
||||
96
vllm/v1/kv_offload/backend.py
Normal file
96
vllm/v1/kv_offload/backend.py
Normal file
@ -0,0 +1,96 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import ctypes
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable
|
||||
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
from vllm.v1.kv_offload.abstract import LoadStoreSpec
|
||||
|
||||
|
||||
class BlockStatus(ctypes.Structure):
|
||||
"""
|
||||
Offloading status for a single block of KV data.
|
||||
Holds the following information:
|
||||
|
||||
ref_cnt - the current number of transfers using this block as a source.
|
||||
A value of -1 indicates the block is not yet ready to be read.
|
||||
load_store_spec - backend-specific information on how to actually
|
||||
read/write the block.
|
||||
"""
|
||||
_fields_ = [("ref_cnt", ctypes.c_int32)]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# initialize block as "not ready" (ref_cnt = -1)
|
||||
self.ref_cnt = -1
|
||||
|
||||
@property
|
||||
def is_ready(self) -> bool:
|
||||
"""
|
||||
Returns whether the block is ready to be read.
|
||||
"""
|
||||
return self.ref_cnt >= 0
|
||||
|
||||
|
||||
class Backend(ABC):
|
||||
"""
|
||||
An abstract class for allocating and returning specs for writing
|
||||
KV blocks to some backend.
|
||||
"""
|
||||
|
||||
def __init__(self, block_size: int, medium: str):
|
||||
self.block_size = block_size
|
||||
self.medium = medium
|
||||
|
||||
@abstractmethod
|
||||
def get_num_free_blocks(self):
|
||||
"""
|
||||
Returns the number of current number of blocks that can be allocated.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def allocate_blocks(self,
|
||||
block_hashes: list[BlockHash]) -> list[BlockStatus]:
|
||||
"""
|
||||
Allocate space for writing blocks.
|
||||
This method assumes there is enough space for allocation.
|
||||
It is unsafe to use without checking get_num_free_blocks beforehand.
|
||||
|
||||
Args:
|
||||
block_hashes: the hashes identifying the blocks to be written.
|
||||
|
||||
Returns:
|
||||
A list of BlockStatus for the allocated blocks.
|
||||
The ref_cnt of each returned item will be -1, meaning the block
|
||||
is not yet ready to be read.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def free(self, block: BlockStatus):
|
||||
"""
|
||||
Free a previously allocated block.
|
||||
You should only call this function with blocks returned by
|
||||
allocate_blocks, and only once per each block.
|
||||
|
||||
Args:
|
||||
block: The block to be freed.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_load_store_spec(self, block_hashes: Iterable[BlockHash],
|
||||
blocks: Iterable[BlockStatus]) -> LoadStoreSpec:
|
||||
"""
|
||||
Get backend-specific information on how to read/write blocks.
|
||||
|
||||
Args:
|
||||
block_hashes: the list of block hashes identifying the blocks.
|
||||
blocks: the list of blocks.
|
||||
|
||||
Returns:
|
||||
A LoadStoreSpec that can be used by a worker
|
||||
to read/write the blocks.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
61
vllm/v1/kv_offload/backends/cpu.py
Normal file
61
vllm/v1/kv_offload/backends/cpu.py
Normal file
@ -0,0 +1,61 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import ctypes
|
||||
from collections.abc import Iterable
|
||||
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
from vllm.v1.kv_offload.abstract import LoadStoreSpec
|
||||
from vllm.v1.kv_offload.backend import Backend, BlockStatus
|
||||
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec
|
||||
|
||||
|
||||
class CPUBlockStatus(BlockStatus):
|
||||
_fields_ = BlockStatus._fields_ + [("block_id", ctypes.c_int64)
|
||||
] # type: ignore
|
||||
|
||||
def __init__(self, block_id: int):
|
||||
super().__init__()
|
||||
self.block_id = block_id
|
||||
|
||||
|
||||
class CPUBackend(Backend):
|
||||
|
||||
def __init__(self, block_size: int, num_blocks: int):
|
||||
super().__init__(block_size=block_size,
|
||||
medium=CPULoadStoreSpec.medium())
|
||||
|
||||
self.num_blocks: int = num_blocks
|
||||
self.num_allocated_blocks: int = 0
|
||||
self.allocated_blocks_free_list: list[int] = []
|
||||
|
||||
def get_num_free_blocks(self):
|
||||
return (len(self.allocated_blocks_free_list) + self.num_blocks -
|
||||
self.num_allocated_blocks)
|
||||
|
||||
def allocate_blocks(self,
|
||||
block_hashes: list[BlockHash]) -> list[BlockStatus]:
|
||||
num_fresh_blocks = min(len(block_hashes),
|
||||
self.num_blocks - self.num_allocated_blocks)
|
||||
num_reused_blocks = len(block_hashes) - num_fresh_blocks
|
||||
assert len(self.allocated_blocks_free_list) >= num_reused_blocks
|
||||
|
||||
# allocate fresh blocks
|
||||
blocks: list[BlockStatus] = []
|
||||
for _ in range(num_fresh_blocks):
|
||||
blocks.append(CPUBlockStatus(self.num_allocated_blocks))
|
||||
self.num_allocated_blocks += 1
|
||||
|
||||
# allocate reused blocks
|
||||
for _ in range(num_reused_blocks):
|
||||
block_id = self.allocated_blocks_free_list.pop()
|
||||
blocks.append(CPUBlockStatus(block_id))
|
||||
|
||||
return blocks
|
||||
|
||||
def free(self, block: BlockStatus):
|
||||
assert isinstance(block, CPUBlockStatus)
|
||||
self.allocated_blocks_free_list.append(block.block_id)
|
||||
|
||||
def get_load_store_spec(self, block_hashes: Iterable[BlockHash],
|
||||
blocks: Iterable[BlockStatus]) -> LoadStoreSpec:
|
||||
return CPULoadStoreSpec([block.block_id for block in blocks])
|
||||
132
vllm/v1/kv_offload/lru_manager.py
Normal file
132
vllm/v1/kv_offload/lru_manager.py
Normal file
@ -0,0 +1,132 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
from vllm.v1.kv_offload.abstract import (LoadStoreSpec, OffloadingEvent,
|
||||
OffloadingManager, PrepareStoreOutput)
|
||||
from vllm.v1.kv_offload.backend import Backend, BlockStatus
|
||||
|
||||
|
||||
class LRUOffloadingManager(OffloadingManager):
|
||||
"""
|
||||
An OffloadingManager with a pluggable backend, which evicts blocks by LRU.
|
||||
"""
|
||||
|
||||
def __init__(self, backend: Backend, enable_events: bool = False):
|
||||
self.backend: Backend = backend
|
||||
# block_hash -> BlockStatus
|
||||
self.blocks: OrderedDict[BlockHash, BlockStatus] = OrderedDict()
|
||||
self.events: Optional[list[OffloadingEvent]] = \
|
||||
[] if enable_events else None
|
||||
|
||||
def lookup(self, block_hashes: Iterable[BlockHash]) -> int:
|
||||
hit_count = 0
|
||||
for block_hash in block_hashes:
|
||||
block = self.blocks.get(block_hash)
|
||||
if block is None or not block.is_ready:
|
||||
break
|
||||
hit_count += 1
|
||||
return hit_count
|
||||
|
||||
def prepare_load(self, block_hashes: Iterable[BlockHash]) -> LoadStoreSpec:
|
||||
blocks = []
|
||||
for block_hash in block_hashes:
|
||||
block = self.blocks[block_hash]
|
||||
assert block.is_ready
|
||||
block.ref_cnt += 1
|
||||
blocks.append(block)
|
||||
|
||||
return self.backend.get_load_store_spec(block_hashes, blocks)
|
||||
|
||||
def touch(self, block_hashes: Iterable[BlockHash]):
|
||||
for block_hash in reversed(list(block_hashes)):
|
||||
if self.blocks.get(block_hash):
|
||||
self.blocks.move_to_end(block_hash)
|
||||
|
||||
def complete_load(self, block_hashes: Iterable[BlockHash]):
|
||||
for block_hash in block_hashes:
|
||||
block = self.blocks[block_hash]
|
||||
assert block.ref_cnt > 0
|
||||
block.ref_cnt -= 1
|
||||
|
||||
def prepare_store(
|
||||
self,
|
||||
block_hashes: Iterable[BlockHash]) -> Optional[PrepareStoreOutput]:
|
||||
# filter out blocks that are already stored
|
||||
block_hashes_to_store = [
|
||||
block_hash for block_hash in block_hashes
|
||||
if block_hash not in self.blocks
|
||||
]
|
||||
|
||||
num_blocks_to_evict = (len(block_hashes_to_store) -
|
||||
self.backend.get_num_free_blocks())
|
||||
|
||||
# build list of blocks to evict
|
||||
to_evict = []
|
||||
if num_blocks_to_evict > 0:
|
||||
for block_hash, block in self.blocks.items():
|
||||
if block.ref_cnt == 0:
|
||||
to_evict.append(block_hash)
|
||||
num_blocks_to_evict -= 1
|
||||
if num_blocks_to_evict == 0:
|
||||
break
|
||||
else:
|
||||
# we could not evict enough blocks
|
||||
return None
|
||||
|
||||
# evict blocks
|
||||
for block_hash in to_evict:
|
||||
self.backend.free(self.blocks.pop(block_hash))
|
||||
|
||||
if to_evict and self.events is not None:
|
||||
self.events.append(
|
||||
OffloadingEvent(block_hashes=to_evict,
|
||||
block_size=self.backend.block_size,
|
||||
medium=self.backend.medium,
|
||||
removed=True))
|
||||
|
||||
blocks = self.backend.allocate_blocks(block_hashes_to_store)
|
||||
assert len(blocks) == len(block_hashes_to_store)
|
||||
|
||||
for block_hash, block in zip(block_hashes_to_store, blocks):
|
||||
self.blocks[block_hash] = block
|
||||
|
||||
# build store specs for allocated blocks
|
||||
store_spec = self.backend.get_load_store_spec(block_hashes_to_store,
|
||||
blocks)
|
||||
|
||||
return PrepareStoreOutput(block_hashes_to_store=block_hashes_to_store,
|
||||
store_spec=store_spec,
|
||||
block_hashes_evicted=to_evict)
|
||||
|
||||
def complete_store(self,
|
||||
block_hashes: Iterable[BlockHash],
|
||||
success: bool = True):
|
||||
stored_block_hashes: list[BlockHash] = []
|
||||
if success:
|
||||
for block_hash in block_hashes:
|
||||
block = self.blocks[block_hash]
|
||||
if not block.is_ready:
|
||||
block.ref_cnt = 0
|
||||
stored_block_hashes.append(block_hash)
|
||||
else:
|
||||
for block_hash in block_hashes:
|
||||
block = self.blocks[block_hash]
|
||||
if not block.is_ready:
|
||||
self.backend.free(block)
|
||||
del self.blocks[block_hash]
|
||||
|
||||
if stored_block_hashes and self.events is not None:
|
||||
self.events.append(
|
||||
OffloadingEvent(block_hashes=stored_block_hashes,
|
||||
block_size=self.backend.block_size,
|
||||
medium=self.backend.medium,
|
||||
removed=False))
|
||||
|
||||
def take_events(self) -> Iterable[OffloadingEvent]:
|
||||
if self.events is not None:
|
||||
yield from self.events
|
||||
self.events.clear()
|
||||
Loading…
x
Reference in New Issue
Block a user