From 0098c3fb936ea941b3d806fc2ada13ea4731d9b5 Mon Sep 17 00:00:00 2001 From: Pr0Wh1teGivee Date: Tue, 25 Nov 2025 11:13:52 +0800 Subject: [PATCH 01/20] [Feat][Sched] Add SJF Scheduling Policy Co-authored-by: HiC4Sh1e Co-authored-by: JiahongZhang-Work Signed-off-by: Pr0Wh1teGivee Signed-off-by: weichen --- vllm/config/scheduler.py | 2 +- .../v1/core/sched/policy/normalized_scorer.py | 82 ++++++++++ .../sched/policy/weighted_score_softer.py | 28 ++++ vllm/v1/core/sched/request_queue.py | 149 ++++++++++++++++++ 4 files changed, 260 insertions(+), 1 deletion(-) create mode 100644 vllm/v1/core/sched/policy/normalized_scorer.py create mode 100644 vllm/v1/core/sched/policy/weighted_score_softer.py diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index 8abbe8ba0103e..1fe09a6ae2ce3 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -20,7 +20,7 @@ if TYPE_CHECKING: logger = init_logger(__name__) RunnerType = Literal["generate", "pooling", "draft"] -SchedulerPolicy = Literal["fcfs", "priority"] +SchedulerPolicy = Literal["fcfs", "priority", "sjf"] @config diff --git a/vllm/v1/core/sched/policy/normalized_scorer.py b/vllm/v1/core/sched/policy/normalized_scorer.py new file mode 100644 index 0000000000000..7b7e83cbbd708 --- /dev/null +++ b/vllm/v1/core/sched/policy/normalized_scorer.py @@ -0,0 +1,82 @@ +from typing import List + +from vllm.logger import init_logger + +import math + +logger = init_logger(__name__) + +class ScoreDim: + """ + Normalized scoring dimension. + """ + def __init__(self, name: str, median: float, norm_scale=0.0, weight=0.5, reverse=False): + self.name = name + self.median = median + if norm_scale != 0.0: + self.norm_scale = norm_scale + else: + self.norm_scale = 1/median + self.weight = weight + self.reverse = reverse + +class NormalizedScorer: + """ + Normalize unbounded N-dimensional values into a composite score using the Sigmoid function. + """ + + def __init__(self, dim_list: List[ScoreDim]) -> None: + """ + :param dim_list: Scoring dimensions; each dimension must define a median reference point, scaling factor, and weight. + """ + self.dim_list = dim_list + self.dim_count = len(dim_list) + + @staticmethod + def _sigmoid_normalize(value, median, norm_scale): + """Sigmoid function: Maps value to (0, 1).""" + return 1 / (1 + math.exp(-norm_scale * (value - median))) + + @staticmethod + def _inv_sigmoid_normalize(value, median, norm_scale): + """Inverse Sigmoid: Used for dimensions where a larger value yields a lower score.""" + # Equivalent to sigmoid(-x), but more numerically stable. + return 1 / (1 + math.exp(norm_scale * (value - median))) + + def score(self, *dims: float) -> float: + """ + Compute the composite score. + Larger value → higher score → use forward Sigmoid. + Smaller value → higher score → use inverse Sigmoid. + """ + if len(dims) > self.dim_count: + raise ValueError(f"Dim num({len(dims)}) exceeds max num dim({self.dim_count})") + + final_score = 0.0 + for idx, dim_value in enumerate(dims): + dim_info = self.dim_list[idx] + if dim_info.reverse: + score = self._inv_sigmoid_normalize(dim_value, dim_info.median, dim_info.norm_scale) + else: + score = self._sigmoid_normalize(dim_value, dim_info.median, dim_info.norm_scale) + logger.debug(f"{dim_info.name}({dim_info.reverse}) : {score:.10f}") + + # Weighted summation. + final_score += score * dim_info.weight + return max(0.0, min(1.0, final_score)) # Clamp to [0, 1]. + +class TimeAndLengthScorer(NormalizedScorer): + """ + Scorer for time and length dimensions; defaults to forward scoring with equal weights (0.5 each). + """ + def __init__(self, + time_median, length_median, + time_scale=0.0, length_scale=0.0, + time_weight=0.5, length_weight=0.5, + reverse_time=False, reverse_len=False) -> None: + dim_list = [ScoreDim("time", time_median, time_scale, time_weight, reverse_time), + ScoreDim("length", length_median, length_scale, length_weight, reverse_len)] + super().__init__(dim_list) + + def score(self, time: float, length: float) -> float: + return super().score(time, length) diff --git a/vllm/v1/core/sched/policy/weighted_score_softer.py b/vllm/v1/core/sched/policy/weighted_score_softer.py new file mode 100644 index 0000000000000..17be66a9c6754 --- /dev/null +++ b/vllm/v1/core/sched/policy/weighted_score_softer.py @@ -0,0 +1,28 @@ +from functools import total_ordering + +from vllm.v1.core.sched.policy.normalized_scorer import TimeAndLengthScorer +import time + +TimeAndLengthScorer_Instance = None + +if TimeAndLengthScorer_Instance == None: + TimeAndLengthScorer_Instance = TimeAndLengthScorer(time_median=5, time_weight=0.5, length_median=32 * 1024, + length_weight=0.5, reverse_len=True) +@total_ordering +class WeightedScoreSorter: + def __init__(self, request_length: int, request_arrival_time: float, request_slo_requirement: list = None): + self.request_length = request_length + self.request_arrival_time = request_arrival_time + self.request_slo_requirement = request_slo_requirement + self.__update_stats() + + def __lt__(self, other_request_weighted_score: 'WeightedScoreSorter') -> bool: + self.__update_stats() + return self.weighted_score > other_request_weighted_score.weighted_score + + def __eq__(self, other_request_weighted_score: 'WeightedScoreSorter') -> bool: + return self.weighted_score == other_request_weighted_score.weighted_score + + def __update_stats(self): + self.wait_time = time.time() - self.request_arrival_time + self.weighted_score = TimeAndLengthScorer_Instance.score(self.wait_time, self.request_length) diff --git a/vllm/v1/core/sched/request_queue.py b/vllm/v1/core/sched/request_queue.py index a00ca1912b0f3..792b5098578ab 100644 --- a/vllm/v1/core/sched/request_queue.py +++ b/vllm/v1/core/sched/request_queue.py @@ -8,6 +8,7 @@ from collections.abc import Iterable, Iterator from enum import Enum from vllm.v1.request import Request +from vllm.v1.core.sched.policy.weighted_score_softer import WeightedScoreSorter class SchedulingPolicy(Enum): @@ -15,6 +16,7 @@ class SchedulingPolicy(Enum): FCFS = "fcfs" PRIORITY = "priority" + SJF = "sjf" class RequestQueue(ABC): @@ -207,11 +209,158 @@ class PriorityRequestQueue(RequestQueue): return reversed(list(self)) +class SJFRequestQueue(deque[Request], RequestQueue): + """A short-job-first queue that supports deque operations.""" + + def __init__(self): + deque.__init__(self) + + def add_request(self, request: Request) -> None: + """Add a request to the queue according to SJF policy.""" + self.append(request) + self._sort_requests() + + def pop_request(self) -> Request: + """Pop a request from the queue according to SJF policy.""" + return self.popleft() + + def peek_request(self) -> Request: + """Peek at the next request in the queue without removing it.""" + if not self: + raise IndexError("peek from an empty queue") + self._sort_requests() + return self[0] + + def prepend_request(self, request: Request) -> None: + """Prepend a request to the front of the queue.""" + self.appendleft(request) + + def prepend_requests(self, requests: RequestQueue) -> None: + """Prepend all requests from another queue to the front of this + queue.""" + self.extendleft(reversed(requests)) + + def remove_request(self, request: Request) -> None: + """Remove a specific request from the queue.""" + self.remove(request) + + def remove_requests(self, requests: Iterable[Request]) -> None: + """Remove multiple specific requests from the queue.""" + requests_to_remove = set(requests) + filtered_requests = [ + req for req in self if req not in requests_to_remove + ] + # deque does not support in-place filtering, so we need to clear + # and extend + self.clear() + self.extend(filtered_requests) + + def _sort_requests(self, reverse = False) -> None: + key_func = lambda req: WeightedScoreSorter(request_length=len(req.prompt_token_ids), request_arrival_time=req.arrival_time) + sorted_list = sorted(self, key=key_func, reverse=reverse) + self.clear() + self.extend(sorted_list) + + def __bool__(self) -> bool: + """Check if queue has any requests.""" + return len(self) > 0 + + def __len__(self) -> int: + """Get number of requests in queue.""" + return super().__len__() + + def __iter__(self) -> Iterator[Request]: + """Iterate over the queue according to SJF policy.""" + return super().__iter__() + + def __reversed__(self) -> Iterator[Request]: + """Iterate over the queue in reverse order.""" + return super().__reversed__() + + +class SJFRequestQueueInHeap(RequestQueue): + """ + A SJF queue that supports heap operations. + + Requests with a larger value of weighted score value are processed first. + """ + + def __init__(self) -> None: + self._heap: list[tuple[WeightedScoreSorter, Request]] = [] + + def add_request(self, request: Request) -> None: + """Add a request to the queue according to SJF policy.""" + heapq.heappush(self._heap, + (WeightedScoreSorter(len(request.prompt_token_ids), request.arrival_time), request)) + + def pop_request(self) -> Request: + """Pop a request from the queue according to SJF policy.""" + if not self._heap: + raise IndexError("pop from empty heap") + _, request = heapq.heappop(self._heap) + return request + + def peek_request(self) -> Request: + """Peek at the next request in the queue without removing it.""" + if not self._heap: + raise IndexError("peek from empty heap") + _, request = self._heap[0] + return request + + def prepend_request(self, request: Request) -> None: + """Add a request to the queue according to SJF policy. + + Note: In a SJF queue, there is no concept of prepending to the + front. Requests are ordered by (priority, arrival_time).""" + self.add_request(request) + + def prepend_requests(self, requests: RequestQueue) -> None: + """Add all requests from another queue according to SJF policy. + + Note: In a SJF queue, there is no concept of prepending to the + front. Requests are ordered by weighted score.""" + for request in requests: + self.add_request(request) + + def remove_request(self, request: Request) -> None: + """Remove a specific request from the queue.""" + self._heap = [(ws, r) for ws, r in self._heap if r != request] + heapq.heapify(self._heap) + + def remove_requests(self, requests: Iterable[Request]) -> None: + """Remove multiple specific requests from the queue.""" + requests_to_remove = set(requests) + self._heap = [(ws, r) for ws , r in self._heap + if r not in requests_to_remove] + heapq.heapify(self._heap) + + def __bool__(self) -> bool: + """Check if queue has any requests.""" + return bool(self._heap) + + def __len__(self) -> int: + """Get number of requests in queue.""" + return len(self._heap) + + def __iter__(self) -> Iterator[Request]: + """Iterate over the queue according to SJF policy.""" + heap_copy = self._heap[:] + while heap_copy: + _, _, request = heapq.heappop(heap_copy) + yield request + + def __reversed__(self) -> Iterator[Request]: + """Iterate over the queue in reverse SJF order.""" + return reversed(list(self)) + + def create_request_queue(policy: SchedulingPolicy) -> RequestQueue: """Create request queue based on scheduling policy.""" if policy == SchedulingPolicy.PRIORITY: return PriorityRequestQueue() elif policy == SchedulingPolicy.FCFS: return FCFSRequestQueue() + elif policy == SchedulingPolicy.SJF: + return SJFRequestQueue() else: raise ValueError(f"Unknown scheduling policy: {policy}") From e14d347982a979ee66c389fb63202444ce6c3618 Mon Sep 17 00:00:00 2001 From: Pr0Wh1teGivee Date: Tue, 25 Nov 2025 14:24:25 +0800 Subject: [PATCH 02/20] use heap Signed-off-by: Pr0Wh1teGivee Signed-off-by: weichen --- vllm/v1/core/sched/request_queue.py | 71 +---------------------------- 1 file changed, 1 insertion(+), 70 deletions(-) diff --git a/vllm/v1/core/sched/request_queue.py b/vllm/v1/core/sched/request_queue.py index 792b5098578ab..a6bdab44a26fa 100644 --- a/vllm/v1/core/sched/request_queue.py +++ b/vllm/v1/core/sched/request_queue.py @@ -209,76 +209,7 @@ class PriorityRequestQueue(RequestQueue): return reversed(list(self)) -class SJFRequestQueue(deque[Request], RequestQueue): - """A short-job-first queue that supports deque operations.""" - - def __init__(self): - deque.__init__(self) - - def add_request(self, request: Request) -> None: - """Add a request to the queue according to SJF policy.""" - self.append(request) - self._sort_requests() - - def pop_request(self) -> Request: - """Pop a request from the queue according to SJF policy.""" - return self.popleft() - - def peek_request(self) -> Request: - """Peek at the next request in the queue without removing it.""" - if not self: - raise IndexError("peek from an empty queue") - self._sort_requests() - return self[0] - - def prepend_request(self, request: Request) -> None: - """Prepend a request to the front of the queue.""" - self.appendleft(request) - - def prepend_requests(self, requests: RequestQueue) -> None: - """Prepend all requests from another queue to the front of this - queue.""" - self.extendleft(reversed(requests)) - - def remove_request(self, request: Request) -> None: - """Remove a specific request from the queue.""" - self.remove(request) - - def remove_requests(self, requests: Iterable[Request]) -> None: - """Remove multiple specific requests from the queue.""" - requests_to_remove = set(requests) - filtered_requests = [ - req for req in self if req not in requests_to_remove - ] - # deque does not support in-place filtering, so we need to clear - # and extend - self.clear() - self.extend(filtered_requests) - - def _sort_requests(self, reverse = False) -> None: - key_func = lambda req: WeightedScoreSorter(request_length=len(req.prompt_token_ids), request_arrival_time=req.arrival_time) - sorted_list = sorted(self, key=key_func, reverse=reverse) - self.clear() - self.extend(sorted_list) - - def __bool__(self) -> bool: - """Check if queue has any requests.""" - return len(self) > 0 - - def __len__(self) -> int: - """Get number of requests in queue.""" - return super().__len__() - - def __iter__(self) -> Iterator[Request]: - """Iterate over the queue according to SJF policy.""" - return super().__iter__() - - def __reversed__(self) -> Iterator[Request]: - """Iterate over the queue in reverse order.""" - return super().__reversed__() - - -class SJFRequestQueueInHeap(RequestQueue): +class SJFRequestQueue(RequestQueue): """ A SJF queue that supports heap operations. From 379eabac7f4e0f2dd328551576e005e4907b8cf7 Mon Sep 17 00:00:00 2001 From: Pr0Wh1teGivee Date: Tue, 2 Dec 2025 16:12:40 +0800 Subject: [PATCH 03/20] linting Signed-off-by: Pr0Wh1teGivee Signed-off-by: weichen --- .../v1/core/sched/policy/normalized_scorer.py | 53 +++++++++++++------ .../sched/policy/weighted_score_softer.py | 30 ++++++++--- vllm/v1/core/sched/request_queue.py | 24 +++++---- 3 files changed, 75 insertions(+), 32 deletions(-) diff --git a/vllm/v1/core/sched/policy/normalized_scorer.py b/vllm/v1/core/sched/policy/normalized_scorer.py index 7b7e83cbbd708..145bd57ba5329 100644 --- a/vllm/v1/core/sched/policy/normalized_scorer.py +++ b/vllm/v1/core/sched/policy/normalized_scorer.py @@ -1,31 +1,36 @@ -from typing import List +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math from vllm.logger import init_logger -import math - logger = init_logger(__name__) + class ScoreDim: """ Normalized scoring dimension. """ - def __init__(self, name: str, median: float, norm_scale=0.0, weight=0.5, reverse=False): + + def __init__( + self, name: str, median: float, norm_scale=0.0, weight=0.5, reverse=False + ): self.name = name self.median = median if norm_scale != 0.0: self.norm_scale = norm_scale else: - self.norm_scale = 1/median + self.norm_scale = 1 / median self.weight = weight self.reverse = reverse + class NormalizedScorer: """ Normalize unbounded N-dimensional values into a composite score using the Sigmoid function. """ - def __init__(self, dim_list: List[ScoreDim]) -> None: + def __init__(self, dim_list: list[ScoreDim]) -> None: """ :param dim_list: Scoring dimensions; each dimension must define a median reference point, scaling factor, and weight. """ @@ -50,32 +55,48 @@ class NormalizedScorer: Smaller value → higher score → use inverse Sigmoid. """ if len(dims) > self.dim_count: - raise ValueError(f"Dim num({len(dims)}) exceeds max num dim({self.dim_count})") + raise ValueError( + f"Dim num({len(dims)}) exceeds max num dim({self.dim_count})" + ) final_score = 0.0 for idx, dim_value in enumerate(dims): dim_info = self.dim_list[idx] if dim_info.reverse: - score = self._inv_sigmoid_normalize(dim_value, dim_info.median, dim_info.norm_scale) + score = self._inv_sigmoid_normalize( + dim_value, dim_info.median, dim_info.norm_scale + ) else: - score = self._sigmoid_normalize(dim_value, dim_info.median, dim_info.norm_scale) + score = self._sigmoid_normalize( + dim_value, dim_info.median, dim_info.norm_scale + ) logger.debug(f"{dim_info.name}({dim_info.reverse}) : {score:.10f}") # Weighted summation. final_score += score * dim_info.weight return max(0.0, min(1.0, final_score)) # Clamp to [0, 1]. + class TimeAndLengthScorer(NormalizedScorer): """ Scorer for time and length dimensions; defaults to forward scoring with equal weights (0.5 each). """ - def __init__(self, - time_median, length_median, - time_scale=0.0, length_scale=0.0, - time_weight=0.5, length_weight=0.5, - reverse_time=False, reverse_len=False) -> None: - dim_list = [ScoreDim("time", time_median, time_scale, time_weight, reverse_time), - ScoreDim("length", length_median, length_scale, length_weight, reverse_len)] + + def __init__( + self, + time_median, + length_median, + time_scale=0.0, + length_scale=0.0, + time_weight=0.5, + length_weight=0.5, + reverse_time=False, + reverse_len=False, + ) -> None: + dim_list = [ + ScoreDim("time", time_median, time_scale, time_weight, reverse_time), + ScoreDim("length", length_median, length_scale, length_weight, reverse_len), + ] super().__init__(dim_list) def score(self, time: float, length: float) -> float: diff --git a/vllm/v1/core/sched/policy/weighted_score_softer.py b/vllm/v1/core/sched/policy/weighted_score_softer.py index 17be66a9c6754..cfc6c7bc55bca 100644 --- a/vllm/v1/core/sched/policy/weighted_score_softer.py +++ b/vllm/v1/core/sched/policy/weighted_score_softer.py @@ -1,28 +1,44 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import time from functools import total_ordering from vllm.v1.core.sched.policy.normalized_scorer import TimeAndLengthScorer -import time TimeAndLengthScorer_Instance = None if TimeAndLengthScorer_Instance == None: - TimeAndLengthScorer_Instance = TimeAndLengthScorer(time_median=5, time_weight=0.5, length_median=32 * 1024, - length_weight=0.5, reverse_len=True) + TimeAndLengthScorer_Instance = TimeAndLengthScorer( + time_median=5, + time_weight=0.5, + length_median=32 * 1024, + length_weight=0.5, + reverse_len=True, + ) + + @total_ordering class WeightedScoreSorter: - def __init__(self, request_length: int, request_arrival_time: float, request_slo_requirement: list = None): + def __init__( + self, + request_length: int, + request_arrival_time: float, + request_slo_requirement: list = None, + ): self.request_length = request_length self.request_arrival_time = request_arrival_time self.request_slo_requirement = request_slo_requirement self.__update_stats() - def __lt__(self, other_request_weighted_score: 'WeightedScoreSorter') -> bool: + def __lt__(self, other_request_weighted_score: "WeightedScoreSorter") -> bool: self.__update_stats() return self.weighted_score > other_request_weighted_score.weighted_score - def __eq__(self, other_request_weighted_score: 'WeightedScoreSorter') -> bool: + def __eq__(self, other_request_weighted_score: "WeightedScoreSorter") -> bool: return self.weighted_score == other_request_weighted_score.weighted_score def __update_stats(self): self.wait_time = time.time() - self.request_arrival_time - self.weighted_score = TimeAndLengthScorer_Instance.score(self.wait_time, self.request_length) + self.weighted_score = TimeAndLengthScorer_Instance.score( + self.wait_time, self.request_length + ) diff --git a/vllm/v1/core/sched/request_queue.py b/vllm/v1/core/sched/request_queue.py index a6bdab44a26fa..9200ea524362a 100644 --- a/vllm/v1/core/sched/request_queue.py +++ b/vllm/v1/core/sched/request_queue.py @@ -7,8 +7,8 @@ from collections import deque from collections.abc import Iterable, Iterator from enum import Enum -from vllm.v1.request import Request from vllm.v1.core.sched.policy.weighted_score_softer import WeightedScoreSorter +from vllm.v1.request import Request class SchedulingPolicy(Enum): @@ -212,7 +212,7 @@ class PriorityRequestQueue(RequestQueue): class SJFRequestQueue(RequestQueue): """ A SJF queue that supports heap operations. - + Requests with a larger value of weighted score value are processed first. """ @@ -221,8 +221,15 @@ class SJFRequestQueue(RequestQueue): def add_request(self, request: Request) -> None: """Add a request to the queue according to SJF policy.""" - heapq.heappush(self._heap, - (WeightedScoreSorter(len(request.prompt_token_ids), request.arrival_time), request)) + heapq.heappush( + self._heap, + ( + WeightedScoreSorter( + len(request.prompt_token_ids), request.arrival_time + ), + request, + ), + ) def pop_request(self) -> Request: """Pop a request from the queue according to SJF policy.""" @@ -240,14 +247,14 @@ class SJFRequestQueue(RequestQueue): def prepend_request(self, request: Request) -> None: """Add a request to the queue according to SJF policy. - + Note: In a SJF queue, there is no concept of prepending to the front. Requests are ordered by (priority, arrival_time).""" self.add_request(request) def prepend_requests(self, requests: RequestQueue) -> None: """Add all requests from another queue according to SJF policy. - + Note: In a SJF queue, there is no concept of prepending to the front. Requests are ordered by weighted score.""" for request in requests: @@ -261,8 +268,7 @@ class SJFRequestQueue(RequestQueue): def remove_requests(self, requests: Iterable[Request]) -> None: """Remove multiple specific requests from the queue.""" requests_to_remove = set(requests) - self._heap = [(ws, r) for ws , r in self._heap - if r not in requests_to_remove] + self._heap = [(ws, r) for ws, r in self._heap if r not in requests_to_remove] heapq.heapify(self._heap) def __bool__(self) -> bool: @@ -282,7 +288,7 @@ class SJFRequestQueue(RequestQueue): def __reversed__(self) -> Iterator[Request]: """Iterate over the queue in reverse SJF order.""" - return reversed(list(self)) + return reversed(list(self)) def create_request_queue(policy: SchedulingPolicy) -> RequestQueue: From dd0e1224bccbc139bcebd789297426bfbc5cb8ff Mon Sep 17 00:00:00 2001 From: Pr0Wh1teGivee Date: Wed, 3 Dec 2025 15:17:55 +0800 Subject: [PATCH 04/20] linting Signed-off-by: Pr0Wh1teGivee Signed-off-by: weichen --- vllm/v1/core/sched/policy/weighted_score_softer.py | 8 ++++++-- vllm/v1/core/sched/request_queue.py | 3 ++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/vllm/v1/core/sched/policy/weighted_score_softer.py b/vllm/v1/core/sched/policy/weighted_score_softer.py index cfc6c7bc55bca..b4cd54e30124e 100644 --- a/vllm/v1/core/sched/policy/weighted_score_softer.py +++ b/vllm/v1/core/sched/policy/weighted_score_softer.py @@ -4,6 +4,7 @@ import time from functools import total_ordering from vllm.v1.core.sched.policy.normalized_scorer import TimeAndLengthScorer +from typing import Optional, List, Any TimeAndLengthScorer_Instance = None @@ -23,7 +24,7 @@ class WeightedScoreSorter: self, request_length: int, request_arrival_time: float, - request_slo_requirement: list = None, + request_slo_requirement: Optional[List[Any]] = None, ): self.request_length = request_length self.request_arrival_time = request_arrival_time @@ -34,11 +35,14 @@ class WeightedScoreSorter: self.__update_stats() return self.weighted_score > other_request_weighted_score.weighted_score - def __eq__(self, other_request_weighted_score: "WeightedScoreSorter") -> bool: + def __eq__(self, other_request_weighted_score: object) -> bool: + if not isinstance(other_request_weighted_score, WeightedScoreSorter): + return NotImplemented return self.weighted_score == other_request_weighted_score.weighted_score def __update_stats(self): self.wait_time = time.time() - self.request_arrival_time + assert TimeAndLengthScorer_Instance is not None self.weighted_score = TimeAndLengthScorer_Instance.score( self.wait_time, self.request_length ) diff --git a/vllm/v1/core/sched/request_queue.py b/vllm/v1/core/sched/request_queue.py index 9200ea524362a..3d28fb3652c54 100644 --- a/vllm/v1/core/sched/request_queue.py +++ b/vllm/v1/core/sched/request_queue.py @@ -221,6 +221,7 @@ class SJFRequestQueue(RequestQueue): def add_request(self, request: Request) -> None: """Add a request to the queue according to SJF policy.""" + assert request.prompt_token_ids is not None, "prompt_token_ids cannot be None for SJF scheduling." heapq.heappush( self._heap, ( @@ -283,7 +284,7 @@ class SJFRequestQueue(RequestQueue): """Iterate over the queue according to SJF policy.""" heap_copy = self._heap[:] while heap_copy: - _, _, request = heapq.heappop(heap_copy) + _, request = heapq.heappop(heap_copy) yield request def __reversed__(self) -> Iterator[Request]: From db3e0a576ebd8516f079ff83f9524063ed18a56d Mon Sep 17 00:00:00 2001 From: Pr0Wh1teGivee Date: Wed, 3 Dec 2025 16:04:00 +0800 Subject: [PATCH 05/20] linting Signed-off-by: Pr0Wh1teGivee Signed-off-by: weichen --- vllm/v1/core/sched/policy/normalized_scorer.py | 5 +++-- vllm/v1/core/sched/policy/weighted_score_softer.py | 4 ++-- vllm/v1/core/sched/request_queue.py | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/v1/core/sched/policy/normalized_scorer.py b/vllm/v1/core/sched/policy/normalized_scorer.py index 145bd57ba5329..6929c611d9475 100644 --- a/vllm/v1/core/sched/policy/normalized_scorer.py +++ b/vllm/v1/core/sched/policy/normalized_scorer.py @@ -99,5 +99,6 @@ class TimeAndLengthScorer(NormalizedScorer): ] super().__init__(dim_list) - def score(self, time: float, length: float) -> float: - return super().score(time, length) + def score(self, *dims: float) -> float: + assert len(dims) == 2 + return super().score(*dims) diff --git a/vllm/v1/core/sched/policy/weighted_score_softer.py b/vllm/v1/core/sched/policy/weighted_score_softer.py index b4cd54e30124e..3280a15c7a32b 100644 --- a/vllm/v1/core/sched/policy/weighted_score_softer.py +++ b/vllm/v1/core/sched/policy/weighted_score_softer.py @@ -2,9 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import time from functools import total_ordering +from typing import Any from vllm.v1.core.sched.policy.normalized_scorer import TimeAndLengthScorer -from typing import Optional, List, Any TimeAndLengthScorer_Instance = None @@ -24,7 +24,7 @@ class WeightedScoreSorter: self, request_length: int, request_arrival_time: float, - request_slo_requirement: Optional[List[Any]] = None, + request_slo_requirement: list[Any] | None = None, ): self.request_length = request_length self.request_arrival_time = request_arrival_time diff --git a/vllm/v1/core/sched/request_queue.py b/vllm/v1/core/sched/request_queue.py index 3d28fb3652c54..7cea8a881f0a7 100644 --- a/vllm/v1/core/sched/request_queue.py +++ b/vllm/v1/core/sched/request_queue.py @@ -221,7 +221,7 @@ class SJFRequestQueue(RequestQueue): def add_request(self, request: Request) -> None: """Add a request to the queue according to SJF policy.""" - assert request.prompt_token_ids is not None, "prompt_token_ids cannot be None for SJF scheduling." + assert request.prompt_token_ids is not None heapq.heappush( self._heap, ( From 1e8b313afb2d6e09039f3221bf98785e9c9aa6d9 Mon Sep 17 00:00:00 2001 From: Pr0Wh1teGivee Date: Wed, 3 Dec 2025 16:22:51 +0800 Subject: [PATCH 06/20] linting Signed-off-by: Pr0Wh1teGivee Signed-off-by: weichen --- vllm/v1/core/sched/policy/normalized_scorer.py | 7 ++++++- vllm/v1/core/sched/policy/weighted_score_softer.py | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/vllm/v1/core/sched/policy/normalized_scorer.py b/vllm/v1/core/sched/policy/normalized_scorer.py index 6929c611d9475..d44738068ee56 100644 --- a/vllm/v1/core/sched/policy/normalized_scorer.py +++ b/vllm/v1/core/sched/policy/normalized_scorer.py @@ -70,7 +70,12 @@ class NormalizedScorer: score = self._sigmoid_normalize( dim_value, dim_info.median, dim_info.norm_scale ) - logger.debug(f"{dim_info.name}({dim_info.reverse}) : {score:.10f}") + logger.debug( + "%s(%s) : %.10f", + dim_info.name, + dim_info.reverse, + score + ) # Weighted summation. final_score += score * dim_info.weight diff --git a/vllm/v1/core/sched/policy/weighted_score_softer.py b/vllm/v1/core/sched/policy/weighted_score_softer.py index 3280a15c7a32b..4828559d27695 100644 --- a/vllm/v1/core/sched/policy/weighted_score_softer.py +++ b/vllm/v1/core/sched/policy/weighted_score_softer.py @@ -8,7 +8,7 @@ from vllm.v1.core.sched.policy.normalized_scorer import TimeAndLengthScorer TimeAndLengthScorer_Instance = None -if TimeAndLengthScorer_Instance == None: +if TimeAndLengthScorer_Instance is None: TimeAndLengthScorer_Instance = TimeAndLengthScorer( time_median=5, time_weight=0.5, From b04f6786593742f15eb4bf869d3c9dd3d922fe89 Mon Sep 17 00:00:00 2001 From: Pr0Wh1teGivee Date: Wed, 3 Dec 2025 17:01:51 +0800 Subject: [PATCH 07/20] linting Signed-off-by: Pr0Wh1teGivee Signed-off-by: weichen --- vllm/v1/core/sched/policy/normalized_scorer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/v1/core/sched/policy/normalized_scorer.py b/vllm/v1/core/sched/policy/normalized_scorer.py index d44738068ee56..ff124dd5c8929 100644 --- a/vllm/v1/core/sched/policy/normalized_scorer.py +++ b/vllm/v1/core/sched/policy/normalized_scorer.py @@ -32,7 +32,11 @@ class NormalizedScorer: def __init__(self, dim_list: list[ScoreDim]) -> None: """ - :param dim_list: Scoring dimensions; each dimension must define a median reference point, scaling factor, and weight. + Initialize the scorer with a list of scoring dimensions. + + Args: + dim_list: A list of `ScoreDim` objects. Each dimension must define a + median reference point, scaling factor, and weight. """ self.dim_list = dim_list self.dim_count = len(dim_list) From ed2a8082520ada5b575b77158437006e059dc8c3 Mon Sep 17 00:00:00 2001 From: Pr0Wh1teGivee Date: Thu, 4 Dec 2025 10:19:36 +0800 Subject: [PATCH 08/20] Update normalized_scorer.py Signed-off-by: Pr0Wh1teGivee Signed-off-by: weichen --- vllm/v1/core/sched/policy/normalized_scorer.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/vllm/v1/core/sched/policy/normalized_scorer.py b/vllm/v1/core/sched/policy/normalized_scorer.py index ff124dd5c8929..9460d7454085c 100644 --- a/vllm/v1/core/sched/policy/normalized_scorer.py +++ b/vllm/v1/core/sched/policy/normalized_scorer.py @@ -27,7 +27,8 @@ class ScoreDim: class NormalizedScorer: """ - Normalize unbounded N-dimensional values into a composite score using the Sigmoid function. + Normalize unbounded N-dimensional values into a composite score using the Sigmoid + function. """ def __init__(self, dim_list: list[ScoreDim]) -> None: @@ -48,7 +49,9 @@ class NormalizedScorer: @staticmethod def _inv_sigmoid_normalize(value, median, norm_scale): - """Inverse Sigmoid: Used for dimensions where a larger value yields a lower score.""" + """Inverse Sigmoid: Used for dimensions where a larger value yields a lower + score. + """ # Equivalent to sigmoid(-x), but more numerically stable. return 1 / (1 + math.exp(norm_scale * (value - median))) @@ -74,12 +77,7 @@ class NormalizedScorer: score = self._sigmoid_normalize( dim_value, dim_info.median, dim_info.norm_scale ) - logger.debug( - "%s(%s) : %.10f", - dim_info.name, - dim_info.reverse, - score - ) + logger.debug("%s(%s) : %.10f", dim_info.name, dim_info.reverse, score) # Weighted summation. final_score += score * dim_info.weight @@ -88,7 +86,8 @@ class NormalizedScorer: class TimeAndLengthScorer(NormalizedScorer): """ - Scorer for time and length dimensions; defaults to forward scoring with equal weights (0.5 each). + Scorer for time and length dimensions; defaults to forward scoring with equal + weights (0.5 each). """ def __init__( From 779769ea97744986f4ef0ca10471d3a6d4261c0e Mon Sep 17 00:00:00 2001 From: Pr0Wh1teGivee Date: Fri, 5 Dec 2025 17:06:27 +0800 Subject: [PATCH 09/20] Create __init__.py Signed-off-by: Pr0Wh1teGivee Signed-off-by: weichen --- vllm/v1/core/sched/policy/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 vllm/v1/core/sched/policy/__init__.py diff --git a/vllm/v1/core/sched/policy/__init__.py b/vllm/v1/core/sched/policy/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d From 64137934664cd02d9c80ce585302c333b84f45cf Mon Sep 17 00:00:00 2001 From: Pr0Wh1teGivee Date: Fri, 5 Dec 2025 17:15:33 +0800 Subject: [PATCH 10/20] Update scheduler.py Signed-off-by: Pr0Wh1teGivee Signed-off-by: weichen --- vllm/config/scheduler.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index 1fe09a6ae2ce3..23d21c0fd45ca 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -21,7 +21,11 @@ logger = init_logger(__name__) RunnerType = Literal["generate", "pooling", "draft"] SchedulerPolicy = Literal["fcfs", "priority", "sjf"] - +""" SJF Scheduling Policy: +It stands for shortest-job-first — requests are scheduled by total prompt + +output length (shorter first), with aging to prevent starvation. For more +information, please check: https://github.com/vllm-project/vllm/issues/29406 +""" @config @dataclass From ac674f6fc77f98d186c8f72ab13caedc5b21b4f0 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Fri, 5 Dec 2025 11:03:38 +0100 Subject: [PATCH 11/20] Move docstring Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: Pr0Wh1teGivee Signed-off-by: weichen --- vllm/config/scheduler.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index 23d21c0fd45ca..bb0ecf38a7468 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -21,11 +21,7 @@ logger = init_logger(__name__) RunnerType = Literal["generate", "pooling", "draft"] SchedulerPolicy = Literal["fcfs", "priority", "sjf"] -""" SJF Scheduling Policy: -It stands for shortest-job-first — requests are scheduled by total prompt + -output length (shorter first), with aging to prevent starvation. For more -information, please check: https://github.com/vllm-project/vllm/issues/29406 -""" + @config @dataclass @@ -109,7 +105,9 @@ class SchedulerConfig: - "fcfs" means first come first served, i.e. requests are handled in order of arrival.\n - "priority" means requests are handled based on given priority (lower - value means earlier handling) and time of arrival deciding any ties).""" + value means earlier handling) and time of arrival deciding any ties).\n + - "sjf" means shortest job first. Requests are scheduled by prompt length + (shortest first), with aging to prevent starvation.""" disable_chunked_mm_input: bool = False """If set to true and chunked prefill is enabled, we do not want to From cc0a8ae572d50a243f55705c821fa45b34cb8b1f Mon Sep 17 00:00:00 2001 From: Pr0Wh1teGivee Date: Sat, 6 Dec 2025 10:16:03 +0800 Subject: [PATCH 12/20] naming Signed-off-by: Pr0Wh1teGivee Signed-off-by: weichen --- .../{weighted_score_softer.py => weighted_score_sorter.py} | 0 vllm/v1/core/sched/request_queue.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename vllm/v1/core/sched/policy/{weighted_score_softer.py => weighted_score_sorter.py} (100%) diff --git a/vllm/v1/core/sched/policy/weighted_score_softer.py b/vllm/v1/core/sched/policy/weighted_score_sorter.py similarity index 100% rename from vllm/v1/core/sched/policy/weighted_score_softer.py rename to vllm/v1/core/sched/policy/weighted_score_sorter.py diff --git a/vllm/v1/core/sched/request_queue.py b/vllm/v1/core/sched/request_queue.py index 7cea8a881f0a7..b0efd043c0668 100644 --- a/vllm/v1/core/sched/request_queue.py +++ b/vllm/v1/core/sched/request_queue.py @@ -7,7 +7,7 @@ from collections import deque from collections.abc import Iterable, Iterator from enum import Enum -from vllm.v1.core.sched.policy.weighted_score_softer import WeightedScoreSorter +from vllm.v1.core.sched.policy.weighted_score_sorter import WeightedScoreSorter from vllm.v1.request import Request From 4fe722fae5b747af437a889bbeab90f511542414 Mon Sep 17 00:00:00 2001 From: Pr0Wh1teGivee Date: Tue, 9 Dec 2025 15:26:05 +0800 Subject: [PATCH 13/20] abstracting common code to HeapBasedRequestQueue Signed-off-by: Pr0Wh1teGivee Signed-off-by: weichen --- vllm/v1/core/sched/request_queue.py | 181 +++++++++++++--------------- 1 file changed, 82 insertions(+), 99 deletions(-) diff --git a/vllm/v1/core/sched/request_queue.py b/vllm/v1/core/sched/request_queue.py index b0efd043c0668..3c5618a4700fd 100644 --- a/vllm/v1/core/sched/request_queue.py +++ b/vllm/v1/core/sched/request_queue.py @@ -135,58 +135,72 @@ class FCFSRequestQueue(deque[Request], RequestQueue): return super().__reversed__() -class PriorityRequestQueue(RequestQueue): - """ - A priority queue that supports heap operations. - - Respects the ordering defined in the Request class, where - requests with a smaller value of `priority` are processed first. - If multiple requests have the same priority, the one with the earlier - `arrival_time` is processed first. - """ +class HeapBasedRequestQueue(RequestQueue, ABC): + """Base class for heap-based request queues (priority and SJF).""" def __init__(self) -> None: - self._heap: list[Request] = [] + self._heap: list = [] + + @abstractmethod + def _to_heap_element(self, request: Request) -> object: + """Convert a request to the appropriate heap element.""" + pass + + @abstractmethod + def _from_heap_element(self, heap_element: object) -> Request: + """Extract the request from a heap element.""" + pass def add_request(self, request: Request) -> None: - """Add a request to the queue according to priority policy.""" - heapq.heappush(self._heap, request) + """Add a request to the heap queue.""" + heap_element = self._to_heap_element(request) + heapq.heappush(self._heap, heap_element) def pop_request(self) -> Request: - """Pop a request from the queue according to priority policy.""" + """Pop the highest priority request from the heap.""" if not self._heap: raise IndexError("pop from empty heap") - return heapq.heappop(self._heap) + heap_element = heapq.heappop(self._heap) + return self._from_heap_element(heap_element) def peek_request(self) -> Request: - """Peek at the next request in the queue without removing it.""" + """Peek at the highest priority request without removing it.""" if not self._heap: raise IndexError("peek from empty heap") - return self._heap[0] + return self._from_heap_element(self._heap[0]) def prepend_request(self, request: Request) -> None: - """Add a request to the queue according to priority policy. - - Note: In a priority queue, there is no concept of prepending to the - front. Requests are ordered by (priority, arrival_time).""" + """ + Add request to the queue. In heap-based queues, "prepend" has no + special meaning as elements are ordered by priority/score. This + behaves like add_request. + """ self.add_request(request) def prepend_requests(self, requests: RequestQueue) -> None: - """Add all requests from another queue according to priority policy. - - Note: In a priority queue, there is no concept of prepending to the - front. Requests are ordered by (priority, arrival_time).""" + """ + Add all requests from another queue. In heap-based queues, + "prepend" has no special meaning as elements are ordered by + priority/score. This behaves like adding all requests. + """ for request in requests: self.add_request(request) def remove_request(self, request: Request) -> None: - """Remove a specific request from the queue.""" - self._heap.remove(request) - heapq.heapify(self._heap) + """Remove a specific request from the heap.""" + try: + self._heap.remove(request) + heapq.heapify(self._heap) + except ValueError as err: + raise ValueError( + f"Request {request.request_id} not found in queue" + ) from err def remove_requests(self, requests: Iterable[Request]) -> None: - """Remove multiple specific requests from the queue.""" - requests_to_remove = requests if isinstance(requests, set) else set(requests) + """Remove multiple specific requests from the heap.""" + requests_to_remove = ( + set(requests) if not isinstance(requests, set) else requests + ) self._heap = [r for r in self._heap if r not in requests_to_remove] heapq.heapify(self._heap) @@ -199,98 +213,67 @@ class PriorityRequestQueue(RequestQueue): return len(self._heap) def __iter__(self) -> Iterator[Request]: - """Iterate over the queue according to priority policy.""" - heap_copy = self._heap[:] + """Iterate over requests in priority/score order.""" + heap_copy = self._heap.copy() while heap_copy: - yield heapq.heappop(heap_copy) + heap_element = heapq.heappop(heap_copy) + yield self._from_heap_element(heap_element) def __reversed__(self) -> Iterator[Request]: - """Iterate over the queue in reverse priority order.""" + """Iterate over requests in reverse priority/score order.""" return reversed(list(self)) -class SJFRequestQueue(RequestQueue): +class PriorityRequestQueue(HeapBasedRequestQueue): """ - A SJF queue that supports heap operations. - - Requests with a larger value of weighted score value are processed first. + A priority queue where requests are ordered by (priority, arrival_time). + Lower priority values and earlier arrival times are processed first. """ - def __init__(self) -> None: - self._heap: list[tuple[WeightedScoreSorter, Request]] = [] + def _to_heap_element(self, request: Request) -> Request: + """For priority queue, the heap element is the request itself.""" + return request - def add_request(self, request: Request) -> None: - """Add a request to the queue according to SJF policy.""" + def _from_heap_element(self, heap_element: object) -> Request: + """Extract request from heap element with type checking.""" + assert isinstance(heap_element, Request) + return heap_element + + +class SJFRequestQueue(HeapBasedRequestQueue): + """ + A Shortest Job First (SJF) queue where requests are ordered by weighted score. + Requests with higher weighted scores (shorter jobs) are processed first. + """ + + def _to_heap_element(self, request: Request) -> tuple[WeightedScoreSorter, Request]: + """Convert request to (weighted_score, request) tuple for heap.""" assert request.prompt_token_ids is not None - heapq.heappush( - self._heap, - ( - WeightedScoreSorter( - len(request.prompt_token_ids), request.arrival_time - ), - request, - ), + return ( + WeightedScoreSorter(len(request.prompt_token_ids), request.arrival_time), + request, ) - def pop_request(self) -> Request: - """Pop a request from the queue according to SJF policy.""" - if not self._heap: - raise IndexError("pop from empty heap") - _, request = heapq.heappop(self._heap) + def _from_heap_element(self, heap_element: object) -> Request: + """Extract request from the (score, request) tuple with type checking.""" + assert isinstance(heap_element, tuple) and len(heap_element) == 2 + _, request = heap_element return request - def peek_request(self) -> Request: - """Peek at the next request in the queue without removing it.""" - if not self._heap: - raise IndexError("peek from empty heap") - _, request = self._heap[0] - return request - - def prepend_request(self, request: Request) -> None: - """Add a request to the queue according to SJF policy. - - Note: In a SJF queue, there is no concept of prepending to the - front. Requests are ordered by (priority, arrival_time).""" - self.add_request(request) - - def prepend_requests(self, requests: RequestQueue) -> None: - """Add all requests from another queue according to SJF policy. - - Note: In a SJF queue, there is no concept of prepending to the - front. Requests are ordered by weighted score.""" - for request in requests: - self.add_request(request) - def remove_request(self, request: Request) -> None: - """Remove a specific request from the queue.""" - self._heap = [(ws, r) for ws, r in self._heap if r != request] + """Remove a specific request from the SJF heap.""" + original_length = len(self._heap) + self._heap = [(ws, r) for (ws, r) in self._heap if r != request] + if len(self._heap) == original_length: + raise ValueError(f"Request {request.request_id} not found in SJF queue") heapq.heapify(self._heap) def remove_requests(self, requests: Iterable[Request]) -> None: - """Remove multiple specific requests from the queue.""" + """Remove multiple specific requests from the SJF heap.""" requests_to_remove = set(requests) - self._heap = [(ws, r) for ws, r in self._heap if r not in requests_to_remove] + self._heap = [(ws, r) for (ws, r) in self._heap if r not in requests_to_remove] heapq.heapify(self._heap) - def __bool__(self) -> bool: - """Check if queue has any requests.""" - return bool(self._heap) - - def __len__(self) -> int: - """Get number of requests in queue.""" - return len(self._heap) - - def __iter__(self) -> Iterator[Request]: - """Iterate over the queue according to SJF policy.""" - heap_copy = self._heap[:] - while heap_copy: - _, request = heapq.heappop(heap_copy) - yield request - - def __reversed__(self) -> Iterator[Request]: - """Iterate over the queue in reverse SJF order.""" - return reversed(list(self)) - def create_request_queue(policy: SchedulingPolicy) -> RequestQueue: """Create request queue based on scheduling policy.""" From 601387735c87eb6aaaa8a69296cfb8369b2db739 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 18 Dec 2025 17:20:49 +0100 Subject: [PATCH 14/20] Fix removal from heap Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: weichen --- vllm/v1/core/sched/request_queue.py | 113 ++++++++++------------------ 1 file changed, 40 insertions(+), 73 deletions(-) diff --git a/vllm/v1/core/sched/request_queue.py b/vllm/v1/core/sched/request_queue.py index 3c5618a4700fd..2c0a9d62eaae3 100644 --- a/vllm/v1/core/sched/request_queue.py +++ b/vllm/v1/core/sched/request_queue.py @@ -6,6 +6,7 @@ from abc import ABC, abstractmethod from collections import deque from collections.abc import Iterable, Iterator from enum import Enum +from typing import Any from vllm.v1.core.sched.policy.weighted_score_sorter import WeightedScoreSorter from vllm.v1.request import Request @@ -135,73 +136,57 @@ class FCFSRequestQueue(deque[Request], RequestQueue): return super().__reversed__() -class HeapBasedRequestQueue(RequestQueue, ABC): - """Base class for heap-based request queues (priority and SJF).""" +class RequestHeap(RequestQueue): + """A queue that supports heap operations.""" def __init__(self) -> None: self._heap: list = [] - @abstractmethod - def _to_heap_element(self, request: Request) -> object: + def _request_to_heap(self, request: Request) -> Any: """Convert a request to the appropriate heap element.""" - pass + raise NotImplementedError - @abstractmethod - def _from_heap_element(self, heap_element: object) -> Request: + def _heap_to_request(self, element: Any) -> Request: """Extract the request from a heap element.""" - pass + raise NotImplementedError def add_request(self, request: Request) -> None: - """Add a request to the heap queue.""" - heap_element = self._to_heap_element(request) - heapq.heappush(self._heap, heap_element) + """Add a request to the queue according to heap priority.""" + heapq.heappush(self._heap, self._request_to_heap(request)) def pop_request(self) -> Request: """Pop the highest priority request from the heap.""" if not self._heap: raise IndexError("pop from empty heap") - heap_element = heapq.heappop(self._heap) - return self._from_heap_element(heap_element) + return self._heap_to_request(heapq.heappop(self._heap)) def peek_request(self) -> Request: - """Peek at the highest priority request without removing it.""" + """Peek at the highest priority request in the heap without removing it.""" if not self._heap: raise IndexError("peek from empty heap") - return self._from_heap_element(self._heap[0]) + return self._heap_to_request(self._heap[0]) def prepend_request(self, request: Request) -> None: - """ - Add request to the queue. In heap-based queues, "prepend" has no - special meaning as elements are ordered by priority/score. This - behaves like add_request. - """ + """Add a request to the heap. In heap-based queues there is no beginning as + elements are ordered by priority/score. This behaves like add_request.""" self.add_request(request) def prepend_requests(self, requests: RequestQueue) -> None: - """ - Add all requests from another queue. In heap-based queues, - "prepend" has no special meaning as elements are ordered by - priority/score. This behaves like adding all requests. - """ + """Add all requests from another queue to the heap. In heap-based queues there + is no beginning as elements are ordered by priority/score. This behaves like + add_request.""" for request in requests: self.add_request(request) def remove_request(self, request: Request) -> None: """Remove a specific request from the heap.""" - try: - self._heap.remove(request) - heapq.heapify(self._heap) - except ValueError as err: - raise ValueError( - f"Request {request.request_id} not found in queue" - ) from err + self._heap.remove(self._request_to_heap(request)) + heapq.heapify(self._heap) def remove_requests(self, requests: Iterable[Request]) -> None: """Remove multiple specific requests from the heap.""" - requests_to_remove = ( - set(requests) if not isinstance(requests, set) else requests - ) - self._heap = [r for r in self._heap if r not in requests_to_remove] + remove = requests if isinstance(requests, set) else set(requests) + self._heap = [h for h in self._heap if self._heap_to_request(h) not in remove] heapq.heapify(self._heap) def __bool__(self) -> bool: @@ -213,40 +198,37 @@ class HeapBasedRequestQueue(RequestQueue, ABC): return len(self._heap) def __iter__(self) -> Iterator[Request]: - """Iterate over requests in priority/score order.""" - heap_copy = self._heap.copy() + """Iterate over the queue to heap order.""" + heap_copy = self._heap[:] while heap_copy: - heap_element = heapq.heappop(heap_copy) - yield self._from_heap_element(heap_element) + yield self._heap_to_request(heapq.heappop(heap_copy)) def __reversed__(self) -> Iterator[Request]: - """Iterate over requests in reverse priority/score order.""" + """Iterate over the queue in reverse heap order.""" return reversed(list(self)) -class PriorityRequestQueue(HeapBasedRequestQueue): - """ - A priority queue where requests are ordered by (priority, arrival_time). - Lower priority values and earlier arrival times are processed first. - """ +class PriorityRequestQueue(RequestHeap): + """A priority queue that supports heap operations. - def _to_heap_element(self, request: Request) -> Request: + Respects the ordering defined in the Request class, where requests with a smaller + value of `priority` are processed first. If multiple requests have the same + priority, the one with the earlier `arrival_time` is processed first.""" + + def _request_to_heap(self, request: Request) -> Request: """For priority queue, the heap element is the request itself.""" return request - def _from_heap_element(self, heap_element: object) -> Request: + def _heap_to_request(self, element: Request) -> Request: """Extract request from heap element with type checking.""" - assert isinstance(heap_element, Request) - return heap_element + return element -class SJFRequestQueue(HeapBasedRequestQueue): - """ - A Shortest Job First (SJF) queue where requests are ordered by weighted score. - Requests with higher weighted scores (shorter jobs) are processed first. - """ +class SJFRequestQueue(RequestHeap): + """A Shortest Job First (SJF) queue where requests are ordered by weighted score. + Requests with higher weighted scores (shorter jobs) are processed first.""" - def _to_heap_element(self, request: Request) -> tuple[WeightedScoreSorter, Request]: + def _request_to_heap(self, request: Request) -> tuple[WeightedScoreSorter, Request]: """Convert request to (weighted_score, request) tuple for heap.""" assert request.prompt_token_ids is not None return ( @@ -254,26 +236,11 @@ class SJFRequestQueue(HeapBasedRequestQueue): request, ) - def _from_heap_element(self, heap_element: object) -> Request: + def _heap_to_request(self, element: tuple[WeightedScoreSorter, Request]) -> Request: """Extract request from the (score, request) tuple with type checking.""" - assert isinstance(heap_element, tuple) and len(heap_element) == 2 - _, request = heap_element + _, request = element return request - def remove_request(self, request: Request) -> None: - """Remove a specific request from the SJF heap.""" - original_length = len(self._heap) - self._heap = [(ws, r) for (ws, r) in self._heap if r != request] - if len(self._heap) == original_length: - raise ValueError(f"Request {request.request_id} not found in SJF queue") - heapq.heapify(self._heap) - - def remove_requests(self, requests: Iterable[Request]) -> None: - """Remove multiple specific requests from the SJF heap.""" - requests_to_remove = set(requests) - self._heap = [(ws, r) for (ws, r) in self._heap if r not in requests_to_remove] - heapq.heapify(self._heap) - def create_request_queue(policy: SchedulingPolicy) -> RequestQueue: """Create request queue based on scheduling policy.""" From 53d57d9dca9d60bb5068652679ff53ea47f2b553 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 18 Dec 2025 17:25:03 +0100 Subject: [PATCH 15/20] Remove tuple stuff Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: weichen --- .../core/sched/policy/weighted_score_sorter.py | 16 ++++++---------- vllm/v1/core/sched/request_queue.py | 17 ++++++----------- 2 files changed, 12 insertions(+), 21 deletions(-) diff --git a/vllm/v1/core/sched/policy/weighted_score_sorter.py b/vllm/v1/core/sched/policy/weighted_score_sorter.py index 4828559d27695..db72de3b4dbe1 100644 --- a/vllm/v1/core/sched/policy/weighted_score_sorter.py +++ b/vllm/v1/core/sched/policy/weighted_score_sorter.py @@ -2,9 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import time from functools import total_ordering -from typing import Any from vllm.v1.core.sched.policy.normalized_scorer import TimeAndLengthScorer +from vllm.v1.request import Request TimeAndLengthScorer_Instance = None @@ -20,15 +20,11 @@ if TimeAndLengthScorer_Instance is None: @total_ordering class WeightedScoreSorter: - def __init__( - self, - request_length: int, - request_arrival_time: float, - request_slo_requirement: list[Any] | None = None, - ): - self.request_length = request_length - self.request_arrival_time = request_arrival_time - self.request_slo_requirement = request_slo_requirement + def __init__(self, request: Request): + self.request = request + assert request.prompt_token_ids is not None + self.request_length = len(request.prompt_token_ids) + self.request_arrival_time = request.arrival_time self.__update_stats() def __lt__(self, other_request_weighted_score: "WeightedScoreSorter") -> bool: diff --git a/vllm/v1/core/sched/request_queue.py b/vllm/v1/core/sched/request_queue.py index 2c0a9d62eaae3..e339b5de6255f 100644 --- a/vllm/v1/core/sched/request_queue.py +++ b/vllm/v1/core/sched/request_queue.py @@ -228,18 +228,13 @@ class SJFRequestQueue(RequestHeap): """A Shortest Job First (SJF) queue where requests are ordered by weighted score. Requests with higher weighted scores (shorter jobs) are processed first.""" - def _request_to_heap(self, request: Request) -> tuple[WeightedScoreSorter, Request]: - """Convert request to (weighted_score, request) tuple for heap.""" - assert request.prompt_token_ids is not None - return ( - WeightedScoreSorter(len(request.prompt_token_ids), request.arrival_time), - request, - ) + def _request_to_heap(self, request: Request) -> WeightedScoreSorter: + """Convert request to `WeightedScoreSorter` for heap.""" + return WeightedScoreSorter(request) - def _heap_to_request(self, element: tuple[WeightedScoreSorter, Request]) -> Request: - """Extract request from the (score, request) tuple with type checking.""" - _, request = element - return request + def _heap_to_request(self, element: WeightedScoreSorter) -> Request: + """Extract request from the `WeightedScoreSorter`.""" + return element.request def create_request_queue(policy: SchedulingPolicy) -> RequestQueue: From 9e8d9e1231620bb3a4579a2fb9df7865ce41ce54 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 18 Dec 2025 17:34:28 +0100 Subject: [PATCH 16/20] Consolidate SJF code and remove global variable Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: weichen --- .../core/sched/policy/shortest_job_first.py | 139 ++++++++++++++++++ .../sched/policy/weighted_score_sorter.py | 44 ------ vllm/v1/core/sched/request_queue.py | 11 +- 3 files changed, 148 insertions(+), 46 deletions(-) create mode 100644 vllm/v1/core/sched/policy/shortest_job_first.py delete mode 100644 vllm/v1/core/sched/policy/weighted_score_sorter.py diff --git a/vllm/v1/core/sched/policy/shortest_job_first.py b/vllm/v1/core/sched/policy/shortest_job_first.py new file mode 100644 index 0000000000000..4d9ba0de11b65 --- /dev/null +++ b/vllm/v1/core/sched/policy/shortest_job_first.py @@ -0,0 +1,139 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math +import time +from functools import total_ordering + +from vllm.logger import init_logger +from vllm.v1.request import Request + +logger = init_logger(__name__) + + +class ScoreDim: + """ + Normalized scoring dimension. + """ + + def __init__( + self, name: str, median: float, norm_scale=0.0, weight=0.5, reverse=False + ): + self.name = name + self.median = median + if norm_scale != 0.0: + self.norm_scale = norm_scale + else: + self.norm_scale = 1 / median + self.weight = weight + self.reverse = reverse + + +class NormalizedScorer: + """ + Normalize unbounded N-dimensional values into a composite score using the Sigmoid + function. + """ + + def __init__(self, dim_list: list[ScoreDim]) -> None: + """ + Initialize the scorer with a list of scoring dimensions. + + Args: + dim_list: A list of `ScoreDim` objects. Each dimension must define a + median reference point, scaling factor, and weight. + """ + self.dim_list = dim_list + self.dim_count = len(dim_list) + + @staticmethod + def _sigmoid_normalize(value, median, norm_scale): + """Sigmoid function: Maps value to (0, 1).""" + return 1 / (1 + math.exp(-norm_scale * (value - median))) + + @staticmethod + def _inv_sigmoid_normalize(value, median, norm_scale): + """Inverse Sigmoid: Used for dimensions where a larger value yields a lower + score. + """ + # Equivalent to sigmoid(-x), but more numerically stable. + return 1 / (1 + math.exp(norm_scale * (value - median))) + + def score(self, *dims: float) -> float: + """ + Compute the composite score. + Larger value → higher score → use forward Sigmoid. + Smaller value → higher score → use inverse Sigmoid. + """ + if len(dims) > self.dim_count: + raise ValueError( + f"Dim num({len(dims)}) exceeds max num dim({self.dim_count})" + ) + + final_score = 0.0 + for idx, dim_value in enumerate(dims): + dim_info = self.dim_list[idx] + if dim_info.reverse: + score = self._inv_sigmoid_normalize( + dim_value, dim_info.median, dim_info.norm_scale + ) + else: + score = self._sigmoid_normalize( + dim_value, dim_info.median, dim_info.norm_scale + ) + logger.debug("%s(%s) : %.10f", dim_info.name, dim_info.reverse, score) + + # Weighted summation. + final_score += score * dim_info.weight + return max(0.0, min(1.0, final_score)) # Clamp to [0, 1]. + + +class TimeAndLengthScorer(NormalizedScorer): + """ + Scorer for time and length dimensions; defaults to forward scoring with equal + weights (0.5 each). + """ + + def __init__( + self, + time_median=5, + length_median=1024 * 32, + time_scale=0.0, + length_scale=0.0, + time_weight=0.5, + length_weight=0.5, + reverse_time=False, + reverse_len=True, + ) -> None: + dim_list = [ + ScoreDim("time", time_median, time_scale, time_weight, reverse_time), + ScoreDim("length", length_median, length_scale, length_weight, reverse_len), + ] + super().__init__(dim_list) + + def score(self, *dims: float) -> float: + assert len(dims) == 2 + return super().score(*dims) + + +@total_ordering +class WeightedScoreSorter: + def __init__(self, request: Request, scorer: TimeAndLengthScorer): + self.request = request + self.scorer = scorer + assert request.prompt_token_ids is not None + self.request_length = len(request.prompt_token_ids) + self.request_arrival_time = request.arrival_time + self.__update_stats() + + def __lt__(self, other_request_weighted_score: "WeightedScoreSorter") -> bool: + self.__update_stats() + return self.weighted_score > other_request_weighted_score.weighted_score + + def __eq__(self, other_request_weighted_score: object) -> bool: + if not isinstance(other_request_weighted_score, WeightedScoreSorter): + return NotImplemented + return self.weighted_score == other_request_weighted_score.weighted_score + + def __update_stats(self): + self.wait_time = time.time() - self.request_arrival_time + self.weighted_score = self.scorer.score(self.wait_time, self.request_length) diff --git a/vllm/v1/core/sched/policy/weighted_score_sorter.py b/vllm/v1/core/sched/policy/weighted_score_sorter.py deleted file mode 100644 index db72de3b4dbe1..0000000000000 --- a/vllm/v1/core/sched/policy/weighted_score_sorter.py +++ /dev/null @@ -1,44 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import time -from functools import total_ordering - -from vllm.v1.core.sched.policy.normalized_scorer import TimeAndLengthScorer -from vllm.v1.request import Request - -TimeAndLengthScorer_Instance = None - -if TimeAndLengthScorer_Instance is None: - TimeAndLengthScorer_Instance = TimeAndLengthScorer( - time_median=5, - time_weight=0.5, - length_median=32 * 1024, - length_weight=0.5, - reverse_len=True, - ) - - -@total_ordering -class WeightedScoreSorter: - def __init__(self, request: Request): - self.request = request - assert request.prompt_token_ids is not None - self.request_length = len(request.prompt_token_ids) - self.request_arrival_time = request.arrival_time - self.__update_stats() - - def __lt__(self, other_request_weighted_score: "WeightedScoreSorter") -> bool: - self.__update_stats() - return self.weighted_score > other_request_weighted_score.weighted_score - - def __eq__(self, other_request_weighted_score: object) -> bool: - if not isinstance(other_request_weighted_score, WeightedScoreSorter): - return NotImplemented - return self.weighted_score == other_request_weighted_score.weighted_score - - def __update_stats(self): - self.wait_time = time.time() - self.request_arrival_time - assert TimeAndLengthScorer_Instance is not None - self.weighted_score = TimeAndLengthScorer_Instance.score( - self.wait_time, self.request_length - ) diff --git a/vllm/v1/core/sched/request_queue.py b/vllm/v1/core/sched/request_queue.py index e339b5de6255f..a2e2ba6c67d31 100644 --- a/vllm/v1/core/sched/request_queue.py +++ b/vllm/v1/core/sched/request_queue.py @@ -8,7 +8,10 @@ from collections.abc import Iterable, Iterator from enum import Enum from typing import Any -from vllm.v1.core.sched.policy.weighted_score_sorter import WeightedScoreSorter +from vllm.v1.core.sched.policy.shortest_job_first import ( + TimeAndLengthScorer, + WeightedScoreSorter, +) from vllm.v1.request import Request @@ -228,9 +231,13 @@ class SJFRequestQueue(RequestHeap): """A Shortest Job First (SJF) queue where requests are ordered by weighted score. Requests with higher weighted scores (shorter jobs) are processed first.""" + def __init__(self): + super().__init__() + self.scorer = TimeAndLengthScorer() + def _request_to_heap(self, request: Request) -> WeightedScoreSorter: """Convert request to `WeightedScoreSorter` for heap.""" - return WeightedScoreSorter(request) + return WeightedScoreSorter(request, self.scorer) def _heap_to_request(self, element: WeightedScoreSorter) -> Request: """Extract request from the `WeightedScoreSorter`.""" From 58615e5889b91e48752fe5db47fcaccaf513885b Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 18 Dec 2025 17:36:36 +0100 Subject: [PATCH 17/20] docstring Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: weichen --- vllm/config/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index bb0ecf38a7468..a9640943b8c18 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -106,7 +106,7 @@ class SchedulerConfig: of arrival.\n - "priority" means requests are handled based on given priority (lower value means earlier handling) and time of arrival deciding any ties).\n - - "sjf" means shortest job first. Requests are scheduled by prompt length + - "sjf" means shortest job first. Requests are scheduled by prompt length (shortest first), with aging to prevent starvation.""" disable_chunked_mm_input: bool = False From da9d1531121d3debd7ebb745d420c97d1226bbb8 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 18 Dec 2025 17:37:06 +0100 Subject: [PATCH 18/20] Delete now empty file Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: weichen --- .../v1/core/sched/policy/normalized_scorer.py | 112 ------------------ 1 file changed, 112 deletions(-) delete mode 100644 vllm/v1/core/sched/policy/normalized_scorer.py diff --git a/vllm/v1/core/sched/policy/normalized_scorer.py b/vllm/v1/core/sched/policy/normalized_scorer.py deleted file mode 100644 index 9460d7454085c..0000000000000 --- a/vllm/v1/core/sched/policy/normalized_scorer.py +++ /dev/null @@ -1,112 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import math - -from vllm.logger import init_logger - -logger = init_logger(__name__) - - -class ScoreDim: - """ - Normalized scoring dimension. - """ - - def __init__( - self, name: str, median: float, norm_scale=0.0, weight=0.5, reverse=False - ): - self.name = name - self.median = median - if norm_scale != 0.0: - self.norm_scale = norm_scale - else: - self.norm_scale = 1 / median - self.weight = weight - self.reverse = reverse - - -class NormalizedScorer: - """ - Normalize unbounded N-dimensional values into a composite score using the Sigmoid - function. - """ - - def __init__(self, dim_list: list[ScoreDim]) -> None: - """ - Initialize the scorer with a list of scoring dimensions. - - Args: - dim_list: A list of `ScoreDim` objects. Each dimension must define a - median reference point, scaling factor, and weight. - """ - self.dim_list = dim_list - self.dim_count = len(dim_list) - - @staticmethod - def _sigmoid_normalize(value, median, norm_scale): - """Sigmoid function: Maps value to (0, 1).""" - return 1 / (1 + math.exp(-norm_scale * (value - median))) - - @staticmethod - def _inv_sigmoid_normalize(value, median, norm_scale): - """Inverse Sigmoid: Used for dimensions where a larger value yields a lower - score. - """ - # Equivalent to sigmoid(-x), but more numerically stable. - return 1 / (1 + math.exp(norm_scale * (value - median))) - - def score(self, *dims: float) -> float: - """ - Compute the composite score. - Larger value → higher score → use forward Sigmoid. - Smaller value → higher score → use inverse Sigmoid. - """ - if len(dims) > self.dim_count: - raise ValueError( - f"Dim num({len(dims)}) exceeds max num dim({self.dim_count})" - ) - - final_score = 0.0 - for idx, dim_value in enumerate(dims): - dim_info = self.dim_list[idx] - if dim_info.reverse: - score = self._inv_sigmoid_normalize( - dim_value, dim_info.median, dim_info.norm_scale - ) - else: - score = self._sigmoid_normalize( - dim_value, dim_info.median, dim_info.norm_scale - ) - logger.debug("%s(%s) : %.10f", dim_info.name, dim_info.reverse, score) - - # Weighted summation. - final_score += score * dim_info.weight - return max(0.0, min(1.0, final_score)) # Clamp to [0, 1]. - - -class TimeAndLengthScorer(NormalizedScorer): - """ - Scorer for time and length dimensions; defaults to forward scoring with equal - weights (0.5 each). - """ - - def __init__( - self, - time_median, - length_median, - time_scale=0.0, - length_scale=0.0, - time_weight=0.5, - length_weight=0.5, - reverse_time=False, - reverse_len=False, - ) -> None: - dim_list = [ - ScoreDim("time", time_median, time_scale, time_weight, reverse_time), - ScoreDim("length", length_median, length_scale, length_weight, reverse_len), - ] - super().__init__(dim_list) - - def score(self, *dims: float) -> float: - assert len(dims) == 2 - return super().score(*dims) From 0000d981d2556a78223666e777b53b68f0cb22d5 Mon Sep 17 00:00:00 2001 From: weichen Date: Wed, 24 Dec 2025 16:26:52 +0800 Subject: [PATCH 19/20] add ut for sjf scheduler policy Signed-off-by: weichen --- tests/v1/core/test_sjf_scheduler.py | 229 ++++++++++++++++++++++++++++ 1 file changed, 229 insertions(+) create mode 100644 tests/v1/core/test_sjf_scheduler.py diff --git a/tests/v1/core/test_sjf_scheduler.py b/tests/v1/core/test_sjf_scheduler.py new file mode 100644 index 0000000000000..cd87e80bbef88 --- /dev/null +++ b/tests/v1/core/test_sjf_scheduler.py @@ -0,0 +1,229 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from unittest.mock import Mock, patch + +import pytest +import torch + +from vllm.config import ( + CacheConfig, + ModelConfig, + SchedulerConfig, + VllmConfig, +) +from vllm.sampling_params import SamplingParams +from vllm.utils.hashing import sha256 +from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash +from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheConfig, + KVCacheGroupSpec, +) +from vllm.v1.request import Request +from vllm.v1.structured_output import StructuredOutputManager + +from .utils import EOS_TOKEN_ID + +pytestmark = pytest.mark.cpu_test + + +def create_scheduler_with_sjf( + model: str = "facebook/opt-125m", + max_num_seqs: int = 16, + max_num_batched_tokens: int = 8192, + enable_prefix_caching: bool = False, + long_prefill_token_threshold: int = 0, + num_blocks: int = 10000, + block_size: int = 16, + max_model_len: int | None = None, +) -> Scheduler: + """Create scheduler with SJF policy enabled. + + Args: + model: model under test + max_num_seqs: max sequences to schedule + max_num_batch_tokens: max num tokens to batch + enable_prefix_caching: optionally force APC config + (True/False) or use default + (False) + + Returns: + {class}`Scheduler` instance with SJF scheduling + """ + model_config = ModelConfig( + model=model, + trust_remote_code=True, + dtype="float16", + seed=42, + ) + if max_model_len is None: + max_model_len = max_num_batched_tokens + scheduler_config = SchedulerConfig( + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + max_model_len=max_model_len, + long_prefill_token_threshold=long_prefill_token_threshold, + enable_chunked_prefill=True, + is_encoder_decoder=model_config.is_encoder_decoder, + policy="sjf", # Enable SJF scheduling + ) + + cache_config = CacheConfig( + block_size=block_size, + gpu_memory_utilization=0.9, + swap_space=0, + cache_dtype="auto", + enable_prefix_caching=enable_prefix_caching, + ) + + vllm_config = VllmConfig( + scheduler_config=scheduler_config, + model_config=model_config, + cache_config=cache_config + ) + kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False) + ) + ], + ) + cache_config.num_gpu_blocks = num_blocks + return Scheduler( + vllm_config=vllm_config, + kv_cache_config=kv_cache_config, + log_stats=True, + structured_output_manager=StructuredOutputManager(vllm_config), + block_size=block_size, + ) + + +_none_hash_initialized = False + + +def create_requests_for_sjf( + num_requests: int, + prompt_lengths: list[int], + arrival_times: list[float] | None = None, + max_tokens: int = 16, + stop_token_ids: list[int] | None = None, + prompt_logprobs: int | None = None, + starting_idx: int = 0, + same_prompt: bool = False, + block_size: int = 16, + req_ids: list[str] | None = None, +): + """Create requests with specified prompt lengths and arrival times for SJF testing.""" + assert len(prompt_lengths) == num_requests + if arrival_times is not None: + assert len(arrival_times) == num_requests + else: + arrival_times = [float(i) for i in range(num_requests)] + + global _none_hash_initialized + if not _none_hash_initialized: + init_none_hash(sha256) + _none_hash_initialized = True + + block_hasher = get_request_block_hasher(block_size, sha256) + sampling_params = SamplingParams( + ignore_eos=False, + max_tokens=max_tokens, + stop_token_ids=stop_token_ids, + prompt_logprobs=prompt_logprobs, + ) + requests = [] + + if req_ids: + assert len(req_ids) == num_requests + else: + req_ids = [f"{i + starting_idx}" for i in range(num_requests)] + + for i in range(num_requests): + num_tokens = prompt_lengths[i] + prompt_token_ids = ( + [starting_idx] * num_tokens + if same_prompt + else [i + starting_idx] * num_tokens + ) + request = Request( + request_id=req_ids[i], + prompt_token_ids=prompt_token_ids, + sampling_params=sampling_params, + pooling_params=None, + eos_token_id=EOS_TOKEN_ID, + arrival_time=arrival_times[i], + priority=1, # SJF ignores priority, set to default + block_hasher=block_hasher, + ) + requests.append(request) + return requests + + +def test_sjf_scheduling_basic_ordering(): + """Test that requests are scheduled in SJF order + (shorter job = higher priority).""" + scheduler = create_scheduler_with_sjf() + + # Create requests with different prompt lengths + # Shorter jobs should be scheduled first + prompt_lengths = [100, 50, 75] # Add in non-length order + arrival_times = [0.0, 0.0, 0.0] # All same arrival times + requests = create_requests_for_sjf( + num_requests=3, prompt_lengths=prompt_lengths, arrival_times=arrival_times + ) + + # Add requests in non-length order + for request in requests: + scheduler.add_request(request) + + # Schedule and verify SJF order + output = scheduler.schedule() + + # Should schedule all requests since they fit in budget + assert len(output.scheduled_new_reqs) == 3 + + # Verify they are scheduled in length order (shortest first): + # req_1 (length 50), req_2 (length 75), req_0 (length 100) + scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs] + assert scheduled_req_ids == ["1", "2", "0"] + + +def test_sjf_scheduling_waiting_time_tiebreaker_fixed(): + """Test that waiting time is used as tiebreaker when lengths are equal. + """ + scheduler = create_scheduler_with_sjf() + + # Mock current time, fixed at 10.0 seconds + current_time = 10.0 + time_patch = Mock(return_value=current_time) + + with patch('time.time', time_patch): + # Create 3 requests with same length but different arrival times + prompt_lengths = [64, 64, 64] # All requests have same length + # Arrival times: req1 earliest, req2 second, req0 latest + arrival_times = [3.0, 1.0, 2.0] + + requests = create_requests_for_sjf( + num_requests=3, + prompt_lengths=prompt_lengths, + arrival_times=arrival_times + ) + + # Add requests to scheduler (order of addition doesn't affect final scheduling order) + for request in requests: + scheduler.add_request(request) + + # Execute scheduling + output = scheduler.schedule() + + # Verify all requests are scheduled (resources are sufficient) + assert len(output.scheduled_new_reqs) == 3 + + # Verify scheduling order: longest wait first + # Expected order: req1 (waited 9.0s), req2 (waited 8.0s), req0 (waited 7.0s) + scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs] + assert scheduled_req_ids == ["1", "2", "0"] From 0431508388b8e130a170ef017d760d873a80ee23 Mon Sep 17 00:00:00 2001 From: weichen Date: Wed, 24 Dec 2025 16:30:07 +0800 Subject: [PATCH 20/20] Use request_id as the identifier when removing a request Signed-off-by: weichen --- vllm/v1/core/sched/policy/shortest_job_first.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/core/sched/policy/shortest_job_first.py b/vllm/v1/core/sched/policy/shortest_job_first.py index 4d9ba0de11b65..409f9e6ec96ca 100644 --- a/vllm/v1/core/sched/policy/shortest_job_first.py +++ b/vllm/v1/core/sched/policy/shortest_job_first.py @@ -132,7 +132,7 @@ class WeightedScoreSorter: def __eq__(self, other_request_weighted_score: object) -> bool: if not isinstance(other_request_weighted_score, WeightedScoreSorter): return NotImplemented - return self.weighted_score == other_request_weighted_score.weighted_score + return self.request.request_id == other_request_weighted_score.request.request_id def __update_stats(self): self.wait_time = time.time() - self.request_arrival_time