diff --git a/tests/v1/core/test_sjf_scheduler.py b/tests/v1/core/test_sjf_scheduler.py new file mode 100644 index 0000000000000..cd87e80bbef88 --- /dev/null +++ b/tests/v1/core/test_sjf_scheduler.py @@ -0,0 +1,229 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from unittest.mock import Mock, patch + +import pytest +import torch + +from vllm.config import ( + CacheConfig, + ModelConfig, + SchedulerConfig, + VllmConfig, +) +from vllm.sampling_params import SamplingParams +from vllm.utils.hashing import sha256 +from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash +from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheConfig, + KVCacheGroupSpec, +) +from vllm.v1.request import Request +from vllm.v1.structured_output import StructuredOutputManager + +from .utils import EOS_TOKEN_ID + +pytestmark = pytest.mark.cpu_test + + +def create_scheduler_with_sjf( + model: str = "facebook/opt-125m", + max_num_seqs: int = 16, + max_num_batched_tokens: int = 8192, + enable_prefix_caching: bool = False, + long_prefill_token_threshold: int = 0, + num_blocks: int = 10000, + block_size: int = 16, + max_model_len: int | None = None, +) -> Scheduler: + """Create scheduler with SJF policy enabled. + + Args: + model: model under test + max_num_seqs: max sequences to schedule + max_num_batch_tokens: max num tokens to batch + enable_prefix_caching: optionally force APC config + (True/False) or use default + (False) + + Returns: + {class}`Scheduler` instance with SJF scheduling + """ + model_config = ModelConfig( + model=model, + trust_remote_code=True, + dtype="float16", + seed=42, + ) + if max_model_len is None: + max_model_len = max_num_batched_tokens + scheduler_config = SchedulerConfig( + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + max_model_len=max_model_len, + long_prefill_token_threshold=long_prefill_token_threshold, + enable_chunked_prefill=True, + is_encoder_decoder=model_config.is_encoder_decoder, + policy="sjf", # Enable SJF scheduling + ) + + cache_config = CacheConfig( + block_size=block_size, + gpu_memory_utilization=0.9, + swap_space=0, + cache_dtype="auto", + enable_prefix_caching=enable_prefix_caching, + ) + + vllm_config = VllmConfig( + scheduler_config=scheduler_config, + model_config=model_config, + cache_config=cache_config + ) + kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False) + ) + ], + ) + cache_config.num_gpu_blocks = num_blocks + return Scheduler( + vllm_config=vllm_config, + kv_cache_config=kv_cache_config, + log_stats=True, + structured_output_manager=StructuredOutputManager(vllm_config), + block_size=block_size, + ) + + +_none_hash_initialized = False + + +def create_requests_for_sjf( + num_requests: int, + prompt_lengths: list[int], + arrival_times: list[float] | None = None, + max_tokens: int = 16, + stop_token_ids: list[int] | None = None, + prompt_logprobs: int | None = None, + starting_idx: int = 0, + same_prompt: bool = False, + block_size: int = 16, + req_ids: list[str] | None = None, +): + """Create requests with specified prompt lengths and arrival times for SJF testing.""" + assert len(prompt_lengths) == num_requests + if arrival_times is not None: + assert len(arrival_times) == num_requests + else: + arrival_times = [float(i) for i in range(num_requests)] + + global _none_hash_initialized + if not _none_hash_initialized: + init_none_hash(sha256) + _none_hash_initialized = True + + block_hasher = get_request_block_hasher(block_size, sha256) + sampling_params = SamplingParams( + ignore_eos=False, + max_tokens=max_tokens, + stop_token_ids=stop_token_ids, + prompt_logprobs=prompt_logprobs, + ) + requests = [] + + if req_ids: + assert len(req_ids) == num_requests + else: + req_ids = [f"{i + starting_idx}" for i in range(num_requests)] + + for i in range(num_requests): + num_tokens = prompt_lengths[i] + prompt_token_ids = ( + [starting_idx] * num_tokens + if same_prompt + else [i + starting_idx] * num_tokens + ) + request = Request( + request_id=req_ids[i], + prompt_token_ids=prompt_token_ids, + sampling_params=sampling_params, + pooling_params=None, + eos_token_id=EOS_TOKEN_ID, + arrival_time=arrival_times[i], + priority=1, # SJF ignores priority, set to default + block_hasher=block_hasher, + ) + requests.append(request) + return requests + + +def test_sjf_scheduling_basic_ordering(): + """Test that requests are scheduled in SJF order + (shorter job = higher priority).""" + scheduler = create_scheduler_with_sjf() + + # Create requests with different prompt lengths + # Shorter jobs should be scheduled first + prompt_lengths = [100, 50, 75] # Add in non-length order + arrival_times = [0.0, 0.0, 0.0] # All same arrival times + requests = create_requests_for_sjf( + num_requests=3, prompt_lengths=prompt_lengths, arrival_times=arrival_times + ) + + # Add requests in non-length order + for request in requests: + scheduler.add_request(request) + + # Schedule and verify SJF order + output = scheduler.schedule() + + # Should schedule all requests since they fit in budget + assert len(output.scheduled_new_reqs) == 3 + + # Verify they are scheduled in length order (shortest first): + # req_1 (length 50), req_2 (length 75), req_0 (length 100) + scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs] + assert scheduled_req_ids == ["1", "2", "0"] + + +def test_sjf_scheduling_waiting_time_tiebreaker_fixed(): + """Test that waiting time is used as tiebreaker when lengths are equal. + """ + scheduler = create_scheduler_with_sjf() + + # Mock current time, fixed at 10.0 seconds + current_time = 10.0 + time_patch = Mock(return_value=current_time) + + with patch('time.time', time_patch): + # Create 3 requests with same length but different arrival times + prompt_lengths = [64, 64, 64] # All requests have same length + # Arrival times: req1 earliest, req2 second, req0 latest + arrival_times = [3.0, 1.0, 2.0] + + requests = create_requests_for_sjf( + num_requests=3, + prompt_lengths=prompt_lengths, + arrival_times=arrival_times + ) + + # Add requests to scheduler (order of addition doesn't affect final scheduling order) + for request in requests: + scheduler.add_request(request) + + # Execute scheduling + output = scheduler.schedule() + + # Verify all requests are scheduled (resources are sufficient) + assert len(output.scheduled_new_reqs) == 3 + + # Verify scheduling order: longest wait first + # Expected order: req1 (waited 9.0s), req2 (waited 8.0s), req0 (waited 7.0s) + scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs] + assert scheduled_req_ids == ["1", "2", "0"] diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index 8abbe8ba0103e..a9640943b8c18 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -20,7 +20,7 @@ if TYPE_CHECKING: logger = init_logger(__name__) RunnerType = Literal["generate", "pooling", "draft"] -SchedulerPolicy = Literal["fcfs", "priority"] +SchedulerPolicy = Literal["fcfs", "priority", "sjf"] @config @@ -105,7 +105,9 @@ class SchedulerConfig: - "fcfs" means first come first served, i.e. requests are handled in order of arrival.\n - "priority" means requests are handled based on given priority (lower - value means earlier handling) and time of arrival deciding any ties).""" + value means earlier handling) and time of arrival deciding any ties).\n + - "sjf" means shortest job first. Requests are scheduled by prompt length + (shortest first), with aging to prevent starvation.""" disable_chunked_mm_input: bool = False """If set to true and chunked prefill is enabled, we do not want to diff --git a/vllm/v1/core/sched/policy/__init__.py b/vllm/v1/core/sched/policy/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/v1/core/sched/policy/shortest_job_first.py b/vllm/v1/core/sched/policy/shortest_job_first.py new file mode 100644 index 0000000000000..409f9e6ec96ca --- /dev/null +++ b/vllm/v1/core/sched/policy/shortest_job_first.py @@ -0,0 +1,139 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math +import time +from functools import total_ordering + +from vllm.logger import init_logger +from vllm.v1.request import Request + +logger = init_logger(__name__) + + +class ScoreDim: + """ + Normalized scoring dimension. + """ + + def __init__( + self, name: str, median: float, norm_scale=0.0, weight=0.5, reverse=False + ): + self.name = name + self.median = median + if norm_scale != 0.0: + self.norm_scale = norm_scale + else: + self.norm_scale = 1 / median + self.weight = weight + self.reverse = reverse + + +class NormalizedScorer: + """ + Normalize unbounded N-dimensional values into a composite score using the Sigmoid + function. + """ + + def __init__(self, dim_list: list[ScoreDim]) -> None: + """ + Initialize the scorer with a list of scoring dimensions. + + Args: + dim_list: A list of `ScoreDim` objects. Each dimension must define a + median reference point, scaling factor, and weight. + """ + self.dim_list = dim_list + self.dim_count = len(dim_list) + + @staticmethod + def _sigmoid_normalize(value, median, norm_scale): + """Sigmoid function: Maps value to (0, 1).""" + return 1 / (1 + math.exp(-norm_scale * (value - median))) + + @staticmethod + def _inv_sigmoid_normalize(value, median, norm_scale): + """Inverse Sigmoid: Used for dimensions where a larger value yields a lower + score. + """ + # Equivalent to sigmoid(-x), but more numerically stable. + return 1 / (1 + math.exp(norm_scale * (value - median))) + + def score(self, *dims: float) -> float: + """ + Compute the composite score. + Larger value → higher score → use forward Sigmoid. + Smaller value → higher score → use inverse Sigmoid. + """ + if len(dims) > self.dim_count: + raise ValueError( + f"Dim num({len(dims)}) exceeds max num dim({self.dim_count})" + ) + + final_score = 0.0 + for idx, dim_value in enumerate(dims): + dim_info = self.dim_list[idx] + if dim_info.reverse: + score = self._inv_sigmoid_normalize( + dim_value, dim_info.median, dim_info.norm_scale + ) + else: + score = self._sigmoid_normalize( + dim_value, dim_info.median, dim_info.norm_scale + ) + logger.debug("%s(%s) : %.10f", dim_info.name, dim_info.reverse, score) + + # Weighted summation. + final_score += score * dim_info.weight + return max(0.0, min(1.0, final_score)) # Clamp to [0, 1]. + + +class TimeAndLengthScorer(NormalizedScorer): + """ + Scorer for time and length dimensions; defaults to forward scoring with equal + weights (0.5 each). + """ + + def __init__( + self, + time_median=5, + length_median=1024 * 32, + time_scale=0.0, + length_scale=0.0, + time_weight=0.5, + length_weight=0.5, + reverse_time=False, + reverse_len=True, + ) -> None: + dim_list = [ + ScoreDim("time", time_median, time_scale, time_weight, reverse_time), + ScoreDim("length", length_median, length_scale, length_weight, reverse_len), + ] + super().__init__(dim_list) + + def score(self, *dims: float) -> float: + assert len(dims) == 2 + return super().score(*dims) + + +@total_ordering +class WeightedScoreSorter: + def __init__(self, request: Request, scorer: TimeAndLengthScorer): + self.request = request + self.scorer = scorer + assert request.prompt_token_ids is not None + self.request_length = len(request.prompt_token_ids) + self.request_arrival_time = request.arrival_time + self.__update_stats() + + def __lt__(self, other_request_weighted_score: "WeightedScoreSorter") -> bool: + self.__update_stats() + return self.weighted_score > other_request_weighted_score.weighted_score + + def __eq__(self, other_request_weighted_score: object) -> bool: + if not isinstance(other_request_weighted_score, WeightedScoreSorter): + return NotImplemented + return self.request.request_id == other_request_weighted_score.request.request_id + + def __update_stats(self): + self.wait_time = time.time() - self.request_arrival_time + self.weighted_score = self.scorer.score(self.wait_time, self.request_length) diff --git a/vllm/v1/core/sched/request_queue.py b/vllm/v1/core/sched/request_queue.py index a00ca1912b0f3..a2e2ba6c67d31 100644 --- a/vllm/v1/core/sched/request_queue.py +++ b/vllm/v1/core/sched/request_queue.py @@ -6,7 +6,12 @@ from abc import ABC, abstractmethod from collections import deque from collections.abc import Iterable, Iterator from enum import Enum +from typing import Any +from vllm.v1.core.sched.policy.shortest_job_first import ( + TimeAndLengthScorer, + WeightedScoreSorter, +) from vllm.v1.request import Request @@ -15,6 +20,7 @@ class SchedulingPolicy(Enum): FCFS = "fcfs" PRIORITY = "priority" + SJF = "sjf" class RequestQueue(ABC): @@ -133,59 +139,57 @@ class FCFSRequestQueue(deque[Request], RequestQueue): return super().__reversed__() -class PriorityRequestQueue(RequestQueue): - """ - A priority queue that supports heap operations. - - Respects the ordering defined in the Request class, where - requests with a smaller value of `priority` are processed first. - If multiple requests have the same priority, the one with the earlier - `arrival_time` is processed first. - """ +class RequestHeap(RequestQueue): + """A queue that supports heap operations.""" def __init__(self) -> None: - self._heap: list[Request] = [] + self._heap: list = [] + + def _request_to_heap(self, request: Request) -> Any: + """Convert a request to the appropriate heap element.""" + raise NotImplementedError + + def _heap_to_request(self, element: Any) -> Request: + """Extract the request from a heap element.""" + raise NotImplementedError def add_request(self, request: Request) -> None: - """Add a request to the queue according to priority policy.""" - heapq.heappush(self._heap, request) + """Add a request to the queue according to heap priority.""" + heapq.heappush(self._heap, self._request_to_heap(request)) def pop_request(self) -> Request: - """Pop a request from the queue according to priority policy.""" + """Pop the highest priority request from the heap.""" if not self._heap: raise IndexError("pop from empty heap") - return heapq.heappop(self._heap) + return self._heap_to_request(heapq.heappop(self._heap)) def peek_request(self) -> Request: - """Peek at the next request in the queue without removing it.""" + """Peek at the highest priority request in the heap without removing it.""" if not self._heap: raise IndexError("peek from empty heap") - return self._heap[0] + return self._heap_to_request(self._heap[0]) def prepend_request(self, request: Request) -> None: - """Add a request to the queue according to priority policy. - - Note: In a priority queue, there is no concept of prepending to the - front. Requests are ordered by (priority, arrival_time).""" + """Add a request to the heap. In heap-based queues there is no beginning as + elements are ordered by priority/score. This behaves like add_request.""" self.add_request(request) def prepend_requests(self, requests: RequestQueue) -> None: - """Add all requests from another queue according to priority policy. - - Note: In a priority queue, there is no concept of prepending to the - front. Requests are ordered by (priority, arrival_time).""" + """Add all requests from another queue to the heap. In heap-based queues there + is no beginning as elements are ordered by priority/score. This behaves like + add_request.""" for request in requests: self.add_request(request) def remove_request(self, request: Request) -> None: - """Remove a specific request from the queue.""" - self._heap.remove(request) + """Remove a specific request from the heap.""" + self._heap.remove(self._request_to_heap(request)) heapq.heapify(self._heap) def remove_requests(self, requests: Iterable[Request]) -> None: - """Remove multiple specific requests from the queue.""" - requests_to_remove = requests if isinstance(requests, set) else set(requests) - self._heap = [r for r in self._heap if r not in requests_to_remove] + """Remove multiple specific requests from the heap.""" + remove = requests if isinstance(requests, set) else set(requests) + self._heap = [h for h in self._heap if self._heap_to_request(h) not in remove] heapq.heapify(self._heap) def __bool__(self) -> bool: @@ -197,21 +201,56 @@ class PriorityRequestQueue(RequestQueue): return len(self._heap) def __iter__(self) -> Iterator[Request]: - """Iterate over the queue according to priority policy.""" + """Iterate over the queue to heap order.""" heap_copy = self._heap[:] while heap_copy: - yield heapq.heappop(heap_copy) + yield self._heap_to_request(heapq.heappop(heap_copy)) def __reversed__(self) -> Iterator[Request]: - """Iterate over the queue in reverse priority order.""" + """Iterate over the queue in reverse heap order.""" return reversed(list(self)) +class PriorityRequestQueue(RequestHeap): + """A priority queue that supports heap operations. + + Respects the ordering defined in the Request class, where requests with a smaller + value of `priority` are processed first. If multiple requests have the same + priority, the one with the earlier `arrival_time` is processed first.""" + + def _request_to_heap(self, request: Request) -> Request: + """For priority queue, the heap element is the request itself.""" + return request + + def _heap_to_request(self, element: Request) -> Request: + """Extract request from heap element with type checking.""" + return element + + +class SJFRequestQueue(RequestHeap): + """A Shortest Job First (SJF) queue where requests are ordered by weighted score. + Requests with higher weighted scores (shorter jobs) are processed first.""" + + def __init__(self): + super().__init__() + self.scorer = TimeAndLengthScorer() + + def _request_to_heap(self, request: Request) -> WeightedScoreSorter: + """Convert request to `WeightedScoreSorter` for heap.""" + return WeightedScoreSorter(request, self.scorer) + + def _heap_to_request(self, element: WeightedScoreSorter) -> Request: + """Extract request from the `WeightedScoreSorter`.""" + return element.request + + def create_request_queue(policy: SchedulingPolicy) -> RequestQueue: """Create request queue based on scheduling policy.""" if policy == SchedulingPolicy.PRIORITY: return PriorityRequestQueue() elif policy == SchedulingPolicy.FCFS: return FCFSRequestQueue() + elif policy == SchedulingPolicy.SJF: + return SJFRequestQueue() else: raise ValueError(f"Unknown scheduling policy: {policy}")