From 7a7929abe8e2fd6a4688487c471a1ee1fde0edd2 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 30 Mar 2023 14:51:46 -0700 Subject: [PATCH] Implement preemption via recomputation & Refactor scheduling logic (#12) --- cacheflow/http_frontend/fastapi_frontend.py | 3 +- cacheflow/master/block_manager.py | 3 +- cacheflow/master/policy.py | 45 +++ cacheflow/master/scheduler.py | 337 +++++++++++++------- cacheflow/master/server.py | 3 +- cacheflow/master/simple_frontend.py | 4 +- cacheflow/sequence.py | 6 +- 7 files changed, 277 insertions(+), 124 deletions(-) create mode 100644 cacheflow/master/policy.py diff --git a/cacheflow/http_frontend/fastapi_frontend.py b/cacheflow/http_frontend/fastapi_frontend.py index dff7f7526ac68..d901baac82fc5 100644 --- a/cacheflow/http_frontend/fastapi_frontend.py +++ b/cacheflow/http_frontend/fastapi_frontend.py @@ -84,8 +84,9 @@ class FastAPIFrontend: seq = Sequence(seq_id, token_ids, block_size=self.block_size) seqs.append(seq) + arrival_time = time.time() group_id = next(self.seq_group_counter) - seq_group = SequenceGroup(group_id, seqs) + seq_group = SequenceGroup(group_id, seqs, arrival_time) group_event = asyncio.Event() self.sequence_group_events[group_id] = group_event await self.server.add_sequence_groups.remote([(seq_group, sampling_params)]) diff --git a/cacheflow/master/block_manager.py b/cacheflow/master/block_manager.py index a4e346ec6afb3..1616b7c785173 100644 --- a/cacheflow/master/block_manager.py +++ b/cacheflow/master/block_manager.py @@ -76,7 +76,8 @@ class BlockSpaceManager: self.block_tables: Dict[int, BlockTable] = {} def can_allocate(self, seq_group: SequenceGroup) -> bool: - # NOTE: Here we assume that all sequences in the group have the same prompt. + # FIXME(woosuk): Here we assume that all sequences in the group share + # the same prompt. This may not be true for preempted sequences. seq = seq_group.seqs[0] num_required_blocks = len(seq.logical_token_blocks) num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() diff --git a/cacheflow/master/policy.py b/cacheflow/master/policy.py new file mode 100644 index 0000000000000..7d8afbff10c71 --- /dev/null +++ b/cacheflow/master/policy.py @@ -0,0 +1,45 @@ +from typing import List + +from cacheflow.sequence import SequenceGroup + + +class Policy: + + def get_priority( + self, + now: float, + seq_group: SequenceGroup, + ) -> float: + raise NotImplementedError + + def sort_by_priority( + self, + now: float, + seq_groups: List[SequenceGroup], + ) -> List[SequenceGroup]: + return sorted( + seq_groups, + key=lambda seq_group: self.get_priority(now, seq_group), + reverse=True, + ) + + +class FCFS(Policy): + + def get_priority( + self, + now: float, + seq_group: SequenceGroup, + ) -> float: + return now - seq_group.arrival_time + + +class PolicyFactory: + + _POLICY_REGISTRY = { + 'fcfs': FCFS, + } + + @classmethod + def get_policy(cls, policy_name: str, **kwargs) -> Policy: + return cls._POLICY_REGISTRY[policy_name](**kwargs) diff --git a/cacheflow/master/scheduler.py b/cacheflow/master/scheduler.py index c71d0768bcb24..c0ab33066c977 100644 --- a/cacheflow/master/scheduler.py +++ b/cacheflow/master/scheduler.py @@ -1,6 +1,9 @@ -from typing import Dict, List, Tuple +import enum +import time +from typing import Dict, List, Optional, Tuple from cacheflow.master.block_manager import BlockSpaceManager +from cacheflow.master.policy import PolicyFactory from cacheflow.sampling_params import SamplingParams from cacheflow.sequence import Sequence from cacheflow.sequence import SequenceGroup @@ -9,6 +12,19 @@ from cacheflow.sequence import SequenceOutputs from cacheflow.sequence import SequenceStatus +class PreemptionMode(enum.Enum): + """Preemption modes. + + 1. Swapping: Swap out the blocks of the preempted sequences to CPU memory + and swap them back in when the sequences are resumed. + 2. Recomputation: Discard the blocks of the preempted sequences and + recompute them when the sequences are resumed, treating the sequences as + new prompts. + """ + SWAP = enum.auto() + RECOMPUTE = enum.auto() + + class Scheduler: def __init__( @@ -25,6 +41,8 @@ class Scheduler: self.num_cpu_blocks = num_cpu_blocks self.max_num_batched_tokens = max_num_batched_tokens + # Instantiate the scheduling policy. + self.policy = PolicyFactory.get_policy(policy_name='fcfs') # Create the block space manager. self.block_manager = BlockSpaceManager( block_size=block_size, @@ -32,158 +50,140 @@ class Scheduler: num_cpu_blocks=num_cpu_blocks, ) - # Running sequence groups (FIFO). + # Sequence groups in the WAITING state. + self.waiting: List[SequenceGroup] = [] + # Sequence groups in the RUNNING state. self.running: List[SequenceGroup] = [] # Mapping: group_id -> num_steps. self.num_steps: Dict[int, int] = {} # Mapping: group_id -> sampling params. self.sampling_params: Dict[int, SamplingParams] = {} - - # Swapped sequence groups (LIFO). + # Sequence groups in the SWAPPED state. self.swapped: List[SequenceGroup] = [] - # Pending sequence groups (FIFO). - self.pending: List[SequenceGroup] = [] def add_sequence_groups( self, - sequence_groups: List[Tuple[SequenceGroup, SamplingParams]], + seq_groups: List[Tuple[SequenceGroup, SamplingParams]], ) -> None: - # Add sequence groups to the pending queue. - for seq_group, sampling_params in sequence_groups: - self.pending.append(seq_group) + # Add sequence groups to the waiting queue. + for seq_group, sampling_params in seq_groups: + self.waiting.append(seq_group) self.sampling_params[seq_group.group_id] = sampling_params - def _free_seq(self, seq: Sequence) -> None: - seq.status = SequenceStatus.FINISHED - self.block_manager.free(seq) - - def _allocate(self, seq_group: SequenceGroup) -> None: - self.block_manager.allocate(seq_group) - for seq in seq_group.seqs: - seq.status = SequenceStatus.RUNNING - self.running.append(seq_group) - # FIXME(woosuk): Support interactive generation. - self.num_steps[seq_group.group_id] = 0 - - def _append( + def _schedule( self, - seq_group: SequenceGroup, - blocks_to_copy: Dict[int, List[int]], - ) -> None: - for seq in seq_group.seqs: - if seq.status == SequenceStatus.FINISHED: - continue - ret = self.block_manager.append(seq) - if ret is not None: - src_block, dst_block = ret - if src_block in blocks_to_copy: - blocks_to_copy[src_block].append(dst_block) - else: - blocks_to_copy[src_block] = [dst_block] - - def _swap_in( - self, - seq_group: SequenceGroup, - blocks_to_swap_in: Dict[int, int], - ) -> None: - mapping = self.block_manager.swap_in(seq_group) - blocks_to_swap_in.update(mapping) - for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): - seq.status = SequenceStatus.RUNNING - self.running.append(seq_group) - - def _swap_out( - self, - seq_group: SequenceGroup, - blocks_to_swap_out: Dict[int, int], - ) -> None: - assert self.block_manager.can_swap_out(seq_group) - mapping = self.block_manager.swap_out(seq_group) - blocks_to_swap_out.update(mapping) - for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - seq.status = SequenceStatus.SWAPPED - self.swapped.append(seq_group) - - def step(self) -> List[SequenceGroup]: + ) -> Tuple[Dict[int, int], Dict[int, int], Dict[int, List[int]], List[int]]: # Blocks that need to be swaped or copied before model execution. blocks_to_swap_in: Dict[int, int] = {} blocks_to_swap_out: Dict[int, int] = {} blocks_to_copy: Dict[int, List[int]] = {} - # 1. Reserve new slots for the running sequences. - # NOTE: Here we implicitly assume FCFS scheduling. - # That is, the most recently added sequence group is the first - # to be swapped out. - victim_idx = len(self.running) - 1 - for i, seq_group in enumerate(self.running): - if i > victim_idx: - # The i-th sequence group has already been swapped out. - break - # OOM. Swap out the victim sequence groups. + # Fix the current time. + now = time.time() + + # NOTE(woosuk): We prioritize the sequence groups in the RUNNING state + # in order to minimize the preemption overheads. + # Preemption happens only when there is no available slot to keep all + # the sequence groups in the RUNNING state. + # In this case, the policy is responsible for deciding which sequence + # groups to preempt. + self.running = self.policy.sort_by_priority(now, self.running) + + # Reserve new token slots for the running sequence groups. + running: List[SequenceGroup] = [] + preempted: List[SequenceGroup] = [] + while self.running: + seq_group = self.running.pop(0) while not self.block_manager.can_append(seq_group): - victim_seq_group = self.running[victim_idx] - self._swap_out(victim_seq_group, blocks_to_swap_out) - victim_idx -= 1 - if i > victim_idx: - # No other sequence groups can be swapped out. + if self.running: + # Preempt the lowest-priority sequence groups. + victim_seq_group = self.running.pop(-1) + self._preempt(victim_seq_group, blocks_to_swap_out) + preempted.append(victim_seq_group) + else: + # No other sequence groups can be preempted. + # Preempt the current sequence group. + self._preempt(seq_group, blocks_to_swap_out) + preempted.append(seq_group) break else: + # Append new slots to the sequence group. self._append(seq_group, blocks_to_copy) - self.running = self.running[:victim_idx + 1] + running.append(seq_group) + self.running = running - # 2. Swap in the swapped sequences if possible. - # NOTE: Here we implicitly assume FCFS scheduling. - # The swapped sequences are in LIFO order. - for i, seq_group in enumerate(reversed(self.swapped)): - if self.block_manager.can_swap_in(seq_group): - self._swap_in(seq_group, blocks_to_swap_in) - self._append(seq_group, blocks_to_copy) - else: - # OOM. Stop swapping. - self.swapped = self.swapped[:len(self.swapped) - i] + # Swap in the sequence groups in the SWAPPED state if possible. + self.swapped = self.policy.sort_by_priority(now, self.swapped) + while self.swapped: + seq_group = self.swapped[0] + # If the sequence group has been preempted in this step, stop. + if seq_group in preempted: + break + # If the sequence group cannot be swapped in, stop. + if not self.block_manager.can_swap_in(seq_group): break - else: - # All swapped sequences are swapped in. - self.swapped.clear() - # Ensure that swap-in and swap-out never happen at the same timestep. - if blocks_to_swap_in: - assert not blocks_to_swap_out + seq_group = self.swapped.pop(0) + self._swap_in(seq_group, blocks_to_swap_in) + self._append(seq_group, blocks_to_copy) + self.running.append(seq_group) num_batched_tokens = sum( seq_group.num_seqs(status=SequenceStatus.RUNNING) for seq_group in self.running ) - # 3. Join new sequences if possible. - # NOTE: Here we implicitly assume FCFS scheduling. - # TODO(woosuk): Add a batching policy to control the batch size. + # Join waiting sequences if possible. + prompt_group_ids: List[int] = [] + # NOTE(woosuk): The sequence groups in the SWAPPED state are strictly + # prioritized over the sequence groups in the WAITING state. + # This is because we want to bound the amount of CPU memory taken by + # the swapped sequence groups. if not self.swapped: - for i, seq_group in enumerate(self.pending): + self.waiting = self.policy.sort_by_priority(now, self.waiting) + while self.waiting: + seq_group = self.waiting[0] + # If the sequence group has been preempted in this step, stop. + if seq_group in preempted: + break + # If the sequence group cannot be allocated, stop. + if not self.block_manager.can_allocate(seq_group): + break + + # If the number of batched tokens exceeds the limit, stop. num_prompt_tokens = seq_group.seqs[0].get_len() - if self.block_manager.can_allocate(seq_group): - if (num_batched_tokens + num_prompt_tokens - <= self.max_num_batched_tokens): - self._allocate(seq_group) - num_batched_tokens += num_prompt_tokens - continue + if (num_batched_tokens + num_prompt_tokens + > self.max_num_batched_tokens): + break - self.pending = self.pending[i:] - break - else: - self.pending.clear() + seq_group = self.waiting.pop(0) + self._allocate(seq_group) + self.running.append(seq_group) + num_batched_tokens += num_prompt_tokens + prompt_group_ids.append(seq_group.group_id) - # 4. Create input data structures. + return (blocks_to_swap_in, + blocks_to_swap_out, + blocks_to_copy, + prompt_group_ids) + + def step(self) -> List[SequenceGroup]: + # Schedule sequence groups. + # This function call changes the internal states of the scheduler + # such as self.running, self.swapped, and self.waiting. + scheduler_output = self._schedule() + blocks_to_swap_in = scheduler_output[0] + blocks_to_swap_out = scheduler_output[1] + blocks_to_copy = scheduler_output[2] + prompt_group_ids = scheduler_output[3] + + # Create input data structures. input_seq_groups: List[SequenceGroupInputs] = [] updated_seq_groups: List[SequenceGroup] = self.running.copy() for seq_group in self.running: group_id = seq_group.group_id - num_steps = self.num_steps[group_id] - - # NOTE(woosuk): We assume that the number of steps is 0 - # for the prompt sequences. - is_prompt = num_steps == 0 + is_prompt = group_id in prompt_group_ids input_tokens: Dict[int, List[int]] = {} seq_logprobs: Dict[int, float] = {} @@ -211,13 +211,15 @@ class Scheduler: ) input_seq_groups.append(input_seq_group) - # 5. Execute the first stage of the pipeline. - if (input_seq_groups or blocks_to_swap_in or blocks_to_swap_out): + # Execute the first stage of the pipeline. + if input_seq_groups or blocks_to_swap_in or blocks_to_swap_out: + # Swap in and swap out should never happen at the same time. + assert not (blocks_to_swap_in and blocks_to_swap_out) self.controllers[0].execute_stage( input_seq_groups, - blocks_to_swap_in, - blocks_to_swap_out, - blocks_to_copy, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, ) return updated_seq_groups @@ -276,7 +278,106 @@ class Scheduler: running.append(seq_group) self.running = running + def _allocate(self, seq_group: SequenceGroup) -> None: + self.block_manager.allocate(seq_group) + for seq in seq_group.seqs: + seq.status = SequenceStatus.RUNNING + # FIXME(woosuk): Support interactive generation. + if seq_group.group_id not in self.num_steps: + self.num_steps[seq_group.group_id] = 0 + + def _append( + self, + seq_group: SequenceGroup, + blocks_to_copy: Dict[int, List[int]], + ) -> None: + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + ret = self.block_manager.append(seq) + if ret is not None: + src_block, dst_block = ret + if src_block in blocks_to_copy: + blocks_to_copy[src_block].append(dst_block) + else: + blocks_to_copy[src_block] = [dst_block] + + def _preempt( + self, + seq_group: SequenceGroup, + blocks_to_swap_out: Dict[int, int], + preemption_mode: Optional[PreemptionMode] = None, + ) -> None: + # If preemption mode is not specified, we determine the mode as follows: + # We use recomputation by default since it incurs lower overhead than + # swapping. However, when the sequence group has multiple sequences + # (e.g., beam search), recomputation is not supported. In such a case, + # we use swapping instead. + # FIXME(woosuk): This makes our scheduling policy a bit bizarre. + # As swapped sequences are prioritized over waiting sequences, + # sequence groups with multiple sequences are implicitly prioritized + # over sequence groups with a single sequence. + # TODO(woosuk): Support recomputation for sequence groups with multiple + # sequences. This may require a more sophisticated CUDA kernel. + if preemption_mode is None: + seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) + if len(seqs) == 1: + preemption_mode = PreemptionMode.RECOMPUTE + else: + preemption_mode = PreemptionMode.SWAP + if preemption_mode == PreemptionMode.RECOMPUTE: + self._preempt_by_recompute(seq_group) + elif preemption_mode == PreemptionMode.SWAP: + self._preempt_by_swap(seq_group, blocks_to_swap_out) + else: + assert False, 'Invalid preemption mode.' + + def _preempt_by_recompute( + self, + seq_group: SequenceGroup, + ) -> None: + seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) + assert len(seqs) == 1 + for seq in seqs: + seq.status = SequenceStatus.WAITING + self.block_manager.free(seq) + self.waiting.append(seq_group) + + def _preempt_by_swap( + self, + seq_group: SequenceGroup, + blocks_to_swap_out: Dict[int, int], + ) -> None: + seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) + for seq in seqs: + seq.status = SequenceStatus.SWAPPED + self._swap_out(seq_group, blocks_to_swap_out) + self.swapped.append(seq_group) + + def _free_seq(self, seq: Sequence) -> None: + seq.status = SequenceStatus.FINISHED + self.block_manager.free(seq) + def _free_seq_group(self, seq_group: SequenceGroup) -> None: group_id = seq_group.group_id del self.num_steps[group_id] del self.sampling_params[group_id] + + def _swap_in( + self, + seq_group: SequenceGroup, + blocks_to_swap_in: Dict[int, int], + ) -> None: + mapping = self.block_manager.swap_in(seq_group) + blocks_to_swap_in.update(mapping) + for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): + seq.status = SequenceStatus.RUNNING + + def _swap_out( + self, + seq_group: SequenceGroup, + blocks_to_swap_out: Dict[int, int], + ) -> None: + assert self.block_manager.can_swap_out(seq_group) + mapping = self.block_manager.swap_out(seq_group) + blocks_to_swap_out.update(mapping) + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + seq.status = SequenceStatus.SWAPPED diff --git a/cacheflow/master/server.py b/cacheflow/master/server.py index 92b9858c375f6..1f224316c01b4 100644 --- a/cacheflow/master/server.py +++ b/cacheflow/master/server.py @@ -10,6 +10,7 @@ from cacheflow.worker.controller import Controller, DeviceID from cacheflow.sequence import SequenceGroup from cacheflow.sampling_params import SamplingParams + class Server: def __init__( self, @@ -91,7 +92,7 @@ class Server: return self.scheduler.step() def has_unfinished_requests(self): - return (self.scheduler.pending or self.scheduler.running or + return (self.scheduler.waiting or self.scheduler.running or self.scheduler.swapped) diff --git a/cacheflow/master/simple_frontend.py b/cacheflow/master/simple_frontend.py index 9a86226b90e8e..c0459333b8a44 100644 --- a/cacheflow/master/simple_frontend.py +++ b/cacheflow/master/simple_frontend.py @@ -1,3 +1,4 @@ +import time from typing import List, Optional, Set, Tuple from transformers import AutoTokenizer @@ -39,6 +40,7 @@ class SimpleFrontend: token_ids: List[int], sampling_params: SamplingParams, ) -> None: + arrival_time = time.time() seqs: List[Sequence] = [] for _ in range(sampling_params.n): seq_id = next(self.seq_counter) @@ -46,7 +48,7 @@ class SimpleFrontend: seqs.append(seq) group_id = next(self.seq_group_counter) - seq_group = SequenceGroup(group_id, seqs) + seq_group = SequenceGroup(group_id, seqs, arrival_time) self.inputs.append((seq_group, sampling_params)) def get_inputs(self) -> List[Tuple[SequenceGroup, SamplingParams]]: diff --git a/cacheflow/sequence.py b/cacheflow/sequence.py index 8cdd977237f1e..607ea328c30a8 100644 --- a/cacheflow/sequence.py +++ b/cacheflow/sequence.py @@ -7,7 +7,7 @@ from cacheflow.sampling_params import SamplingParams class SequenceStatus(enum.Enum): - PENDING = enum.auto() + WAITING = enum.auto() RUNNING = enum.auto() SWAPPED = enum.auto() FINISHED = enum.auto() @@ -28,7 +28,7 @@ class Sequence: # Initialize the logical token blocks with the given token ids. self.add(token_ids) - self.status = SequenceStatus.PENDING + self.status = SequenceStatus.WAITING self.output_logprobs: List[Dict[int, float]] = [] self.cumulative_logprobs = 0.0 @@ -88,9 +88,11 @@ class SequenceGroup: self, group_id: int, seqs: List[Sequence], + arrival_time: float, ) -> None: self.group_id = group_id self.seqs = seqs + self.arrival_time = arrival_time def get_seqs( self,