diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index 534b60312fd1..84441aa7d28c 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -66,7 +66,7 @@ def test_engine_core(): assert len(engine_core.scheduler.waiting) == 1 assert len(engine_core.scheduler.running) == 0 - _ = engine_core.step() + _ = engine_core.step_fn() assert len(engine_core.scheduler.waiting) == 0 assert len(engine_core.scheduler.running) == 1 @@ -75,7 +75,7 @@ def test_engine_core(): assert len(engine_core.scheduler.waiting) == 1 assert len(engine_core.scheduler.running) == 1 - _ = engine_core.step() + _ = engine_core.step_fn() assert len(engine_core.scheduler.waiting) == 0 assert len(engine_core.scheduler.running) == 2 @@ -85,12 +85,12 @@ def test_engine_core(): assert len(engine_core.scheduler.waiting) == 2 assert len(engine_core.scheduler.running) == 2 - _ = engine_core.step() + _ = engine_core.step_fn() assert len(engine_core.scheduler.waiting) == 0 assert len(engine_core.scheduler.running) == 4 # Loop through until they are all done. - while (outs := engine_core.step()[0].get(0)) and outs.outputs: + while (outs := engine_core.step_fn()[0].get(0)) and outs.outputs: pass assert len(engine_core.scheduler.waiting) == 0 @@ -107,7 +107,7 @@ def test_engine_core(): assert engine_core.scheduler.has_unfinished_requests() assert not engine_core.scheduler.has_finished_requests() - _ = engine_core.step() + _ = engine_core.step_fn() assert len(engine_core.scheduler.waiting) == 0 assert len(engine_core.scheduler.running) == 1 assert engine_core.scheduler.has_unfinished_requests() @@ -119,7 +119,7 @@ def test_engine_core(): assert not engine_core.scheduler.has_unfinished_requests() assert engine_core.scheduler.has_finished_requests() - _ = engine_core.step() + _ = engine_core.step_fn() assert not engine_core.scheduler.has_unfinished_requests() assert not engine_core.scheduler.has_finished_requests() @@ -133,7 +133,7 @@ def test_engine_core(): assert len(engine_core.scheduler.waiting) == 2 assert len(engine_core.scheduler.running) == 0 - _ = engine_core.step() + _ = engine_core.step_fn() assert len(engine_core.scheduler.waiting) == 0 assert len(engine_core.scheduler.running) == 2 @@ -141,7 +141,7 @@ def test_engine_core(): assert len(engine_core.scheduler.waiting) == 1 assert len(engine_core.scheduler.running) == 2 - _ = engine_core.step() + _ = engine_core.step_fn() assert len(engine_core.scheduler.waiting) == 0 assert len(engine_core.scheduler.running) == 3 @@ -150,7 +150,7 @@ def test_engine_core(): assert len(engine_core.scheduler.waiting) == 0 assert len(engine_core.scheduler.running) == 2 - _ = engine_core.step() + _ = engine_core.step_fn() assert len(engine_core.scheduler.waiting) == 0 assert len(engine_core.scheduler.running) == 2 @@ -165,12 +165,12 @@ def test_engine_core(): req0.request_id = req1.request_id = "test" engine_core.add_request(*engine_core.preprocess_add_request(req0)) - while (outs := engine_core.step()[0].get(0)) and outs.outputs: - pass + while engine_core.scheduler.has_requests(): + engine_core.step_fn() engine_core.add_request(*engine_core.preprocess_add_request(req1)) - while (outs := engine_core.step()[0].get(0)) and outs.outputs: - pass + while engine_core.scheduler.has_requests(): + engine_core.step_fn() assert len(engine_core.scheduler.waiting) == 0 assert len(engine_core.scheduler.running) == 0 @@ -208,8 +208,8 @@ def test_engine_core_advanced_sampling(): assert len(engine_core.scheduler.waiting) == 1 assert len(engine_core.scheduler.running) == 0 # Loop through until they are all done. - while (outs := engine_core.step()[0].get(0)) and outs.outputs: - pass + while engine_core.scheduler.has_requests(): + engine_core.step_fn() assert len(engine_core.scheduler.waiting) == 0 assert len(engine_core.scheduler.running) == 0 @@ -297,6 +297,8 @@ def test_engine_core_concurrent_batches(): max_num_batched_tokens=10, # Reduce startup time. enforce_eager=True, + # Test concurrent batch behaviour independently of async scheduling. + async_scheduling=False, ) vllm_config = engine_args.create_engine_config() with set_default_torch_num_threads(1): diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index b837b830e774..47aa343527b3 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -4,7 +4,7 @@ import hashlib from collections.abc import Callable from dataclasses import InitVar -from typing import Any, Literal +from typing import TYPE_CHECKING, Any, Literal, cast from pydantic import Field, field_validator, model_validator from pydantic.dataclasses import dataclass @@ -17,6 +17,10 @@ from vllm.utils import ( MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, POOLING_MODEL_MAX_NUM_BATCHED_TOKENS, ) +from vllm.utils.import_utils import resolve_obj_by_qualname + +if TYPE_CHECKING: + from vllm.v1.core.sched.interface import SchedulerInterface logger = init_logger(__name__) @@ -120,7 +124,7 @@ class SchedulerConfig: # scheduler class or path. "vllm.v1.core.sched.scheduler.Scheduler" # (default) or "mod.custom_class". - scheduler_cls: str | type[object] = "vllm.v1.core.sched.scheduler.Scheduler" + scheduler_cls: str | type[object] = Field(default=None) """The scheduler class to use. "vllm.v1.core.sched.scheduler.Scheduler" is the default scheduler. Can be a class directly or the path to a class of form "mod.custom_class".""" @@ -132,12 +136,34 @@ class SchedulerConfig: """ async_scheduling: bool = False - """EXPERIMENTAL: If set to True, perform async scheduling. This may help - reduce the CPU overheads, leading to better latency and throughput. However, - async scheduling is currently not supported with some features such as - structured outputs, speculative decoding, and pipeline parallelism. + """If set to True, perform async scheduling. This helps to avoid gaps in + GPU utilization, leading to better latency and throughput. + Async scheduling is currently not supported with some features such as + speculative decoding and pipeline parallelism. """ + def get_scheduler_cls(self) -> type["SchedulerInterface"]: + if self.scheduler_cls is None: + if self.async_scheduling: + from vllm.v1.core.sched.async_scheduler import AsyncScheduler + + return AsyncScheduler + from vllm.v1.core.sched.scheduler import Scheduler + + return Scheduler + + # This warning can be removed once the Scheduler interface is + # finalized and we can maintain support for scheduler classes that + # implement it + logger.warning_once( + "Using custom scheduler class %s. This scheduler interface is " + "not public and compatibility may not be maintained.", + self.scheduler_cls, + ) + if not isinstance(self.scheduler_cls, str): + return cast(type["SchedulerInterface"], self.scheduler_cls) + return resolve_obj_by_qualname(self.scheduler_cls) + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, @@ -161,6 +187,8 @@ class SchedulerConfig: "max_num_seqs", "max_model_len", "enable_chunked_prefill", + "scheduler_cls", + "async_scheduling", mode="wrap", ) @classmethod @@ -242,9 +270,6 @@ class SchedulerConfig: self.long_prefill_token_threshold, ) - if self.async_scheduling: - self.scheduler_cls = "vllm.v1.core.sched.async_scheduler.AsyncScheduler" - @model_validator(mode="after") def _verify_args(self) -> Self: if ( diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index ee91cb0ef5c3..ac4607886305 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -353,6 +353,53 @@ class VllmConfig: self.model_config, self.load_config ) + executor_backend = self.parallel_config.distributed_executor_backend + executor_supports_async_sched = executor_backend in ( + "mp", + "uni", + "external_launcher", + ) + + if self.scheduler_config.async_scheduling: + # Async scheduling explicitly enabled, hard fail any incompatibilities. + if self.parallel_config.pipeline_parallel_size > 1: + raise ValueError( + "Async scheduling is not yet compatible with " + "pipeline_parallel_size > 1." + ) + if self.speculative_config is not None: + raise ValueError( + "Async scheduling is not yet compatible with speculative decoding." + ) + if not executor_supports_async_sched: + raise ValueError( + "Currently, async scheduling only supports `mp`, `uni`, or " + "`external_launcher` distributed executor backend, but you chose " + f"`{executor_backend}`." + ) + elif self.scheduler_config.async_scheduling is None: + # Enable async scheduling unless there is an incompatible option. + # NOTE: we won't reach here until async scheduling is enabled by default. + if ( + self.parallel_config.pipeline_parallel_size > 1 + or self.speculative_config is not None + ): + logger.warning( + "Async scheduling is not yet supported with speculative decoding " + " or pipeline_parallel_size > 1 and will be disabled." + ) + self.scheduler_config.async_scheduling = False + elif not executor_supports_async_sched: + logger.warning( + "Async scheduling will be disabled because it is not supported " + "with the `%s` distributed executor backend (only `mp`, `uni`, and " + "`external_launcher` are supported).", + executor_backend, + ) + self.scheduler_config.async_scheduling = False + else: + self.scheduler_config.async_scheduling = True + from vllm.platforms import current_platform if ( @@ -467,7 +514,7 @@ class VllmConfig: self.speculative_config is not None and self.speculative_config.use_eagle() ): - raise NotImplementedError( + raise ValueError( "Fast prefill optimization for KV sharing is not " "compatible with EAGLE as EAGLE requires correct logits " "for all tokens while fast prefill gives incorrect logits " @@ -491,7 +538,7 @@ class VllmConfig: ) if not getattr(self.model_config.hf_config, "is_causal", True): disable_chunked_prefill_reasons.append( - "Only models using causal attention supports chunked " + "Only models using causal attention support chunked " "prefill and prefix caching; disabling both." ) elif self.model_config.is_encoder_decoder: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index fe48e4293c03..f1a6c0716e4c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -513,7 +513,7 @@ class EngineArgs: ObservabilityConfig.collect_detailed_traces ) scheduling_policy: SchedulerPolicy = SchedulerConfig.policy - scheduler_cls: str | type[object] = SchedulerConfig.scheduler_cls + scheduler_cls: str | type[object] | None = SchedulerConfig.scheduler_cls pooler_config: PoolerConfig | None = ModelConfig.pooler_config override_pooler_config: dict | PoolerConfig | None = ( @@ -552,7 +552,7 @@ class EngineArgs: ) """Custom logitproc types""" - async_scheduling: bool = SchedulerConfig.async_scheduling + async_scheduling: bool | None = SchedulerConfig.async_scheduling kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill @@ -1479,20 +1479,6 @@ class EngineArgs: else ParallelConfig.data_parallel_rpc_port ) - if self.async_scheduling: - if self.pipeline_parallel_size > 1: - raise ValueError( - "Async scheduling is not supported with pipeline-parallel-size > 1." - ) - - # Currently, async scheduling does not support speculative decoding. - # TODO(woosuk): Support it. - if self.speculative_config is not None: - raise ValueError( - "Currently, speculative decoding is not supported with " - "async scheduling." - ) - # Forward the deprecated CLI args to the EPLB config. if self.num_redundant_experts is not None: self.eplb_config.num_redundant_experts = self.num_redundant_experts @@ -1536,16 +1522,6 @@ class EngineArgs: _api_process_rank=self._api_process_rank, ) - if self.async_scheduling and ( - parallel_config.distributed_executor_backend - not in ("mp", "uni", "external_launcher") - ): - raise ValueError( - "Currently, async scheduling only supports `mp`, `uni` or " - "`external_launcher` distributed executor backend, but you choose " - f"`{parallel_config.distributed_executor_backend}`." - ) - speculative_config = self.create_speculative_config( target_model_config=model_config, target_parallel_config=parallel_config, diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py index 291d33c9bf98..88d99d940282 100644 --- a/vllm/v1/core/sched/interface.py +++ b/vllm/v1/core/sched/interface.py @@ -4,16 +4,34 @@ from abc import ABC, abstractmethod from collections.abc import Iterable from typing import TYPE_CHECKING, Optional +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry + if TYPE_CHECKING: + from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.engine import EngineCoreOutputs + from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput from vllm.v1.request import Request, RequestStatus + from vllm.v1.structured_output import StructuredOutputManager class SchedulerInterface(ABC): + @abstractmethod + def __init__( + self, + vllm_config: "VllmConfig", + kv_cache_config: "KVCacheConfig", + structured_output_manager: "StructuredOutputManager", + block_size: int, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + include_finished_set: bool = False, + log_stats: bool = False, + ) -> None: + raise NotImplementedError + @abstractmethod def schedule(self) -> "SchedulerOutput": """Schedule the requests to process in this scheduling step. diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 78af197821e2..fba018432e0a 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -29,7 +29,6 @@ from vllm.tasks import POOLING_TASKS, SupportedTask from vllm.transformers_utils.config import maybe_register_config_serialize_by_value from vllm.utils.gc_utils import maybe_attach_gc_debug_callback from vllm.utils.hashing import get_hash_fn_by_name -from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.network_utils import make_zmq_socket from vllm.utils.system_utils import decorate_logs, set_process_title from vllm.v1.core.kv_cache_utils import ( @@ -41,7 +40,6 @@ from vllm.v1.core.kv_cache_utils import ( ) from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.output import SchedulerOutput -from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler from vllm.v1.engine import ( EngineCoreOutputs, EngineCoreRequest, @@ -117,23 +115,7 @@ class EngineCore: self.structured_output_manager = StructuredOutputManager(vllm_config) # Setup scheduler. - if isinstance(vllm_config.scheduler_config.scheduler_cls, str): - Scheduler = resolve_obj_by_qualname( - vllm_config.scheduler_config.scheduler_cls - ) - else: - Scheduler = vllm_config.scheduler_config.scheduler_cls - - # This warning can be removed once the V1 Scheduler interface is - # finalized and we can maintain support for scheduler classes that - # implement it - if Scheduler is not V1Scheduler: - logger.warning( - "Using configured V1 scheduler class %s. " - "This scheduler interface is not public and " - "compatibility may not be maintained.", - vllm_config.scheduler_config.scheduler_cls, - ) + Scheduler = vllm_config.scheduler_config.get_scheduler_cls() if len(kv_cache_config.kv_cache_groups) == 0: # Encoder models without KV cache don't support