[3/N] Refactor scheduler for chunked prefill scheduling (#3550)

This commit is contained in:
SangBin Cho 2024-04-04 06:13:49 +09:00 committed by GitHub
parent c64cf38673
commit 3dcb3e8b98
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 1021 additions and 256 deletions

View File

@ -1,10 +1,15 @@
import time
from collections import deque
from typing import List
from unittest.mock import MagicMock
import pytest # noqa
from vllm.config import CacheConfig, SchedulerConfig
from vllm.core.scheduler import Scheduler
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.interfaces import AllocStatus
from vllm.core.policy import PolicyFactory
from vllm.core.scheduler import Scheduler, SchedulingBudget
from vllm.lora.request import LoRARequest
from vllm.sequence import Logprob, SequenceGroup
from .utils import create_dummy_prompt
@ -177,7 +182,6 @@ def test_scheduler_max_seqs():
def test_scheduler_delay_factor():
block_size = 4
scheduler_config = SchedulerConfig(100, 64, 16, delay_factor=0.5)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
@ -189,7 +193,7 @@ def test_scheduler_delay_factor():
_, seq_group = create_dummy_prompt("0", prompt_length=block_size)
scheduler.add_seq_group(seq_group)
seq_group_meta, out = scheduler.schedule()
assert out.prompt_run
assert out.num_prefill_groups > 0
assert seq_group_meta[0].request_id == '0'
# wait for a second before scheduling next prompt
@ -199,11 +203,533 @@ def test_scheduler_delay_factor():
# second prompt should *not* be scheduled
seq_group_meta, out = scheduler.schedule()
assert not out.prompt_run
assert out.num_prefill_groups == 0
assert seq_group_meta[0].request_id == '0'
# wait for more than 0.5 second and try again
time.sleep(0.6)
seq_group_meta, out = scheduler.schedule()
assert out.prompt_run
assert out.num_prefill_groups > 0
assert seq_group_meta[0].request_id == '1'
def test_swapped_out_prioritized():
scheduler = initialize_scheduler(max_num_seqs=6)
# best_of=2 * 3 == 6 sequences.
for i in range(3):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2)
scheduler.add_seq_group(seq_group)
_, out = scheduler.schedule()
# prefill scheduled now.
assert len(out.scheduled_seq_groups) == 3
# The last request should be swapped out.
scheduler.block_manager.can_append_slots = MagicMock()
def cannot_append_second_group(seq_group, num_lookahead_slots):
return seq_group.request_id != "2"
scheduler.block_manager.can_append_slots.side_effect = (
cannot_append_second_group)
_, out = scheduler.schedule()
assert len(out.scheduled_seq_groups) == 2
assert out.num_batched_tokens == 2
assert out.blocks_to_swap_out != {}
assert out.blocks_to_swap_in == {}
# Add 1 more task. Swap should be prioritized over prefill.
_, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2)
scheduler.add_seq_group(seq_group)
_, out = scheduler.schedule()
assert len(out.scheduled_seq_groups) == 3
# 3 decodes. It is swapped in.
assert out.num_batched_tokens == 3
assert out.blocks_to_swap_in != {}
assert out.blocks_to_swap_out == {}
def initialize_scheduler(*,
max_num_seqs=1000,
max_token_budget=1000,
max_model_len=1000,
lora_config=None):
block_size = 4
scheduler_config = SchedulerConfig(max_token_budget, max_num_seqs,
max_model_len)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 8
cache_config.num_gpu_blocks = 8
scheduler = Scheduler(scheduler_config, cache_config, lora_config)
return scheduler
def create_token_budget(num_batched_tokens: int = 0,
num_curr_seqs: int = 0,
token_budget: int = 10000,
max_num_seqs: int = 10000) -> SchedulingBudget:
return SchedulingBudget(
num_batched_tokens=num_batched_tokens,
num_curr_seqs=num_curr_seqs,
token_budget=token_budget,
max_num_seqs=max_num_seqs,
)
def test_prefill_schedule_max_prompt_len():
"""
Test prompt longer than max_prompt_len is aborted.
"""
scheduler = initialize_scheduler(max_model_len=30)
_, seq_group = create_dummy_prompt(0, prompt_length=60)
waiting = deque([seq_group])
budget = create_token_budget()
remaining_waiting, output = scheduler._schedule_prefills(
waiting, budget, None)
assert len(output.ignored_seq_groups) == 1
assert len(output.seq_groups) == 0
assert budget.num_batched_tokens == 0
assert budget.num_curr_seqs == 0
assert len(remaining_waiting) == 0
def test_prefill_schedule_token_budget():
"""
Test token budget respected.
"""
scheduler = initialize_scheduler()
waiting = deque()
budget = create_token_budget(token_budget=0)
for i in range(2):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
waiting.append(seq_group)
# 0 token budget == nothing is scheduled.
remaining_waiting, output = scheduler._schedule_prefills(
waiting, budget, None)
assert len(output.ignored_seq_groups) == 0
assert len(output.seq_groups) == 0
assert budget.num_batched_tokens == 0
assert budget.num_curr_seqs == 0
assert len(remaining_waiting) == 2
# 60 token budget == 1 request scheduled.
budget = create_token_budget(token_budget=60)
remaining_waiting, output = scheduler._schedule_prefills(
waiting, budget, None)
assert len(output.ignored_seq_groups) == 0
assert len(output.seq_groups) == 1
assert budget.num_batched_tokens == 60
assert budget.num_curr_seqs == 1
assert len(remaining_waiting) == 1
# Test when current_batched_tokens respected.
scheduler = initialize_scheduler()
waiting = deque()
budget = create_token_budget(num_batched_tokens=30, token_budget=60)
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
# Cannot schedule a prompt that doesn't fit the budget.
waiting.append(seq_group)
remaining_waiting, output = scheduler._schedule_prefills(
waiting, budget, None)
assert len(output.ignored_seq_groups) == 0
assert len(output.seq_groups) == 0
assert budget.num_batched_tokens == 30
assert budget.num_curr_seqs == 0
assert len(remaining_waiting) == 1
budget = create_token_budget(num_batched_tokens=30, token_budget=90)
remaining_waiting, output = scheduler._schedule_prefills(
waiting, budget, None)
assert len(output.seq_groups) == 1
assert budget.num_batched_tokens == 90
assert budget.num_curr_seqs == 1
assert len(remaining_waiting) == 0
def test_prefill_schedule_max_seqs():
"""
Test max seq respected.
"""
scheduler = initialize_scheduler()
waiting = deque()
budget = create_token_budget(max_num_seqs=2)
for i in range(3):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
waiting.append(seq_group)
remaining_waiting, output = scheduler._schedule_prefills(
waiting, budget, None)
assert len(output.ignored_seq_groups) == 0
assert len(output.seq_groups) == 2
assert budget.num_batched_tokens == 120
assert budget.num_curr_seqs == 2
assert len(remaining_waiting) == 1
# Verify curr_num_seqs respected.
waiting = deque()
budget = create_token_budget(num_curr_seqs=2, max_num_seqs=2)
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
waiting.append(seq_group)
remaining_waiting, output = scheduler._schedule_prefills(
waiting, budget, None)
assert len(output.ignored_seq_groups) == 0
assert len(output.seq_groups) == 0
assert budget.num_batched_tokens == 0
assert budget.num_curr_seqs == 2
assert len(remaining_waiting) == 1
def test_prefill_schedule_max_lora():
"""
Test max lora is respected and prioritized.
"""
lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)
scheduler = initialize_scheduler(lora_config=lora_config)
waiting = deque()
budget = create_token_budget(token_budget=120)
curr_loras = set()
for i in range(2):
_, seq_group = create_dummy_prompt(str(i),
prompt_length=60,
lora_request=LoRARequest(
lora_name=str(i),
lora_int_id=i + 1,
lora_local_path="abc"))
waiting.append(seq_group)
# Add two more requests to verify lora is prioritized.
# 0: Lora, 1: Lora, 2: regular, 3: regular
# In the first iteration, index 0, 2 is scheduled.
# If a request is not scheduled because it hits max lora, it is
# prioritized. Verify that.
for i in range(2, 4):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
waiting.append(seq_group)
# Schedule 2 requests (0 and 2)
remaining_waiting, output = scheduler._schedule_prefills(
waiting, budget, curr_loras)
assert len(output.ignored_seq_groups) == 0
assert len(output.seq_groups) == 2
assert budget.num_batched_tokens == 120
assert budget.num_curr_seqs == 2
assert len(remaining_waiting) == 2
assert len(curr_loras) == 1
# The second lora request is scheduled next as FCFS policy.
# Reset curr_loras so that it can be scheduled.
curr_loras = set()
budget = create_token_budget(token_budget=60)
remaining_waiting, output = scheduler._schedule_prefills(
remaining_waiting, budget, curr_loras)
assert len(output.seq_groups) == 1
assert output.seq_groups[0].seq_group.request_id == "1"
assert len(remaining_waiting) == 1
assert len(curr_loras) == 1
assert budget.num_batched_tokens == 60
def test_prefill_schedule_no_block_manager_capacity():
"""
Test sequence cannot be scheduled due to block manager has no capacity.
"""
scheduler = initialize_scheduler()
waiting = deque()
budget = create_token_budget()
for i in range(3):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
waiting.append(seq_group)
scheduler.block_manager.can_allocate = MagicMock()
scheduler.block_manager.can_allocate.return_value = AllocStatus.LATER
remainig_waiting, output = scheduler._schedule_prefills(
waiting, budget, None)
assert len(output.ignored_seq_groups) == 0
assert len(output.seq_groups) == 0
assert budget.num_batched_tokens == 0
assert budget.num_curr_seqs == 0
assert len(remainig_waiting) == 3
scheduler = initialize_scheduler()
waiting = deque()
budget = create_token_budget()
for i in range(3):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
waiting.append(seq_group)
scheduler.block_manager.can_allocate = MagicMock()
scheduler.block_manager.can_allocate.return_value = AllocStatus.NEVER
remaining_waiting, output = scheduler._schedule_prefills(
waiting, budget, None)
assert len(output.ignored_seq_groups) == 3
assert len(output.seq_groups) == 0
assert budget.num_batched_tokens == 0
assert budget.num_curr_seqs == 0
assert len(remaining_waiting) == 0
def test_decode_schedule_preempted():
"""
Test decodes cannot be scheduled and preempted.
"""
scheduler = initialize_scheduler()
running = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
for i in range(3):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
scheduler._allocate_and_set_running(seq_group)
running.append(seq_group)
scheduler.block_manager.can_append_slots = MagicMock()
def cannot_append_second_group(seq_group, num_lookahead_slots):
return seq_group.request_id != "1"
scheduler.block_manager.can_append_slots.side_effect = (
cannot_append_second_group)
# 1 cannot be scheduled, and the lowest priority (request 2)
# should be preempted. 1 will also be preempted.
budget = create_token_budget(num_batched_tokens=3, num_curr_seqs=3)
remainig_running, output = scheduler._schedule_decodes(
running, budget, curr_loras, policy)
assert len(remainig_running) == 0
assert len(output.seq_groups) == 1
assert output.seq_groups[0].seq_group.request_id == "0"
assert len(output.preempted) == 2
# Verify budgets are updated.
assert budget.num_batched_tokens == 1
assert budget.num_curr_seqs == 1
# Both should be preempted, not swapped.
assert output.blocks_to_swap_out == {}
# Nothing is copied.
assert output.blocks_to_copy == {}
def test_decode_swap_beam_search():
"""
Test best_of > 1 swap out blocks
"""
scheduler = initialize_scheduler()
running = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
for i in range(3):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group)
running.append(seq_group)
# The last request should be swapped out.
scheduler.block_manager.can_append_slots = MagicMock()
def cannot_append_second_group(seq_group, num_lookahead_slots):
return seq_group.request_id != "2"
scheduler.block_manager.can_append_slots.side_effect = (
cannot_append_second_group)
scheduler.block_manager.swap_out = MagicMock()
expected_swap_mapping = {"5": "7"}
scheduler.block_manager.swap_out.return_value = expected_swap_mapping
budget = create_token_budget(num_batched_tokens=3, num_curr_seqs=3)
remainig_running, output = scheduler._schedule_decodes(
running, budget, curr_loras, policy)
assert len(remainig_running) == 0
assert len(output.seq_groups) == 2
assert output.seq_groups[0].seq_group.request_id == "0"
assert output.seq_groups[1].seq_group.request_id == "1"
assert len(output.preempted) == 0
assert len(output.swapped_out) == 1
# Budget should refledct preempted requests.
assert budget.num_batched_tokens == 2
# since there are 2 sequences, 2 should be subtracted.
assert budget.num_curr_seqs == 1
# Both should be preempted, not swapped.
assert output.blocks_to_swap_out == expected_swap_mapping
# Nothing is copied.
assert output.blocks_to_copy == {}
def test_schedule_decode_blocks_to_copy_update():
"""
Verify blocks_to_copy is updated.
"""
scheduler = initialize_scheduler()
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
running = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
scheduler._allocate_and_set_running(seq_group)
running.append(seq_group)
# The last request should be swapped out.
scheduler.block_manager.append_slots = MagicMock()
scheduler.block_manager.append_slots.return_value = {2: [3]}
budget = create_token_budget()
remaining_running, output = scheduler._schedule_decodes(
running, budget, curr_loras, policy)
assert len(remaining_running) == 0
assert len(output.seq_groups) == 1
assert len(output.preempted) == 0
assert len(output.swapped_out) == 0
# Nothing is preempted.
assert output.blocks_to_swap_out == {}
# Since append_slot returns the source -> dist mapping, it should
# applied.
assert output.blocks_to_copy == {2: [3]}
def test_schedule_swapped_simple():
scheduler = initialize_scheduler()
swapped = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
blocks_to_swap_out = {}
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group)
scheduler._swap_out(seq_group, blocks_to_swap_out)
swapped.append(seq_group)
budget = create_token_budget()
remaining_swapped, output = scheduler._schedule_swapped(
swapped, budget, curr_loras, policy)
assert len(remaining_swapped) == 0
assert budget.num_batched_tokens == 1
assert budget.num_curr_seqs == 2
assert len(output.seq_groups) == 1
# swap in is the reverse of swap out
blocks_to_swap_in_reverse = {}
for swapin, swapout in output.blocks_to_swap_in.items():
blocks_to_swap_in_reverse[swapout] = swapin
assert blocks_to_swap_out == blocks_to_swap_in_reverse
def test_schedule_swapped_max_token_budget():
scheduler = initialize_scheduler()
swapped = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
blocks_to_swap_out = {}
for _ in range(2):
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group)
scheduler._swap_out(seq_group, blocks_to_swap_out)
swapped.append(seq_group)
budget = create_token_budget(token_budget=1)
remaining_swapped, output = scheduler._schedule_swapped(
swapped, budget, curr_loras, policy)
assert len(remaining_swapped) == 1
assert budget.num_batched_tokens == 1
assert budget.num_curr_seqs == 2
assert len(output.seq_groups) == 1
# Verify num_batched_tokens are respected.
budget = create_token_budget(num_batched_tokens=1, token_budget=1)
remaining_swapped, output = scheduler._schedule_swapped(
remaining_swapped, budget, curr_loras, policy)
assert len(remaining_swapped) == 1
assert budget.num_batched_tokens == 1
assert budget.num_curr_seqs == 0
assert len(output.seq_groups) == 0
def test_schedule_swapped_max_seqs():
scheduler = initialize_scheduler()
swapped = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
blocks_to_swap_out = {}
for _ in range(2):
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group)
scheduler._swap_out(seq_group, blocks_to_swap_out)
swapped.append(seq_group)
budget = create_token_budget(max_num_seqs=2)
remaining_swapped, output = scheduler._schedule_swapped(
swapped, budget, curr_loras, policy)
assert len(remaining_swapped) == 1
assert budget.num_batched_tokens == 1
assert budget.num_curr_seqs == 2
assert len(output.seq_groups) == 1
# Verify num_curr_seqs are respected.
budget = create_token_budget(num_curr_seqs=2, max_num_seqs=2)
remaining_swapped, output = scheduler._schedule_swapped(
remaining_swapped, budget, curr_loras, policy)
assert len(remaining_swapped) == 1
assert budget.num_batched_tokens == 0
assert budget.num_curr_seqs == 2
assert len(output.seq_groups) == 0
def test_schedule_swapped_max_loras():
lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)
scheduler = initialize_scheduler(lora_config=lora_config)
swapped = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = set()
blocks_to_swap_out = {}
for i in range(2):
_, seq_group = create_dummy_prompt(str(i),
prompt_length=60,
lora_request=LoRARequest(
lora_name=str(i),
lora_int_id=i + 1,
lora_local_path="abc"))
scheduler._allocate_and_set_running(seq_group)
scheduler._swap_out(seq_group, blocks_to_swap_out)
swapped.append(seq_group)
budget = create_token_budget()
remaining_swapped, output = scheduler._schedule_swapped(
swapped, budget, curr_loras, policy)
assert len(remaining_swapped) == 1
assert budget.num_batched_tokens == 1
assert budget.num_curr_seqs == 1
assert len(output.seq_groups) == 1
assert len(curr_loras) == 1
def test_schedule_swapped_cannot_swap_in():
scheduler = initialize_scheduler()
swapped = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
blocks_to_swap_out = {}
for _ in range(2):
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group)
scheduler._swap_out(seq_group, blocks_to_swap_out)
swapped.append(seq_group)
# The last request should be swapped out.
scheduler.block_manager.can_swap_in = MagicMock()
scheduler.block_manager.can_swap_in.return_value = False
# Since we cannot swap in, none of the requests are swapped in.
budget = create_token_budget()
remaining_swapped, output = scheduler._schedule_swapped(
swapped, budget, curr_loras, policy)
assert len(remaining_swapped) == 2
assert budget.num_batched_tokens == 0
assert budget.num_curr_seqs == 0
assert len(output.seq_groups) == 0
def test_schedule_swapped_blocks_to_copy():
scheduler = initialize_scheduler()
swapped = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group)
blocks_to_swap_out = {}
scheduler._swap_out(seq_group, blocks_to_swap_out)
swapped.append(seq_group)
# The last request should be swapped out.
scheduler.block_manager.append_slots = MagicMock()
scheduler.block_manager.append_slots.return_value = {2: [3]}
budget = create_token_budget()
remaining_swapped, output = scheduler._schedule_swapped(
swapped, budget, curr_loras, policy)
assert len(remaining_swapped) == 0
assert len(output.seq_groups) == 1
assert output.blocks_to_copy == {2: [3]}

View File

@ -1,14 +1,19 @@
import time
from typing import Tuple
from typing import Optional, Tuple
from vllm import SamplingParams
from vllm.lora.request import LoRARequest
from vllm.sequence import Logprob, Sequence, SequenceGroup
def create_dummy_prompt(
request_id: str,
prompt_length: int,
block_size: int = None) -> Tuple[Sequence, SequenceGroup]:
request_id: str,
prompt_length: int,
block_size: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
use_beam_search: bool = False,
best_of: int = 1,
) -> Tuple[Sequence, SequenceGroup]:
if not block_size:
block_size = prompt_length
@ -17,8 +22,10 @@ def create_dummy_prompt(
prompt_tokens = list(range(prompt_length))
prompt_str = " ".join([str(t) for t in prompt_tokens])
prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size)
seq_group = SequenceGroup(request_id, [prompt], SamplingParams(),
time.time(), None)
seq_group = SequenceGroup(
request_id, [prompt],
SamplingParams(use_beam_search=use_beam_search, best_of=best_of),
time.time(), lora_request)
return prompt, seq_group

View File

@ -6,11 +6,12 @@ from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.core.policy import PolicyFactory
from vllm.core.policy import Policy, PolicyFactory
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceStatus)
from vllm.utils import merge_dicts
logger = init_logger(__name__)
@ -28,9 +29,19 @@ class PreemptionMode(enum.Enum):
RECOMPUTE = enum.auto()
# seq_group: SequenceGroup to schedule.
# token_chunk_size: The number of prefill tokens to be processed in the next
# step.
@dataclass
class SchedulingBudget:
"""The available slots for scheduling."""
num_batched_tokens: int
num_curr_seqs: int
token_budget: int
max_num_seqs: int
def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int):
return (self.num_batched_tokens + num_new_tokens <= self.token_budget
and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs)
@dataclass
class ScheduledSequenceGroup:
# A sequence group that's scheduled.
@ -41,53 +52,28 @@ class ScheduledSequenceGroup:
token_chunk_size: int
@dataclass
class SchedulerOutputs:
# Scheduled sequence groups.
scheduled_seq_groups: Iterable[ScheduledSequenceGroup]
# Number of prefill groups scheduled.
num_prefill_groups: int
# Total number of batched tokens.
num_batched_tokens: int
# Blocks to swap in. Dict of CPU -> GPU block number.
blocks_to_swap_in: Dict[int, int]
# Blocks to swap out. Dict of GPU -> CPU block number.
blocks_to_swap_out: Dict[int, int]
# Blocks to copy. Source to a list of dest blocks.
blocks_to_copy: Dict[int, List[int]]
# Sequence groups that are going to be ignored.
ignored_seq_groups: List[SequenceGroup]
# The number of slots for lookahead decoding.
num_lookahead_slots: int
def __init__(
self,
scheduled_seq_groups: Iterable[ScheduledSequenceGroup],
prompt_run: bool,
num_batched_tokens: int,
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
ignored_seq_groups: List[SequenceGroup],
num_lookahead_slots: int,
) -> None:
"""A list of sequence groups to be scheduled as a single batch.
Args:
scheduled_seq_groups: A tuple of scheduled sequence group and its
token chunk size.
prompt_run: True if all sequence groups are in prefill phase.
If False, all sequence groups are in decoding phase.
num_batched_tokens: Total number of batched tokens.
blocks_to_swap_in: Blocks to swap in. Dict of CPU -> GPU block
number.
blocks_to_swap_out: Blocks to swap out. Dict of GPU -> CPU block
number.
blocks_to_copy: Blocks to copy. Source to a list of dest blocks.
ignored_seq_groups: Sequence groups that are going to be ignored.
"""
# A tuple of scheduled sequence group and its chunk size.
self.scheduled_seq_groups: ScheduledSequenceGroup = scheduled_seq_groups
# True if all sequence groups are in prefill phase. If False, all
# sequence groups are in decoding phase.
self.prompt_run: bool = prompt_run
# Total number of batched tokens.
self.num_batched_tokens: int = num_batched_tokens
# Blocks to swap in. Dict of CPU -> GPU block number.
self.blocks_to_swap_in: Dict[int, int] = blocks_to_swap_in
# Blocks to swap out. Dict of GPU -> CPU block number.
self.blocks_to_swap_out: Dict[int, int] = blocks_to_swap_out
# Blocks to copy. Source to a list of dest blocks.
self.blocks_to_copy: Dict[int, List[int]] = blocks_to_copy
# Sequence groups that are going to be ignored.
self.ignored_seq_groups: List[SequenceGroup] = ignored_seq_groups
def __post_init__(self):
# Swap in and swap out should never happen at the same time.
assert not (blocks_to_swap_in and blocks_to_swap_out)
self.num_lookahead_slots = num_lookahead_slots
assert not (self.blocks_to_swap_in and self.blocks_to_swap_out)
self.num_loras: int = len(self.lora_requests)
if self.num_loras > 0:
@ -108,6 +94,73 @@ class SchedulerOutputs:
return {g.seq_group.lora_request for g in self.scheduled_seq_groups}
@dataclass
class SchedulerDecodeOutputs:
"""Outputs of the decoding phase of the scheduler."""
# Selected sequence groups for decoding.
seq_groups: List[SequenceGroup]
# The preempted sequences.
preempted: List[SequenceGroup]
# Sequences that are swapped out.
swapped_out: List[SequenceGroup]
# The blocks to swap out.
blocks_to_swap_out: Dict[int, int]
# The blocks to copy.
blocks_to_copy: Dict[int, List[int]]
num_lookahead_slots: int
@classmethod
def create_empty(cls) -> "SchedulerDecodeOutputs":
return SchedulerDecodeOutputs(
seq_groups=[],
preempted=[],
swapped_out=[],
blocks_to_swap_out={},
blocks_to_copy={},
num_lookahead_slots=0,
)
@dataclass
class SchedulerSwappedInOutputs:
"""Outputs of the decoding phase of the scheduler."""
# Selected sequence groups for decoding.
seq_groups: List[SequenceGroup]
# The blocks to swap in.
blocks_to_swap_in: Dict[int, int]
# The blocks to copy.
blocks_to_copy: Dict[int, List[int]]
# # The number of batched tokens.
num_lookahead_slots: int
@classmethod
def create_empty(cls) -> "SchedulerSwappedInOutputs":
return SchedulerSwappedInOutputs(
seq_groups=[],
blocks_to_swap_in={},
blocks_to_copy={},
num_lookahead_slots=0,
)
@dataclass
class SchedulerPrefillOutputs:
"""Outputs of the prefill phase of the scheduler."""
# Selected sequence groups for prefill.
seq_groups: List[SequenceGroup]
# Ignored sequence groups.
ignored_seq_groups: List[SequenceGroup]
num_lookahead_slots: int
@classmethod
def create_empty(cls) -> "SchedulerPrefillOutputs":
return SchedulerPrefillOutputs(
seq_groups=[],
ignored_seq_groups=[],
num_lookahead_slots=0,
)
class Scheduler:
def __init__(
@ -123,6 +176,7 @@ class Scheduler:
# LoRAs. This should be improved in the future.
self.lora_config = lora_config
# TODO(sang): Fix it after chunked prefill is enabled.
self.prompt_limit = min(self.scheduler_config.max_model_len,
self.scheduler_config.max_num_batched_tokens)
@ -142,10 +196,13 @@ class Scheduler:
enable_caching=self.cache_config.enable_prefix_caching)
# Sequence groups in the WAITING state.
# Contain new prefill or preempted requests.
self.waiting: Deque[SequenceGroup] = deque()
# Sequence groups in the RUNNING state.
# Contain decode requests.
self.running: Deque[SequenceGroup] = deque()
# Sequence groups in the SWAPPED state.
# Contain decode requests that are swapped out.
self.swapped: Deque[SequenceGroup] = deque()
# Time at previous scheduling step
@ -159,8 +216,14 @@ class Scheduler:
def lora_enabled(self) -> bool:
return bool(self.lora_config)
@property
def num_decoding_tokens_per_seq(self) -> int:
"""The number of new tokens."""
return 1
def add_seq_group(self, seq_group: SequenceGroup) -> None:
# Add sequence groups to the waiting queue.
logger.debug(f"add_seq_group {seq_group.request_id}")
self.waiting.append(seq_group)
def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
@ -205,214 +268,365 @@ class Scheduler:
def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped)
def _schedule(self) -> SchedulerOutputs:
def _schedule_decodes(
self,
running_queue: deque,
budget: SchedulingBudget,
curr_loras: Optional[Set[int]],
policy: Policy,
) -> Tuple[deque, SchedulerDecodeOutputs]:
"""Schedule sequence groups in a decoding stage.
NOTE(sang): All the RUNNING num_batched_tokens, num_curr_seqs,
and curr_loras should be already included in `budget` and `curr_loras`.
The API doesn't ADD UP these values.
Note that `budget` and `curr_loras` are still subtracted/popped when
any running requests are preempted from this API.
Args:
running_queue: The queue that contains running requests (i.e.,
decodes). The given arguments are NOT in-place modified.
budget: The scheduling budget. The argument is in-place updated
when any decodes are preempted.
curr_loras: Currently batched lora request ids. The argument is
in-place updated when any decodes are preempted.
policy: The sorting policy to sort running_queue.
Returns:
A tuple of remaining running queue (should be always 0) after
scheduling and SchedulerDecodeOutputs.
"""
# Blocks that need to be swapped or copied before model execution.
blocks_to_swap_in: Dict[int, int] = {}
blocks_to_swap_out: Dict[int, int] = {}
blocks_to_copy: Dict[int, List[int]] = {}
# Fix the current time.
now = time.time()
# Join waiting sequences if possible.
if not self.swapped:
ignored_seq_groups: List[SequenceGroup] = []
scheduled: List[SequenceGroup] = []
# The total number of sequences on the fly, including the
# requests in the generation phase.
num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
for seq_group in self.running)
curr_loras = set(
seq_group.lora_int_id
for seq_group in self.running) if self.lora_enabled else None
# Optimization: We do not sort the waiting queue since the preempted
# sequence groups are added to the front and the new sequence groups
# are added to the back.
leftover_waiting_sequences = deque()
num_batched_tokens = 0
while self._passed_delay(now) and self.waiting:
seq_group = self.waiting[0]
waiting_seqs = seq_group.get_seqs(
status=SequenceStatus.WAITING)
assert len(waiting_seqs) == 1, (
"Waiting sequence group should have only one prompt "
"sequence.")
# get_len includes output tokens if the request has been
# preempted.
num_prefill_tokens = waiting_seqs[0].get_len()
if num_prefill_tokens > self.prompt_limit:
logger.warning(
f"Input prompt ({num_prefill_tokens} tokens) is too "
f"long and exceeds limit of {self.prompt_limit}")
for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group)
self.waiting.popleft()
continue
# If the sequence group cannot be allocated, stop.
can_allocate = self.block_manager.can_allocate(seq_group)
if can_allocate == AllocStatus.LATER:
break
elif can_allocate == AllocStatus.NEVER:
logger.warning(
f"Input prompt ({num_prefill_tokens} tokens) is too "
f"long and exceeds the capacity of block_manager")
for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group)
self.waiting.popleft()
continue
lora_int_id = 0
if self.lora_enabled:
lora_int_id = seq_group.lora_int_id
if (lora_int_id > 0 and lora_int_id not in curr_loras
and len(curr_loras) >= self.lora_config.max_loras):
# We don't have a space for another LoRA, so
# we ignore this request for now.
leftover_waiting_sequences.appendleft(seq_group)
self.waiting.popleft()
continue
# If the number of batched tokens exceeds the limit, stop.
num_batched_tokens += num_prefill_tokens
if (num_batched_tokens >
self.scheduler_config.max_num_batched_tokens):
break
# The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences.
num_new_seqs = seq_group.get_max_num_running_seqs()
if (num_curr_seqs + num_new_seqs >
self.scheduler_config.max_num_seqs):
break
if lora_int_id > 0:
curr_loras.add(lora_int_id)
self.waiting.popleft()
self._allocate(seq_group)
self.running.append(seq_group)
num_curr_seqs += num_new_seqs
scheduled.append(
ScheduledSequenceGroup(
seq_group=seq_group,
token_chunk_size=num_prefill_tokens))
self.waiting.extendleft(leftover_waiting_sequences)
if scheduled or ignored_seq_groups:
self.prev_prompt = True
scheduler_outputs = SchedulerOutputs(
scheduled_seq_groups=scheduled,
prompt_run=True,
num_batched_tokens=num_batched_tokens,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
ignored_seq_groups=ignored_seq_groups,
num_lookahead_slots=self._get_num_lookahead_slots(
is_prefill=True),
)
return scheduler_outputs
seq_groups: List[ScheduledSequenceGroup] = []
preempted: List[SequenceGroup] = []
swapped_out: List[SequenceGroup] = []
# NOTE(woosuk): Preemption happens only when there is no available slot
# to keep all the sequence groups in the RUNNING state.
# In this case, the policy is responsible for deciding which sequence
# groups to preempt.
self.running = self.policy.sort_by_priority(now, self.running)
now = time.time()
running_queue = policy.sort_by_priority(now, running_queue)
# Reserve new token slots for the running sequence groups.
running: Deque[SequenceGroup] = deque()
preempted: List[SequenceGroup] = []
while self.running:
seq_group = self.running.popleft()
while running_queue:
# NOTE: running
seq_group = running_queue[0]
num_running_tokens = (
seq_group.num_seqs(status=SequenceStatus.RUNNING) *
self.num_decoding_tokens_per_seq)
num_running_seqs = seq_group.get_max_num_running_seqs()
running_queue.popleft()
while not self._can_append_slots(seq_group):
if self.running:
# Increase the budget as requests are preempted.
budget.num_batched_tokens -= num_running_tokens
budget.num_curr_seqs -= num_running_seqs
if curr_loras is not None and seq_group.lora_int_id > 0:
curr_loras.pop(seq_group.lora_int_id)
if running_queue:
# Preempt the lowest-priority sequence groups.
victim_seq_group = self.running.pop()
self._preempt(victim_seq_group, blocks_to_swap_out)
preempted.append(victim_seq_group)
victim_seq_group = running_queue.pop()
preempted_mode = self._preempt(victim_seq_group,
blocks_to_swap_out)
if preempted_mode == PreemptionMode.RECOMPUTE:
preempted.append(victim_seq_group)
else:
swapped_out.append(victim_seq_group)
else:
# No other sequence groups can be preempted.
# Preempt the current sequence group.
self._preempt(seq_group, blocks_to_swap_out)
preempted.append(seq_group)
preempted_mode = self._preempt(seq_group,
blocks_to_swap_out)
if preempted_mode == PreemptionMode.RECOMPUTE:
preempted.append(seq_group)
else:
swapped_out.append(seq_group)
break
else:
# Append new slots to the sequence group.
logger.debug(f"append slot for {seq_group}")
self._append_slots(seq_group, blocks_to_copy)
running.append(seq_group)
self.running = running
seq_groups.append(
ScheduledSequenceGroup(seq_group=seq_group,
token_chunk_size=1))
# Make sure all queues are updated.
assert len(running_queue) == 0
# Swap in the sequence groups in the SWAPPED state if possible.
self.swapped = self.policy.sort_by_priority(now, self.swapped)
if not preempted:
num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
for seq_group in self.running)
curr_loras = set(
seq_group.lora_int_id
for seq_group in self.running) if self.lora_enabled else None
leftover_swapped = deque()
while self.swapped:
seq_group = self.swapped[0]
lora_int_id = 0
if self.lora_enabled:
lora_int_id = seq_group.lora_int_id
if (lora_int_id > 0 and lora_int_id not in curr_loras
and len(curr_loras) >= self.lora_config.max_loras):
# We don't have a space for another LoRA, so
# we ignore this request for now.
leftover_swapped.appendleft(seq_group)
self.swapped.popleft()
continue
# If the sequence group cannot be swapped in, stop.
if not self._can_swap_in(seq_group):
break
# The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences.
num_new_seqs = seq_group.get_max_num_running_seqs()
if (num_curr_seqs + num_new_seqs >
self.scheduler_config.max_num_seqs):
break
if lora_int_id > 0:
curr_loras.add(lora_int_id)
self.swapped.popleft()
self._swap_in(seq_group, blocks_to_swap_in)
self._append_slots(seq_group, blocks_to_copy)
num_curr_seqs += num_new_seqs
self.running.append(seq_group)
self.swapped.extendleft(leftover_swapped)
# Each sequence in the generation phase only takes one token slot.
# Therefore, the number of batched tokens is equal to the number of
# sequences in the RUNNING state.
num_batched_tokens = sum(
seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running)
scheduler_outputs = SchedulerOutputs(
scheduled_seq_groups=[
ScheduledSequenceGroup(seq_group=running_group,
token_chunk_size=1)
for running_group in self.running
],
prompt_run=False,
num_batched_tokens=num_batched_tokens,
blocks_to_swap_in=blocks_to_swap_in,
return running_queue, SchedulerDecodeOutputs(
seq_groups=seq_groups,
preempted=preempted,
swapped_out=swapped_out,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
ignored_seq_groups=[],
num_lookahead_slots=self._get_num_lookahead_slots(
is_prefill=False),
is_prefill=False))
def _schedule_swapped(
self,
swapped_queue: deque,
budget: SchedulingBudget,
curr_loras: Optional[Set[int]],
policy: Policy,
) -> Tuple[deque, SchedulerSwappedInOutputs]:
"""Schedule sequence groups that are swapped out.
It schedules swapped requests as long as it fits `budget` and
curr_loras <= max_lora from the scheduling config. The input arguments
`budget` and `curr_loras` are updated based on scheduled seq_groups.
Args:
swapped_queue: The queue that contains swapped out requests.
The given arguments are NOT in-place modified.
budget: The scheduling budget. The argument is in-place updated
when any requests are swapped in.
curr_loras: Currently batched lora request ids. The argument is
in-place updated when any requests are swapped in.
policy: The sorting policy to sort swapped_queue.
Returns:
A tuple of remaining swapped_queue after scheduling and
SchedulerSwappedInOutputs.
"""
# Blocks that need to be swapped or copied before model execution.
blocks_to_swap_in: Dict[int, int] = {}
blocks_to_copy: Dict[int, List[int]] = {}
seq_groups: List[ScheduledSequenceGroup] = []
now = time.time()
swapped_queue = policy.sort_by_priority(now, swapped_queue)
leftover_swapped = deque()
while swapped_queue:
seq_group = swapped_queue[0]
# If the sequence group cannot be swapped in, stop.
if not self.block_manager.can_swap_in(seq_group):
break
lora_int_id = 0
if self.lora_enabled:
lora_int_id = seq_group.lora_int_id
if (lora_int_id > 0 and lora_int_id not in curr_loras
and len(curr_loras) >= self.lora_config.max_loras):
# We don't have a space for another LoRA, so
# we ignore this request for now.
leftover_swapped.appendleft(seq_group)
swapped_queue.popleft()
continue
# The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences.
num_new_seqs = seq_group.get_max_num_running_seqs()
num_new_tokens = (
seq_group.num_seqs(status=SequenceStatus.SWAPPED) *
self.num_decoding_tokens_per_seq)
if not budget.can_schedule(num_new_tokens=num_new_tokens,
num_new_seqs=num_new_seqs):
break
if lora_int_id > 0 and curr_loras is not None:
curr_loras.add(lora_int_id)
swapped_queue.popleft()
self._swap_in(seq_group, blocks_to_swap_in)
self._append_slots(seq_group, blocks_to_copy)
seq_groups.append(
ScheduledSequenceGroup(seq_group, token_chunk_size=1))
budget.num_batched_tokens += num_new_tokens
budget.num_curr_seqs += num_new_seqs
swapped_queue.extendleft(leftover_swapped)
return swapped_queue, SchedulerSwappedInOutputs(
seq_groups=seq_groups,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_copy=blocks_to_copy,
num_lookahead_slots=self._get_num_lookahead_slots(
is_prefill=False))
def _schedule_prefills(
self,
waiting_queue: deque,
budget: SchedulingBudget,
curr_loras: Optional[Set[int]],
) -> Tuple[deque, SchedulerPrefillOutputs]:
"""Schedule sequence groups that are in prefill stage.
Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE
as a new prefill (that starts from beginning -> most recently generated
tokens).
It schedules waiting requests as long as it fits `budget` and
curr_loras <= max_lora from the scheduling config. The input arguments
`budget` and `curr_loras` are updated based on scheduled seq_groups.
Args:
waiting_queue: The queue that contains prefill requests.
The given arguments are NOT in-place modified.
budget: The scheduling budget. The argument is in-place updated
when any requests are scheduled.
curr_loras: Currently batched lora request ids. The argument is
in-place updated when any requests are scheduled.
Returns:
A tuple of remaining waiting_queue after scheduling and
SchedulerSwappedInOutputs.
"""
ignored_seq_groups: List[SequenceGroup] = []
seq_groups: List[SequenceGroup] = []
# We don't sort waiting queue because we assume it is sorted.
# Copy the queue so that the input queue is not modified.
waiting_queue = deque([s for s in waiting_queue])
leftover_waiting_sequences = deque()
while self._passed_delay(time.time()) and waiting_queue:
seq_group = waiting_queue[0]
waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
assert len(waiting_seqs) == 1, (
"Waiting sequence group should have only one prompt "
"sequence.")
num_prompt_tokens = waiting_seqs[0].get_len()
if num_prompt_tokens > self.prompt_limit:
logger.warning(
f"Input prompt ({num_prompt_tokens} tokens) is too long"
f" and exceeds limit of {self.prompt_limit}")
for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group)
waiting_queue.popleft()
continue
# If the sequence group cannot be allocated, stop.
can_allocate = self.block_manager.can_allocate(seq_group)
if can_allocate == AllocStatus.LATER:
break
elif can_allocate == AllocStatus.NEVER:
logger.warning(
f"Input prompt ({num_prompt_tokens} tokens) is too long"
f" and exceeds the capacity of block_manager")
for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group)
waiting_queue.popleft()
continue
lora_int_id = 0
if self.lora_enabled:
lora_int_id = seq_group.lora_int_id
if (self.lora_enabled and lora_int_id > 0
and lora_int_id not in curr_loras
and len(curr_loras) >= self.lora_config.max_loras):
# We don't have a space for another LoRA, so
# we ignore this request for now.
leftover_waiting_sequences.appendleft(seq_group)
waiting_queue.popleft()
continue
num_new_seqs = seq_group.get_max_num_running_seqs()
if not budget.can_schedule(num_new_tokens=num_prompt_tokens,
num_new_seqs=num_new_seqs):
break
# Can schedule this request.
if curr_loras is not None and lora_int_id > 0:
curr_loras.add(lora_int_id)
waiting_queue.popleft()
self._allocate_and_set_running(seq_group)
seq_groups.append(
ScheduledSequenceGroup(seq_group=seq_group,
token_chunk_size=num_prompt_tokens))
budget.num_batched_tokens += num_prompt_tokens
budget.num_curr_seqs += num_new_seqs
# Queue requests that couldn't be scheduled.
waiting_queue.extendleft(leftover_waiting_sequences)
if len(seq_groups) > 0:
self.prev_prompt = True
return waiting_queue, SchedulerPrefillOutputs(
seq_groups=seq_groups,
ignored_seq_groups=ignored_seq_groups,
num_lookahead_slots=self._get_num_lookahead_slots(is_prefill=True))
def _schedule(self) -> SchedulerOutputs:
"""Batch requests that are queued..
The current policy is designed to opimimize the throughput. First,
it batches as many prefill requests as possible. And it schedules
decodes. If there's a pressure on GPU memory, decode requests can
be swapped or preempted.
"""
# Include running requests to the budget.
budget = SchedulingBudget(
num_batched_tokens=sum(
seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running),
num_curr_seqs=sum(seq_group.get_max_num_running_seqs()
for seq_group in self.running),
token_budget=self.scheduler_config.max_num_batched_tokens,
max_num_seqs=self.scheduler_config.max_num_seqs,
)
curr_loras = set(
seq_group.lora_int_id
for seq_group in self.running) if self.lora_enabled else None
remaining_waiting, prefills = (self.waiting,
SchedulerPrefillOutputs.create_empty())
remaining_running, decodes = (self.running,
SchedulerDecodeOutputs.create_empty())
remaining_swapped, swapped_in = (
self.swapped, SchedulerSwappedInOutputs.create_empty())
# If any requests are swapped, prioritized swapped requests.
if not self.swapped:
remaining_waiting, prefills = self._schedule_prefills(
self.waiting, budget, curr_loras)
# Don't schedule decodes if prefills are scheduled.
if len(prefills.seq_groups) == 0:
remaining_running, decodes = self._schedule_decodes(
self.running, budget, curr_loras, self.policy)
# If any sequence group is preempted, do not swap in any sequence
# group. because it means there's no slot for new running requests.
if len(decodes.preempted) + len(decodes.swapped_out) == 0:
remaining_swapped, swapped_in = self._schedule_swapped(
self.swapped, budget, curr_loras, self.policy)
assert (budget.num_batched_tokens <=
self.scheduler_config.max_num_batched_tokens)
assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs
# Update waiting requests.
self.waiting = remaining_waiting
self.waiting.extendleft(decodes.preempted)
# Update new running requests.
self.running = remaining_running
self.running.extend([s.seq_group for s in prefills.seq_groups])
self.running.extend([s.seq_group for s in decodes.seq_groups])
self.running.extend([s.seq_group for s in swapped_in.seq_groups])
# Update swapped requests.
self.swapped = remaining_swapped
self.swapped.extend(decodes.swapped_out)
return SchedulerOutputs(
scheduled_seq_groups=prefills.seq_groups + decodes.seq_groups +
swapped_in.seq_groups,
num_prefill_groups=len(prefills.seq_groups),
num_batched_tokens=budget.num_batched_tokens,
blocks_to_swap_in=swapped_in.blocks_to_swap_in,
blocks_to_swap_out=decodes.blocks_to_swap_out,
blocks_to_copy=merge_dicts(decodes.blocks_to_copy,
swapped_in.blocks_to_copy),
ignored_seq_groups=prefills.ignored_seq_groups,
num_lookahead_slots=(prefills.num_lookahead_slots +
decodes.num_lookahead_slots +
swapped_in.num_lookahead_slots),
)
return scheduler_outputs
def _can_append_slots(self, seq_group: SequenceGroup) -> bool:
"""Determine whether or not we have enough space in the KV cache to
@ -444,7 +658,8 @@ class Scheduler:
# Create input data structures.
seq_group_metadata_list: List[SequenceGroupMetadata] = []
for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
for i, scheduled_seq_group in enumerate(
scheduler_outputs.scheduled_seq_groups):
seq_group = scheduled_seq_group.seq_group
token_chunk_size = scheduled_seq_group.token_chunk_size
seq_group.maybe_set_first_scheduled_time(now)
@ -464,9 +679,12 @@ class Scheduler:
self.block_manager.get_common_computed_block_ids(
seq_group.get_seqs(status=SequenceStatus.RUNNING)))
# It assumes the scheduled_seq_groups is ordered by
# prefill < decoding.
is_prompt = i < scheduler_outputs.num_prefill_groups
seq_group_metadata = SequenceGroupMetadata(
request_id=seq_group.request_id,
is_prompt=scheduler_outputs.prompt_run,
is_prompt=is_prompt,
seq_data=seq_data,
sampling_params=seq_group.sampling_params,
block_tables=block_tables,
@ -479,7 +697,7 @@ class Scheduler:
# the subsequent comms can still use delta, but
# `multi_modal_data` will be None.
multi_modal_data=seq_group.multi_modal_data
if scheduler_outputs.prompt_run else None,
if scheduler_outputs.num_prefill_groups > 0 else None,
)
seq_group_metadata_list.append(seq_group_metadata)
@ -504,7 +722,7 @@ class Scheduler:
self.running = deque(seq_group for seq_group in self.running
if not seq_group.is_finished())
def _allocate(self, seq_group: SequenceGroup) -> None:
def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None:
self.block_manager.allocate(seq_group)
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
seq.status = SequenceStatus.RUNNING
@ -539,7 +757,7 @@ class Scheduler:
seq_group: SequenceGroup,
blocks_to_swap_out: Dict[int, int],
preemption_mode: Optional[PreemptionMode] = None,
) -> None:
) -> PreemptionMode:
# If preemption mode is not specified, we determine the mode as follows:
# We use recomputation by default since it incurs lower overhead than
# swapping. However, when the sequence group has multiple sequences
@ -562,6 +780,7 @@ class Scheduler:
self._preempt_by_swap(seq_group, blocks_to_swap_out)
else:
raise AssertionError("Invalid preemption mode.")
return preemption_mode
def _preempt_by_recompute(
self,
@ -573,9 +792,6 @@ class Scheduler:
seq.status = SequenceStatus.WAITING
self.free_seq(seq)
seq.reset_state_for_recompute()
# NOTE: For FCFS, we insert the preempted sequence group to the front
# of the waiting queue.
self.waiting.appendleft(seq_group)
def _preempt_by_swap(
self,
@ -583,7 +799,6 @@ class Scheduler:
blocks_to_swap_out: Dict[int, int],
) -> None:
self._swap_out(seq_group, blocks_to_swap_out)
self.swapped.append(seq_group)
def _swap_in(
self,

View File

@ -728,7 +728,7 @@ class LLMEngine:
time_per_output_tokens = []
time_e2e_requests = []
if scheduler_outputs is not None:
prompt_run = scheduler_outputs.prompt_run
prompt_run = scheduler_outputs.num_prefill_groups > 0
# Number of Tokens.
if prompt_run:

View File

@ -6,7 +6,7 @@ import socket
import subprocess
import uuid
import warnings
from collections import OrderedDict
from collections import OrderedDict, defaultdict
from functools import lru_cache, partial
from platform import uname
from typing import (Any, Awaitable, Callable, Generic, Hashable, List,
@ -450,3 +450,20 @@ def maybe_expand_dim(tensor: torch.Tensor,
if tensor.ndim < target_dims:
tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim)))
return tensor
def merge_dicts(dict1: dict[Any, list[Any]],
dict2: dict[Any, list[Any]]) -> dict[Any, list[Any]]:
"""Merge 2 dicts that have key -> List of items.
When a key conflicts, the values in dict1 is prioritized.
"""
merged_dict = defaultdict(list)
for key, value in dict1.items():
merged_dict[key].extend(value)
for key, value in dict2.items():
merged_dict[key].extend(value)
return dict(merged_dict)