[Core] Rework handling of async scheduling config (#28250)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-11-07 12:01:23 -08:00 committed by GitHub
parent 18903216f5
commit da786e339e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 121 additions and 71 deletions

View File

@ -66,7 +66,7 @@ def test_engine_core():
assert len(engine_core.scheduler.waiting) == 1
assert len(engine_core.scheduler.running) == 0
_ = engine_core.step()
_ = engine_core.step_fn()
assert len(engine_core.scheduler.waiting) == 0
assert len(engine_core.scheduler.running) == 1
@ -75,7 +75,7 @@ def test_engine_core():
assert len(engine_core.scheduler.waiting) == 1
assert len(engine_core.scheduler.running) == 1
_ = engine_core.step()
_ = engine_core.step_fn()
assert len(engine_core.scheduler.waiting) == 0
assert len(engine_core.scheduler.running) == 2
@ -85,12 +85,12 @@ def test_engine_core():
assert len(engine_core.scheduler.waiting) == 2
assert len(engine_core.scheduler.running) == 2
_ = engine_core.step()
_ = engine_core.step_fn()
assert len(engine_core.scheduler.waiting) == 0
assert len(engine_core.scheduler.running) == 4
# Loop through until they are all done.
while (outs := engine_core.step()[0].get(0)) and outs.outputs:
while (outs := engine_core.step_fn()[0].get(0)) and outs.outputs:
pass
assert len(engine_core.scheduler.waiting) == 0
@ -107,7 +107,7 @@ def test_engine_core():
assert engine_core.scheduler.has_unfinished_requests()
assert not engine_core.scheduler.has_finished_requests()
_ = engine_core.step()
_ = engine_core.step_fn()
assert len(engine_core.scheduler.waiting) == 0
assert len(engine_core.scheduler.running) == 1
assert engine_core.scheduler.has_unfinished_requests()
@ -119,7 +119,7 @@ def test_engine_core():
assert not engine_core.scheduler.has_unfinished_requests()
assert engine_core.scheduler.has_finished_requests()
_ = engine_core.step()
_ = engine_core.step_fn()
assert not engine_core.scheduler.has_unfinished_requests()
assert not engine_core.scheduler.has_finished_requests()
@ -133,7 +133,7 @@ def test_engine_core():
assert len(engine_core.scheduler.waiting) == 2
assert len(engine_core.scheduler.running) == 0
_ = engine_core.step()
_ = engine_core.step_fn()
assert len(engine_core.scheduler.waiting) == 0
assert len(engine_core.scheduler.running) == 2
@ -141,7 +141,7 @@ def test_engine_core():
assert len(engine_core.scheduler.waiting) == 1
assert len(engine_core.scheduler.running) == 2
_ = engine_core.step()
_ = engine_core.step_fn()
assert len(engine_core.scheduler.waiting) == 0
assert len(engine_core.scheduler.running) == 3
@ -150,7 +150,7 @@ def test_engine_core():
assert len(engine_core.scheduler.waiting) == 0
assert len(engine_core.scheduler.running) == 2
_ = engine_core.step()
_ = engine_core.step_fn()
assert len(engine_core.scheduler.waiting) == 0
assert len(engine_core.scheduler.running) == 2
@ -165,12 +165,12 @@ def test_engine_core():
req0.request_id = req1.request_id = "test"
engine_core.add_request(*engine_core.preprocess_add_request(req0))
while (outs := engine_core.step()[0].get(0)) and outs.outputs:
pass
while engine_core.scheduler.has_requests():
engine_core.step_fn()
engine_core.add_request(*engine_core.preprocess_add_request(req1))
while (outs := engine_core.step()[0].get(0)) and outs.outputs:
pass
while engine_core.scheduler.has_requests():
engine_core.step_fn()
assert len(engine_core.scheduler.waiting) == 0
assert len(engine_core.scheduler.running) == 0
@ -208,8 +208,8 @@ def test_engine_core_advanced_sampling():
assert len(engine_core.scheduler.waiting) == 1
assert len(engine_core.scheduler.running) == 0
# Loop through until they are all done.
while (outs := engine_core.step()[0].get(0)) and outs.outputs:
pass
while engine_core.scheduler.has_requests():
engine_core.step_fn()
assert len(engine_core.scheduler.waiting) == 0
assert len(engine_core.scheduler.running) == 0
@ -297,6 +297,8 @@ def test_engine_core_concurrent_batches():
max_num_batched_tokens=10,
# Reduce startup time.
enforce_eager=True,
# Test concurrent batch behaviour independently of async scheduling.
async_scheduling=False,
)
vllm_config = engine_args.create_engine_config()
with set_default_torch_num_threads(1):

View File

@ -4,7 +4,7 @@
import hashlib
from collections.abc import Callable
from dataclasses import InitVar
from typing import Any, Literal
from typing import TYPE_CHECKING, Any, Literal, cast
from pydantic import Field, field_validator, model_validator
from pydantic.dataclasses import dataclass
@ -17,6 +17,10 @@ from vllm.utils import (
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS,
)
from vllm.utils.import_utils import resolve_obj_by_qualname
if TYPE_CHECKING:
from vllm.v1.core.sched.interface import SchedulerInterface
logger = init_logger(__name__)
@ -120,7 +124,7 @@ class SchedulerConfig:
# scheduler class or path. "vllm.v1.core.sched.scheduler.Scheduler"
# (default) or "mod.custom_class".
scheduler_cls: str | type[object] = "vllm.v1.core.sched.scheduler.Scheduler"
scheduler_cls: str | type[object] = Field(default=None)
"""The scheduler class to use. "vllm.v1.core.sched.scheduler.Scheduler" is
the default scheduler. Can be a class directly or the path to a class of
form "mod.custom_class"."""
@ -132,12 +136,34 @@ class SchedulerConfig:
"""
async_scheduling: bool = False
"""EXPERIMENTAL: If set to True, perform async scheduling. This may help
reduce the CPU overheads, leading to better latency and throughput. However,
async scheduling is currently not supported with some features such as
structured outputs, speculative decoding, and pipeline parallelism.
"""If set to True, perform async scheduling. This helps to avoid gaps in
GPU utilization, leading to better latency and throughput.
Async scheduling is currently not supported with some features such as
speculative decoding and pipeline parallelism.
"""
def get_scheduler_cls(self) -> type["SchedulerInterface"]:
if self.scheduler_cls is None:
if self.async_scheduling:
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
return AsyncScheduler
from vllm.v1.core.sched.scheduler import Scheduler
return Scheduler
# This warning can be removed once the Scheduler interface is
# finalized and we can maintain support for scheduler classes that
# implement it
logger.warning_once(
"Using custom scheduler class %s. This scheduler interface is "
"not public and compatibility may not be maintained.",
self.scheduler_cls,
)
if not isinstance(self.scheduler_cls, str):
return cast(type["SchedulerInterface"], self.scheduler_cls)
return resolve_obj_by_qualname(self.scheduler_cls)
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
@ -161,6 +187,8 @@ class SchedulerConfig:
"max_num_seqs",
"max_model_len",
"enable_chunked_prefill",
"scheduler_cls",
"async_scheduling",
mode="wrap",
)
@classmethod
@ -242,9 +270,6 @@ class SchedulerConfig:
self.long_prefill_token_threshold,
)
if self.async_scheduling:
self.scheduler_cls = "vllm.v1.core.sched.async_scheduler.AsyncScheduler"
@model_validator(mode="after")
def _verify_args(self) -> Self:
if (

View File

@ -353,6 +353,53 @@ class VllmConfig:
self.model_config, self.load_config
)
executor_backend = self.parallel_config.distributed_executor_backend
executor_supports_async_sched = executor_backend in (
"mp",
"uni",
"external_launcher",
)
if self.scheduler_config.async_scheduling:
# Async scheduling explicitly enabled, hard fail any incompatibilities.
if self.parallel_config.pipeline_parallel_size > 1:
raise ValueError(
"Async scheduling is not yet compatible with "
"pipeline_parallel_size > 1."
)
if self.speculative_config is not None:
raise ValueError(
"Async scheduling is not yet compatible with speculative decoding."
)
if not executor_supports_async_sched:
raise ValueError(
"Currently, async scheduling only supports `mp`, `uni`, or "
"`external_launcher` distributed executor backend, but you chose "
f"`{executor_backend}`."
)
elif self.scheduler_config.async_scheduling is None:
# Enable async scheduling unless there is an incompatible option.
# NOTE: we won't reach here until async scheduling is enabled by default.
if (
self.parallel_config.pipeline_parallel_size > 1
or self.speculative_config is not None
):
logger.warning(
"Async scheduling is not yet supported with speculative decoding "
" or pipeline_parallel_size > 1 and will be disabled."
)
self.scheduler_config.async_scheduling = False
elif not executor_supports_async_sched:
logger.warning(
"Async scheduling will be disabled because it is not supported "
"with the `%s` distributed executor backend (only `mp`, `uni`, and "
"`external_launcher` are supported).",
executor_backend,
)
self.scheduler_config.async_scheduling = False
else:
self.scheduler_config.async_scheduling = True
from vllm.platforms import current_platform
if (
@ -467,7 +514,7 @@ class VllmConfig:
self.speculative_config is not None
and self.speculative_config.use_eagle()
):
raise NotImplementedError(
raise ValueError(
"Fast prefill optimization for KV sharing is not "
"compatible with EAGLE as EAGLE requires correct logits "
"for all tokens while fast prefill gives incorrect logits "
@ -491,7 +538,7 @@ class VllmConfig:
)
if not getattr(self.model_config.hf_config, "is_causal", True):
disable_chunked_prefill_reasons.append(
"Only models using causal attention supports chunked "
"Only models using causal attention support chunked "
"prefill and prefix caching; disabling both."
)
elif self.model_config.is_encoder_decoder:

View File

@ -513,7 +513,7 @@ class EngineArgs:
ObservabilityConfig.collect_detailed_traces
)
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
scheduler_cls: str | type[object] = SchedulerConfig.scheduler_cls
scheduler_cls: str | type[object] | None = SchedulerConfig.scheduler_cls
pooler_config: PoolerConfig | None = ModelConfig.pooler_config
override_pooler_config: dict | PoolerConfig | None = (
@ -552,7 +552,7 @@ class EngineArgs:
)
"""Custom logitproc types"""
async_scheduling: bool = SchedulerConfig.async_scheduling
async_scheduling: bool | None = SchedulerConfig.async_scheduling
kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill
@ -1479,20 +1479,6 @@ class EngineArgs:
else ParallelConfig.data_parallel_rpc_port
)
if self.async_scheduling:
if self.pipeline_parallel_size > 1:
raise ValueError(
"Async scheduling is not supported with pipeline-parallel-size > 1."
)
# Currently, async scheduling does not support speculative decoding.
# TODO(woosuk): Support it.
if self.speculative_config is not None:
raise ValueError(
"Currently, speculative decoding is not supported with "
"async scheduling."
)
# Forward the deprecated CLI args to the EPLB config.
if self.num_redundant_experts is not None:
self.eplb_config.num_redundant_experts = self.num_redundant_experts
@ -1536,16 +1522,6 @@ class EngineArgs:
_api_process_rank=self._api_process_rank,
)
if self.async_scheduling and (
parallel_config.distributed_executor_backend
not in ("mp", "uni", "external_launcher")
):
raise ValueError(
"Currently, async scheduling only supports `mp`, `uni` or "
"`external_launcher` distributed executor backend, but you choose "
f"`{parallel_config.distributed_executor_backend}`."
)
speculative_config = self.create_speculative_config(
target_model_config=model_config,
target_parallel_config=parallel_config,

View File

@ -4,16 +4,34 @@ from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import TYPE_CHECKING, Optional
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.engine import EngineCoreOutputs
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager
class SchedulerInterface(ABC):
@abstractmethod
def __init__(
self,
vllm_config: "VllmConfig",
kv_cache_config: "KVCacheConfig",
structured_output_manager: "StructuredOutputManager",
block_size: int,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
include_finished_set: bool = False,
log_stats: bool = False,
) -> None:
raise NotImplementedError
@abstractmethod
def schedule(self) -> "SchedulerOutput":
"""Schedule the requests to process in this scheduling step.

View File

@ -29,7 +29,6 @@ from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
from vllm.utils.gc_utils import maybe_attach_gc_debug_callback
from vllm.utils.hashing import get_hash_fn_by_name
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.network_utils import make_zmq_socket
from vllm.utils.system_utils import decorate_logs, set_process_title
from vllm.v1.core.kv_cache_utils import (
@ -41,7 +40,6 @@ from vllm.v1.core.kv_cache_utils import (
)
from vllm.v1.core.sched.interface import SchedulerInterface
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
from vllm.v1.engine import (
EngineCoreOutputs,
EngineCoreRequest,
@ -117,23 +115,7 @@ class EngineCore:
self.structured_output_manager = StructuredOutputManager(vllm_config)
# Setup scheduler.
if isinstance(vllm_config.scheduler_config.scheduler_cls, str):
Scheduler = resolve_obj_by_qualname(
vllm_config.scheduler_config.scheduler_cls
)
else:
Scheduler = vllm_config.scheduler_config.scheduler_cls
# This warning can be removed once the V1 Scheduler interface is
# finalized and we can maintain support for scheduler classes that
# implement it
if Scheduler is not V1Scheduler:
logger.warning(
"Using configured V1 scheduler class %s. "
"This scheduler interface is not public and "
"compatibility may not be maintained.",
vllm_config.scheduler_config.scheduler_cls,
)
Scheduler = vllm_config.scheduler_config.get_scheduler_cls()
if len(kv_cache_config.kv_cache_groups) == 0:
# Encoder models without KV cache don't support