[Core] feat: Implement Priority Scheduling in V1 Engine (#19057)

Signed-off-by: amit <amit.man@gmail.com>
Co-authored-by: Roger Wang <Rogerw0108@gmail.com>
This commit is contained in:
amit 2025-06-23 06:18:08 +03:00 committed by GitHub
parent c4cf260677
commit 4a0f7888a3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 896 additions and 30 deletions

View File

@ -45,6 +45,18 @@ For each item, our progress towards V1 support falls into one of the following s
- **🟠 Delayed**: Temporarily dropped in V1 but planned to be re-introduced later.
- **🔴 Deprecated**: Not planned for V1 unless there is strong demand.
!!! note
vLLM V1s unified scheduler treats both prompt and output tokens the same
way by using a simple dictionary (e.g., `{request_id: num_tokens}`) to dynamically
allocate a fixed token budget per request, enabling features like chunked prefills,
prefix caching, and speculative decoding without a strict separation between prefill
and decode phases.
The V1 scheduler supports multiple scheduling policies, including First-Come,
First-Served (FCFS) and priority-based scheduling (where requests are processed
based on assigned priority, with FCFS as a tie-breaker), configurable via the
`--scheduling-policy` argument.
### Hardware
| Hardware | Status |

View File

@ -1150,7 +1150,6 @@ def test_kv_connector_handles_preemption():
assert len(scheduler.running) == 1
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 1
# All memory should be freed since nothing is running.
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \
== NUM_BLOCKS - 1
@ -1265,3 +1264,592 @@ def test_memory_leak():
# Confirm no memory leak.
assert_scheduler_empty(scheduler)
def create_scheduler_with_priority(
model: str = "facebook/opt-125m",
max_num_seqs: int = 16,
max_num_batched_tokens: int = 8192,
enable_prefix_caching: Optional[bool] = None,
long_prefill_token_threshold: int = 0,
disable_chunked_mm_input: bool = False,
use_kv_connector: bool = False,
num_blocks: int = 10000,
block_size: int = 16,
max_model_len: Optional[int] = None,
num_speculative_tokens: Optional[int] = None,
) -> Scheduler:
'''Create scheduler with priority policy enabled.
Args:
model: model under test
max_num_seqs: max sequences to schedule
max_num_batch_tokens: max num tokens to batch
enable_prefix_caching: optionally force APC config
(True/False) or use default
(None)
Returns:
{class}`Scheduler` instance with priority scheduling
'''
if max_model_len is None:
max_model_len = max_num_batched_tokens
scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
max_model_len=max_model_len,
long_prefill_token_threshold=long_prefill_token_threshold,
disable_chunked_mm_input=disable_chunked_mm_input,
enable_chunked_prefill=True,
policy="priority", # Enable priority scheduling
)
model_config = ModelConfig(
model=model,
task="auto",
tokenizer=model,
tokenizer_mode="auto",
trust_remote_code=True,
dtype="float16",
seed=42,
)
# Cache config, optionally force APC
kwargs_cache = ({} if enable_prefix_caching is None else {
'enable_prefix_caching': enable_prefix_caching
})
cache_config = CacheConfig(
block_size=block_size,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
**kwargs_cache,
)
kv_transfer_config = KVTransferConfig(
kv_connector="SharedStorageConnector",
kv_role="kv_both",
kv_connector_extra_config={"shared_storage_path": "local_storage"},
) if use_kv_connector else None
speculative_config: Optional[SpeculativeConfig] = None
if num_speculative_tokens is not None:
speculative_config = SpeculativeConfig(
model="ngram", num_speculative_tokens=num_speculative_tokens)
vllm_config = VllmConfig(
scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
kv_transfer_config=kv_transfer_config,
speculative_config=speculative_config,
)
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks, # A large number of blocks to hold all requests
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(['layer'],
FullAttentionSpec(block_size, 1, 1, torch.float32,
False))
],
)
cache_config.num_gpu_blocks = num_blocks
return Scheduler(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
log_stats=True,
structured_output_manager=StructuredOutputManager(vllm_config),
)
def create_requests_with_priority(
num_requests: int,
priorities: list[int],
arrival_times: Optional[list[float]] = None,
num_tokens: int = 10,
mm_positions: Optional[list[PlaceholderRange]] = None,
max_tokens: int = 16,
stop_token_ids: Optional[list[int]] = None,
prompt_logprobs: Optional[int] = None):
"""Create requests with specified priorities and arrival times."""
assert len(priorities) == num_requests
if arrival_times is not None:
assert len(arrival_times) == num_requests
else:
arrival_times = [float(i) for i in range(num_requests)]
sampling_params = SamplingParams(ignore_eos=False,
max_tokens=max_tokens,
stop_token_ids=stop_token_ids,
prompt_logprobs=prompt_logprobs)
requests = []
for i in range(num_requests):
if mm_positions is not None:
mm_position = mm_positions[i]
mm_inputs = [MultiModalKwargs({})] * len(mm_position)
else:
mm_position = None
mm_inputs = None
request = Request(
request_id=f"{i}",
prompt_token_ids=[i] * num_tokens,
sampling_params=sampling_params,
pooling_params=None,
multi_modal_inputs=mm_inputs,
multi_modal_placeholders=mm_position,
multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID,
arrival_time=arrival_times[i],
priority=priorities[i],
)
requests.append(request)
return requests
def test_priority_scheduling_basic_ordering():
"""Test that requests are scheduled in priority order
(lower value = higher priority)."""
scheduler = create_scheduler_with_priority()
# Create requests with different priorities
# Priority 0 (highest), 1, 2 (lowest)
priorities = [2, 0, 1] # Add in non-priority order
arrival_times = [1.0, 2.0, 3.0] # All different arrival times
requests = create_requests_with_priority(num_requests=3,
priorities=priorities,
arrival_times=arrival_times)
# Add requests in non-priority order
for request in requests:
scheduler.add_request(request)
# Schedule and verify priority order
output = scheduler.schedule()
# Should schedule all requests since they fit in budget
assert len(output.scheduled_new_reqs) == 3
# Verify they are scheduled in priority order:
# req_1 (priority 0), req_2 (priority 1), req_0 (priority 2)
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
assert scheduled_req_ids == ["1", "2", "0"]
def test_priority_scheduling_arrival_time_tiebreaker():
"""Test that arrival time is used
as tiebreaker when priorities are equal."""
scheduler = create_scheduler_with_priority()
# Create requests with same priority but different arrival times
priorities = [1, 1, 1] # All same priority
arrival_times = [3.0, 1.0, 2.0] # Different arrival times
requests = create_requests_with_priority(num_requests=3,
priorities=priorities,
arrival_times=arrival_times)
# Add requests in non-arrival order
for request in requests:
scheduler.add_request(request)
# Schedule and verify arrival time order
output = scheduler.schedule()
# Should schedule all requests since they fit in budget
assert len(output.scheduled_new_reqs) == 3
# Verify they are scheduled in arrival time order:
# req_1 (1.0), req_2 (2.0), req_0 (3.0)
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
assert scheduled_req_ids == ["1", "2", "0"]
def test_priority_scheduling_mixed_priority_and_arrival():
"""Test priority scheduling with mixed priorities and arrival times."""
scheduler = create_scheduler_with_priority()
# Create requests with mixed priorities and arrival times
priorities = [2, 1, 1, 0] # Mixed priorities
arrival_times = [1.0, 3.0, 2.0, 4.0] # Mixed arrival times
requests = create_requests_with_priority(num_requests=4,
priorities=priorities,
arrival_times=arrival_times)
# Add requests
for request in requests:
scheduler.add_request(request)
# Schedule and verify order
output = scheduler.schedule()
# Should schedule all requests since they fit in budget
assert len(output.scheduled_new_reqs) == 4
# Expected order:
# 1. req_3 (priority 0, arrival 4.0)
# 2. req_2 (priority 1, arrival 2.0) - earlier arrival than req_1
# 3. req_1 (priority 1, arrival 3.0)
# 4. req_0 (priority 2, arrival 1.0)
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
assert scheduled_req_ids == ["3", "2", "1", "0"]
def test_priority_scheduling_preemption():
"""Test that priority scheduling preempts
lower priority requests when memory is constrained."""
# Create scheduler with very limited memory to force preemption
scheduler = create_scheduler_with_priority(
max_num_seqs=3, # Allow multiple requests
max_num_batched_tokens=200,
num_blocks=6, # Very limited blocks to force memory pressure
block_size=16, # Standard block size
)
# Create initial low-priority requests that will consume most memory
low_priority_requests = create_requests_with_priority(
num_requests=2,
priorities=[5, 5], # Low priority
arrival_times=[1.0, 2.0],
num_tokens=30 # Large enough to consume significant memory
)
# Add and schedule low priority requests
for request in low_priority_requests:
scheduler.add_request(request)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 2
# Simulate model execution to move requests to running state
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in low_priority_requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(low_priority_requests)
},
sampled_token_ids=[[100] for _ in low_priority_requests],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_output)
# Verify both requests are running
assert len(scheduler.running) == 2
# Now add a high-priority request that requires memory allocation
# This should trigger preemption due to memory constraints
high_priority_request = create_requests_with_priority(
num_requests=1,
priorities=[0], # High priority
arrival_times=[3.0],
num_tokens=30 # Large enough to require significant memory
)[0]
scheduler.add_request(high_priority_request)
# Schedule again - this should trigger
# preemption when trying to allocate memory
output = scheduler.schedule()
# Due to the scheduler's design, if preemption happens
# during running request scheduling,
# waiting requests won't be scheduled in the same step
# Let's check if preemption occurred by looking at the waiting queue
# If preemption happened, we should see requests in the
# waiting queue
if len(scheduler.waiting) > 1: # high priority + preempted request
# Preemption occurred - verify the high priority request
# gets scheduled next
output2 = scheduler.schedule()
assert len(output2.scheduled_new_reqs) == 1
# High priority request
assert output2.scheduled_new_reqs[0].req_id == "0"
else:
# No preemption needed - all requests fit
# This is also valid behavior if memory allows
assert len(output.scheduled_new_reqs) == 1
# High priority request
assert output.scheduled_new_reqs[0].req_id == "0"
def test_priority_scheduling_no_preemption_when_space_available():
"""Test that preemption doesn't happen
when there's space for new requests."""
scheduler = create_scheduler_with_priority(
max_num_seqs=3, # Allow 3 concurrent requests
max_num_batched_tokens=200, # Sufficient token budget
)
# Add two low-priority running requests
low_priority_requests = create_requests_with_priority(
num_requests=2,
priorities=[5, 5],
arrival_times=[1.0, 2.0],
num_tokens=30)
for request in low_priority_requests:
scheduler.add_request(request)
output = scheduler.schedule()
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in low_priority_requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(low_priority_requests)
},
sampled_token_ids=[[100] for _ in low_priority_requests],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_output)
# Add high-priority request
high_priority_request = create_requests_with_priority(num_requests=1,
priorities=[0],
arrival_times=[3.0],
num_tokens=30)[0]
scheduler.add_request(high_priority_request)
# Schedule - should not preempt since there's space
output = scheduler.schedule()
# Should schedule the new request without preemption
assert len(output.scheduled_new_reqs) == 1
assert len(scheduler.running) == 3 # All three requests running
assert len(scheduler.waiting) == 0 # No requests waiting
def test_priority_scheduling_preemption_victim_selection():
"""Test that the correct victim is selected for
preemption based on priority and arrival time."""
# This test verifies the priority-based victim selection logic
# by checking the waiting queue order after adding requests with different
# priorities
scheduler = create_scheduler_with_priority(
max_num_seqs=1, # Force sequential processing to test priority order
)
# Create requests with different priorities
requests = create_requests_with_priority(
num_requests=3,
priorities=[3, 2, 0], # Different priorities: low, medium, high
arrival_times=[1.0, 2.0, 3.0],
num_tokens=10)
# Add all requests
for request in requests:
scheduler.add_request(request)
# Schedule - should only schedule the highest priority request
# (req_2, priority 0)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 1
assert output.scheduled_new_reqs[0].req_id == "2" # Highest priority
# Verify the waiting queue has the remaining requests in priority order
assert len(scheduler.waiting) == 2
# Extract waiting requests and verify priority order
waiting_requests = list(scheduler.waiting)
waiting_priorities = [req.priority for req in waiting_requests]
waiting_req_ids = [req.request_id for req in waiting_requests]
# Should be req_1 (priority 2) then req_0 (priority 3)
assert waiting_priorities == [2, 3]
assert waiting_req_ids == ["1", "0"]
def test_priority_scheduling_equal_priority_preemption():
"""Test arrival time tiebreaker when requests have equal priority."""
# This test verifies that arrival time is used as a tiebreaker for equal
# priorities
scheduler = create_scheduler_with_priority(
max_num_seqs=1, # Force sequential processing
)
# Create requests with same priority but different arrival times
requests = create_requests_with_priority(
num_requests=3,
priorities=[2, 2, 2], # Same priority
arrival_times=[3.0, 1.0, 2.0], # Different arrival times
num_tokens=10)
# Add all requests
for request in requests:
scheduler.add_request(request)
# Schedule - should schedule the request with earliest arrival time
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 1
assert output.scheduled_new_reqs[0].req_id == "1" # Earliest arrival (1.0)
# Verify the waiting queue has remaining requests in arrival time order
assert len(scheduler.waiting) == 2
# Extract waiting requests and verify arrival time order
waiting_requests = list(scheduler.waiting)
waiting_arrival_times = [req.arrival_time for req in waiting_requests]
waiting_req_ids = [req.request_id for req in waiting_requests]
# Should be req_2 (arrival 2.0) then req_0 (arrival 3.0)
assert waiting_arrival_times == [2.0, 3.0]
assert waiting_req_ids == ["2", "0"]
def test_priority_scheduling_waiting_queue_order():
"""Test that the waiting queue maintains priority order."""
scheduler = create_scheduler_with_priority(
max_num_seqs=1, # Only one request can run at a time
)
# Create multiple requests with different priorities
requests = create_requests_with_priority(
num_requests=4,
priorities=[3, 1, 2, 0], # Mixed priorities
arrival_times=[1.0, 2.0, 3.0, 4.0],
num_tokens=10)
# Add all requests
for request in requests:
scheduler.add_request(request)
# Schedule - should only schedule the highest priority request
# (req_3, priority 0)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 1
assert output.scheduled_new_reqs[0].req_id == "3"
# Verify waiting queue has remaining requests in priority order
assert len(scheduler.waiting) == 3
# Extract requests from waiting queue
# (it's a heap, so we need to pop to see order)
waiting_requests = list(scheduler.waiting)
waiting_priorities = [req.priority for req in waiting_requests]
waiting_req_ids = [req.request_id for req in waiting_requests]
# Should be ordered by priority: req_1 (1), req_2 (2), req_0 (3)
assert waiting_req_ids == ["1", "2", "0"]
assert waiting_priorities == [1, 2, 3]
def test_priority_scheduling_fcfs_fallback():
"""Test that FCFS behavior is maintained when all
requests have same priority."""
scheduler = create_scheduler_with_priority()
# Create requests with same priority but different arrival times
priorities = [1, 1, 1, 1] # All same priority
arrival_times = [4.0, 1.0, 3.0, 2.0] # Different arrival times
requests = create_requests_with_priority(num_requests=4,
priorities=priorities,
arrival_times=arrival_times)
# Add requests
for request in requests:
scheduler.add_request(request)
# Schedule
output = scheduler.schedule()
# Should schedule all requests in arrival time order
assert len(output.scheduled_new_reqs) == 4
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
# Expected order by arrival time:
# req_1 (1.0), req_3 (2.0), req_2 (3.0), req_0 (4.0)
assert scheduled_req_ids == ["1", "3", "2", "0"]
def test_priority_scheduling_with_limited_slots():
"""Test priority scheduling when max_num_seqs limits concurrent requests."""
scheduler = create_scheduler_with_priority(
max_num_seqs=2, # Only allow 2 concurrent requests
max_num_batched_tokens=1000, # Plenty of token budget
)
# Create requests with different priorities
requests = create_requests_with_priority(
num_requests=4,
priorities=[3, 1, 2, 0], # Mixed priorities
arrival_times=[1.0, 2.0, 3.0, 4.0],
num_tokens=10)
# Add all requests
for request in requests:
scheduler.add_request(request)
# Schedule - should only schedule the 2 highest priority requests
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 2
# Should schedule req_3 (priority 0) and req_1 (priority 1)
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
assert "3" in scheduled_req_ids # Priority 0
assert "1" in scheduled_req_ids # Priority 1
# Remaining requests should be in waiting queue in priority order
assert len(scheduler.waiting) == 2
# Extract waiting requests and verify order
waiting_requests = list(scheduler.waiting)
waiting_priorities = [req.priority for req in waiting_requests]
waiting_req_ids = [req.request_id for req in waiting_requests]
# Should be req_2 (priority 2) then req_0 (priority 3)
assert waiting_priorities == [2, 3]
assert waiting_req_ids == ["2", "0"]
def test_priority_scheduling_heap_property():
"""Test that the waiting queue maintains heap
property for priority scheduling."""
scheduler = create_scheduler_with_priority(
max_num_seqs=1, # Only one request can run at a time
)
# Add requests in random priority order
priorities = [5, 1, 8, 3, 2, 7, 4, 6]
arrival_times = [float(i) for i in range(len(priorities))]
requests = create_requests_with_priority(num_requests=len(priorities),
priorities=priorities,
arrival_times=arrival_times,
num_tokens=10)
# Add all requests
for request in requests:
scheduler.add_request(request)
# Schedule one request at a time and verify priority order
scheduled_priorities = []
while scheduler.waiting:
output = scheduler.schedule()
if output.scheduled_new_reqs:
req = output.scheduled_new_reqs[0]
scheduled_priorities.append(requests[int(req.req_id)].priority)
# Simulate completion to make room for next request
model_output = ModelRunnerOutput(
req_ids=[req.req_id],
req_id_to_index={req.req_id: 0},
sampled_token_ids=[[100]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_output)
# Finish the request to make room for the next one
scheduler.finish_requests(req.req_id,
RequestStatus.FINISHED_STOPPED)
# Verify requests were scheduled in priority order (lowest value first)
expected_priorities = sorted(priorities)
assert scheduled_priorities == expected_priorities

View File

@ -0,0 +1,224 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import heapq
from abc import ABC, abstractmethod
from collections import deque
from collections.abc import Iterable, Iterator
from enum import Enum
from vllm.v1.request import Request
class SchedulingPolicy(Enum):
"""Enum for scheduling policies."""
FCFS = "fcfs"
PRIORITY = "priority"
class RequestQueue(ABC):
"""Abstract base class for request queues."""
@abstractmethod
def add_request(self, request: Request) -> None:
"""Add a request to the queue according to the policy."""
pass
@abstractmethod
def pop_request(self) -> Request:
"""Pop a request from the queue according to the policy."""
pass
@abstractmethod
def peek_request(self) -> Request:
"""Peek at the request at the front of the queue without removing it."""
pass
@abstractmethod
def prepend_request(self, request: Request) -> None:
"""Prepend a request to the front of the queue."""
pass
@abstractmethod
def prepend_requests(self, requests: RequestQueue) -> None:
"""Prepend all requests from another queue to the front of this
queue."""
pass
@abstractmethod
def remove_request(self, request: Request) -> None:
"""Remove a specific request from the queue."""
pass
@abstractmethod
def remove_requests(self, requests: Iterable[Request]) -> None:
"""Remove multiple specific requests from the queue."""
pass
@abstractmethod
def __bool__(self) -> bool:
"""Check if queue has any requests."""
pass
@abstractmethod
def __len__(self) -> int:
"""Get number of requests in queue."""
pass
@abstractmethod
def __iter__(self) -> Iterator[Request]:
"""Iterate over the queue according to the policy."""
pass
@abstractmethod
def __reversed__(self) -> Iterator[Request]:
"""Iterate over the queue in reverse order."""
pass
class FCFSRequestQueue(deque[Request], RequestQueue):
"""A first-come-first-served queue that supports deque operations."""
def add_request(self, request: Request) -> None:
"""Add a request to the queue according to FCFS policy."""
self.append(request)
def pop_request(self) -> Request:
"""Pop a request from the queue according to FCFS policy."""
return self.popleft()
def peek_request(self) -> Request:
"""Peek at the next request in the queue without removing it."""
if not self:
raise IndexError("peek from an empty queue")
return self[0]
def prepend_request(self, request: Request) -> None:
"""Prepend a request to the front of the queue."""
self.appendleft(request)
def prepend_requests(self, requests: RequestQueue) -> None:
"""Prepend all requests from another queue to the front of this
queue."""
self.extendleft(reversed(requests))
def remove_request(self, request: Request) -> None:
"""Remove a specific request from the queue."""
self.remove(request)
def remove_requests(self, requests: Iterable[Request]) -> None:
"""Remove multiple specific requests from the queue."""
requests_to_remove = set(requests)
filtered_requests = [
req for req in self if req not in requests_to_remove
]
# deque does not support in-place filtering, so we need to clear
# and extend
self.clear()
self.extend(filtered_requests)
def __bool__(self) -> bool:
"""Check if queue has any requests."""
return len(self) > 0
def __len__(self) -> int:
"""Get number of requests in queue."""
return super().__len__()
def __iter__(self) -> Iterator[Request]:
"""Iterate over the queue according to FCFS policy."""
return super().__iter__()
def __reversed__(self) -> Iterator[Request]:
"""Iterate over the queue in reverse order."""
return super().__reversed__()
class PriorityRequestQueue(RequestQueue):
"""
A priority queue that supports heap operations.
Requests with a smaller value of `priority` are processed first.
If multiple requests have the same priority, the one with the earlier
`arrival_time` is processed first.
"""
def __init__(self) -> None:
self._heap: list[tuple[int, float, Request]] = []
def add_request(self, request: Request) -> None:
"""Add a request to the queue according to priority policy."""
heapq.heappush(self._heap,
(request.priority, request.arrival_time, request))
def pop_request(self) -> Request:
"""Pop a request from the queue according to priority policy."""
if not self._heap:
raise IndexError("pop from empty heap")
_, _, request = heapq.heappop(self._heap)
return request
def peek_request(self) -> Request:
"""Peek at the next request in the queue without removing it."""
if not self._heap:
raise IndexError("peek from empty heap")
_, _, request = self._heap[0]
return request
def prepend_request(self, request: Request) -> None:
"""Add a request to the queue according to priority policy.
Note: In a priority queue, there is no concept of prepending to the
front. Requests are ordered by (priority, arrival_time)."""
self.add_request(request)
def prepend_requests(self, requests: RequestQueue) -> None:
"""Add all requests from another queue according to priority policy.
Note: In a priority queue, there is no concept of prepending to the
front. Requests are ordered by (priority, arrival_time)."""
for request in requests:
self.add_request(request)
def remove_request(self, request: Request) -> None:
"""Remove a specific request from the queue."""
self._heap = [(p, t, r) for p, t, r in self._heap if r != request]
heapq.heapify(self._heap)
def remove_requests(self, requests: Iterable[Request]) -> None:
"""Remove multiple specific requests from the queue."""
requests_to_remove = set(requests)
self._heap = [(p, t, r) for p, t, r in self._heap
if r not in requests_to_remove]
heapq.heapify(self._heap)
def __bool__(self) -> bool:
"""Check if queue has any requests."""
return bool(self._heap)
def __len__(self) -> int:
"""Get number of requests in queue."""
return len(self._heap)
def __iter__(self) -> Iterator[Request]:
"""Iterate over the queue according to priority policy."""
heap_copy = self._heap[:]
while heap_copy:
_, _, request = heapq.heappop(heap_copy)
yield request
def __reversed__(self) -> Iterator[Request]:
"""Iterate over the queue in reverse priority order."""
return reversed(list(self))
def create_request_queue(policy: SchedulingPolicy) -> RequestQueue:
"""Create request queue based on scheduling policy."""
if policy == SchedulingPolicy.PRIORITY:
return PriorityRequestQueue()
elif policy == SchedulingPolicy.FCFS:
return FCFSRequestQueue()
else:
raise ValueError(f"Unknown scheduling policy: {policy}")

View File

@ -22,6 +22,8 @@ from vllm.v1.core.kv_cache_manager import KVCacheManager
from vllm.v1.core.sched.interface import SchedulerInterface
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
SchedulerOutput)
from vllm.v1.core.sched.request_queue import (SchedulingPolicy,
create_request_queue)
from vllm.v1.core.sched.utils import check_stop
from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput,
EngineCoreOutputs)
@ -94,8 +96,16 @@ class Scheduler(SchedulerInterface):
# req_id -> Request
self.requests: dict[str, Request] = {}
# Scheduling policy
if self.scheduler_config.policy == "priority":
self.policy = SchedulingPolicy.PRIORITY
elif self.scheduler_config.policy == "fcfs":
self.policy = SchedulingPolicy.FCFS
else:
raise ValueError(
f"Unknown scheduling policy: {self.scheduler_config.policy}")
# Priority queues for requests.
self.waiting: deque[Request] = deque()
self.waiting = create_request_queue(self.policy)
self.running: list[Request] = []
# The request IDs that are finished in between the previous and the
@ -247,7 +257,15 @@ class Scheduler(SchedulerInterface):
if new_blocks is None:
# The request cannot be scheduled.
# Preempt the lowest-priority request.
preempted_req = self.running.pop()
if self.policy == SchedulingPolicy.PRIORITY:
preempted_req = max(
self.running,
key=lambda r: (r.priority, r.arrival_time),
)
self.running.remove(preempted_req)
else:
preempted_req = self.running.pop()
self.kv_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0
@ -255,7 +273,7 @@ class Scheduler(SchedulerInterface):
preempted_req.record_event(
EngineCoreEventType.PREEMPTED, scheduled_timestamp)
self.waiting.appendleft(preempted_req)
self.waiting.prepend_request(preempted_req)
preempted_reqs.append(preempted_req)
if preempted_req == request:
# No more request to preempt.
@ -311,9 +329,9 @@ class Scheduler(SchedulerInterface):
if req.lora_request and req.lora_request.lora_int_id > 0)
assert len(scheduled_loras) <= self.lora_config.max_loras
# Use a temporary deque to collect requests that need to be skipped
# and put back at the head of the waiting queue later
skipped_waiting_requests: deque[Request] = deque()
# Use a temporary RequestQueue to collect requests that need to be
# skipped and put back at the head of the waiting queue later
skipped_waiting_requests = create_request_queue(self.policy)
# Next, schedule the WAITING requests.
if not preempted_reqs:
@ -321,7 +339,7 @@ class Scheduler(SchedulerInterface):
if len(self.running) == self.max_num_running_reqs:
break
request = self.waiting[0]
request = self.waiting.peek_request()
# KVTransfer: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
@ -332,8 +350,8 @@ class Scheduler(SchedulerInterface):
logger.debug(
"%s is still in WAITING_FOR_REMOTE_KVS state.",
request.request_id)
self.waiting.popleft()
skipped_waiting_requests.appendleft(request)
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# Skip request if the structured output request is still waiting
@ -343,19 +361,18 @@ class Scheduler(SchedulerInterface):
if structured_output_req and structured_output_req.grammar:
request.status = RequestStatus.WAITING
else:
self.waiting.popleft()
skipped_waiting_requests.appendleft(request)
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# Check that adding the request still respects the max_loras
# constraint.
if self.lora_config and request.lora_request and (
len(scheduled_loras) == self.lora_config.max_loras
and request.lora_request.lora_int_id
not in scheduled_loras):
if (self.lora_config and request.lora_request and
(len(scheduled_loras) == self.lora_config.max_loras and
request.lora_request.lora_int_id not in scheduled_loras)):
# Scheduling would exceed max_loras, skip.
self.waiting.popleft()
skipped_waiting_requests.appendleft(request)
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
num_external_computed_tokens = 0
@ -407,8 +424,8 @@ class Scheduler(SchedulerInterface):
# pooling requests to be chunked
if not self.scheduler_config.chunked_prefill_enabled and \
num_new_tokens > token_budget:
self.waiting.popleft()
skipped_waiting_requests.appendleft(request)
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
num_new_tokens = min(num_new_tokens, token_budget)
@ -448,17 +465,19 @@ class Scheduler(SchedulerInterface):
num_external_computed_tokens,
)
self.waiting.popleft()
# Request was already popped from self.waiting
# unless it was re-added above due to new_blocks being None.
request = self.waiting.pop_request()
if load_kv_async:
# If loading async, allocate memory and put request
# into the WAITING_FOR_REMOTE_KV state.
skipped_waiting_requests.appendleft(request)
skipped_waiting_requests.prepend_request(request)
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
continue
if request.use_structured_output:
structured_output_request_ids[
request.request_id] = req_index
structured_output_request_ids[request.request_id] = (
req_index)
req_index += 1
self.running.append(request)
if self.log_stats:
@ -494,7 +513,7 @@ class Scheduler(SchedulerInterface):
# Put back any skipped requests at the head of the waiting queue
if skipped_waiting_requests:
self.waiting.extendleft(skipped_waiting_requests)
self.waiting.prepend_requests(skipped_waiting_requests)
# Check if the scheduling constraints are satisfied.
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
@ -896,7 +915,7 @@ class Scheduler(SchedulerInterface):
return len(self.running), len(self.waiting)
def add_request(self, request: Request) -> None:
self.waiting.append(request)
self.waiting.add_request(request)
self.requests[request.request_id] = request
if self.log_stats:
request.record_event(EngineCoreEventType.QUEUED)
@ -917,16 +936,31 @@ class Scheduler(SchedulerInterface):
else:
request_ids = set(request_ids)
running_requests_to_remove = []
waiting_requests_to_remove = []
valid_requests = []
# First pass: collect requests to remove from queues
for req_id in request_ids:
request = self.requests.get(req_id)
if request is None:
# Invalid request ID.
continue
valid_requests.append(request)
if request.status == RequestStatus.RUNNING:
self.running.remove(request)
running_requests_to_remove.append(request)
else:
self.waiting.remove(request)
waiting_requests_to_remove.append(request)
# Remove all requests from queues at once for better efficiency
for request in running_requests_to_remove:
self.running.remove(request)
if waiting_requests_to_remove:
self.waiting.remove_requests(waiting_requests_to_remove)
# Second pass: set status and free requests
for request in valid_requests:
request.status = finished_status
self._free_request(request)

View File

@ -68,6 +68,7 @@ class EngineCoreRequest(
# belong to, to cover a race condition where the request is sent before
# a wave finished notification is received.
current_wave: int = 0
priority: int = 0
class EngineCoreEventType(enum.IntEnum):

View File

@ -219,8 +219,6 @@ class Processor:
# TODO(woosuk): Support encoder-decoder models.
self._validate_lora(lora_request)
self._validate_params(params, lora_request)
if priority != 0:
raise ValueError("V1 does not support priority yet.")
if trace_headers is not None:
raise ValueError("V1 does not support tracing yet.")
if prompt_adapter_request is not None:
@ -340,6 +338,7 @@ class Processor:
arrival_time=arrival_time,
lora_request=lora_request,
cache_salt=decoder_inputs.get("cache_salt"),
priority=priority,
data_parallel_rank=data_parallel_rank,
)

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
import time
from typing import TYPE_CHECKING, Any, Optional, Union
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
@ -30,18 +31,23 @@ class Request:
pooling_params: Optional[PoolingParams],
eos_token_id: Optional[int],
client_index: int = 0,
arrival_time: Optional[float] = None,
lora_request: Optional["LoRARequest"] = None,
structured_output_request: Optional["StructuredOutputRequest"] = None,
cache_salt: Optional[str] = None,
priority: int = 0,
) -> None:
self.request_id = request_id
self.client_index = client_index
self.priority = priority
self.sampling_params = sampling_params
self.pooling_params = pooling_params
# Because of LoRA, the eos token id can be different for each request.
self.eos_token_id = eos_token_id
self.lora_request = lora_request
self.structured_output_request = structured_output_request
self.arrival_time = arrival_time if arrival_time is not None else \
time.time()
self.status = RequestStatus.WAITING
if sampling_params and sampling_params.guided_decoding is not None:
@ -118,11 +124,13 @@ class Request:
sampling_params=request.sampling_params,
pooling_params=request.pooling_params,
eos_token_id=request.eos_token_id,
arrival_time=request.arrival_time,
lora_request=request.lora_request,
structured_output_request=StructuredOutputRequest(
sampling_params=request.sampling_params) \
if request.sampling_params else None,
cache_salt=request.cache_salt,
priority=request.priority,
)
def append_output_token_ids(