Merge 0431508388b8e130a170ef017d760d873a80ee23 into 254f6b986720c92ddf97fbb1a6a6465da8e87e29

This commit is contained in:
weichen 2025-12-25 00:06:54 +00:00 committed by GitHub
commit 831501ec55
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 443 additions and 34 deletions

View File

@ -0,0 +1,229 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import Mock, patch
import pytest
import torch
from vllm.config import (
CacheConfig,
ModelConfig,
SchedulerConfig,
VllmConfig,
)
from vllm.sampling_params import SamplingParams
from vllm.utils.hashing import sha256
from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheConfig,
KVCacheGroupSpec,
)
from vllm.v1.request import Request
from vllm.v1.structured_output import StructuredOutputManager
from .utils import EOS_TOKEN_ID
pytestmark = pytest.mark.cpu_test
def create_scheduler_with_sjf(
model: str = "facebook/opt-125m",
max_num_seqs: int = 16,
max_num_batched_tokens: int = 8192,
enable_prefix_caching: bool = False,
long_prefill_token_threshold: int = 0,
num_blocks: int = 10000,
block_size: int = 16,
max_model_len: int | None = None,
) -> Scheduler:
"""Create scheduler with SJF policy enabled.
Args:
model: model under test
max_num_seqs: max sequences to schedule
max_num_batch_tokens: max num tokens to batch
enable_prefix_caching: optionally force APC config
(True/False) or use default
(False)
Returns:
{class}`Scheduler` instance with SJF scheduling
"""
model_config = ModelConfig(
model=model,
trust_remote_code=True,
dtype="float16",
seed=42,
)
if max_model_len is None:
max_model_len = max_num_batched_tokens
scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
max_model_len=max_model_len,
long_prefill_token_threshold=long_prefill_token_threshold,
enable_chunked_prefill=True,
is_encoder_decoder=model_config.is_encoder_decoder,
policy="sjf", # Enable SJF scheduling
)
cache_config = CacheConfig(
block_size=block_size,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
enable_prefix_caching=enable_prefix_caching,
)
vllm_config = VllmConfig(
scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config
)
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False)
)
],
)
cache_config.num_gpu_blocks = num_blocks
return Scheduler(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
log_stats=True,
structured_output_manager=StructuredOutputManager(vllm_config),
block_size=block_size,
)
_none_hash_initialized = False
def create_requests_for_sjf(
num_requests: int,
prompt_lengths: list[int],
arrival_times: list[float] | None = None,
max_tokens: int = 16,
stop_token_ids: list[int] | None = None,
prompt_logprobs: int | None = None,
starting_idx: int = 0,
same_prompt: bool = False,
block_size: int = 16,
req_ids: list[str] | None = None,
):
"""Create requests with specified prompt lengths and arrival times for SJF testing."""
assert len(prompt_lengths) == num_requests
if arrival_times is not None:
assert len(arrival_times) == num_requests
else:
arrival_times = [float(i) for i in range(num_requests)]
global _none_hash_initialized
if not _none_hash_initialized:
init_none_hash(sha256)
_none_hash_initialized = True
block_hasher = get_request_block_hasher(block_size, sha256)
sampling_params = SamplingParams(
ignore_eos=False,
max_tokens=max_tokens,
stop_token_ids=stop_token_ids,
prompt_logprobs=prompt_logprobs,
)
requests = []
if req_ids:
assert len(req_ids) == num_requests
else:
req_ids = [f"{i + starting_idx}" for i in range(num_requests)]
for i in range(num_requests):
num_tokens = prompt_lengths[i]
prompt_token_ids = (
[starting_idx] * num_tokens
if same_prompt
else [i + starting_idx] * num_tokens
)
request = Request(
request_id=req_ids[i],
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
arrival_time=arrival_times[i],
priority=1, # SJF ignores priority, set to default
block_hasher=block_hasher,
)
requests.append(request)
return requests
def test_sjf_scheduling_basic_ordering():
"""Test that requests are scheduled in SJF order
(shorter job = higher priority)."""
scheduler = create_scheduler_with_sjf()
# Create requests with different prompt lengths
# Shorter jobs should be scheduled first
prompt_lengths = [100, 50, 75] # Add in non-length order
arrival_times = [0.0, 0.0, 0.0] # All same arrival times
requests = create_requests_for_sjf(
num_requests=3, prompt_lengths=prompt_lengths, arrival_times=arrival_times
)
# Add requests in non-length order
for request in requests:
scheduler.add_request(request)
# Schedule and verify SJF order
output = scheduler.schedule()
# Should schedule all requests since they fit in budget
assert len(output.scheduled_new_reqs) == 3
# Verify they are scheduled in length order (shortest first):
# req_1 (length 50), req_2 (length 75), req_0 (length 100)
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
assert scheduled_req_ids == ["1", "2", "0"]
def test_sjf_scheduling_waiting_time_tiebreaker_fixed():
"""Test that waiting time is used as tiebreaker when lengths are equal.
"""
scheduler = create_scheduler_with_sjf()
# Mock current time, fixed at 10.0 seconds
current_time = 10.0
time_patch = Mock(return_value=current_time)
with patch('time.time', time_patch):
# Create 3 requests with same length but different arrival times
prompt_lengths = [64, 64, 64] # All requests have same length
# Arrival times: req1 earliest, req2 second, req0 latest
arrival_times = [3.0, 1.0, 2.0]
requests = create_requests_for_sjf(
num_requests=3,
prompt_lengths=prompt_lengths,
arrival_times=arrival_times
)
# Add requests to scheduler (order of addition doesn't affect final scheduling order)
for request in requests:
scheduler.add_request(request)
# Execute scheduling
output = scheduler.schedule()
# Verify all requests are scheduled (resources are sufficient)
assert len(output.scheduled_new_reqs) == 3
# Verify scheduling order: longest wait first
# Expected order: req1 (waited 9.0s), req2 (waited 8.0s), req0 (waited 7.0s)
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
assert scheduled_req_ids == ["1", "2", "0"]

View File

@ -20,7 +20,7 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
RunnerType = Literal["generate", "pooling", "draft"]
SchedulerPolicy = Literal["fcfs", "priority"]
SchedulerPolicy = Literal["fcfs", "priority", "sjf"]
@config
@ -105,7 +105,9 @@ class SchedulerConfig:
- "fcfs" means first come first served, i.e. requests are handled in order
of arrival.\n
- "priority" means requests are handled based on given priority (lower
value means earlier handling) and time of arrival deciding any ties)."""
value means earlier handling) and time of arrival deciding any ties).\n
- "sjf" means shortest job first. Requests are scheduled by prompt length
(shortest first), with aging to prevent starvation."""
disable_chunked_mm_input: bool = False
"""If set to true and chunked prefill is enabled, we do not want to

View File

View File

@ -0,0 +1,139 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
import time
from functools import total_ordering
from vllm.logger import init_logger
from vllm.v1.request import Request
logger = init_logger(__name__)
class ScoreDim:
"""
Normalized scoring dimension.
"""
def __init__(
self, name: str, median: float, norm_scale=0.0, weight=0.5, reverse=False
):
self.name = name
self.median = median
if norm_scale != 0.0:
self.norm_scale = norm_scale
else:
self.norm_scale = 1 / median
self.weight = weight
self.reverse = reverse
class NormalizedScorer:
"""
Normalize unbounded N-dimensional values into a composite score using the Sigmoid
function.
"""
def __init__(self, dim_list: list[ScoreDim]) -> None:
"""
Initialize the scorer with a list of scoring dimensions.
Args:
dim_list: A list of `ScoreDim` objects. Each dimension must define a
median reference point, scaling factor, and weight.
"""
self.dim_list = dim_list
self.dim_count = len(dim_list)
@staticmethod
def _sigmoid_normalize(value, median, norm_scale):
"""Sigmoid function: Maps value to (0, 1)."""
return 1 / (1 + math.exp(-norm_scale * (value - median)))
@staticmethod
def _inv_sigmoid_normalize(value, median, norm_scale):
"""Inverse Sigmoid: Used for dimensions where a larger value yields a lower
score.
"""
# Equivalent to sigmoid(-x), but more numerically stable.
return 1 / (1 + math.exp(norm_scale * (value - median)))
def score(self, *dims: float) -> float:
"""
Compute the composite score.
Larger value higher score use forward Sigmoid.
Smaller value higher score use inverse Sigmoid.
"""
if len(dims) > self.dim_count:
raise ValueError(
f"Dim num({len(dims)}) exceeds max num dim({self.dim_count})"
)
final_score = 0.0
for idx, dim_value in enumerate(dims):
dim_info = self.dim_list[idx]
if dim_info.reverse:
score = self._inv_sigmoid_normalize(
dim_value, dim_info.median, dim_info.norm_scale
)
else:
score = self._sigmoid_normalize(
dim_value, dim_info.median, dim_info.norm_scale
)
logger.debug("%s(%s) : %.10f", dim_info.name, dim_info.reverse, score)
# Weighted summation.
final_score += score * dim_info.weight
return max(0.0, min(1.0, final_score)) # Clamp to [0, 1].
class TimeAndLengthScorer(NormalizedScorer):
"""
Scorer for time and length dimensions; defaults to forward scoring with equal
weights (0.5 each).
"""
def __init__(
self,
time_median=5,
length_median=1024 * 32,
time_scale=0.0,
length_scale=0.0,
time_weight=0.5,
length_weight=0.5,
reverse_time=False,
reverse_len=True,
) -> None:
dim_list = [
ScoreDim("time", time_median, time_scale, time_weight, reverse_time),
ScoreDim("length", length_median, length_scale, length_weight, reverse_len),
]
super().__init__(dim_list)
def score(self, *dims: float) -> float:
assert len(dims) == 2
return super().score(*dims)
@total_ordering
class WeightedScoreSorter:
def __init__(self, request: Request, scorer: TimeAndLengthScorer):
self.request = request
self.scorer = scorer
assert request.prompt_token_ids is not None
self.request_length = len(request.prompt_token_ids)
self.request_arrival_time = request.arrival_time
self.__update_stats()
def __lt__(self, other_request_weighted_score: "WeightedScoreSorter") -> bool:
self.__update_stats()
return self.weighted_score > other_request_weighted_score.weighted_score
def __eq__(self, other_request_weighted_score: object) -> bool:
if not isinstance(other_request_weighted_score, WeightedScoreSorter):
return NotImplemented
return self.request.request_id == other_request_weighted_score.request.request_id
def __update_stats(self):
self.wait_time = time.time() - self.request_arrival_time
self.weighted_score = self.scorer.score(self.wait_time, self.request_length)

View File

@ -6,7 +6,12 @@ from abc import ABC, abstractmethod
from collections import deque
from collections.abc import Iterable, Iterator
from enum import Enum
from typing import Any
from vllm.v1.core.sched.policy.shortest_job_first import (
TimeAndLengthScorer,
WeightedScoreSorter,
)
from vllm.v1.request import Request
@ -15,6 +20,7 @@ class SchedulingPolicy(Enum):
FCFS = "fcfs"
PRIORITY = "priority"
SJF = "sjf"
class RequestQueue(ABC):
@ -133,59 +139,57 @@ class FCFSRequestQueue(deque[Request], RequestQueue):
return super().__reversed__()
class PriorityRequestQueue(RequestQueue):
"""
A priority queue that supports heap operations.
Respects the ordering defined in the Request class, where
requests with a smaller value of `priority` are processed first.
If multiple requests have the same priority, the one with the earlier
`arrival_time` is processed first.
"""
class RequestHeap(RequestQueue):
"""A queue that supports heap operations."""
def __init__(self) -> None:
self._heap: list[Request] = []
self._heap: list = []
def _request_to_heap(self, request: Request) -> Any:
"""Convert a request to the appropriate heap element."""
raise NotImplementedError
def _heap_to_request(self, element: Any) -> Request:
"""Extract the request from a heap element."""
raise NotImplementedError
def add_request(self, request: Request) -> None:
"""Add a request to the queue according to priority policy."""
heapq.heappush(self._heap, request)
"""Add a request to the queue according to heap priority."""
heapq.heappush(self._heap, self._request_to_heap(request))
def pop_request(self) -> Request:
"""Pop a request from the queue according to priority policy."""
"""Pop the highest priority request from the heap."""
if not self._heap:
raise IndexError("pop from empty heap")
return heapq.heappop(self._heap)
return self._heap_to_request(heapq.heappop(self._heap))
def peek_request(self) -> Request:
"""Peek at the next request in the queue without removing it."""
"""Peek at the highest priority request in the heap without removing it."""
if not self._heap:
raise IndexError("peek from empty heap")
return self._heap[0]
return self._heap_to_request(self._heap[0])
def prepend_request(self, request: Request) -> None:
"""Add a request to the queue according to priority policy.
Note: In a priority queue, there is no concept of prepending to the
front. Requests are ordered by (priority, arrival_time)."""
"""Add a request to the heap. In heap-based queues there is no beginning as
elements are ordered by priority/score. This behaves like add_request."""
self.add_request(request)
def prepend_requests(self, requests: RequestQueue) -> None:
"""Add all requests from another queue according to priority policy.
Note: In a priority queue, there is no concept of prepending to the
front. Requests are ordered by (priority, arrival_time)."""
"""Add all requests from another queue to the heap. In heap-based queues there
is no beginning as elements are ordered by priority/score. This behaves like
add_request."""
for request in requests:
self.add_request(request)
def remove_request(self, request: Request) -> None:
"""Remove a specific request from the queue."""
self._heap.remove(request)
"""Remove a specific request from the heap."""
self._heap.remove(self._request_to_heap(request))
heapq.heapify(self._heap)
def remove_requests(self, requests: Iterable[Request]) -> None:
"""Remove multiple specific requests from the queue."""
requests_to_remove = requests if isinstance(requests, set) else set(requests)
self._heap = [r for r in self._heap if r not in requests_to_remove]
"""Remove multiple specific requests from the heap."""
remove = requests if isinstance(requests, set) else set(requests)
self._heap = [h for h in self._heap if self._heap_to_request(h) not in remove]
heapq.heapify(self._heap)
def __bool__(self) -> bool:
@ -197,21 +201,56 @@ class PriorityRequestQueue(RequestQueue):
return len(self._heap)
def __iter__(self) -> Iterator[Request]:
"""Iterate over the queue according to priority policy."""
"""Iterate over the queue to heap order."""
heap_copy = self._heap[:]
while heap_copy:
yield heapq.heappop(heap_copy)
yield self._heap_to_request(heapq.heappop(heap_copy))
def __reversed__(self) -> Iterator[Request]:
"""Iterate over the queue in reverse priority order."""
"""Iterate over the queue in reverse heap order."""
return reversed(list(self))
class PriorityRequestQueue(RequestHeap):
"""A priority queue that supports heap operations.
Respects the ordering defined in the Request class, where requests with a smaller
value of `priority` are processed first. If multiple requests have the same
priority, the one with the earlier `arrival_time` is processed first."""
def _request_to_heap(self, request: Request) -> Request:
"""For priority queue, the heap element is the request itself."""
return request
def _heap_to_request(self, element: Request) -> Request:
"""Extract request from heap element with type checking."""
return element
class SJFRequestQueue(RequestHeap):
"""A Shortest Job First (SJF) queue where requests are ordered by weighted score.
Requests with higher weighted scores (shorter jobs) are processed first."""
def __init__(self):
super().__init__()
self.scorer = TimeAndLengthScorer()
def _request_to_heap(self, request: Request) -> WeightedScoreSorter:
"""Convert request to `WeightedScoreSorter` for heap."""
return WeightedScoreSorter(request, self.scorer)
def _heap_to_request(self, element: WeightedScoreSorter) -> Request:
"""Extract request from the `WeightedScoreSorter`."""
return element.request
def create_request_queue(policy: SchedulingPolicy) -> RequestQueue:
"""Create request queue based on scheduling policy."""
if policy == SchedulingPolicy.PRIORITY:
return PriorityRequestQueue()
elif policy == SchedulingPolicy.FCFS:
return FCFSRequestQueue()
elif policy == SchedulingPolicy.SJF:
return SJFRequestQueue()
else:
raise ValueError(f"Unknown scheduling policy: {policy}")