[AsyncScheduling] Support ray backend for async scheduling

Signed-off-by: Lehua Ding <lehuading@tencent.com>
This commit is contained in:
Lehua Ding 2025-12-24 16:52:30 +08:00
parent 8f8fda261a
commit e591fc16ca
4 changed files with 62 additions and 25 deletions

View File

@ -69,6 +69,12 @@ def test_without_spec_decoding(
(False, "mp", True, None, True), (False, "mp", True, None, True),
(True, "mp", True, None, True), (True, "mp", True, None, True),
(True, "uni", True, None, True), (True, "uni", True, None, True),
(False, "ray", False, None, False),
(True, "ray", False, None, True),
(False, "ray", True, None, False),
(True, "ray", True, None, False),
(False, "ray", True, None, True),
(True, "ray", True, None, True),
] ]
if current_platform.is_rocm(): if current_platform.is_rocm():
@ -119,6 +125,13 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
(True, "uni", True, spec_config_short, False), (True, "uni", True, spec_config_short, False),
(True, "mp", True, spec_config, True), (True, "mp", True, spec_config, True),
(True, "uni", True, spec_config_short, True), (True, "uni", True, spec_config_short, True),
(False, "ray", False, None, False),
(False, "ray", False, spec_config, False),
(True, "ray", False, spec_config, True),
(False, "ray", True, spec_config, False),
(True, "ray", True, spec_config, False),
(False, "ray", True, spec_config_short, True),
(True, "ray", True, spec_config, True),
] ]
# On ROCm, use TRITON_ATTN + float32 for better numerical consistency # On ROCm, use TRITON_ATTN + float32 for better numerical consistency

View File

@ -532,13 +532,6 @@ class VllmConfig:
self.model_config, self.load_config self.model_config, self.load_config
) )
executor_backend = self.parallel_config.distributed_executor_backend
executor_supports_async_sched = executor_backend in (
"mp",
"uni",
"external_launcher",
)
if self.scheduler_config.async_scheduling: if self.scheduler_config.async_scheduling:
# Async scheduling explicitly enabled, hard fail any incompatibilities. # Async scheduling explicitly enabled, hard fail any incompatibilities.
if self.parallel_config.pipeline_parallel_size > 1: if self.parallel_config.pipeline_parallel_size > 1:
@ -562,12 +555,6 @@ class VllmConfig:
"this situation now. please set " "this situation now. please set "
"disable_padded_drafter_batch=Fasle" "disable_padded_drafter_batch=Fasle"
) )
if not executor_supports_async_sched:
raise ValueError(
"Currently, async scheduling only supports `mp`, `uni`, or "
"`external_launcher` distributed executor backend, but you chose "
f"`{executor_backend}`."
)
elif self.scheduler_config.async_scheduling is None: elif self.scheduler_config.async_scheduling is None:
# Enable async scheduling unless there is an incompatible option. # Enable async scheduling unless there is an incompatible option.
# NOTE: we won't reach here until async scheduling is enabled by default. # NOTE: we won't reach here until async scheduling is enabled by default.
@ -580,14 +567,6 @@ class VllmConfig:
" or pipeline_parallel_size > 1 and will be disabled." " or pipeline_parallel_size > 1 and will be disabled."
) )
self.scheduler_config.async_scheduling = False self.scheduler_config.async_scheduling = False
elif not executor_supports_async_sched:
logger.warning(
"Async scheduling will be disabled because it is not supported "
"with the `%s` distributed executor backend (only `mp`, `uni`, and "
"`external_launcher` are supported).",
executor_backend,
)
self.scheduler_config.async_scheduling = False
else: else:
self.scheduler_config.async_scheduling = True self.scheduler_config.async_scheduling = True

View File

@ -6,6 +6,7 @@ from collections import defaultdict
from collections.abc import Callable from collections.abc import Callable
from concurrent.futures import Future from concurrent.futures import Future
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
import cloudpickle import cloudpickle
@ -434,6 +435,11 @@ class RayDistributedExecutor(Executor):
return self._execute_dag(scheduler_output, grammar_output, non_block) return self._execute_dag(scheduler_output, grammar_output, non_block)
@staticmethod
def _get_async_refs(refs, worker, timeout=None):
ray.get(refs, timeout=timeout)
return worker.execute_method.remote("get_execute_model_output")
def _execute_dag( def _execute_dag(
self, self,
scheduler_output: SchedulerOutput, scheduler_output: SchedulerOutput,
@ -446,6 +452,14 @@ class RayDistributedExecutor(Executor):
refs = self.forward_dag.execute((scheduler_output, grammar_output)) # type: ignore refs = self.forward_dag.execute((scheduler_output, grammar_output)) # type: ignore
if self.scheduler_config.async_scheduling:
assert non_block
refs = [
partial(RayDistributedExecutor._get_async_refs, ref, worker)
for ref, worker in zip(refs, self.workers)
]
if not self.has_connector: if not self.has_connector:
# Get output only from a single worker (output_rank) # Get output only from a single worker (output_rank)
# When PP is not used, we block here until the result is available. # When PP is not used, we block here until the result is available.

View File

@ -3,7 +3,7 @@
import os import os
import time import time
from collections import defaultdict from collections import defaultdict, deque
from concurrent.futures import Future from concurrent.futures import Future
from typing import TYPE_CHECKING, Union from typing import TYPE_CHECKING, Union
@ -49,6 +49,7 @@ try:
# The flag indicates is set_device is called on # The flag indicates is set_device is called on
# that thread. # that thread.
self.compiled_dag_cuda_device_set = False self.compiled_dag_cuda_device_set = False
self._execute_model_outputs = deque[AsyncModelRunnerOutput]()
def get_node_ip(self) -> str: def get_node_ip(self) -> str:
return get_ip() return get_ip()
@ -87,6 +88,7 @@ try:
) -> Union[ ) -> Union[
"ModelRunnerOutput", "ModelRunnerOutput",
tuple["SchedulerOutput", "GrammarOutput", "IntermediateTensors"], tuple["SchedulerOutput", "GrammarOutput", "IntermediateTensors"],
None,
]: ]:
# This method is used by Ray Compiled Graph to execute the model, # This method is used by Ray Compiled Graph to execute the model,
# and it needs a special logic of self.setup_device_if_necessary() # and it needs a special logic of self.setup_device_if_necessary()
@ -112,11 +114,24 @@ try:
output = scheduler_output, grammar_output, None output = scheduler_output, grammar_output, None
elif output is None: elif output is None:
output = self.worker.model_runner.sample_tokens(grammar_output) output = self.worker.model_runner.sample_tokens(grammar_output)
if self.vllm_config.scheduler_config.async_scheduling:
self._execute_model_outputs.append(output)
return None
return output
def get_execute_model_output(self) -> "ModelRunnerOutput":
assert self.vllm_config.scheduler_config.async_scheduling
assert self._execute_model_outputs, "No execute_model output available"
output = self._execute_model_outputs.popleft()
if isinstance(output, AsyncModelRunnerOutput):
# Ensure outputs crossing Ray compiled DAG are serializable. # Ensure outputs crossing Ray compiled DAG are serializable.
# AsyncModelRunnerOutput holds CUDA events and cannot be # AsyncModelRunnerOutput holds CUDA events and cannot be
# pickled. # pickled.
if isinstance(output, AsyncModelRunnerOutput): output = output.get_output()
output = output.get_output()
return output return output
def override_env_vars(self, vars: dict[str, str]): def override_env_vars(self, vars: dict[str, str]):
@ -146,8 +161,24 @@ class FutureWrapper(Future):
self.ref_or_refs = ref_or_refs self.ref_or_refs = ref_or_refs
self.aggregator = aggregator self.aggregator = aggregator
def is_callable(self, ref_or_refs):
if isinstance(ref_or_refs, list):
return callable(ref_or_refs[0])
else:
return callable(ref_or_refs)
def get_refs(self, timeout=None):
if self.is_callable(self.ref_or_refs):
if isinstance(self.ref_or_refs, list):
refs = [ref(timeout) for ref in self.ref_or_refs]
else:
refs = self.ref_or_refs(timeout)
else:
refs = self.ref_or_refs
return refs
def result(self, timeout=None): def result(self, timeout=None):
outputs = ray.get(self.ref_or_refs, timeout=timeout) outputs = ray.get(self.get_refs(), timeout=timeout)
if self.aggregator is None: if self.aggregator is None:
return outputs return outputs