mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-03 02:44:30 +08:00
Merge 0431508388b8e130a170ef017d760d873a80ee23 into 254f6b986720c92ddf97fbb1a6a6465da8e87e29
This commit is contained in:
commit
831501ec55
229
tests/v1/core/test_sjf_scheduler.py
Normal file
229
tests/v1/core/test_sjf_scheduler.py
Normal file
@ -0,0 +1,229 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.config import (
|
||||||
|
CacheConfig,
|
||||||
|
ModelConfig,
|
||||||
|
SchedulerConfig,
|
||||||
|
VllmConfig,
|
||||||
|
)
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.utils.hashing import sha256
|
||||||
|
from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash
|
||||||
|
from vllm.v1.core.sched.scheduler import Scheduler
|
||||||
|
from vllm.v1.kv_cache_interface import (
|
||||||
|
FullAttentionSpec,
|
||||||
|
KVCacheConfig,
|
||||||
|
KVCacheGroupSpec,
|
||||||
|
)
|
||||||
|
from vllm.v1.request import Request
|
||||||
|
from vllm.v1.structured_output import StructuredOutputManager
|
||||||
|
|
||||||
|
from .utils import EOS_TOKEN_ID
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.cpu_test
|
||||||
|
|
||||||
|
|
||||||
|
def create_scheduler_with_sjf(
|
||||||
|
model: str = "facebook/opt-125m",
|
||||||
|
max_num_seqs: int = 16,
|
||||||
|
max_num_batched_tokens: int = 8192,
|
||||||
|
enable_prefix_caching: bool = False,
|
||||||
|
long_prefill_token_threshold: int = 0,
|
||||||
|
num_blocks: int = 10000,
|
||||||
|
block_size: int = 16,
|
||||||
|
max_model_len: int | None = None,
|
||||||
|
) -> Scheduler:
|
||||||
|
"""Create scheduler with SJF policy enabled.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: model under test
|
||||||
|
max_num_seqs: max sequences to schedule
|
||||||
|
max_num_batch_tokens: max num tokens to batch
|
||||||
|
enable_prefix_caching: optionally force APC config
|
||||||
|
(True/False) or use default
|
||||||
|
(False)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{class}`Scheduler` instance with SJF scheduling
|
||||||
|
"""
|
||||||
|
model_config = ModelConfig(
|
||||||
|
model=model,
|
||||||
|
trust_remote_code=True,
|
||||||
|
dtype="float16",
|
||||||
|
seed=42,
|
||||||
|
)
|
||||||
|
if max_model_len is None:
|
||||||
|
max_model_len = max_num_batched_tokens
|
||||||
|
scheduler_config = SchedulerConfig(
|
||||||
|
max_num_seqs=max_num_seqs,
|
||||||
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
|
max_model_len=max_model_len,
|
||||||
|
long_prefill_token_threshold=long_prefill_token_threshold,
|
||||||
|
enable_chunked_prefill=True,
|
||||||
|
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||||
|
policy="sjf", # Enable SJF scheduling
|
||||||
|
)
|
||||||
|
|
||||||
|
cache_config = CacheConfig(
|
||||||
|
block_size=block_size,
|
||||||
|
gpu_memory_utilization=0.9,
|
||||||
|
swap_space=0,
|
||||||
|
cache_dtype="auto",
|
||||||
|
enable_prefix_caching=enable_prefix_caching,
|
||||||
|
)
|
||||||
|
|
||||||
|
vllm_config = VllmConfig(
|
||||||
|
scheduler_config=scheduler_config,
|
||||||
|
model_config=model_config,
|
||||||
|
cache_config=cache_config
|
||||||
|
)
|
||||||
|
kv_cache_config = KVCacheConfig(
|
||||||
|
num_blocks=num_blocks,
|
||||||
|
kv_cache_tensors=[],
|
||||||
|
kv_cache_groups=[
|
||||||
|
KVCacheGroupSpec(
|
||||||
|
["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False)
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
cache_config.num_gpu_blocks = num_blocks
|
||||||
|
return Scheduler(
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
kv_cache_config=kv_cache_config,
|
||||||
|
log_stats=True,
|
||||||
|
structured_output_manager=StructuredOutputManager(vllm_config),
|
||||||
|
block_size=block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_none_hash_initialized = False
|
||||||
|
|
||||||
|
|
||||||
|
def create_requests_for_sjf(
|
||||||
|
num_requests: int,
|
||||||
|
prompt_lengths: list[int],
|
||||||
|
arrival_times: list[float] | None = None,
|
||||||
|
max_tokens: int = 16,
|
||||||
|
stop_token_ids: list[int] | None = None,
|
||||||
|
prompt_logprobs: int | None = None,
|
||||||
|
starting_idx: int = 0,
|
||||||
|
same_prompt: bool = False,
|
||||||
|
block_size: int = 16,
|
||||||
|
req_ids: list[str] | None = None,
|
||||||
|
):
|
||||||
|
"""Create requests with specified prompt lengths and arrival times for SJF testing."""
|
||||||
|
assert len(prompt_lengths) == num_requests
|
||||||
|
if arrival_times is not None:
|
||||||
|
assert len(arrival_times) == num_requests
|
||||||
|
else:
|
||||||
|
arrival_times = [float(i) for i in range(num_requests)]
|
||||||
|
|
||||||
|
global _none_hash_initialized
|
||||||
|
if not _none_hash_initialized:
|
||||||
|
init_none_hash(sha256)
|
||||||
|
_none_hash_initialized = True
|
||||||
|
|
||||||
|
block_hasher = get_request_block_hasher(block_size, sha256)
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
ignore_eos=False,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
stop_token_ids=stop_token_ids,
|
||||||
|
prompt_logprobs=prompt_logprobs,
|
||||||
|
)
|
||||||
|
requests = []
|
||||||
|
|
||||||
|
if req_ids:
|
||||||
|
assert len(req_ids) == num_requests
|
||||||
|
else:
|
||||||
|
req_ids = [f"{i + starting_idx}" for i in range(num_requests)]
|
||||||
|
|
||||||
|
for i in range(num_requests):
|
||||||
|
num_tokens = prompt_lengths[i]
|
||||||
|
prompt_token_ids = (
|
||||||
|
[starting_idx] * num_tokens
|
||||||
|
if same_prompt
|
||||||
|
else [i + starting_idx] * num_tokens
|
||||||
|
)
|
||||||
|
request = Request(
|
||||||
|
request_id=req_ids[i],
|
||||||
|
prompt_token_ids=prompt_token_ids,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
pooling_params=None,
|
||||||
|
eos_token_id=EOS_TOKEN_ID,
|
||||||
|
arrival_time=arrival_times[i],
|
||||||
|
priority=1, # SJF ignores priority, set to default
|
||||||
|
block_hasher=block_hasher,
|
||||||
|
)
|
||||||
|
requests.append(request)
|
||||||
|
return requests
|
||||||
|
|
||||||
|
|
||||||
|
def test_sjf_scheduling_basic_ordering():
|
||||||
|
"""Test that requests are scheduled in SJF order
|
||||||
|
(shorter job = higher priority)."""
|
||||||
|
scheduler = create_scheduler_with_sjf()
|
||||||
|
|
||||||
|
# Create requests with different prompt lengths
|
||||||
|
# Shorter jobs should be scheduled first
|
||||||
|
prompt_lengths = [100, 50, 75] # Add in non-length order
|
||||||
|
arrival_times = [0.0, 0.0, 0.0] # All same arrival times
|
||||||
|
requests = create_requests_for_sjf(
|
||||||
|
num_requests=3, prompt_lengths=prompt_lengths, arrival_times=arrival_times
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add requests in non-length order
|
||||||
|
for request in requests:
|
||||||
|
scheduler.add_request(request)
|
||||||
|
|
||||||
|
# Schedule and verify SJF order
|
||||||
|
output = scheduler.schedule()
|
||||||
|
|
||||||
|
# Should schedule all requests since they fit in budget
|
||||||
|
assert len(output.scheduled_new_reqs) == 3
|
||||||
|
|
||||||
|
# Verify they are scheduled in length order (shortest first):
|
||||||
|
# req_1 (length 50), req_2 (length 75), req_0 (length 100)
|
||||||
|
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
|
||||||
|
assert scheduled_req_ids == ["1", "2", "0"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_sjf_scheduling_waiting_time_tiebreaker_fixed():
|
||||||
|
"""Test that waiting time is used as tiebreaker when lengths are equal.
|
||||||
|
"""
|
||||||
|
scheduler = create_scheduler_with_sjf()
|
||||||
|
|
||||||
|
# Mock current time, fixed at 10.0 seconds
|
||||||
|
current_time = 10.0
|
||||||
|
time_patch = Mock(return_value=current_time)
|
||||||
|
|
||||||
|
with patch('time.time', time_patch):
|
||||||
|
# Create 3 requests with same length but different arrival times
|
||||||
|
prompt_lengths = [64, 64, 64] # All requests have same length
|
||||||
|
# Arrival times: req1 earliest, req2 second, req0 latest
|
||||||
|
arrival_times = [3.0, 1.0, 2.0]
|
||||||
|
|
||||||
|
requests = create_requests_for_sjf(
|
||||||
|
num_requests=3,
|
||||||
|
prompt_lengths=prompt_lengths,
|
||||||
|
arrival_times=arrival_times
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add requests to scheduler (order of addition doesn't affect final scheduling order)
|
||||||
|
for request in requests:
|
||||||
|
scheduler.add_request(request)
|
||||||
|
|
||||||
|
# Execute scheduling
|
||||||
|
output = scheduler.schedule()
|
||||||
|
|
||||||
|
# Verify all requests are scheduled (resources are sufficient)
|
||||||
|
assert len(output.scheduled_new_reqs) == 3
|
||||||
|
|
||||||
|
# Verify scheduling order: longest wait first
|
||||||
|
# Expected order: req1 (waited 9.0s), req2 (waited 8.0s), req0 (waited 7.0s)
|
||||||
|
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
|
||||||
|
assert scheduled_req_ids == ["1", "2", "0"]
|
||||||
@ -20,7 +20,7 @@ if TYPE_CHECKING:
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
RunnerType = Literal["generate", "pooling", "draft"]
|
RunnerType = Literal["generate", "pooling", "draft"]
|
||||||
SchedulerPolicy = Literal["fcfs", "priority"]
|
SchedulerPolicy = Literal["fcfs", "priority", "sjf"]
|
||||||
|
|
||||||
|
|
||||||
@config
|
@config
|
||||||
@ -105,7 +105,9 @@ class SchedulerConfig:
|
|||||||
- "fcfs" means first come first served, i.e. requests are handled in order
|
- "fcfs" means first come first served, i.e. requests are handled in order
|
||||||
of arrival.\n
|
of arrival.\n
|
||||||
- "priority" means requests are handled based on given priority (lower
|
- "priority" means requests are handled based on given priority (lower
|
||||||
value means earlier handling) and time of arrival deciding any ties)."""
|
value means earlier handling) and time of arrival deciding any ties).\n
|
||||||
|
- "sjf" means shortest job first. Requests are scheduled by prompt length
|
||||||
|
(shortest first), with aging to prevent starvation."""
|
||||||
|
|
||||||
disable_chunked_mm_input: bool = False
|
disable_chunked_mm_input: bool = False
|
||||||
"""If set to true and chunked prefill is enabled, we do not want to
|
"""If set to true and chunked prefill is enabled, we do not want to
|
||||||
|
|||||||
0
vllm/v1/core/sched/policy/__init__.py
Normal file
0
vllm/v1/core/sched/policy/__init__.py
Normal file
139
vllm/v1/core/sched/policy/shortest_job_first.py
Normal file
139
vllm/v1/core/sched/policy/shortest_job_first.py
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
from functools import total_ordering
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.v1.request import Request
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ScoreDim:
|
||||||
|
"""
|
||||||
|
Normalized scoring dimension.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, name: str, median: float, norm_scale=0.0, weight=0.5, reverse=False
|
||||||
|
):
|
||||||
|
self.name = name
|
||||||
|
self.median = median
|
||||||
|
if norm_scale != 0.0:
|
||||||
|
self.norm_scale = norm_scale
|
||||||
|
else:
|
||||||
|
self.norm_scale = 1 / median
|
||||||
|
self.weight = weight
|
||||||
|
self.reverse = reverse
|
||||||
|
|
||||||
|
|
||||||
|
class NormalizedScorer:
|
||||||
|
"""
|
||||||
|
Normalize unbounded N-dimensional values into a composite score using the Sigmoid
|
||||||
|
function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dim_list: list[ScoreDim]) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the scorer with a list of scoring dimensions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dim_list: A list of `ScoreDim` objects. Each dimension must define a
|
||||||
|
median reference point, scaling factor, and weight.
|
||||||
|
"""
|
||||||
|
self.dim_list = dim_list
|
||||||
|
self.dim_count = len(dim_list)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _sigmoid_normalize(value, median, norm_scale):
|
||||||
|
"""Sigmoid function: Maps value to (0, 1)."""
|
||||||
|
return 1 / (1 + math.exp(-norm_scale * (value - median)))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _inv_sigmoid_normalize(value, median, norm_scale):
|
||||||
|
"""Inverse Sigmoid: Used for dimensions where a larger value yields a lower
|
||||||
|
score.
|
||||||
|
"""
|
||||||
|
# Equivalent to sigmoid(-x), but more numerically stable.
|
||||||
|
return 1 / (1 + math.exp(norm_scale * (value - median)))
|
||||||
|
|
||||||
|
def score(self, *dims: float) -> float:
|
||||||
|
"""
|
||||||
|
Compute the composite score.
|
||||||
|
Larger value → higher score → use forward Sigmoid.
|
||||||
|
Smaller value → higher score → use inverse Sigmoid.
|
||||||
|
"""
|
||||||
|
if len(dims) > self.dim_count:
|
||||||
|
raise ValueError(
|
||||||
|
f"Dim num({len(dims)}) exceeds max num dim({self.dim_count})"
|
||||||
|
)
|
||||||
|
|
||||||
|
final_score = 0.0
|
||||||
|
for idx, dim_value in enumerate(dims):
|
||||||
|
dim_info = self.dim_list[idx]
|
||||||
|
if dim_info.reverse:
|
||||||
|
score = self._inv_sigmoid_normalize(
|
||||||
|
dim_value, dim_info.median, dim_info.norm_scale
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
score = self._sigmoid_normalize(
|
||||||
|
dim_value, dim_info.median, dim_info.norm_scale
|
||||||
|
)
|
||||||
|
logger.debug("%s(%s) : %.10f", dim_info.name, dim_info.reverse, score)
|
||||||
|
|
||||||
|
# Weighted summation.
|
||||||
|
final_score += score * dim_info.weight
|
||||||
|
return max(0.0, min(1.0, final_score)) # Clamp to [0, 1].
|
||||||
|
|
||||||
|
|
||||||
|
class TimeAndLengthScorer(NormalizedScorer):
|
||||||
|
"""
|
||||||
|
Scorer for time and length dimensions; defaults to forward scoring with equal
|
||||||
|
weights (0.5 each).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
time_median=5,
|
||||||
|
length_median=1024 * 32,
|
||||||
|
time_scale=0.0,
|
||||||
|
length_scale=0.0,
|
||||||
|
time_weight=0.5,
|
||||||
|
length_weight=0.5,
|
||||||
|
reverse_time=False,
|
||||||
|
reverse_len=True,
|
||||||
|
) -> None:
|
||||||
|
dim_list = [
|
||||||
|
ScoreDim("time", time_median, time_scale, time_weight, reverse_time),
|
||||||
|
ScoreDim("length", length_median, length_scale, length_weight, reverse_len),
|
||||||
|
]
|
||||||
|
super().__init__(dim_list)
|
||||||
|
|
||||||
|
def score(self, *dims: float) -> float:
|
||||||
|
assert len(dims) == 2
|
||||||
|
return super().score(*dims)
|
||||||
|
|
||||||
|
|
||||||
|
@total_ordering
|
||||||
|
class WeightedScoreSorter:
|
||||||
|
def __init__(self, request: Request, scorer: TimeAndLengthScorer):
|
||||||
|
self.request = request
|
||||||
|
self.scorer = scorer
|
||||||
|
assert request.prompt_token_ids is not None
|
||||||
|
self.request_length = len(request.prompt_token_ids)
|
||||||
|
self.request_arrival_time = request.arrival_time
|
||||||
|
self.__update_stats()
|
||||||
|
|
||||||
|
def __lt__(self, other_request_weighted_score: "WeightedScoreSorter") -> bool:
|
||||||
|
self.__update_stats()
|
||||||
|
return self.weighted_score > other_request_weighted_score.weighted_score
|
||||||
|
|
||||||
|
def __eq__(self, other_request_weighted_score: object) -> bool:
|
||||||
|
if not isinstance(other_request_weighted_score, WeightedScoreSorter):
|
||||||
|
return NotImplemented
|
||||||
|
return self.request.request_id == other_request_weighted_score.request.request_id
|
||||||
|
|
||||||
|
def __update_stats(self):
|
||||||
|
self.wait_time = time.time() - self.request_arrival_time
|
||||||
|
self.weighted_score = self.scorer.score(self.wait_time, self.request_length)
|
||||||
@ -6,7 +6,12 @@ from abc import ABC, abstractmethod
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
from collections.abc import Iterable, Iterator
|
from collections.abc import Iterable, Iterator
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from vllm.v1.core.sched.policy.shortest_job_first import (
|
||||||
|
TimeAndLengthScorer,
|
||||||
|
WeightedScoreSorter,
|
||||||
|
)
|
||||||
from vllm.v1.request import Request
|
from vllm.v1.request import Request
|
||||||
|
|
||||||
|
|
||||||
@ -15,6 +20,7 @@ class SchedulingPolicy(Enum):
|
|||||||
|
|
||||||
FCFS = "fcfs"
|
FCFS = "fcfs"
|
||||||
PRIORITY = "priority"
|
PRIORITY = "priority"
|
||||||
|
SJF = "sjf"
|
||||||
|
|
||||||
|
|
||||||
class RequestQueue(ABC):
|
class RequestQueue(ABC):
|
||||||
@ -133,59 +139,57 @@ class FCFSRequestQueue(deque[Request], RequestQueue):
|
|||||||
return super().__reversed__()
|
return super().__reversed__()
|
||||||
|
|
||||||
|
|
||||||
class PriorityRequestQueue(RequestQueue):
|
class RequestHeap(RequestQueue):
|
||||||
"""
|
"""A queue that supports heap operations."""
|
||||||
A priority queue that supports heap operations.
|
|
||||||
|
|
||||||
Respects the ordering defined in the Request class, where
|
|
||||||
requests with a smaller value of `priority` are processed first.
|
|
||||||
If multiple requests have the same priority, the one with the earlier
|
|
||||||
`arrival_time` is processed first.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._heap: list[Request] = []
|
self._heap: list = []
|
||||||
|
|
||||||
|
def _request_to_heap(self, request: Request) -> Any:
|
||||||
|
"""Convert a request to the appropriate heap element."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _heap_to_request(self, element: Any) -> Request:
|
||||||
|
"""Extract the request from a heap element."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def add_request(self, request: Request) -> None:
|
def add_request(self, request: Request) -> None:
|
||||||
"""Add a request to the queue according to priority policy."""
|
"""Add a request to the queue according to heap priority."""
|
||||||
heapq.heappush(self._heap, request)
|
heapq.heappush(self._heap, self._request_to_heap(request))
|
||||||
|
|
||||||
def pop_request(self) -> Request:
|
def pop_request(self) -> Request:
|
||||||
"""Pop a request from the queue according to priority policy."""
|
"""Pop the highest priority request from the heap."""
|
||||||
if not self._heap:
|
if not self._heap:
|
||||||
raise IndexError("pop from empty heap")
|
raise IndexError("pop from empty heap")
|
||||||
return heapq.heappop(self._heap)
|
return self._heap_to_request(heapq.heappop(self._heap))
|
||||||
|
|
||||||
def peek_request(self) -> Request:
|
def peek_request(self) -> Request:
|
||||||
"""Peek at the next request in the queue without removing it."""
|
"""Peek at the highest priority request in the heap without removing it."""
|
||||||
if not self._heap:
|
if not self._heap:
|
||||||
raise IndexError("peek from empty heap")
|
raise IndexError("peek from empty heap")
|
||||||
return self._heap[0]
|
return self._heap_to_request(self._heap[0])
|
||||||
|
|
||||||
def prepend_request(self, request: Request) -> None:
|
def prepend_request(self, request: Request) -> None:
|
||||||
"""Add a request to the queue according to priority policy.
|
"""Add a request to the heap. In heap-based queues there is no beginning as
|
||||||
|
elements are ordered by priority/score. This behaves like add_request."""
|
||||||
Note: In a priority queue, there is no concept of prepending to the
|
|
||||||
front. Requests are ordered by (priority, arrival_time)."""
|
|
||||||
self.add_request(request)
|
self.add_request(request)
|
||||||
|
|
||||||
def prepend_requests(self, requests: RequestQueue) -> None:
|
def prepend_requests(self, requests: RequestQueue) -> None:
|
||||||
"""Add all requests from another queue according to priority policy.
|
"""Add all requests from another queue to the heap. In heap-based queues there
|
||||||
|
is no beginning as elements are ordered by priority/score. This behaves like
|
||||||
Note: In a priority queue, there is no concept of prepending to the
|
add_request."""
|
||||||
front. Requests are ordered by (priority, arrival_time)."""
|
|
||||||
for request in requests:
|
for request in requests:
|
||||||
self.add_request(request)
|
self.add_request(request)
|
||||||
|
|
||||||
def remove_request(self, request: Request) -> None:
|
def remove_request(self, request: Request) -> None:
|
||||||
"""Remove a specific request from the queue."""
|
"""Remove a specific request from the heap."""
|
||||||
self._heap.remove(request)
|
self._heap.remove(self._request_to_heap(request))
|
||||||
heapq.heapify(self._heap)
|
heapq.heapify(self._heap)
|
||||||
|
|
||||||
def remove_requests(self, requests: Iterable[Request]) -> None:
|
def remove_requests(self, requests: Iterable[Request]) -> None:
|
||||||
"""Remove multiple specific requests from the queue."""
|
"""Remove multiple specific requests from the heap."""
|
||||||
requests_to_remove = requests if isinstance(requests, set) else set(requests)
|
remove = requests if isinstance(requests, set) else set(requests)
|
||||||
self._heap = [r for r in self._heap if r not in requests_to_remove]
|
self._heap = [h for h in self._heap if self._heap_to_request(h) not in remove]
|
||||||
heapq.heapify(self._heap)
|
heapq.heapify(self._heap)
|
||||||
|
|
||||||
def __bool__(self) -> bool:
|
def __bool__(self) -> bool:
|
||||||
@ -197,21 +201,56 @@ class PriorityRequestQueue(RequestQueue):
|
|||||||
return len(self._heap)
|
return len(self._heap)
|
||||||
|
|
||||||
def __iter__(self) -> Iterator[Request]:
|
def __iter__(self) -> Iterator[Request]:
|
||||||
"""Iterate over the queue according to priority policy."""
|
"""Iterate over the queue to heap order."""
|
||||||
heap_copy = self._heap[:]
|
heap_copy = self._heap[:]
|
||||||
while heap_copy:
|
while heap_copy:
|
||||||
yield heapq.heappop(heap_copy)
|
yield self._heap_to_request(heapq.heappop(heap_copy))
|
||||||
|
|
||||||
def __reversed__(self) -> Iterator[Request]:
|
def __reversed__(self) -> Iterator[Request]:
|
||||||
"""Iterate over the queue in reverse priority order."""
|
"""Iterate over the queue in reverse heap order."""
|
||||||
return reversed(list(self))
|
return reversed(list(self))
|
||||||
|
|
||||||
|
|
||||||
|
class PriorityRequestQueue(RequestHeap):
|
||||||
|
"""A priority queue that supports heap operations.
|
||||||
|
|
||||||
|
Respects the ordering defined in the Request class, where requests with a smaller
|
||||||
|
value of `priority` are processed first. If multiple requests have the same
|
||||||
|
priority, the one with the earlier `arrival_time` is processed first."""
|
||||||
|
|
||||||
|
def _request_to_heap(self, request: Request) -> Request:
|
||||||
|
"""For priority queue, the heap element is the request itself."""
|
||||||
|
return request
|
||||||
|
|
||||||
|
def _heap_to_request(self, element: Request) -> Request:
|
||||||
|
"""Extract request from heap element with type checking."""
|
||||||
|
return element
|
||||||
|
|
||||||
|
|
||||||
|
class SJFRequestQueue(RequestHeap):
|
||||||
|
"""A Shortest Job First (SJF) queue where requests are ordered by weighted score.
|
||||||
|
Requests with higher weighted scores (shorter jobs) are processed first."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.scorer = TimeAndLengthScorer()
|
||||||
|
|
||||||
|
def _request_to_heap(self, request: Request) -> WeightedScoreSorter:
|
||||||
|
"""Convert request to `WeightedScoreSorter` for heap."""
|
||||||
|
return WeightedScoreSorter(request, self.scorer)
|
||||||
|
|
||||||
|
def _heap_to_request(self, element: WeightedScoreSorter) -> Request:
|
||||||
|
"""Extract request from the `WeightedScoreSorter`."""
|
||||||
|
return element.request
|
||||||
|
|
||||||
|
|
||||||
def create_request_queue(policy: SchedulingPolicy) -> RequestQueue:
|
def create_request_queue(policy: SchedulingPolicy) -> RequestQueue:
|
||||||
"""Create request queue based on scheduling policy."""
|
"""Create request queue based on scheduling policy."""
|
||||||
if policy == SchedulingPolicy.PRIORITY:
|
if policy == SchedulingPolicy.PRIORITY:
|
||||||
return PriorityRequestQueue()
|
return PriorityRequestQueue()
|
||||||
elif policy == SchedulingPolicy.FCFS:
|
elif policy == SchedulingPolicy.FCFS:
|
||||||
return FCFSRequestQueue()
|
return FCFSRequestQueue()
|
||||||
|
elif policy == SchedulingPolicy.SJF:
|
||||||
|
return SJFRequestQueue()
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown scheduling policy: {policy}")
|
raise ValueError(f"Unknown scheduling policy: {policy}")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user