[Core] Enable prefix caching with block manager v2 enabled (#4142)

Co-authored-by: Lei Wen <wenlei03@qiyi.com>
Co-authored-by: Sage Moore <sagemoore@utexas.edu>
This commit is contained in:
leiwen83 2024-05-02 02:20:32 +08:00 committed by GitHub
parent b38e42fbca
commit 24750f4cad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 584 additions and 57 deletions

View File

@ -16,20 +16,22 @@ def test_prefix(llm=None, sampling_params=None, prompts=None):
def main(args): def main(args):
llm = LLM(model="baichuan-inc/Baichuan2-13B-Chat", llm = LLM(model=args.model,
tokenizer_mode='auto', tokenizer_mode='auto',
trust_remote_code=True, trust_remote_code=True,
enforce_eager=True, enforce_eager=True,
use_v2_block_manager=args.use_v2_block_manager,
tensor_parallel_size=args.tensor_parallel_size,
enable_prefix_caching=args.enable_prefix_caching) enable_prefix_caching=args.enable_prefix_caching)
num_prompts = 100 num_prompts = 100
prompts = [PROMPT] * num_prompts prompts = [PROMPT] * num_prompts
sampling_params = SamplingParams(temperature=0, max_tokens=100) sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
print("------warm up------") print("------warm up------")
test_prefix( test_prefix(
llm=llm, llm=llm,
prompts=prompts[:1], prompts=prompts,
sampling_params=sampling_params, sampling_params=sampling_params,
) )
@ -45,8 +47,16 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Benchmark the performance with or without automatic ' description='Benchmark the performance with or without automatic '
'prefix caching.') 'prefix caching.')
parser.add_argument('--model',
type=str,
default='baichuan-inc/Baichuan2-13B-Chat')
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--output-len', type=int, default=10)
parser.add_argument('--enable-prefix-caching', parser.add_argument('--enable-prefix-caching',
action='store_true', action='store_true',
help='enable prefix caching') help='enable prefix caching')
parser.add_argument('--use-v2-block-manager',
action='store_true',
help='Use BlockSpaceMangerV2')
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@ -300,6 +300,152 @@ def test_chunked_prefill_block_manager_v2(baseline_llm_generator,
assert baseline_token_ids == test_token_ids assert baseline_token_ids == test_token_ids
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Use a small model for a fast test.
"model": "facebook/opt-125m",
# skip cuda graph creation for fast test.
"enforce_eager": True,
# Allow only 5 sequences of ~1024 tokens in worst case.
"block_size": 16,
"num_gpu_blocks_override": 5 * (64 + 1),
# Enable prefill cache
"enable_prefix_caching": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{
"use_v2_block_manager": False
}])
@pytest.mark.parametrize("test_llm_kwargs", [{"use_v2_block_manager": True}])
@pytest.mark.parametrize("batch_size", [10])
@pytest.mark.parametrize("seed", [1])
def test_v1_v2_greedy_equality_prefix_caching_enabled_with_preemption(
baseline_llm_generator, test_llm_generator, batch_size):
"""Verify block manager v2 produces same outputs as block manager v1, even
when there is preemption.
This constructs two LLM, each with limited number of GPU blocks. The limit
is decided such that as the sequences in the batch grow, sequences must be
preempted and removed from cache.
If the output token ids are equivalent, then we have confidence that the KV
cache is not corrupted in the v2 block manager.
NOTE: We want a significant number of generated tokens so that any incorrect
KV mapping has time to build up error.
"""
output_len = 1024
temperature = 0.0
# We want to ensure equality even with preemption.
# We force the total block size to be 1 + cdiv(output_len, block_size)
# so that only one sequence can fit at a time (once the sequences grow).
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)
print('Getting token ids from block manager v1')
baseline_token_ids = get_token_ids_from_llm_generator(
baseline_llm_generator, prompts, sampling_params)
print('Getting token ids from block manager v2')
test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
prompts, sampling_params)
for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
test_token_ids):
assert expected_token_ids == actual_token_ids
assert baseline_token_ids == test_token_ids
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Use a small model for a fast test.
"model": "facebook/opt-125m",
# skip cuda graph creation for fast test.
"enforce_eager": True,
# Allow only 5 sequences of ~1024 tokens in worst case.
"block_size": 16,
"num_gpu_blocks_override": 5 * (64 + 1),
# Test APC in v2 block
"use_v2_block_manager": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{
"enable_prefix_caching": False
}])
@pytest.mark.parametrize("test_llm_kwargs", [{"enable_prefix_caching": True}])
@pytest.mark.parametrize("batch_size", [10])
@pytest.mark.parametrize("seed", [1])
def test_auto_prefix_caching_with_preemption(baseline_llm_generator,
test_llm_generator, batch_size):
"""Verify block manager v2 with auto prefix caching enabled produces same
outputs as auto prefix caching disabled, even when there is preemption.
This constructs two LLM, each with limited number of GPU blocks. The limit
is decided such that as the sequences in the batch grow, sequences must be
preempted and removed from cache.
If the output token ids are equivalent, then we have confidence that auto
prefix caching itself at least don't cause result error.
"""
output_len = 1024
temperature = 0.0
# We want to ensure equality even with preemption.
# We force the total block size to be 1 + cdiv(output_len, block_size)
# so that only one sequence can fit at a time (once the sequences grow).
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)
print('Getting token ids with APC disabled')
baseline_token_ids = get_token_ids_from_llm_generator(
baseline_llm_generator, prompts, sampling_params)
print('Getting token ids with APC enabled')
test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
prompts, sampling_params)
for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
test_token_ids):
assert expected_token_ids == actual_token_ids
assert baseline_token_ids == test_token_ids
def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params): def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params):
for llm in llm_generator: for llm in llm_generator:
outputs = llm.generate(prompts, sampling_params, use_tqdm=True) outputs = llm.generate(prompts, sampling_params, use_tqdm=True)

View File

@ -358,6 +358,131 @@ class TestPrefixCachingBlockAllocator:
i) i)
allocator.free(block) allocator.free(block)
@staticmethod
@pytest.mark.parametrize("num_blocks", [1024])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("seed", list(range(20)))
def test_get_common_computed_block_ids(num_blocks: int, block_size: int,
seed: int):
"""Verify get_common_computed_block_ids could get correct result
by create two immutable chain sharing prefix at specified pos,
and compare whether we also could get right result
from get_common_computed_block_ids.
"""
random.seed(seed)
allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks * 2,
block_size=block_size)
num_blocks_to_consume = random.randint(1, num_blocks - 1)
# Create token ids that will exhaust all blocks.
token_ids = list(range(num_blocks_to_consume * block_size))
blocks = list(range(num_blocks_to_consume))
first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=token_ids,
allocator=allocator,
)
# mark all blocks in first chain as computed
allocator.mark_blocks_as_computed(blocks)
# After zero_point, second_chain's token_ids would be set -1, which
# make it different from here comparing with first_chain
zero_point = random.randint(1, len(token_ids) - 1)
zero_point_blocks = zero_point // block_size
token_ids[zero_point:] = [-1] * (len(token_ids) - zero_point)
second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=token_ids,
allocator=allocator,
)
first_computed_ids = [
first_chain[i].block_id for i in range(num_blocks_to_consume)
]
second_computed_ids = [
second_chain[i].block_id for i in range(num_blocks_to_consume)
]
res = allocator.get_common_computed_block_ids(
[first_computed_ids, second_computed_ids])
assert (len(res) == zero_point_blocks)
# Test case where two last accessed times are equal
@staticmethod
@pytest.mark.parametrize("num_blocks", [1024])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("seed", list(range(20)))
def test_eviction_order(num_blocks: int, block_size: int, seed: int):
"""This test case simulate the two chain created and free in order,
and together they would exhaust the initial freed blocks.
So the next block created after those two chain shall use the block
from the first chain as that block has long access time.
While first chain has two blocks, it shall pick up the last one, as
it has larger token number.
"""
random.seed(seed)
allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
block_size=block_size)
num_blocks_to_consume = num_blocks + 1
token_ids = list(range(num_blocks_to_consume * block_size))
num_blocks_in_first_chain = 2
num_tokens_in_first_chain = block_size * num_blocks_in_first_chain
# First chain takes the first block
first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=token_ids[:num_tokens_in_first_chain],
allocator=allocator,
)
# There should only be one block allocated at this point
assert allocator.get_num_free_blocks() == (num_blocks -
num_blocks_in_first_chain)
# Set the last accessed time of the first block to 1
blocks_ids = [block.block_id for block in first_chain]
allocator.mark_blocks_as_accessed(blocks_ids, 1)
# Second chain takes the rest of the blocks
second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=token_ids[num_tokens_in_first_chain:-block_size],
allocator=allocator,
)
# There shouldn't be any blocks left at this point
assert allocator.get_num_free_blocks() == (0)
assert len(first_chain) == num_blocks_in_first_chain
last_block_id = first_chain[-1].block_id
# Free each block in the first chain.
for i, block in enumerate(first_chain):
allocator.free(block)
# Set the last accessed time on all of the blocks in the second chain
# to 2
blocks_ids = [block.block_id for block in second_chain]
allocator.mark_blocks_as_accessed(blocks_ids, 2)
# Free each block in the second chain.
for i, block in enumerate(second_chain):
allocator.free(block)
# Allocate a new block and check that it's the least recently used block
# from the first chain.
new_block = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=token_ids[-block_size:],
allocator=allocator,
)
assert new_block[0].block_id == last_block_id
@staticmethod @staticmethod
def create_immutable_chain( def create_immutable_chain(
block_size: int, block_size: int,

View File

@ -190,10 +190,18 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
device = Device.GPU device = Device.GPU
return self._allocators[device].clear_copy_on_writes() return self._allocators[device].clear_copy_on_writes()
def mark_blocks_as_computed(self) -> None: def mark_blocks_as_accessed(self, block_ids: List[int],
now: float) -> None:
"""Mark blocks as accessed, only use for prefix caching."""
# Prefix caching only supported on GPU. # Prefix caching only supported on GPU.
device = Device.GPU device = Device.GPU
return self._allocators[device].mark_blocks_as_computed() return self._allocators[device].mark_blocks_as_accessed(block_ids, now)
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
"""Mark blocks as accessed, only use for prefix caching."""
# Prefix caching only supported on GPU.
device = Device.GPU
return self._allocators[device].mark_blocks_as_computed(block_ids)
def get_common_computed_block_ids( def get_common_computed_block_ids(
self, seq_block_ids: List[List[int]]) -> List[int]: self, seq_block_ids: List[List[int]]) -> List[int]:

View File

@ -81,6 +81,10 @@ class BlockAllocator(ABC):
def clear_copy_on_writes(self) -> Dict[int, List[int]]: def clear_copy_on_writes(self) -> Dict[int, List[int]]:
pass pass
@abstractmethod
def mark_blocks_as_accessed(self) -> None:
pass
@abstractmethod @abstractmethod
def mark_blocks_as_computed(self) -> None: def mark_blocks_as_computed(self) -> None:
pass pass

View File

@ -174,7 +174,16 @@ class NaiveBlockAllocator(BlockAllocator):
""" """
return self._cow_tracker.clear_cows() return self._cow_tracker.clear_cows()
def mark_blocks_as_computed(self) -> None: def mark_blocks_as_accessed(self, block_ids: List[int],
now: float) -> None:
"""Mark blocks as accessed, used in prefix caching.
Since the naive allocator does not implement prefix caching, we do
nothing.
"""
pass
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
"""Mark blocks as computed, used in prefix caching. """Mark blocks as computed, used in prefix caching.
Since the naive allocator does not implement prefix caching, we do Since the naive allocator does not implement prefix caching, we do

View File

@ -7,10 +7,16 @@ from vllm.core.block.common import (CopyOnWriteTracker,
get_all_blocks_recursively) get_all_blocks_recursively)
from vllm.core.block.interfaces import Block, BlockAllocator from vllm.core.block.interfaces import Block, BlockAllocator
from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator
from vllm.core.evictor_v2 import EvictionPolicy, Evictor, make_evictor
PrefixHash = int PrefixHash = int
BlockId = int BlockId = int
# By default, we init our block access time as _DEFAULT_LAST_ACCESSED_TIME
# so that if we find one block is still hold _DEFAULT_LAST_ACCESSED_TIME,
# then we know this block hasn't been accessed yet.
_DEFAULT_LAST_ACCESSED_TIME = -1
class PrefixCachingBlockAllocator(BlockAllocator): class PrefixCachingBlockAllocator(BlockAllocator):
"""A block allocator that implements prefix caching. """A block allocator that implements prefix caching.
@ -27,22 +33,19 @@ class PrefixCachingBlockAllocator(BlockAllocator):
from 0 to num_blocks - 1. from 0 to num_blocks - 1.
""" """
# TODO last access time / evictor integration
def __init__( def __init__(
self, self,
num_blocks: int, num_blocks: int,
block_size: int, block_size: int,
block_ids: Optional[Iterable[int]] = None, block_ids: Optional[Iterable[int]] = None,
eviction_policy: Optional[EvictionPolicy] = EvictionPolicy.LRU,
): ):
# A mapping of prefix hash to block index. All blocks which have a # A mapping of prefix hash to block index. All blocks which have a
# prefix hash will be in this dict, even if they have refcount 0. # prefix hash will be in this dict, even if they have refcount 0.
self._cached_blocks: Dict[PrefixHash, BlockId] = {} self._cached_blocks: Dict[PrefixHash, BlockId] = {}
# A mapping of prefix hash to block index. All blocks which have a # A mapping of blockId to Block to track those cached blocks
# prefix hash AND refcount 0 will be in this dict. Thus, it is a subset self._blocks: Dict[BlockId, Block] = {}
# of self._cached_blocks.
self._unused_cached_blocks: Dict[PrefixHash, BlockId] = {}
# An allocator for blocks that do not have prefix hashes. # An allocator for blocks that do not have prefix hashes.
self._hashless_allocator = NaiveBlockAllocator( self._hashless_allocator = NaiveBlockAllocator(
@ -54,6 +57,10 @@ class PrefixCachingBlockAllocator(BlockAllocator):
self._block_size = block_size self._block_size = block_size
# Evitor used to maintain how we want to handle those computed blocks
# if we find memory pressure is high.
self.evictor: Evictor = make_evictor(eviction_policy)
# We share the refcounter between allocators. This allows us to promote # We share the refcounter between allocators. This allows us to promote
# blocks originally allocated in the hashless allocator to immutable # blocks originally allocated in the hashless allocator to immutable
# blocks. # blocks.
@ -72,6 +79,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
block_size: int, block_size: int,
allocator: BlockAllocator, allocator: BlockAllocator,
block_id: Optional[int] = None, block_id: Optional[int] = None,
computed: Optional[bool] = False,
) -> Block: ) -> Block:
# Bind block to self. # Bind block to self.
allocator = self allocator = self
@ -82,6 +90,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
block_size=block_size, block_size=block_size,
block_id=block_id, block_id=block_id,
prefix_caching_allocator=allocator, prefix_caching_allocator=allocator,
computed=computed,
) )
def allocate_immutable(self, prev_block: Optional[Block], def allocate_immutable(self, prev_block: Optional[Block],
@ -109,14 +118,12 @@ class PrefixCachingBlockAllocator(BlockAllocator):
cached_block_id = self._cached_blocks.get(block.content_hash, None) cached_block_id = self._cached_blocks.get(block.content_hash, None)
if cached_block_id is not None: if cached_block_id is not None:
block.block_id = cached_block_id block.block_id = cached_block_id
self._incr_refcount_cached_block(block.content_hash, self._incr_refcount_cached_block(block, block.block_id)
block.block_id)
return block return block
block = self.allocate_mutable(prev_block) block = self.allocate_mutable(prev_block)
block.append_token_ids(token_ids) block.append_token_ids(token_ids)
assert block.content_hash is not None assert block.content_hash is not None
# TODO computed bit
return block return block
@ -133,41 +140,67 @@ class PrefixCachingBlockAllocator(BlockAllocator):
assert_prefix_caching_block_or_none(prev_block) assert_prefix_caching_block_or_none(prev_block)
try: try:
return self._hashless_allocator.allocate_mutable( block = self._hashless_allocator.allocate_mutable(
prev_block=prev_block) prev_block=prev_block)
assert block.block_id not in self._blocks
self._blocks[block.block_id] = block
return block
except BlockAllocator.NoFreeBlocksError: except BlockAllocator.NoFreeBlocksError:
# We must check the unused cached blocks before raising OOM. # We must check the unused cached blocks before raising OOM.
pass pass
if self._unused_cached_blocks: # If the evictor has blocks available for eviction, evict a block
# TODO policy for selecting block to remove # and return it.
content_hash_to_evict = next(iter(self._unused_cached_blocks)) if self.evictor.num_blocks > 0:
block_id, content_hash_to_evict = self.evictor.evict()
# Clear content hash mapping; the block will be overwritten. # Here we may have scenario that several blocks have
del self._cached_blocks[content_hash_to_evict] # the same content hash, but due to the latter coming block
# is coming from mutable to immutable path, their physical
# block is added into evictor.
# However in this case, we shall not pop the _cached_blocks,
# as the same content is still used by others, which means
# we need to check ref before decide to pop the list.
block_id = self._unused_cached_blocks.pop(content_hash_to_evict) _block_id = self._cached_blocks[content_hash_to_evict]
refcount = self._refcounter.incr(block_id) refcount = self._refcounter.get(_block_id)
assert refcount == 1 if refcount == 1:
self._cached_blocks.pop(content_hash_to_evict)
assert _block_id == block_id
self._refcounter.incr(block_id)
# the block comes from evictor already contain computed result
block = self._create_block( block = self._create_block(
prev_block=prev_block, prev_block=prev_block,
token_ids=[], token_ids=[],
block_size=self._block_size, block_size=self._block_size,
allocator=self, allocator=self,
block_id=block_id, block_id=block_id,
computed=True,
) )
assert block.content_hash is None assert block.content_hash is None
assert block.block_id not in self._blocks
self._blocks[block.block_id] = block
return block return block
# No block available in hashless allocator, nor in unused cache blocks. # No block available in hashless allocator, nor in unused cache blocks.
raise BlockAllocator.NoFreeBlocksError() raise BlockAllocator.NoFreeBlocksError()
def _incr_refcount_cached_block(self, content_hash: int, def _incr_refcount_cached_block(self, block: Block,
block_id: BlockId) -> None: block_id: BlockId) -> None:
# since block is already computed, mark it
block.computed = True
refcount = self._refcounter.incr(block_id) refcount = self._refcounter.incr(block_id)
if refcount == 1: if refcount == 1:
assert content_hash in self._unused_cached_blocks # if block get referred, then it shall not be in evictor
del self._unused_cached_blocks[content_hash] # and put it into _blocks for tracking
if block_id in self.evictor:
self.evictor.remove(block_id)
self._blocks[block_id] = block
def free(self, block: Block) -> None: def free(self, block: Block) -> None:
"""Decrement the refcount of the block. If the decremented refcount is """Decrement the refcount of the block. If the decremented refcount is
@ -180,6 +213,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
is not None), "freeing unallocated block is undefined" is not None), "freeing unallocated block is undefined"
self._free_block_id_for_block(block.block_id, block) self._free_block_id_for_block(block.block_id, block)
block.block_id = None block.block_id = None
def _free_block_id_for_block(self, block_id: BlockId, def _free_block_id_for_block(self, block_id: BlockId,
@ -187,15 +221,21 @@ class PrefixCachingBlockAllocator(BlockAllocator):
assert isinstance(block, PrefixCachingBlock) assert isinstance(block, PrefixCachingBlock)
if block.content_hash is None: if block.content_hash is None:
refcount = self._refcounter.get(block_id)
# We have fork case where block would get more than one ref,
# so we cannot free it from tracking if ref cnt large than 1
if refcount <= 1:
del self._blocks[block.block_id]
return self._hashless_allocator.free(block) return self._hashless_allocator.free(block)
refcount = self._refcounter.decr(block_id) refcount = self._refcounter.decr(block_id)
# If no longer used, add the block to the unused cached blocks. # If no longer used, add the block to the evictor.
if refcount == 0: if refcount == 0:
assert block.content_hash not in self._unused_cached_blocks
assert block.content_hash in self._cached_blocks assert block.content_hash in self._cached_blocks
self._unused_cached_blocks[block.content_hash] = block_id del self._blocks[block.block_id]
self.evictor.add(block.block_id, block.content_hash,
block.num_tokens_total, block.last_accessed)
def fork(self, last_block: Block) -> List[Block]: def fork(self, last_block: Block) -> List[Block]:
"""Creates a new sequence of blocks that shares the same underlying """Creates a new sequence of blocks that shares the same underlying
@ -230,9 +270,9 @@ class PrefixCachingBlockAllocator(BlockAllocator):
def get_num_free_blocks(self) -> int: def get_num_free_blocks(self) -> int:
# The number of free blocks is the number of hashless free blocks # The number of free blocks is the number of hashless free blocks
# plus the number of hashful blocks that are unused. # plus the number of blocks evictor could free from its list.
return self._hashless_allocator.get_num_free_blocks() + len( return self._hashless_allocator.get_num_free_blocks(
self._unused_cached_blocks) ) + self.evictor.num_blocks
@property @property
def all_block_ids(self) -> frozenset[int]: def all_block_ids(self) -> frozenset[int]:
@ -266,7 +306,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
else: else:
self._free_block_id_for_block(block.block_id, block) self._free_block_id_for_block(block.block_id, block)
self._incr_refcount_cached_block( self._incr_refcount_cached_block(
block.content_hash, self._cached_blocks[block.content_hash]) block, self._cached_blocks[block.content_hash])
return self._cached_blocks[block.content_hash] return self._cached_blocks[block.content_hash]
@ -293,29 +333,60 @@ class PrefixCachingBlockAllocator(BlockAllocator):
""" """
return self._cow_tracker.clear_cows() return self._cow_tracker.clear_cows()
def mark_blocks_as_computed(self) -> None: def mark_blocks_as_accessed(self, block_ids: List[int],
now: float) -> None:
"""Mark blocks as accessed, used in prefix caching.
If the block is added into evictor, we need to update corresponding
info in evictor's metadata.
"""
for block_id in block_ids:
if block_id in self._blocks:
self._blocks[block_id].last_accessed = now
elif block_id in self.evictor:
self.evictor.update(block_id, now)
else:
raise ValueError(
"Mark block as accessed which is not belonged to GPU")
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
"""Mark blocks as computed, used in prefix caching.""" """Mark blocks as computed, used in prefix caching."""
# TODO Track computed blocks.
pass for block_id in block_ids:
if block_id in self._blocks:
# only those full block is valid for prefix caching
if self._blocks[block_id].is_full:
self._blocks[block_id].computed = True
elif block_id not in self.evictor:
raise ValueError(f"Mark {block_id=} as computed which "
"is not belonged to GPU")
def block_is_computed(self, block_id: int) -> bool:
if block_id in self._blocks:
return self._blocks[block_id].computed
else:
return block_id in self.evictor
def get_common_computed_block_ids( def get_common_computed_block_ids(
self, seq_block_ids: List[List[int]]) -> List[int]: self, seq_block_ids: List[List[int]]) -> List[int]:
"""Return the block ids that are common for a given sequence group. """Return the block ids that are common for a given sequence group.
Used in prefill (can skip prefill of some blocks). Only those blocks that are immutable and already be marked
compyted would be taken consideration.
""" """
# TODO: Track computed blocks.
computed = lambda block_id: False
# NOTE We exclude the last block to avoid the case where the entire # NOTE We exclude the last block to avoid the case where the entire
# prompt is cached. This would cause erroneous behavior in model # prompt is cached. This would cause erroneous behavior in model
# runner. # runner.
ids_list = [ ids_list = [
takewhile(lambda block_id: computed(block_id), seq[:-1]) list(
for seq in seq_block_ids takewhile(lambda block_id: self.block_is_computed(block_id),
seq[:-1])) for seq in seq_block_ids
] ]
return commonprefix([ids for ids in ids_list if ids != []]) res = commonprefix([ids for ids in ids_list if ids != []])
return res
class PrefixCachingBlock(Block): class PrefixCachingBlock(Block):
@ -345,12 +416,16 @@ class PrefixCachingBlock(Block):
block_size: int, block_size: int,
prefix_caching_allocator: PrefixCachingBlockAllocator, prefix_caching_allocator: PrefixCachingBlockAllocator,
block_id: Optional[int] = None, block_id: Optional[int] = None,
computed: Optional[bool] = False,
): ):
assert_prefix_caching_block_or_none(prev_block) assert_prefix_caching_block_or_none(prev_block)
self._prev_block = prev_block self._prev_block = prev_block
self._cached_content_hash: Optional[int] = None self._cached_content_hash: Optional[int] = None
self._cached_num_tokens_total: Optional[int] = None
self._prefix_caching_allocator = prefix_caching_allocator self._prefix_caching_allocator = prefix_caching_allocator
self.last_accessed = _DEFAULT_LAST_ACCESSED_TIME
self.computed = computed
self._block = NaiveBlock( self._block = NaiveBlock(
prev_block=prev_block, prev_block=prev_block,
@ -398,6 +473,27 @@ class PrefixCachingBlock(Block):
def num_empty_slots(self) -> int: def num_empty_slots(self) -> int:
return self._block.num_empty_slots return self._block.num_empty_slots
@property
def num_tokens_total(self) -> int:
"""return the total tokens so far.
Here we iterate the block chain till to the first block, while
cache the result in local to prevent repeated computations.
"""
if self._cached_num_tokens_total is not None:
return self._cached_num_tokens_total
_block = self
self._cached_num_tokens_total = 0
# TODO: current implement here take O(N^2), we expect future
# we have O(1) here
while _block is not None:
self._cached_num_tokens_total += len(_block.token_ids)
_block = _block.prev_block
return self._cached_num_tokens_total
@property @property
def block_size(self) -> int: def block_size(self) -> int:
return self._block.block_size return self._block.block_size

View File

@ -8,7 +8,7 @@ from typing import Sequence as GenericSequence
from typing import Set from typing import Set
from vllm.block import BlockTable, PhysicalTokenBlock from vllm.block import BlockTable, PhysicalTokenBlock
from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor from vllm.core.evictor_v1 import EvictionPolicy, Evictor, make_evictor
from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.sequence import Sequence, SequenceGroup, SequenceStatus

View File

@ -72,14 +72,12 @@ class BlockSpaceManagerV2(BlockSpaceManager):
self.watermark = watermark self.watermark = watermark
assert watermark >= 0.0 assert watermark >= 0.0
assert not enable_caching, "Prefix caching not yet supported"
self.enable_caching = enable_caching self.enable_caching = enable_caching
self.watermark_blocks = int(watermark * num_gpu_blocks) self.watermark_blocks = int(watermark * num_gpu_blocks)
self.block_allocator = CpuGpuBlockAllocator.create( self.block_allocator = CpuGpuBlockAllocator.create(
# Currently, only naive blocks are supported (no prefix caching). allocator_type="prefix_caching" if enable_caching else "naive",
allocator_type="naive",
num_gpu_blocks=num_gpu_blocks, num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks, num_cpu_blocks=num_cpu_blocks,
block_size=block_size, block_size=block_size,
@ -194,17 +192,26 @@ class BlockSpaceManagerV2(BlockSpaceManager):
assert all(b is not None for b in block_ids) assert all(b is not None for b in block_ids)
return block_ids return block_ids
def access_all_blocks_in_seq(self, seq, now): def access_all_blocks_in_seq(self, seq: Sequence, now: float):
# TODO add prefix caching support. # Update the last accessed time of all the blocks accessed
# Tracked here https://github.com/vllm-project/vllm/issues/3667 # in this step.
pass # And the accessed time is only useful for prefix caching now,
# as it support internal evictor policy for which cached
# block could be refilled, to keep cached content could be reused
# at max extend.
if self.enable_caching:
block_table = self.block_tables[seq.seq_id]
block_ids = []
for block_id in block_table.physical_block_ids:
block_ids.append(block_id)
self.block_allocator.mark_blocks_as_accessed(block_ids, now)
def mark_blocks_as_computed(self, seq_group: SequenceGroup): def mark_blocks_as_computed(self, seq_group: SequenceGroup):
# We ignore the sequence group as its not necessary. After the batch is # The only need for mark block as computed is for prefix caching,
# formed by the scheduler, we do not need to mark blocks from individual # while currently we could determine whether one block is computed
# sequence groups as computed -- all blocks in the batch can be marked # or not by check whether it has content hash.
# as computed. # So this function is useless for block_v2.
self.block_allocator.mark_blocks_as_computed() pass
def get_common_computed_block_ids( def get_common_computed_block_ids(
self, seqs: List[Sequence]) -> GenericSequence[int]: self, seqs: List[Sequence]) -> GenericSequence[int]:

122
vllm/core/evictor_v2.py Normal file
View File

@ -0,0 +1,122 @@
import enum
from abc import ABC, abstractmethod, abstractproperty
from typing import OrderedDict, Tuple
class EvictionPolicy(enum.Enum):
"""Enum for eviction policy used by make_evictor to instantiate the correct
Evictor subclass.
"""
LRU = enum.auto()
class Evictor(ABC):
"""The Evictor subclasses should be used by the BlockAllocator class to
handle eviction of freed PhysicalTokenBlocks.
"""
@abstractmethod
def __init__(self):
pass
@abstractmethod
def __contains__(self, block_id: int) -> bool:
pass
@abstractmethod
def evict(self) -> Tuple[int, int]:
"""Runs the eviction algorithm and returns the evicted block's
content hash along with physical block id along with physical block id
"""
pass
@abstractmethod
def add(self, block_id: int, content_hash: int, num_hashed_tokens: int,
last_accessed: int):
"""Adds block to the evictor, making it a candidate for eviction"""
pass
@abstractmethod
def update(self, block_id: int, last_accessed: int):
"""Update corresponding block's access time in metadata"""
pass
@abstractproperty
def num_blocks(self) -> int:
pass
class BlockMetaData():
"""Data structure for storing key data describe cached block, so that
evitor could use to make its decision which one to choose for eviction
Here we use physical block id as the dict key, as there maybe several
blocks with the same content hash, but their physical id is unique.
"""
def __init__(self, content_hash: int, num_hashed_tokens: int,
last_accessed: int):
self.content_hash = content_hash
self.num_hashed_tokens = num_hashed_tokens
self.last_accessed = last_accessed
class LRUEvictor(Evictor):
"""Evicts in a least-recently-used order using the last_accessed timestamp
that's recorded in the PhysicalTokenBlock. If there are multiple blocks with
the same last_accessed time, then the one with the largest num_hashed_tokens
will be evicted. If two blocks each have the lowest last_accessed time and
highest num_hashed_tokens value, then one will be chose arbitrarily
"""
def __init__(self):
self.free_table: OrderedDict[int, BlockMetaData] = OrderedDict()
def __contains__(self, block_id: int) -> bool:
return block_id in self.free_table
def evict(self) -> Tuple[int, int]:
if len(self.free_table) == 0:
raise ValueError("No usable cache memory left")
evicted_block = next(iter(self.free_table.values()))
evicted_block_id = next(iter(self.free_table.keys()))
# The blocks with the lowest timestamps should be placed consecutively
# at the start of OrderedDict. Loop through all these blocks to
# find the one with maximum number of hashed tokens.
for _id, block in self.free_table.items():
if evicted_block.last_accessed > block.last_accessed or (
evicted_block.last_accessed == block.last_accessed and
evicted_block.num_hashed_tokens < block.num_hashed_tokens):
evicted_block = block
evicted_block_id = _id
self.free_table.pop(evicted_block_id)
return evicted_block_id, evicted_block.content_hash
def add(self, block_id: int, content_hash: int, num_hashed_tokens: int,
last_accessed: int):
self.free_table[block_id] = BlockMetaData(content_hash,
num_hashed_tokens,
last_accessed)
def update(self, block_id: int, last_accessed: int):
self.free_table[block_id].last_accessed = last_accessed
def remove(self, block_id: int):
if block_id not in self.free_table:
raise ValueError(
"Attempting to remove block that's not in the evictor")
self.free_table.pop(block_id)
@property
def num_blocks(self) -> int:
return len(self.free_table)
def make_evictor(eviction_policy: EvictionPolicy) -> Evictor:
if eviction_policy == EvictionPolicy.LRU:
return LRUEvictor()
else:
raise ValueError(f"Unknown cache eviction policy: {eviction_policy}")