[Misc]: Implement CPU/GPU swapping in BlockManagerV2 (#3834)

This commit is contained in:
Kaiyang Chen 2024-06-04 04:37:11 +08:00 committed by GitHub
parent cafb8e06c5
commit 10c38e3e46
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 529 additions and 49 deletions

View File

@ -118,7 +118,7 @@ mypy vllm/model_executor --config-file pyproject.toml
# https://github.com/codespell-project/codespell/issues/1915
# Avoiding the "./" prefix and using "/**" globs for directories appears to solve the problem
CODESPELL_EXCLUDES=(
'--skip' 'tests/prompts/**,./benchmarks/sonnet.txt,tests/lora/data/**,build/**'
'--skip' 'tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**'
)
# check spelling of specified files

View File

@ -24,7 +24,13 @@ from .conftest import get_token_ids_from_llm_generator
@pytest.mark.parametrize("baseline_llm_kwargs", [{
"use_v2_block_manager": False
}])
@pytest.mark.parametrize("test_llm_kwargs", [{"use_v2_block_manager": True}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"use_v2_block_manager": True,
"preemption_mode": "swap"
}, {
"use_v2_block_manager": True,
"preemption_mode": "recompute"
}])
@pytest.mark.parametrize("batch_size", [10])
@pytest.mark.parametrize("seed", [1])
def test_v1_v2_greedy_equality_with_preemption(baseline_llm_generator,
@ -95,7 +101,13 @@ def test_v1_v2_greedy_equality_with_preemption(baseline_llm_generator,
@pytest.mark.parametrize("baseline_llm_kwargs", [{
"use_v2_block_manager": False
}])
@pytest.mark.parametrize("test_llm_kwargs", [{"use_v2_block_manager": True}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"use_v2_block_manager": True,
"preemption_mode": "swap"
}, {
"use_v2_block_manager": True,
"preemption_mode": "recompute"
}])
@pytest.mark.parametrize("batch_size", [10])
@pytest.mark.parametrize("seed", [1])
def test_v1_v2_greedy_equality_with_cow(baseline_llm_generator,
@ -179,11 +191,18 @@ def test_v1_v2_greedy_equality_with_cow(baseline_llm_generator,
}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[{
# We run one test with block_size < lookahead_slots, one test with
# block_size > lookahead_slots
"num_lookahead_slots": 10,
}])
[
{
# We run one test with block_size < lookahead_slots, one test with
# block_size > lookahead_slots
"num_lookahead_slots": 10,
"preemption_mode": "swap",
},
{
"num_lookahead_slots": 10,
"preemption_mode": "recompute",
}
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator,
@ -322,7 +341,13 @@ def test_chunked_prefill_block_manager_v2(baseline_llm_generator,
@pytest.mark.parametrize("baseline_llm_kwargs", [{
"use_v2_block_manager": False
}])
@pytest.mark.parametrize("test_llm_kwargs", [{"use_v2_block_manager": True}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"use_v2_block_manager": True,
"preemption_mode": "swap"
}, {
"use_v2_block_manager": True,
"preemption_mode": "recompute"
}])
@pytest.mark.parametrize("batch_size", [10])
@pytest.mark.parametrize("seed", [1])
def test_v1_v2_greedy_equality_prefix_caching_enabled_with_preemption(
@ -397,7 +422,13 @@ def test_v1_v2_greedy_equality_prefix_caching_enabled_with_preemption(
@pytest.mark.parametrize("baseline_llm_kwargs", [{
"enable_prefix_caching": False
}])
@pytest.mark.parametrize("test_llm_kwargs", [{"enable_prefix_caching": True}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"enable_prefix_caching": True,
"preemption_mode": "swap"
}, {
"enable_prefix_caching": True,
"preemption_mode": "recompute"
}])
@pytest.mark.parametrize("batch_size", [10])
@pytest.mark.parametrize("seed", [1])
def test_auto_prefix_caching_with_preemption(baseline_llm_generator,

View File

@ -7,7 +7,8 @@ from vllm.core.interfaces import AllocStatus
from vllm.sequence import Logprob, SequenceStatus
from vllm.utils import chunk_list
from ..utils import create_seq_group, create_seq_group_encoder_decoder
from ..utils import (create_dummy_prompt, create_seq_group,
create_seq_group_encoder_decoder)
@pytest.mark.parametrize("block_size", [16])
@ -255,6 +256,61 @@ def test_append_slots(block_size, prompt_len, num_slots_to_append,
assert num_consumed_blocks == expected_consumed_blocks
@pytest.mark.parametrize("block_size", [8])
@pytest.mark.parametrize("num_cpu_blocks", [4])
@pytest.mark.parametrize("num_gpu_blocks", [4])
@pytest.mark.parametrize("num_lookahead_slots", [0, 2, 10])
@pytest.mark.parametrize("enable_caching", [False, True])
def test_swap(block_size, num_cpu_blocks, num_gpu_blocks, num_lookahead_slots,
enable_caching):
"""Verify blocks number on src/desc device is correct after swapping in/out
sequence group (not missing or extra blocks).
"""
block_manager = BlockSpaceManagerV2(block_size,
num_cpu_blocks,
num_gpu_blocks,
watermark=0,
enable_caching=enable_caching)
prompt, seq_group = create_dummy_prompt("1", prompt_length=block_size - 1)
prompt.status = SequenceStatus.WAITING
block_manager.allocate(seq_group)
# Emulate a forward pass by appending a single token.
# The block manager then knows how many unprocessed
# tokens will be written in the next forward pass.
token_id = 0
prompt.status = SequenceStatus.RUNNING
prompt.append_token_id(token_id, {token_id: Logprob(0.0)})
# Swap seq group from GPU -> CPU.
gpu_blocks = block_manager.get_block_table(prompt)
assert block_manager.can_swap_out(seq_group)
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
mapping = block_manager.swap_out(seq_group)
mapping_keys = [key for key, _ in mapping]
assert mapping_keys == gpu_blocks
after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks)
assert before_gpu_blocks + len(gpu_blocks) == after_gpu_blocks
prompt.status = SequenceStatus.SWAPPED
# Swap seq group from CPU -> GPU.
assert block_manager.can_swap_in(seq_group, num_lookahead_slots)
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
mapping = block_manager.swap_in(seq_group)
cpu_blocks = block_manager.get_block_table(prompt)
mapping_keys = [key for key, _ in mapping]
assert mapping_keys == [cpu_blocks[0]]
after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
assert before_gpu_blocks == after_gpu_blocks + len(cpu_blocks)
# TODO(cade/kaiyang): add comprehensive tests for swapping at allocator level.
@pytest.mark.parametrize("block_size", [8, 16])
@pytest.mark.parametrize("prompt_len", [10, 300, 1000])
@pytest.mark.parametrize("num_slots_to_append", [50])

View File

@ -651,19 +651,24 @@ class SchedulerConfig:
enable_chunked_prefill: If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens.
embedding_mode: Whether the running model is for embedding.
preemption_mode: Whether to perform preemption by swapping or
recomputation. If not specified, we determine the mode as follows:
We use recomputation by default since it incurs lower overhead than
swapping. However, when the sequence group has multiple sequences
(e.g., beam search), recomputation is not currently supported. In
such a case, we use swapping instead.
"""
def __init__(
self,
max_num_batched_tokens: Optional[int],
max_num_seqs: int,
max_model_len: int,
use_v2_block_manager: bool = False,
num_lookahead_slots: int = 0,
delay_factor: float = 0.0,
enable_chunked_prefill: bool = False,
embedding_mode: Optional[bool] = False,
) -> None:
def __init__(self,
max_num_batched_tokens: Optional[int],
max_num_seqs: int,
max_model_len: int,
use_v2_block_manager: bool = False,
num_lookahead_slots: int = 0,
delay_factor: float = 0.0,
enable_chunked_prefill: bool = False,
embedding_mode: Optional[bool] = False,
preemption_mode: Optional[str] = None) -> None:
if max_num_batched_tokens is not None:
self.max_num_batched_tokens = max_num_batched_tokens
else:
@ -689,6 +694,7 @@ class SchedulerConfig:
self.delay_factor = delay_factor
self.chunked_prefill_enabled = enable_chunked_prefill
self.embedding_mode = embedding_mode
self.preemption_mode = preemption_mode
self._verify_args()

View File

@ -283,6 +283,10 @@ class BlockTable:
def _is_allocated(self) -> bool:
return len(self._blocks) > 0
@property
def blocks(self) -> Optional[List[Block]]:
return self._blocks
@property
def _num_empty_slots(self) -> int:
assert self._is_allocated

View File

@ -140,7 +140,6 @@ class CopyOnWriteTracker:
assert refcount != 0
if refcount > 1:
src_block_id = block_id
# Decrement refcount of the old block.
self._allocator.free(block)

View File

@ -90,11 +90,8 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
gpu_block_allocator=gpu_allocator,
)
def __init__(
self,
cpu_block_allocator: BlockAllocator,
gpu_block_allocator: BlockAllocator,
):
def __init__(self, cpu_block_allocator: BlockAllocator,
gpu_block_allocator: BlockAllocator):
assert not (
cpu_block_allocator.all_block_ids
& gpu_block_allocator.all_block_ids
@ -105,6 +102,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Device.GPU: gpu_block_allocator,
}
self._swap_mapping: Dict[int, int] = {}
self._null_block: Optional[Block] = None
self._block_ids_to_allocator: Dict[int, BlockAllocator] = {}
@ -198,6 +196,68 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
def get_num_total_blocks(self, device: Device) -> int:
return self._allocators[device].get_num_total_blocks()
def get_physical_block_id(self, device: Device, absolute_id: int) -> int:
"""Returns the zero-offset block id on certain device given the
absolute block id.
Args:
device (Device): The device for which to query relative block id.
absolute_id (int): The absolute block id for the block in
whole allocator.
Returns:
int: The zero-offset block id on certain device.
"""
return self._allocators[device].get_physical_block_id(absolute_id)
def swap(self, blocks: List[Block], source_device: Device,
dest_device: Device) -> Dict[int, int]:
"""Execute the swap for the given blocks from source_device
on to dest_device, save the current swap mapping and append
them to the accumulated `self._swap_mapping` for each
scheduling move.
Args:
blocks: List of blocks to be swapped.
source_device (Device): Device to swap the 'blocks' from.
dest_device (Device): Device to swap the 'blocks' to.
Returns:
Dict[int, int]: Swap mapping from source_device
on to dest_device.
"""
source_block_ids = [block.block_id for block in blocks]
self._allocators[source_device].swap_out(blocks)
self._allocators[dest_device].swap_in(blocks)
dest_block_ids = [block.block_id for block in blocks]
current_swap_mapping: Dict[int, int] = {}
for src, dest in zip(source_block_ids, dest_block_ids):
if src is not None and dest is not None:
self._swap_mapping[src] = dest
current_swap_mapping[src] = dest
return current_swap_mapping
def get_num_blocks_touched(self,
blocks: List[Block],
device: Device,
num_lookahead_slots: int = 0) -> int:
"""Returns the number of blocks that will be touched by
swapping in/out the given blocks on to the 'device'.
Args:
blocks: List of blocks to be swapped.
device (Device): Device to swap the 'blocks' on.
num_lookahead_slots (int): Number of lookahead slots used in
speculative decoding, default to 0.
Returns:
int: the number of blocks that will be touched by
swapping in/out the given blocks on to the 'device'.
"""
return self._allocators[device].get_num_blocks_touched(
blocks, num_lookahead_slots)
def clear_copy_on_writes(self) -> List[Tuple[int, int]]:
"""Clears the copy-on-write (CoW) state and returns the mapping of
source to destination block IDs.
@ -240,6 +300,18 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
raise NotImplementedError
def get_and_reset_swaps(self) -> List[Tuple[int, int]]:
"""Returns and clears the mapping of source to destination block IDs.
Will be called after every swapping operations for now, and after every
schedule when BlockManagerV2 become default. Currently not useful.
Returns:
List[Tuple[int, int]]: A mapping of source to destination block IDs.
"""
mapping = self._swap_mapping.copy()
self._swap_mapping.clear()
return list(mapping.items())
class NullBlock(Block):
"""

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import FrozenSet, List, Optional, Protocol, Tuple
from typing import Dict, FrozenSet, List, Optional, Protocol, Tuple
from vllm.utils import Device
@ -116,6 +116,18 @@ class BlockAllocator(ABC):
def get_num_free_blocks(self) -> int:
pass
@abstractmethod
def get_physical_block_id(self, absolute_id: int) -> int:
pass
@abstractmethod
def swap_out(self, blocks: List[Block]) -> None:
pass
@abstractmethod
def swap_in(self, blocks: List[Block]) -> None:
pass
@property
@abstractmethod
def all_block_ids(self) -> FrozenSet[int]:
@ -149,6 +161,12 @@ class BlockAllocator(ABC):
"""NOTE: This should not be used besides Block"""
pass
@abstractmethod
def get_num_blocks_touched(self,
blocks: List[Block],
num_lookahead_slots: int = 0) -> int:
pass
class NoFreeBlocksError(ValueError):
pass
@ -204,6 +222,22 @@ class DeviceAwareBlockAllocator(ABC):
self, seq_block_ids: List[List[int]]) -> List[int]:
pass
@abstractmethod
def get_num_blocks_touched(self,
blocks: List[Block],
device: Device,
num_lookahead_slots: int = 0) -> int:
pass
@abstractmethod
def swap(self, blocks: List[Block], source_device: Device,
dest_device: Device) -> Dict[int, int]:
pass
@abstractmethod
def get_physical_block_id(self, device: Device, absolute_id: int) -> int:
pass
@abstractmethod
def allocate_or_get_null_block(self) -> Block:
"""

View File

@ -3,6 +3,7 @@ from typing import FrozenSet, Iterable, List, Optional, Set, Tuple
from vllm.core.block.common import (CopyOnWriteTracker, RefCounter,
get_all_blocks_recursively)
from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
from vllm.utils import cdiv
Refcount = int
@ -95,8 +96,6 @@ class NaiveBlockAllocator(BlockAllocator):
def free(self, block: Block) -> None:
assert block.block_id is not None
self._free_block_id(block.block_id)
# Mark the block as having no allocation.
block.block_id = None
def fork(self, last_block: Block) -> List[Block]:
@ -153,6 +152,19 @@ class NaiveBlockAllocator(BlockAllocator):
if refcount == 0:
self._free_block_indices.add(block_id)
def get_physical_block_id(self, absolute_id: int) -> int:
"""Returns the zero-offset block id on certain block allocator
given the absolute block id.
Args:
absolute_id (int): The absolute block id for the block
in whole allocator.
Returns:
int: The zero-offset block id on certain device.
"""
return sorted(self._all_block_indices).index(absolute_id)
@property
def refcounter(self):
return self._refcounter
@ -213,6 +225,56 @@ class NaiveBlockAllocator(BlockAllocator):
def promote_to_immutable_block(self, block: Block) -> BlockId:
raise NotImplementedError
def get_num_blocks_touched(self,
blocks: List[Block],
num_lookahead_slots: int = 0) -> int:
"""Determine the number of blocks that will be touched by
swapping in/out the given blocks from certain sequence
group with the provided num_lookahead_slots.
Args:
blocks (List[Block]): The potential blocks to swap.
num_lookahead_slots (int): number of lookahead slots (0 for swap
out).
Returns:
int: the number of blocks that will be touched by
swapping in/out the given blocks and num_lookahead_slots.
"""
# NOTE: for naive block, we use set to eliminate common blocks among
# seqs, also we compare the empty slots in the mutable blocks with
# lookahead slots to get the number of unique new block that are
# needed.
old_block_set = set()
new_block_count = 0
# TODO(cade): make sure the logic is correct and clean it up.
for block in blocks:
if not block.is_full and num_lookahead_slots != 0:
if block.num_empty_slots >= num_lookahead_slots:
new_block_count += 1
else:
new_block_count += cdiv(
num_lookahead_slots - block.num_empty_slots,
self._block_size)
else:
old_block_set.add(block.block_id)
num_touched_blocks = new_block_count + len(old_block_set)
return num_touched_blocks
def swap_out(self, blocks: List[Block]) -> None:
for block in blocks:
self.free(block)
def swap_in(self, blocks: List[Block]) -> None:
for block in blocks:
if block.is_full:
alloc = self.allocate_immutable(block.prev_block,
block.token_ids)
else:
alloc = self.allocate_mutable(block.prev_block)
alloc.append_token_ids(block.token_ids)
block.block_id = alloc.block_id
class NaiveBlock(Block):
"""An implementation of the Block class that does not support prefix

View File

@ -1,4 +1,5 @@
"""Token blocks."""
from itertools import takewhile
from os.path import commonprefix
from typing import Dict, FrozenSet, Iterable, List, Optional, Tuple
@ -8,6 +9,7 @@ from vllm.core.block.common import (CopyOnWriteTracker,
from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator
from vllm.core.evictor_v2 import EvictionPolicy, Evictor, make_evictor
from vllm.utils import cdiv
PrefixHash = int
@ -294,10 +296,29 @@ class PrefixCachingBlockAllocator(BlockAllocator):
def get_num_total_blocks(self) -> int:
return self._hashless_allocator.get_num_total_blocks()
def get_physical_block_id(self, absolute_id: int) -> int:
"""Returns the zero-offset block id on certain block allocator
given the absolute block id.
Args:
absolute_id (int): The absolute block id for the block
in whole allocator.
Returns:
int: The rzero-offset block id on certain device.
"""
return sorted(self.all_block_ids).index(absolute_id)
@property
def all_block_ids(self) -> FrozenSet[int]:
return self._hashless_allocator.all_block_ids
def is_block_cached(self, block: Block) -> bool:
assert block.content_hash is not None
if block.content_hash in self._cached_blocks:
return True
return False
def promote_to_immutable_block(self, block: Block) -> BlockId:
"""Once a mutable block is full, it can be promoted to an immutable
block. This means that its content can be referenced by future blocks
@ -411,6 +432,63 @@ class PrefixCachingBlockAllocator(BlockAllocator):
if ids != []
])
def get_num_blocks_touched(self,
blocks: List[Block],
num_lookahead_slots: int = 0) -> int:
"""Determine the number of blocks that will be touched by
swapping in/out the given blocks from certain sequence
group with the provided num_lookahead_slots.
Args:
blocks (List[Block]): The potential blocks to swap.
num_lookahead_slots (int): number of lookahead slots (0 for
swap out).
Returns:
int: the number of blocks that will be touched by
swapping in/out the given blocks and num_lookahead_slots.
"""
num_touched_blocks = 0
for block in blocks:
if not block.is_full:
if block.num_empty_slots >= num_lookahead_slots:
num_touched_blocks += 1
else:
num_touched_blocks += cdiv(
num_lookahead_slots - block.num_empty_slots,
self._block_size)
else:
if not self.is_block_cached(block):
num_touched_blocks += 1
return num_touched_blocks
def swap_out(self, blocks: List[Block]) -> None:
"""Execute the swap out actions. Basically just free the
given blocks.
Args:
blocks: List of blocks to be swapped out.
"""
for block in blocks:
self.free(block)
def swap_in(self, blocks: List[Block]) -> None:
"""Execute the swap int actions. Change the block id from
old allocator to current allocator for each block to finish
the block table update.
Args:
blocks: List of blocks to be swapped in.
"""
for block in blocks:
if block.is_full:
alloc = self.allocate_immutable(block.prev_block,
block.token_ids)
else:
alloc = self.allocate_mutable(block.prev_block)
alloc.append_token_ids(block.token_ids)
block.block_id = alloc.block_id
class PrefixCachingBlock(Block):
"""A block implementation that supports prefix caching.

View File

@ -541,11 +541,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
return new_block_table
def swap_in(self,
seq_group: SequenceGroup,
num_lookahead_slots: int = 0) -> List[Tuple[int, int]]:
assert (num_lookahead_slots == 0
), "BlockSpaceManagerV1 does not support lookahead allocation"
def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
request_id = seq_group.request_id

View File

@ -1,10 +1,12 @@
"""A block manager that manages token blocks."""
from itertools import chain
from typing import Dict, List, Optional
from typing import Sequence as GenericSequence
from typing import Tuple
from vllm.core.block.block_table import BlockTable
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
from vllm.core.block.interfaces import Block
from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
@ -217,7 +219,6 @@ class BlockSpaceManagerV2(BlockSpaceManager):
num_lookahead_slots=num_lookahead_slots,
num_computed_slots=seq.data.get_num_computed_tokens(),
)
# Return any new copy-on-writes.
new_cows = self.block_allocator.clear_copy_on_writes()
return new_cows
@ -297,20 +298,145 @@ class BlockSpaceManagerV2(BlockSpaceManager):
def can_swap_in(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> AllocStatus:
return AllocStatus.LATER
"""Returns the AllocStatus for the given sequence_group
with num_lookahead_slots.
def swap_in(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> List[Tuple[int, int]]:
raise NotImplementedError
Args:
sequence_group (SequenceGroup): The sequence group to swap in.
num_lookahead_slots (int): Number of lookahead slots used in
speculative decoding, default to 0.
Returns:
AllocStatus: The AllocStatus for the given sequence group.
"""
return self._can_swap(seq_group, Device.GPU, SequenceStatus.SWAPPED,
num_lookahead_slots)
def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
"""Returns the block id mapping (from CPU to GPU) generated by
swapping in the given seq_group with num_lookahead_slots.
Args:
seq_group (SequenceGroup): The sequence group to swap in.
Returns:
List[Tuple[int, int]]: The mapping of swapping block from CPU
to GPU.
"""
blocks = self._get_blocks_for_swap(seq_group, SequenceStatus.SWAPPED)
current_swap_mapping = self.block_allocator.swap(
blocks=blocks, source_device=Device.CPU, dest_device=Device.GPU)
block_number_mapping = {
self.block_allocator.get_physical_block_id(Device.CPU,
cpu_block_id):
self.block_allocator.get_physical_block_id(Device.GPU,
gpu_block_id)
for cpu_block_id, gpu_block_id in current_swap_mapping.items()
}
# convert to list of tuples once here
return list(block_number_mapping.items())
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
"""Returns whether we can swap out the given sequence_group
with num_lookahead_slots.
Args:
seq_group (SequenceGroup): The sequence group to swap in.
num_lookahead_slots (int): Number of lookahead slots used in
speculative decoding, default to 0.
Returns:
bool: Whether it's possible to swap out current sequence group.
"""
alloc_status = self._can_swap(seq_group, Device.CPU,
SequenceStatus.RUNNING)
if alloc_status == AllocStatus.OK:
return True
return False
def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
raise NotImplementedError
def swap_out(self, sequence_group: SequenceGroup) -> List[Tuple[int, int]]:
"""Returns the block id mapping (from GPU to CPU) generated by
swapping out the given sequence_group with num_lookahead_slots.
Args:
sequence_group (SequenceGroup): The sequence group to swap in.
Returns:
List[Tuple[int, int]]: The mapping of swapping block from
GPU to CPU.
"""
blocks = self._get_blocks_for_swap(sequence_group,
SequenceStatus.RUNNING)
current_swap_mapping = self.block_allocator.swap(
blocks=blocks, source_device=Device.GPU, dest_device=Device.CPU)
block_number_mapping = {
self.block_allocator.get_physical_block_id(Device.GPU,
gpu_block_id):
self.block_allocator.get_physical_block_id(Device.CPU,
cpu_block_id)
for gpu_block_id, cpu_block_id in current_swap_mapping.items()
}
# convert to list of tuples once here
return list(block_number_mapping.items())
def get_num_free_gpu_blocks(self) -> int:
return self.block_allocator.get_num_free_blocks(Device.GPU)
def get_num_free_cpu_blocks(self) -> int:
return self.block_allocator.get_num_free_blocks(Device.CPU)
def _can_swap(self,
seq_group: SequenceGroup,
device: Device,
status: SequenceStatus,
num_lookahead_slots: int = 0) -> AllocStatus:
"""Returns the AllocStatus for swapping in/out the given sequence_group
on to the 'device'.
Args:
sequence_group (SequenceGroup): The sequence group to swap in.
device (Device): device to swap the 'seq_group' on.
status (SequenceStatus): The status of sequence which is needed
for action. RUNNING for swap out and SWAPPED for swap in
num_lookahead_slots (int): Number of lookahead slots used in
speculative decoding, default to 0.
Returns:
AllocStatus: The AllocStatus for swapping in/out the given
sequence_group on to the 'device'.
"""
blocks = self._get_blocks_for_swap(seq_group, status)
num_blocks_touched = self.block_allocator.get_num_blocks_touched(
blocks, device, num_lookahead_slots)
watermark_blocks = 0
if device == Device.GPU:
watermark_blocks = self.watermark_blocks
if self.block_allocator.get_num_total_blocks(
device) < num_blocks_touched:
return AllocStatus.NEVER
elif self.block_allocator.get_num_free_blocks(
device) - num_blocks_touched >= watermark_blocks:
return AllocStatus.OK
else:
return AllocStatus.LATER
def _get_blocks_for_swap(self, seq_group: SequenceGroup,
status: SequenceStatus) -> List[Block]:
"""Returns the list of blocks those are touched by the seq_group
Args:
sequence_group (SequenceGroup): The sequence group to swap in.
status (SequenceStatus): The status of sequence which is needed
for action. RUNNING for swap out and SWAPPED for swap in
Returns:
The list of blocks those are touched by the seq_group.
"""
blocks: Dict[int, List[Block]] = {}
for seq in seq_group.get_seqs(status=status):
block_table = self.block_tables[seq.seq_id]
if block_table.blocks is not None:
blocks[seq.seq_id] = block_table.blocks
combined_blocks = list(chain(*blocks.values()))
return combined_blocks

View File

@ -46,8 +46,7 @@ class EmbeddingModelBlockSpaceManager(BlockSpaceManager):
num_lookahead_slots: int) -> AllocStatus:
return AllocStatus.OK
def swap_in(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> List[Tuple[int, int]]:
def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
return None # type: ignore
def can_swap_out(self, seq_group: SequenceGroup) -> bool:

View File

@ -73,8 +73,7 @@ class BlockSpaceManager(ABC):
pass
@abstractmethod
def swap_in(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> List[Tuple[int, int]]:
def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
pass
@abstractmethod

View File

@ -297,6 +297,8 @@ class Scheduler:
self.prev_prompt = False
# Latency of the last prompt step
self.last_prompt_latency = 0.0
# preemption mode, RECOMPUTE or SWAP
self.user_specified_preemption_mode = scheduler_config.preemption_mode
# The following field is test-only. It is used to inject artificial
# preemption.
@ -522,7 +524,9 @@ class Scheduler:
seq_group = swapped_queue[0]
# If the sequence group cannot be swapped in, stop.
alloc_status = self.block_manager.can_swap_in(seq_group)
is_prefill = seq_group.is_prefill()
alloc_status = self.block_manager.can_swap_in(
seq_group, self._get_num_lookahead_slots(is_prefill))
if alloc_status == AllocStatus.LATER:
break
elif alloc_status == AllocStatus.NEVER:
@ -1067,12 +1071,17 @@ class Scheduler:
# over sequence groups with a single sequence.
# TODO(woosuk): Support recomputation for sequence groups with multiple
# sequences. This may require a more sophisticated CUDA kernel.
if preemption_mode is None:
if self.user_specified_preemption_mode is None:
if seq_group.get_max_num_running_seqs() == 1:
preemption_mode = PreemptionMode.RECOMPUTE
else:
preemption_mode = PreemptionMode.SWAP
elif self.user_specified_preemption_mode == "swap":
preemption_mode = PreemptionMode.SWAP
else:
preemption_mode = PreemptionMode.RECOMPUTE
if self.num_cumulative_preemption % 50 == 0:
logger.warning(
"Sequence group %s is preempted by %s mode because there is "

View File

@ -75,6 +75,7 @@ class EngineArgs:
num_gpu_blocks_override: Optional[int] = None
num_lookahead_slots: int = 0
model_loader_extra_config: Optional[dict] = None
preemption_mode: Optional[str] = None
# Related to Vision-language models such as llava
image_input_type: Optional[str] = None
@ -564,6 +565,13 @@ class EngineArgs:
'corresponding to the chosen load_format. '
'This should be a JSON string that will be '
'parsed into a dictionary.')
parser.add_argument(
'--preemption_mode',
type=str,
default=None,
help='If \'recompute\', the engine performs preemption by block '
'swapping; If \'swap\', the engine performs preemption by block '
'swapping.')
parser.add_argument(
"--served-model-name",
@ -667,6 +675,7 @@ class EngineArgs:
delay_factor=self.scheduler_delay_factor,
enable_chunked_prefill=self.enable_chunked_prefill,
embedding_mode=model_config.embedding_mode,
preemption_mode=self.preemption_mode,
)
lora_config = LoRAConfig(
max_lora_rank=self.max_lora_rank,