Consolidate SJF code and remove global variable

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: weichen <calvin_zhu0210@outlook.com>
This commit is contained in:
Harry Mellor 2025-12-18 17:34:28 +01:00 committed by weichen
parent 53d57d9dca
commit 9e8d9e1231
3 changed files with 148 additions and 46 deletions

View File

@ -0,0 +1,139 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
import time
from functools import total_ordering
from vllm.logger import init_logger
from vllm.v1.request import Request
logger = init_logger(__name__)
class ScoreDim:
"""
Normalized scoring dimension.
"""
def __init__(
self, name: str, median: float, norm_scale=0.0, weight=0.5, reverse=False
):
self.name = name
self.median = median
if norm_scale != 0.0:
self.norm_scale = norm_scale
else:
self.norm_scale = 1 / median
self.weight = weight
self.reverse = reverse
class NormalizedScorer:
"""
Normalize unbounded N-dimensional values into a composite score using the Sigmoid
function.
"""
def __init__(self, dim_list: list[ScoreDim]) -> None:
"""
Initialize the scorer with a list of scoring dimensions.
Args:
dim_list: A list of `ScoreDim` objects. Each dimension must define a
median reference point, scaling factor, and weight.
"""
self.dim_list = dim_list
self.dim_count = len(dim_list)
@staticmethod
def _sigmoid_normalize(value, median, norm_scale):
"""Sigmoid function: Maps value to (0, 1)."""
return 1 / (1 + math.exp(-norm_scale * (value - median)))
@staticmethod
def _inv_sigmoid_normalize(value, median, norm_scale):
"""Inverse Sigmoid: Used for dimensions where a larger value yields a lower
score.
"""
# Equivalent to sigmoid(-x), but more numerically stable.
return 1 / (1 + math.exp(norm_scale * (value - median)))
def score(self, *dims: float) -> float:
"""
Compute the composite score.
Larger value higher score use forward Sigmoid.
Smaller value higher score use inverse Sigmoid.
"""
if len(dims) > self.dim_count:
raise ValueError(
f"Dim num({len(dims)}) exceeds max num dim({self.dim_count})"
)
final_score = 0.0
for idx, dim_value in enumerate(dims):
dim_info = self.dim_list[idx]
if dim_info.reverse:
score = self._inv_sigmoid_normalize(
dim_value, dim_info.median, dim_info.norm_scale
)
else:
score = self._sigmoid_normalize(
dim_value, dim_info.median, dim_info.norm_scale
)
logger.debug("%s(%s) : %.10f", dim_info.name, dim_info.reverse, score)
# Weighted summation.
final_score += score * dim_info.weight
return max(0.0, min(1.0, final_score)) # Clamp to [0, 1].
class TimeAndLengthScorer(NormalizedScorer):
"""
Scorer for time and length dimensions; defaults to forward scoring with equal
weights (0.5 each).
"""
def __init__(
self,
time_median=5,
length_median=1024 * 32,
time_scale=0.0,
length_scale=0.0,
time_weight=0.5,
length_weight=0.5,
reverse_time=False,
reverse_len=True,
) -> None:
dim_list = [
ScoreDim("time", time_median, time_scale, time_weight, reverse_time),
ScoreDim("length", length_median, length_scale, length_weight, reverse_len),
]
super().__init__(dim_list)
def score(self, *dims: float) -> float:
assert len(dims) == 2
return super().score(*dims)
@total_ordering
class WeightedScoreSorter:
def __init__(self, request: Request, scorer: TimeAndLengthScorer):
self.request = request
self.scorer = scorer
assert request.prompt_token_ids is not None
self.request_length = len(request.prompt_token_ids)
self.request_arrival_time = request.arrival_time
self.__update_stats()
def __lt__(self, other_request_weighted_score: "WeightedScoreSorter") -> bool:
self.__update_stats()
return self.weighted_score > other_request_weighted_score.weighted_score
def __eq__(self, other_request_weighted_score: object) -> bool:
if not isinstance(other_request_weighted_score, WeightedScoreSorter):
return NotImplemented
return self.weighted_score == other_request_weighted_score.weighted_score
def __update_stats(self):
self.wait_time = time.time() - self.request_arrival_time
self.weighted_score = self.scorer.score(self.wait_time, self.request_length)

View File

@ -1,44 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from functools import total_ordering
from vllm.v1.core.sched.policy.normalized_scorer import TimeAndLengthScorer
from vllm.v1.request import Request
TimeAndLengthScorer_Instance = None
if TimeAndLengthScorer_Instance is None:
TimeAndLengthScorer_Instance = TimeAndLengthScorer(
time_median=5,
time_weight=0.5,
length_median=32 * 1024,
length_weight=0.5,
reverse_len=True,
)
@total_ordering
class WeightedScoreSorter:
def __init__(self, request: Request):
self.request = request
assert request.prompt_token_ids is not None
self.request_length = len(request.prompt_token_ids)
self.request_arrival_time = request.arrival_time
self.__update_stats()
def __lt__(self, other_request_weighted_score: "WeightedScoreSorter") -> bool:
self.__update_stats()
return self.weighted_score > other_request_weighted_score.weighted_score
def __eq__(self, other_request_weighted_score: object) -> bool:
if not isinstance(other_request_weighted_score, WeightedScoreSorter):
return NotImplemented
return self.weighted_score == other_request_weighted_score.weighted_score
def __update_stats(self):
self.wait_time = time.time() - self.request_arrival_time
assert TimeAndLengthScorer_Instance is not None
self.weighted_score = TimeAndLengthScorer_Instance.score(
self.wait_time, self.request_length
)

View File

@ -8,7 +8,10 @@ from collections.abc import Iterable, Iterator
from enum import Enum from enum import Enum
from typing import Any from typing import Any
from vllm.v1.core.sched.policy.weighted_score_sorter import WeightedScoreSorter from vllm.v1.core.sched.policy.shortest_job_first import (
TimeAndLengthScorer,
WeightedScoreSorter,
)
from vllm.v1.request import Request from vllm.v1.request import Request
@ -228,9 +231,13 @@ class SJFRequestQueue(RequestHeap):
"""A Shortest Job First (SJF) queue where requests are ordered by weighted score. """A Shortest Job First (SJF) queue where requests are ordered by weighted score.
Requests with higher weighted scores (shorter jobs) are processed first.""" Requests with higher weighted scores (shorter jobs) are processed first."""
def __init__(self):
super().__init__()
self.scorer = TimeAndLengthScorer()
def _request_to_heap(self, request: Request) -> WeightedScoreSorter: def _request_to_heap(self, request: Request) -> WeightedScoreSorter:
"""Convert request to `WeightedScoreSorter` for heap.""" """Convert request to `WeightedScoreSorter` for heap."""
return WeightedScoreSorter(request) return WeightedScoreSorter(request, self.scorer)
def _heap_to_request(self, element: WeightedScoreSorter) -> Request: def _heap_to_request(self, element: WeightedScoreSorter) -> Request:
"""Extract request from the `WeightedScoreSorter`.""" """Extract request from the `WeightedScoreSorter`."""