[PREFIX CACHING FOLLOW UP] A bunch of fixes to block allocator performance when automatic prefix caching is disabled (#3357)

Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
ElizaWszola 2024-03-20 08:11:11 +01:00 committed by GitHub
parent 20478c4d3a
commit 9474e89ba4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 171 additions and 127 deletions

View File

@ -4,7 +4,7 @@ from typing import List
from vllm import SamplingParams from vllm import SamplingParams
from vllm.block import PhysicalTokenBlock from vllm.block import PhysicalTokenBlock
from vllm.core.block_manager import (BlockAllocator, BlockSpaceManager, from vllm.core.block_manager import (UncachedBlockAllocator, BlockSpaceManager,
AllocStatus) AllocStatus)
from vllm.utils import Device from vllm.utils import Device
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus, Logprob from vllm.sequence import Sequence, SequenceGroup, SequenceStatus, Logprob
@ -15,7 +15,8 @@ from .utils import create_dummy_prompt
def test_block_allocator_allocate(): def test_block_allocator_allocate():
block_size = 4 block_size = 4
num_cpu_blocks = 4 num_cpu_blocks = 4
cpu_allocator = BlockAllocator(Device.CPU, block_size, num_cpu_blocks) cpu_allocator = UncachedBlockAllocator(Device.CPU, block_size,
num_cpu_blocks)
# Allocate all available cpu blocks. # Allocate all available cpu blocks.
num_free = num_cpu_blocks num_free = num_cpu_blocks
@ -24,7 +25,7 @@ def test_block_allocator_allocate():
block = cpu_allocator.allocate() block = cpu_allocator.allocate()
num_free -= 1 num_free -= 1
assert block.block_hash not in cpu_allocator.evictor assert block not in cpu_allocator.free_blocks
assert cpu_allocator.get_num_free_blocks() == num_free assert cpu_allocator.get_num_free_blocks() == num_free
with pytest.raises(ValueError): with pytest.raises(ValueError):
@ -34,14 +35,15 @@ def test_block_allocator_allocate():
def test_block_allocator_free(): def test_block_allocator_free():
block_size = 4 block_size = 4
num_cpu_blocks = 4 num_cpu_blocks = 4
cpu_allocator = BlockAllocator(Device.CPU, block_size, num_cpu_blocks) cpu_allocator = UncachedBlockAllocator(Device.CPU, block_size,
num_cpu_blocks)
# Allocate all available cpu blocks. # Allocate all available cpu blocks.
blocks: List[PhysicalTokenBlock] = [] blocks: List[PhysicalTokenBlock] = []
for _ in range(num_cpu_blocks): for _ in range(num_cpu_blocks):
block = cpu_allocator.allocate() block = cpu_allocator.allocate()
blocks.append(block) blocks.append(block)
assert block.block_hash not in cpu_allocator.evictor assert block not in cpu_allocator.free_blocks
# Free all allocated cpu blocks. # Free all allocated cpu blocks.
num_free = 0 num_free = 0
@ -49,7 +51,7 @@ def test_block_allocator_free():
for block in blocks: for block in blocks:
cpu_allocator.free(block) cpu_allocator.free(block)
num_free += 1 num_free += 1
assert block.block_hash in cpu_allocator.evictor assert block in cpu_allocator.free_blocks
assert cpu_allocator.get_num_free_blocks() == num_free assert cpu_allocator.get_num_free_blocks() == num_free
with pytest.raises(ValueError): with pytest.raises(ValueError):

View File

@ -4,7 +4,7 @@ Run `pytest tests/prefix_caching/test_prefix_caching.py`.
""" """
import pytest import pytest
from vllm.core.block_manager import BlockAllocator from vllm.core.block_manager import CachedBlockAllocator
from vllm.utils import Device from vllm.utils import Device
@ -15,10 +15,7 @@ def test_block_allocator(
num_blocks: int, num_blocks: int,
): ):
block_hash = 1 block_hash = 1
block_allocator = BlockAllocator(Device.CPU, block_allocator = CachedBlockAllocator(Device.CPU, block_size, num_blocks)
block_size,
num_blocks,
enable_caching=True)
# Allocate two PysicalTokenBlocks with the same hash and check # Allocate two PysicalTokenBlocks with the same hash and check
# that they are the same PhysicalTokenBlock # that they are the same PhysicalTokenBlock
@ -45,10 +42,7 @@ def test_block_allocator(
@pytest.mark.parametrize("num_blocks", [16]) @pytest.mark.parametrize("num_blocks", [16])
def test_eviction(num_blocks: int, ): def test_eviction(num_blocks: int, ):
block_size = 16 block_size = 16
block_allocator = BlockAllocator(Device.CPU, block_allocator = CachedBlockAllocator(Device.CPU, block_size, num_blocks)
block_size,
num_blocks,
enable_caching=True)
blocks = [] blocks = []
for i in range(num_blocks): for i in range(num_blocks):

View File

@ -3,6 +3,7 @@ import enum
from itertools import count, takewhile from itertools import count, takewhile
from os.path import commonprefix from os.path import commonprefix
from typing import Dict, List, Optional, Set, Tuple from typing import Dict, List, Optional, Set, Tuple
from abc import ABC, abstractmethod
from vllm.block import BlockTable, PhysicalTokenBlock from vllm.block import BlockTable, PhysicalTokenBlock
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
@ -10,7 +11,46 @@ from vllm.utils import Device
from vllm.core.evictor import Evictor, EvictionPolicy, make_evictor from vllm.core.evictor import Evictor, EvictionPolicy, make_evictor
class BlockAllocator: class BlockAllocatorBase(ABC):
"""Manages free physical token blocks for a device.
The allocator maintains a list of free blocks and allocates a block when
requested. When a block is freed, its reference count is decremented. If
the reference count becomes zero, the block is added back to the free list.
"""
@abstractmethod
def __init__(self,
device: Device,
block_size: int,
num_blocks: int,
eviction_policy: EvictionPolicy = EvictionPolicy.LRU):
pass
@abstractmethod
def allocate(self,
block_hash: Optional[int] = None,
num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
pass
@abstractmethod
def free(self, block: PhysicalTokenBlock) -> None:
pass
@abstractmethod
def get_num_free_blocks(self) -> int:
pass
@abstractmethod
def contains_block(self, block_hash: int) -> bool:
pass
@abstractmethod
def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
pass
class CachedBlockAllocator(BlockAllocatorBase):
"""Manages free physical token blocks for a device. """Manages free physical token blocks for a device.
The allocator maintains a list of free blocks and allocates a block when The allocator maintains a list of free blocks and allocates a block when
@ -22,19 +62,14 @@ class BlockAllocator:
device: Device, device: Device,
block_size: int, block_size: int,
num_blocks: int, num_blocks: int,
eviction_policy: EvictionPolicy = EvictionPolicy.LRU, eviction_policy: EvictionPolicy = EvictionPolicy.LRU) -> None:
enable_caching: bool = False) -> None:
self.device = device self.device = device
self.block_size = block_size self.block_size = block_size
self.num_blocks = num_blocks self.num_blocks = num_blocks
self.enable_caching = enable_caching
self.current_num_blocks = 0 self.current_num_blocks = 0
self.cached_blocks: Dict[int, PhysicalTokenBlock] = {} self.cached_blocks: Dict[int, PhysicalTokenBlock] = {}
# Switch over to FIFO eviction when caching is disabled
if not self.enable_caching:
eviction_policy = EvictionPolicy.FIFO
self.evictor: Evictor = make_evictor(eviction_policy) self.evictor: Evictor = make_evictor(eviction_policy)
self.default_hash_ctr = count() self.default_hash_ctr = count()
@ -57,13 +92,6 @@ class BlockAllocator:
def allocate(self, def allocate(self,
block_hash: Optional[int] = None, block_hash: Optional[int] = None,
num_hashed_tokens: int = 0) -> PhysicalTokenBlock: num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
# If caching is disabled, just allocate a new block and return it
if not self.enable_caching:
block = self.allocate_block(next(self.default_hash_ctr),
num_hashed_tokens)
block.ref_count += 1
return block
if block_hash is None: if block_hash is None:
block_hash = next(self.default_hash_ctr) block_hash = next(self.default_hash_ctr)
if block_hash in self.evictor: if block_hash in self.evictor:
@ -90,9 +118,8 @@ class BlockAllocator:
assert block.block_hash not in self.evictor assert block.block_hash not in self.evictor
self.evictor.add(block) self.evictor.add(block)
# If caching is enabled, remove the block from the cached_blocks # Remove the block from the cached_blocks
if self.enable_caching: del self.cached_blocks[block.block_hash]
del self.cached_blocks[block.block_hash]
def get_num_free_blocks(self) -> int: def get_num_free_blocks(self) -> int:
return (self.num_blocks - self.current_num_blocks + return (self.num_blocks - self.current_num_blocks +
@ -102,14 +129,68 @@ class BlockAllocator:
return block_hash in self.cached_blocks or block_hash in self.evictor return block_hash in self.cached_blocks or block_hash in self.evictor
def update_hash(self, block_hash: int, block: PhysicalTokenBlock): def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
# If caching is enabled, update the hash of block and the # Update the hash of block and the cached_blocks dictionary.
# cached_blocks dictionary. assert not self.contains_block(block_hash)
if self.enable_caching: old_hash = block.block_hash
assert not self.contains_block(block_hash) block.block_hash = block_hash
old_hash = block.block_hash del self.cached_blocks[old_hash]
block.block_hash = block_hash self.cached_blocks[block_hash] = block
del self.cached_blocks[old_hash]
self.cached_blocks[block_hash] = block
class UncachedBlockAllocator(BlockAllocatorBase):
"""Manages free physical token blocks for a device.
The allocator maintains a list of free blocks and allocates a block when
requested. When a block is freed, its reference count is decremented. If
the reference count becomes zero, the block is added back to the free list.
"""
def __init__(
self,
device: Device,
block_size: int,
num_blocks: int,
) -> None:
self.device = device
self.block_size = block_size
self.num_blocks = num_blocks
# Initialize the free blocks.
self.free_blocks: BlockTable = []
for i in range(num_blocks):
block = PhysicalTokenBlock(device=device,
block_number=i,
block_size=block_size,
block_hash=-1,
num_hashed_tokens=0)
self.free_blocks.append(block)
def allocate(self,
block_hash: Optional[int] = None,
num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
if not self.free_blocks:
raise ValueError("Out of memory! No free blocks are available.")
block = self.free_blocks.pop()
block.ref_count = 1
return block
def free(self, block: PhysicalTokenBlock) -> None:
if block.ref_count == 0:
raise ValueError(f"Double free! {block} is already freed.")
block.ref_count -= 1
if block.ref_count == 0:
self.free_blocks.append(block)
def get_num_free_blocks(self) -> int:
return len(self.free_blocks)
def contains_block(self, block_hash: int) -> bool:
raise NotImplementedError(
"Invalid codepath for uncached block allocator.")
def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
raise NotImplementedError(
"Invalid codepath for uncached block allocator.")
class AllocStatus(enum.Enum): class AllocStatus(enum.Enum):
@ -142,6 +223,10 @@ class BlockSpaceManager:
self.num_total_gpu_blocks = num_gpu_blocks self.num_total_gpu_blocks = num_gpu_blocks
self.num_total_cpu_blocks = num_cpu_blocks self.num_total_cpu_blocks = num_cpu_blocks
if enable_caching and sliding_window is not None:
raise NotImplementedError(
"Sliding window is not allowed with prefix caching enabled!")
self.block_sliding_window = None self.block_sliding_window = None
if sliding_window is not None: if sliding_window is not None:
assert sliding_window % block_size == 0, (sliding_window, assert sliding_window % block_size == 0, (sliding_window,
@ -154,14 +239,17 @@ class BlockSpaceManager:
self.enable_caching = enable_caching self.enable_caching = enable_caching
self.watermark_blocks = int(watermark * num_gpu_blocks) self.watermark_blocks = int(watermark * num_gpu_blocks)
self.gpu_allocator = BlockAllocator(Device.GPU,
block_size, if self.enable_caching:
num_gpu_blocks, self.gpu_allocator = CachedBlockAllocator(Device.GPU, block_size,
enable_caching=enable_caching) num_gpu_blocks)
self.cpu_allocator = BlockAllocator(Device.CPU, self.cpu_allocator = CachedBlockAllocator(Device.CPU, block_size,
block_size, num_cpu_blocks)
num_cpu_blocks, else:
enable_caching=enable_caching) self.gpu_allocator = UncachedBlockAllocator(
Device.GPU, block_size, num_gpu_blocks)
self.cpu_allocator = UncachedBlockAllocator(
Device.CPU, block_size, num_cpu_blocks)
# Mapping: seq_id -> BlockTable. # Mapping: seq_id -> BlockTable.
self.block_tables: Dict[int, BlockTable] = {} self.block_tables: Dict[int, BlockTable] = {}
@ -198,10 +286,16 @@ class BlockSpaceManager:
if (self.block_sliding_window is not None if (self.block_sliding_window is not None
and logical_idx >= self.block_sliding_window): and logical_idx >= self.block_sliding_window):
block = block_table[logical_idx % self.block_sliding_window] block = block_table[logical_idx % self.block_sliding_window]
else: # Set the reference counts of the token blocks.
block.ref_count = seq_group.num_seqs()
elif self.enable_caching:
block = self.gpu_allocator.allocate( block = self.gpu_allocator.allocate(
seq.hash_of_block(logical_idx), seq.hash_of_block(logical_idx),
seq.num_hashed_tokens_of_block(logical_idx)) seq.num_hashed_tokens_of_block(logical_idx))
else:
block = self.gpu_allocator.allocate()
# Set the reference counts of the token blocks.
block.ref_count = seq_group.num_seqs()
block_table.append(block) block_table.append(block)
# Assign the block table for each sequence. # Assign the block table for each sequence.
@ -220,8 +314,10 @@ class BlockSpaceManager:
seq: Sequence, seq: Sequence,
last_block: PhysicalTokenBlock, last_block: PhysicalTokenBlock,
) -> PhysicalTokenBlock: ) -> PhysicalTokenBlock:
# Compute a new hash for the block so that it can be shared by assert self.enable_caching
# other Sequences
# Compute a new hash for the block so that it can be shared by other
# Sequences
new_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1) new_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
# if new_hash is already in the cached table, then free last_block # if new_hash is already in the cached table, then free last_block
@ -254,6 +350,8 @@ class BlockSpaceManager:
self, self,
seq: Sequence, seq: Sequence,
) -> PhysicalTokenBlock: ) -> PhysicalTokenBlock:
if not self.enable_caching:
return self.gpu_allocator.allocate()
block_hash: Optional[int] = None block_hash: Optional[int] = None
if (self._is_last_block_full(seq)): if (self._is_last_block_full(seq)):
block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1) block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
@ -293,10 +391,12 @@ class BlockSpaceManager:
assert last_block.device == Device.GPU assert last_block.device == Device.GPU
if last_block.ref_count == 1: if last_block.ref_count == 1:
# Not shared with other sequences. Appendable. # Not shared with other sequences. Appendable.
# If the last block is now complete, promote it to a full block so if self.enable_caching:
# that it can be shared # If the last block is now complete, we may reuse an old block
new_block = self._maybe_promote_last_block(seq, last_block) # to save memory.
block_table[-1] = new_block maybe_new_block = self._maybe_promote_last_block(
seq, last_block)
block_table[-1] = maybe_new_block
return None return None
else: else:
# The last block is shared with other sequences. # The last block is shared with other sequences.
@ -440,9 +540,12 @@ class BlockSpaceManager:
seq: Sequence, seq: Sequence,
access_time: float, access_time: float,
) -> None: ) -> None:
block_table = self.block_tables[seq.seq_id] if self.enable_caching:
for block in block_table: # Update the last accessed time of all the blocks accessed
block.last_accessed = access_time # in this step.
block_table = self.block_tables[seq.seq_id]
for block in block_table:
block.last_accessed = access_time
def compute_full_blocks_in_seq(self, seq: Sequence): def compute_full_blocks_in_seq(self, seq: Sequence):
if seq.seq_id not in self.block_tables: if seq.seq_id not in self.block_tables:

View File

@ -1,5 +1,5 @@
import enum import enum
from typing import Dict, List, Optional from typing import Dict
from abc import ABC, abstractmethod, abstractproperty from abc import ABC, abstractmethod, abstractproperty
from vllm.block import PhysicalTokenBlock from vllm.block import PhysicalTokenBlock
@ -10,7 +10,6 @@ class EvictionPolicy(enum.Enum):
Evictor subclass. Evictor subclass.
""" """
LRU = enum.auto() LRU = enum.auto()
FIFO = enum.auto()
class Evictor(ABC): class Evictor(ABC):
@ -65,75 +64,23 @@ class LRUEvictor(Evictor):
return block_hash in self.free_table return block_hash in self.free_table
# TODO: The performance of this evict function can be optimized further. # TODO: The performance of this evict function can be optimized further.
def evict(self) -> PhysicalTokenBlock:
free_blocks: List[PhysicalTokenBlock] = list(self.free_table.values())
if len(free_blocks) == 0:
raise ValueError("No usable cache memory left")
# Find lowest timestamp
lowest_timestamp = free_blocks[0].last_accessed
for block in free_blocks:
if block.last_accessed < lowest_timestamp:
lowest_timestamp = block.last_accessed
# Find all blocks with the lowest timestamp
least_recent: List[PhysicalTokenBlock] = []
for block in free_blocks:
if block.last_accessed == lowest_timestamp:
least_recent.append(block)
# Find highest prefix count per block
highest_num_hashed_tokens = 0
for block in least_recent:
if block.num_hashed_tokens > highest_num_hashed_tokens:
highest_num_hashed_tokens = block.num_hashed_tokens
evicted_block: Optional[PhysicalTokenBlock] = None
# Find the first block with the lowest timestamp
for block in least_recent:
if block.num_hashed_tokens == highest_num_hashed_tokens:
evicted_block = block
break
assert evicted_block is not None
del self.free_table[evicted_block.block_hash]
evicted_block.computed = False
return evicted_block
def add(self, block: PhysicalTokenBlock):
self.free_table[block.block_hash] = block
def remove(self, block_hash: int) -> PhysicalTokenBlock:
if block_hash not in self.free_table:
raise ValueError(
"Attempting to remove block that's not in the evictor")
block: PhysicalTokenBlock = self.free_table[block_hash]
del self.free_table[block_hash]
return block
@property
def num_blocks(self) -> int:
return len(self.free_table)
class RandomEvictor(Evictor):
"""Evicts in a first-in-first-out order"""
def __init__(self):
self.free_table: Dict[int, PhysicalTokenBlock] = {}
def __contains__(self, block_hash: int) -> bool:
return block_hash in self.free_table
def evict(self) -> PhysicalTokenBlock: def evict(self) -> PhysicalTokenBlock:
if len(self.free_table) == 0: if len(self.free_table) == 0:
raise ValueError("No usable cache memory left") raise ValueError("No usable cache memory left")
evicted_block = next(iter(self.free_table.values())) free_blocks = self.free_table.values()
evicted_block.computed = False
# Get evicted block
evicted_block: PhysicalTokenBlock = next(iter(free_blocks))
for block in free_blocks:
if (block.last_accessed < evicted_block.last_accessed
or block.last_accessed == evicted_block.last_accessed and
block.num_hashed_tokens > evicted_block.num_hashed_tokens):
evicted_block = block
del self.free_table[evicted_block.block_hash] del self.free_table[evicted_block.block_hash]
evicted_block.computed = False
return evicted_block return evicted_block
def add(self, block: PhysicalTokenBlock): def add(self, block: PhysicalTokenBlock):
@ -155,7 +102,5 @@ class RandomEvictor(Evictor):
def make_evictor(eviction_policy: EvictionPolicy) -> Evictor: def make_evictor(eviction_policy: EvictionPolicy) -> Evictor:
if eviction_policy == EvictionPolicy.LRU: if eviction_policy == EvictionPolicy.LRU:
return LRUEvictor() return LRUEvictor()
elif eviction_policy == EvictionPolicy.FIFO:
return RandomEvictor()
else: else:
raise ValueError(f"Unknown cache eviction policy: {eviction_policy}") raise ValueError(f"Unknown cache eviction policy: {eviction_policy}")