[Feature] Pluggable platform-specific scheduler (#13161)

Signed-off-by: Yannick Schnider <yannick.schnider1@ibm.com>
Signed-off-by: Yannick Schnider <Yannick.Schnider1@ibm.com>
This commit is contained in:
Yannick Schnider 2025-02-19 10:16:38 +01:00 committed by GitHub
parent caf7ff4456
commit 423330263b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 56 additions and 3 deletions

View File

@ -531,6 +531,7 @@ steps:
- pip uninstall vllm_add_dummy_platform -y - pip uninstall vllm_add_dummy_platform -y
# end platform plugin tests # end platform plugin tests
# other tests continue here: # other tests continue here:
- pytest -v -s plugins_tests/test_scheduler_plugins.py
- pip install -e ./plugins/vllm_add_dummy_model - pip install -e ./plugins/vllm_add_dummy_model
- pytest -v -s distributed/test_distributed_oot.py - pytest -v -s distributed/test_distributed_oot.py
- pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process - pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process

View File

@ -0,0 +1,33 @@
# SPDX-License-Identifier: Apache-2.0
from vllm.core.scheduler import Scheduler
class DummyScheduler(Scheduler):
def schedule(self):
raise Exception("Exception raised by DummyScheduler")
def test_scheduler_plugins():
import pytest
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.sampling_params import SamplingParams
with pytest.raises(Exception) as exception_info:
engine_args = EngineArgs(
model="facebook/opt-125m",
enforce_eager=True, # reduce test time
scheduler_cls=DummyScheduler,
)
engine = LLMEngine.from_engine_args(engine_args=engine_args)
sampling_params = SamplingParams(max_tokens=1)
engine.add_request("0", "foo", sampling_params)
engine.step()
assert str(exception_info.value) == "Exception raised by DummyScheduler"

View File

@ -1495,6 +1495,10 @@ class SchedulerConfig:
chunked_prefill_enabled: bool = field(init=False) chunked_prefill_enabled: bool = field(init=False)
# scheduler class or path. "vllm.core.scheduler.Scheduler" (default)
# or "mod.custom_class".
scheduler_cls: Union[str, Type[object]] = "vllm.core.scheduler.Scheduler"
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
WARNING: Whenever a new field is added to this config, WARNING: Whenever a new field is added to this config,

View File

@ -192,6 +192,7 @@ class EngineArgs:
collect_detailed_traces: Optional[str] = None collect_detailed_traces: Optional[str] = None
disable_async_output_proc: bool = False disable_async_output_proc: bool = False
scheduling_policy: Literal["fcfs", "priority"] = "fcfs" scheduling_policy: Literal["fcfs", "priority"] = "fcfs"
scheduler_cls: Union[str, Type[object]] = "vllm.core.scheduler.Scheduler"
override_neuron_config: Optional[Dict[str, Any]] = None override_neuron_config: Optional[Dict[str, Any]] = None
override_pooler_config: Optional[PoolerConfig] = None override_pooler_config: Optional[PoolerConfig] = None
@ -938,6 +939,13 @@ class EngineArgs:
'priority (lower value means earlier handling) and time of ' 'priority (lower value means earlier handling) and time of '
'arrival deciding any ties).') 'arrival deciding any ties).')
parser.add_argument(
'--scheduler-cls',
default=EngineArgs.scheduler_cls,
help='The scheduler class to use. "vllm.core.scheduler.Scheduler" '
'is the default scheduler. Can be a class directly or the path to '
'a class of form "mod.custom_class".')
parser.add_argument( parser.add_argument(
'--override-neuron-config', '--override-neuron-config',
type=json.loads, type=json.loads,
@ -1273,10 +1281,12 @@ class EngineArgs:
send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
and parallel_config.use_ray), and parallel_config.use_ray),
policy=self.scheduling_policy, policy=self.scheduling_policy,
scheduler_cls=self.scheduler_cls,
max_num_partial_prefills=self.max_num_partial_prefills, max_num_partial_prefills=self.max_num_partial_prefills,
max_long_partial_prefills=self.max_long_partial_prefills, max_long_partial_prefills=self.max_long_partial_prefills,
long_prefill_token_threshold=self.long_prefill_token_threshold, long_prefill_token_threshold=self.long_prefill_token_threshold,
) )
lora_config = LoRAConfig( lora_config = LoRAConfig(
bias_enabled=self.enable_lora_bias, bias_enabled=self.enable_lora_bias,
max_lora_rank=self.max_lora_rank, max_lora_rank=self.max_lora_rank,

View File

@ -19,8 +19,7 @@ import vllm.envs as envs
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ObservabilityConfig, ParallelConfig, SchedulerConfig, ObservabilityConfig, ParallelConfig, SchedulerConfig,
VllmConfig) VllmConfig)
from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler, from vllm.core.scheduler import ScheduledSequenceGroup, SchedulerOutputs
SchedulerOutputs)
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics_types import StatLoggerBase, Stats from vllm.engine.metrics_types import StatLoggerBase, Stats
from vllm.engine.output_processor.interfaces import ( from vllm.engine.output_processor.interfaces import (
@ -58,7 +57,8 @@ from vllm.transformers_utils.tokenizer_group import (
BaseTokenizerGroup, init_tokenizer_from_configs) BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message) usage_message)
from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind from vllm.utils import (Counter, Device, deprecate_kwargs,
resolve_obj_by_qualname, weak_bind)
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__) logger = init_logger(__name__)
@ -346,6 +346,11 @@ class LLMEngine:
# Create the scheduler. # Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of # NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor. # GPU and CPU blocks, which are profiled in the distributed executor.
if isinstance(self.vllm_config.scheduler_config.scheduler_cls, str):
Scheduler = resolve_obj_by_qualname(
self.vllm_config.scheduler_config.scheduler_cls)
else:
Scheduler = self.vllm_config.scheduler_config.scheduler_cls
self.scheduler = [ self.scheduler = [
Scheduler( Scheduler(
self.scheduler_config, self.cache_config, self.lora_config, self.scheduler_config, self.cache_config, self.lora_config,