mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 03:15:01 +08:00
[Core] Enable prefix caching with block manager v2 enabled (#4142)
Co-authored-by: Lei Wen <wenlei03@qiyi.com> Co-authored-by: Sage Moore <sagemoore@utexas.edu>
This commit is contained in:
parent
b38e42fbca
commit
24750f4cad
@ -16,20 +16,22 @@ def test_prefix(llm=None, sampling_params=None, prompts=None):
|
|||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
llm = LLM(model="baichuan-inc/Baichuan2-13B-Chat",
|
llm = LLM(model=args.model,
|
||||||
tokenizer_mode='auto',
|
tokenizer_mode='auto',
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
|
use_v2_block_manager=args.use_v2_block_manager,
|
||||||
|
tensor_parallel_size=args.tensor_parallel_size,
|
||||||
enable_prefix_caching=args.enable_prefix_caching)
|
enable_prefix_caching=args.enable_prefix_caching)
|
||||||
|
|
||||||
num_prompts = 100
|
num_prompts = 100
|
||||||
prompts = [PROMPT] * num_prompts
|
prompts = [PROMPT] * num_prompts
|
||||||
sampling_params = SamplingParams(temperature=0, max_tokens=100)
|
sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
|
||||||
|
|
||||||
print("------warm up------")
|
print("------warm up------")
|
||||||
test_prefix(
|
test_prefix(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
prompts=prompts[:1],
|
prompts=prompts,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -45,8 +47,16 @@ if __name__ == "__main__":
|
|||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description='Benchmark the performance with or without automatic '
|
description='Benchmark the performance with or without automatic '
|
||||||
'prefix caching.')
|
'prefix caching.')
|
||||||
|
parser.add_argument('--model',
|
||||||
|
type=str,
|
||||||
|
default='baichuan-inc/Baichuan2-13B-Chat')
|
||||||
|
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
|
||||||
|
parser.add_argument('--output-len', type=int, default=10)
|
||||||
parser.add_argument('--enable-prefix-caching',
|
parser.add_argument('--enable-prefix-caching',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help='enable prefix caching')
|
help='enable prefix caching')
|
||||||
|
parser.add_argument('--use-v2-block-manager',
|
||||||
|
action='store_true',
|
||||||
|
help='Use BlockSpaceMangerV2')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
|||||||
@ -300,6 +300,152 @@ def test_chunked_prefill_block_manager_v2(baseline_llm_generator,
|
|||||||
assert baseline_token_ids == test_token_ids
|
assert baseline_token_ids == test_token_ids
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
# Use a small model for a fast test.
|
||||||
|
"model": "facebook/opt-125m",
|
||||||
|
|
||||||
|
# skip cuda graph creation for fast test.
|
||||||
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# Allow only 5 sequences of ~1024 tokens in worst case.
|
||||||
|
"block_size": 16,
|
||||||
|
"num_gpu_blocks_override": 5 * (64 + 1),
|
||||||
|
|
||||||
|
# Enable prefill cache
|
||||||
|
"enable_prefix_caching": True,
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{
|
||||||
|
"use_v2_block_manager": False
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("test_llm_kwargs", [{"use_v2_block_manager": True}])
|
||||||
|
@pytest.mark.parametrize("batch_size", [10])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_v1_v2_greedy_equality_prefix_caching_enabled_with_preemption(
|
||||||
|
baseline_llm_generator, test_llm_generator, batch_size):
|
||||||
|
"""Verify block manager v2 produces same outputs as block manager v1, even
|
||||||
|
when there is preemption.
|
||||||
|
|
||||||
|
This constructs two LLM, each with limited number of GPU blocks. The limit
|
||||||
|
is decided such that as the sequences in the batch grow, sequences must be
|
||||||
|
preempted and removed from cache.
|
||||||
|
|
||||||
|
If the output token ids are equivalent, then we have confidence that the KV
|
||||||
|
cache is not corrupted in the v2 block manager.
|
||||||
|
|
||||||
|
NOTE: We want a significant number of generated tokens so that any incorrect
|
||||||
|
KV mapping has time to build up error.
|
||||||
|
"""
|
||||||
|
output_len = 1024
|
||||||
|
temperature = 0.0
|
||||||
|
|
||||||
|
# We want to ensure equality even with preemption.
|
||||||
|
# We force the total block size to be 1 + cdiv(output_len, block_size)
|
||||||
|
# so that only one sequence can fit at a time (once the sequences grow).
|
||||||
|
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
|
||||||
|
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
max_tokens=output_len,
|
||||||
|
ignore_eos=True,
|
||||||
|
temperature=temperature,
|
||||||
|
)
|
||||||
|
|
||||||
|
print('Getting token ids from block manager v1')
|
||||||
|
baseline_token_ids = get_token_ids_from_llm_generator(
|
||||||
|
baseline_llm_generator, prompts, sampling_params)
|
||||||
|
|
||||||
|
print('Getting token ids from block manager v2')
|
||||||
|
test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
|
||||||
|
prompts, sampling_params)
|
||||||
|
|
||||||
|
for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
|
||||||
|
test_token_ids):
|
||||||
|
assert expected_token_ids == actual_token_ids
|
||||||
|
|
||||||
|
assert baseline_token_ids == test_token_ids
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
# Use a small model for a fast test.
|
||||||
|
"model": "facebook/opt-125m",
|
||||||
|
|
||||||
|
# skip cuda graph creation for fast test.
|
||||||
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# Allow only 5 sequences of ~1024 tokens in worst case.
|
||||||
|
"block_size": 16,
|
||||||
|
"num_gpu_blocks_override": 5 * (64 + 1),
|
||||||
|
|
||||||
|
# Test APC in v2 block
|
||||||
|
"use_v2_block_manager": True,
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{
|
||||||
|
"enable_prefix_caching": False
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("test_llm_kwargs", [{"enable_prefix_caching": True}])
|
||||||
|
@pytest.mark.parametrize("batch_size", [10])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_auto_prefix_caching_with_preemption(baseline_llm_generator,
|
||||||
|
test_llm_generator, batch_size):
|
||||||
|
"""Verify block manager v2 with auto prefix caching enabled produces same
|
||||||
|
outputs as auto prefix caching disabled, even when there is preemption.
|
||||||
|
|
||||||
|
This constructs two LLM, each with limited number of GPU blocks. The limit
|
||||||
|
is decided such that as the sequences in the batch grow, sequences must be
|
||||||
|
preempted and removed from cache.
|
||||||
|
|
||||||
|
If the output token ids are equivalent, then we have confidence that auto
|
||||||
|
prefix caching itself at least don't cause result error.
|
||||||
|
"""
|
||||||
|
output_len = 1024
|
||||||
|
temperature = 0.0
|
||||||
|
|
||||||
|
# We want to ensure equality even with preemption.
|
||||||
|
# We force the total block size to be 1 + cdiv(output_len, block_size)
|
||||||
|
# so that only one sequence can fit at a time (once the sequences grow).
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
|
||||||
|
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
max_tokens=output_len,
|
||||||
|
ignore_eos=True,
|
||||||
|
temperature=temperature,
|
||||||
|
)
|
||||||
|
|
||||||
|
print('Getting token ids with APC disabled')
|
||||||
|
baseline_token_ids = get_token_ids_from_llm_generator(
|
||||||
|
baseline_llm_generator, prompts, sampling_params)
|
||||||
|
|
||||||
|
print('Getting token ids with APC enabled')
|
||||||
|
test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
|
||||||
|
prompts, sampling_params)
|
||||||
|
|
||||||
|
for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
|
||||||
|
test_token_ids):
|
||||||
|
assert expected_token_ids == actual_token_ids
|
||||||
|
|
||||||
|
assert baseline_token_ids == test_token_ids
|
||||||
|
|
||||||
|
|
||||||
def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params):
|
def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params):
|
||||||
for llm in llm_generator:
|
for llm in llm_generator:
|
||||||
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
|
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
|
||||||
|
|||||||
@ -358,6 +358,131 @@ class TestPrefixCachingBlockAllocator:
|
|||||||
i)
|
i)
|
||||||
allocator.free(block)
|
allocator.free(block)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@pytest.mark.parametrize("num_blocks", [1024])
|
||||||
|
@pytest.mark.parametrize("block_size", [16])
|
||||||
|
@pytest.mark.parametrize("seed", list(range(20)))
|
||||||
|
def test_get_common_computed_block_ids(num_blocks: int, block_size: int,
|
||||||
|
seed: int):
|
||||||
|
"""Verify get_common_computed_block_ids could get correct result
|
||||||
|
by create two immutable chain sharing prefix at specified pos,
|
||||||
|
and compare whether we also could get right result
|
||||||
|
from get_common_computed_block_ids.
|
||||||
|
"""
|
||||||
|
random.seed(seed)
|
||||||
|
allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks * 2,
|
||||||
|
block_size=block_size)
|
||||||
|
num_blocks_to_consume = random.randint(1, num_blocks - 1)
|
||||||
|
|
||||||
|
# Create token ids that will exhaust all blocks.
|
||||||
|
token_ids = list(range(num_blocks_to_consume * block_size))
|
||||||
|
blocks = list(range(num_blocks_to_consume))
|
||||||
|
|
||||||
|
first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
|
||||||
|
block_size=block_size,
|
||||||
|
token_ids=token_ids,
|
||||||
|
allocator=allocator,
|
||||||
|
)
|
||||||
|
|
||||||
|
# mark all blocks in first chain as computed
|
||||||
|
allocator.mark_blocks_as_computed(blocks)
|
||||||
|
|
||||||
|
# After zero_point, second_chain's token_ids would be set -1, which
|
||||||
|
# make it different from here comparing with first_chain
|
||||||
|
zero_point = random.randint(1, len(token_ids) - 1)
|
||||||
|
zero_point_blocks = zero_point // block_size
|
||||||
|
token_ids[zero_point:] = [-1] * (len(token_ids) - zero_point)
|
||||||
|
|
||||||
|
second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
|
||||||
|
block_size=block_size,
|
||||||
|
token_ids=token_ids,
|
||||||
|
allocator=allocator,
|
||||||
|
)
|
||||||
|
|
||||||
|
first_computed_ids = [
|
||||||
|
first_chain[i].block_id for i in range(num_blocks_to_consume)
|
||||||
|
]
|
||||||
|
second_computed_ids = [
|
||||||
|
second_chain[i].block_id for i in range(num_blocks_to_consume)
|
||||||
|
]
|
||||||
|
res = allocator.get_common_computed_block_ids(
|
||||||
|
[first_computed_ids, second_computed_ids])
|
||||||
|
|
||||||
|
assert (len(res) == zero_point_blocks)
|
||||||
|
|
||||||
|
# Test case where two last accessed times are equal
|
||||||
|
@staticmethod
|
||||||
|
@pytest.mark.parametrize("num_blocks", [1024])
|
||||||
|
@pytest.mark.parametrize("block_size", [16])
|
||||||
|
@pytest.mark.parametrize("seed", list(range(20)))
|
||||||
|
def test_eviction_order(num_blocks: int, block_size: int, seed: int):
|
||||||
|
"""This test case simulate the two chain created and free in order,
|
||||||
|
and together they would exhaust the initial freed blocks.
|
||||||
|
|
||||||
|
So the next block created after those two chain shall use the block
|
||||||
|
from the first chain as that block has long access time.
|
||||||
|
While first chain has two blocks, it shall pick up the last one, as
|
||||||
|
it has larger token number.
|
||||||
|
"""
|
||||||
|
|
||||||
|
random.seed(seed)
|
||||||
|
allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
|
||||||
|
block_size=block_size)
|
||||||
|
num_blocks_to_consume = num_blocks + 1
|
||||||
|
|
||||||
|
token_ids = list(range(num_blocks_to_consume * block_size))
|
||||||
|
|
||||||
|
num_blocks_in_first_chain = 2
|
||||||
|
num_tokens_in_first_chain = block_size * num_blocks_in_first_chain
|
||||||
|
# First chain takes the first block
|
||||||
|
first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
|
||||||
|
block_size=block_size,
|
||||||
|
token_ids=token_ids[:num_tokens_in_first_chain],
|
||||||
|
allocator=allocator,
|
||||||
|
)
|
||||||
|
# There should only be one block allocated at this point
|
||||||
|
assert allocator.get_num_free_blocks() == (num_blocks -
|
||||||
|
num_blocks_in_first_chain)
|
||||||
|
|
||||||
|
# Set the last accessed time of the first block to 1
|
||||||
|
blocks_ids = [block.block_id for block in first_chain]
|
||||||
|
allocator.mark_blocks_as_accessed(blocks_ids, 1)
|
||||||
|
|
||||||
|
# Second chain takes the rest of the blocks
|
||||||
|
second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
|
||||||
|
block_size=block_size,
|
||||||
|
token_ids=token_ids[num_tokens_in_first_chain:-block_size],
|
||||||
|
allocator=allocator,
|
||||||
|
)
|
||||||
|
|
||||||
|
# There shouldn't be any blocks left at this point
|
||||||
|
assert allocator.get_num_free_blocks() == (0)
|
||||||
|
|
||||||
|
assert len(first_chain) == num_blocks_in_first_chain
|
||||||
|
last_block_id = first_chain[-1].block_id
|
||||||
|
# Free each block in the first chain.
|
||||||
|
for i, block in enumerate(first_chain):
|
||||||
|
allocator.free(block)
|
||||||
|
|
||||||
|
# Set the last accessed time on all of the blocks in the second chain
|
||||||
|
# to 2
|
||||||
|
blocks_ids = [block.block_id for block in second_chain]
|
||||||
|
allocator.mark_blocks_as_accessed(blocks_ids, 2)
|
||||||
|
|
||||||
|
# Free each block in the second chain.
|
||||||
|
for i, block in enumerate(second_chain):
|
||||||
|
allocator.free(block)
|
||||||
|
|
||||||
|
# Allocate a new block and check that it's the least recently used block
|
||||||
|
# from the first chain.
|
||||||
|
new_block = TestPrefixCachingBlockAllocator.create_immutable_chain(
|
||||||
|
block_size=block_size,
|
||||||
|
token_ids=token_ids[-block_size:],
|
||||||
|
allocator=allocator,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert new_block[0].block_id == last_block_id
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_immutable_chain(
|
def create_immutable_chain(
|
||||||
block_size: int,
|
block_size: int,
|
||||||
|
|||||||
@ -190,10 +190,18 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
|
|||||||
device = Device.GPU
|
device = Device.GPU
|
||||||
return self._allocators[device].clear_copy_on_writes()
|
return self._allocators[device].clear_copy_on_writes()
|
||||||
|
|
||||||
def mark_blocks_as_computed(self) -> None:
|
def mark_blocks_as_accessed(self, block_ids: List[int],
|
||||||
|
now: float) -> None:
|
||||||
|
"""Mark blocks as accessed, only use for prefix caching."""
|
||||||
# Prefix caching only supported on GPU.
|
# Prefix caching only supported on GPU.
|
||||||
device = Device.GPU
|
device = Device.GPU
|
||||||
return self._allocators[device].mark_blocks_as_computed()
|
return self._allocators[device].mark_blocks_as_accessed(block_ids, now)
|
||||||
|
|
||||||
|
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
|
||||||
|
"""Mark blocks as accessed, only use for prefix caching."""
|
||||||
|
# Prefix caching only supported on GPU.
|
||||||
|
device = Device.GPU
|
||||||
|
return self._allocators[device].mark_blocks_as_computed(block_ids)
|
||||||
|
|
||||||
def get_common_computed_block_ids(
|
def get_common_computed_block_ids(
|
||||||
self, seq_block_ids: List[List[int]]) -> List[int]:
|
self, seq_block_ids: List[List[int]]) -> List[int]:
|
||||||
|
|||||||
@ -81,6 +81,10 @@ class BlockAllocator(ABC):
|
|||||||
def clear_copy_on_writes(self) -> Dict[int, List[int]]:
|
def clear_copy_on_writes(self) -> Dict[int, List[int]]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def mark_blocks_as_accessed(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def mark_blocks_as_computed(self) -> None:
|
def mark_blocks_as_computed(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -174,7 +174,16 @@ class NaiveBlockAllocator(BlockAllocator):
|
|||||||
"""
|
"""
|
||||||
return self._cow_tracker.clear_cows()
|
return self._cow_tracker.clear_cows()
|
||||||
|
|
||||||
def mark_blocks_as_computed(self) -> None:
|
def mark_blocks_as_accessed(self, block_ids: List[int],
|
||||||
|
now: float) -> None:
|
||||||
|
"""Mark blocks as accessed, used in prefix caching.
|
||||||
|
|
||||||
|
Since the naive allocator does not implement prefix caching, we do
|
||||||
|
nothing.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
|
||||||
"""Mark blocks as computed, used in prefix caching.
|
"""Mark blocks as computed, used in prefix caching.
|
||||||
|
|
||||||
Since the naive allocator does not implement prefix caching, we do
|
Since the naive allocator does not implement prefix caching, we do
|
||||||
|
|||||||
@ -7,10 +7,16 @@ from vllm.core.block.common import (CopyOnWriteTracker,
|
|||||||
get_all_blocks_recursively)
|
get_all_blocks_recursively)
|
||||||
from vllm.core.block.interfaces import Block, BlockAllocator
|
from vllm.core.block.interfaces import Block, BlockAllocator
|
||||||
from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator
|
from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator
|
||||||
|
from vllm.core.evictor_v2 import EvictionPolicy, Evictor, make_evictor
|
||||||
|
|
||||||
PrefixHash = int
|
PrefixHash = int
|
||||||
BlockId = int
|
BlockId = int
|
||||||
|
|
||||||
|
# By default, we init our block access time as _DEFAULT_LAST_ACCESSED_TIME
|
||||||
|
# so that if we find one block is still hold _DEFAULT_LAST_ACCESSED_TIME,
|
||||||
|
# then we know this block hasn't been accessed yet.
|
||||||
|
_DEFAULT_LAST_ACCESSED_TIME = -1
|
||||||
|
|
||||||
|
|
||||||
class PrefixCachingBlockAllocator(BlockAllocator):
|
class PrefixCachingBlockAllocator(BlockAllocator):
|
||||||
"""A block allocator that implements prefix caching.
|
"""A block allocator that implements prefix caching.
|
||||||
@ -27,22 +33,19 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|||||||
from 0 to num_blocks - 1.
|
from 0 to num_blocks - 1.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# TODO last access time / evictor integration
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_blocks: int,
|
num_blocks: int,
|
||||||
block_size: int,
|
block_size: int,
|
||||||
block_ids: Optional[Iterable[int]] = None,
|
block_ids: Optional[Iterable[int]] = None,
|
||||||
|
eviction_policy: Optional[EvictionPolicy] = EvictionPolicy.LRU,
|
||||||
):
|
):
|
||||||
# A mapping of prefix hash to block index. All blocks which have a
|
# A mapping of prefix hash to block index. All blocks which have a
|
||||||
# prefix hash will be in this dict, even if they have refcount 0.
|
# prefix hash will be in this dict, even if they have refcount 0.
|
||||||
self._cached_blocks: Dict[PrefixHash, BlockId] = {}
|
self._cached_blocks: Dict[PrefixHash, BlockId] = {}
|
||||||
|
|
||||||
# A mapping of prefix hash to block index. All blocks which have a
|
# A mapping of blockId to Block to track those cached blocks
|
||||||
# prefix hash AND refcount 0 will be in this dict. Thus, it is a subset
|
self._blocks: Dict[BlockId, Block] = {}
|
||||||
# of self._cached_blocks.
|
|
||||||
self._unused_cached_blocks: Dict[PrefixHash, BlockId] = {}
|
|
||||||
|
|
||||||
# An allocator for blocks that do not have prefix hashes.
|
# An allocator for blocks that do not have prefix hashes.
|
||||||
self._hashless_allocator = NaiveBlockAllocator(
|
self._hashless_allocator = NaiveBlockAllocator(
|
||||||
@ -54,6 +57,10 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|||||||
|
|
||||||
self._block_size = block_size
|
self._block_size = block_size
|
||||||
|
|
||||||
|
# Evitor used to maintain how we want to handle those computed blocks
|
||||||
|
# if we find memory pressure is high.
|
||||||
|
self.evictor: Evictor = make_evictor(eviction_policy)
|
||||||
|
|
||||||
# We share the refcounter between allocators. This allows us to promote
|
# We share the refcounter between allocators. This allows us to promote
|
||||||
# blocks originally allocated in the hashless allocator to immutable
|
# blocks originally allocated in the hashless allocator to immutable
|
||||||
# blocks.
|
# blocks.
|
||||||
@ -72,6 +79,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|||||||
block_size: int,
|
block_size: int,
|
||||||
allocator: BlockAllocator,
|
allocator: BlockAllocator,
|
||||||
block_id: Optional[int] = None,
|
block_id: Optional[int] = None,
|
||||||
|
computed: Optional[bool] = False,
|
||||||
) -> Block:
|
) -> Block:
|
||||||
# Bind block to self.
|
# Bind block to self.
|
||||||
allocator = self
|
allocator = self
|
||||||
@ -82,6 +90,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
block_id=block_id,
|
block_id=block_id,
|
||||||
prefix_caching_allocator=allocator,
|
prefix_caching_allocator=allocator,
|
||||||
|
computed=computed,
|
||||||
)
|
)
|
||||||
|
|
||||||
def allocate_immutable(self, prev_block: Optional[Block],
|
def allocate_immutable(self, prev_block: Optional[Block],
|
||||||
@ -109,14 +118,12 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|||||||
cached_block_id = self._cached_blocks.get(block.content_hash, None)
|
cached_block_id = self._cached_blocks.get(block.content_hash, None)
|
||||||
if cached_block_id is not None:
|
if cached_block_id is not None:
|
||||||
block.block_id = cached_block_id
|
block.block_id = cached_block_id
|
||||||
self._incr_refcount_cached_block(block.content_hash,
|
self._incr_refcount_cached_block(block, block.block_id)
|
||||||
block.block_id)
|
|
||||||
return block
|
return block
|
||||||
|
|
||||||
block = self.allocate_mutable(prev_block)
|
block = self.allocate_mutable(prev_block)
|
||||||
block.append_token_ids(token_ids)
|
block.append_token_ids(token_ids)
|
||||||
assert block.content_hash is not None
|
assert block.content_hash is not None
|
||||||
# TODO computed bit
|
|
||||||
|
|
||||||
return block
|
return block
|
||||||
|
|
||||||
@ -133,41 +140,67 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|||||||
assert_prefix_caching_block_or_none(prev_block)
|
assert_prefix_caching_block_or_none(prev_block)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return self._hashless_allocator.allocate_mutable(
|
block = self._hashless_allocator.allocate_mutable(
|
||||||
prev_block=prev_block)
|
prev_block=prev_block)
|
||||||
|
|
||||||
|
assert block.block_id not in self._blocks
|
||||||
|
self._blocks[block.block_id] = block
|
||||||
|
return block
|
||||||
except BlockAllocator.NoFreeBlocksError:
|
except BlockAllocator.NoFreeBlocksError:
|
||||||
# We must check the unused cached blocks before raising OOM.
|
# We must check the unused cached blocks before raising OOM.
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if self._unused_cached_blocks:
|
# If the evictor has blocks available for eviction, evict a block
|
||||||
# TODO policy for selecting block to remove
|
# and return it.
|
||||||
content_hash_to_evict = next(iter(self._unused_cached_blocks))
|
if self.evictor.num_blocks > 0:
|
||||||
|
block_id, content_hash_to_evict = self.evictor.evict()
|
||||||
|
|
||||||
# Clear content hash mapping; the block will be overwritten.
|
# Here we may have scenario that several blocks have
|
||||||
del self._cached_blocks[content_hash_to_evict]
|
# the same content hash, but due to the latter coming block
|
||||||
|
# is coming from mutable to immutable path, their physical
|
||||||
|
# block is added into evictor.
|
||||||
|
# However in this case, we shall not pop the _cached_blocks,
|
||||||
|
# as the same content is still used by others, which means
|
||||||
|
# we need to check ref before decide to pop the list.
|
||||||
|
|
||||||
block_id = self._unused_cached_blocks.pop(content_hash_to_evict)
|
_block_id = self._cached_blocks[content_hash_to_evict]
|
||||||
refcount = self._refcounter.incr(block_id)
|
refcount = self._refcounter.get(_block_id)
|
||||||
assert refcount == 1
|
if refcount == 1:
|
||||||
|
self._cached_blocks.pop(content_hash_to_evict)
|
||||||
|
assert _block_id == block_id
|
||||||
|
|
||||||
|
self._refcounter.incr(block_id)
|
||||||
|
|
||||||
|
# the block comes from evictor already contain computed result
|
||||||
block = self._create_block(
|
block = self._create_block(
|
||||||
prev_block=prev_block,
|
prev_block=prev_block,
|
||||||
token_ids=[],
|
token_ids=[],
|
||||||
block_size=self._block_size,
|
block_size=self._block_size,
|
||||||
allocator=self,
|
allocator=self,
|
||||||
block_id=block_id,
|
block_id=block_id,
|
||||||
|
computed=True,
|
||||||
)
|
)
|
||||||
assert block.content_hash is None
|
assert block.content_hash is None
|
||||||
|
|
||||||
|
assert block.block_id not in self._blocks
|
||||||
|
self._blocks[block.block_id] = block
|
||||||
return block
|
return block
|
||||||
|
|
||||||
# No block available in hashless allocator, nor in unused cache blocks.
|
# No block available in hashless allocator, nor in unused cache blocks.
|
||||||
raise BlockAllocator.NoFreeBlocksError()
|
raise BlockAllocator.NoFreeBlocksError()
|
||||||
|
|
||||||
def _incr_refcount_cached_block(self, content_hash: int,
|
def _incr_refcount_cached_block(self, block: Block,
|
||||||
block_id: BlockId) -> None:
|
block_id: BlockId) -> None:
|
||||||
|
# since block is already computed, mark it
|
||||||
|
block.computed = True
|
||||||
|
|
||||||
refcount = self._refcounter.incr(block_id)
|
refcount = self._refcounter.incr(block_id)
|
||||||
if refcount == 1:
|
if refcount == 1:
|
||||||
assert content_hash in self._unused_cached_blocks
|
# if block get referred, then it shall not be in evictor
|
||||||
del self._unused_cached_blocks[content_hash]
|
# and put it into _blocks for tracking
|
||||||
|
if block_id in self.evictor:
|
||||||
|
self.evictor.remove(block_id)
|
||||||
|
self._blocks[block_id] = block
|
||||||
|
|
||||||
def free(self, block: Block) -> None:
|
def free(self, block: Block) -> None:
|
||||||
"""Decrement the refcount of the block. If the decremented refcount is
|
"""Decrement the refcount of the block. If the decremented refcount is
|
||||||
@ -180,6 +213,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|||||||
is not None), "freeing unallocated block is undefined"
|
is not None), "freeing unallocated block is undefined"
|
||||||
|
|
||||||
self._free_block_id_for_block(block.block_id, block)
|
self._free_block_id_for_block(block.block_id, block)
|
||||||
|
|
||||||
block.block_id = None
|
block.block_id = None
|
||||||
|
|
||||||
def _free_block_id_for_block(self, block_id: BlockId,
|
def _free_block_id_for_block(self, block_id: BlockId,
|
||||||
@ -187,15 +221,21 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|||||||
assert isinstance(block, PrefixCachingBlock)
|
assert isinstance(block, PrefixCachingBlock)
|
||||||
|
|
||||||
if block.content_hash is None:
|
if block.content_hash is None:
|
||||||
|
refcount = self._refcounter.get(block_id)
|
||||||
|
# We have fork case where block would get more than one ref,
|
||||||
|
# so we cannot free it from tracking if ref cnt large than 1
|
||||||
|
if refcount <= 1:
|
||||||
|
del self._blocks[block.block_id]
|
||||||
return self._hashless_allocator.free(block)
|
return self._hashless_allocator.free(block)
|
||||||
|
|
||||||
refcount = self._refcounter.decr(block_id)
|
refcount = self._refcounter.decr(block_id)
|
||||||
|
|
||||||
# If no longer used, add the block to the unused cached blocks.
|
# If no longer used, add the block to the evictor.
|
||||||
if refcount == 0:
|
if refcount == 0:
|
||||||
assert block.content_hash not in self._unused_cached_blocks
|
|
||||||
assert block.content_hash in self._cached_blocks
|
assert block.content_hash in self._cached_blocks
|
||||||
self._unused_cached_blocks[block.content_hash] = block_id
|
del self._blocks[block.block_id]
|
||||||
|
self.evictor.add(block.block_id, block.content_hash,
|
||||||
|
block.num_tokens_total, block.last_accessed)
|
||||||
|
|
||||||
def fork(self, last_block: Block) -> List[Block]:
|
def fork(self, last_block: Block) -> List[Block]:
|
||||||
"""Creates a new sequence of blocks that shares the same underlying
|
"""Creates a new sequence of blocks that shares the same underlying
|
||||||
@ -230,9 +270,9 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|||||||
|
|
||||||
def get_num_free_blocks(self) -> int:
|
def get_num_free_blocks(self) -> int:
|
||||||
# The number of free blocks is the number of hashless free blocks
|
# The number of free blocks is the number of hashless free blocks
|
||||||
# plus the number of hashful blocks that are unused.
|
# plus the number of blocks evictor could free from its list.
|
||||||
return self._hashless_allocator.get_num_free_blocks() + len(
|
return self._hashless_allocator.get_num_free_blocks(
|
||||||
self._unused_cached_blocks)
|
) + self.evictor.num_blocks
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def all_block_ids(self) -> frozenset[int]:
|
def all_block_ids(self) -> frozenset[int]:
|
||||||
@ -266,7 +306,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|||||||
else:
|
else:
|
||||||
self._free_block_id_for_block(block.block_id, block)
|
self._free_block_id_for_block(block.block_id, block)
|
||||||
self._incr_refcount_cached_block(
|
self._incr_refcount_cached_block(
|
||||||
block.content_hash, self._cached_blocks[block.content_hash])
|
block, self._cached_blocks[block.content_hash])
|
||||||
|
|
||||||
return self._cached_blocks[block.content_hash]
|
return self._cached_blocks[block.content_hash]
|
||||||
|
|
||||||
@ -293,29 +333,60 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|||||||
"""
|
"""
|
||||||
return self._cow_tracker.clear_cows()
|
return self._cow_tracker.clear_cows()
|
||||||
|
|
||||||
def mark_blocks_as_computed(self) -> None:
|
def mark_blocks_as_accessed(self, block_ids: List[int],
|
||||||
|
now: float) -> None:
|
||||||
|
"""Mark blocks as accessed, used in prefix caching.
|
||||||
|
|
||||||
|
If the block is added into evictor, we need to update corresponding
|
||||||
|
info in evictor's metadata.
|
||||||
|
"""
|
||||||
|
|
||||||
|
for block_id in block_ids:
|
||||||
|
if block_id in self._blocks:
|
||||||
|
self._blocks[block_id].last_accessed = now
|
||||||
|
elif block_id in self.evictor:
|
||||||
|
self.evictor.update(block_id, now)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Mark block as accessed which is not belonged to GPU")
|
||||||
|
|
||||||
|
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
|
||||||
"""Mark blocks as computed, used in prefix caching."""
|
"""Mark blocks as computed, used in prefix caching."""
|
||||||
# TODO Track computed blocks.
|
|
||||||
pass
|
for block_id in block_ids:
|
||||||
|
if block_id in self._blocks:
|
||||||
|
# only those full block is valid for prefix caching
|
||||||
|
if self._blocks[block_id].is_full:
|
||||||
|
self._blocks[block_id].computed = True
|
||||||
|
elif block_id not in self.evictor:
|
||||||
|
raise ValueError(f"Mark {block_id=} as computed which "
|
||||||
|
"is not belonged to GPU")
|
||||||
|
|
||||||
|
def block_is_computed(self, block_id: int) -> bool:
|
||||||
|
if block_id in self._blocks:
|
||||||
|
return self._blocks[block_id].computed
|
||||||
|
else:
|
||||||
|
return block_id in self.evictor
|
||||||
|
|
||||||
def get_common_computed_block_ids(
|
def get_common_computed_block_ids(
|
||||||
self, seq_block_ids: List[List[int]]) -> List[int]:
|
self, seq_block_ids: List[List[int]]) -> List[int]:
|
||||||
"""Return the block ids that are common for a given sequence group.
|
"""Return the block ids that are common for a given sequence group.
|
||||||
|
|
||||||
Used in prefill (can skip prefill of some blocks).
|
Only those blocks that are immutable and already be marked
|
||||||
|
compyted would be taken consideration.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# TODO: Track computed blocks.
|
|
||||||
computed = lambda block_id: False
|
|
||||||
|
|
||||||
# NOTE We exclude the last block to avoid the case where the entire
|
# NOTE We exclude the last block to avoid the case where the entire
|
||||||
# prompt is cached. This would cause erroneous behavior in model
|
# prompt is cached. This would cause erroneous behavior in model
|
||||||
# runner.
|
# runner.
|
||||||
|
|
||||||
ids_list = [
|
ids_list = [
|
||||||
takewhile(lambda block_id: computed(block_id), seq[:-1])
|
list(
|
||||||
for seq in seq_block_ids
|
takewhile(lambda block_id: self.block_is_computed(block_id),
|
||||||
|
seq[:-1])) for seq in seq_block_ids
|
||||||
]
|
]
|
||||||
return commonprefix([ids for ids in ids_list if ids != []])
|
res = commonprefix([ids for ids in ids_list if ids != []])
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
class PrefixCachingBlock(Block):
|
class PrefixCachingBlock(Block):
|
||||||
@ -345,12 +416,16 @@ class PrefixCachingBlock(Block):
|
|||||||
block_size: int,
|
block_size: int,
|
||||||
prefix_caching_allocator: PrefixCachingBlockAllocator,
|
prefix_caching_allocator: PrefixCachingBlockAllocator,
|
||||||
block_id: Optional[int] = None,
|
block_id: Optional[int] = None,
|
||||||
|
computed: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
assert_prefix_caching_block_or_none(prev_block)
|
assert_prefix_caching_block_or_none(prev_block)
|
||||||
|
|
||||||
self._prev_block = prev_block
|
self._prev_block = prev_block
|
||||||
self._cached_content_hash: Optional[int] = None
|
self._cached_content_hash: Optional[int] = None
|
||||||
|
self._cached_num_tokens_total: Optional[int] = None
|
||||||
self._prefix_caching_allocator = prefix_caching_allocator
|
self._prefix_caching_allocator = prefix_caching_allocator
|
||||||
|
self.last_accessed = _DEFAULT_LAST_ACCESSED_TIME
|
||||||
|
self.computed = computed
|
||||||
|
|
||||||
self._block = NaiveBlock(
|
self._block = NaiveBlock(
|
||||||
prev_block=prev_block,
|
prev_block=prev_block,
|
||||||
@ -398,6 +473,27 @@ class PrefixCachingBlock(Block):
|
|||||||
def num_empty_slots(self) -> int:
|
def num_empty_slots(self) -> int:
|
||||||
return self._block.num_empty_slots
|
return self._block.num_empty_slots
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_tokens_total(self) -> int:
|
||||||
|
"""return the total tokens so far.
|
||||||
|
|
||||||
|
Here we iterate the block chain till to the first block, while
|
||||||
|
cache the result in local to prevent repeated computations.
|
||||||
|
"""
|
||||||
|
if self._cached_num_tokens_total is not None:
|
||||||
|
return self._cached_num_tokens_total
|
||||||
|
|
||||||
|
_block = self
|
||||||
|
self._cached_num_tokens_total = 0
|
||||||
|
|
||||||
|
# TODO: current implement here take O(N^2), we expect future
|
||||||
|
# we have O(1) here
|
||||||
|
while _block is not None:
|
||||||
|
self._cached_num_tokens_total += len(_block.token_ids)
|
||||||
|
_block = _block.prev_block
|
||||||
|
|
||||||
|
return self._cached_num_tokens_total
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def block_size(self) -> int:
|
def block_size(self) -> int:
|
||||||
return self._block.block_size
|
return self._block.block_size
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from typing import Sequence as GenericSequence
|
|||||||
from typing import Set
|
from typing import Set
|
||||||
|
|
||||||
from vllm.block import BlockTable, PhysicalTokenBlock
|
from vllm.block import BlockTable, PhysicalTokenBlock
|
||||||
from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor
|
from vllm.core.evictor_v1 import EvictionPolicy, Evictor, make_evictor
|
||||||
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
|
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
|
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
|
||||||
|
|||||||
@ -72,14 +72,12 @@ class BlockSpaceManagerV2(BlockSpaceManager):
|
|||||||
self.watermark = watermark
|
self.watermark = watermark
|
||||||
assert watermark >= 0.0
|
assert watermark >= 0.0
|
||||||
|
|
||||||
assert not enable_caching, "Prefix caching not yet supported"
|
|
||||||
self.enable_caching = enable_caching
|
self.enable_caching = enable_caching
|
||||||
|
|
||||||
self.watermark_blocks = int(watermark * num_gpu_blocks)
|
self.watermark_blocks = int(watermark * num_gpu_blocks)
|
||||||
|
|
||||||
self.block_allocator = CpuGpuBlockAllocator.create(
|
self.block_allocator = CpuGpuBlockAllocator.create(
|
||||||
# Currently, only naive blocks are supported (no prefix caching).
|
allocator_type="prefix_caching" if enable_caching else "naive",
|
||||||
allocator_type="naive",
|
|
||||||
num_gpu_blocks=num_gpu_blocks,
|
num_gpu_blocks=num_gpu_blocks,
|
||||||
num_cpu_blocks=num_cpu_blocks,
|
num_cpu_blocks=num_cpu_blocks,
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
@ -194,17 +192,26 @@ class BlockSpaceManagerV2(BlockSpaceManager):
|
|||||||
assert all(b is not None for b in block_ids)
|
assert all(b is not None for b in block_ids)
|
||||||
return block_ids
|
return block_ids
|
||||||
|
|
||||||
def access_all_blocks_in_seq(self, seq, now):
|
def access_all_blocks_in_seq(self, seq: Sequence, now: float):
|
||||||
# TODO add prefix caching support.
|
# Update the last accessed time of all the blocks accessed
|
||||||
# Tracked here https://github.com/vllm-project/vllm/issues/3667
|
# in this step.
|
||||||
pass
|
# And the accessed time is only useful for prefix caching now,
|
||||||
|
# as it support internal evictor policy for which cached
|
||||||
|
# block could be refilled, to keep cached content could be reused
|
||||||
|
# at max extend.
|
||||||
|
if self.enable_caching:
|
||||||
|
block_table = self.block_tables[seq.seq_id]
|
||||||
|
block_ids = []
|
||||||
|
for block_id in block_table.physical_block_ids:
|
||||||
|
block_ids.append(block_id)
|
||||||
|
self.block_allocator.mark_blocks_as_accessed(block_ids, now)
|
||||||
|
|
||||||
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
|
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
|
||||||
# We ignore the sequence group as its not necessary. After the batch is
|
# The only need for mark block as computed is for prefix caching,
|
||||||
# formed by the scheduler, we do not need to mark blocks from individual
|
# while currently we could determine whether one block is computed
|
||||||
# sequence groups as computed -- all blocks in the batch can be marked
|
# or not by check whether it has content hash.
|
||||||
# as computed.
|
# So this function is useless for block_v2.
|
||||||
self.block_allocator.mark_blocks_as_computed()
|
pass
|
||||||
|
|
||||||
def get_common_computed_block_ids(
|
def get_common_computed_block_ids(
|
||||||
self, seqs: List[Sequence]) -> GenericSequence[int]:
|
self, seqs: List[Sequence]) -> GenericSequence[int]:
|
||||||
|
|||||||
122
vllm/core/evictor_v2.py
Normal file
122
vllm/core/evictor_v2.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
import enum
|
||||||
|
from abc import ABC, abstractmethod, abstractproperty
|
||||||
|
from typing import OrderedDict, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
class EvictionPolicy(enum.Enum):
|
||||||
|
"""Enum for eviction policy used by make_evictor to instantiate the correct
|
||||||
|
Evictor subclass.
|
||||||
|
"""
|
||||||
|
LRU = enum.auto()
|
||||||
|
|
||||||
|
|
||||||
|
class Evictor(ABC):
|
||||||
|
"""The Evictor subclasses should be used by the BlockAllocator class to
|
||||||
|
handle eviction of freed PhysicalTokenBlocks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __contains__(self, block_id: int) -> bool:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def evict(self) -> Tuple[int, int]:
|
||||||
|
"""Runs the eviction algorithm and returns the evicted block's
|
||||||
|
content hash along with physical block id along with physical block id
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def add(self, block_id: int, content_hash: int, num_hashed_tokens: int,
|
||||||
|
last_accessed: int):
|
||||||
|
"""Adds block to the evictor, making it a candidate for eviction"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update(self, block_id: int, last_accessed: int):
|
||||||
|
"""Update corresponding block's access time in metadata"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractproperty
|
||||||
|
def num_blocks(self) -> int:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BlockMetaData():
|
||||||
|
"""Data structure for storing key data describe cached block, so that
|
||||||
|
evitor could use to make its decision which one to choose for eviction
|
||||||
|
|
||||||
|
Here we use physical block id as the dict key, as there maybe several
|
||||||
|
blocks with the same content hash, but their physical id is unique.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, content_hash: int, num_hashed_tokens: int,
|
||||||
|
last_accessed: int):
|
||||||
|
self.content_hash = content_hash
|
||||||
|
self.num_hashed_tokens = num_hashed_tokens
|
||||||
|
self.last_accessed = last_accessed
|
||||||
|
|
||||||
|
|
||||||
|
class LRUEvictor(Evictor):
|
||||||
|
"""Evicts in a least-recently-used order using the last_accessed timestamp
|
||||||
|
that's recorded in the PhysicalTokenBlock. If there are multiple blocks with
|
||||||
|
the same last_accessed time, then the one with the largest num_hashed_tokens
|
||||||
|
will be evicted. If two blocks each have the lowest last_accessed time and
|
||||||
|
highest num_hashed_tokens value, then one will be chose arbitrarily
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.free_table: OrderedDict[int, BlockMetaData] = OrderedDict()
|
||||||
|
|
||||||
|
def __contains__(self, block_id: int) -> bool:
|
||||||
|
return block_id in self.free_table
|
||||||
|
|
||||||
|
def evict(self) -> Tuple[int, int]:
|
||||||
|
if len(self.free_table) == 0:
|
||||||
|
raise ValueError("No usable cache memory left")
|
||||||
|
|
||||||
|
evicted_block = next(iter(self.free_table.values()))
|
||||||
|
evicted_block_id = next(iter(self.free_table.keys()))
|
||||||
|
# The blocks with the lowest timestamps should be placed consecutively
|
||||||
|
# at the start of OrderedDict. Loop through all these blocks to
|
||||||
|
# find the one with maximum number of hashed tokens.
|
||||||
|
for _id, block in self.free_table.items():
|
||||||
|
if evicted_block.last_accessed > block.last_accessed or (
|
||||||
|
evicted_block.last_accessed == block.last_accessed and
|
||||||
|
evicted_block.num_hashed_tokens < block.num_hashed_tokens):
|
||||||
|
evicted_block = block
|
||||||
|
evicted_block_id = _id
|
||||||
|
|
||||||
|
self.free_table.pop(evicted_block_id)
|
||||||
|
|
||||||
|
return evicted_block_id, evicted_block.content_hash
|
||||||
|
|
||||||
|
def add(self, block_id: int, content_hash: int, num_hashed_tokens: int,
|
||||||
|
last_accessed: int):
|
||||||
|
self.free_table[block_id] = BlockMetaData(content_hash,
|
||||||
|
num_hashed_tokens,
|
||||||
|
last_accessed)
|
||||||
|
|
||||||
|
def update(self, block_id: int, last_accessed: int):
|
||||||
|
self.free_table[block_id].last_accessed = last_accessed
|
||||||
|
|
||||||
|
def remove(self, block_id: int):
|
||||||
|
if block_id not in self.free_table:
|
||||||
|
raise ValueError(
|
||||||
|
"Attempting to remove block that's not in the evictor")
|
||||||
|
self.free_table.pop(block_id)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_blocks(self) -> int:
|
||||||
|
return len(self.free_table)
|
||||||
|
|
||||||
|
|
||||||
|
def make_evictor(eviction_policy: EvictionPolicy) -> Evictor:
|
||||||
|
if eviction_policy == EvictionPolicy.LRU:
|
||||||
|
return LRUEvictor()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown cache eviction policy: {eviction_policy}")
|
||||||
Loading…
x
Reference in New Issue
Block a user