diff --git a/vllm/v1/core/sched/policy/weighted_score_sorter.py b/vllm/v1/core/sched/policy/weighted_score_sorter.py index 4828559d27695..db72de3b4dbe1 100644 --- a/vllm/v1/core/sched/policy/weighted_score_sorter.py +++ b/vllm/v1/core/sched/policy/weighted_score_sorter.py @@ -2,9 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import time from functools import total_ordering -from typing import Any from vllm.v1.core.sched.policy.normalized_scorer import TimeAndLengthScorer +from vllm.v1.request import Request TimeAndLengthScorer_Instance = None @@ -20,15 +20,11 @@ if TimeAndLengthScorer_Instance is None: @total_ordering class WeightedScoreSorter: - def __init__( - self, - request_length: int, - request_arrival_time: float, - request_slo_requirement: list[Any] | None = None, - ): - self.request_length = request_length - self.request_arrival_time = request_arrival_time - self.request_slo_requirement = request_slo_requirement + def __init__(self, request: Request): + self.request = request + assert request.prompt_token_ids is not None + self.request_length = len(request.prompt_token_ids) + self.request_arrival_time = request.arrival_time self.__update_stats() def __lt__(self, other_request_weighted_score: "WeightedScoreSorter") -> bool: diff --git a/vllm/v1/core/sched/request_queue.py b/vllm/v1/core/sched/request_queue.py index 2c0a9d62eaae3..e339b5de6255f 100644 --- a/vllm/v1/core/sched/request_queue.py +++ b/vllm/v1/core/sched/request_queue.py @@ -228,18 +228,13 @@ class SJFRequestQueue(RequestHeap): """A Shortest Job First (SJF) queue where requests are ordered by weighted score. Requests with higher weighted scores (shorter jobs) are processed first.""" - def _request_to_heap(self, request: Request) -> tuple[WeightedScoreSorter, Request]: - """Convert request to (weighted_score, request) tuple for heap.""" - assert request.prompt_token_ids is not None - return ( - WeightedScoreSorter(len(request.prompt_token_ids), request.arrival_time), - request, - ) + def _request_to_heap(self, request: Request) -> WeightedScoreSorter: + """Convert request to `WeightedScoreSorter` for heap.""" + return WeightedScoreSorter(request) - def _heap_to_request(self, element: tuple[WeightedScoreSorter, Request]) -> Request: - """Extract request from the (score, request) tuple with type checking.""" - _, request = element - return request + def _heap_to_request(self, element: WeightedScoreSorter) -> Request: + """Extract request from the `WeightedScoreSorter`.""" + return element.request def create_request_queue(policy: SchedulingPolicy) -> RequestQueue: