mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 12:15:56 +08:00
[Core] support LoRA and prompt adapter in content-based hashing for Block Manager v2 prefix caching (#8240)
This commit is contained in:
parent
d1fa714cb1
commit
c31d4a57a6
@ -5,7 +5,7 @@ from unittest.mock import MagicMock
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from tests.core.utils import create_dummy_sequence
|
from tests.core.utils import create_dummy_lora_sequence, create_dummy_sequence
|
||||||
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
|
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
|
||||||
from vllm.core.block.interfaces import Block, BlockAllocator
|
from vllm.core.block.interfaces import Block, BlockAllocator
|
||||||
from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker,
|
from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker,
|
||||||
@ -801,6 +801,7 @@ class TestPrefixCachingBlockAllocator:
|
|||||||
block_size: int,
|
block_size: int,
|
||||||
token_ids: List[int],
|
token_ids: List[int],
|
||||||
allocator: PrefixCachingBlockAllocator,
|
allocator: PrefixCachingBlockAllocator,
|
||||||
|
extra_hash: Optional[int] = None,
|
||||||
) -> List[PrefixCachingBlock]:
|
) -> List[PrefixCachingBlock]:
|
||||||
"""Helper method which creates a chain of blocks.
|
"""Helper method which creates a chain of blocks.
|
||||||
"""
|
"""
|
||||||
@ -816,7 +817,9 @@ class TestPrefixCachingBlockAllocator:
|
|||||||
block_size:(block_number + 1) *
|
block_size:(block_number + 1) *
|
||||||
block_size]
|
block_size]
|
||||||
prev_block = allocator.allocate_immutable_block(
|
prev_block = allocator.allocate_immutable_block(
|
||||||
prev_block=prev_block, token_ids=block_token_ids)
|
prev_block=prev_block,
|
||||||
|
token_ids=block_token_ids,
|
||||||
|
extra_hash=extra_hash)
|
||||||
blocks.append(prev_block)
|
blocks.append(prev_block)
|
||||||
|
|
||||||
return blocks
|
return blocks
|
||||||
@ -931,3 +934,61 @@ class TestComputedBlocksTracker:
|
|||||||
allocator.mark_blocks_as_computed([])
|
allocator.mark_blocks_as_computed([])
|
||||||
|
|
||||||
assert tracker.get_num_cached_tokens(seq) == len(tokens)
|
assert tracker.get_num_cached_tokens(seq) == len(tokens)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def test_correct_extra_hash():
|
||||||
|
"""
|
||||||
|
Test that the block hash is correctly computed based on the extra hash,
|
||||||
|
ensuring it matches the allocator's block hash, specifically for the
|
||||||
|
LoRA case, and that the correct number of cached tokens is retrieved.
|
||||||
|
"""
|
||||||
|
block_size = 4
|
||||||
|
allocator = CpuGpuBlockAllocator.create(
|
||||||
|
allocator_type="prefix_caching",
|
||||||
|
num_gpu_blocks=16,
|
||||||
|
num_cpu_blocks=16,
|
||||||
|
block_size=block_size,
|
||||||
|
)
|
||||||
|
gpu_allocator = allocator._allocators[Device.GPU]
|
||||||
|
|
||||||
|
tracker = ComputedBlocksTracker(
|
||||||
|
allocator=allocator,
|
||||||
|
block_size=block_size,
|
||||||
|
enable_caching=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
tokens = list(range(block_size * 4))
|
||||||
|
|
||||||
|
# Create a dummy LoRA sequence with a specific LoRA ID.
|
||||||
|
lora_seq = create_dummy_lora_sequence(request_id=0,
|
||||||
|
token_ids=tokens,
|
||||||
|
block_size=block_size,
|
||||||
|
lora_int_id=1)
|
||||||
|
|
||||||
|
_ = TestPrefixCachingBlockAllocator.create_immutable_chain(
|
||||||
|
block_size=block_size,
|
||||||
|
token_ids=tokens,
|
||||||
|
allocator=gpu_allocator,
|
||||||
|
extra_hash=lora_seq.extra_hash(),
|
||||||
|
)
|
||||||
|
|
||||||
|
allocator.mark_blocks_as_computed([])
|
||||||
|
|
||||||
|
# Create different dummy sequences that have the same token IDs
|
||||||
|
# but different LoRA IDs.
|
||||||
|
seq = create_dummy_sequence(request_id=1,
|
||||||
|
token_ids=tokens,
|
||||||
|
block_size=block_size)
|
||||||
|
|
||||||
|
different_lora_seq = create_dummy_lora_sequence(request_id=2,
|
||||||
|
token_ids=tokens,
|
||||||
|
block_size=block_size,
|
||||||
|
lora_int_id=2)
|
||||||
|
|
||||||
|
# Due to the different LoRA IDs, corresponding blocks are not cached.
|
||||||
|
assert tracker.get_num_cached_tokens(seq) == 0
|
||||||
|
assert tracker.get_num_cached_tokens(different_lora_seq) == 0
|
||||||
|
|
||||||
|
# The number of cached tokens matches the length of the tokens
|
||||||
|
# for the cached LoRA sequence.
|
||||||
|
assert tracker.get_num_cached_tokens(lora_seq) == len(tokens)
|
||||||
|
|||||||
@ -46,6 +46,16 @@ def create_dummy_prompt(
|
|||||||
return prompt, seq_group
|
return prompt, seq_group
|
||||||
|
|
||||||
|
|
||||||
|
def create_dummy_lora_sequence(request_id: int, token_ids: List[int],
|
||||||
|
block_size: int, lora_int_id: int) -> Sequence:
|
||||||
|
return Sequence(seq_id=request_id,
|
||||||
|
inputs=token_inputs(token_ids),
|
||||||
|
block_size=block_size,
|
||||||
|
lora_request=LoRARequest(lora_name="dummy",
|
||||||
|
lora_path="/dummy",
|
||||||
|
lora_int_id=lora_int_id))
|
||||||
|
|
||||||
|
|
||||||
def create_dummy_sequence(request_id: int, token_ids: List[int],
|
def create_dummy_sequence(request_id: int, token_ids: List[int],
|
||||||
block_size: int) -> Sequence:
|
block_size: int) -> Sequence:
|
||||||
return Sequence(
|
return Sequence(
|
||||||
|
|||||||
@ -80,7 +80,8 @@ class BlockTable:
|
|||||||
|
|
||||||
def allocate(self,
|
def allocate(self,
|
||||||
token_ids: List[int],
|
token_ids: List[int],
|
||||||
device: Device = Device.GPU) -> None:
|
device: Device = Device.GPU,
|
||||||
|
extra_hash: Optional[int] = None) -> None:
|
||||||
"""Allocates memory blocks for storing the given sequence of token IDs.
|
"""Allocates memory blocks for storing the given sequence of token IDs.
|
||||||
|
|
||||||
This method allocates the required number of blocks to store the given
|
This method allocates the required number of blocks to store the given
|
||||||
@ -90,12 +91,16 @@ class BlockTable:
|
|||||||
token_ids (List[int]): The sequence of token IDs to be stored.
|
token_ids (List[int]): The sequence of token IDs to be stored.
|
||||||
device (Device, optional): The device on which the blocks should be
|
device (Device, optional): The device on which the blocks should be
|
||||||
allocated. Defaults to Device.GPU.
|
allocated. Defaults to Device.GPU.
|
||||||
|
extra_hash (Optional[int]): The hash value of additional
|
||||||
|
factors, such as adapters, that influence the block hash
|
||||||
|
in the prefixcaching block.
|
||||||
"""
|
"""
|
||||||
assert not self._is_allocated
|
assert not self._is_allocated
|
||||||
assert token_ids
|
assert token_ids
|
||||||
blocks = self._allocate_blocks_for_token_ids(prev_block=None,
|
blocks = self._allocate_blocks_for_token_ids(prev_block=None,
|
||||||
token_ids=token_ids,
|
token_ids=token_ids,
|
||||||
device=device)
|
device=device,
|
||||||
|
extra_hash=extra_hash)
|
||||||
self.update(blocks)
|
self.update(blocks)
|
||||||
self._num_full_slots = len(token_ids)
|
self._num_full_slots = len(token_ids)
|
||||||
|
|
||||||
@ -108,7 +113,8 @@ class BlockTable:
|
|||||||
def append_token_ids(self,
|
def append_token_ids(self,
|
||||||
token_ids: List[int],
|
token_ids: List[int],
|
||||||
num_lookahead_slots: int = 0,
|
num_lookahead_slots: int = 0,
|
||||||
num_computed_slots: Optional[int] = None) -> None:
|
num_computed_slots: Optional[int] = None,
|
||||||
|
extra_hash: Optional[int] = None) -> None:
|
||||||
"""Appends a sequence of token IDs to the existing blocks in the
|
"""Appends a sequence of token IDs to the existing blocks in the
|
||||||
BlockTable.
|
BlockTable.
|
||||||
|
|
||||||
@ -130,6 +136,9 @@ class BlockTable:
|
|||||||
Without sliding window, None can be passed.
|
Without sliding window, None can be passed.
|
||||||
Without chunked prefill, it should be the same as
|
Without chunked prefill, it should be the same as
|
||||||
_num_full_slots.
|
_num_full_slots.
|
||||||
|
extra_hash (Optional[int]): The hash value of additional
|
||||||
|
factors such as adapters that influence the block, apart
|
||||||
|
from the token_ids.
|
||||||
"""
|
"""
|
||||||
assert self._is_allocated, "no blocks have been allocated"
|
assert self._is_allocated, "no blocks have been allocated"
|
||||||
assert len(self._blocks) > 0
|
assert len(self._blocks) > 0
|
||||||
@ -149,7 +158,8 @@ class BlockTable:
|
|||||||
# Ensure there are enough empty slots for the new tokens plus
|
# Ensure there are enough empty slots for the new tokens plus
|
||||||
# lookahead slots
|
# lookahead slots
|
||||||
self.ensure_num_empty_slots(num_empty_slots=len(token_ids) +
|
self.ensure_num_empty_slots(num_empty_slots=len(token_ids) +
|
||||||
num_lookahead_slots)
|
num_lookahead_slots,
|
||||||
|
extra_hash=extra_hash)
|
||||||
|
|
||||||
# Update the blocks with the new tokens
|
# Update the blocks with the new tokens
|
||||||
first_block_idx = self._num_full_slots // self._block_size
|
first_block_idx = self._num_full_slots // self._block_size
|
||||||
@ -160,7 +170,9 @@ class BlockTable:
|
|||||||
|
|
||||||
self._num_full_slots += len(token_ids)
|
self._num_full_slots += len(token_ids)
|
||||||
|
|
||||||
def ensure_num_empty_slots(self, num_empty_slots: int) -> None:
|
def ensure_num_empty_slots(self,
|
||||||
|
num_empty_slots: int,
|
||||||
|
extra_hash: Optional[int] = None) -> None:
|
||||||
"""Ensures that the BlockTable has at least the specified number of
|
"""Ensures that the BlockTable has at least the specified number of
|
||||||
empty slots available.
|
empty slots available.
|
||||||
|
|
||||||
@ -171,6 +183,9 @@ class BlockTable:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_empty_slots (int): The minimum number of empty slots required.
|
num_empty_slots (int): The minimum number of empty slots required.
|
||||||
|
extra_hash (Optional[int]): The hash value of additional
|
||||||
|
factors such as adapters that influence the block, apart
|
||||||
|
from the token_ids.
|
||||||
"""
|
"""
|
||||||
# Currently the block table only supports
|
# Currently the block table only supports
|
||||||
# appending tokens to GPU blocks.
|
# appending tokens to GPU blocks.
|
||||||
@ -187,7 +202,9 @@ class BlockTable:
|
|||||||
assert len(self._blocks) > 0
|
assert len(self._blocks) > 0
|
||||||
self._blocks.append(
|
self._blocks.append(
|
||||||
self._allocator.allocate_mutable_block(
|
self._allocator.allocate_mutable_block(
|
||||||
prev_block=self._blocks[-1], device=device))
|
prev_block=self._blocks[-1],
|
||||||
|
device=device,
|
||||||
|
extra_hash=extra_hash))
|
||||||
|
|
||||||
def fork(self) -> "BlockTable":
|
def fork(self) -> "BlockTable":
|
||||||
"""Creates a new BlockTable instance with a copy of the blocks from the
|
"""Creates a new BlockTable instance with a copy of the blocks from the
|
||||||
@ -259,9 +276,12 @@ class BlockTable:
|
|||||||
# ones after the appended ones.
|
# ones after the appended ones.
|
||||||
return sequence_token_ids[self.num_full_slots:]
|
return sequence_token_ids[self.num_full_slots:]
|
||||||
|
|
||||||
def _allocate_blocks_for_token_ids(self, prev_block: Optional[Block],
|
def _allocate_blocks_for_token_ids(
|
||||||
|
self,
|
||||||
|
prev_block: Optional[Block],
|
||||||
token_ids: List[int],
|
token_ids: List[int],
|
||||||
device: Device) -> List[Block]:
|
device: Device,
|
||||||
|
extra_hash: Optional[int] = None) -> List[Block]:
|
||||||
blocks: List[Block] = []
|
blocks: List[Block] = []
|
||||||
|
|
||||||
block_token_ids = []
|
block_token_ids = []
|
||||||
@ -275,8 +295,10 @@ class BlockTable:
|
|||||||
if block_token_ids:
|
if block_token_ids:
|
||||||
blocks.extend(
|
blocks.extend(
|
||||||
self._allocator.allocate_immutable_blocks(
|
self._allocator.allocate_immutable_blocks(
|
||||||
prev_block, block_token_ids=block_token_ids,
|
prev_block,
|
||||||
device=device))
|
block_token_ids=block_token_ids,
|
||||||
|
device=device,
|
||||||
|
extra_hash=extra_hash))
|
||||||
prev_block = blocks[-1]
|
prev_block = blocks[-1]
|
||||||
|
|
||||||
if tail_token_ids:
|
if tail_token_ids:
|
||||||
@ -284,7 +306,7 @@ class BlockTable:
|
|||||||
cur_token_ids = tail_token_ids[0]
|
cur_token_ids = tail_token_ids[0]
|
||||||
|
|
||||||
block = self._allocator.allocate_mutable_block(
|
block = self._allocator.allocate_mutable_block(
|
||||||
prev_block=prev_block, device=device)
|
prev_block=prev_block, device=device, extra_hash=extra_hash)
|
||||||
block.append_token_ids(cur_token_ids)
|
block.append_token_ids(cur_token_ids)
|
||||||
|
|
||||||
blocks.append(block)
|
blocks.append(block)
|
||||||
|
|||||||
@ -177,7 +177,8 @@ class BlockPool:
|
|||||||
token_ids=[],
|
token_ids=[],
|
||||||
block_size=self._block_size,
|
block_size=self._block_size,
|
||||||
allocator=self._allocator,
|
allocator=self._allocator,
|
||||||
block_id=None))
|
block_id=None,
|
||||||
|
extra_hash=None))
|
||||||
|
|
||||||
def increase_pool(self):
|
def increase_pool(self):
|
||||||
"""Doubles the internal pool size
|
"""Doubles the internal pool size
|
||||||
@ -194,10 +195,15 @@ class BlockPool:
|
|||||||
token_ids=[],
|
token_ids=[],
|
||||||
block_size=self._block_size,
|
block_size=self._block_size,
|
||||||
allocator=self._allocator,
|
allocator=self._allocator,
|
||||||
block_id=None))
|
block_id=None,
|
||||||
|
extra_hash=None))
|
||||||
|
|
||||||
def init_block(self, prev_block: Optional[Block], token_ids: List[int],
|
def init_block(self,
|
||||||
block_size: int, physical_block_id: Optional[int]) -> Block:
|
prev_block: Optional[Block],
|
||||||
|
token_ids: List[int],
|
||||||
|
block_size: int,
|
||||||
|
physical_block_id: Optional[int],
|
||||||
|
extra_hash: Optional[int] = None) -> Block:
|
||||||
if len(self._free_ids) == 0:
|
if len(self._free_ids) == 0:
|
||||||
self.increase_pool()
|
self.increase_pool()
|
||||||
assert len(self._free_ids) > 0
|
assert len(self._free_ids) > 0
|
||||||
@ -210,7 +216,8 @@ class BlockPool:
|
|||||||
token_ids=token_ids,
|
token_ids=token_ids,
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
allocator=block._allocator, # type: ignore[attr-defined]
|
allocator=block._allocator, # type: ignore[attr-defined]
|
||||||
block_id=physical_block_id)
|
block_id=physical_block_id,
|
||||||
|
extra_hash=extra_hash)
|
||||||
block.pool_id = pool_id # type: ignore[attr-defined]
|
block.pool_id = pool_id # type: ignore[attr-defined]
|
||||||
return block
|
return block
|
||||||
|
|
||||||
|
|||||||
@ -121,23 +121,32 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
|
|||||||
self.allocate_mutable_block(None, Device.GPU))
|
self.allocate_mutable_block(None, Device.GPU))
|
||||||
return self._null_block
|
return self._null_block
|
||||||
|
|
||||||
def allocate_mutable_block(self, prev_block: Optional[Block],
|
def allocate_mutable_block(self,
|
||||||
device: Device) -> Block:
|
prev_block: Optional[Block],
|
||||||
|
device: Device,
|
||||||
|
extra_hash: Optional[int] = None) -> Block:
|
||||||
"""Allocates a new mutable block on the specified device.
|
"""Allocates a new mutable block on the specified device.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prev_block (Optional[Block]): The previous block to in the sequence.
|
prev_block (Optional[Block]): The previous block to in the sequence.
|
||||||
Used for prefix hashing.
|
Used for prefix hashing.
|
||||||
device (Device): The device on which to allocate the new block.
|
device (Device): The device on which to allocate the new block.
|
||||||
|
extra_hash (Optional[int]): The hash value of additional
|
||||||
|
factors, such as adapters, that influence the block hash
|
||||||
|
in the prefix caching block.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Block: The newly allocated mutable block.
|
Block: The newly allocated mutable block.
|
||||||
"""
|
"""
|
||||||
return self._allocators[device].allocate_mutable_block(prev_block)
|
return self._allocators[device].allocate_mutable_block(
|
||||||
|
prev_block, extra_hash=extra_hash)
|
||||||
|
|
||||||
def allocate_immutable_blocks(self, prev_block: Optional[Block],
|
def allocate_immutable_blocks(
|
||||||
|
self,
|
||||||
|
prev_block: Optional[Block],
|
||||||
block_token_ids: List[List[int]],
|
block_token_ids: List[List[int]],
|
||||||
device: Device) -> List[Block]:
|
device: Device,
|
||||||
|
extra_hash: Optional[int] = None) -> List[Block]:
|
||||||
"""Allocates a new group of immutable blocks with the provided block
|
"""Allocates a new group of immutable blocks with the provided block
|
||||||
token IDs on the specified device.
|
token IDs on the specified device.
|
||||||
|
|
||||||
@ -147,17 +156,22 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
|
|||||||
block_token_ids (List[int]): The list of block token IDs to be
|
block_token_ids (List[int]): The list of block token IDs to be
|
||||||
stored in the new blocks.
|
stored in the new blocks.
|
||||||
device (Device): The device on which to allocate the new block.
|
device (Device): The device on which to allocate the new block.
|
||||||
|
extra_hash (Optional[int]): The hash value of additional
|
||||||
|
factors, such as adapters, that influence the block hash
|
||||||
|
in the prefix caching block.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Block]: The newly allocated list of immutable blocks
|
List[Block]: The newly allocated list of immutable blocks
|
||||||
containing the provided block token IDs.
|
containing the provided block token IDs.
|
||||||
"""
|
"""
|
||||||
return self._allocators[device].allocate_immutable_blocks(
|
return self._allocators[device].allocate_immutable_blocks(
|
||||||
prev_block, block_token_ids)
|
prev_block, block_token_ids, extra_hash=extra_hash)
|
||||||
|
|
||||||
def allocate_immutable_block(self, prev_block: Optional[Block],
|
def allocate_immutable_block(self,
|
||||||
|
prev_block: Optional[Block],
|
||||||
token_ids: List[int],
|
token_ids: List[int],
|
||||||
device: Device) -> Block:
|
device: Device,
|
||||||
|
extra_hash: Optional[int] = None) -> Block:
|
||||||
"""Allocates a new immutable block with the provided token IDs on the
|
"""Allocates a new immutable block with the provided token IDs on the
|
||||||
specified device.
|
specified device.
|
||||||
|
|
||||||
@ -167,13 +181,16 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
|
|||||||
token_ids (List[int]): The list of token IDs to be stored in the new
|
token_ids (List[int]): The list of token IDs to be stored in the new
|
||||||
block.
|
block.
|
||||||
device (Device): The device on which to allocate the new block.
|
device (Device): The device on which to allocate the new block.
|
||||||
|
extra_hash (Optional[int]): The hash value of additional
|
||||||
|
factors, such as adapters, that influence the block hash
|
||||||
|
in the prefix caching block.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Block: The newly allocated immutable block containing the provided
|
Block: The newly allocated immutable block containing the provided
|
||||||
token IDs.
|
token IDs.
|
||||||
"""
|
"""
|
||||||
return self._allocators[device].allocate_immutable_block(
|
return self._allocators[device].allocate_immutable_block(
|
||||||
prev_block, token_ids)
|
prev_block, token_ids, extra_hash=extra_hash)
|
||||||
|
|
||||||
def free(self, block: Block) -> None:
|
def free(self, block: Block) -> None:
|
||||||
"""Frees the memory occupied by the given block.
|
"""Frees the memory occupied by the given block.
|
||||||
@ -387,6 +404,10 @@ class NullBlock(Block):
|
|||||||
def prev_block(self):
|
def prev_block(self):
|
||||||
return self._proxy.prev_block
|
return self._proxy.prev_block
|
||||||
|
|
||||||
|
@property
|
||||||
|
def extra_hash(self):
|
||||||
|
return None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def computed(self):
|
def computed(self):
|
||||||
return self._proxy.computed
|
return self._proxy.computed
|
||||||
|
|||||||
@ -50,6 +50,11 @@ class Block(ABC):
|
|||||||
def prev_block(self) -> Optional["Block"]:
|
def prev_block(self) -> Optional["Block"]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def extra_hash(self) -> Optional[int]:
|
||||||
|
return None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def computed(self) -> bool:
|
def computed(self) -> bool:
|
||||||
@ -81,6 +86,8 @@ class Block(ABC):
|
|||||||
block_size: int,
|
block_size: int,
|
||||||
allocator: "BlockAllocator",
|
allocator: "BlockAllocator",
|
||||||
block_id: Optional[int] = None,
|
block_id: Optional[int] = None,
|
||||||
|
computed: bool = False,
|
||||||
|
extra_hash: Optional[int] = None,
|
||||||
) -> "Block":
|
) -> "Block":
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -99,18 +106,20 @@ class Block(ABC):
|
|||||||
class BlockAllocator(ABC):
|
class BlockAllocator(ABC):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def allocate_mutable_block(self, prev_block: Optional[Block]) -> Block:
|
def allocate_mutable_block(self, prev_block: Optional[Block],
|
||||||
|
extra_hash: Optional[int]) -> Block:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def allocate_immutable_block(self, prev_block: Optional[Block],
|
def allocate_immutable_block(self, prev_block: Optional[Block],
|
||||||
token_ids: List[int]) -> Block:
|
token_ids: List[int],
|
||||||
|
extra_hash: Optional[int]) -> Block:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def allocate_immutable_blocks(
|
def allocate_immutable_blocks(self, prev_block: Optional[Block],
|
||||||
self, prev_block: Optional[Block],
|
block_token_ids: List[List[int]],
|
||||||
block_token_ids: List[List[int]]) -> List[Block]:
|
extra_hash: Optional[int]) -> List[Block]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -197,14 +206,18 @@ class BlockAllocator(ABC):
|
|||||||
class DeviceAwareBlockAllocator(ABC):
|
class DeviceAwareBlockAllocator(ABC):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def allocate_mutable_block(self, prev_block: Optional[Block],
|
def allocate_mutable_block(self,
|
||||||
device: Device) -> Block:
|
prev_block: Optional[Block],
|
||||||
|
device: Device,
|
||||||
|
extra_hash: Optional[int] = None) -> Block:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def allocate_immutable_block(self, prev_block: Optional[Block],
|
def allocate_immutable_block(self,
|
||||||
|
prev_block: Optional[Block],
|
||||||
token_ids: List[int],
|
token_ids: List[int],
|
||||||
device: Device) -> Block:
|
device: Device,
|
||||||
|
extra_hash: Optional[int] = None) -> Block:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -213,6 +226,7 @@ class DeviceAwareBlockAllocator(ABC):
|
|||||||
prev_block: Optional[Block],
|
prev_block: Optional[Block],
|
||||||
block_token_ids: List[List[int]],
|
block_token_ids: List[List[int]],
|
||||||
device: Device,
|
device: Device,
|
||||||
|
extra_hash: Optional[int] = None,
|
||||||
) -> List[Block]:
|
) -> List[Block]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@ -63,6 +63,7 @@ class NaiveBlockAllocator(BlockAllocator):
|
|||||||
def allocate_immutable_block(self,
|
def allocate_immutable_block(self,
|
||||||
prev_block: Optional[Block],
|
prev_block: Optional[Block],
|
||||||
token_ids: List[int],
|
token_ids: List[int],
|
||||||
|
extra_hash: Optional[int] = None,
|
||||||
device: Optional[Device] = None) -> Block:
|
device: Optional[Device] = None) -> Block:
|
||||||
"""Allocates a new immutable block with the given token IDs, linked to
|
"""Allocates a new immutable block with the given token IDs, linked to
|
||||||
the previous block.
|
the previous block.
|
||||||
@ -85,6 +86,7 @@ class NaiveBlockAllocator(BlockAllocator):
|
|||||||
self,
|
self,
|
||||||
prev_block: Optional[Block],
|
prev_block: Optional[Block],
|
||||||
block_token_ids: List[List[int]],
|
block_token_ids: List[List[int]],
|
||||||
|
extra_hash: Optional[int] = None,
|
||||||
device: Optional[Device] = None) -> List[Block]:
|
device: Optional[Device] = None) -> List[Block]:
|
||||||
assert device is None
|
assert device is None
|
||||||
num_blocks = len(block_token_ids)
|
num_blocks = len(block_token_ids)
|
||||||
@ -106,6 +108,7 @@ class NaiveBlockAllocator(BlockAllocator):
|
|||||||
|
|
||||||
def allocate_mutable_block(self,
|
def allocate_mutable_block(self,
|
||||||
prev_block: Optional[Block],
|
prev_block: Optional[Block],
|
||||||
|
extra_hash: Optional[int] = None,
|
||||||
device: Optional[Device] = None) -> Block:
|
device: Optional[Device] = None) -> Block:
|
||||||
"""Allocates a new mutable block, linked to the previous block.
|
"""Allocates a new mutable block, linked to the previous block.
|
||||||
|
|
||||||
@ -355,7 +358,8 @@ class NaiveBlock(Block):
|
|||||||
block_size: int,
|
block_size: int,
|
||||||
allocator: BlockAllocator,
|
allocator: BlockAllocator,
|
||||||
block_id: Optional[int] = None,
|
block_id: Optional[int] = None,
|
||||||
_cow_target: Optional[Block] = None):
|
_cow_target: Optional[Block] = None,
|
||||||
|
extra_hash: Optional[int] = None):
|
||||||
self._token_ids: List[int] = []
|
self._token_ids: List[int] = []
|
||||||
self._block_size = block_size
|
self._block_size = block_size
|
||||||
self._prev_block = prev_block
|
self._prev_block = prev_block
|
||||||
@ -441,6 +445,10 @@ class NaiveBlock(Block):
|
|||||||
def prev_block(self) -> Optional["Block"]:
|
def prev_block(self) -> Optional["Block"]:
|
||||||
return self._prev_block
|
return self._prev_block
|
||||||
|
|
||||||
|
@property
|
||||||
|
def extra_hash(self):
|
||||||
|
return None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def content_hash(self) -> Optional[int]:
|
def content_hash(self) -> Optional[int]:
|
||||||
return None
|
return None
|
||||||
|
|||||||
@ -126,6 +126,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|||||||
allocator: BlockAllocator,
|
allocator: BlockAllocator,
|
||||||
block_id: Optional[int] = None,
|
block_id: Optional[int] = None,
|
||||||
computed: bool = False,
|
computed: bool = False,
|
||||||
|
extra_hash: Optional[int] = None,
|
||||||
) -> Block:
|
) -> Block:
|
||||||
# Bind block to self.
|
# Bind block to self.
|
||||||
allocator = self
|
allocator = self
|
||||||
@ -137,11 +138,13 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|||||||
block_id=block_id,
|
block_id=block_id,
|
||||||
allocator=allocator,
|
allocator=allocator,
|
||||||
computed=computed,
|
computed=computed,
|
||||||
|
extra_hash=extra_hash,
|
||||||
)
|
)
|
||||||
|
|
||||||
def allocate_immutable_block(self,
|
def allocate_immutable_block(self,
|
||||||
prev_block: Optional[Block],
|
prev_block: Optional[Block],
|
||||||
token_ids: List[int],
|
token_ids: List[int],
|
||||||
|
extra_hash: Optional[int] = None,
|
||||||
device: Optional[Device] = None) -> Block:
|
device: Optional[Device] = None) -> Block:
|
||||||
"""Allocates an immutable block with the given token IDs, reusing cached
|
"""Allocates an immutable block with the given token IDs, reusing cached
|
||||||
blocks if possible.
|
blocks if possible.
|
||||||
@ -160,7 +163,8 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|||||||
block = self._block_pool.init_block(prev_block=prev_block,
|
block = self._block_pool.init_block(prev_block=prev_block,
|
||||||
token_ids=token_ids,
|
token_ids=token_ids,
|
||||||
block_size=self._block_size,
|
block_size=self._block_size,
|
||||||
physical_block_id=None)
|
physical_block_id=None,
|
||||||
|
extra_hash=extra_hash)
|
||||||
assert block.content_hash is not None
|
assert block.content_hash is not None
|
||||||
|
|
||||||
cached_block_id = self._cached_blocks.get(block.content_hash, None)
|
cached_block_id = self._cached_blocks.get(block.content_hash, None)
|
||||||
@ -173,7 +177,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|||||||
self._block_pool.free_block(block)
|
self._block_pool.free_block(block)
|
||||||
|
|
||||||
# No cached block => Allocate a new block
|
# No cached block => Allocate a new block
|
||||||
block = self.allocate_mutable_block(prev_block)
|
block = self.allocate_mutable_block(prev_block, extra_hash=extra_hash)
|
||||||
block.append_token_ids(token_ids)
|
block.append_token_ids(token_ids)
|
||||||
return block
|
return block
|
||||||
|
|
||||||
@ -181,17 +185,20 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|||||||
self,
|
self,
|
||||||
prev_block: Optional[Block],
|
prev_block: Optional[Block],
|
||||||
block_token_ids: List[List[int]],
|
block_token_ids: List[List[int]],
|
||||||
|
extra_hash: Optional[int] = None,
|
||||||
device: Optional[Device] = None) -> List[Block]:
|
device: Optional[Device] = None) -> List[Block]:
|
||||||
blocks = []
|
blocks = []
|
||||||
for token_ids in block_token_ids:
|
for token_ids in block_token_ids:
|
||||||
prev_block = self.allocate_immutable_block(prev_block=prev_block,
|
prev_block = self.allocate_immutable_block(prev_block=prev_block,
|
||||||
token_ids=token_ids,
|
token_ids=token_ids,
|
||||||
device=device)
|
device=device,
|
||||||
|
extra_hash=extra_hash)
|
||||||
blocks.append(prev_block)
|
blocks.append(prev_block)
|
||||||
return blocks
|
return blocks
|
||||||
|
|
||||||
def allocate_mutable_block(self,
|
def allocate_mutable_block(self,
|
||||||
prev_block: Optional[Block],
|
prev_block: Optional[Block],
|
||||||
|
extra_hash: Optional[int] = None,
|
||||||
device: Optional[Device] = None) -> Block:
|
device: Optional[Device] = None) -> Block:
|
||||||
"""Allocates a mutable block. If there are no free blocks, this will
|
"""Allocates a mutable block. If there are no free blocks, this will
|
||||||
evict unused cached blocks.
|
evict unused cached blocks.
|
||||||
@ -210,7 +217,8 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|||||||
block = self._block_pool.init_block(prev_block=prev_block,
|
block = self._block_pool.init_block(prev_block=prev_block,
|
||||||
token_ids=[],
|
token_ids=[],
|
||||||
block_size=self._block_size,
|
block_size=self._block_size,
|
||||||
physical_block_id=block_id)
|
physical_block_id=block_id,
|
||||||
|
extra_hash=extra_hash)
|
||||||
assert not block.computed
|
assert not block.computed
|
||||||
assert block.content_hash is None
|
assert block.content_hash is None
|
||||||
return block
|
return block
|
||||||
@ -382,7 +390,8 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|||||||
prev_block=prev_block,
|
prev_block=prev_block,
|
||||||
token_ids=block.token_ids,
|
token_ids=block.token_ids,
|
||||||
block_size=self._block_size,
|
block_size=self._block_size,
|
||||||
physical_block_id=block_id)
|
physical_block_id=block_id,
|
||||||
|
extra_hash=block.extra_hash)
|
||||||
|
|
||||||
forked_blocks.append(forked_block)
|
forked_blocks.append(forked_block)
|
||||||
prev_block = forked_blocks[-1]
|
prev_block = forked_blocks[-1]
|
||||||
@ -608,10 +617,12 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|||||||
# existing "block" object
|
# existing "block" object
|
||||||
if block.is_full:
|
if block.is_full:
|
||||||
tmp_block = self.allocate_immutable_block(
|
tmp_block = self.allocate_immutable_block(
|
||||||
prev_block=block.prev_block, token_ids=block.token_ids)
|
prev_block=block.prev_block,
|
||||||
|
token_ids=block.token_ids,
|
||||||
|
extra_hash=block.extra_hash)
|
||||||
else:
|
else:
|
||||||
tmp_block = self.allocate_mutable_block(
|
tmp_block = self.allocate_mutable_block(
|
||||||
prev_block=block.prev_block)
|
prev_block=block.prev_block, extra_hash=block.extra_hash)
|
||||||
tmp_block.append_token_ids(block.token_ids)
|
tmp_block.append_token_ids(block.token_ids)
|
||||||
|
|
||||||
block_id = tmp_block.block_id
|
block_id = tmp_block.block_id
|
||||||
@ -679,6 +690,8 @@ class PrefixCachingBlock(Block):
|
|||||||
caching block allocator associated with this block.
|
caching block allocator associated with this block.
|
||||||
block_id (Optional[int], optional): The physical block index
|
block_id (Optional[int], optional): The physical block index
|
||||||
of this block. Defaults to None.
|
of this block. Defaults to None.
|
||||||
|
extra_hash (Optional[int]): The hash value of additional factors
|
||||||
|
such as adapters that influence the block, apart from the token_ids.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -689,6 +702,7 @@ class PrefixCachingBlock(Block):
|
|||||||
allocator: BlockAllocator,
|
allocator: BlockAllocator,
|
||||||
block_id: Optional[int] = None,
|
block_id: Optional[int] = None,
|
||||||
computed: bool = False,
|
computed: bool = False,
|
||||||
|
extra_hash: Optional[int] = None,
|
||||||
):
|
):
|
||||||
assert isinstance(allocator, PrefixCachingBlockAllocator), (
|
assert isinstance(allocator, PrefixCachingBlockAllocator), (
|
||||||
"Currently this class is only tested with "
|
"Currently this class is only tested with "
|
||||||
@ -702,6 +716,7 @@ class PrefixCachingBlock(Block):
|
|||||||
self._allocator = allocator
|
self._allocator = allocator
|
||||||
self._last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME
|
self._last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME
|
||||||
self._computed = computed
|
self._computed = computed
|
||||||
|
self._extra_hash = extra_hash
|
||||||
|
|
||||||
# On the first time, we create the block object, and next we only
|
# On the first time, we create the block object, and next we only
|
||||||
# reinitialize it
|
# reinitialize it
|
||||||
@ -811,6 +826,10 @@ class PrefixCachingBlock(Block):
|
|||||||
def prev_block(self) -> Optional[Block]:
|
def prev_block(self) -> Optional[Block]:
|
||||||
return self._prev_block
|
return self._prev_block
|
||||||
|
|
||||||
|
@property
|
||||||
|
def extra_hash(self) -> Optional[int]:
|
||||||
|
return self._extra_hash
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def content_hash(self) -> Optional[int]:
|
def content_hash(self) -> Optional[int]:
|
||||||
"""Return the content-based hash of the current block, or None if it is
|
"""Return the content-based hash of the current block, or None if it is
|
||||||
@ -841,18 +860,19 @@ class PrefixCachingBlock(Block):
|
|||||||
self._cached_content_hash = PrefixCachingBlock.hash_block_tokens(
|
self._cached_content_hash = PrefixCachingBlock.hash_block_tokens(
|
||||||
is_first_block,
|
is_first_block,
|
||||||
prev_block_hash,
|
prev_block_hash,
|
||||||
cur_block_token_ids=self.token_ids)
|
cur_block_token_ids=self.token_ids,
|
||||||
|
extra_hash=self._extra_hash)
|
||||||
return self._cached_content_hash
|
return self._cached_content_hash
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def hash_block_tokens(is_first_block: bool, prev_block_hash: Optional[int],
|
def hash_block_tokens(is_first_block: bool,
|
||||||
cur_block_token_ids: List[int]) -> int:
|
prev_block_hash: Optional[int],
|
||||||
|
cur_block_token_ids: List[int],
|
||||||
|
extra_hash: Optional[int] = None) -> int:
|
||||||
"""Computes a hash value corresponding to the contents of a block and
|
"""Computes a hash value corresponding to the contents of a block and
|
||||||
the contents of the preceding block(s). The hash value is used for
|
the contents of the preceding block(s). The hash value is used for
|
||||||
prefix caching.
|
prefix caching.
|
||||||
|
|
||||||
NOTE: Content-based hashing does not yet support LoRA.
|
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
- is_first_block (bool): A flag indicating if the block is the first in
|
- is_first_block (bool): A flag indicating if the block is the first in
|
||||||
the sequence.
|
the sequence.
|
||||||
@ -860,12 +880,15 @@ class PrefixCachingBlock(Block):
|
|||||||
if this is the first block.
|
if this is the first block.
|
||||||
- cur_block_token_ids (List[int]): A list of token ids in the current
|
- cur_block_token_ids (List[int]): A list of token ids in the current
|
||||||
block. The current block is assumed to be full.
|
block. The current block is assumed to be full.
|
||||||
|
- extra_hash (Optional[int]): The hash value of additional factors
|
||||||
|
such as adapters that influence the block, apart from the token_ids.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- int: The computed hash value for the block.
|
- int: The computed hash value for the block.
|
||||||
"""
|
"""
|
||||||
assert (prev_block_hash is None) == is_first_block
|
assert (prev_block_hash is None) == is_first_block
|
||||||
return hash((is_first_block, prev_block_hash, *cur_block_token_ids))
|
return hash((is_first_block, prev_block_hash, *cur_block_token_ids,
|
||||||
|
extra_hash))
|
||||||
|
|
||||||
|
|
||||||
class ComputedBlocksTracker:
|
class ComputedBlocksTracker:
|
||||||
@ -935,12 +958,18 @@ class ComputedBlocksTracker:
|
|||||||
assert len(token_ids) >= (i + 1) * self._block_size
|
assert len(token_ids) >= (i + 1) * self._block_size
|
||||||
block_token_ids = token_ids[i * self._block_size:(i + 1) *
|
block_token_ids = token_ids[i * self._block_size:(i + 1) *
|
||||||
self._block_size]
|
self._block_size]
|
||||||
|
|
||||||
|
# NOTE: If there are any factors affecting the block besides
|
||||||
|
# token_ids, they should be added as input to extra_hash.
|
||||||
|
extra_hash = seq.extra_hash()
|
||||||
|
|
||||||
# This has to be kept in sync with the allocator's hash
|
# This has to be kept in sync with the allocator's hash
|
||||||
# calculation.
|
# calculation.
|
||||||
block_hash = PrefixCachingBlock.hash_block_tokens(
|
block_hash = PrefixCachingBlock.hash_block_tokens(
|
||||||
is_first_block=prev_block_hash is None,
|
is_first_block=prev_block_hash is None,
|
||||||
prev_block_hash=prev_block_hash,
|
prev_block_hash=prev_block_hash,
|
||||||
cur_block_token_ids=block_token_ids,
|
cur_block_token_ids=block_token_ids,
|
||||||
|
extra_hash=extra_hash,
|
||||||
)
|
)
|
||||||
block_hashes_recorded.append(block_hash)
|
block_hashes_recorded.append(block_hash)
|
||||||
prev_block_hash = block_hash
|
prev_block_hash = block_hash
|
||||||
|
|||||||
@ -151,8 +151,13 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
|
|||||||
max_block_sliding_window=self.max_block_sliding_window,
|
max_block_sliding_window=self.max_block_sliding_window,
|
||||||
)
|
)
|
||||||
if seq.get_token_ids():
|
if seq.get_token_ids():
|
||||||
|
# NOTE: If there are any factors affecting the block besides
|
||||||
|
# token_ids, they should be added as input to extra_hash.
|
||||||
|
extra_hash = seq.extra_hash()
|
||||||
|
|
||||||
# Add blocks to the block table only if the sequence is non empty.
|
# Add blocks to the block table only if the sequence is non empty.
|
||||||
block_table.allocate(seq.get_token_ids())
|
block_table.allocate(token_ids=seq.get_token_ids(),
|
||||||
|
extra_hash=extra_hash)
|
||||||
|
|
||||||
return block_table
|
return block_table
|
||||||
|
|
||||||
@ -238,6 +243,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
|
|||||||
token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()),
|
token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()),
|
||||||
num_lookahead_slots=num_lookahead_slots,
|
num_lookahead_slots=num_lookahead_slots,
|
||||||
num_computed_slots=seq.data.get_num_computed_tokens(),
|
num_computed_slots=seq.data.get_num_computed_tokens(),
|
||||||
|
extra_hash=seq.extra_hash(),
|
||||||
)
|
)
|
||||||
# Return any new copy-on-writes.
|
# Return any new copy-on-writes.
|
||||||
new_cows = self.block_allocator.clear_copy_on_writes()
|
new_cows = self.block_allocator.clear_copy_on_writes()
|
||||||
|
|||||||
@ -527,6 +527,19 @@ class Sequence:
|
|||||||
hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
|
hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
|
||||||
return hash((hashed_tokens, self.lora_int_id))
|
return hash((hashed_tokens, self.lora_int_id))
|
||||||
|
|
||||||
|
def extra_hash(self) -> Optional[int]:
|
||||||
|
"""
|
||||||
|
This function computes an extra hash for a sequence, specifically
|
||||||
|
designed for prefix caching mode. The final sequence hash is determined
|
||||||
|
by applying token_ids from the sequence's blocks.
|
||||||
|
"""
|
||||||
|
if self.prompt_adapter_id == 0 and self.lora_int_id == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# NOTE: If there are additional factors influencing the block aside from
|
||||||
|
# token_ids, include them as input parameters to the hash.
|
||||||
|
return hash((self.prompt_adapter_id, self.lora_int_id))
|
||||||
|
|
||||||
def num_hashed_tokens_of_block(self, logical_idx: int):
|
def num_hashed_tokens_of_block(self, logical_idx: int):
|
||||||
return logical_idx * self.block_size + self.block_size
|
return logical_idx * self.block_size + self.block_size
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user