Implement preemption via recomputation & Refactor scheduling logic (#12)

This commit is contained in:
Woosuk Kwon 2023-03-30 14:51:46 -07:00 committed by GitHub
parent 88c0268a18
commit 7a7929abe8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 277 additions and 124 deletions

View File

@ -84,8 +84,9 @@ class FastAPIFrontend:
seq = Sequence(seq_id, token_ids, block_size=self.block_size)
seqs.append(seq)
arrival_time = time.time()
group_id = next(self.seq_group_counter)
seq_group = SequenceGroup(group_id, seqs)
seq_group = SequenceGroup(group_id, seqs, arrival_time)
group_event = asyncio.Event()
self.sequence_group_events[group_id] = group_event
await self.server.add_sequence_groups.remote([(seq_group, sampling_params)])

View File

@ -76,7 +76,8 @@ class BlockSpaceManager:
self.block_tables: Dict[int, BlockTable] = {}
def can_allocate(self, seq_group: SequenceGroup) -> bool:
# NOTE: Here we assume that all sequences in the group have the same prompt.
# FIXME(woosuk): Here we assume that all sequences in the group share
# the same prompt. This may not be true for preempted sequences.
seq = seq_group.seqs[0]
num_required_blocks = len(seq.logical_token_blocks)
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()

View File

@ -0,0 +1,45 @@
from typing import List
from cacheflow.sequence import SequenceGroup
class Policy:
def get_priority(
self,
now: float,
seq_group: SequenceGroup,
) -> float:
raise NotImplementedError
def sort_by_priority(
self,
now: float,
seq_groups: List[SequenceGroup],
) -> List[SequenceGroup]:
return sorted(
seq_groups,
key=lambda seq_group: self.get_priority(now, seq_group),
reverse=True,
)
class FCFS(Policy):
def get_priority(
self,
now: float,
seq_group: SequenceGroup,
) -> float:
return now - seq_group.arrival_time
class PolicyFactory:
_POLICY_REGISTRY = {
'fcfs': FCFS,
}
@classmethod
def get_policy(cls, policy_name: str, **kwargs) -> Policy:
return cls._POLICY_REGISTRY[policy_name](**kwargs)

View File

@ -1,6 +1,9 @@
from typing import Dict, List, Tuple
import enum
import time
from typing import Dict, List, Optional, Tuple
from cacheflow.master.block_manager import BlockSpaceManager
from cacheflow.master.policy import PolicyFactory
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence
from cacheflow.sequence import SequenceGroup
@ -9,6 +12,19 @@ from cacheflow.sequence import SequenceOutputs
from cacheflow.sequence import SequenceStatus
class PreemptionMode(enum.Enum):
"""Preemption modes.
1. Swapping: Swap out the blocks of the preempted sequences to CPU memory
and swap them back in when the sequences are resumed.
2. Recomputation: Discard the blocks of the preempted sequences and
recompute them when the sequences are resumed, treating the sequences as
new prompts.
"""
SWAP = enum.auto()
RECOMPUTE = enum.auto()
class Scheduler:
def __init__(
@ -25,6 +41,8 @@ class Scheduler:
self.num_cpu_blocks = num_cpu_blocks
self.max_num_batched_tokens = max_num_batched_tokens
# Instantiate the scheduling policy.
self.policy = PolicyFactory.get_policy(policy_name='fcfs')
# Create the block space manager.
self.block_manager = BlockSpaceManager(
block_size=block_size,
@ -32,158 +50,140 @@ class Scheduler:
num_cpu_blocks=num_cpu_blocks,
)
# Running sequence groups (FIFO).
# Sequence groups in the WAITING state.
self.waiting: List[SequenceGroup] = []
# Sequence groups in the RUNNING state.
self.running: List[SequenceGroup] = []
# Mapping: group_id -> num_steps.
self.num_steps: Dict[int, int] = {}
# Mapping: group_id -> sampling params.
self.sampling_params: Dict[int, SamplingParams] = {}
# Swapped sequence groups (LIFO).
# Sequence groups in the SWAPPED state.
self.swapped: List[SequenceGroup] = []
# Pending sequence groups (FIFO).
self.pending: List[SequenceGroup] = []
def add_sequence_groups(
self,
sequence_groups: List[Tuple[SequenceGroup, SamplingParams]],
seq_groups: List[Tuple[SequenceGroup, SamplingParams]],
) -> None:
# Add sequence groups to the pending queue.
for seq_group, sampling_params in sequence_groups:
self.pending.append(seq_group)
# Add sequence groups to the waiting queue.
for seq_group, sampling_params in seq_groups:
self.waiting.append(seq_group)
self.sampling_params[seq_group.group_id] = sampling_params
def _free_seq(self, seq: Sequence) -> None:
seq.status = SequenceStatus.FINISHED
self.block_manager.free(seq)
def _allocate(self, seq_group: SequenceGroup) -> None:
self.block_manager.allocate(seq_group)
for seq in seq_group.seqs:
seq.status = SequenceStatus.RUNNING
self.running.append(seq_group)
# FIXME(woosuk): Support interactive generation.
self.num_steps[seq_group.group_id] = 0
def _append(
def _schedule(
self,
seq_group: SequenceGroup,
blocks_to_copy: Dict[int, List[int]],
) -> None:
for seq in seq_group.seqs:
if seq.status == SequenceStatus.FINISHED:
continue
ret = self.block_manager.append(seq)
if ret is not None:
src_block, dst_block = ret
if src_block in blocks_to_copy:
blocks_to_copy[src_block].append(dst_block)
else:
blocks_to_copy[src_block] = [dst_block]
def _swap_in(
self,
seq_group: SequenceGroup,
blocks_to_swap_in: Dict[int, int],
) -> None:
mapping = self.block_manager.swap_in(seq_group)
blocks_to_swap_in.update(mapping)
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
seq.status = SequenceStatus.RUNNING
self.running.append(seq_group)
def _swap_out(
self,
seq_group: SequenceGroup,
blocks_to_swap_out: Dict[int, int],
) -> None:
assert self.block_manager.can_swap_out(seq_group)
mapping = self.block_manager.swap_out(seq_group)
blocks_to_swap_out.update(mapping)
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq.status = SequenceStatus.SWAPPED
self.swapped.append(seq_group)
def step(self) -> List[SequenceGroup]:
) -> Tuple[Dict[int, int], Dict[int, int], Dict[int, List[int]], List[int]]:
# Blocks that need to be swaped or copied before model execution.
blocks_to_swap_in: Dict[int, int] = {}
blocks_to_swap_out: Dict[int, int] = {}
blocks_to_copy: Dict[int, List[int]] = {}
# 1. Reserve new slots for the running sequences.
# NOTE: Here we implicitly assume FCFS scheduling.
# That is, the most recently added sequence group is the first
# to be swapped out.
victim_idx = len(self.running) - 1
for i, seq_group in enumerate(self.running):
if i > victim_idx:
# The i-th sequence group has already been swapped out.
break
# OOM. Swap out the victim sequence groups.
# Fix the current time.
now = time.time()
# NOTE(woosuk): We prioritize the sequence groups in the RUNNING state
# in order to minimize the preemption overheads.
# Preemption happens only when there is no available slot to keep all
# the sequence groups in the RUNNING state.
# In this case, the policy is responsible for deciding which sequence
# groups to preempt.
self.running = self.policy.sort_by_priority(now, self.running)
# Reserve new token slots for the running sequence groups.
running: List[SequenceGroup] = []
preempted: List[SequenceGroup] = []
while self.running:
seq_group = self.running.pop(0)
while not self.block_manager.can_append(seq_group):
victim_seq_group = self.running[victim_idx]
self._swap_out(victim_seq_group, blocks_to_swap_out)
victim_idx -= 1
if i > victim_idx:
# No other sequence groups can be swapped out.
if self.running:
# Preempt the lowest-priority sequence groups.
victim_seq_group = self.running.pop(-1)
self._preempt(victim_seq_group, blocks_to_swap_out)
preempted.append(victim_seq_group)
else:
# No other sequence groups can be preempted.
# Preempt the current sequence group.
self._preempt(seq_group, blocks_to_swap_out)
preempted.append(seq_group)
break
else:
# Append new slots to the sequence group.
self._append(seq_group, blocks_to_copy)
self.running = self.running[:victim_idx + 1]
running.append(seq_group)
self.running = running
# 2. Swap in the swapped sequences if possible.
# NOTE: Here we implicitly assume FCFS scheduling.
# The swapped sequences are in LIFO order.
for i, seq_group in enumerate(reversed(self.swapped)):
if self.block_manager.can_swap_in(seq_group):
self._swap_in(seq_group, blocks_to_swap_in)
self._append(seq_group, blocks_to_copy)
else:
# OOM. Stop swapping.
self.swapped = self.swapped[:len(self.swapped) - i]
# Swap in the sequence groups in the SWAPPED state if possible.
self.swapped = self.policy.sort_by_priority(now, self.swapped)
while self.swapped:
seq_group = self.swapped[0]
# If the sequence group has been preempted in this step, stop.
if seq_group in preempted:
break
# If the sequence group cannot be swapped in, stop.
if not self.block_manager.can_swap_in(seq_group):
break
else:
# All swapped sequences are swapped in.
self.swapped.clear()
# Ensure that swap-in and swap-out never happen at the same timestep.
if blocks_to_swap_in:
assert not blocks_to_swap_out
seq_group = self.swapped.pop(0)
self._swap_in(seq_group, blocks_to_swap_in)
self._append(seq_group, blocks_to_copy)
self.running.append(seq_group)
num_batched_tokens = sum(
seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running
)
# 3. Join new sequences if possible.
# NOTE: Here we implicitly assume FCFS scheduling.
# TODO(woosuk): Add a batching policy to control the batch size.
# Join waiting sequences if possible.
prompt_group_ids: List[int] = []
# NOTE(woosuk): The sequence groups in the SWAPPED state are strictly
# prioritized over the sequence groups in the WAITING state.
# This is because we want to bound the amount of CPU memory taken by
# the swapped sequence groups.
if not self.swapped:
for i, seq_group in enumerate(self.pending):
self.waiting = self.policy.sort_by_priority(now, self.waiting)
while self.waiting:
seq_group = self.waiting[0]
# If the sequence group has been preempted in this step, stop.
if seq_group in preempted:
break
# If the sequence group cannot be allocated, stop.
if not self.block_manager.can_allocate(seq_group):
break
# If the number of batched tokens exceeds the limit, stop.
num_prompt_tokens = seq_group.seqs[0].get_len()
if self.block_manager.can_allocate(seq_group):
if (num_batched_tokens + num_prompt_tokens
<= self.max_num_batched_tokens):
self._allocate(seq_group)
num_batched_tokens += num_prompt_tokens
continue
if (num_batched_tokens + num_prompt_tokens
> self.max_num_batched_tokens):
break
self.pending = self.pending[i:]
break
else:
self.pending.clear()
seq_group = self.waiting.pop(0)
self._allocate(seq_group)
self.running.append(seq_group)
num_batched_tokens += num_prompt_tokens
prompt_group_ids.append(seq_group.group_id)
# 4. Create input data structures.
return (blocks_to_swap_in,
blocks_to_swap_out,
blocks_to_copy,
prompt_group_ids)
def step(self) -> List[SequenceGroup]:
# Schedule sequence groups.
# This function call changes the internal states of the scheduler
# such as self.running, self.swapped, and self.waiting.
scheduler_output = self._schedule()
blocks_to_swap_in = scheduler_output[0]
blocks_to_swap_out = scheduler_output[1]
blocks_to_copy = scheduler_output[2]
prompt_group_ids = scheduler_output[3]
# Create input data structures.
input_seq_groups: List[SequenceGroupInputs] = []
updated_seq_groups: List[SequenceGroup] = self.running.copy()
for seq_group in self.running:
group_id = seq_group.group_id
num_steps = self.num_steps[group_id]
# NOTE(woosuk): We assume that the number of steps is 0
# for the prompt sequences.
is_prompt = num_steps == 0
is_prompt = group_id in prompt_group_ids
input_tokens: Dict[int, List[int]] = {}
seq_logprobs: Dict[int, float] = {}
@ -211,13 +211,15 @@ class Scheduler:
)
input_seq_groups.append(input_seq_group)
# 5. Execute the first stage of the pipeline.
if (input_seq_groups or blocks_to_swap_in or blocks_to_swap_out):
# Execute the first stage of the pipeline.
if input_seq_groups or blocks_to_swap_in or blocks_to_swap_out:
# Swap in and swap out should never happen at the same time.
assert not (blocks_to_swap_in and blocks_to_swap_out)
self.controllers[0].execute_stage(
input_seq_groups,
blocks_to_swap_in,
blocks_to_swap_out,
blocks_to_copy,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
return updated_seq_groups
@ -276,7 +278,106 @@ class Scheduler:
running.append(seq_group)
self.running = running
def _allocate(self, seq_group: SequenceGroup) -> None:
self.block_manager.allocate(seq_group)
for seq in seq_group.seqs:
seq.status = SequenceStatus.RUNNING
# FIXME(woosuk): Support interactive generation.
if seq_group.group_id not in self.num_steps:
self.num_steps[seq_group.group_id] = 0
def _append(
self,
seq_group: SequenceGroup,
blocks_to_copy: Dict[int, List[int]],
) -> None:
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
ret = self.block_manager.append(seq)
if ret is not None:
src_block, dst_block = ret
if src_block in blocks_to_copy:
blocks_to_copy[src_block].append(dst_block)
else:
blocks_to_copy[src_block] = [dst_block]
def _preempt(
self,
seq_group: SequenceGroup,
blocks_to_swap_out: Dict[int, int],
preemption_mode: Optional[PreemptionMode] = None,
) -> None:
# If preemption mode is not specified, we determine the mode as follows:
# We use recomputation by default since it incurs lower overhead than
# swapping. However, when the sequence group has multiple sequences
# (e.g., beam search), recomputation is not supported. In such a case,
# we use swapping instead.
# FIXME(woosuk): This makes our scheduling policy a bit bizarre.
# As swapped sequences are prioritized over waiting sequences,
# sequence groups with multiple sequences are implicitly prioritized
# over sequence groups with a single sequence.
# TODO(woosuk): Support recomputation for sequence groups with multiple
# sequences. This may require a more sophisticated CUDA kernel.
if preemption_mode is None:
seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
if len(seqs) == 1:
preemption_mode = PreemptionMode.RECOMPUTE
else:
preemption_mode = PreemptionMode.SWAP
if preemption_mode == PreemptionMode.RECOMPUTE:
self._preempt_by_recompute(seq_group)
elif preemption_mode == PreemptionMode.SWAP:
self._preempt_by_swap(seq_group, blocks_to_swap_out)
else:
assert False, 'Invalid preemption mode.'
def _preempt_by_recompute(
self,
seq_group: SequenceGroup,
) -> None:
seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
assert len(seqs) == 1
for seq in seqs:
seq.status = SequenceStatus.WAITING
self.block_manager.free(seq)
self.waiting.append(seq_group)
def _preempt_by_swap(
self,
seq_group: SequenceGroup,
blocks_to_swap_out: Dict[int, int],
) -> None:
seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
for seq in seqs:
seq.status = SequenceStatus.SWAPPED
self._swap_out(seq_group, blocks_to_swap_out)
self.swapped.append(seq_group)
def _free_seq(self, seq: Sequence) -> None:
seq.status = SequenceStatus.FINISHED
self.block_manager.free(seq)
def _free_seq_group(self, seq_group: SequenceGroup) -> None:
group_id = seq_group.group_id
del self.num_steps[group_id]
del self.sampling_params[group_id]
def _swap_in(
self,
seq_group: SequenceGroup,
blocks_to_swap_in: Dict[int, int],
) -> None:
mapping = self.block_manager.swap_in(seq_group)
blocks_to_swap_in.update(mapping)
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
seq.status = SequenceStatus.RUNNING
def _swap_out(
self,
seq_group: SequenceGroup,
blocks_to_swap_out: Dict[int, int],
) -> None:
assert self.block_manager.can_swap_out(seq_group)
mapping = self.block_manager.swap_out(seq_group)
blocks_to_swap_out.update(mapping)
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq.status = SequenceStatus.SWAPPED

View File

@ -10,6 +10,7 @@ from cacheflow.worker.controller import Controller, DeviceID
from cacheflow.sequence import SequenceGroup
from cacheflow.sampling_params import SamplingParams
class Server:
def __init__(
self,
@ -91,7 +92,7 @@ class Server:
return self.scheduler.step()
def has_unfinished_requests(self):
return (self.scheduler.pending or self.scheduler.running or
return (self.scheduler.waiting or self.scheduler.running or
self.scheduler.swapped)

View File

@ -1,3 +1,4 @@
import time
from typing import List, Optional, Set, Tuple
from transformers import AutoTokenizer
@ -39,6 +40,7 @@ class SimpleFrontend:
token_ids: List[int],
sampling_params: SamplingParams,
) -> None:
arrival_time = time.time()
seqs: List[Sequence] = []
for _ in range(sampling_params.n):
seq_id = next(self.seq_counter)
@ -46,7 +48,7 @@ class SimpleFrontend:
seqs.append(seq)
group_id = next(self.seq_group_counter)
seq_group = SequenceGroup(group_id, seqs)
seq_group = SequenceGroup(group_id, seqs, arrival_time)
self.inputs.append((seq_group, sampling_params))
def get_inputs(self) -> List[Tuple[SequenceGroup, SamplingParams]]:

View File

@ -7,7 +7,7 @@ from cacheflow.sampling_params import SamplingParams
class SequenceStatus(enum.Enum):
PENDING = enum.auto()
WAITING = enum.auto()
RUNNING = enum.auto()
SWAPPED = enum.auto()
FINISHED = enum.auto()
@ -28,7 +28,7 @@ class Sequence:
# Initialize the logical token blocks with the given token ids.
self.add(token_ids)
self.status = SequenceStatus.PENDING
self.status = SequenceStatus.WAITING
self.output_logprobs: List[Dict[int, float]] = []
self.cumulative_logprobs = 0.0
@ -88,9 +88,11 @@ class SequenceGroup:
self,
group_id: int,
seqs: List[Sequence],
arrival_time: float,
) -> None:
self.group_id = group_id
self.seqs = seqs
self.arrival_time = arrival_time
def get_seqs(
self,