diff --git a/vllm/v1/core/sched/policy/weighted_score_softer.py b/vllm/v1/core/sched/policy/weighted_score_softer.py index cfc6c7bc55bca..b4cd54e30124e 100644 --- a/vllm/v1/core/sched/policy/weighted_score_softer.py +++ b/vllm/v1/core/sched/policy/weighted_score_softer.py @@ -4,6 +4,7 @@ import time from functools import total_ordering from vllm.v1.core.sched.policy.normalized_scorer import TimeAndLengthScorer +from typing import Optional, List, Any TimeAndLengthScorer_Instance = None @@ -23,7 +24,7 @@ class WeightedScoreSorter: self, request_length: int, request_arrival_time: float, - request_slo_requirement: list = None, + request_slo_requirement: Optional[List[Any]] = None, ): self.request_length = request_length self.request_arrival_time = request_arrival_time @@ -34,11 +35,14 @@ class WeightedScoreSorter: self.__update_stats() return self.weighted_score > other_request_weighted_score.weighted_score - def __eq__(self, other_request_weighted_score: "WeightedScoreSorter") -> bool: + def __eq__(self, other_request_weighted_score: object) -> bool: + if not isinstance(other_request_weighted_score, WeightedScoreSorter): + return NotImplemented return self.weighted_score == other_request_weighted_score.weighted_score def __update_stats(self): self.wait_time = time.time() - self.request_arrival_time + assert TimeAndLengthScorer_Instance is not None self.weighted_score = TimeAndLengthScorer_Instance.score( self.wait_time, self.request_length ) diff --git a/vllm/v1/core/sched/request_queue.py b/vllm/v1/core/sched/request_queue.py index 9200ea524362a..3d28fb3652c54 100644 --- a/vllm/v1/core/sched/request_queue.py +++ b/vllm/v1/core/sched/request_queue.py @@ -221,6 +221,7 @@ class SJFRequestQueue(RequestQueue): def add_request(self, request: Request) -> None: """Add a request to the queue according to SJF policy.""" + assert request.prompt_token_ids is not None, "prompt_token_ids cannot be None for SJF scheduling." heapq.heappush( self._heap, ( @@ -283,7 +284,7 @@ class SJFRequestQueue(RequestQueue): """Iterate over the queue according to SJF policy.""" heap_copy = self._heap[:] while heap_copy: - _, _, request = heapq.heappop(heap_copy) + _, request = heapq.heappop(heap_copy) yield request def __reversed__(self) -> Iterator[Request]: