Fix removal from heap

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: weichen <calvin_zhu0210@outlook.com>
This commit is contained in:
Harry Mellor 2025-12-18 17:20:49 +01:00 committed by weichen
parent 4fe722fae5
commit 601387735c

View File

@ -6,6 +6,7 @@ from abc import ABC, abstractmethod
from collections import deque
from collections.abc import Iterable, Iterator
from enum import Enum
from typing import Any
from vllm.v1.core.sched.policy.weighted_score_sorter import WeightedScoreSorter
from vllm.v1.request import Request
@ -135,73 +136,57 @@ class FCFSRequestQueue(deque[Request], RequestQueue):
return super().__reversed__()
class HeapBasedRequestQueue(RequestQueue, ABC):
"""Base class for heap-based request queues (priority and SJF)."""
class RequestHeap(RequestQueue):
"""A queue that supports heap operations."""
def __init__(self) -> None:
self._heap: list = []
@abstractmethod
def _to_heap_element(self, request: Request) -> object:
def _request_to_heap(self, request: Request) -> Any:
"""Convert a request to the appropriate heap element."""
pass
raise NotImplementedError
@abstractmethod
def _from_heap_element(self, heap_element: object) -> Request:
def _heap_to_request(self, element: Any) -> Request:
"""Extract the request from a heap element."""
pass
raise NotImplementedError
def add_request(self, request: Request) -> None:
"""Add a request to the heap queue."""
heap_element = self._to_heap_element(request)
heapq.heappush(self._heap, heap_element)
"""Add a request to the queue according to heap priority."""
heapq.heappush(self._heap, self._request_to_heap(request))
def pop_request(self) -> Request:
"""Pop the highest priority request from the heap."""
if not self._heap:
raise IndexError("pop from empty heap")
heap_element = heapq.heappop(self._heap)
return self._from_heap_element(heap_element)
return self._heap_to_request(heapq.heappop(self._heap))
def peek_request(self) -> Request:
"""Peek at the highest priority request without removing it."""
"""Peek at the highest priority request in the heap without removing it."""
if not self._heap:
raise IndexError("peek from empty heap")
return self._from_heap_element(self._heap[0])
return self._heap_to_request(self._heap[0])
def prepend_request(self, request: Request) -> None:
"""
Add request to the queue. In heap-based queues, "prepend" has no
special meaning as elements are ordered by priority/score. This
behaves like add_request.
"""
"""Add a request to the heap. In heap-based queues there is no beginning as
elements are ordered by priority/score. This behaves like add_request."""
self.add_request(request)
def prepend_requests(self, requests: RequestQueue) -> None:
"""
Add all requests from another queue. In heap-based queues,
"prepend" has no special meaning as elements are ordered by
priority/score. This behaves like adding all requests.
"""
"""Add all requests from another queue to the heap. In heap-based queues there
is no beginning as elements are ordered by priority/score. This behaves like
add_request."""
for request in requests:
self.add_request(request)
def remove_request(self, request: Request) -> None:
"""Remove a specific request from the heap."""
try:
self._heap.remove(request)
heapq.heapify(self._heap)
except ValueError as err:
raise ValueError(
f"Request {request.request_id} not found in queue"
) from err
self._heap.remove(self._request_to_heap(request))
heapq.heapify(self._heap)
def remove_requests(self, requests: Iterable[Request]) -> None:
"""Remove multiple specific requests from the heap."""
requests_to_remove = (
set(requests) if not isinstance(requests, set) else requests
)
self._heap = [r for r in self._heap if r not in requests_to_remove]
remove = requests if isinstance(requests, set) else set(requests)
self._heap = [h for h in self._heap if self._heap_to_request(h) not in remove]
heapq.heapify(self._heap)
def __bool__(self) -> bool:
@ -213,40 +198,37 @@ class HeapBasedRequestQueue(RequestQueue, ABC):
return len(self._heap)
def __iter__(self) -> Iterator[Request]:
"""Iterate over requests in priority/score order."""
heap_copy = self._heap.copy()
"""Iterate over the queue to heap order."""
heap_copy = self._heap[:]
while heap_copy:
heap_element = heapq.heappop(heap_copy)
yield self._from_heap_element(heap_element)
yield self._heap_to_request(heapq.heappop(heap_copy))
def __reversed__(self) -> Iterator[Request]:
"""Iterate over requests in reverse priority/score order."""
"""Iterate over the queue in reverse heap order."""
return reversed(list(self))
class PriorityRequestQueue(HeapBasedRequestQueue):
"""
A priority queue where requests are ordered by (priority, arrival_time).
Lower priority values and earlier arrival times are processed first.
"""
class PriorityRequestQueue(RequestHeap):
"""A priority queue that supports heap operations.
def _to_heap_element(self, request: Request) -> Request:
Respects the ordering defined in the Request class, where requests with a smaller
value of `priority` are processed first. If multiple requests have the same
priority, the one with the earlier `arrival_time` is processed first."""
def _request_to_heap(self, request: Request) -> Request:
"""For priority queue, the heap element is the request itself."""
return request
def _from_heap_element(self, heap_element: object) -> Request:
def _heap_to_request(self, element: Request) -> Request:
"""Extract request from heap element with type checking."""
assert isinstance(heap_element, Request)
return heap_element
return element
class SJFRequestQueue(HeapBasedRequestQueue):
"""
A Shortest Job First (SJF) queue where requests are ordered by weighted score.
Requests with higher weighted scores (shorter jobs) are processed first.
"""
class SJFRequestQueue(RequestHeap):
"""A Shortest Job First (SJF) queue where requests are ordered by weighted score.
Requests with higher weighted scores (shorter jobs) are processed first."""
def _to_heap_element(self, request: Request) -> tuple[WeightedScoreSorter, Request]:
def _request_to_heap(self, request: Request) -> tuple[WeightedScoreSorter, Request]:
"""Convert request to (weighted_score, request) tuple for heap."""
assert request.prompt_token_ids is not None
return (
@ -254,26 +236,11 @@ class SJFRequestQueue(HeapBasedRequestQueue):
request,
)
def _from_heap_element(self, heap_element: object) -> Request:
def _heap_to_request(self, element: tuple[WeightedScoreSorter, Request]) -> Request:
"""Extract request from the (score, request) tuple with type checking."""
assert isinstance(heap_element, tuple) and len(heap_element) == 2
_, request = heap_element
_, request = element
return request
def remove_request(self, request: Request) -> None:
"""Remove a specific request from the SJF heap."""
original_length = len(self._heap)
self._heap = [(ws, r) for (ws, r) in self._heap if r != request]
if len(self._heap) == original_length:
raise ValueError(f"Request {request.request_id} not found in SJF queue")
heapq.heapify(self._heap)
def remove_requests(self, requests: Iterable[Request]) -> None:
"""Remove multiple specific requests from the SJF heap."""
requests_to_remove = set(requests)
self._heap = [(ws, r) for (ws, r) in self._heap if r not in requests_to_remove]
heapq.heapify(self._heap)
def create_request_queue(policy: SchedulingPolicy) -> RequestQueue:
"""Create request queue based on scheduling policy."""