[PREFIX CACHING FOLLOW UP] A bunch of fixes to block allocator performance when automatic prefix caching is disabled (#3357)

Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
ElizaWszola 2024-03-20 08:11:11 +01:00 committed by GitHub
parent 20478c4d3a
commit 9474e89ba4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 171 additions and 127 deletions

View File

@ -4,7 +4,7 @@ from typing import List
from vllm import SamplingParams
from vllm.block import PhysicalTokenBlock
from vllm.core.block_manager import (BlockAllocator, BlockSpaceManager,
from vllm.core.block_manager import (UncachedBlockAllocator, BlockSpaceManager,
AllocStatus)
from vllm.utils import Device
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus, Logprob
@ -15,7 +15,8 @@ from .utils import create_dummy_prompt
def test_block_allocator_allocate():
block_size = 4
num_cpu_blocks = 4
cpu_allocator = BlockAllocator(Device.CPU, block_size, num_cpu_blocks)
cpu_allocator = UncachedBlockAllocator(Device.CPU, block_size,
num_cpu_blocks)
# Allocate all available cpu blocks.
num_free = num_cpu_blocks
@ -24,7 +25,7 @@ def test_block_allocator_allocate():
block = cpu_allocator.allocate()
num_free -= 1
assert block.block_hash not in cpu_allocator.evictor
assert block not in cpu_allocator.free_blocks
assert cpu_allocator.get_num_free_blocks() == num_free
with pytest.raises(ValueError):
@ -34,14 +35,15 @@ def test_block_allocator_allocate():
def test_block_allocator_free():
block_size = 4
num_cpu_blocks = 4
cpu_allocator = BlockAllocator(Device.CPU, block_size, num_cpu_blocks)
cpu_allocator = UncachedBlockAllocator(Device.CPU, block_size,
num_cpu_blocks)
# Allocate all available cpu blocks.
blocks: List[PhysicalTokenBlock] = []
for _ in range(num_cpu_blocks):
block = cpu_allocator.allocate()
blocks.append(block)
assert block.block_hash not in cpu_allocator.evictor
assert block not in cpu_allocator.free_blocks
# Free all allocated cpu blocks.
num_free = 0
@ -49,7 +51,7 @@ def test_block_allocator_free():
for block in blocks:
cpu_allocator.free(block)
num_free += 1
assert block.block_hash in cpu_allocator.evictor
assert block in cpu_allocator.free_blocks
assert cpu_allocator.get_num_free_blocks() == num_free
with pytest.raises(ValueError):

View File

@ -4,7 +4,7 @@ Run `pytest tests/prefix_caching/test_prefix_caching.py`.
"""
import pytest
from vllm.core.block_manager import BlockAllocator
from vllm.core.block_manager import CachedBlockAllocator
from vllm.utils import Device
@ -15,10 +15,7 @@ def test_block_allocator(
num_blocks: int,
):
block_hash = 1
block_allocator = BlockAllocator(Device.CPU,
block_size,
num_blocks,
enable_caching=True)
block_allocator = CachedBlockAllocator(Device.CPU, block_size, num_blocks)
# Allocate two PysicalTokenBlocks with the same hash and check
# that they are the same PhysicalTokenBlock
@ -45,10 +42,7 @@ def test_block_allocator(
@pytest.mark.parametrize("num_blocks", [16])
def test_eviction(num_blocks: int, ):
block_size = 16
block_allocator = BlockAllocator(Device.CPU,
block_size,
num_blocks,
enable_caching=True)
block_allocator = CachedBlockAllocator(Device.CPU, block_size, num_blocks)
blocks = []
for i in range(num_blocks):

View File

@ -3,6 +3,7 @@ import enum
from itertools import count, takewhile
from os.path import commonprefix
from typing import Dict, List, Optional, Set, Tuple
from abc import ABC, abstractmethod
from vllm.block import BlockTable, PhysicalTokenBlock
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
@ -10,7 +11,46 @@ from vllm.utils import Device
from vllm.core.evictor import Evictor, EvictionPolicy, make_evictor
class BlockAllocator:
class BlockAllocatorBase(ABC):
"""Manages free physical token blocks for a device.
The allocator maintains a list of free blocks and allocates a block when
requested. When a block is freed, its reference count is decremented. If
the reference count becomes zero, the block is added back to the free list.
"""
@abstractmethod
def __init__(self,
device: Device,
block_size: int,
num_blocks: int,
eviction_policy: EvictionPolicy = EvictionPolicy.LRU):
pass
@abstractmethod
def allocate(self,
block_hash: Optional[int] = None,
num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
pass
@abstractmethod
def free(self, block: PhysicalTokenBlock) -> None:
pass
@abstractmethod
def get_num_free_blocks(self) -> int:
pass
@abstractmethod
def contains_block(self, block_hash: int) -> bool:
pass
@abstractmethod
def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
pass
class CachedBlockAllocator(BlockAllocatorBase):
"""Manages free physical token blocks for a device.
The allocator maintains a list of free blocks and allocates a block when
@ -22,19 +62,14 @@ class BlockAllocator:
device: Device,
block_size: int,
num_blocks: int,
eviction_policy: EvictionPolicy = EvictionPolicy.LRU,
enable_caching: bool = False) -> None:
eviction_policy: EvictionPolicy = EvictionPolicy.LRU) -> None:
self.device = device
self.block_size = block_size
self.num_blocks = num_blocks
self.enable_caching = enable_caching
self.current_num_blocks = 0
self.cached_blocks: Dict[int, PhysicalTokenBlock] = {}
# Switch over to FIFO eviction when caching is disabled
if not self.enable_caching:
eviction_policy = EvictionPolicy.FIFO
self.evictor: Evictor = make_evictor(eviction_policy)
self.default_hash_ctr = count()
@ -57,13 +92,6 @@ class BlockAllocator:
def allocate(self,
block_hash: Optional[int] = None,
num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
# If caching is disabled, just allocate a new block and return it
if not self.enable_caching:
block = self.allocate_block(next(self.default_hash_ctr),
num_hashed_tokens)
block.ref_count += 1
return block
if block_hash is None:
block_hash = next(self.default_hash_ctr)
if block_hash in self.evictor:
@ -90,9 +118,8 @@ class BlockAllocator:
assert block.block_hash not in self.evictor
self.evictor.add(block)
# If caching is enabled, remove the block from the cached_blocks
if self.enable_caching:
del self.cached_blocks[block.block_hash]
# Remove the block from the cached_blocks
del self.cached_blocks[block.block_hash]
def get_num_free_blocks(self) -> int:
return (self.num_blocks - self.current_num_blocks +
@ -102,14 +129,68 @@ class BlockAllocator:
return block_hash in self.cached_blocks or block_hash in self.evictor
def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
# If caching is enabled, update the hash of block and the
# cached_blocks dictionary.
if self.enable_caching:
assert not self.contains_block(block_hash)
old_hash = block.block_hash
block.block_hash = block_hash
del self.cached_blocks[old_hash]
self.cached_blocks[block_hash] = block
# Update the hash of block and the cached_blocks dictionary.
assert not self.contains_block(block_hash)
old_hash = block.block_hash
block.block_hash = block_hash
del self.cached_blocks[old_hash]
self.cached_blocks[block_hash] = block
class UncachedBlockAllocator(BlockAllocatorBase):
"""Manages free physical token blocks for a device.
The allocator maintains a list of free blocks and allocates a block when
requested. When a block is freed, its reference count is decremented. If
the reference count becomes zero, the block is added back to the free list.
"""
def __init__(
self,
device: Device,
block_size: int,
num_blocks: int,
) -> None:
self.device = device
self.block_size = block_size
self.num_blocks = num_blocks
# Initialize the free blocks.
self.free_blocks: BlockTable = []
for i in range(num_blocks):
block = PhysicalTokenBlock(device=device,
block_number=i,
block_size=block_size,
block_hash=-1,
num_hashed_tokens=0)
self.free_blocks.append(block)
def allocate(self,
block_hash: Optional[int] = None,
num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
if not self.free_blocks:
raise ValueError("Out of memory! No free blocks are available.")
block = self.free_blocks.pop()
block.ref_count = 1
return block
def free(self, block: PhysicalTokenBlock) -> None:
if block.ref_count == 0:
raise ValueError(f"Double free! {block} is already freed.")
block.ref_count -= 1
if block.ref_count == 0:
self.free_blocks.append(block)
def get_num_free_blocks(self) -> int:
return len(self.free_blocks)
def contains_block(self, block_hash: int) -> bool:
raise NotImplementedError(
"Invalid codepath for uncached block allocator.")
def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
raise NotImplementedError(
"Invalid codepath for uncached block allocator.")
class AllocStatus(enum.Enum):
@ -142,6 +223,10 @@ class BlockSpaceManager:
self.num_total_gpu_blocks = num_gpu_blocks
self.num_total_cpu_blocks = num_cpu_blocks
if enable_caching and sliding_window is not None:
raise NotImplementedError(
"Sliding window is not allowed with prefix caching enabled!")
self.block_sliding_window = None
if sliding_window is not None:
assert sliding_window % block_size == 0, (sliding_window,
@ -154,14 +239,17 @@ class BlockSpaceManager:
self.enable_caching = enable_caching
self.watermark_blocks = int(watermark * num_gpu_blocks)
self.gpu_allocator = BlockAllocator(Device.GPU,
block_size,
num_gpu_blocks,
enable_caching=enable_caching)
self.cpu_allocator = BlockAllocator(Device.CPU,
block_size,
num_cpu_blocks,
enable_caching=enable_caching)
if self.enable_caching:
self.gpu_allocator = CachedBlockAllocator(Device.GPU, block_size,
num_gpu_blocks)
self.cpu_allocator = CachedBlockAllocator(Device.CPU, block_size,
num_cpu_blocks)
else:
self.gpu_allocator = UncachedBlockAllocator(
Device.GPU, block_size, num_gpu_blocks)
self.cpu_allocator = UncachedBlockAllocator(
Device.CPU, block_size, num_cpu_blocks)
# Mapping: seq_id -> BlockTable.
self.block_tables: Dict[int, BlockTable] = {}
@ -198,10 +286,16 @@ class BlockSpaceManager:
if (self.block_sliding_window is not None
and logical_idx >= self.block_sliding_window):
block = block_table[logical_idx % self.block_sliding_window]
else:
# Set the reference counts of the token blocks.
block.ref_count = seq_group.num_seqs()
elif self.enable_caching:
block = self.gpu_allocator.allocate(
seq.hash_of_block(logical_idx),
seq.num_hashed_tokens_of_block(logical_idx))
else:
block = self.gpu_allocator.allocate()
# Set the reference counts of the token blocks.
block.ref_count = seq_group.num_seqs()
block_table.append(block)
# Assign the block table for each sequence.
@ -220,8 +314,10 @@ class BlockSpaceManager:
seq: Sequence,
last_block: PhysicalTokenBlock,
) -> PhysicalTokenBlock:
# Compute a new hash for the block so that it can be shared by
# other Sequences
assert self.enable_caching
# Compute a new hash for the block so that it can be shared by other
# Sequences
new_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
# if new_hash is already in the cached table, then free last_block
@ -254,6 +350,8 @@ class BlockSpaceManager:
self,
seq: Sequence,
) -> PhysicalTokenBlock:
if not self.enable_caching:
return self.gpu_allocator.allocate()
block_hash: Optional[int] = None
if (self._is_last_block_full(seq)):
block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
@ -293,10 +391,12 @@ class BlockSpaceManager:
assert last_block.device == Device.GPU
if last_block.ref_count == 1:
# Not shared with other sequences. Appendable.
# If the last block is now complete, promote it to a full block so
# that it can be shared
new_block = self._maybe_promote_last_block(seq, last_block)
block_table[-1] = new_block
if self.enable_caching:
# If the last block is now complete, we may reuse an old block
# to save memory.
maybe_new_block = self._maybe_promote_last_block(
seq, last_block)
block_table[-1] = maybe_new_block
return None
else:
# The last block is shared with other sequences.
@ -440,9 +540,12 @@ class BlockSpaceManager:
seq: Sequence,
access_time: float,
) -> None:
block_table = self.block_tables[seq.seq_id]
for block in block_table:
block.last_accessed = access_time
if self.enable_caching:
# Update the last accessed time of all the blocks accessed
# in this step.
block_table = self.block_tables[seq.seq_id]
for block in block_table:
block.last_accessed = access_time
def compute_full_blocks_in_seq(self, seq: Sequence):
if seq.seq_id not in self.block_tables:

View File

@ -1,5 +1,5 @@
import enum
from typing import Dict, List, Optional
from typing import Dict
from abc import ABC, abstractmethod, abstractproperty
from vllm.block import PhysicalTokenBlock
@ -10,7 +10,6 @@ class EvictionPolicy(enum.Enum):
Evictor subclass.
"""
LRU = enum.auto()
FIFO = enum.auto()
class Evictor(ABC):
@ -65,75 +64,23 @@ class LRUEvictor(Evictor):
return block_hash in self.free_table
# TODO: The performance of this evict function can be optimized further.
def evict(self) -> PhysicalTokenBlock:
free_blocks: List[PhysicalTokenBlock] = list(self.free_table.values())
if len(free_blocks) == 0:
raise ValueError("No usable cache memory left")
# Find lowest timestamp
lowest_timestamp = free_blocks[0].last_accessed
for block in free_blocks:
if block.last_accessed < lowest_timestamp:
lowest_timestamp = block.last_accessed
# Find all blocks with the lowest timestamp
least_recent: List[PhysicalTokenBlock] = []
for block in free_blocks:
if block.last_accessed == lowest_timestamp:
least_recent.append(block)
# Find highest prefix count per block
highest_num_hashed_tokens = 0
for block in least_recent:
if block.num_hashed_tokens > highest_num_hashed_tokens:
highest_num_hashed_tokens = block.num_hashed_tokens
evicted_block: Optional[PhysicalTokenBlock] = None
# Find the first block with the lowest timestamp
for block in least_recent:
if block.num_hashed_tokens == highest_num_hashed_tokens:
evicted_block = block
break
assert evicted_block is not None
del self.free_table[evicted_block.block_hash]
evicted_block.computed = False
return evicted_block
def add(self, block: PhysicalTokenBlock):
self.free_table[block.block_hash] = block
def remove(self, block_hash: int) -> PhysicalTokenBlock:
if block_hash not in self.free_table:
raise ValueError(
"Attempting to remove block that's not in the evictor")
block: PhysicalTokenBlock = self.free_table[block_hash]
del self.free_table[block_hash]
return block
@property
def num_blocks(self) -> int:
return len(self.free_table)
class RandomEvictor(Evictor):
"""Evicts in a first-in-first-out order"""
def __init__(self):
self.free_table: Dict[int, PhysicalTokenBlock] = {}
def __contains__(self, block_hash: int) -> bool:
return block_hash in self.free_table
def evict(self) -> PhysicalTokenBlock:
if len(self.free_table) == 0:
raise ValueError("No usable cache memory left")
evicted_block = next(iter(self.free_table.values()))
evicted_block.computed = False
free_blocks = self.free_table.values()
# Get evicted block
evicted_block: PhysicalTokenBlock = next(iter(free_blocks))
for block in free_blocks:
if (block.last_accessed < evicted_block.last_accessed
or block.last_accessed == evicted_block.last_accessed and
block.num_hashed_tokens > evicted_block.num_hashed_tokens):
evicted_block = block
del self.free_table[evicted_block.block_hash]
evicted_block.computed = False
return evicted_block
def add(self, block: PhysicalTokenBlock):
@ -155,7 +102,5 @@ class RandomEvictor(Evictor):
def make_evictor(eviction_policy: EvictionPolicy) -> Evictor:
if eviction_policy == EvictionPolicy.LRU:
return LRUEvictor()
elif eviction_policy == EvictionPolicy.FIFO:
return RandomEvictor()
else:
raise ValueError(f"Unknown cache eviction policy: {eviction_policy}")