mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-18 17:46:58 +08:00
[AsyncScheduling] Support ray backend for async scheduling
Signed-off-by: Lehua Ding <lehuading@tencent.com>
This commit is contained in:
parent
8f8fda261a
commit
e591fc16ca
@ -69,6 +69,12 @@ def test_without_spec_decoding(
|
|||||||
(False, "mp", True, None, True),
|
(False, "mp", True, None, True),
|
||||||
(True, "mp", True, None, True),
|
(True, "mp", True, None, True),
|
||||||
(True, "uni", True, None, True),
|
(True, "uni", True, None, True),
|
||||||
|
(False, "ray", False, None, False),
|
||||||
|
(True, "ray", False, None, True),
|
||||||
|
(False, "ray", True, None, False),
|
||||||
|
(True, "ray", True, None, False),
|
||||||
|
(False, "ray", True, None, True),
|
||||||
|
(True, "ray", True, None, True),
|
||||||
]
|
]
|
||||||
|
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_rocm():
|
||||||
@ -119,6 +125,13 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
|
|||||||
(True, "uni", True, spec_config_short, False),
|
(True, "uni", True, spec_config_short, False),
|
||||||
(True, "mp", True, spec_config, True),
|
(True, "mp", True, spec_config, True),
|
||||||
(True, "uni", True, spec_config_short, True),
|
(True, "uni", True, spec_config_short, True),
|
||||||
|
(False, "ray", False, None, False),
|
||||||
|
(False, "ray", False, spec_config, False),
|
||||||
|
(True, "ray", False, spec_config, True),
|
||||||
|
(False, "ray", True, spec_config, False),
|
||||||
|
(True, "ray", True, spec_config, False),
|
||||||
|
(False, "ray", True, spec_config_short, True),
|
||||||
|
(True, "ray", True, spec_config, True),
|
||||||
]
|
]
|
||||||
|
|
||||||
# On ROCm, use TRITON_ATTN + float32 for better numerical consistency
|
# On ROCm, use TRITON_ATTN + float32 for better numerical consistency
|
||||||
|
|||||||
@ -532,13 +532,6 @@ class VllmConfig:
|
|||||||
self.model_config, self.load_config
|
self.model_config, self.load_config
|
||||||
)
|
)
|
||||||
|
|
||||||
executor_backend = self.parallel_config.distributed_executor_backend
|
|
||||||
executor_supports_async_sched = executor_backend in (
|
|
||||||
"mp",
|
|
||||||
"uni",
|
|
||||||
"external_launcher",
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.scheduler_config.async_scheduling:
|
if self.scheduler_config.async_scheduling:
|
||||||
# Async scheduling explicitly enabled, hard fail any incompatibilities.
|
# Async scheduling explicitly enabled, hard fail any incompatibilities.
|
||||||
if self.parallel_config.pipeline_parallel_size > 1:
|
if self.parallel_config.pipeline_parallel_size > 1:
|
||||||
@ -562,12 +555,6 @@ class VllmConfig:
|
|||||||
"this situation now. please set "
|
"this situation now. please set "
|
||||||
"disable_padded_drafter_batch=Fasle"
|
"disable_padded_drafter_batch=Fasle"
|
||||||
)
|
)
|
||||||
if not executor_supports_async_sched:
|
|
||||||
raise ValueError(
|
|
||||||
"Currently, async scheduling only supports `mp`, `uni`, or "
|
|
||||||
"`external_launcher` distributed executor backend, but you chose "
|
|
||||||
f"`{executor_backend}`."
|
|
||||||
)
|
|
||||||
elif self.scheduler_config.async_scheduling is None:
|
elif self.scheduler_config.async_scheduling is None:
|
||||||
# Enable async scheduling unless there is an incompatible option.
|
# Enable async scheduling unless there is an incompatible option.
|
||||||
# NOTE: we won't reach here until async scheduling is enabled by default.
|
# NOTE: we won't reach here until async scheduling is enabled by default.
|
||||||
@ -580,14 +567,6 @@ class VllmConfig:
|
|||||||
" or pipeline_parallel_size > 1 and will be disabled."
|
" or pipeline_parallel_size > 1 and will be disabled."
|
||||||
)
|
)
|
||||||
self.scheduler_config.async_scheduling = False
|
self.scheduler_config.async_scheduling = False
|
||||||
elif not executor_supports_async_sched:
|
|
||||||
logger.warning(
|
|
||||||
"Async scheduling will be disabled because it is not supported "
|
|
||||||
"with the `%s` distributed executor backend (only `mp`, `uni`, and "
|
|
||||||
"`external_launcher` are supported).",
|
|
||||||
executor_backend,
|
|
||||||
)
|
|
||||||
self.scheduler_config.async_scheduling = False
|
|
||||||
else:
|
else:
|
||||||
self.scheduler_config.async_scheduling = True
|
self.scheduler_config.async_scheduling = True
|
||||||
|
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from collections import defaultdict
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from concurrent.futures import Future
|
from concurrent.futures import Future
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from functools import partial
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
import cloudpickle
|
import cloudpickle
|
||||||
@ -434,6 +435,11 @@ class RayDistributedExecutor(Executor):
|
|||||||
|
|
||||||
return self._execute_dag(scheduler_output, grammar_output, non_block)
|
return self._execute_dag(scheduler_output, grammar_output, non_block)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_async_refs(refs, worker, timeout=None):
|
||||||
|
ray.get(refs, timeout=timeout)
|
||||||
|
return worker.execute_method.remote("get_execute_model_output")
|
||||||
|
|
||||||
def _execute_dag(
|
def _execute_dag(
|
||||||
self,
|
self,
|
||||||
scheduler_output: SchedulerOutput,
|
scheduler_output: SchedulerOutput,
|
||||||
@ -446,6 +452,14 @@ class RayDistributedExecutor(Executor):
|
|||||||
|
|
||||||
refs = self.forward_dag.execute((scheduler_output, grammar_output)) # type: ignore
|
refs = self.forward_dag.execute((scheduler_output, grammar_output)) # type: ignore
|
||||||
|
|
||||||
|
if self.scheduler_config.async_scheduling:
|
||||||
|
assert non_block
|
||||||
|
|
||||||
|
refs = [
|
||||||
|
partial(RayDistributedExecutor._get_async_refs, ref, worker)
|
||||||
|
for ref, worker in zip(refs, self.workers)
|
||||||
|
]
|
||||||
|
|
||||||
if not self.has_connector:
|
if not self.has_connector:
|
||||||
# Get output only from a single worker (output_rank)
|
# Get output only from a single worker (output_rank)
|
||||||
# When PP is not used, we block here until the result is available.
|
# When PP is not used, we block here until the result is available.
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict, deque
|
||||||
from concurrent.futures import Future
|
from concurrent.futures import Future
|
||||||
from typing import TYPE_CHECKING, Union
|
from typing import TYPE_CHECKING, Union
|
||||||
|
|
||||||
@ -49,6 +49,7 @@ try:
|
|||||||
# The flag indicates is set_device is called on
|
# The flag indicates is set_device is called on
|
||||||
# that thread.
|
# that thread.
|
||||||
self.compiled_dag_cuda_device_set = False
|
self.compiled_dag_cuda_device_set = False
|
||||||
|
self._execute_model_outputs = deque[AsyncModelRunnerOutput]()
|
||||||
|
|
||||||
def get_node_ip(self) -> str:
|
def get_node_ip(self) -> str:
|
||||||
return get_ip()
|
return get_ip()
|
||||||
@ -87,6 +88,7 @@ try:
|
|||||||
) -> Union[
|
) -> Union[
|
||||||
"ModelRunnerOutput",
|
"ModelRunnerOutput",
|
||||||
tuple["SchedulerOutput", "GrammarOutput", "IntermediateTensors"],
|
tuple["SchedulerOutput", "GrammarOutput", "IntermediateTensors"],
|
||||||
|
None,
|
||||||
]:
|
]:
|
||||||
# This method is used by Ray Compiled Graph to execute the model,
|
# This method is used by Ray Compiled Graph to execute the model,
|
||||||
# and it needs a special logic of self.setup_device_if_necessary()
|
# and it needs a special logic of self.setup_device_if_necessary()
|
||||||
@ -112,11 +114,24 @@ try:
|
|||||||
output = scheduler_output, grammar_output, None
|
output = scheduler_output, grammar_output, None
|
||||||
elif output is None:
|
elif output is None:
|
||||||
output = self.worker.model_runner.sample_tokens(grammar_output)
|
output = self.worker.model_runner.sample_tokens(grammar_output)
|
||||||
|
|
||||||
|
if self.vllm_config.scheduler_config.async_scheduling:
|
||||||
|
self._execute_model_outputs.append(output)
|
||||||
|
return None
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def get_execute_model_output(self) -> "ModelRunnerOutput":
|
||||||
|
assert self.vllm_config.scheduler_config.async_scheduling
|
||||||
|
assert self._execute_model_outputs, "No execute_model output available"
|
||||||
|
output = self._execute_model_outputs.popleft()
|
||||||
|
|
||||||
|
if isinstance(output, AsyncModelRunnerOutput):
|
||||||
# Ensure outputs crossing Ray compiled DAG are serializable.
|
# Ensure outputs crossing Ray compiled DAG are serializable.
|
||||||
# AsyncModelRunnerOutput holds CUDA events and cannot be
|
# AsyncModelRunnerOutput holds CUDA events and cannot be
|
||||||
# pickled.
|
# pickled.
|
||||||
if isinstance(output, AsyncModelRunnerOutput):
|
output = output.get_output()
|
||||||
output = output.get_output()
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def override_env_vars(self, vars: dict[str, str]):
|
def override_env_vars(self, vars: dict[str, str]):
|
||||||
@ -146,8 +161,24 @@ class FutureWrapper(Future):
|
|||||||
self.ref_or_refs = ref_or_refs
|
self.ref_or_refs = ref_or_refs
|
||||||
self.aggregator = aggregator
|
self.aggregator = aggregator
|
||||||
|
|
||||||
|
def is_callable(self, ref_or_refs):
|
||||||
|
if isinstance(ref_or_refs, list):
|
||||||
|
return callable(ref_or_refs[0])
|
||||||
|
else:
|
||||||
|
return callable(ref_or_refs)
|
||||||
|
|
||||||
|
def get_refs(self, timeout=None):
|
||||||
|
if self.is_callable(self.ref_or_refs):
|
||||||
|
if isinstance(self.ref_or_refs, list):
|
||||||
|
refs = [ref(timeout) for ref in self.ref_or_refs]
|
||||||
|
else:
|
||||||
|
refs = self.ref_or_refs(timeout)
|
||||||
|
else:
|
||||||
|
refs = self.ref_or_refs
|
||||||
|
return refs
|
||||||
|
|
||||||
def result(self, timeout=None):
|
def result(self, timeout=None):
|
||||||
outputs = ray.get(self.ref_or_refs, timeout=timeout)
|
outputs = ray.get(self.get_refs(), timeout=timeout)
|
||||||
if self.aggregator is None:
|
if self.aggregator is None:
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user