[Core] support LoRA and prompt adapter in content-based hashing for Block Manager v2 prefix caching (#8240)

This commit is contained in:
Sungjae Lee 2024-12-14 00:51:25 +09:00 committed by GitHub
parent d1fa714cb1
commit c31d4a57a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 244 additions and 53 deletions

View File

@ -5,7 +5,7 @@ from unittest.mock import MagicMock
import pytest
from tests.core.utils import create_dummy_sequence
from tests.core.utils import create_dummy_lora_sequence, create_dummy_sequence
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
from vllm.core.block.interfaces import Block, BlockAllocator
from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker,
@ -801,6 +801,7 @@ class TestPrefixCachingBlockAllocator:
block_size: int,
token_ids: List[int],
allocator: PrefixCachingBlockAllocator,
extra_hash: Optional[int] = None,
) -> List[PrefixCachingBlock]:
"""Helper method which creates a chain of blocks.
"""
@ -816,7 +817,9 @@ class TestPrefixCachingBlockAllocator:
block_size:(block_number + 1) *
block_size]
prev_block = allocator.allocate_immutable_block(
prev_block=prev_block, token_ids=block_token_ids)
prev_block=prev_block,
token_ids=block_token_ids,
extra_hash=extra_hash)
blocks.append(prev_block)
return blocks
@ -931,3 +934,61 @@ class TestComputedBlocksTracker:
allocator.mark_blocks_as_computed([])
assert tracker.get_num_cached_tokens(seq) == len(tokens)
@staticmethod
def test_correct_extra_hash():
"""
Test that the block hash is correctly computed based on the extra hash,
ensuring it matches the allocator's block hash, specifically for the
LoRA case, and that the correct number of cached tokens is retrieved.
"""
block_size = 4
allocator = CpuGpuBlockAllocator.create(
allocator_type="prefix_caching",
num_gpu_blocks=16,
num_cpu_blocks=16,
block_size=block_size,
)
gpu_allocator = allocator._allocators[Device.GPU]
tracker = ComputedBlocksTracker(
allocator=allocator,
block_size=block_size,
enable_caching=True,
)
tokens = list(range(block_size * 4))
# Create a dummy LoRA sequence with a specific LoRA ID.
lora_seq = create_dummy_lora_sequence(request_id=0,
token_ids=tokens,
block_size=block_size,
lora_int_id=1)
_ = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=tokens,
allocator=gpu_allocator,
extra_hash=lora_seq.extra_hash(),
)
allocator.mark_blocks_as_computed([])
# Create different dummy sequences that have the same token IDs
# but different LoRA IDs.
seq = create_dummy_sequence(request_id=1,
token_ids=tokens,
block_size=block_size)
different_lora_seq = create_dummy_lora_sequence(request_id=2,
token_ids=tokens,
block_size=block_size,
lora_int_id=2)
# Due to the different LoRA IDs, corresponding blocks are not cached.
assert tracker.get_num_cached_tokens(seq) == 0
assert tracker.get_num_cached_tokens(different_lora_seq) == 0
# The number of cached tokens matches the length of the tokens
# for the cached LoRA sequence.
assert tracker.get_num_cached_tokens(lora_seq) == len(tokens)

View File

@ -46,6 +46,16 @@ def create_dummy_prompt(
return prompt, seq_group
def create_dummy_lora_sequence(request_id: int, token_ids: List[int],
block_size: int, lora_int_id: int) -> Sequence:
return Sequence(seq_id=request_id,
inputs=token_inputs(token_ids),
block_size=block_size,
lora_request=LoRARequest(lora_name="dummy",
lora_path="/dummy",
lora_int_id=lora_int_id))
def create_dummy_sequence(request_id: int, token_ids: List[int],
block_size: int) -> Sequence:
return Sequence(

View File

@ -80,7 +80,8 @@ class BlockTable:
def allocate(self,
token_ids: List[int],
device: Device = Device.GPU) -> None:
device: Device = Device.GPU,
extra_hash: Optional[int] = None) -> None:
"""Allocates memory blocks for storing the given sequence of token IDs.
This method allocates the required number of blocks to store the given
@ -90,12 +91,16 @@ class BlockTable:
token_ids (List[int]): The sequence of token IDs to be stored.
device (Device, optional): The device on which the blocks should be
allocated. Defaults to Device.GPU.
extra_hash (Optional[int]): The hash value of additional
factors, such as adapters, that influence the block hash
in the prefixcaching block.
"""
assert not self._is_allocated
assert token_ids
blocks = self._allocate_blocks_for_token_ids(prev_block=None,
token_ids=token_ids,
device=device)
device=device,
extra_hash=extra_hash)
self.update(blocks)
self._num_full_slots = len(token_ids)
@ -108,7 +113,8 @@ class BlockTable:
def append_token_ids(self,
token_ids: List[int],
num_lookahead_slots: int = 0,
num_computed_slots: Optional[int] = None) -> None:
num_computed_slots: Optional[int] = None,
extra_hash: Optional[int] = None) -> None:
"""Appends a sequence of token IDs to the existing blocks in the
BlockTable.
@ -130,6 +136,9 @@ class BlockTable:
Without sliding window, None can be passed.
Without chunked prefill, it should be the same as
_num_full_slots.
extra_hash (Optional[int]): The hash value of additional
factors such as adapters that influence the block, apart
from the token_ids.
"""
assert self._is_allocated, "no blocks have been allocated"
assert len(self._blocks) > 0
@ -149,7 +158,8 @@ class BlockTable:
# Ensure there are enough empty slots for the new tokens plus
# lookahead slots
self.ensure_num_empty_slots(num_empty_slots=len(token_ids) +
num_lookahead_slots)
num_lookahead_slots,
extra_hash=extra_hash)
# Update the blocks with the new tokens
first_block_idx = self._num_full_slots // self._block_size
@ -160,7 +170,9 @@ class BlockTable:
self._num_full_slots += len(token_ids)
def ensure_num_empty_slots(self, num_empty_slots: int) -> None:
def ensure_num_empty_slots(self,
num_empty_slots: int,
extra_hash: Optional[int] = None) -> None:
"""Ensures that the BlockTable has at least the specified number of
empty slots available.
@ -171,6 +183,9 @@ class BlockTable:
Args:
num_empty_slots (int): The minimum number of empty slots required.
extra_hash (Optional[int]): The hash value of additional
factors such as adapters that influence the block, apart
from the token_ids.
"""
# Currently the block table only supports
# appending tokens to GPU blocks.
@ -187,7 +202,9 @@ class BlockTable:
assert len(self._blocks) > 0
self._blocks.append(
self._allocator.allocate_mutable_block(
prev_block=self._blocks[-1], device=device))
prev_block=self._blocks[-1],
device=device,
extra_hash=extra_hash))
def fork(self) -> "BlockTable":
"""Creates a new BlockTable instance with a copy of the blocks from the
@ -259,9 +276,12 @@ class BlockTable:
# ones after the appended ones.
return sequence_token_ids[self.num_full_slots:]
def _allocate_blocks_for_token_ids(self, prev_block: Optional[Block],
token_ids: List[int],
device: Device) -> List[Block]:
def _allocate_blocks_for_token_ids(
self,
prev_block: Optional[Block],
token_ids: List[int],
device: Device,
extra_hash: Optional[int] = None) -> List[Block]:
blocks: List[Block] = []
block_token_ids = []
@ -275,8 +295,10 @@ class BlockTable:
if block_token_ids:
blocks.extend(
self._allocator.allocate_immutable_blocks(
prev_block, block_token_ids=block_token_ids,
device=device))
prev_block,
block_token_ids=block_token_ids,
device=device,
extra_hash=extra_hash))
prev_block = blocks[-1]
if tail_token_ids:
@ -284,7 +306,7 @@ class BlockTable:
cur_token_ids = tail_token_ids[0]
block = self._allocator.allocate_mutable_block(
prev_block=prev_block, device=device)
prev_block=prev_block, device=device, extra_hash=extra_hash)
block.append_token_ids(cur_token_ids)
blocks.append(block)

View File

@ -177,7 +177,8 @@ class BlockPool:
token_ids=[],
block_size=self._block_size,
allocator=self._allocator,
block_id=None))
block_id=None,
extra_hash=None))
def increase_pool(self):
"""Doubles the internal pool size
@ -194,10 +195,15 @@ class BlockPool:
token_ids=[],
block_size=self._block_size,
allocator=self._allocator,
block_id=None))
block_id=None,
extra_hash=None))
def init_block(self, prev_block: Optional[Block], token_ids: List[int],
block_size: int, physical_block_id: Optional[int]) -> Block:
def init_block(self,
prev_block: Optional[Block],
token_ids: List[int],
block_size: int,
physical_block_id: Optional[int],
extra_hash: Optional[int] = None) -> Block:
if len(self._free_ids) == 0:
self.increase_pool()
assert len(self._free_ids) > 0
@ -210,7 +216,8 @@ class BlockPool:
token_ids=token_ids,
block_size=block_size,
allocator=block._allocator, # type: ignore[attr-defined]
block_id=physical_block_id)
block_id=physical_block_id,
extra_hash=extra_hash)
block.pool_id = pool_id # type: ignore[attr-defined]
return block

View File

@ -121,23 +121,32 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
self.allocate_mutable_block(None, Device.GPU))
return self._null_block
def allocate_mutable_block(self, prev_block: Optional[Block],
device: Device) -> Block:
def allocate_mutable_block(self,
prev_block: Optional[Block],
device: Device,
extra_hash: Optional[int] = None) -> Block:
"""Allocates a new mutable block on the specified device.
Args:
prev_block (Optional[Block]): The previous block to in the sequence.
Used for prefix hashing.
device (Device): The device on which to allocate the new block.
extra_hash (Optional[int]): The hash value of additional
factors, such as adapters, that influence the block hash
in the prefix caching block.
Returns:
Block: The newly allocated mutable block.
"""
return self._allocators[device].allocate_mutable_block(prev_block)
return self._allocators[device].allocate_mutable_block(
prev_block, extra_hash=extra_hash)
def allocate_immutable_blocks(self, prev_block: Optional[Block],
block_token_ids: List[List[int]],
device: Device) -> List[Block]:
def allocate_immutable_blocks(
self,
prev_block: Optional[Block],
block_token_ids: List[List[int]],
device: Device,
extra_hash: Optional[int] = None) -> List[Block]:
"""Allocates a new group of immutable blocks with the provided block
token IDs on the specified device.
@ -147,17 +156,22 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
block_token_ids (List[int]): The list of block token IDs to be
stored in the new blocks.
device (Device): The device on which to allocate the new block.
extra_hash (Optional[int]): The hash value of additional
factors, such as adapters, that influence the block hash
in the prefix caching block.
Returns:
List[Block]: The newly allocated list of immutable blocks
containing the provided block token IDs.
"""
return self._allocators[device].allocate_immutable_blocks(
prev_block, block_token_ids)
prev_block, block_token_ids, extra_hash=extra_hash)
def allocate_immutable_block(self, prev_block: Optional[Block],
def allocate_immutable_block(self,
prev_block: Optional[Block],
token_ids: List[int],
device: Device) -> Block:
device: Device,
extra_hash: Optional[int] = None) -> Block:
"""Allocates a new immutable block with the provided token IDs on the
specified device.
@ -167,13 +181,16 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
token_ids (List[int]): The list of token IDs to be stored in the new
block.
device (Device): The device on which to allocate the new block.
extra_hash (Optional[int]): The hash value of additional
factors, such as adapters, that influence the block hash
in the prefix caching block.
Returns:
Block: The newly allocated immutable block containing the provided
token IDs.
"""
return self._allocators[device].allocate_immutable_block(
prev_block, token_ids)
prev_block, token_ids, extra_hash=extra_hash)
def free(self, block: Block) -> None:
"""Frees the memory occupied by the given block.
@ -387,6 +404,10 @@ class NullBlock(Block):
def prev_block(self):
return self._proxy.prev_block
@property
def extra_hash(self):
return None
@property
def computed(self):
return self._proxy.computed

View File

@ -50,6 +50,11 @@ class Block(ABC):
def prev_block(self) -> Optional["Block"]:
pass
@property
@abstractmethod
def extra_hash(self) -> Optional[int]:
return None
@property
@abstractmethod
def computed(self) -> bool:
@ -81,6 +86,8 @@ class Block(ABC):
block_size: int,
allocator: "BlockAllocator",
block_id: Optional[int] = None,
computed: bool = False,
extra_hash: Optional[int] = None,
) -> "Block":
pass
@ -99,18 +106,20 @@ class Block(ABC):
class BlockAllocator(ABC):
@abstractmethod
def allocate_mutable_block(self, prev_block: Optional[Block]) -> Block:
def allocate_mutable_block(self, prev_block: Optional[Block],
extra_hash: Optional[int]) -> Block:
pass
@abstractmethod
def allocate_immutable_block(self, prev_block: Optional[Block],
token_ids: List[int]) -> Block:
token_ids: List[int],
extra_hash: Optional[int]) -> Block:
pass
@abstractmethod
def allocate_immutable_blocks(
self, prev_block: Optional[Block],
block_token_ids: List[List[int]]) -> List[Block]:
def allocate_immutable_blocks(self, prev_block: Optional[Block],
block_token_ids: List[List[int]],
extra_hash: Optional[int]) -> List[Block]:
pass
@abstractmethod
@ -197,14 +206,18 @@ class BlockAllocator(ABC):
class DeviceAwareBlockAllocator(ABC):
@abstractmethod
def allocate_mutable_block(self, prev_block: Optional[Block],
device: Device) -> Block:
def allocate_mutable_block(self,
prev_block: Optional[Block],
device: Device,
extra_hash: Optional[int] = None) -> Block:
pass
@abstractmethod
def allocate_immutable_block(self, prev_block: Optional[Block],
def allocate_immutable_block(self,
prev_block: Optional[Block],
token_ids: List[int],
device: Device) -> Block:
device: Device,
extra_hash: Optional[int] = None) -> Block:
pass
@abstractmethod
@ -213,6 +226,7 @@ class DeviceAwareBlockAllocator(ABC):
prev_block: Optional[Block],
block_token_ids: List[List[int]],
device: Device,
extra_hash: Optional[int] = None,
) -> List[Block]:
pass

View File

@ -63,6 +63,7 @@ class NaiveBlockAllocator(BlockAllocator):
def allocate_immutable_block(self,
prev_block: Optional[Block],
token_ids: List[int],
extra_hash: Optional[int] = None,
device: Optional[Device] = None) -> Block:
"""Allocates a new immutable block with the given token IDs, linked to
the previous block.
@ -85,6 +86,7 @@ class NaiveBlockAllocator(BlockAllocator):
self,
prev_block: Optional[Block],
block_token_ids: List[List[int]],
extra_hash: Optional[int] = None,
device: Optional[Device] = None) -> List[Block]:
assert device is None
num_blocks = len(block_token_ids)
@ -106,6 +108,7 @@ class NaiveBlockAllocator(BlockAllocator):
def allocate_mutable_block(self,
prev_block: Optional[Block],
extra_hash: Optional[int] = None,
device: Optional[Device] = None) -> Block:
"""Allocates a new mutable block, linked to the previous block.
@ -355,7 +358,8 @@ class NaiveBlock(Block):
block_size: int,
allocator: BlockAllocator,
block_id: Optional[int] = None,
_cow_target: Optional[Block] = None):
_cow_target: Optional[Block] = None,
extra_hash: Optional[int] = None):
self._token_ids: List[int] = []
self._block_size = block_size
self._prev_block = prev_block
@ -441,6 +445,10 @@ class NaiveBlock(Block):
def prev_block(self) -> Optional["Block"]:
return self._prev_block
@property
def extra_hash(self):
return None
@property
def content_hash(self) -> Optional[int]:
return None

View File

@ -126,6 +126,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
allocator: BlockAllocator,
block_id: Optional[int] = None,
computed: bool = False,
extra_hash: Optional[int] = None,
) -> Block:
# Bind block to self.
allocator = self
@ -137,11 +138,13 @@ class PrefixCachingBlockAllocator(BlockAllocator):
block_id=block_id,
allocator=allocator,
computed=computed,
extra_hash=extra_hash,
)
def allocate_immutable_block(self,
prev_block: Optional[Block],
token_ids: List[int],
extra_hash: Optional[int] = None,
device: Optional[Device] = None) -> Block:
"""Allocates an immutable block with the given token IDs, reusing cached
blocks if possible.
@ -160,7 +163,8 @@ class PrefixCachingBlockAllocator(BlockAllocator):
block = self._block_pool.init_block(prev_block=prev_block,
token_ids=token_ids,
block_size=self._block_size,
physical_block_id=None)
physical_block_id=None,
extra_hash=extra_hash)
assert block.content_hash is not None
cached_block_id = self._cached_blocks.get(block.content_hash, None)
@ -173,7 +177,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
self._block_pool.free_block(block)
# No cached block => Allocate a new block
block = self.allocate_mutable_block(prev_block)
block = self.allocate_mutable_block(prev_block, extra_hash=extra_hash)
block.append_token_ids(token_ids)
return block
@ -181,17 +185,20 @@ class PrefixCachingBlockAllocator(BlockAllocator):
self,
prev_block: Optional[Block],
block_token_ids: List[List[int]],
extra_hash: Optional[int] = None,
device: Optional[Device] = None) -> List[Block]:
blocks = []
for token_ids in block_token_ids:
prev_block = self.allocate_immutable_block(prev_block=prev_block,
token_ids=token_ids,
device=device)
device=device,
extra_hash=extra_hash)
blocks.append(prev_block)
return blocks
def allocate_mutable_block(self,
prev_block: Optional[Block],
extra_hash: Optional[int] = None,
device: Optional[Device] = None) -> Block:
"""Allocates a mutable block. If there are no free blocks, this will
evict unused cached blocks.
@ -210,7 +217,8 @@ class PrefixCachingBlockAllocator(BlockAllocator):
block = self._block_pool.init_block(prev_block=prev_block,
token_ids=[],
block_size=self._block_size,
physical_block_id=block_id)
physical_block_id=block_id,
extra_hash=extra_hash)
assert not block.computed
assert block.content_hash is None
return block
@ -382,7 +390,8 @@ class PrefixCachingBlockAllocator(BlockAllocator):
prev_block=prev_block,
token_ids=block.token_ids,
block_size=self._block_size,
physical_block_id=block_id)
physical_block_id=block_id,
extra_hash=block.extra_hash)
forked_blocks.append(forked_block)
prev_block = forked_blocks[-1]
@ -608,10 +617,12 @@ class PrefixCachingBlockAllocator(BlockAllocator):
# existing "block" object
if block.is_full:
tmp_block = self.allocate_immutable_block(
prev_block=block.prev_block, token_ids=block.token_ids)
prev_block=block.prev_block,
token_ids=block.token_ids,
extra_hash=block.extra_hash)
else:
tmp_block = self.allocate_mutable_block(
prev_block=block.prev_block)
prev_block=block.prev_block, extra_hash=block.extra_hash)
tmp_block.append_token_ids(block.token_ids)
block_id = tmp_block.block_id
@ -679,6 +690,8 @@ class PrefixCachingBlock(Block):
caching block allocator associated with this block.
block_id (Optional[int], optional): The physical block index
of this block. Defaults to None.
extra_hash (Optional[int]): The hash value of additional factors
such as adapters that influence the block, apart from the token_ids.
"""
def __init__(
@ -689,6 +702,7 @@ class PrefixCachingBlock(Block):
allocator: BlockAllocator,
block_id: Optional[int] = None,
computed: bool = False,
extra_hash: Optional[int] = None,
):
assert isinstance(allocator, PrefixCachingBlockAllocator), (
"Currently this class is only tested with "
@ -702,6 +716,7 @@ class PrefixCachingBlock(Block):
self._allocator = allocator
self._last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME
self._computed = computed
self._extra_hash = extra_hash
# On the first time, we create the block object, and next we only
# reinitialize it
@ -811,6 +826,10 @@ class PrefixCachingBlock(Block):
def prev_block(self) -> Optional[Block]:
return self._prev_block
@property
def extra_hash(self) -> Optional[int]:
return self._extra_hash
@property
def content_hash(self) -> Optional[int]:
"""Return the content-based hash of the current block, or None if it is
@ -841,18 +860,19 @@ class PrefixCachingBlock(Block):
self._cached_content_hash = PrefixCachingBlock.hash_block_tokens(
is_first_block,
prev_block_hash,
cur_block_token_ids=self.token_ids)
cur_block_token_ids=self.token_ids,
extra_hash=self._extra_hash)
return self._cached_content_hash
@staticmethod
def hash_block_tokens(is_first_block: bool, prev_block_hash: Optional[int],
cur_block_token_ids: List[int]) -> int:
def hash_block_tokens(is_first_block: bool,
prev_block_hash: Optional[int],
cur_block_token_ids: List[int],
extra_hash: Optional[int] = None) -> int:
"""Computes a hash value corresponding to the contents of a block and
the contents of the preceding block(s). The hash value is used for
prefix caching.
NOTE: Content-based hashing does not yet support LoRA.
Parameters:
- is_first_block (bool): A flag indicating if the block is the first in
the sequence.
@ -860,12 +880,15 @@ class PrefixCachingBlock(Block):
if this is the first block.
- cur_block_token_ids (List[int]): A list of token ids in the current
block. The current block is assumed to be full.
- extra_hash (Optional[int]): The hash value of additional factors
such as adapters that influence the block, apart from the token_ids.
Returns:
- int: The computed hash value for the block.
"""
assert (prev_block_hash is None) == is_first_block
return hash((is_first_block, prev_block_hash, *cur_block_token_ids))
return hash((is_first_block, prev_block_hash, *cur_block_token_ids,
extra_hash))
class ComputedBlocksTracker:
@ -935,12 +958,18 @@ class ComputedBlocksTracker:
assert len(token_ids) >= (i + 1) * self._block_size
block_token_ids = token_ids[i * self._block_size:(i + 1) *
self._block_size]
# NOTE: If there are any factors affecting the block besides
# token_ids, they should be added as input to extra_hash.
extra_hash = seq.extra_hash()
# This has to be kept in sync with the allocator's hash
# calculation.
block_hash = PrefixCachingBlock.hash_block_tokens(
is_first_block=prev_block_hash is None,
prev_block_hash=prev_block_hash,
cur_block_token_ids=block_token_ids,
extra_hash=extra_hash,
)
block_hashes_recorded.append(block_hash)
prev_block_hash = block_hash

View File

@ -151,8 +151,13 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
max_block_sliding_window=self.max_block_sliding_window,
)
if seq.get_token_ids():
# NOTE: If there are any factors affecting the block besides
# token_ids, they should be added as input to extra_hash.
extra_hash = seq.extra_hash()
# Add blocks to the block table only if the sequence is non empty.
block_table.allocate(seq.get_token_ids())
block_table.allocate(token_ids=seq.get_token_ids(),
extra_hash=extra_hash)
return block_table
@ -238,6 +243,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()),
num_lookahead_slots=num_lookahead_slots,
num_computed_slots=seq.data.get_num_computed_tokens(),
extra_hash=seq.extra_hash(),
)
# Return any new copy-on-writes.
new_cows = self.block_allocator.clear_copy_on_writes()

View File

@ -527,6 +527,19 @@ class Sequence:
hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
return hash((hashed_tokens, self.lora_int_id))
def extra_hash(self) -> Optional[int]:
"""
This function computes an extra hash for a sequence, specifically
designed for prefix caching mode. The final sequence hash is determined
by applying token_ids from the sequence's blocks.
"""
if self.prompt_adapter_id == 0 and self.lora_int_id == 0:
return None
# NOTE: If there are additional factors influencing the block aside from
# token_ids, include them as input parameters to the hash.
return hash((self.prompt_adapter_id, self.lora_int_id))
def num_hashed_tokens_of_block(self, logical_idx: int):
return logical_idx * self.block_size + self.block_size