From e591fc16ca3c3391ede2fba8f0a4c25adabb57b0 Mon Sep 17 00:00:00 2001 From: Lehua Ding Date: Wed, 24 Dec 2025 16:52:30 +0800 Subject: [PATCH] [AsyncScheduling] Support ray backend for async scheduling Signed-off-by: Lehua Ding --- tests/v1/e2e/test_async_scheduling.py | 13 +++++++++ vllm/config/vllm.py | 21 --------------- vllm/v1/executor/ray_executor.py | 14 ++++++++++ vllm/v1/executor/ray_utils.py | 39 ++++++++++++++++++++++++--- 4 files changed, 62 insertions(+), 25 deletions(-) diff --git a/tests/v1/e2e/test_async_scheduling.py b/tests/v1/e2e/test_async_scheduling.py index 5cef9b33c9984..f70d6d7adde41 100644 --- a/tests/v1/e2e/test_async_scheduling.py +++ b/tests/v1/e2e/test_async_scheduling.py @@ -69,6 +69,12 @@ def test_without_spec_decoding( (False, "mp", True, None, True), (True, "mp", True, None, True), (True, "uni", True, None, True), + (False, "ray", False, None, False), + (True, "ray", False, None, True), + (False, "ray", True, None, False), + (True, "ray", True, None, False), + (False, "ray", True, None, True), + (True, "ray", True, None, True), ] if current_platform.is_rocm(): @@ -119,6 +125,13 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch): (True, "uni", True, spec_config_short, False), (True, "mp", True, spec_config, True), (True, "uni", True, spec_config_short, True), + (False, "ray", False, None, False), + (False, "ray", False, spec_config, False), + (True, "ray", False, spec_config, True), + (False, "ray", True, spec_config, False), + (True, "ray", True, spec_config, False), + (False, "ray", True, spec_config_short, True), + (True, "ray", True, spec_config, True), ] # On ROCm, use TRITON_ATTN + float32 for better numerical consistency diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index b5f8f916de438..18a62a6abe453 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -532,13 +532,6 @@ class VllmConfig: self.model_config, self.load_config ) - executor_backend = self.parallel_config.distributed_executor_backend - executor_supports_async_sched = executor_backend in ( - "mp", - "uni", - "external_launcher", - ) - if self.scheduler_config.async_scheduling: # Async scheduling explicitly enabled, hard fail any incompatibilities. if self.parallel_config.pipeline_parallel_size > 1: @@ -562,12 +555,6 @@ class VllmConfig: "this situation now. please set " "disable_padded_drafter_batch=Fasle" ) - if not executor_supports_async_sched: - raise ValueError( - "Currently, async scheduling only supports `mp`, `uni`, or " - "`external_launcher` distributed executor backend, but you chose " - f"`{executor_backend}`." - ) elif self.scheduler_config.async_scheduling is None: # Enable async scheduling unless there is an incompatible option. # NOTE: we won't reach here until async scheduling is enabled by default. @@ -580,14 +567,6 @@ class VllmConfig: " or pipeline_parallel_size > 1 and will be disabled." ) self.scheduler_config.async_scheduling = False - elif not executor_supports_async_sched: - logger.warning( - "Async scheduling will be disabled because it is not supported " - "with the `%s` distributed executor backend (only `mp`, `uni`, and " - "`external_launcher` are supported).", - executor_backend, - ) - self.scheduler_config.async_scheduling = False else: self.scheduler_config.async_scheduling = True diff --git a/vllm/v1/executor/ray_executor.py b/vllm/v1/executor/ray_executor.py index 406eafcd339b0..47637be3b6520 100644 --- a/vllm/v1/executor/ray_executor.py +++ b/vllm/v1/executor/ray_executor.py @@ -6,6 +6,7 @@ from collections import defaultdict from collections.abc import Callable from concurrent.futures import Future from dataclasses import dataclass +from functools import partial from typing import TYPE_CHECKING, Any import cloudpickle @@ -434,6 +435,11 @@ class RayDistributedExecutor(Executor): return self._execute_dag(scheduler_output, grammar_output, non_block) + @staticmethod + def _get_async_refs(refs, worker, timeout=None): + ray.get(refs, timeout=timeout) + return worker.execute_method.remote("get_execute_model_output") + def _execute_dag( self, scheduler_output: SchedulerOutput, @@ -446,6 +452,14 @@ class RayDistributedExecutor(Executor): refs = self.forward_dag.execute((scheduler_output, grammar_output)) # type: ignore + if self.scheduler_config.async_scheduling: + assert non_block + + refs = [ + partial(RayDistributedExecutor._get_async_refs, ref, worker) + for ref, worker in zip(refs, self.workers) + ] + if not self.has_connector: # Get output only from a single worker (output_rank) # When PP is not used, we block here until the result is available. diff --git a/vllm/v1/executor/ray_utils.py b/vllm/v1/executor/ray_utils.py index 21910d1160bd4..e160c3991a0f6 100644 --- a/vllm/v1/executor/ray_utils.py +++ b/vllm/v1/executor/ray_utils.py @@ -3,7 +3,7 @@ import os import time -from collections import defaultdict +from collections import defaultdict, deque from concurrent.futures import Future from typing import TYPE_CHECKING, Union @@ -49,6 +49,7 @@ try: # The flag indicates is set_device is called on # that thread. self.compiled_dag_cuda_device_set = False + self._execute_model_outputs = deque[AsyncModelRunnerOutput]() def get_node_ip(self) -> str: return get_ip() @@ -87,6 +88,7 @@ try: ) -> Union[ "ModelRunnerOutput", tuple["SchedulerOutput", "GrammarOutput", "IntermediateTensors"], + None, ]: # This method is used by Ray Compiled Graph to execute the model, # and it needs a special logic of self.setup_device_if_necessary() @@ -112,11 +114,24 @@ try: output = scheduler_output, grammar_output, None elif output is None: output = self.worker.model_runner.sample_tokens(grammar_output) + + if self.vllm_config.scheduler_config.async_scheduling: + self._execute_model_outputs.append(output) + return None + + return output + + def get_execute_model_output(self) -> "ModelRunnerOutput": + assert self.vllm_config.scheduler_config.async_scheduling + assert self._execute_model_outputs, "No execute_model output available" + output = self._execute_model_outputs.popleft() + + if isinstance(output, AsyncModelRunnerOutput): # Ensure outputs crossing Ray compiled DAG are serializable. # AsyncModelRunnerOutput holds CUDA events and cannot be # pickled. - if isinstance(output, AsyncModelRunnerOutput): - output = output.get_output() + output = output.get_output() + return output def override_env_vars(self, vars: dict[str, str]): @@ -146,8 +161,24 @@ class FutureWrapper(Future): self.ref_or_refs = ref_or_refs self.aggregator = aggregator + def is_callable(self, ref_or_refs): + if isinstance(ref_or_refs, list): + return callable(ref_or_refs[0]) + else: + return callable(ref_or_refs) + + def get_refs(self, timeout=None): + if self.is_callable(self.ref_or_refs): + if isinstance(self.ref_or_refs, list): + refs = [ref(timeout) for ref in self.ref_or_refs] + else: + refs = self.ref_or_refs(timeout) + else: + refs = self.ref_or_refs + return refs + def result(self, timeout=None): - outputs = ray.get(self.ref_or_refs, timeout=timeout) + outputs = ray.get(self.get_refs(), timeout=timeout) if self.aggregator is None: return outputs