mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 17:45:19 +08:00
[Core] Rework handling of async scheduling config (#28250)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
18903216f5
commit
da786e339e
@ -66,7 +66,7 @@ def test_engine_core():
|
||||
assert len(engine_core.scheduler.waiting) == 1
|
||||
assert len(engine_core.scheduler.running) == 0
|
||||
|
||||
_ = engine_core.step()
|
||||
_ = engine_core.step_fn()
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
assert len(engine_core.scheduler.running) == 1
|
||||
|
||||
@ -75,7 +75,7 @@ def test_engine_core():
|
||||
assert len(engine_core.scheduler.waiting) == 1
|
||||
assert len(engine_core.scheduler.running) == 1
|
||||
|
||||
_ = engine_core.step()
|
||||
_ = engine_core.step_fn()
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
assert len(engine_core.scheduler.running) == 2
|
||||
|
||||
@ -85,12 +85,12 @@ def test_engine_core():
|
||||
assert len(engine_core.scheduler.waiting) == 2
|
||||
assert len(engine_core.scheduler.running) == 2
|
||||
|
||||
_ = engine_core.step()
|
||||
_ = engine_core.step_fn()
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
assert len(engine_core.scheduler.running) == 4
|
||||
|
||||
# Loop through until they are all done.
|
||||
while (outs := engine_core.step()[0].get(0)) and outs.outputs:
|
||||
while (outs := engine_core.step_fn()[0].get(0)) and outs.outputs:
|
||||
pass
|
||||
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
@ -107,7 +107,7 @@ def test_engine_core():
|
||||
assert engine_core.scheduler.has_unfinished_requests()
|
||||
assert not engine_core.scheduler.has_finished_requests()
|
||||
|
||||
_ = engine_core.step()
|
||||
_ = engine_core.step_fn()
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
assert len(engine_core.scheduler.running) == 1
|
||||
assert engine_core.scheduler.has_unfinished_requests()
|
||||
@ -119,7 +119,7 @@ def test_engine_core():
|
||||
assert not engine_core.scheduler.has_unfinished_requests()
|
||||
assert engine_core.scheduler.has_finished_requests()
|
||||
|
||||
_ = engine_core.step()
|
||||
_ = engine_core.step_fn()
|
||||
assert not engine_core.scheduler.has_unfinished_requests()
|
||||
assert not engine_core.scheduler.has_finished_requests()
|
||||
|
||||
@ -133,7 +133,7 @@ def test_engine_core():
|
||||
assert len(engine_core.scheduler.waiting) == 2
|
||||
assert len(engine_core.scheduler.running) == 0
|
||||
|
||||
_ = engine_core.step()
|
||||
_ = engine_core.step_fn()
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
assert len(engine_core.scheduler.running) == 2
|
||||
|
||||
@ -141,7 +141,7 @@ def test_engine_core():
|
||||
assert len(engine_core.scheduler.waiting) == 1
|
||||
assert len(engine_core.scheduler.running) == 2
|
||||
|
||||
_ = engine_core.step()
|
||||
_ = engine_core.step_fn()
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
assert len(engine_core.scheduler.running) == 3
|
||||
|
||||
@ -150,7 +150,7 @@ def test_engine_core():
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
assert len(engine_core.scheduler.running) == 2
|
||||
|
||||
_ = engine_core.step()
|
||||
_ = engine_core.step_fn()
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
assert len(engine_core.scheduler.running) == 2
|
||||
|
||||
@ -165,12 +165,12 @@ def test_engine_core():
|
||||
req0.request_id = req1.request_id = "test"
|
||||
engine_core.add_request(*engine_core.preprocess_add_request(req0))
|
||||
|
||||
while (outs := engine_core.step()[0].get(0)) and outs.outputs:
|
||||
pass
|
||||
while engine_core.scheduler.has_requests():
|
||||
engine_core.step_fn()
|
||||
|
||||
engine_core.add_request(*engine_core.preprocess_add_request(req1))
|
||||
while (outs := engine_core.step()[0].get(0)) and outs.outputs:
|
||||
pass
|
||||
while engine_core.scheduler.has_requests():
|
||||
engine_core.step_fn()
|
||||
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
assert len(engine_core.scheduler.running) == 0
|
||||
@ -208,8 +208,8 @@ def test_engine_core_advanced_sampling():
|
||||
assert len(engine_core.scheduler.waiting) == 1
|
||||
assert len(engine_core.scheduler.running) == 0
|
||||
# Loop through until they are all done.
|
||||
while (outs := engine_core.step()[0].get(0)) and outs.outputs:
|
||||
pass
|
||||
while engine_core.scheduler.has_requests():
|
||||
engine_core.step_fn()
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
assert len(engine_core.scheduler.running) == 0
|
||||
|
||||
@ -297,6 +297,8 @@ def test_engine_core_concurrent_batches():
|
||||
max_num_batched_tokens=10,
|
||||
# Reduce startup time.
|
||||
enforce_eager=True,
|
||||
# Test concurrent batch behaviour independently of async scheduling.
|
||||
async_scheduling=False,
|
||||
)
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
with set_default_torch_num_threads(1):
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
import hashlib
|
||||
from collections.abc import Callable
|
||||
from dataclasses import InitVar
|
||||
from typing import Any, Literal
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from pydantic import Field, field_validator, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
@ -17,6 +17,10 @@ from vllm.utils import (
|
||||
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
|
||||
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS,
|
||||
)
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.interface import SchedulerInterface
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -120,7 +124,7 @@ class SchedulerConfig:
|
||||
|
||||
# scheduler class or path. "vllm.v1.core.sched.scheduler.Scheduler"
|
||||
# (default) or "mod.custom_class".
|
||||
scheduler_cls: str | type[object] = "vllm.v1.core.sched.scheduler.Scheduler"
|
||||
scheduler_cls: str | type[object] = Field(default=None)
|
||||
"""The scheduler class to use. "vllm.v1.core.sched.scheduler.Scheduler" is
|
||||
the default scheduler. Can be a class directly or the path to a class of
|
||||
form "mod.custom_class"."""
|
||||
@ -132,12 +136,34 @@ class SchedulerConfig:
|
||||
"""
|
||||
|
||||
async_scheduling: bool = False
|
||||
"""EXPERIMENTAL: If set to True, perform async scheduling. This may help
|
||||
reduce the CPU overheads, leading to better latency and throughput. However,
|
||||
async scheduling is currently not supported with some features such as
|
||||
structured outputs, speculative decoding, and pipeline parallelism.
|
||||
"""If set to True, perform async scheduling. This helps to avoid gaps in
|
||||
GPU utilization, leading to better latency and throughput.
|
||||
Async scheduling is currently not supported with some features such as
|
||||
speculative decoding and pipeline parallelism.
|
||||
"""
|
||||
|
||||
def get_scheduler_cls(self) -> type["SchedulerInterface"]:
|
||||
if self.scheduler_cls is None:
|
||||
if self.async_scheduling:
|
||||
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
|
||||
|
||||
return AsyncScheduler
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
|
||||
return Scheduler
|
||||
|
||||
# This warning can be removed once the Scheduler interface is
|
||||
# finalized and we can maintain support for scheduler classes that
|
||||
# implement it
|
||||
logger.warning_once(
|
||||
"Using custom scheduler class %s. This scheduler interface is "
|
||||
"not public and compatibility may not be maintained.",
|
||||
self.scheduler_cls,
|
||||
)
|
||||
if not isinstance(self.scheduler_cls, str):
|
||||
return cast(type["SchedulerInterface"], self.scheduler_cls)
|
||||
return resolve_obj_by_qualname(self.scheduler_cls)
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
@ -161,6 +187,8 @@ class SchedulerConfig:
|
||||
"max_num_seqs",
|
||||
"max_model_len",
|
||||
"enable_chunked_prefill",
|
||||
"scheduler_cls",
|
||||
"async_scheduling",
|
||||
mode="wrap",
|
||||
)
|
||||
@classmethod
|
||||
@ -242,9 +270,6 @@ class SchedulerConfig:
|
||||
self.long_prefill_token_threshold,
|
||||
)
|
||||
|
||||
if self.async_scheduling:
|
||||
self.scheduler_cls = "vllm.v1.core.sched.async_scheduler.AsyncScheduler"
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _verify_args(self) -> Self:
|
||||
if (
|
||||
|
||||
@ -353,6 +353,53 @@ class VllmConfig:
|
||||
self.model_config, self.load_config
|
||||
)
|
||||
|
||||
executor_backend = self.parallel_config.distributed_executor_backend
|
||||
executor_supports_async_sched = executor_backend in (
|
||||
"mp",
|
||||
"uni",
|
||||
"external_launcher",
|
||||
)
|
||||
|
||||
if self.scheduler_config.async_scheduling:
|
||||
# Async scheduling explicitly enabled, hard fail any incompatibilities.
|
||||
if self.parallel_config.pipeline_parallel_size > 1:
|
||||
raise ValueError(
|
||||
"Async scheduling is not yet compatible with "
|
||||
"pipeline_parallel_size > 1."
|
||||
)
|
||||
if self.speculative_config is not None:
|
||||
raise ValueError(
|
||||
"Async scheduling is not yet compatible with speculative decoding."
|
||||
)
|
||||
if not executor_supports_async_sched:
|
||||
raise ValueError(
|
||||
"Currently, async scheduling only supports `mp`, `uni`, or "
|
||||
"`external_launcher` distributed executor backend, but you chose "
|
||||
f"`{executor_backend}`."
|
||||
)
|
||||
elif self.scheduler_config.async_scheduling is None:
|
||||
# Enable async scheduling unless there is an incompatible option.
|
||||
# NOTE: we won't reach here until async scheduling is enabled by default.
|
||||
if (
|
||||
self.parallel_config.pipeline_parallel_size > 1
|
||||
or self.speculative_config is not None
|
||||
):
|
||||
logger.warning(
|
||||
"Async scheduling is not yet supported with speculative decoding "
|
||||
" or pipeline_parallel_size > 1 and will be disabled."
|
||||
)
|
||||
self.scheduler_config.async_scheduling = False
|
||||
elif not executor_supports_async_sched:
|
||||
logger.warning(
|
||||
"Async scheduling will be disabled because it is not supported "
|
||||
"with the `%s` distributed executor backend (only `mp`, `uni`, and "
|
||||
"`external_launcher` are supported).",
|
||||
executor_backend,
|
||||
)
|
||||
self.scheduler_config.async_scheduling = False
|
||||
else:
|
||||
self.scheduler_config.async_scheduling = True
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if (
|
||||
@ -467,7 +514,7 @@ class VllmConfig:
|
||||
self.speculative_config is not None
|
||||
and self.speculative_config.use_eagle()
|
||||
):
|
||||
raise NotImplementedError(
|
||||
raise ValueError(
|
||||
"Fast prefill optimization for KV sharing is not "
|
||||
"compatible with EAGLE as EAGLE requires correct logits "
|
||||
"for all tokens while fast prefill gives incorrect logits "
|
||||
@ -491,7 +538,7 @@ class VllmConfig:
|
||||
)
|
||||
if not getattr(self.model_config.hf_config, "is_causal", True):
|
||||
disable_chunked_prefill_reasons.append(
|
||||
"Only models using causal attention supports chunked "
|
||||
"Only models using causal attention support chunked "
|
||||
"prefill and prefix caching; disabling both."
|
||||
)
|
||||
elif self.model_config.is_encoder_decoder:
|
||||
|
||||
@ -513,7 +513,7 @@ class EngineArgs:
|
||||
ObservabilityConfig.collect_detailed_traces
|
||||
)
|
||||
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
|
||||
scheduler_cls: str | type[object] = SchedulerConfig.scheduler_cls
|
||||
scheduler_cls: str | type[object] | None = SchedulerConfig.scheduler_cls
|
||||
|
||||
pooler_config: PoolerConfig | None = ModelConfig.pooler_config
|
||||
override_pooler_config: dict | PoolerConfig | None = (
|
||||
@ -552,7 +552,7 @@ class EngineArgs:
|
||||
)
|
||||
"""Custom logitproc types"""
|
||||
|
||||
async_scheduling: bool = SchedulerConfig.async_scheduling
|
||||
async_scheduling: bool | None = SchedulerConfig.async_scheduling
|
||||
|
||||
kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill
|
||||
|
||||
@ -1479,20 +1479,6 @@ class EngineArgs:
|
||||
else ParallelConfig.data_parallel_rpc_port
|
||||
)
|
||||
|
||||
if self.async_scheduling:
|
||||
if self.pipeline_parallel_size > 1:
|
||||
raise ValueError(
|
||||
"Async scheduling is not supported with pipeline-parallel-size > 1."
|
||||
)
|
||||
|
||||
# Currently, async scheduling does not support speculative decoding.
|
||||
# TODO(woosuk): Support it.
|
||||
if self.speculative_config is not None:
|
||||
raise ValueError(
|
||||
"Currently, speculative decoding is not supported with "
|
||||
"async scheduling."
|
||||
)
|
||||
|
||||
# Forward the deprecated CLI args to the EPLB config.
|
||||
if self.num_redundant_experts is not None:
|
||||
self.eplb_config.num_redundant_experts = self.num_redundant_experts
|
||||
@ -1536,16 +1522,6 @@ class EngineArgs:
|
||||
_api_process_rank=self._api_process_rank,
|
||||
)
|
||||
|
||||
if self.async_scheduling and (
|
||||
parallel_config.distributed_executor_backend
|
||||
not in ("mp", "uni", "external_launcher")
|
||||
):
|
||||
raise ValueError(
|
||||
"Currently, async scheduling only supports `mp`, `uni` or "
|
||||
"`external_launcher` distributed executor backend, but you choose "
|
||||
f"`{parallel_config.distributed_executor_backend}`."
|
||||
)
|
||||
|
||||
speculative_config = self.create_speculative_config(
|
||||
target_model_config=model_config,
|
||||
target_parallel_config=parallel_config,
|
||||
|
||||
@ -4,16 +4,34 @@ from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
from vllm.v1.engine import EngineCoreOutputs
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.metrics.stats import SchedulerStats
|
||||
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
|
||||
|
||||
class SchedulerInterface(ABC):
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
kv_cache_config: "KVCacheConfig",
|
||||
structured_output_manager: "StructuredOutputManager",
|
||||
block_size: int,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
include_finished_set: bool = False,
|
||||
log_stats: bool = False,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def schedule(self) -> "SchedulerOutput":
|
||||
"""Schedule the requests to process in this scheduling step.
|
||||
|
||||
@ -29,7 +29,6 @@ from vllm.tasks import POOLING_TASKS, SupportedTask
|
||||
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
|
||||
from vllm.utils.gc_utils import maybe_attach_gc_debug_callback
|
||||
from vllm.utils.hashing import get_hash_fn_by_name
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
from vllm.utils.network_utils import make_zmq_socket
|
||||
from vllm.utils.system_utils import decorate_logs, set_process_title
|
||||
from vllm.v1.core.kv_cache_utils import (
|
||||
@ -41,7 +40,6 @@ from vllm.v1.core.kv_cache_utils import (
|
||||
)
|
||||
from vllm.v1.core.sched.interface import SchedulerInterface
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
|
||||
from vllm.v1.engine import (
|
||||
EngineCoreOutputs,
|
||||
EngineCoreRequest,
|
||||
@ -117,23 +115,7 @@ class EngineCore:
|
||||
self.structured_output_manager = StructuredOutputManager(vllm_config)
|
||||
|
||||
# Setup scheduler.
|
||||
if isinstance(vllm_config.scheduler_config.scheduler_cls, str):
|
||||
Scheduler = resolve_obj_by_qualname(
|
||||
vllm_config.scheduler_config.scheduler_cls
|
||||
)
|
||||
else:
|
||||
Scheduler = vllm_config.scheduler_config.scheduler_cls
|
||||
|
||||
# This warning can be removed once the V1 Scheduler interface is
|
||||
# finalized and we can maintain support for scheduler classes that
|
||||
# implement it
|
||||
if Scheduler is not V1Scheduler:
|
||||
logger.warning(
|
||||
"Using configured V1 scheduler class %s. "
|
||||
"This scheduler interface is not public and "
|
||||
"compatibility may not be maintained.",
|
||||
vllm_config.scheduler_config.scheduler_cls,
|
||||
)
|
||||
Scheduler = vllm_config.scheduler_config.get_scheduler_cls()
|
||||
|
||||
if len(kv_cache_config.kv_cache_groups) == 0:
|
||||
# Encoder models without KV cache don't support
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user