diff --git a/vllm/v1/core/sched/request_queue.py b/vllm/v1/core/sched/request_queue.py index 3c5618a4700fd..2c0a9d62eaae3 100644 --- a/vllm/v1/core/sched/request_queue.py +++ b/vllm/v1/core/sched/request_queue.py @@ -6,6 +6,7 @@ from abc import ABC, abstractmethod from collections import deque from collections.abc import Iterable, Iterator from enum import Enum +from typing import Any from vllm.v1.core.sched.policy.weighted_score_sorter import WeightedScoreSorter from vllm.v1.request import Request @@ -135,73 +136,57 @@ class FCFSRequestQueue(deque[Request], RequestQueue): return super().__reversed__() -class HeapBasedRequestQueue(RequestQueue, ABC): - """Base class for heap-based request queues (priority and SJF).""" +class RequestHeap(RequestQueue): + """A queue that supports heap operations.""" def __init__(self) -> None: self._heap: list = [] - @abstractmethod - def _to_heap_element(self, request: Request) -> object: + def _request_to_heap(self, request: Request) -> Any: """Convert a request to the appropriate heap element.""" - pass + raise NotImplementedError - @abstractmethod - def _from_heap_element(self, heap_element: object) -> Request: + def _heap_to_request(self, element: Any) -> Request: """Extract the request from a heap element.""" - pass + raise NotImplementedError def add_request(self, request: Request) -> None: - """Add a request to the heap queue.""" - heap_element = self._to_heap_element(request) - heapq.heappush(self._heap, heap_element) + """Add a request to the queue according to heap priority.""" + heapq.heappush(self._heap, self._request_to_heap(request)) def pop_request(self) -> Request: """Pop the highest priority request from the heap.""" if not self._heap: raise IndexError("pop from empty heap") - heap_element = heapq.heappop(self._heap) - return self._from_heap_element(heap_element) + return self._heap_to_request(heapq.heappop(self._heap)) def peek_request(self) -> Request: - """Peek at the highest priority request without removing it.""" + """Peek at the highest priority request in the heap without removing it.""" if not self._heap: raise IndexError("peek from empty heap") - return self._from_heap_element(self._heap[0]) + return self._heap_to_request(self._heap[0]) def prepend_request(self, request: Request) -> None: - """ - Add request to the queue. In heap-based queues, "prepend" has no - special meaning as elements are ordered by priority/score. This - behaves like add_request. - """ + """Add a request to the heap. In heap-based queues there is no beginning as + elements are ordered by priority/score. This behaves like add_request.""" self.add_request(request) def prepend_requests(self, requests: RequestQueue) -> None: - """ - Add all requests from another queue. In heap-based queues, - "prepend" has no special meaning as elements are ordered by - priority/score. This behaves like adding all requests. - """ + """Add all requests from another queue to the heap. In heap-based queues there + is no beginning as elements are ordered by priority/score. This behaves like + add_request.""" for request in requests: self.add_request(request) def remove_request(self, request: Request) -> None: """Remove a specific request from the heap.""" - try: - self._heap.remove(request) - heapq.heapify(self._heap) - except ValueError as err: - raise ValueError( - f"Request {request.request_id} not found in queue" - ) from err + self._heap.remove(self._request_to_heap(request)) + heapq.heapify(self._heap) def remove_requests(self, requests: Iterable[Request]) -> None: """Remove multiple specific requests from the heap.""" - requests_to_remove = ( - set(requests) if not isinstance(requests, set) else requests - ) - self._heap = [r for r in self._heap if r not in requests_to_remove] + remove = requests if isinstance(requests, set) else set(requests) + self._heap = [h for h in self._heap if self._heap_to_request(h) not in remove] heapq.heapify(self._heap) def __bool__(self) -> bool: @@ -213,40 +198,37 @@ class HeapBasedRequestQueue(RequestQueue, ABC): return len(self._heap) def __iter__(self) -> Iterator[Request]: - """Iterate over requests in priority/score order.""" - heap_copy = self._heap.copy() + """Iterate over the queue to heap order.""" + heap_copy = self._heap[:] while heap_copy: - heap_element = heapq.heappop(heap_copy) - yield self._from_heap_element(heap_element) + yield self._heap_to_request(heapq.heappop(heap_copy)) def __reversed__(self) -> Iterator[Request]: - """Iterate over requests in reverse priority/score order.""" + """Iterate over the queue in reverse heap order.""" return reversed(list(self)) -class PriorityRequestQueue(HeapBasedRequestQueue): - """ - A priority queue where requests are ordered by (priority, arrival_time). - Lower priority values and earlier arrival times are processed first. - """ +class PriorityRequestQueue(RequestHeap): + """A priority queue that supports heap operations. - def _to_heap_element(self, request: Request) -> Request: + Respects the ordering defined in the Request class, where requests with a smaller + value of `priority` are processed first. If multiple requests have the same + priority, the one with the earlier `arrival_time` is processed first.""" + + def _request_to_heap(self, request: Request) -> Request: """For priority queue, the heap element is the request itself.""" return request - def _from_heap_element(self, heap_element: object) -> Request: + def _heap_to_request(self, element: Request) -> Request: """Extract request from heap element with type checking.""" - assert isinstance(heap_element, Request) - return heap_element + return element -class SJFRequestQueue(HeapBasedRequestQueue): - """ - A Shortest Job First (SJF) queue where requests are ordered by weighted score. - Requests with higher weighted scores (shorter jobs) are processed first. - """ +class SJFRequestQueue(RequestHeap): + """A Shortest Job First (SJF) queue where requests are ordered by weighted score. + Requests with higher weighted scores (shorter jobs) are processed first.""" - def _to_heap_element(self, request: Request) -> tuple[WeightedScoreSorter, Request]: + def _request_to_heap(self, request: Request) -> tuple[WeightedScoreSorter, Request]: """Convert request to (weighted_score, request) tuple for heap.""" assert request.prompt_token_ids is not None return ( @@ -254,26 +236,11 @@ class SJFRequestQueue(HeapBasedRequestQueue): request, ) - def _from_heap_element(self, heap_element: object) -> Request: + def _heap_to_request(self, element: tuple[WeightedScoreSorter, Request]) -> Request: """Extract request from the (score, request) tuple with type checking.""" - assert isinstance(heap_element, tuple) and len(heap_element) == 2 - _, request = heap_element + _, request = element return request - def remove_request(self, request: Request) -> None: - """Remove a specific request from the SJF heap.""" - original_length = len(self._heap) - self._heap = [(ws, r) for (ws, r) in self._heap if r != request] - if len(self._heap) == original_length: - raise ValueError(f"Request {request.request_id} not found in SJF queue") - heapq.heapify(self._heap) - - def remove_requests(self, requests: Iterable[Request]) -> None: - """Remove multiple specific requests from the SJF heap.""" - requests_to_remove = set(requests) - self._heap = [(ws, r) for (ws, r) in self._heap if r not in requests_to_remove] - heapq.heapify(self._heap) - def create_request_queue(policy: SchedulingPolicy) -> RequestQueue: """Create request queue based on scheduling policy."""