From 9e8d9e1231620bb3a4579a2fb9df7865ce41ce54 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 18 Dec 2025 17:34:28 +0100 Subject: [PATCH] Consolidate SJF code and remove global variable Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: weichen --- .../core/sched/policy/shortest_job_first.py | 139 ++++++++++++++++++ .../sched/policy/weighted_score_sorter.py | 44 ------ vllm/v1/core/sched/request_queue.py | 11 +- 3 files changed, 148 insertions(+), 46 deletions(-) create mode 100644 vllm/v1/core/sched/policy/shortest_job_first.py delete mode 100644 vllm/v1/core/sched/policy/weighted_score_sorter.py diff --git a/vllm/v1/core/sched/policy/shortest_job_first.py b/vllm/v1/core/sched/policy/shortest_job_first.py new file mode 100644 index 0000000000000..4d9ba0de11b65 --- /dev/null +++ b/vllm/v1/core/sched/policy/shortest_job_first.py @@ -0,0 +1,139 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math +import time +from functools import total_ordering + +from vllm.logger import init_logger +from vllm.v1.request import Request + +logger = init_logger(__name__) + + +class ScoreDim: + """ + Normalized scoring dimension. + """ + + def __init__( + self, name: str, median: float, norm_scale=0.0, weight=0.5, reverse=False + ): + self.name = name + self.median = median + if norm_scale != 0.0: + self.norm_scale = norm_scale + else: + self.norm_scale = 1 / median + self.weight = weight + self.reverse = reverse + + +class NormalizedScorer: + """ + Normalize unbounded N-dimensional values into a composite score using the Sigmoid + function. + """ + + def __init__(self, dim_list: list[ScoreDim]) -> None: + """ + Initialize the scorer with a list of scoring dimensions. + + Args: + dim_list: A list of `ScoreDim` objects. Each dimension must define a + median reference point, scaling factor, and weight. + """ + self.dim_list = dim_list + self.dim_count = len(dim_list) + + @staticmethod + def _sigmoid_normalize(value, median, norm_scale): + """Sigmoid function: Maps value to (0, 1).""" + return 1 / (1 + math.exp(-norm_scale * (value - median))) + + @staticmethod + def _inv_sigmoid_normalize(value, median, norm_scale): + """Inverse Sigmoid: Used for dimensions where a larger value yields a lower + score. + """ + # Equivalent to sigmoid(-x), but more numerically stable. + return 1 / (1 + math.exp(norm_scale * (value - median))) + + def score(self, *dims: float) -> float: + """ + Compute the composite score. + Larger value → higher score → use forward Sigmoid. + Smaller value → higher score → use inverse Sigmoid. + """ + if len(dims) > self.dim_count: + raise ValueError( + f"Dim num({len(dims)}) exceeds max num dim({self.dim_count})" + ) + + final_score = 0.0 + for idx, dim_value in enumerate(dims): + dim_info = self.dim_list[idx] + if dim_info.reverse: + score = self._inv_sigmoid_normalize( + dim_value, dim_info.median, dim_info.norm_scale + ) + else: + score = self._sigmoid_normalize( + dim_value, dim_info.median, dim_info.norm_scale + ) + logger.debug("%s(%s) : %.10f", dim_info.name, dim_info.reverse, score) + + # Weighted summation. + final_score += score * dim_info.weight + return max(0.0, min(1.0, final_score)) # Clamp to [0, 1]. + + +class TimeAndLengthScorer(NormalizedScorer): + """ + Scorer for time and length dimensions; defaults to forward scoring with equal + weights (0.5 each). + """ + + def __init__( + self, + time_median=5, + length_median=1024 * 32, + time_scale=0.0, + length_scale=0.0, + time_weight=0.5, + length_weight=0.5, + reverse_time=False, + reverse_len=True, + ) -> None: + dim_list = [ + ScoreDim("time", time_median, time_scale, time_weight, reverse_time), + ScoreDim("length", length_median, length_scale, length_weight, reverse_len), + ] + super().__init__(dim_list) + + def score(self, *dims: float) -> float: + assert len(dims) == 2 + return super().score(*dims) + + +@total_ordering +class WeightedScoreSorter: + def __init__(self, request: Request, scorer: TimeAndLengthScorer): + self.request = request + self.scorer = scorer + assert request.prompt_token_ids is not None + self.request_length = len(request.prompt_token_ids) + self.request_arrival_time = request.arrival_time + self.__update_stats() + + def __lt__(self, other_request_weighted_score: "WeightedScoreSorter") -> bool: + self.__update_stats() + return self.weighted_score > other_request_weighted_score.weighted_score + + def __eq__(self, other_request_weighted_score: object) -> bool: + if not isinstance(other_request_weighted_score, WeightedScoreSorter): + return NotImplemented + return self.weighted_score == other_request_weighted_score.weighted_score + + def __update_stats(self): + self.wait_time = time.time() - self.request_arrival_time + self.weighted_score = self.scorer.score(self.wait_time, self.request_length) diff --git a/vllm/v1/core/sched/policy/weighted_score_sorter.py b/vllm/v1/core/sched/policy/weighted_score_sorter.py deleted file mode 100644 index db72de3b4dbe1..0000000000000 --- a/vllm/v1/core/sched/policy/weighted_score_sorter.py +++ /dev/null @@ -1,44 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import time -from functools import total_ordering - -from vllm.v1.core.sched.policy.normalized_scorer import TimeAndLengthScorer -from vllm.v1.request import Request - -TimeAndLengthScorer_Instance = None - -if TimeAndLengthScorer_Instance is None: - TimeAndLengthScorer_Instance = TimeAndLengthScorer( - time_median=5, - time_weight=0.5, - length_median=32 * 1024, - length_weight=0.5, - reverse_len=True, - ) - - -@total_ordering -class WeightedScoreSorter: - def __init__(self, request: Request): - self.request = request - assert request.prompt_token_ids is not None - self.request_length = len(request.prompt_token_ids) - self.request_arrival_time = request.arrival_time - self.__update_stats() - - def __lt__(self, other_request_weighted_score: "WeightedScoreSorter") -> bool: - self.__update_stats() - return self.weighted_score > other_request_weighted_score.weighted_score - - def __eq__(self, other_request_weighted_score: object) -> bool: - if not isinstance(other_request_weighted_score, WeightedScoreSorter): - return NotImplemented - return self.weighted_score == other_request_weighted_score.weighted_score - - def __update_stats(self): - self.wait_time = time.time() - self.request_arrival_time - assert TimeAndLengthScorer_Instance is not None - self.weighted_score = TimeAndLengthScorer_Instance.score( - self.wait_time, self.request_length - ) diff --git a/vllm/v1/core/sched/request_queue.py b/vllm/v1/core/sched/request_queue.py index e339b5de6255f..a2e2ba6c67d31 100644 --- a/vllm/v1/core/sched/request_queue.py +++ b/vllm/v1/core/sched/request_queue.py @@ -8,7 +8,10 @@ from collections.abc import Iterable, Iterator from enum import Enum from typing import Any -from vllm.v1.core.sched.policy.weighted_score_sorter import WeightedScoreSorter +from vllm.v1.core.sched.policy.shortest_job_first import ( + TimeAndLengthScorer, + WeightedScoreSorter, +) from vllm.v1.request import Request @@ -228,9 +231,13 @@ class SJFRequestQueue(RequestHeap): """A Shortest Job First (SJF) queue where requests are ordered by weighted score. Requests with higher weighted scores (shorter jobs) are processed first.""" + def __init__(self): + super().__init__() + self.scorer = TimeAndLengthScorer() + def _request_to_heap(self, request: Request) -> WeightedScoreSorter: """Convert request to `WeightedScoreSorter` for heap.""" - return WeightedScoreSorter(request) + return WeightedScoreSorter(request, self.scorer) def _heap_to_request(self, element: WeightedScoreSorter) -> Request: """Extract request from the `WeightedScoreSorter`."""