abstracting common code to HeapBasedRequestQueue

Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Signed-off-by: weichen <calvin_zhu0210@outlook.com>
This commit is contained in:
Pr0Wh1teGivee 2025-12-09 15:26:05 +08:00 committed by weichen
parent cc0a8ae572
commit 4fe722fae5

View File

@ -135,58 +135,72 @@ class FCFSRequestQueue(deque[Request], RequestQueue):
return super().__reversed__()
class PriorityRequestQueue(RequestQueue):
"""
A priority queue that supports heap operations.
Respects the ordering defined in the Request class, where
requests with a smaller value of `priority` are processed first.
If multiple requests have the same priority, the one with the earlier
`arrival_time` is processed first.
"""
class HeapBasedRequestQueue(RequestQueue, ABC):
"""Base class for heap-based request queues (priority and SJF)."""
def __init__(self) -> None:
self._heap: list[Request] = []
self._heap: list = []
@abstractmethod
def _to_heap_element(self, request: Request) -> object:
"""Convert a request to the appropriate heap element."""
pass
@abstractmethod
def _from_heap_element(self, heap_element: object) -> Request:
"""Extract the request from a heap element."""
pass
def add_request(self, request: Request) -> None:
"""Add a request to the queue according to priority policy."""
heapq.heappush(self._heap, request)
"""Add a request to the heap queue."""
heap_element = self._to_heap_element(request)
heapq.heappush(self._heap, heap_element)
def pop_request(self) -> Request:
"""Pop a request from the queue according to priority policy."""
"""Pop the highest priority request from the heap."""
if not self._heap:
raise IndexError("pop from empty heap")
return heapq.heappop(self._heap)
heap_element = heapq.heappop(self._heap)
return self._from_heap_element(heap_element)
def peek_request(self) -> Request:
"""Peek at the next request in the queue without removing it."""
"""Peek at the highest priority request without removing it."""
if not self._heap:
raise IndexError("peek from empty heap")
return self._heap[0]
return self._from_heap_element(self._heap[0])
def prepend_request(self, request: Request) -> None:
"""Add a request to the queue according to priority policy.
Note: In a priority queue, there is no concept of prepending to the
front. Requests are ordered by (priority, arrival_time)."""
"""
Add request to the queue. In heap-based queues, "prepend" has no
special meaning as elements are ordered by priority/score. This
behaves like add_request.
"""
self.add_request(request)
def prepend_requests(self, requests: RequestQueue) -> None:
"""Add all requests from another queue according to priority policy.
Note: In a priority queue, there is no concept of prepending to the
front. Requests are ordered by (priority, arrival_time)."""
"""
Add all requests from another queue. In heap-based queues,
"prepend" has no special meaning as elements are ordered by
priority/score. This behaves like adding all requests.
"""
for request in requests:
self.add_request(request)
def remove_request(self, request: Request) -> None:
"""Remove a specific request from the queue."""
self._heap.remove(request)
heapq.heapify(self._heap)
"""Remove a specific request from the heap."""
try:
self._heap.remove(request)
heapq.heapify(self._heap)
except ValueError as err:
raise ValueError(
f"Request {request.request_id} not found in queue"
) from err
def remove_requests(self, requests: Iterable[Request]) -> None:
"""Remove multiple specific requests from the queue."""
requests_to_remove = requests if isinstance(requests, set) else set(requests)
"""Remove multiple specific requests from the heap."""
requests_to_remove = (
set(requests) if not isinstance(requests, set) else requests
)
self._heap = [r for r in self._heap if r not in requests_to_remove]
heapq.heapify(self._heap)
@ -199,98 +213,67 @@ class PriorityRequestQueue(RequestQueue):
return len(self._heap)
def __iter__(self) -> Iterator[Request]:
"""Iterate over the queue according to priority policy."""
heap_copy = self._heap[:]
"""Iterate over requests in priority/score order."""
heap_copy = self._heap.copy()
while heap_copy:
yield heapq.heappop(heap_copy)
heap_element = heapq.heappop(heap_copy)
yield self._from_heap_element(heap_element)
def __reversed__(self) -> Iterator[Request]:
"""Iterate over the queue in reverse priority order."""
"""Iterate over requests in reverse priority/score order."""
return reversed(list(self))
class SJFRequestQueue(RequestQueue):
class PriorityRequestQueue(HeapBasedRequestQueue):
"""
A SJF queue that supports heap operations.
Requests with a larger value of weighted score value are processed first.
A priority queue where requests are ordered by (priority, arrival_time).
Lower priority values and earlier arrival times are processed first.
"""
def __init__(self) -> None:
self._heap: list[tuple[WeightedScoreSorter, Request]] = []
def _to_heap_element(self, request: Request) -> Request:
"""For priority queue, the heap element is the request itself."""
return request
def add_request(self, request: Request) -> None:
"""Add a request to the queue according to SJF policy."""
def _from_heap_element(self, heap_element: object) -> Request:
"""Extract request from heap element with type checking."""
assert isinstance(heap_element, Request)
return heap_element
class SJFRequestQueue(HeapBasedRequestQueue):
"""
A Shortest Job First (SJF) queue where requests are ordered by weighted score.
Requests with higher weighted scores (shorter jobs) are processed first.
"""
def _to_heap_element(self, request: Request) -> tuple[WeightedScoreSorter, Request]:
"""Convert request to (weighted_score, request) tuple for heap."""
assert request.prompt_token_ids is not None
heapq.heappush(
self._heap,
(
WeightedScoreSorter(
len(request.prompt_token_ids), request.arrival_time
),
request,
),
return (
WeightedScoreSorter(len(request.prompt_token_ids), request.arrival_time),
request,
)
def pop_request(self) -> Request:
"""Pop a request from the queue according to SJF policy."""
if not self._heap:
raise IndexError("pop from empty heap")
_, request = heapq.heappop(self._heap)
def _from_heap_element(self, heap_element: object) -> Request:
"""Extract request from the (score, request) tuple with type checking."""
assert isinstance(heap_element, tuple) and len(heap_element) == 2
_, request = heap_element
return request
def peek_request(self) -> Request:
"""Peek at the next request in the queue without removing it."""
if not self._heap:
raise IndexError("peek from empty heap")
_, request = self._heap[0]
return request
def prepend_request(self, request: Request) -> None:
"""Add a request to the queue according to SJF policy.
Note: In a SJF queue, there is no concept of prepending to the
front. Requests are ordered by (priority, arrival_time)."""
self.add_request(request)
def prepend_requests(self, requests: RequestQueue) -> None:
"""Add all requests from another queue according to SJF policy.
Note: In a SJF queue, there is no concept of prepending to the
front. Requests are ordered by weighted score."""
for request in requests:
self.add_request(request)
def remove_request(self, request: Request) -> None:
"""Remove a specific request from the queue."""
self._heap = [(ws, r) for ws, r in self._heap if r != request]
"""Remove a specific request from the SJF heap."""
original_length = len(self._heap)
self._heap = [(ws, r) for (ws, r) in self._heap if r != request]
if len(self._heap) == original_length:
raise ValueError(f"Request {request.request_id} not found in SJF queue")
heapq.heapify(self._heap)
def remove_requests(self, requests: Iterable[Request]) -> None:
"""Remove multiple specific requests from the queue."""
"""Remove multiple specific requests from the SJF heap."""
requests_to_remove = set(requests)
self._heap = [(ws, r) for ws, r in self._heap if r not in requests_to_remove]
self._heap = [(ws, r) for (ws, r) in self._heap if r not in requests_to_remove]
heapq.heapify(self._heap)
def __bool__(self) -> bool:
"""Check if queue has any requests."""
return bool(self._heap)
def __len__(self) -> int:
"""Get number of requests in queue."""
return len(self._heap)
def __iter__(self) -> Iterator[Request]:
"""Iterate over the queue according to SJF policy."""
heap_copy = self._heap[:]
while heap_copy:
_, request = heapq.heappop(heap_copy)
yield request
def __reversed__(self) -> Iterator[Request]:
"""Iterate over the queue in reverse SJF order."""
return reversed(list(self))
def create_request_queue(policy: SchedulingPolicy) -> RequestQueue:
"""Create request queue based on scheduling policy."""