diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 88c2c37f4fb3..431b4534a2a1 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -1,10 +1,15 @@ import time +from collections import deque from typing import List +from unittest.mock import MagicMock import pytest # noqa -from vllm.config import CacheConfig, SchedulerConfig -from vllm.core.scheduler import Scheduler +from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig +from vllm.core.interfaces import AllocStatus +from vllm.core.policy import PolicyFactory +from vllm.core.scheduler import Scheduler, SchedulingBudget +from vllm.lora.request import LoRARequest from vllm.sequence import Logprob, SequenceGroup from .utils import create_dummy_prompt @@ -177,7 +182,6 @@ def test_scheduler_max_seqs(): def test_scheduler_delay_factor(): - block_size = 4 scheduler_config = SchedulerConfig(100, 64, 16, delay_factor=0.5) cache_config = CacheConfig(block_size, 1.0, 1, "auto") @@ -189,7 +193,7 @@ def test_scheduler_delay_factor(): _, seq_group = create_dummy_prompt("0", prompt_length=block_size) scheduler.add_seq_group(seq_group) seq_group_meta, out = scheduler.schedule() - assert out.prompt_run + assert out.num_prefill_groups > 0 assert seq_group_meta[0].request_id == '0' # wait for a second before scheduling next prompt @@ -199,11 +203,533 @@ def test_scheduler_delay_factor(): # second prompt should *not* be scheduled seq_group_meta, out = scheduler.schedule() - assert not out.prompt_run + assert out.num_prefill_groups == 0 assert seq_group_meta[0].request_id == '0' # wait for more than 0.5 second and try again time.sleep(0.6) seq_group_meta, out = scheduler.schedule() - assert out.prompt_run + assert out.num_prefill_groups > 0 assert seq_group_meta[0].request_id == '1' + + +def test_swapped_out_prioritized(): + scheduler = initialize_scheduler(max_num_seqs=6) + # best_of=2 * 3 == 6 sequences. + for i in range(3): + _, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2) + scheduler.add_seq_group(seq_group) + _, out = scheduler.schedule() + # prefill scheduled now. + assert len(out.scheduled_seq_groups) == 3 + + # The last request should be swapped out. + scheduler.block_manager.can_append_slots = MagicMock() + + def cannot_append_second_group(seq_group, num_lookahead_slots): + return seq_group.request_id != "2" + + scheduler.block_manager.can_append_slots.side_effect = ( + cannot_append_second_group) + + _, out = scheduler.schedule() + assert len(out.scheduled_seq_groups) == 2 + assert out.num_batched_tokens == 2 + assert out.blocks_to_swap_out != {} + assert out.blocks_to_swap_in == {} + + # Add 1 more task. Swap should be prioritized over prefill. + _, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2) + scheduler.add_seq_group(seq_group) + _, out = scheduler.schedule() + assert len(out.scheduled_seq_groups) == 3 + # 3 decodes. It is swapped in. + assert out.num_batched_tokens == 3 + assert out.blocks_to_swap_in != {} + assert out.blocks_to_swap_out == {} + + +def initialize_scheduler(*, + max_num_seqs=1000, + max_token_budget=1000, + max_model_len=1000, + lora_config=None): + block_size = 4 + scheduler_config = SchedulerConfig(max_token_budget, max_num_seqs, + max_model_len) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 8 + cache_config.num_gpu_blocks = 8 + scheduler = Scheduler(scheduler_config, cache_config, lora_config) + return scheduler + + +def create_token_budget(num_batched_tokens: int = 0, + num_curr_seqs: int = 0, + token_budget: int = 10000, + max_num_seqs: int = 10000) -> SchedulingBudget: + return SchedulingBudget( + num_batched_tokens=num_batched_tokens, + num_curr_seqs=num_curr_seqs, + token_budget=token_budget, + max_num_seqs=max_num_seqs, + ) + + +def test_prefill_schedule_max_prompt_len(): + """ + Test prompt longer than max_prompt_len is aborted. + """ + scheduler = initialize_scheduler(max_model_len=30) + _, seq_group = create_dummy_prompt(0, prompt_length=60) + waiting = deque([seq_group]) + budget = create_token_budget() + remaining_waiting, output = scheduler._schedule_prefills( + waiting, budget, None) + assert len(output.ignored_seq_groups) == 1 + assert len(output.seq_groups) == 0 + assert budget.num_batched_tokens == 0 + assert budget.num_curr_seqs == 0 + assert len(remaining_waiting) == 0 + + +def test_prefill_schedule_token_budget(): + """ + Test token budget respected. + """ + scheduler = initialize_scheduler() + waiting = deque() + budget = create_token_budget(token_budget=0) + for i in range(2): + _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + waiting.append(seq_group) + + # 0 token budget == nothing is scheduled. + remaining_waiting, output = scheduler._schedule_prefills( + waiting, budget, None) + assert len(output.ignored_seq_groups) == 0 + assert len(output.seq_groups) == 0 + assert budget.num_batched_tokens == 0 + assert budget.num_curr_seqs == 0 + assert len(remaining_waiting) == 2 + + # 60 token budget == 1 request scheduled. + budget = create_token_budget(token_budget=60) + remaining_waiting, output = scheduler._schedule_prefills( + waiting, budget, None) + assert len(output.ignored_seq_groups) == 0 + assert len(output.seq_groups) == 1 + assert budget.num_batched_tokens == 60 + assert budget.num_curr_seqs == 1 + assert len(remaining_waiting) == 1 + + # Test when current_batched_tokens respected. + scheduler = initialize_scheduler() + waiting = deque() + budget = create_token_budget(num_batched_tokens=30, token_budget=60) + _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + # Cannot schedule a prompt that doesn't fit the budget. + waiting.append(seq_group) + remaining_waiting, output = scheduler._schedule_prefills( + waiting, budget, None) + assert len(output.ignored_seq_groups) == 0 + assert len(output.seq_groups) == 0 + assert budget.num_batched_tokens == 30 + assert budget.num_curr_seqs == 0 + assert len(remaining_waiting) == 1 + budget = create_token_budget(num_batched_tokens=30, token_budget=90) + remaining_waiting, output = scheduler._schedule_prefills( + waiting, budget, None) + assert len(output.seq_groups) == 1 + assert budget.num_batched_tokens == 90 + assert budget.num_curr_seqs == 1 + assert len(remaining_waiting) == 0 + + +def test_prefill_schedule_max_seqs(): + """ + Test max seq respected. + """ + scheduler = initialize_scheduler() + waiting = deque() + budget = create_token_budget(max_num_seqs=2) + for i in range(3): + _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + waiting.append(seq_group) + remaining_waiting, output = scheduler._schedule_prefills( + waiting, budget, None) + assert len(output.ignored_seq_groups) == 0 + assert len(output.seq_groups) == 2 + assert budget.num_batched_tokens == 120 + assert budget.num_curr_seqs == 2 + assert len(remaining_waiting) == 1 + + # Verify curr_num_seqs respected. + waiting = deque() + budget = create_token_budget(num_curr_seqs=2, max_num_seqs=2) + _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + waiting.append(seq_group) + remaining_waiting, output = scheduler._schedule_prefills( + waiting, budget, None) + assert len(output.ignored_seq_groups) == 0 + assert len(output.seq_groups) == 0 + assert budget.num_batched_tokens == 0 + assert budget.num_curr_seqs == 2 + assert len(remaining_waiting) == 1 + + +def test_prefill_schedule_max_lora(): + """ + Test max lora is respected and prioritized. + """ + lora_config = LoRAConfig(max_lora_rank=8, max_loras=1) + scheduler = initialize_scheduler(lora_config=lora_config) + waiting = deque() + budget = create_token_budget(token_budget=120) + curr_loras = set() + for i in range(2): + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + lora_request=LoRARequest( + lora_name=str(i), + lora_int_id=i + 1, + lora_local_path="abc")) + waiting.append(seq_group) + # Add two more requests to verify lora is prioritized. + # 0: Lora, 1: Lora, 2: regular, 3: regular + # In the first iteration, index 0, 2 is scheduled. + # If a request is not scheduled because it hits max lora, it is + # prioritized. Verify that. + for i in range(2, 4): + _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + waiting.append(seq_group) + # Schedule 2 requests (0 and 2) + remaining_waiting, output = scheduler._schedule_prefills( + waiting, budget, curr_loras) + assert len(output.ignored_seq_groups) == 0 + assert len(output.seq_groups) == 2 + assert budget.num_batched_tokens == 120 + assert budget.num_curr_seqs == 2 + assert len(remaining_waiting) == 2 + assert len(curr_loras) == 1 + # The second lora request is scheduled next as FCFS policy. + # Reset curr_loras so that it can be scheduled. + curr_loras = set() + budget = create_token_budget(token_budget=60) + remaining_waiting, output = scheduler._schedule_prefills( + remaining_waiting, budget, curr_loras) + assert len(output.seq_groups) == 1 + assert output.seq_groups[0].seq_group.request_id == "1" + assert len(remaining_waiting) == 1 + assert len(curr_loras) == 1 + assert budget.num_batched_tokens == 60 + + +def test_prefill_schedule_no_block_manager_capacity(): + """ + Test sequence cannot be scheduled due to block manager has no capacity. + """ + scheduler = initialize_scheduler() + waiting = deque() + budget = create_token_budget() + for i in range(3): + _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + waiting.append(seq_group) + scheduler.block_manager.can_allocate = MagicMock() + scheduler.block_manager.can_allocate.return_value = AllocStatus.LATER + remainig_waiting, output = scheduler._schedule_prefills( + waiting, budget, None) + assert len(output.ignored_seq_groups) == 0 + assert len(output.seq_groups) == 0 + assert budget.num_batched_tokens == 0 + assert budget.num_curr_seqs == 0 + assert len(remainig_waiting) == 3 + + scheduler = initialize_scheduler() + waiting = deque() + budget = create_token_budget() + for i in range(3): + _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + waiting.append(seq_group) + scheduler.block_manager.can_allocate = MagicMock() + scheduler.block_manager.can_allocate.return_value = AllocStatus.NEVER + remaining_waiting, output = scheduler._schedule_prefills( + waiting, budget, None) + assert len(output.ignored_seq_groups) == 3 + assert len(output.seq_groups) == 0 + assert budget.num_batched_tokens == 0 + assert budget.num_curr_seqs == 0 + assert len(remaining_waiting) == 0 + + +def test_decode_schedule_preempted(): + """ + Test decodes cannot be scheduled and preempted. + """ + scheduler = initialize_scheduler() + running = deque() + policy = PolicyFactory.get_policy(policy_name="fcfs") + curr_loras = None + for i in range(3): + _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + scheduler._allocate_and_set_running(seq_group) + running.append(seq_group) + scheduler.block_manager.can_append_slots = MagicMock() + + def cannot_append_second_group(seq_group, num_lookahead_slots): + return seq_group.request_id != "1" + + scheduler.block_manager.can_append_slots.side_effect = ( + cannot_append_second_group) + + # 1 cannot be scheduled, and the lowest priority (request 2) + # should be preempted. 1 will also be preempted. + budget = create_token_budget(num_batched_tokens=3, num_curr_seqs=3) + remainig_running, output = scheduler._schedule_decodes( + running, budget, curr_loras, policy) + assert len(remainig_running) == 0 + assert len(output.seq_groups) == 1 + assert output.seq_groups[0].seq_group.request_id == "0" + assert len(output.preempted) == 2 + # Verify budgets are updated. + assert budget.num_batched_tokens == 1 + assert budget.num_curr_seqs == 1 + # Both should be preempted, not swapped. + assert output.blocks_to_swap_out == {} + # Nothing is copied. + assert output.blocks_to_copy == {} + + +def test_decode_swap_beam_search(): + """ + Test best_of > 1 swap out blocks + """ + scheduler = initialize_scheduler() + running = deque() + policy = PolicyFactory.get_policy(policy_name="fcfs") + curr_loras = None + for i in range(3): + _, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2) + scheduler._allocate_and_set_running(seq_group) + running.append(seq_group) + + # The last request should be swapped out. + scheduler.block_manager.can_append_slots = MagicMock() + + def cannot_append_second_group(seq_group, num_lookahead_slots): + return seq_group.request_id != "2" + + scheduler.block_manager.can_append_slots.side_effect = ( + cannot_append_second_group) + scheduler.block_manager.swap_out = MagicMock() + expected_swap_mapping = {"5": "7"} + scheduler.block_manager.swap_out.return_value = expected_swap_mapping + + budget = create_token_budget(num_batched_tokens=3, num_curr_seqs=3) + remainig_running, output = scheduler._schedule_decodes( + running, budget, curr_loras, policy) + assert len(remainig_running) == 0 + assert len(output.seq_groups) == 2 + assert output.seq_groups[0].seq_group.request_id == "0" + assert output.seq_groups[1].seq_group.request_id == "1" + assert len(output.preempted) == 0 + assert len(output.swapped_out) == 1 + # Budget should refledct preempted requests. + assert budget.num_batched_tokens == 2 + # since there are 2 sequences, 2 should be subtracted. + assert budget.num_curr_seqs == 1 + # Both should be preempted, not swapped. + assert output.blocks_to_swap_out == expected_swap_mapping + # Nothing is copied. + assert output.blocks_to_copy == {} + + +def test_schedule_decode_blocks_to_copy_update(): + """ + Verify blocks_to_copy is updated. + """ + scheduler = initialize_scheduler() + _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) + running = deque() + policy = PolicyFactory.get_policy(policy_name="fcfs") + curr_loras = None + scheduler._allocate_and_set_running(seq_group) + running.append(seq_group) + + # The last request should be swapped out. + scheduler.block_manager.append_slots = MagicMock() + scheduler.block_manager.append_slots.return_value = {2: [3]} + + budget = create_token_budget() + remaining_running, output = scheduler._schedule_decodes( + running, budget, curr_loras, policy) + assert len(remaining_running) == 0 + assert len(output.seq_groups) == 1 + assert len(output.preempted) == 0 + assert len(output.swapped_out) == 0 + # Nothing is preempted. + assert output.blocks_to_swap_out == {} + # Since append_slot returns the source -> dist mapping, it should + # applied. + assert output.blocks_to_copy == {2: [3]} + + +def test_schedule_swapped_simple(): + scheduler = initialize_scheduler() + swapped = deque() + policy = PolicyFactory.get_policy(policy_name="fcfs") + curr_loras = None + blocks_to_swap_out = {} + _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) + scheduler._allocate_and_set_running(seq_group) + scheduler._swap_out(seq_group, blocks_to_swap_out) + swapped.append(seq_group) + + budget = create_token_budget() + remaining_swapped, output = scheduler._schedule_swapped( + swapped, budget, curr_loras, policy) + assert len(remaining_swapped) == 0 + assert budget.num_batched_tokens == 1 + assert budget.num_curr_seqs == 2 + assert len(output.seq_groups) == 1 + # swap in is the reverse of swap out + blocks_to_swap_in_reverse = {} + for swapin, swapout in output.blocks_to_swap_in.items(): + blocks_to_swap_in_reverse[swapout] = swapin + assert blocks_to_swap_out == blocks_to_swap_in_reverse + + +def test_schedule_swapped_max_token_budget(): + scheduler = initialize_scheduler() + swapped = deque() + policy = PolicyFactory.get_policy(policy_name="fcfs") + curr_loras = None + blocks_to_swap_out = {} + for _ in range(2): + _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) + scheduler._allocate_and_set_running(seq_group) + scheduler._swap_out(seq_group, blocks_to_swap_out) + swapped.append(seq_group) + + budget = create_token_budget(token_budget=1) + remaining_swapped, output = scheduler._schedule_swapped( + swapped, budget, curr_loras, policy) + assert len(remaining_swapped) == 1 + assert budget.num_batched_tokens == 1 + assert budget.num_curr_seqs == 2 + assert len(output.seq_groups) == 1 + + # Verify num_batched_tokens are respected. + budget = create_token_budget(num_batched_tokens=1, token_budget=1) + remaining_swapped, output = scheduler._schedule_swapped( + remaining_swapped, budget, curr_loras, policy) + assert len(remaining_swapped) == 1 + assert budget.num_batched_tokens == 1 + assert budget.num_curr_seqs == 0 + assert len(output.seq_groups) == 0 + + +def test_schedule_swapped_max_seqs(): + scheduler = initialize_scheduler() + swapped = deque() + policy = PolicyFactory.get_policy(policy_name="fcfs") + curr_loras = None + blocks_to_swap_out = {} + for _ in range(2): + _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) + scheduler._allocate_and_set_running(seq_group) + scheduler._swap_out(seq_group, blocks_to_swap_out) + swapped.append(seq_group) + + budget = create_token_budget(max_num_seqs=2) + remaining_swapped, output = scheduler._schedule_swapped( + swapped, budget, curr_loras, policy) + assert len(remaining_swapped) == 1 + assert budget.num_batched_tokens == 1 + assert budget.num_curr_seqs == 2 + assert len(output.seq_groups) == 1 + + # Verify num_curr_seqs are respected. + budget = create_token_budget(num_curr_seqs=2, max_num_seqs=2) + remaining_swapped, output = scheduler._schedule_swapped( + remaining_swapped, budget, curr_loras, policy) + assert len(remaining_swapped) == 1 + assert budget.num_batched_tokens == 0 + assert budget.num_curr_seqs == 2 + assert len(output.seq_groups) == 0 + + +def test_schedule_swapped_max_loras(): + lora_config = LoRAConfig(max_lora_rank=8, max_loras=1) + scheduler = initialize_scheduler(lora_config=lora_config) + swapped = deque() + policy = PolicyFactory.get_policy(policy_name="fcfs") + curr_loras = set() + blocks_to_swap_out = {} + for i in range(2): + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + lora_request=LoRARequest( + lora_name=str(i), + lora_int_id=i + 1, + lora_local_path="abc")) + scheduler._allocate_and_set_running(seq_group) + scheduler._swap_out(seq_group, blocks_to_swap_out) + swapped.append(seq_group) + + budget = create_token_budget() + remaining_swapped, output = scheduler._schedule_swapped( + swapped, budget, curr_loras, policy) + assert len(remaining_swapped) == 1 + assert budget.num_batched_tokens == 1 + assert budget.num_curr_seqs == 1 + assert len(output.seq_groups) == 1 + assert len(curr_loras) == 1 + + +def test_schedule_swapped_cannot_swap_in(): + scheduler = initialize_scheduler() + swapped = deque() + policy = PolicyFactory.get_policy(policy_name="fcfs") + curr_loras = None + blocks_to_swap_out = {} + for _ in range(2): + _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) + scheduler._allocate_and_set_running(seq_group) + scheduler._swap_out(seq_group, blocks_to_swap_out) + swapped.append(seq_group) + + # The last request should be swapped out. + scheduler.block_manager.can_swap_in = MagicMock() + scheduler.block_manager.can_swap_in.return_value = False + # Since we cannot swap in, none of the requests are swapped in. + budget = create_token_budget() + remaining_swapped, output = scheduler._schedule_swapped( + swapped, budget, curr_loras, policy) + assert len(remaining_swapped) == 2 + assert budget.num_batched_tokens == 0 + assert budget.num_curr_seqs == 0 + assert len(output.seq_groups) == 0 + + +def test_schedule_swapped_blocks_to_copy(): + scheduler = initialize_scheduler() + swapped = deque() + policy = PolicyFactory.get_policy(policy_name="fcfs") + curr_loras = None + _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) + scheduler._allocate_and_set_running(seq_group) + blocks_to_swap_out = {} + scheduler._swap_out(seq_group, blocks_to_swap_out) + swapped.append(seq_group) + + # The last request should be swapped out. + scheduler.block_manager.append_slots = MagicMock() + scheduler.block_manager.append_slots.return_value = {2: [3]} + + budget = create_token_budget() + remaining_swapped, output = scheduler._schedule_swapped( + swapped, budget, curr_loras, policy) + assert len(remaining_swapped) == 0 + assert len(output.seq_groups) == 1 + assert output.blocks_to_copy == {2: [3]} diff --git a/tests/core/utils.py b/tests/core/utils.py index 9482c7761c28..fbbdb07cb8e6 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -1,14 +1,19 @@ import time -from typing import Tuple +from typing import Optional, Tuple from vllm import SamplingParams +from vllm.lora.request import LoRARequest from vllm.sequence import Logprob, Sequence, SequenceGroup def create_dummy_prompt( - request_id: str, - prompt_length: int, - block_size: int = None) -> Tuple[Sequence, SequenceGroup]: + request_id: str, + prompt_length: int, + block_size: Optional[int] = None, + lora_request: Optional[LoRARequest] = None, + use_beam_search: bool = False, + best_of: int = 1, +) -> Tuple[Sequence, SequenceGroup]: if not block_size: block_size = prompt_length @@ -17,8 +22,10 @@ def create_dummy_prompt( prompt_tokens = list(range(prompt_length)) prompt_str = " ".join([str(t) for t in prompt_tokens]) prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size) - seq_group = SequenceGroup(request_id, [prompt], SamplingParams(), - time.time(), None) + seq_group = SequenceGroup( + request_id, [prompt], + SamplingParams(use_beam_search=use_beam_search, best_of=best_of), + time.time(), lora_request) return prompt, seq_group diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 9d098801233e..730d549671f6 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -6,11 +6,12 @@ from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.core.interfaces import AllocStatus, BlockSpaceManager -from vllm.core.policy import PolicyFactory +from vllm.core.policy import Policy, PolicyFactory from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceStatus) +from vllm.utils import merge_dicts logger = init_logger(__name__) @@ -28,9 +29,19 @@ class PreemptionMode(enum.Enum): RECOMPUTE = enum.auto() -# seq_group: SequenceGroup to schedule. -# token_chunk_size: The number of prefill tokens to be processed in the next -# step. +@dataclass +class SchedulingBudget: + """The available slots for scheduling.""" + num_batched_tokens: int + num_curr_seqs: int + token_budget: int + max_num_seqs: int + + def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int): + return (self.num_batched_tokens + num_new_tokens <= self.token_budget + and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs) + + @dataclass class ScheduledSequenceGroup: # A sequence group that's scheduled. @@ -41,53 +52,28 @@ class ScheduledSequenceGroup: token_chunk_size: int +@dataclass class SchedulerOutputs: + # Scheduled sequence groups. + scheduled_seq_groups: Iterable[ScheduledSequenceGroup] + # Number of prefill groups scheduled. + num_prefill_groups: int + # Total number of batched tokens. + num_batched_tokens: int + # Blocks to swap in. Dict of CPU -> GPU block number. + blocks_to_swap_in: Dict[int, int] + # Blocks to swap out. Dict of GPU -> CPU block number. + blocks_to_swap_out: Dict[int, int] + # Blocks to copy. Source to a list of dest blocks. + blocks_to_copy: Dict[int, List[int]] + # Sequence groups that are going to be ignored. + ignored_seq_groups: List[SequenceGroup] + # The number of slots for lookahead decoding. + num_lookahead_slots: int - def __init__( - self, - scheduled_seq_groups: Iterable[ScheduledSequenceGroup], - prompt_run: bool, - num_batched_tokens: int, - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - ignored_seq_groups: List[SequenceGroup], - num_lookahead_slots: int, - ) -> None: - """A list of sequence groups to be scheduled as a single batch. - - Args: - scheduled_seq_groups: A tuple of scheduled sequence group and its - token chunk size. - prompt_run: True if all sequence groups are in prefill phase. - If False, all sequence groups are in decoding phase. - num_batched_tokens: Total number of batched tokens. - blocks_to_swap_in: Blocks to swap in. Dict of CPU -> GPU block - number. - blocks_to_swap_out: Blocks to swap out. Dict of GPU -> CPU block - number. - blocks_to_copy: Blocks to copy. Source to a list of dest blocks. - ignored_seq_groups: Sequence groups that are going to be ignored. - """ - # A tuple of scheduled sequence group and its chunk size. - self.scheduled_seq_groups: ScheduledSequenceGroup = scheduled_seq_groups - # True if all sequence groups are in prefill phase. If False, all - # sequence groups are in decoding phase. - self.prompt_run: bool = prompt_run - # Total number of batched tokens. - self.num_batched_tokens: int = num_batched_tokens - # Blocks to swap in. Dict of CPU -> GPU block number. - self.blocks_to_swap_in: Dict[int, int] = blocks_to_swap_in - # Blocks to swap out. Dict of GPU -> CPU block number. - self.blocks_to_swap_out: Dict[int, int] = blocks_to_swap_out - # Blocks to copy. Source to a list of dest blocks. - self.blocks_to_copy: Dict[int, List[int]] = blocks_to_copy - # Sequence groups that are going to be ignored. - self.ignored_seq_groups: List[SequenceGroup] = ignored_seq_groups - + def __post_init__(self): # Swap in and swap out should never happen at the same time. - assert not (blocks_to_swap_in and blocks_to_swap_out) - self.num_lookahead_slots = num_lookahead_slots + assert not (self.blocks_to_swap_in and self.blocks_to_swap_out) self.num_loras: int = len(self.lora_requests) if self.num_loras > 0: @@ -108,6 +94,73 @@ class SchedulerOutputs: return {g.seq_group.lora_request for g in self.scheduled_seq_groups} +@dataclass +class SchedulerDecodeOutputs: + """Outputs of the decoding phase of the scheduler.""" + # Selected sequence groups for decoding. + seq_groups: List[SequenceGroup] + # The preempted sequences. + preempted: List[SequenceGroup] + # Sequences that are swapped out. + swapped_out: List[SequenceGroup] + # The blocks to swap out. + blocks_to_swap_out: Dict[int, int] + # The blocks to copy. + blocks_to_copy: Dict[int, List[int]] + num_lookahead_slots: int + + @classmethod + def create_empty(cls) -> "SchedulerDecodeOutputs": + return SchedulerDecodeOutputs( + seq_groups=[], + preempted=[], + swapped_out=[], + blocks_to_swap_out={}, + blocks_to_copy={}, + num_lookahead_slots=0, + ) + + +@dataclass +class SchedulerSwappedInOutputs: + """Outputs of the decoding phase of the scheduler.""" + # Selected sequence groups for decoding. + seq_groups: List[SequenceGroup] + # The blocks to swap in. + blocks_to_swap_in: Dict[int, int] + # The blocks to copy. + blocks_to_copy: Dict[int, List[int]] + # # The number of batched tokens. + num_lookahead_slots: int + + @classmethod + def create_empty(cls) -> "SchedulerSwappedInOutputs": + return SchedulerSwappedInOutputs( + seq_groups=[], + blocks_to_swap_in={}, + blocks_to_copy={}, + num_lookahead_slots=0, + ) + + +@dataclass +class SchedulerPrefillOutputs: + """Outputs of the prefill phase of the scheduler.""" + # Selected sequence groups for prefill. + seq_groups: List[SequenceGroup] + # Ignored sequence groups. + ignored_seq_groups: List[SequenceGroup] + num_lookahead_slots: int + + @classmethod + def create_empty(cls) -> "SchedulerPrefillOutputs": + return SchedulerPrefillOutputs( + seq_groups=[], + ignored_seq_groups=[], + num_lookahead_slots=0, + ) + + class Scheduler: def __init__( @@ -123,6 +176,7 @@ class Scheduler: # LoRAs. This should be improved in the future. self.lora_config = lora_config + # TODO(sang): Fix it after chunked prefill is enabled. self.prompt_limit = min(self.scheduler_config.max_model_len, self.scheduler_config.max_num_batched_tokens) @@ -142,10 +196,13 @@ class Scheduler: enable_caching=self.cache_config.enable_prefix_caching) # Sequence groups in the WAITING state. + # Contain new prefill or preempted requests. self.waiting: Deque[SequenceGroup] = deque() # Sequence groups in the RUNNING state. + # Contain decode requests. self.running: Deque[SequenceGroup] = deque() # Sequence groups in the SWAPPED state. + # Contain decode requests that are swapped out. self.swapped: Deque[SequenceGroup] = deque() # Time at previous scheduling step @@ -159,8 +216,14 @@ class Scheduler: def lora_enabled(self) -> bool: return bool(self.lora_config) + @property + def num_decoding_tokens_per_seq(self) -> int: + """The number of new tokens.""" + return 1 + def add_seq_group(self, seq_group: SequenceGroup) -> None: # Add sequence groups to the waiting queue. + logger.debug(f"add_seq_group {seq_group.request_id}") self.waiting.append(seq_group) def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: @@ -205,214 +268,365 @@ class Scheduler: def get_num_unfinished_seq_groups(self) -> int: return len(self.waiting) + len(self.running) + len(self.swapped) - def _schedule(self) -> SchedulerOutputs: + def _schedule_decodes( + self, + running_queue: deque, + budget: SchedulingBudget, + curr_loras: Optional[Set[int]], + policy: Policy, + ) -> Tuple[deque, SchedulerDecodeOutputs]: + """Schedule sequence groups in a decoding stage. + + NOTE(sang): All the RUNNING num_batched_tokens, num_curr_seqs, + and curr_loras should be already included in `budget` and `curr_loras`. + The API doesn't ADD UP these values. + + Note that `budget` and `curr_loras` are still subtracted/popped when + any running requests are preempted from this API. + + Args: + running_queue: The queue that contains running requests (i.e., + decodes). The given arguments are NOT in-place modified. + budget: The scheduling budget. The argument is in-place updated + when any decodes are preempted. + curr_loras: Currently batched lora request ids. The argument is + in-place updated when any decodes are preempted. + policy: The sorting policy to sort running_queue. + + Returns: + A tuple of remaining running queue (should be always 0) after + scheduling and SchedulerDecodeOutputs. + """ # Blocks that need to be swapped or copied before model execution. - blocks_to_swap_in: Dict[int, int] = {} blocks_to_swap_out: Dict[int, int] = {} blocks_to_copy: Dict[int, List[int]] = {} - # Fix the current time. - now = time.time() - - # Join waiting sequences if possible. - if not self.swapped: - ignored_seq_groups: List[SequenceGroup] = [] - scheduled: List[SequenceGroup] = [] - # The total number of sequences on the fly, including the - # requests in the generation phase. - num_curr_seqs = sum(seq_group.get_max_num_running_seqs() - for seq_group in self.running) - curr_loras = set( - seq_group.lora_int_id - for seq_group in self.running) if self.lora_enabled else None - - # Optimization: We do not sort the waiting queue since the preempted - # sequence groups are added to the front and the new sequence groups - # are added to the back. - leftover_waiting_sequences = deque() - num_batched_tokens = 0 - while self._passed_delay(now) and self.waiting: - seq_group = self.waiting[0] - waiting_seqs = seq_group.get_seqs( - status=SequenceStatus.WAITING) - assert len(waiting_seqs) == 1, ( - "Waiting sequence group should have only one prompt " - "sequence.") - # get_len includes output tokens if the request has been - # preempted. - num_prefill_tokens = waiting_seqs[0].get_len() - if num_prefill_tokens > self.prompt_limit: - logger.warning( - f"Input prompt ({num_prefill_tokens} tokens) is too " - f"long and exceeds limit of {self.prompt_limit}") - for seq in waiting_seqs: - seq.status = SequenceStatus.FINISHED_IGNORED - ignored_seq_groups.append(seq_group) - self.waiting.popleft() - continue - - # If the sequence group cannot be allocated, stop. - can_allocate = self.block_manager.can_allocate(seq_group) - if can_allocate == AllocStatus.LATER: - break - elif can_allocate == AllocStatus.NEVER: - logger.warning( - f"Input prompt ({num_prefill_tokens} tokens) is too " - f"long and exceeds the capacity of block_manager") - for seq in waiting_seqs: - seq.status = SequenceStatus.FINISHED_IGNORED - ignored_seq_groups.append(seq_group) - self.waiting.popleft() - continue - - lora_int_id = 0 - if self.lora_enabled: - lora_int_id = seq_group.lora_int_id - if (lora_int_id > 0 and lora_int_id not in curr_loras - and len(curr_loras) >= self.lora_config.max_loras): - # We don't have a space for another LoRA, so - # we ignore this request for now. - leftover_waiting_sequences.appendleft(seq_group) - self.waiting.popleft() - continue - - # If the number of batched tokens exceeds the limit, stop. - num_batched_tokens += num_prefill_tokens - if (num_batched_tokens > - self.scheduler_config.max_num_batched_tokens): - break - - # The total number of sequences in the RUNNING state should not - # exceed the maximum number of sequences. - num_new_seqs = seq_group.get_max_num_running_seqs() - if (num_curr_seqs + num_new_seqs > - self.scheduler_config.max_num_seqs): - break - - if lora_int_id > 0: - curr_loras.add(lora_int_id) - self.waiting.popleft() - self._allocate(seq_group) - self.running.append(seq_group) - num_curr_seqs += num_new_seqs - scheduled.append( - ScheduledSequenceGroup( - seq_group=seq_group, - token_chunk_size=num_prefill_tokens)) - self.waiting.extendleft(leftover_waiting_sequences) - - if scheduled or ignored_seq_groups: - self.prev_prompt = True - scheduler_outputs = SchedulerOutputs( - scheduled_seq_groups=scheduled, - prompt_run=True, - num_batched_tokens=num_batched_tokens, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - ignored_seq_groups=ignored_seq_groups, - num_lookahead_slots=self._get_num_lookahead_slots( - is_prefill=True), - ) - return scheduler_outputs + seq_groups: List[ScheduledSequenceGroup] = [] + preempted: List[SequenceGroup] = [] + swapped_out: List[SequenceGroup] = [] # NOTE(woosuk): Preemption happens only when there is no available slot # to keep all the sequence groups in the RUNNING state. # In this case, the policy is responsible for deciding which sequence # groups to preempt. - self.running = self.policy.sort_by_priority(now, self.running) + now = time.time() + running_queue = policy.sort_by_priority(now, running_queue) - # Reserve new token slots for the running sequence groups. - running: Deque[SequenceGroup] = deque() - preempted: List[SequenceGroup] = [] - while self.running: - seq_group = self.running.popleft() + while running_queue: + # NOTE: running + seq_group = running_queue[0] + num_running_tokens = ( + seq_group.num_seqs(status=SequenceStatus.RUNNING) * + self.num_decoding_tokens_per_seq) + num_running_seqs = seq_group.get_max_num_running_seqs() + + running_queue.popleft() while not self._can_append_slots(seq_group): - if self.running: + # Increase the budget as requests are preempted. + budget.num_batched_tokens -= num_running_tokens + budget.num_curr_seqs -= num_running_seqs + if curr_loras is not None and seq_group.lora_int_id > 0: + curr_loras.pop(seq_group.lora_int_id) + + if running_queue: # Preempt the lowest-priority sequence groups. - victim_seq_group = self.running.pop() - self._preempt(victim_seq_group, blocks_to_swap_out) - preempted.append(victim_seq_group) + victim_seq_group = running_queue.pop() + preempted_mode = self._preempt(victim_seq_group, + blocks_to_swap_out) + if preempted_mode == PreemptionMode.RECOMPUTE: + preempted.append(victim_seq_group) + else: + swapped_out.append(victim_seq_group) else: # No other sequence groups can be preempted. # Preempt the current sequence group. - self._preempt(seq_group, blocks_to_swap_out) - preempted.append(seq_group) + preempted_mode = self._preempt(seq_group, + blocks_to_swap_out) + if preempted_mode == PreemptionMode.RECOMPUTE: + preempted.append(seq_group) + else: + swapped_out.append(seq_group) break else: - # Append new slots to the sequence group. + logger.debug(f"append slot for {seq_group}") self._append_slots(seq_group, blocks_to_copy) - running.append(seq_group) - self.running = running + seq_groups.append( + ScheduledSequenceGroup(seq_group=seq_group, + token_chunk_size=1)) + # Make sure all queues are updated. + assert len(running_queue) == 0 - # Swap in the sequence groups in the SWAPPED state if possible. - self.swapped = self.policy.sort_by_priority(now, self.swapped) - if not preempted: - num_curr_seqs = sum(seq_group.get_max_num_running_seqs() - for seq_group in self.running) - curr_loras = set( - seq_group.lora_int_id - for seq_group in self.running) if self.lora_enabled else None - - leftover_swapped = deque() - - while self.swapped: - seq_group = self.swapped[0] - lora_int_id = 0 - if self.lora_enabled: - lora_int_id = seq_group.lora_int_id - if (lora_int_id > 0 and lora_int_id not in curr_loras - and len(curr_loras) >= self.lora_config.max_loras): - # We don't have a space for another LoRA, so - # we ignore this request for now. - leftover_swapped.appendleft(seq_group) - self.swapped.popleft() - continue - - # If the sequence group cannot be swapped in, stop. - if not self._can_swap_in(seq_group): - break - - # The total number of sequences in the RUNNING state should not - # exceed the maximum number of sequences. - num_new_seqs = seq_group.get_max_num_running_seqs() - if (num_curr_seqs + num_new_seqs > - self.scheduler_config.max_num_seqs): - break - - if lora_int_id > 0: - curr_loras.add(lora_int_id) - self.swapped.popleft() - self._swap_in(seq_group, blocks_to_swap_in) - self._append_slots(seq_group, blocks_to_copy) - num_curr_seqs += num_new_seqs - self.running.append(seq_group) - - self.swapped.extendleft(leftover_swapped) - - # Each sequence in the generation phase only takes one token slot. - # Therefore, the number of batched tokens is equal to the number of - # sequences in the RUNNING state. - num_batched_tokens = sum( - seq_group.num_seqs(status=SequenceStatus.RUNNING) - for seq_group in self.running) - - scheduler_outputs = SchedulerOutputs( - scheduled_seq_groups=[ - ScheduledSequenceGroup(seq_group=running_group, - token_chunk_size=1) - for running_group in self.running - ], - prompt_run=False, - num_batched_tokens=num_batched_tokens, - blocks_to_swap_in=blocks_to_swap_in, + return running_queue, SchedulerDecodeOutputs( + seq_groups=seq_groups, + preempted=preempted, + swapped_out=swapped_out, blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, - ignored_seq_groups=[], num_lookahead_slots=self._get_num_lookahead_slots( - is_prefill=False), + is_prefill=False)) + + def _schedule_swapped( + self, + swapped_queue: deque, + budget: SchedulingBudget, + curr_loras: Optional[Set[int]], + policy: Policy, + ) -> Tuple[deque, SchedulerSwappedInOutputs]: + """Schedule sequence groups that are swapped out. + + It schedules swapped requests as long as it fits `budget` and + curr_loras <= max_lora from the scheduling config. The input arguments + `budget` and `curr_loras` are updated based on scheduled seq_groups. + + Args: + swapped_queue: The queue that contains swapped out requests. + The given arguments are NOT in-place modified. + budget: The scheduling budget. The argument is in-place updated + when any requests are swapped in. + curr_loras: Currently batched lora request ids. The argument is + in-place updated when any requests are swapped in. + policy: The sorting policy to sort swapped_queue. + + Returns: + A tuple of remaining swapped_queue after scheduling and + SchedulerSwappedInOutputs. + """ + # Blocks that need to be swapped or copied before model execution. + blocks_to_swap_in: Dict[int, int] = {} + blocks_to_copy: Dict[int, List[int]] = {} + seq_groups: List[ScheduledSequenceGroup] = [] + now = time.time() + swapped_queue = policy.sort_by_priority(now, swapped_queue) + + leftover_swapped = deque() + while swapped_queue: + seq_group = swapped_queue[0] + + # If the sequence group cannot be swapped in, stop. + if not self.block_manager.can_swap_in(seq_group): + break + + lora_int_id = 0 + if self.lora_enabled: + lora_int_id = seq_group.lora_int_id + if (lora_int_id > 0 and lora_int_id not in curr_loras + and len(curr_loras) >= self.lora_config.max_loras): + # We don't have a space for another LoRA, so + # we ignore this request for now. + leftover_swapped.appendleft(seq_group) + swapped_queue.popleft() + continue + + # The total number of sequences in the RUNNING state should not + # exceed the maximum number of sequences. + num_new_seqs = seq_group.get_max_num_running_seqs() + num_new_tokens = ( + seq_group.num_seqs(status=SequenceStatus.SWAPPED) * + self.num_decoding_tokens_per_seq) + + if not budget.can_schedule(num_new_tokens=num_new_tokens, + num_new_seqs=num_new_seqs): + break + + if lora_int_id > 0 and curr_loras is not None: + curr_loras.add(lora_int_id) + swapped_queue.popleft() + self._swap_in(seq_group, blocks_to_swap_in) + self._append_slots(seq_group, blocks_to_copy) + seq_groups.append( + ScheduledSequenceGroup(seq_group, token_chunk_size=1)) + budget.num_batched_tokens += num_new_tokens + budget.num_curr_seqs += num_new_seqs + + swapped_queue.extendleft(leftover_swapped) + + return swapped_queue, SchedulerSwappedInOutputs( + seq_groups=seq_groups, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_copy=blocks_to_copy, + num_lookahead_slots=self._get_num_lookahead_slots( + is_prefill=False)) + + def _schedule_prefills( + self, + waiting_queue: deque, + budget: SchedulingBudget, + curr_loras: Optional[Set[int]], + ) -> Tuple[deque, SchedulerPrefillOutputs]: + """Schedule sequence groups that are in prefill stage. + + Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE + as a new prefill (that starts from beginning -> most recently generated + tokens). + + It schedules waiting requests as long as it fits `budget` and + curr_loras <= max_lora from the scheduling config. The input arguments + `budget` and `curr_loras` are updated based on scheduled seq_groups. + + Args: + waiting_queue: The queue that contains prefill requests. + The given arguments are NOT in-place modified. + budget: The scheduling budget. The argument is in-place updated + when any requests are scheduled. + curr_loras: Currently batched lora request ids. The argument is + in-place updated when any requests are scheduled. + + Returns: + A tuple of remaining waiting_queue after scheduling and + SchedulerSwappedInOutputs. + """ + ignored_seq_groups: List[SequenceGroup] = [] + seq_groups: List[SequenceGroup] = [] + # We don't sort waiting queue because we assume it is sorted. + # Copy the queue so that the input queue is not modified. + waiting_queue = deque([s for s in waiting_queue]) + + leftover_waiting_sequences = deque() + while self._passed_delay(time.time()) and waiting_queue: + seq_group = waiting_queue[0] + + waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) + assert len(waiting_seqs) == 1, ( + "Waiting sequence group should have only one prompt " + "sequence.") + + num_prompt_tokens = waiting_seqs[0].get_len() + if num_prompt_tokens > self.prompt_limit: + logger.warning( + f"Input prompt ({num_prompt_tokens} tokens) is too long" + f" and exceeds limit of {self.prompt_limit}") + for seq in waiting_seqs: + seq.status = SequenceStatus.FINISHED_IGNORED + ignored_seq_groups.append(seq_group) + waiting_queue.popleft() + continue + + # If the sequence group cannot be allocated, stop. + can_allocate = self.block_manager.can_allocate(seq_group) + if can_allocate == AllocStatus.LATER: + break + elif can_allocate == AllocStatus.NEVER: + logger.warning( + f"Input prompt ({num_prompt_tokens} tokens) is too long" + f" and exceeds the capacity of block_manager") + for seq in waiting_seqs: + seq.status = SequenceStatus.FINISHED_IGNORED + ignored_seq_groups.append(seq_group) + waiting_queue.popleft() + continue + + lora_int_id = 0 + if self.lora_enabled: + lora_int_id = seq_group.lora_int_id + if (self.lora_enabled and lora_int_id > 0 + and lora_int_id not in curr_loras + and len(curr_loras) >= self.lora_config.max_loras): + # We don't have a space for another LoRA, so + # we ignore this request for now. + leftover_waiting_sequences.appendleft(seq_group) + waiting_queue.popleft() + continue + + num_new_seqs = seq_group.get_max_num_running_seqs() + if not budget.can_schedule(num_new_tokens=num_prompt_tokens, + num_new_seqs=num_new_seqs): + break + + # Can schedule this request. + if curr_loras is not None and lora_int_id > 0: + curr_loras.add(lora_int_id) + waiting_queue.popleft() + self._allocate_and_set_running(seq_group) + seq_groups.append( + ScheduledSequenceGroup(seq_group=seq_group, + token_chunk_size=num_prompt_tokens)) + budget.num_batched_tokens += num_prompt_tokens + budget.num_curr_seqs += num_new_seqs + + # Queue requests that couldn't be scheduled. + waiting_queue.extendleft(leftover_waiting_sequences) + if len(seq_groups) > 0: + self.prev_prompt = True + + return waiting_queue, SchedulerPrefillOutputs( + seq_groups=seq_groups, + ignored_seq_groups=ignored_seq_groups, + num_lookahead_slots=self._get_num_lookahead_slots(is_prefill=True)) + + def _schedule(self) -> SchedulerOutputs: + """Batch requests that are queued.. + + The current policy is designed to opimimize the throughput. First, + it batches as many prefill requests as possible. And it schedules + decodes. If there's a pressure on GPU memory, decode requests can + be swapped or preempted. + """ + # Include running requests to the budget. + budget = SchedulingBudget( + num_batched_tokens=sum( + seq_group.num_seqs(status=SequenceStatus.RUNNING) + for seq_group in self.running), + num_curr_seqs=sum(seq_group.get_max_num_running_seqs() + for seq_group in self.running), + token_budget=self.scheduler_config.max_num_batched_tokens, + max_num_seqs=self.scheduler_config.max_num_seqs, + ) + curr_loras = set( + seq_group.lora_int_id + for seq_group in self.running) if self.lora_enabled else None + + remaining_waiting, prefills = (self.waiting, + SchedulerPrefillOutputs.create_empty()) + remaining_running, decodes = (self.running, + SchedulerDecodeOutputs.create_empty()) + remaining_swapped, swapped_in = ( + self.swapped, SchedulerSwappedInOutputs.create_empty()) + + # If any requests are swapped, prioritized swapped requests. + if not self.swapped: + remaining_waiting, prefills = self._schedule_prefills( + self.waiting, budget, curr_loras) + + # Don't schedule decodes if prefills are scheduled. + if len(prefills.seq_groups) == 0: + remaining_running, decodes = self._schedule_decodes( + self.running, budget, curr_loras, self.policy) + # If any sequence group is preempted, do not swap in any sequence + # group. because it means there's no slot for new running requests. + if len(decodes.preempted) + len(decodes.swapped_out) == 0: + remaining_swapped, swapped_in = self._schedule_swapped( + self.swapped, budget, curr_loras, self.policy) + + assert (budget.num_batched_tokens <= + self.scheduler_config.max_num_batched_tokens) + assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs + + # Update waiting requests. + self.waiting = remaining_waiting + self.waiting.extendleft(decodes.preempted) + # Update new running requests. + self.running = remaining_running + self.running.extend([s.seq_group for s in prefills.seq_groups]) + self.running.extend([s.seq_group for s in decodes.seq_groups]) + self.running.extend([s.seq_group for s in swapped_in.seq_groups]) + # Update swapped requests. + self.swapped = remaining_swapped + self.swapped.extend(decodes.swapped_out) + + return SchedulerOutputs( + scheduled_seq_groups=prefills.seq_groups + decodes.seq_groups + + swapped_in.seq_groups, + num_prefill_groups=len(prefills.seq_groups), + num_batched_tokens=budget.num_batched_tokens, + blocks_to_swap_in=swapped_in.blocks_to_swap_in, + blocks_to_swap_out=decodes.blocks_to_swap_out, + blocks_to_copy=merge_dicts(decodes.blocks_to_copy, + swapped_in.blocks_to_copy), + ignored_seq_groups=prefills.ignored_seq_groups, + num_lookahead_slots=(prefills.num_lookahead_slots + + decodes.num_lookahead_slots + + swapped_in.num_lookahead_slots), ) - return scheduler_outputs def _can_append_slots(self, seq_group: SequenceGroup) -> bool: """Determine whether or not we have enough space in the KV cache to @@ -444,7 +658,8 @@ class Scheduler: # Create input data structures. seq_group_metadata_list: List[SequenceGroupMetadata] = [] - for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups: + for i, scheduled_seq_group in enumerate( + scheduler_outputs.scheduled_seq_groups): seq_group = scheduled_seq_group.seq_group token_chunk_size = scheduled_seq_group.token_chunk_size seq_group.maybe_set_first_scheduled_time(now) @@ -464,9 +679,12 @@ class Scheduler: self.block_manager.get_common_computed_block_ids( seq_group.get_seqs(status=SequenceStatus.RUNNING))) + # It assumes the scheduled_seq_groups is ordered by + # prefill < decoding. + is_prompt = i < scheduler_outputs.num_prefill_groups seq_group_metadata = SequenceGroupMetadata( request_id=seq_group.request_id, - is_prompt=scheduler_outputs.prompt_run, + is_prompt=is_prompt, seq_data=seq_data, sampling_params=seq_group.sampling_params, block_tables=block_tables, @@ -479,7 +697,7 @@ class Scheduler: # the subsequent comms can still use delta, but # `multi_modal_data` will be None. multi_modal_data=seq_group.multi_modal_data - if scheduler_outputs.prompt_run else None, + if scheduler_outputs.num_prefill_groups > 0 else None, ) seq_group_metadata_list.append(seq_group_metadata) @@ -504,7 +722,7 @@ class Scheduler: self.running = deque(seq_group for seq_group in self.running if not seq_group.is_finished()) - def _allocate(self, seq_group: SequenceGroup) -> None: + def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: self.block_manager.allocate(seq_group) for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): seq.status = SequenceStatus.RUNNING @@ -539,7 +757,7 @@ class Scheduler: seq_group: SequenceGroup, blocks_to_swap_out: Dict[int, int], preemption_mode: Optional[PreemptionMode] = None, - ) -> None: + ) -> PreemptionMode: # If preemption mode is not specified, we determine the mode as follows: # We use recomputation by default since it incurs lower overhead than # swapping. However, when the sequence group has multiple sequences @@ -562,6 +780,7 @@ class Scheduler: self._preempt_by_swap(seq_group, blocks_to_swap_out) else: raise AssertionError("Invalid preemption mode.") + return preemption_mode def _preempt_by_recompute( self, @@ -573,9 +792,6 @@ class Scheduler: seq.status = SequenceStatus.WAITING self.free_seq(seq) seq.reset_state_for_recompute() - # NOTE: For FCFS, we insert the preempted sequence group to the front - # of the waiting queue. - self.waiting.appendleft(seq_group) def _preempt_by_swap( self, @@ -583,7 +799,6 @@ class Scheduler: blocks_to_swap_out: Dict[int, int], ) -> None: self._swap_out(seq_group, blocks_to_swap_out) - self.swapped.append(seq_group) def _swap_in( self, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 4cac9c5dec1f..2da2c79e0d50 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -728,7 +728,7 @@ class LLMEngine: time_per_output_tokens = [] time_e2e_requests = [] if scheduler_outputs is not None: - prompt_run = scheduler_outputs.prompt_run + prompt_run = scheduler_outputs.num_prefill_groups > 0 # Number of Tokens. if prompt_run: diff --git a/vllm/utils.py b/vllm/utils.py index 1db57bc50c83..3b229f1191dd 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -6,7 +6,7 @@ import socket import subprocess import uuid import warnings -from collections import OrderedDict +from collections import OrderedDict, defaultdict from functools import lru_cache, partial from platform import uname from typing import (Any, Awaitable, Callable, Generic, Hashable, List, @@ -450,3 +450,20 @@ def maybe_expand_dim(tensor: torch.Tensor, if tensor.ndim < target_dims: tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim))) return tensor + + +def merge_dicts(dict1: dict[Any, list[Any]], + dict2: dict[Any, list[Any]]) -> dict[Any, list[Any]]: + """Merge 2 dicts that have key -> List of items. + + When a key conflicts, the values in dict1 is prioritized. + """ + merged_dict = defaultdict(list) + + for key, value in dict1.items(): + merged_dict[key].extend(value) + + for key, value in dict2.items(): + merged_dict[key].extend(value) + + return dict(merged_dict)