[Core] support LoRA and prompt adapter in content-based hashing for Block Manager v2 prefix caching (#8240)

This commit is contained in:
Sungjae Lee 2024-12-14 00:51:25 +09:00 committed by GitHub
parent d1fa714cb1
commit c31d4a57a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 244 additions and 53 deletions

View File

@ -5,7 +5,7 @@ from unittest.mock import MagicMock
import pytest import pytest
from tests.core.utils import create_dummy_sequence from tests.core.utils import create_dummy_lora_sequence, create_dummy_sequence
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
from vllm.core.block.interfaces import Block, BlockAllocator from vllm.core.block.interfaces import Block, BlockAllocator
from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker, from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker,
@ -801,6 +801,7 @@ class TestPrefixCachingBlockAllocator:
block_size: int, block_size: int,
token_ids: List[int], token_ids: List[int],
allocator: PrefixCachingBlockAllocator, allocator: PrefixCachingBlockAllocator,
extra_hash: Optional[int] = None,
) -> List[PrefixCachingBlock]: ) -> List[PrefixCachingBlock]:
"""Helper method which creates a chain of blocks. """Helper method which creates a chain of blocks.
""" """
@ -816,7 +817,9 @@ class TestPrefixCachingBlockAllocator:
block_size:(block_number + 1) * block_size:(block_number + 1) *
block_size] block_size]
prev_block = allocator.allocate_immutable_block( prev_block = allocator.allocate_immutable_block(
prev_block=prev_block, token_ids=block_token_ids) prev_block=prev_block,
token_ids=block_token_ids,
extra_hash=extra_hash)
blocks.append(prev_block) blocks.append(prev_block)
return blocks return blocks
@ -931,3 +934,61 @@ class TestComputedBlocksTracker:
allocator.mark_blocks_as_computed([]) allocator.mark_blocks_as_computed([])
assert tracker.get_num_cached_tokens(seq) == len(tokens) assert tracker.get_num_cached_tokens(seq) == len(tokens)
@staticmethod
def test_correct_extra_hash():
"""
Test that the block hash is correctly computed based on the extra hash,
ensuring it matches the allocator's block hash, specifically for the
LoRA case, and that the correct number of cached tokens is retrieved.
"""
block_size = 4
allocator = CpuGpuBlockAllocator.create(
allocator_type="prefix_caching",
num_gpu_blocks=16,
num_cpu_blocks=16,
block_size=block_size,
)
gpu_allocator = allocator._allocators[Device.GPU]
tracker = ComputedBlocksTracker(
allocator=allocator,
block_size=block_size,
enable_caching=True,
)
tokens = list(range(block_size * 4))
# Create a dummy LoRA sequence with a specific LoRA ID.
lora_seq = create_dummy_lora_sequence(request_id=0,
token_ids=tokens,
block_size=block_size,
lora_int_id=1)
_ = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=tokens,
allocator=gpu_allocator,
extra_hash=lora_seq.extra_hash(),
)
allocator.mark_blocks_as_computed([])
# Create different dummy sequences that have the same token IDs
# but different LoRA IDs.
seq = create_dummy_sequence(request_id=1,
token_ids=tokens,
block_size=block_size)
different_lora_seq = create_dummy_lora_sequence(request_id=2,
token_ids=tokens,
block_size=block_size,
lora_int_id=2)
# Due to the different LoRA IDs, corresponding blocks are not cached.
assert tracker.get_num_cached_tokens(seq) == 0
assert tracker.get_num_cached_tokens(different_lora_seq) == 0
# The number of cached tokens matches the length of the tokens
# for the cached LoRA sequence.
assert tracker.get_num_cached_tokens(lora_seq) == len(tokens)

View File

@ -46,6 +46,16 @@ def create_dummy_prompt(
return prompt, seq_group return prompt, seq_group
def create_dummy_lora_sequence(request_id: int, token_ids: List[int],
block_size: int, lora_int_id: int) -> Sequence:
return Sequence(seq_id=request_id,
inputs=token_inputs(token_ids),
block_size=block_size,
lora_request=LoRARequest(lora_name="dummy",
lora_path="/dummy",
lora_int_id=lora_int_id))
def create_dummy_sequence(request_id: int, token_ids: List[int], def create_dummy_sequence(request_id: int, token_ids: List[int],
block_size: int) -> Sequence: block_size: int) -> Sequence:
return Sequence( return Sequence(

View File

@ -80,7 +80,8 @@ class BlockTable:
def allocate(self, def allocate(self,
token_ids: List[int], token_ids: List[int],
device: Device = Device.GPU) -> None: device: Device = Device.GPU,
extra_hash: Optional[int] = None) -> None:
"""Allocates memory blocks for storing the given sequence of token IDs. """Allocates memory blocks for storing the given sequence of token IDs.
This method allocates the required number of blocks to store the given This method allocates the required number of blocks to store the given
@ -90,12 +91,16 @@ class BlockTable:
token_ids (List[int]): The sequence of token IDs to be stored. token_ids (List[int]): The sequence of token IDs to be stored.
device (Device, optional): The device on which the blocks should be device (Device, optional): The device on which the blocks should be
allocated. Defaults to Device.GPU. allocated. Defaults to Device.GPU.
extra_hash (Optional[int]): The hash value of additional
factors, such as adapters, that influence the block hash
in the prefixcaching block.
""" """
assert not self._is_allocated assert not self._is_allocated
assert token_ids assert token_ids
blocks = self._allocate_blocks_for_token_ids(prev_block=None, blocks = self._allocate_blocks_for_token_ids(prev_block=None,
token_ids=token_ids, token_ids=token_ids,
device=device) device=device,
extra_hash=extra_hash)
self.update(blocks) self.update(blocks)
self._num_full_slots = len(token_ids) self._num_full_slots = len(token_ids)
@ -108,7 +113,8 @@ class BlockTable:
def append_token_ids(self, def append_token_ids(self,
token_ids: List[int], token_ids: List[int],
num_lookahead_slots: int = 0, num_lookahead_slots: int = 0,
num_computed_slots: Optional[int] = None) -> None: num_computed_slots: Optional[int] = None,
extra_hash: Optional[int] = None) -> None:
"""Appends a sequence of token IDs to the existing blocks in the """Appends a sequence of token IDs to the existing blocks in the
BlockTable. BlockTable.
@ -130,6 +136,9 @@ class BlockTable:
Without sliding window, None can be passed. Without sliding window, None can be passed.
Without chunked prefill, it should be the same as Without chunked prefill, it should be the same as
_num_full_slots. _num_full_slots.
extra_hash (Optional[int]): The hash value of additional
factors such as adapters that influence the block, apart
from the token_ids.
""" """
assert self._is_allocated, "no blocks have been allocated" assert self._is_allocated, "no blocks have been allocated"
assert len(self._blocks) > 0 assert len(self._blocks) > 0
@ -149,7 +158,8 @@ class BlockTable:
# Ensure there are enough empty slots for the new tokens plus # Ensure there are enough empty slots for the new tokens plus
# lookahead slots # lookahead slots
self.ensure_num_empty_slots(num_empty_slots=len(token_ids) + self.ensure_num_empty_slots(num_empty_slots=len(token_ids) +
num_lookahead_slots) num_lookahead_slots,
extra_hash=extra_hash)
# Update the blocks with the new tokens # Update the blocks with the new tokens
first_block_idx = self._num_full_slots // self._block_size first_block_idx = self._num_full_slots // self._block_size
@ -160,7 +170,9 @@ class BlockTable:
self._num_full_slots += len(token_ids) self._num_full_slots += len(token_ids)
def ensure_num_empty_slots(self, num_empty_slots: int) -> None: def ensure_num_empty_slots(self,
num_empty_slots: int,
extra_hash: Optional[int] = None) -> None:
"""Ensures that the BlockTable has at least the specified number of """Ensures that the BlockTable has at least the specified number of
empty slots available. empty slots available.
@ -171,6 +183,9 @@ class BlockTable:
Args: Args:
num_empty_slots (int): The minimum number of empty slots required. num_empty_slots (int): The minimum number of empty slots required.
extra_hash (Optional[int]): The hash value of additional
factors such as adapters that influence the block, apart
from the token_ids.
""" """
# Currently the block table only supports # Currently the block table only supports
# appending tokens to GPU blocks. # appending tokens to GPU blocks.
@ -187,7 +202,9 @@ class BlockTable:
assert len(self._blocks) > 0 assert len(self._blocks) > 0
self._blocks.append( self._blocks.append(
self._allocator.allocate_mutable_block( self._allocator.allocate_mutable_block(
prev_block=self._blocks[-1], device=device)) prev_block=self._blocks[-1],
device=device,
extra_hash=extra_hash))
def fork(self) -> "BlockTable": def fork(self) -> "BlockTable":
"""Creates a new BlockTable instance with a copy of the blocks from the """Creates a new BlockTable instance with a copy of the blocks from the
@ -259,9 +276,12 @@ class BlockTable:
# ones after the appended ones. # ones after the appended ones.
return sequence_token_ids[self.num_full_slots:] return sequence_token_ids[self.num_full_slots:]
def _allocate_blocks_for_token_ids(self, prev_block: Optional[Block], def _allocate_blocks_for_token_ids(
self,
prev_block: Optional[Block],
token_ids: List[int], token_ids: List[int],
device: Device) -> List[Block]: device: Device,
extra_hash: Optional[int] = None) -> List[Block]:
blocks: List[Block] = [] blocks: List[Block] = []
block_token_ids = [] block_token_ids = []
@ -275,8 +295,10 @@ class BlockTable:
if block_token_ids: if block_token_ids:
blocks.extend( blocks.extend(
self._allocator.allocate_immutable_blocks( self._allocator.allocate_immutable_blocks(
prev_block, block_token_ids=block_token_ids, prev_block,
device=device)) block_token_ids=block_token_ids,
device=device,
extra_hash=extra_hash))
prev_block = blocks[-1] prev_block = blocks[-1]
if tail_token_ids: if tail_token_ids:
@ -284,7 +306,7 @@ class BlockTable:
cur_token_ids = tail_token_ids[0] cur_token_ids = tail_token_ids[0]
block = self._allocator.allocate_mutable_block( block = self._allocator.allocate_mutable_block(
prev_block=prev_block, device=device) prev_block=prev_block, device=device, extra_hash=extra_hash)
block.append_token_ids(cur_token_ids) block.append_token_ids(cur_token_ids)
blocks.append(block) blocks.append(block)

View File

@ -177,7 +177,8 @@ class BlockPool:
token_ids=[], token_ids=[],
block_size=self._block_size, block_size=self._block_size,
allocator=self._allocator, allocator=self._allocator,
block_id=None)) block_id=None,
extra_hash=None))
def increase_pool(self): def increase_pool(self):
"""Doubles the internal pool size """Doubles the internal pool size
@ -194,10 +195,15 @@ class BlockPool:
token_ids=[], token_ids=[],
block_size=self._block_size, block_size=self._block_size,
allocator=self._allocator, allocator=self._allocator,
block_id=None)) block_id=None,
extra_hash=None))
def init_block(self, prev_block: Optional[Block], token_ids: List[int], def init_block(self,
block_size: int, physical_block_id: Optional[int]) -> Block: prev_block: Optional[Block],
token_ids: List[int],
block_size: int,
physical_block_id: Optional[int],
extra_hash: Optional[int] = None) -> Block:
if len(self._free_ids) == 0: if len(self._free_ids) == 0:
self.increase_pool() self.increase_pool()
assert len(self._free_ids) > 0 assert len(self._free_ids) > 0
@ -210,7 +216,8 @@ class BlockPool:
token_ids=token_ids, token_ids=token_ids,
block_size=block_size, block_size=block_size,
allocator=block._allocator, # type: ignore[attr-defined] allocator=block._allocator, # type: ignore[attr-defined]
block_id=physical_block_id) block_id=physical_block_id,
extra_hash=extra_hash)
block.pool_id = pool_id # type: ignore[attr-defined] block.pool_id = pool_id # type: ignore[attr-defined]
return block return block

View File

@ -121,23 +121,32 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
self.allocate_mutable_block(None, Device.GPU)) self.allocate_mutable_block(None, Device.GPU))
return self._null_block return self._null_block
def allocate_mutable_block(self, prev_block: Optional[Block], def allocate_mutable_block(self,
device: Device) -> Block: prev_block: Optional[Block],
device: Device,
extra_hash: Optional[int] = None) -> Block:
"""Allocates a new mutable block on the specified device. """Allocates a new mutable block on the specified device.
Args: Args:
prev_block (Optional[Block]): The previous block to in the sequence. prev_block (Optional[Block]): The previous block to in the sequence.
Used for prefix hashing. Used for prefix hashing.
device (Device): The device on which to allocate the new block. device (Device): The device on which to allocate the new block.
extra_hash (Optional[int]): The hash value of additional
factors, such as adapters, that influence the block hash
in the prefix caching block.
Returns: Returns:
Block: The newly allocated mutable block. Block: The newly allocated mutable block.
""" """
return self._allocators[device].allocate_mutable_block(prev_block) return self._allocators[device].allocate_mutable_block(
prev_block, extra_hash=extra_hash)
def allocate_immutable_blocks(self, prev_block: Optional[Block], def allocate_immutable_blocks(
self,
prev_block: Optional[Block],
block_token_ids: List[List[int]], block_token_ids: List[List[int]],
device: Device) -> List[Block]: device: Device,
extra_hash: Optional[int] = None) -> List[Block]:
"""Allocates a new group of immutable blocks with the provided block """Allocates a new group of immutable blocks with the provided block
token IDs on the specified device. token IDs on the specified device.
@ -147,17 +156,22 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
block_token_ids (List[int]): The list of block token IDs to be block_token_ids (List[int]): The list of block token IDs to be
stored in the new blocks. stored in the new blocks.
device (Device): The device on which to allocate the new block. device (Device): The device on which to allocate the new block.
extra_hash (Optional[int]): The hash value of additional
factors, such as adapters, that influence the block hash
in the prefix caching block.
Returns: Returns:
List[Block]: The newly allocated list of immutable blocks List[Block]: The newly allocated list of immutable blocks
containing the provided block token IDs. containing the provided block token IDs.
""" """
return self._allocators[device].allocate_immutable_blocks( return self._allocators[device].allocate_immutable_blocks(
prev_block, block_token_ids) prev_block, block_token_ids, extra_hash=extra_hash)
def allocate_immutable_block(self, prev_block: Optional[Block], def allocate_immutable_block(self,
prev_block: Optional[Block],
token_ids: List[int], token_ids: List[int],
device: Device) -> Block: device: Device,
extra_hash: Optional[int] = None) -> Block:
"""Allocates a new immutable block with the provided token IDs on the """Allocates a new immutable block with the provided token IDs on the
specified device. specified device.
@ -167,13 +181,16 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
token_ids (List[int]): The list of token IDs to be stored in the new token_ids (List[int]): The list of token IDs to be stored in the new
block. block.
device (Device): The device on which to allocate the new block. device (Device): The device on which to allocate the new block.
extra_hash (Optional[int]): The hash value of additional
factors, such as adapters, that influence the block hash
in the prefix caching block.
Returns: Returns:
Block: The newly allocated immutable block containing the provided Block: The newly allocated immutable block containing the provided
token IDs. token IDs.
""" """
return self._allocators[device].allocate_immutable_block( return self._allocators[device].allocate_immutable_block(
prev_block, token_ids) prev_block, token_ids, extra_hash=extra_hash)
def free(self, block: Block) -> None: def free(self, block: Block) -> None:
"""Frees the memory occupied by the given block. """Frees the memory occupied by the given block.
@ -387,6 +404,10 @@ class NullBlock(Block):
def prev_block(self): def prev_block(self):
return self._proxy.prev_block return self._proxy.prev_block
@property
def extra_hash(self):
return None
@property @property
def computed(self): def computed(self):
return self._proxy.computed return self._proxy.computed

View File

@ -50,6 +50,11 @@ class Block(ABC):
def prev_block(self) -> Optional["Block"]: def prev_block(self) -> Optional["Block"]:
pass pass
@property
@abstractmethod
def extra_hash(self) -> Optional[int]:
return None
@property @property
@abstractmethod @abstractmethod
def computed(self) -> bool: def computed(self) -> bool:
@ -81,6 +86,8 @@ class Block(ABC):
block_size: int, block_size: int,
allocator: "BlockAllocator", allocator: "BlockAllocator",
block_id: Optional[int] = None, block_id: Optional[int] = None,
computed: bool = False,
extra_hash: Optional[int] = None,
) -> "Block": ) -> "Block":
pass pass
@ -99,18 +106,20 @@ class Block(ABC):
class BlockAllocator(ABC): class BlockAllocator(ABC):
@abstractmethod @abstractmethod
def allocate_mutable_block(self, prev_block: Optional[Block]) -> Block: def allocate_mutable_block(self, prev_block: Optional[Block],
extra_hash: Optional[int]) -> Block:
pass pass
@abstractmethod @abstractmethod
def allocate_immutable_block(self, prev_block: Optional[Block], def allocate_immutable_block(self, prev_block: Optional[Block],
token_ids: List[int]) -> Block: token_ids: List[int],
extra_hash: Optional[int]) -> Block:
pass pass
@abstractmethod @abstractmethod
def allocate_immutable_blocks( def allocate_immutable_blocks(self, prev_block: Optional[Block],
self, prev_block: Optional[Block], block_token_ids: List[List[int]],
block_token_ids: List[List[int]]) -> List[Block]: extra_hash: Optional[int]) -> List[Block]:
pass pass
@abstractmethod @abstractmethod
@ -197,14 +206,18 @@ class BlockAllocator(ABC):
class DeviceAwareBlockAllocator(ABC): class DeviceAwareBlockAllocator(ABC):
@abstractmethod @abstractmethod
def allocate_mutable_block(self, prev_block: Optional[Block], def allocate_mutable_block(self,
device: Device) -> Block: prev_block: Optional[Block],
device: Device,
extra_hash: Optional[int] = None) -> Block:
pass pass
@abstractmethod @abstractmethod
def allocate_immutable_block(self, prev_block: Optional[Block], def allocate_immutable_block(self,
prev_block: Optional[Block],
token_ids: List[int], token_ids: List[int],
device: Device) -> Block: device: Device,
extra_hash: Optional[int] = None) -> Block:
pass pass
@abstractmethod @abstractmethod
@ -213,6 +226,7 @@ class DeviceAwareBlockAllocator(ABC):
prev_block: Optional[Block], prev_block: Optional[Block],
block_token_ids: List[List[int]], block_token_ids: List[List[int]],
device: Device, device: Device,
extra_hash: Optional[int] = None,
) -> List[Block]: ) -> List[Block]:
pass pass

View File

@ -63,6 +63,7 @@ class NaiveBlockAllocator(BlockAllocator):
def allocate_immutable_block(self, def allocate_immutable_block(self,
prev_block: Optional[Block], prev_block: Optional[Block],
token_ids: List[int], token_ids: List[int],
extra_hash: Optional[int] = None,
device: Optional[Device] = None) -> Block: device: Optional[Device] = None) -> Block:
"""Allocates a new immutable block with the given token IDs, linked to """Allocates a new immutable block with the given token IDs, linked to
the previous block. the previous block.
@ -85,6 +86,7 @@ class NaiveBlockAllocator(BlockAllocator):
self, self,
prev_block: Optional[Block], prev_block: Optional[Block],
block_token_ids: List[List[int]], block_token_ids: List[List[int]],
extra_hash: Optional[int] = None,
device: Optional[Device] = None) -> List[Block]: device: Optional[Device] = None) -> List[Block]:
assert device is None assert device is None
num_blocks = len(block_token_ids) num_blocks = len(block_token_ids)
@ -106,6 +108,7 @@ class NaiveBlockAllocator(BlockAllocator):
def allocate_mutable_block(self, def allocate_mutable_block(self,
prev_block: Optional[Block], prev_block: Optional[Block],
extra_hash: Optional[int] = None,
device: Optional[Device] = None) -> Block: device: Optional[Device] = None) -> Block:
"""Allocates a new mutable block, linked to the previous block. """Allocates a new mutable block, linked to the previous block.
@ -355,7 +358,8 @@ class NaiveBlock(Block):
block_size: int, block_size: int,
allocator: BlockAllocator, allocator: BlockAllocator,
block_id: Optional[int] = None, block_id: Optional[int] = None,
_cow_target: Optional[Block] = None): _cow_target: Optional[Block] = None,
extra_hash: Optional[int] = None):
self._token_ids: List[int] = [] self._token_ids: List[int] = []
self._block_size = block_size self._block_size = block_size
self._prev_block = prev_block self._prev_block = prev_block
@ -441,6 +445,10 @@ class NaiveBlock(Block):
def prev_block(self) -> Optional["Block"]: def prev_block(self) -> Optional["Block"]:
return self._prev_block return self._prev_block
@property
def extra_hash(self):
return None
@property @property
def content_hash(self) -> Optional[int]: def content_hash(self) -> Optional[int]:
return None return None

View File

@ -126,6 +126,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
allocator: BlockAllocator, allocator: BlockAllocator,
block_id: Optional[int] = None, block_id: Optional[int] = None,
computed: bool = False, computed: bool = False,
extra_hash: Optional[int] = None,
) -> Block: ) -> Block:
# Bind block to self. # Bind block to self.
allocator = self allocator = self
@ -137,11 +138,13 @@ class PrefixCachingBlockAllocator(BlockAllocator):
block_id=block_id, block_id=block_id,
allocator=allocator, allocator=allocator,
computed=computed, computed=computed,
extra_hash=extra_hash,
) )
def allocate_immutable_block(self, def allocate_immutable_block(self,
prev_block: Optional[Block], prev_block: Optional[Block],
token_ids: List[int], token_ids: List[int],
extra_hash: Optional[int] = None,
device: Optional[Device] = None) -> Block: device: Optional[Device] = None) -> Block:
"""Allocates an immutable block with the given token IDs, reusing cached """Allocates an immutable block with the given token IDs, reusing cached
blocks if possible. blocks if possible.
@ -160,7 +163,8 @@ class PrefixCachingBlockAllocator(BlockAllocator):
block = self._block_pool.init_block(prev_block=prev_block, block = self._block_pool.init_block(prev_block=prev_block,
token_ids=token_ids, token_ids=token_ids,
block_size=self._block_size, block_size=self._block_size,
physical_block_id=None) physical_block_id=None,
extra_hash=extra_hash)
assert block.content_hash is not None assert block.content_hash is not None
cached_block_id = self._cached_blocks.get(block.content_hash, None) cached_block_id = self._cached_blocks.get(block.content_hash, None)
@ -173,7 +177,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
self._block_pool.free_block(block) self._block_pool.free_block(block)
# No cached block => Allocate a new block # No cached block => Allocate a new block
block = self.allocate_mutable_block(prev_block) block = self.allocate_mutable_block(prev_block, extra_hash=extra_hash)
block.append_token_ids(token_ids) block.append_token_ids(token_ids)
return block return block
@ -181,17 +185,20 @@ class PrefixCachingBlockAllocator(BlockAllocator):
self, self,
prev_block: Optional[Block], prev_block: Optional[Block],
block_token_ids: List[List[int]], block_token_ids: List[List[int]],
extra_hash: Optional[int] = None,
device: Optional[Device] = None) -> List[Block]: device: Optional[Device] = None) -> List[Block]:
blocks = [] blocks = []
for token_ids in block_token_ids: for token_ids in block_token_ids:
prev_block = self.allocate_immutable_block(prev_block=prev_block, prev_block = self.allocate_immutable_block(prev_block=prev_block,
token_ids=token_ids, token_ids=token_ids,
device=device) device=device,
extra_hash=extra_hash)
blocks.append(prev_block) blocks.append(prev_block)
return blocks return blocks
def allocate_mutable_block(self, def allocate_mutable_block(self,
prev_block: Optional[Block], prev_block: Optional[Block],
extra_hash: Optional[int] = None,
device: Optional[Device] = None) -> Block: device: Optional[Device] = None) -> Block:
"""Allocates a mutable block. If there are no free blocks, this will """Allocates a mutable block. If there are no free blocks, this will
evict unused cached blocks. evict unused cached blocks.
@ -210,7 +217,8 @@ class PrefixCachingBlockAllocator(BlockAllocator):
block = self._block_pool.init_block(prev_block=prev_block, block = self._block_pool.init_block(prev_block=prev_block,
token_ids=[], token_ids=[],
block_size=self._block_size, block_size=self._block_size,
physical_block_id=block_id) physical_block_id=block_id,
extra_hash=extra_hash)
assert not block.computed assert not block.computed
assert block.content_hash is None assert block.content_hash is None
return block return block
@ -382,7 +390,8 @@ class PrefixCachingBlockAllocator(BlockAllocator):
prev_block=prev_block, prev_block=prev_block,
token_ids=block.token_ids, token_ids=block.token_ids,
block_size=self._block_size, block_size=self._block_size,
physical_block_id=block_id) physical_block_id=block_id,
extra_hash=block.extra_hash)
forked_blocks.append(forked_block) forked_blocks.append(forked_block)
prev_block = forked_blocks[-1] prev_block = forked_blocks[-1]
@ -608,10 +617,12 @@ class PrefixCachingBlockAllocator(BlockAllocator):
# existing "block" object # existing "block" object
if block.is_full: if block.is_full:
tmp_block = self.allocate_immutable_block( tmp_block = self.allocate_immutable_block(
prev_block=block.prev_block, token_ids=block.token_ids) prev_block=block.prev_block,
token_ids=block.token_ids,
extra_hash=block.extra_hash)
else: else:
tmp_block = self.allocate_mutable_block( tmp_block = self.allocate_mutable_block(
prev_block=block.prev_block) prev_block=block.prev_block, extra_hash=block.extra_hash)
tmp_block.append_token_ids(block.token_ids) tmp_block.append_token_ids(block.token_ids)
block_id = tmp_block.block_id block_id = tmp_block.block_id
@ -679,6 +690,8 @@ class PrefixCachingBlock(Block):
caching block allocator associated with this block. caching block allocator associated with this block.
block_id (Optional[int], optional): The physical block index block_id (Optional[int], optional): The physical block index
of this block. Defaults to None. of this block. Defaults to None.
extra_hash (Optional[int]): The hash value of additional factors
such as adapters that influence the block, apart from the token_ids.
""" """
def __init__( def __init__(
@ -689,6 +702,7 @@ class PrefixCachingBlock(Block):
allocator: BlockAllocator, allocator: BlockAllocator,
block_id: Optional[int] = None, block_id: Optional[int] = None,
computed: bool = False, computed: bool = False,
extra_hash: Optional[int] = None,
): ):
assert isinstance(allocator, PrefixCachingBlockAllocator), ( assert isinstance(allocator, PrefixCachingBlockAllocator), (
"Currently this class is only tested with " "Currently this class is only tested with "
@ -702,6 +716,7 @@ class PrefixCachingBlock(Block):
self._allocator = allocator self._allocator = allocator
self._last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME self._last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME
self._computed = computed self._computed = computed
self._extra_hash = extra_hash
# On the first time, we create the block object, and next we only # On the first time, we create the block object, and next we only
# reinitialize it # reinitialize it
@ -811,6 +826,10 @@ class PrefixCachingBlock(Block):
def prev_block(self) -> Optional[Block]: def prev_block(self) -> Optional[Block]:
return self._prev_block return self._prev_block
@property
def extra_hash(self) -> Optional[int]:
return self._extra_hash
@property @property
def content_hash(self) -> Optional[int]: def content_hash(self) -> Optional[int]:
"""Return the content-based hash of the current block, or None if it is """Return the content-based hash of the current block, or None if it is
@ -841,18 +860,19 @@ class PrefixCachingBlock(Block):
self._cached_content_hash = PrefixCachingBlock.hash_block_tokens( self._cached_content_hash = PrefixCachingBlock.hash_block_tokens(
is_first_block, is_first_block,
prev_block_hash, prev_block_hash,
cur_block_token_ids=self.token_ids) cur_block_token_ids=self.token_ids,
extra_hash=self._extra_hash)
return self._cached_content_hash return self._cached_content_hash
@staticmethod @staticmethod
def hash_block_tokens(is_first_block: bool, prev_block_hash: Optional[int], def hash_block_tokens(is_first_block: bool,
cur_block_token_ids: List[int]) -> int: prev_block_hash: Optional[int],
cur_block_token_ids: List[int],
extra_hash: Optional[int] = None) -> int:
"""Computes a hash value corresponding to the contents of a block and """Computes a hash value corresponding to the contents of a block and
the contents of the preceding block(s). The hash value is used for the contents of the preceding block(s). The hash value is used for
prefix caching. prefix caching.
NOTE: Content-based hashing does not yet support LoRA.
Parameters: Parameters:
- is_first_block (bool): A flag indicating if the block is the first in - is_first_block (bool): A flag indicating if the block is the first in
the sequence. the sequence.
@ -860,12 +880,15 @@ class PrefixCachingBlock(Block):
if this is the first block. if this is the first block.
- cur_block_token_ids (List[int]): A list of token ids in the current - cur_block_token_ids (List[int]): A list of token ids in the current
block. The current block is assumed to be full. block. The current block is assumed to be full.
- extra_hash (Optional[int]): The hash value of additional factors
such as adapters that influence the block, apart from the token_ids.
Returns: Returns:
- int: The computed hash value for the block. - int: The computed hash value for the block.
""" """
assert (prev_block_hash is None) == is_first_block assert (prev_block_hash is None) == is_first_block
return hash((is_first_block, prev_block_hash, *cur_block_token_ids)) return hash((is_first_block, prev_block_hash, *cur_block_token_ids,
extra_hash))
class ComputedBlocksTracker: class ComputedBlocksTracker:
@ -935,12 +958,18 @@ class ComputedBlocksTracker:
assert len(token_ids) >= (i + 1) * self._block_size assert len(token_ids) >= (i + 1) * self._block_size
block_token_ids = token_ids[i * self._block_size:(i + 1) * block_token_ids = token_ids[i * self._block_size:(i + 1) *
self._block_size] self._block_size]
# NOTE: If there are any factors affecting the block besides
# token_ids, they should be added as input to extra_hash.
extra_hash = seq.extra_hash()
# This has to be kept in sync with the allocator's hash # This has to be kept in sync with the allocator's hash
# calculation. # calculation.
block_hash = PrefixCachingBlock.hash_block_tokens( block_hash = PrefixCachingBlock.hash_block_tokens(
is_first_block=prev_block_hash is None, is_first_block=prev_block_hash is None,
prev_block_hash=prev_block_hash, prev_block_hash=prev_block_hash,
cur_block_token_ids=block_token_ids, cur_block_token_ids=block_token_ids,
extra_hash=extra_hash,
) )
block_hashes_recorded.append(block_hash) block_hashes_recorded.append(block_hash)
prev_block_hash = block_hash prev_block_hash = block_hash

View File

@ -151,8 +151,13 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
max_block_sliding_window=self.max_block_sliding_window, max_block_sliding_window=self.max_block_sliding_window,
) )
if seq.get_token_ids(): if seq.get_token_ids():
# NOTE: If there are any factors affecting the block besides
# token_ids, they should be added as input to extra_hash.
extra_hash = seq.extra_hash()
# Add blocks to the block table only if the sequence is non empty. # Add blocks to the block table only if the sequence is non empty.
block_table.allocate(seq.get_token_ids()) block_table.allocate(token_ids=seq.get_token_ids(),
extra_hash=extra_hash)
return block_table return block_table
@ -238,6 +243,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()), token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()),
num_lookahead_slots=num_lookahead_slots, num_lookahead_slots=num_lookahead_slots,
num_computed_slots=seq.data.get_num_computed_tokens(), num_computed_slots=seq.data.get_num_computed_tokens(),
extra_hash=seq.extra_hash(),
) )
# Return any new copy-on-writes. # Return any new copy-on-writes.
new_cows = self.block_allocator.clear_copy_on_writes() new_cows = self.block_allocator.clear_copy_on_writes()

View File

@ -527,6 +527,19 @@ class Sequence:
hashed_tokens = self.data.get_prefix_token_ids(num_tokens) hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
return hash((hashed_tokens, self.lora_int_id)) return hash((hashed_tokens, self.lora_int_id))
def extra_hash(self) -> Optional[int]:
"""
This function computes an extra hash for a sequence, specifically
designed for prefix caching mode. The final sequence hash is determined
by applying token_ids from the sequence's blocks.
"""
if self.prompt_adapter_id == 0 and self.lora_int_id == 0:
return None
# NOTE: If there are additional factors influencing the block aside from
# token_ids, include them as input parameters to the hash.
return hash((self.prompt_adapter_id, self.lora_int_id))
def num_hashed_tokens_of_block(self, logical_idx: int): def num_hashed_tokens_of_block(self, logical_idx: int):
return logical_idx * self.block_size + self.block_size return logical_idx * self.block_size + self.block_size