diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py new file mode 100644 index 0000000000000..dfefff7534e20 --- /dev/null +++ b/vllm/v1/core/sched/interface.py @@ -0,0 +1,69 @@ +# SPDX-License-Identifier: Apache-2.0 +from abc import ABC, abstractmethod +from collections.abc import Iterable +from typing import TYPE_CHECKING, Optional, Union + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.engine import EngineCoreOutputs + from vllm.v1.metrics.stats import SchedulerStats + from vllm.v1.outputs import ModelRunnerOutput + from vllm.v1.request import Request, RequestStatus + + +class SchedulerInterface(ABC): + + @abstractmethod + def schedule(self) -> "SchedulerOutput": + raise NotImplementedError + + @abstractmethod + def update_from_output( + self, + scheduler_output: "SchedulerOutput", + model_runner_output: "ModelRunnerOutput", + ) -> "EngineCoreOutputs": + raise NotImplementedError + + @abstractmethod + def add_request(self, request: "Request") -> None: + raise NotImplementedError + + @abstractmethod + def finish_requests( + self, + request_ids: Union[str, Iterable[str]], + finished_status: "RequestStatus", + ) -> None: + raise NotImplementedError + + @abstractmethod + def get_num_unfinished_requests(self) -> int: + raise NotImplementedError + + @abstractmethod + def has_unfinished_requests(self) -> bool: + raise NotImplementedError + + @abstractmethod + def has_finished_requests(self) -> bool: + raise NotImplementedError + + @abstractmethod + def has_requests(self) -> bool: + """Returns True if there are unfinished requests, or finished requests + not yet returned in SchedulerOutputs.""" + raise NotImplementedError + + @abstractmethod + def get_num_unscheduled_requests(self) -> int: + """Number of requests that are not being processed by the executor.""" + raise NotImplementedError + + @abstractmethod + def reset_prefix_cache(self) -> bool: + raise NotImplementedError + + @abstractmethod + def make_stats(self) -> Optional["SchedulerStats"]: + raise NotImplementedError diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 5b501a2f2a7ff..915dd21c9b49c 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -13,6 +13,7 @@ from vllm.logger import init_logger from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, compute_encoder_budget) from vllm.v1.core.kv_cache_manager import KVCacheManager +from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, SchedulerOutput) from vllm.v1.core.sched.utils import check_stop @@ -26,7 +27,7 @@ from vllm.v1.structured_output import StructuredOutputManager logger = init_logger(__name__) -class Scheduler: +class Scheduler(SchedulerInterface): def __init__( self,