[AsyncScheduling] Support ray backend for async scheduling

Signed-off-by: Lehua Ding <lehuading@tencent.com>
This commit is contained in:
Lehua Ding 2025-12-24 16:52:30 +08:00
parent 8f8fda261a
commit e591fc16ca
4 changed files with 62 additions and 25 deletions

View File

@ -69,6 +69,12 @@ def test_without_spec_decoding(
(False, "mp", True, None, True),
(True, "mp", True, None, True),
(True, "uni", True, None, True),
(False, "ray", False, None, False),
(True, "ray", False, None, True),
(False, "ray", True, None, False),
(True, "ray", True, None, False),
(False, "ray", True, None, True),
(True, "ray", True, None, True),
]
if current_platform.is_rocm():
@ -119,6 +125,13 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
(True, "uni", True, spec_config_short, False),
(True, "mp", True, spec_config, True),
(True, "uni", True, spec_config_short, True),
(False, "ray", False, None, False),
(False, "ray", False, spec_config, False),
(True, "ray", False, spec_config, True),
(False, "ray", True, spec_config, False),
(True, "ray", True, spec_config, False),
(False, "ray", True, spec_config_short, True),
(True, "ray", True, spec_config, True),
]
# On ROCm, use TRITON_ATTN + float32 for better numerical consistency

View File

@ -532,13 +532,6 @@ class VllmConfig:
self.model_config, self.load_config
)
executor_backend = self.parallel_config.distributed_executor_backend
executor_supports_async_sched = executor_backend in (
"mp",
"uni",
"external_launcher",
)
if self.scheduler_config.async_scheduling:
# Async scheduling explicitly enabled, hard fail any incompatibilities.
if self.parallel_config.pipeline_parallel_size > 1:
@ -562,12 +555,6 @@ class VllmConfig:
"this situation now. please set "
"disable_padded_drafter_batch=Fasle"
)
if not executor_supports_async_sched:
raise ValueError(
"Currently, async scheduling only supports `mp`, `uni`, or "
"`external_launcher` distributed executor backend, but you chose "
f"`{executor_backend}`."
)
elif self.scheduler_config.async_scheduling is None:
# Enable async scheduling unless there is an incompatible option.
# NOTE: we won't reach here until async scheduling is enabled by default.
@ -580,14 +567,6 @@ class VllmConfig:
" or pipeline_parallel_size > 1 and will be disabled."
)
self.scheduler_config.async_scheduling = False
elif not executor_supports_async_sched:
logger.warning(
"Async scheduling will be disabled because it is not supported "
"with the `%s` distributed executor backend (only `mp`, `uni`, and "
"`external_launcher` are supported).",
executor_backend,
)
self.scheduler_config.async_scheduling = False
else:
self.scheduler_config.async_scheduling = True

View File

@ -6,6 +6,7 @@ from collections import defaultdict
from collections.abc import Callable
from concurrent.futures import Future
from dataclasses import dataclass
from functools import partial
from typing import TYPE_CHECKING, Any
import cloudpickle
@ -434,6 +435,11 @@ class RayDistributedExecutor(Executor):
return self._execute_dag(scheduler_output, grammar_output, non_block)
@staticmethod
def _get_async_refs(refs, worker, timeout=None):
ray.get(refs, timeout=timeout)
return worker.execute_method.remote("get_execute_model_output")
def _execute_dag(
self,
scheduler_output: SchedulerOutput,
@ -446,6 +452,14 @@ class RayDistributedExecutor(Executor):
refs = self.forward_dag.execute((scheduler_output, grammar_output)) # type: ignore
if self.scheduler_config.async_scheduling:
assert non_block
refs = [
partial(RayDistributedExecutor._get_async_refs, ref, worker)
for ref, worker in zip(refs, self.workers)
]
if not self.has_connector:
# Get output only from a single worker (output_rank)
# When PP is not used, we block here until the result is available.

View File

@ -3,7 +3,7 @@
import os
import time
from collections import defaultdict
from collections import defaultdict, deque
from concurrent.futures import Future
from typing import TYPE_CHECKING, Union
@ -49,6 +49,7 @@ try:
# The flag indicates is set_device is called on
# that thread.
self.compiled_dag_cuda_device_set = False
self._execute_model_outputs = deque[AsyncModelRunnerOutput]()
def get_node_ip(self) -> str:
return get_ip()
@ -87,6 +88,7 @@ try:
) -> Union[
"ModelRunnerOutput",
tuple["SchedulerOutput", "GrammarOutput", "IntermediateTensors"],
None,
]:
# This method is used by Ray Compiled Graph to execute the model,
# and it needs a special logic of self.setup_device_if_necessary()
@ -112,11 +114,24 @@ try:
output = scheduler_output, grammar_output, None
elif output is None:
output = self.worker.model_runner.sample_tokens(grammar_output)
if self.vllm_config.scheduler_config.async_scheduling:
self._execute_model_outputs.append(output)
return None
return output
def get_execute_model_output(self) -> "ModelRunnerOutput":
assert self.vllm_config.scheduler_config.async_scheduling
assert self._execute_model_outputs, "No execute_model output available"
output = self._execute_model_outputs.popleft()
if isinstance(output, AsyncModelRunnerOutput):
# Ensure outputs crossing Ray compiled DAG are serializable.
# AsyncModelRunnerOutput holds CUDA events and cannot be
# pickled.
if isinstance(output, AsyncModelRunnerOutput):
output = output.get_output()
output = output.get_output()
return output
def override_env_vars(self, vars: dict[str, str]):
@ -146,8 +161,24 @@ class FutureWrapper(Future):
self.ref_or_refs = ref_or_refs
self.aggregator = aggregator
def is_callable(self, ref_or_refs):
if isinstance(ref_or_refs, list):
return callable(ref_or_refs[0])
else:
return callable(ref_or_refs)
def get_refs(self, timeout=None):
if self.is_callable(self.ref_or_refs):
if isinstance(self.ref_or_refs, list):
refs = [ref(timeout) for ref in self.ref_or_refs]
else:
refs = self.ref_or_refs(timeout)
else:
refs = self.ref_or_refs
return refs
def result(self, timeout=None):
outputs = ray.get(self.ref_or_refs, timeout=timeout)
outputs = ray.get(self.get_refs(), timeout=timeout)
if self.aggregator is None:
return outputs