mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:16:06 +08:00
[PREFIX CACHING FOLLOW UP] A bunch of fixes to block allocator performance when automatic prefix caching is disabled (#3357)
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
parent
20478c4d3a
commit
9474e89ba4
@ -4,7 +4,7 @@ from typing import List
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.block import PhysicalTokenBlock
|
||||
from vllm.core.block_manager import (BlockAllocator, BlockSpaceManager,
|
||||
from vllm.core.block_manager import (UncachedBlockAllocator, BlockSpaceManager,
|
||||
AllocStatus)
|
||||
from vllm.utils import Device
|
||||
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus, Logprob
|
||||
@ -15,7 +15,8 @@ from .utils import create_dummy_prompt
|
||||
def test_block_allocator_allocate():
|
||||
block_size = 4
|
||||
num_cpu_blocks = 4
|
||||
cpu_allocator = BlockAllocator(Device.CPU, block_size, num_cpu_blocks)
|
||||
cpu_allocator = UncachedBlockAllocator(Device.CPU, block_size,
|
||||
num_cpu_blocks)
|
||||
|
||||
# Allocate all available cpu blocks.
|
||||
num_free = num_cpu_blocks
|
||||
@ -24,7 +25,7 @@ def test_block_allocator_allocate():
|
||||
block = cpu_allocator.allocate()
|
||||
num_free -= 1
|
||||
|
||||
assert block.block_hash not in cpu_allocator.evictor
|
||||
assert block not in cpu_allocator.free_blocks
|
||||
assert cpu_allocator.get_num_free_blocks() == num_free
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
@ -34,14 +35,15 @@ def test_block_allocator_allocate():
|
||||
def test_block_allocator_free():
|
||||
block_size = 4
|
||||
num_cpu_blocks = 4
|
||||
cpu_allocator = BlockAllocator(Device.CPU, block_size, num_cpu_blocks)
|
||||
cpu_allocator = UncachedBlockAllocator(Device.CPU, block_size,
|
||||
num_cpu_blocks)
|
||||
|
||||
# Allocate all available cpu blocks.
|
||||
blocks: List[PhysicalTokenBlock] = []
|
||||
for _ in range(num_cpu_blocks):
|
||||
block = cpu_allocator.allocate()
|
||||
blocks.append(block)
|
||||
assert block.block_hash not in cpu_allocator.evictor
|
||||
assert block not in cpu_allocator.free_blocks
|
||||
|
||||
# Free all allocated cpu blocks.
|
||||
num_free = 0
|
||||
@ -49,7 +51,7 @@ def test_block_allocator_free():
|
||||
for block in blocks:
|
||||
cpu_allocator.free(block)
|
||||
num_free += 1
|
||||
assert block.block_hash in cpu_allocator.evictor
|
||||
assert block in cpu_allocator.free_blocks
|
||||
assert cpu_allocator.get_num_free_blocks() == num_free
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
|
||||
@ -4,7 +4,7 @@ Run `pytest tests/prefix_caching/test_prefix_caching.py`.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from vllm.core.block_manager import BlockAllocator
|
||||
from vllm.core.block_manager import CachedBlockAllocator
|
||||
from vllm.utils import Device
|
||||
|
||||
|
||||
@ -15,10 +15,7 @@ def test_block_allocator(
|
||||
num_blocks: int,
|
||||
):
|
||||
block_hash = 1
|
||||
block_allocator = BlockAllocator(Device.CPU,
|
||||
block_size,
|
||||
num_blocks,
|
||||
enable_caching=True)
|
||||
block_allocator = CachedBlockAllocator(Device.CPU, block_size, num_blocks)
|
||||
|
||||
# Allocate two PysicalTokenBlocks with the same hash and check
|
||||
# that they are the same PhysicalTokenBlock
|
||||
@ -45,10 +42,7 @@ def test_block_allocator(
|
||||
@pytest.mark.parametrize("num_blocks", [16])
|
||||
def test_eviction(num_blocks: int, ):
|
||||
block_size = 16
|
||||
block_allocator = BlockAllocator(Device.CPU,
|
||||
block_size,
|
||||
num_blocks,
|
||||
enable_caching=True)
|
||||
block_allocator = CachedBlockAllocator(Device.CPU, block_size, num_blocks)
|
||||
blocks = []
|
||||
|
||||
for i in range(num_blocks):
|
||||
|
||||
@ -3,6 +3,7 @@ import enum
|
||||
from itertools import count, takewhile
|
||||
from os.path import commonprefix
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from vllm.block import BlockTable, PhysicalTokenBlock
|
||||
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
|
||||
@ -10,7 +11,46 @@ from vllm.utils import Device
|
||||
from vllm.core.evictor import Evictor, EvictionPolicy, make_evictor
|
||||
|
||||
|
||||
class BlockAllocator:
|
||||
class BlockAllocatorBase(ABC):
|
||||
"""Manages free physical token blocks for a device.
|
||||
|
||||
The allocator maintains a list of free blocks and allocates a block when
|
||||
requested. When a block is freed, its reference count is decremented. If
|
||||
the reference count becomes zero, the block is added back to the free list.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self,
|
||||
device: Device,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
eviction_policy: EvictionPolicy = EvictionPolicy.LRU):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def allocate(self,
|
||||
block_hash: Optional[int] = None,
|
||||
num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def free(self, block: PhysicalTokenBlock) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_num_free_blocks(self) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def contains_block(self, block_hash: int) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
|
||||
pass
|
||||
|
||||
|
||||
class CachedBlockAllocator(BlockAllocatorBase):
|
||||
"""Manages free physical token blocks for a device.
|
||||
|
||||
The allocator maintains a list of free blocks and allocates a block when
|
||||
@ -22,19 +62,14 @@ class BlockAllocator:
|
||||
device: Device,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
eviction_policy: EvictionPolicy = EvictionPolicy.LRU,
|
||||
enable_caching: bool = False) -> None:
|
||||
eviction_policy: EvictionPolicy = EvictionPolicy.LRU) -> None:
|
||||
self.device = device
|
||||
self.block_size = block_size
|
||||
self.num_blocks = num_blocks
|
||||
self.enable_caching = enable_caching
|
||||
|
||||
self.current_num_blocks = 0
|
||||
self.cached_blocks: Dict[int, PhysicalTokenBlock] = {}
|
||||
|
||||
# Switch over to FIFO eviction when caching is disabled
|
||||
if not self.enable_caching:
|
||||
eviction_policy = EvictionPolicy.FIFO
|
||||
self.evictor: Evictor = make_evictor(eviction_policy)
|
||||
|
||||
self.default_hash_ctr = count()
|
||||
@ -57,13 +92,6 @@ class BlockAllocator:
|
||||
def allocate(self,
|
||||
block_hash: Optional[int] = None,
|
||||
num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
|
||||
# If caching is disabled, just allocate a new block and return it
|
||||
if not self.enable_caching:
|
||||
block = self.allocate_block(next(self.default_hash_ctr),
|
||||
num_hashed_tokens)
|
||||
block.ref_count += 1
|
||||
return block
|
||||
|
||||
if block_hash is None:
|
||||
block_hash = next(self.default_hash_ctr)
|
||||
if block_hash in self.evictor:
|
||||
@ -90,9 +118,8 @@ class BlockAllocator:
|
||||
assert block.block_hash not in self.evictor
|
||||
self.evictor.add(block)
|
||||
|
||||
# If caching is enabled, remove the block from the cached_blocks
|
||||
if self.enable_caching:
|
||||
del self.cached_blocks[block.block_hash]
|
||||
# Remove the block from the cached_blocks
|
||||
del self.cached_blocks[block.block_hash]
|
||||
|
||||
def get_num_free_blocks(self) -> int:
|
||||
return (self.num_blocks - self.current_num_blocks +
|
||||
@ -102,14 +129,68 @@ class BlockAllocator:
|
||||
return block_hash in self.cached_blocks or block_hash in self.evictor
|
||||
|
||||
def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
|
||||
# If caching is enabled, update the hash of block and the
|
||||
# cached_blocks dictionary.
|
||||
if self.enable_caching:
|
||||
assert not self.contains_block(block_hash)
|
||||
old_hash = block.block_hash
|
||||
block.block_hash = block_hash
|
||||
del self.cached_blocks[old_hash]
|
||||
self.cached_blocks[block_hash] = block
|
||||
# Update the hash of block and the cached_blocks dictionary.
|
||||
assert not self.contains_block(block_hash)
|
||||
old_hash = block.block_hash
|
||||
block.block_hash = block_hash
|
||||
del self.cached_blocks[old_hash]
|
||||
self.cached_blocks[block_hash] = block
|
||||
|
||||
|
||||
class UncachedBlockAllocator(BlockAllocatorBase):
|
||||
"""Manages free physical token blocks for a device.
|
||||
|
||||
The allocator maintains a list of free blocks and allocates a block when
|
||||
requested. When a block is freed, its reference count is decremented. If
|
||||
the reference count becomes zero, the block is added back to the free list.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: Device,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
) -> None:
|
||||
self.device = device
|
||||
self.block_size = block_size
|
||||
self.num_blocks = num_blocks
|
||||
|
||||
# Initialize the free blocks.
|
||||
self.free_blocks: BlockTable = []
|
||||
for i in range(num_blocks):
|
||||
block = PhysicalTokenBlock(device=device,
|
||||
block_number=i,
|
||||
block_size=block_size,
|
||||
block_hash=-1,
|
||||
num_hashed_tokens=0)
|
||||
self.free_blocks.append(block)
|
||||
|
||||
def allocate(self,
|
||||
block_hash: Optional[int] = None,
|
||||
num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
|
||||
if not self.free_blocks:
|
||||
raise ValueError("Out of memory! No free blocks are available.")
|
||||
block = self.free_blocks.pop()
|
||||
block.ref_count = 1
|
||||
return block
|
||||
|
||||
def free(self, block: PhysicalTokenBlock) -> None:
|
||||
if block.ref_count == 0:
|
||||
raise ValueError(f"Double free! {block} is already freed.")
|
||||
block.ref_count -= 1
|
||||
if block.ref_count == 0:
|
||||
self.free_blocks.append(block)
|
||||
|
||||
def get_num_free_blocks(self) -> int:
|
||||
return len(self.free_blocks)
|
||||
|
||||
def contains_block(self, block_hash: int) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Invalid codepath for uncached block allocator.")
|
||||
|
||||
def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
|
||||
raise NotImplementedError(
|
||||
"Invalid codepath for uncached block allocator.")
|
||||
|
||||
|
||||
class AllocStatus(enum.Enum):
|
||||
@ -142,6 +223,10 @@ class BlockSpaceManager:
|
||||
self.num_total_gpu_blocks = num_gpu_blocks
|
||||
self.num_total_cpu_blocks = num_cpu_blocks
|
||||
|
||||
if enable_caching and sliding_window is not None:
|
||||
raise NotImplementedError(
|
||||
"Sliding window is not allowed with prefix caching enabled!")
|
||||
|
||||
self.block_sliding_window = None
|
||||
if sliding_window is not None:
|
||||
assert sliding_window % block_size == 0, (sliding_window,
|
||||
@ -154,14 +239,17 @@ class BlockSpaceManager:
|
||||
self.enable_caching = enable_caching
|
||||
|
||||
self.watermark_blocks = int(watermark * num_gpu_blocks)
|
||||
self.gpu_allocator = BlockAllocator(Device.GPU,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
enable_caching=enable_caching)
|
||||
self.cpu_allocator = BlockAllocator(Device.CPU,
|
||||
block_size,
|
||||
num_cpu_blocks,
|
||||
enable_caching=enable_caching)
|
||||
|
||||
if self.enable_caching:
|
||||
self.gpu_allocator = CachedBlockAllocator(Device.GPU, block_size,
|
||||
num_gpu_blocks)
|
||||
self.cpu_allocator = CachedBlockAllocator(Device.CPU, block_size,
|
||||
num_cpu_blocks)
|
||||
else:
|
||||
self.gpu_allocator = UncachedBlockAllocator(
|
||||
Device.GPU, block_size, num_gpu_blocks)
|
||||
self.cpu_allocator = UncachedBlockAllocator(
|
||||
Device.CPU, block_size, num_cpu_blocks)
|
||||
# Mapping: seq_id -> BlockTable.
|
||||
self.block_tables: Dict[int, BlockTable] = {}
|
||||
|
||||
@ -198,10 +286,16 @@ class BlockSpaceManager:
|
||||
if (self.block_sliding_window is not None
|
||||
and logical_idx >= self.block_sliding_window):
|
||||
block = block_table[logical_idx % self.block_sliding_window]
|
||||
else:
|
||||
# Set the reference counts of the token blocks.
|
||||
block.ref_count = seq_group.num_seqs()
|
||||
elif self.enable_caching:
|
||||
block = self.gpu_allocator.allocate(
|
||||
seq.hash_of_block(logical_idx),
|
||||
seq.num_hashed_tokens_of_block(logical_idx))
|
||||
else:
|
||||
block = self.gpu_allocator.allocate()
|
||||
# Set the reference counts of the token blocks.
|
||||
block.ref_count = seq_group.num_seqs()
|
||||
block_table.append(block)
|
||||
|
||||
# Assign the block table for each sequence.
|
||||
@ -220,8 +314,10 @@ class BlockSpaceManager:
|
||||
seq: Sequence,
|
||||
last_block: PhysicalTokenBlock,
|
||||
) -> PhysicalTokenBlock:
|
||||
# Compute a new hash for the block so that it can be shared by
|
||||
# other Sequences
|
||||
assert self.enable_caching
|
||||
|
||||
# Compute a new hash for the block so that it can be shared by other
|
||||
# Sequences
|
||||
new_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
|
||||
|
||||
# if new_hash is already in the cached table, then free last_block
|
||||
@ -254,6 +350,8 @@ class BlockSpaceManager:
|
||||
self,
|
||||
seq: Sequence,
|
||||
) -> PhysicalTokenBlock:
|
||||
if not self.enable_caching:
|
||||
return self.gpu_allocator.allocate()
|
||||
block_hash: Optional[int] = None
|
||||
if (self._is_last_block_full(seq)):
|
||||
block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
|
||||
@ -293,10 +391,12 @@ class BlockSpaceManager:
|
||||
assert last_block.device == Device.GPU
|
||||
if last_block.ref_count == 1:
|
||||
# Not shared with other sequences. Appendable.
|
||||
# If the last block is now complete, promote it to a full block so
|
||||
# that it can be shared
|
||||
new_block = self._maybe_promote_last_block(seq, last_block)
|
||||
block_table[-1] = new_block
|
||||
if self.enable_caching:
|
||||
# If the last block is now complete, we may reuse an old block
|
||||
# to save memory.
|
||||
maybe_new_block = self._maybe_promote_last_block(
|
||||
seq, last_block)
|
||||
block_table[-1] = maybe_new_block
|
||||
return None
|
||||
else:
|
||||
# The last block is shared with other sequences.
|
||||
@ -440,9 +540,12 @@ class BlockSpaceManager:
|
||||
seq: Sequence,
|
||||
access_time: float,
|
||||
) -> None:
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
for block in block_table:
|
||||
block.last_accessed = access_time
|
||||
if self.enable_caching:
|
||||
# Update the last accessed time of all the blocks accessed
|
||||
# in this step.
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
for block in block_table:
|
||||
block.last_accessed = access_time
|
||||
|
||||
def compute_full_blocks_in_seq(self, seq: Sequence):
|
||||
if seq.seq_id not in self.block_tables:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import enum
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict
|
||||
from abc import ABC, abstractmethod, abstractproperty
|
||||
|
||||
from vllm.block import PhysicalTokenBlock
|
||||
@ -10,7 +10,6 @@ class EvictionPolicy(enum.Enum):
|
||||
Evictor subclass.
|
||||
"""
|
||||
LRU = enum.auto()
|
||||
FIFO = enum.auto()
|
||||
|
||||
|
||||
class Evictor(ABC):
|
||||
@ -65,75 +64,23 @@ class LRUEvictor(Evictor):
|
||||
return block_hash in self.free_table
|
||||
|
||||
# TODO: The performance of this evict function can be optimized further.
|
||||
def evict(self) -> PhysicalTokenBlock:
|
||||
free_blocks: List[PhysicalTokenBlock] = list(self.free_table.values())
|
||||
if len(free_blocks) == 0:
|
||||
raise ValueError("No usable cache memory left")
|
||||
|
||||
# Find lowest timestamp
|
||||
lowest_timestamp = free_blocks[0].last_accessed
|
||||
for block in free_blocks:
|
||||
if block.last_accessed < lowest_timestamp:
|
||||
lowest_timestamp = block.last_accessed
|
||||
|
||||
# Find all blocks with the lowest timestamp
|
||||
least_recent: List[PhysicalTokenBlock] = []
|
||||
for block in free_blocks:
|
||||
if block.last_accessed == lowest_timestamp:
|
||||
least_recent.append(block)
|
||||
|
||||
# Find highest prefix count per block
|
||||
highest_num_hashed_tokens = 0
|
||||
for block in least_recent:
|
||||
if block.num_hashed_tokens > highest_num_hashed_tokens:
|
||||
highest_num_hashed_tokens = block.num_hashed_tokens
|
||||
|
||||
evicted_block: Optional[PhysicalTokenBlock] = None
|
||||
|
||||
# Find the first block with the lowest timestamp
|
||||
for block in least_recent:
|
||||
if block.num_hashed_tokens == highest_num_hashed_tokens:
|
||||
evicted_block = block
|
||||
break
|
||||
|
||||
assert evicted_block is not None
|
||||
|
||||
del self.free_table[evicted_block.block_hash]
|
||||
|
||||
evicted_block.computed = False
|
||||
return evicted_block
|
||||
|
||||
def add(self, block: PhysicalTokenBlock):
|
||||
self.free_table[block.block_hash] = block
|
||||
|
||||
def remove(self, block_hash: int) -> PhysicalTokenBlock:
|
||||
if block_hash not in self.free_table:
|
||||
raise ValueError(
|
||||
"Attempting to remove block that's not in the evictor")
|
||||
block: PhysicalTokenBlock = self.free_table[block_hash]
|
||||
del self.free_table[block_hash]
|
||||
return block
|
||||
|
||||
@property
|
||||
def num_blocks(self) -> int:
|
||||
return len(self.free_table)
|
||||
|
||||
|
||||
class RandomEvictor(Evictor):
|
||||
"""Evicts in a first-in-first-out order"""
|
||||
|
||||
def __init__(self):
|
||||
self.free_table: Dict[int, PhysicalTokenBlock] = {}
|
||||
|
||||
def __contains__(self, block_hash: int) -> bool:
|
||||
return block_hash in self.free_table
|
||||
|
||||
def evict(self) -> PhysicalTokenBlock:
|
||||
if len(self.free_table) == 0:
|
||||
raise ValueError("No usable cache memory left")
|
||||
evicted_block = next(iter(self.free_table.values()))
|
||||
evicted_block.computed = False
|
||||
free_blocks = self.free_table.values()
|
||||
|
||||
# Get evicted block
|
||||
evicted_block: PhysicalTokenBlock = next(iter(free_blocks))
|
||||
|
||||
for block in free_blocks:
|
||||
if (block.last_accessed < evicted_block.last_accessed
|
||||
or block.last_accessed == evicted_block.last_accessed and
|
||||
block.num_hashed_tokens > evicted_block.num_hashed_tokens):
|
||||
evicted_block = block
|
||||
|
||||
del self.free_table[evicted_block.block_hash]
|
||||
|
||||
evicted_block.computed = False
|
||||
return evicted_block
|
||||
|
||||
def add(self, block: PhysicalTokenBlock):
|
||||
@ -155,7 +102,5 @@ class RandomEvictor(Evictor):
|
||||
def make_evictor(eviction_policy: EvictionPolicy) -> Evictor:
|
||||
if eviction_policy == EvictionPolicy.LRU:
|
||||
return LRUEvictor()
|
||||
elif eviction_policy == EvictionPolicy.FIFO:
|
||||
return RandomEvictor()
|
||||
else:
|
||||
raise ValueError(f"Unknown cache eviction policy: {eviction_policy}")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user