From 0098c3fb936ea941b3d806fc2ada13ea4731d9b5 Mon Sep 17 00:00:00 2001 From: Pr0Wh1teGivee Date: Tue, 25 Nov 2025 11:13:52 +0800 Subject: [PATCH] [Feat][Sched] Add SJF Scheduling Policy Co-authored-by: HiC4Sh1e Co-authored-by: JiahongZhang-Work Signed-off-by: Pr0Wh1teGivee Signed-off-by: weichen --- vllm/config/scheduler.py | 2 +- .../v1/core/sched/policy/normalized_scorer.py | 82 ++++++++++ .../sched/policy/weighted_score_softer.py | 28 ++++ vllm/v1/core/sched/request_queue.py | 149 ++++++++++++++++++ 4 files changed, 260 insertions(+), 1 deletion(-) create mode 100644 vllm/v1/core/sched/policy/normalized_scorer.py create mode 100644 vllm/v1/core/sched/policy/weighted_score_softer.py diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index 8abbe8ba0103e..1fe09a6ae2ce3 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -20,7 +20,7 @@ if TYPE_CHECKING: logger = init_logger(__name__) RunnerType = Literal["generate", "pooling", "draft"] -SchedulerPolicy = Literal["fcfs", "priority"] +SchedulerPolicy = Literal["fcfs", "priority", "sjf"] @config diff --git a/vllm/v1/core/sched/policy/normalized_scorer.py b/vllm/v1/core/sched/policy/normalized_scorer.py new file mode 100644 index 0000000000000..7b7e83cbbd708 --- /dev/null +++ b/vllm/v1/core/sched/policy/normalized_scorer.py @@ -0,0 +1,82 @@ +from typing import List + +from vllm.logger import init_logger + +import math + +logger = init_logger(__name__) + +class ScoreDim: + """ + Normalized scoring dimension. + """ + def __init__(self, name: str, median: float, norm_scale=0.0, weight=0.5, reverse=False): + self.name = name + self.median = median + if norm_scale != 0.0: + self.norm_scale = norm_scale + else: + self.norm_scale = 1/median + self.weight = weight + self.reverse = reverse + +class NormalizedScorer: + """ + Normalize unbounded N-dimensional values into a composite score using the Sigmoid function. + """ + + def __init__(self, dim_list: List[ScoreDim]) -> None: + """ + :param dim_list: Scoring dimensions; each dimension must define a median reference point, scaling factor, and weight. + """ + self.dim_list = dim_list + self.dim_count = len(dim_list) + + @staticmethod + def _sigmoid_normalize(value, median, norm_scale): + """Sigmoid function: Maps value to (0, 1).""" + return 1 / (1 + math.exp(-norm_scale * (value - median))) + + @staticmethod + def _inv_sigmoid_normalize(value, median, norm_scale): + """Inverse Sigmoid: Used for dimensions where a larger value yields a lower score.""" + # Equivalent to sigmoid(-x), but more numerically stable. + return 1 / (1 + math.exp(norm_scale * (value - median))) + + def score(self, *dims: float) -> float: + """ + Compute the composite score. + Larger value → higher score → use forward Sigmoid. + Smaller value → higher score → use inverse Sigmoid. + """ + if len(dims) > self.dim_count: + raise ValueError(f"Dim num({len(dims)}) exceeds max num dim({self.dim_count})") + + final_score = 0.0 + for idx, dim_value in enumerate(dims): + dim_info = self.dim_list[idx] + if dim_info.reverse: + score = self._inv_sigmoid_normalize(dim_value, dim_info.median, dim_info.norm_scale) + else: + score = self._sigmoid_normalize(dim_value, dim_info.median, dim_info.norm_scale) + logger.debug(f"{dim_info.name}({dim_info.reverse}) : {score:.10f}") + + # Weighted summation. + final_score += score * dim_info.weight + return max(0.0, min(1.0, final_score)) # Clamp to [0, 1]. + +class TimeAndLengthScorer(NormalizedScorer): + """ + Scorer for time and length dimensions; defaults to forward scoring with equal weights (0.5 each). + """ + def __init__(self, + time_median, length_median, + time_scale=0.0, length_scale=0.0, + time_weight=0.5, length_weight=0.5, + reverse_time=False, reverse_len=False) -> None: + dim_list = [ScoreDim("time", time_median, time_scale, time_weight, reverse_time), + ScoreDim("length", length_median, length_scale, length_weight, reverse_len)] + super().__init__(dim_list) + + def score(self, time: float, length: float) -> float: + return super().score(time, length) diff --git a/vllm/v1/core/sched/policy/weighted_score_softer.py b/vllm/v1/core/sched/policy/weighted_score_softer.py new file mode 100644 index 0000000000000..17be66a9c6754 --- /dev/null +++ b/vllm/v1/core/sched/policy/weighted_score_softer.py @@ -0,0 +1,28 @@ +from functools import total_ordering + +from vllm.v1.core.sched.policy.normalized_scorer import TimeAndLengthScorer +import time + +TimeAndLengthScorer_Instance = None + +if TimeAndLengthScorer_Instance == None: + TimeAndLengthScorer_Instance = TimeAndLengthScorer(time_median=5, time_weight=0.5, length_median=32 * 1024, + length_weight=0.5, reverse_len=True) +@total_ordering +class WeightedScoreSorter: + def __init__(self, request_length: int, request_arrival_time: float, request_slo_requirement: list = None): + self.request_length = request_length + self.request_arrival_time = request_arrival_time + self.request_slo_requirement = request_slo_requirement + self.__update_stats() + + def __lt__(self, other_request_weighted_score: 'WeightedScoreSorter') -> bool: + self.__update_stats() + return self.weighted_score > other_request_weighted_score.weighted_score + + def __eq__(self, other_request_weighted_score: 'WeightedScoreSorter') -> bool: + return self.weighted_score == other_request_weighted_score.weighted_score + + def __update_stats(self): + self.wait_time = time.time() - self.request_arrival_time + self.weighted_score = TimeAndLengthScorer_Instance.score(self.wait_time, self.request_length) diff --git a/vllm/v1/core/sched/request_queue.py b/vllm/v1/core/sched/request_queue.py index a00ca1912b0f3..792b5098578ab 100644 --- a/vllm/v1/core/sched/request_queue.py +++ b/vllm/v1/core/sched/request_queue.py @@ -8,6 +8,7 @@ from collections.abc import Iterable, Iterator from enum import Enum from vllm.v1.request import Request +from vllm.v1.core.sched.policy.weighted_score_softer import WeightedScoreSorter class SchedulingPolicy(Enum): @@ -15,6 +16,7 @@ class SchedulingPolicy(Enum): FCFS = "fcfs" PRIORITY = "priority" + SJF = "sjf" class RequestQueue(ABC): @@ -207,11 +209,158 @@ class PriorityRequestQueue(RequestQueue): return reversed(list(self)) +class SJFRequestQueue(deque[Request], RequestQueue): + """A short-job-first queue that supports deque operations.""" + + def __init__(self): + deque.__init__(self) + + def add_request(self, request: Request) -> None: + """Add a request to the queue according to SJF policy.""" + self.append(request) + self._sort_requests() + + def pop_request(self) -> Request: + """Pop a request from the queue according to SJF policy.""" + return self.popleft() + + def peek_request(self) -> Request: + """Peek at the next request in the queue without removing it.""" + if not self: + raise IndexError("peek from an empty queue") + self._sort_requests() + return self[0] + + def prepend_request(self, request: Request) -> None: + """Prepend a request to the front of the queue.""" + self.appendleft(request) + + def prepend_requests(self, requests: RequestQueue) -> None: + """Prepend all requests from another queue to the front of this + queue.""" + self.extendleft(reversed(requests)) + + def remove_request(self, request: Request) -> None: + """Remove a specific request from the queue.""" + self.remove(request) + + def remove_requests(self, requests: Iterable[Request]) -> None: + """Remove multiple specific requests from the queue.""" + requests_to_remove = set(requests) + filtered_requests = [ + req for req in self if req not in requests_to_remove + ] + # deque does not support in-place filtering, so we need to clear + # and extend + self.clear() + self.extend(filtered_requests) + + def _sort_requests(self, reverse = False) -> None: + key_func = lambda req: WeightedScoreSorter(request_length=len(req.prompt_token_ids), request_arrival_time=req.arrival_time) + sorted_list = sorted(self, key=key_func, reverse=reverse) + self.clear() + self.extend(sorted_list) + + def __bool__(self) -> bool: + """Check if queue has any requests.""" + return len(self) > 0 + + def __len__(self) -> int: + """Get number of requests in queue.""" + return super().__len__() + + def __iter__(self) -> Iterator[Request]: + """Iterate over the queue according to SJF policy.""" + return super().__iter__() + + def __reversed__(self) -> Iterator[Request]: + """Iterate over the queue in reverse order.""" + return super().__reversed__() + + +class SJFRequestQueueInHeap(RequestQueue): + """ + A SJF queue that supports heap operations. + + Requests with a larger value of weighted score value are processed first. + """ + + def __init__(self) -> None: + self._heap: list[tuple[WeightedScoreSorter, Request]] = [] + + def add_request(self, request: Request) -> None: + """Add a request to the queue according to SJF policy.""" + heapq.heappush(self._heap, + (WeightedScoreSorter(len(request.prompt_token_ids), request.arrival_time), request)) + + def pop_request(self) -> Request: + """Pop a request from the queue according to SJF policy.""" + if not self._heap: + raise IndexError("pop from empty heap") + _, request = heapq.heappop(self._heap) + return request + + def peek_request(self) -> Request: + """Peek at the next request in the queue without removing it.""" + if not self._heap: + raise IndexError("peek from empty heap") + _, request = self._heap[0] + return request + + def prepend_request(self, request: Request) -> None: + """Add a request to the queue according to SJF policy. + + Note: In a SJF queue, there is no concept of prepending to the + front. Requests are ordered by (priority, arrival_time).""" + self.add_request(request) + + def prepend_requests(self, requests: RequestQueue) -> None: + """Add all requests from another queue according to SJF policy. + + Note: In a SJF queue, there is no concept of prepending to the + front. Requests are ordered by weighted score.""" + for request in requests: + self.add_request(request) + + def remove_request(self, request: Request) -> None: + """Remove a specific request from the queue.""" + self._heap = [(ws, r) for ws, r in self._heap if r != request] + heapq.heapify(self._heap) + + def remove_requests(self, requests: Iterable[Request]) -> None: + """Remove multiple specific requests from the queue.""" + requests_to_remove = set(requests) + self._heap = [(ws, r) for ws , r in self._heap + if r not in requests_to_remove] + heapq.heapify(self._heap) + + def __bool__(self) -> bool: + """Check if queue has any requests.""" + return bool(self._heap) + + def __len__(self) -> int: + """Get number of requests in queue.""" + return len(self._heap) + + def __iter__(self) -> Iterator[Request]: + """Iterate over the queue according to SJF policy.""" + heap_copy = self._heap[:] + while heap_copy: + _, _, request = heapq.heappop(heap_copy) + yield request + + def __reversed__(self) -> Iterator[Request]: + """Iterate over the queue in reverse SJF order.""" + return reversed(list(self)) + + def create_request_queue(policy: SchedulingPolicy) -> RequestQueue: """Create request queue based on scheduling policy.""" if policy == SchedulingPolicy.PRIORITY: return PriorityRequestQueue() elif policy == SchedulingPolicy.FCFS: return FCFSRequestQueue() + elif policy == SchedulingPolicy.SJF: + return SJFRequestQueue() else: raise ValueError(f"Unknown scheduling policy: {policy}")