mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-24 02:29:10 +08:00
Merge e591fc16ca3c3391ede2fba8f0a4c25adabb57b0 into 254f6b986720c92ddf97fbb1a6a6465da8e87e29
This commit is contained in:
commit
5d1311404d
@ -69,6 +69,12 @@ def test_without_spec_decoding(
|
||||
(False, "mp", True, None, True),
|
||||
(True, "mp", True, None, True),
|
||||
(True, "uni", True, None, True),
|
||||
(False, "ray", False, None, False),
|
||||
(True, "ray", False, None, True),
|
||||
(False, "ray", True, None, False),
|
||||
(True, "ray", True, None, False),
|
||||
(False, "ray", True, None, True),
|
||||
(True, "ray", True, None, True),
|
||||
]
|
||||
|
||||
if current_platform.is_rocm():
|
||||
@ -119,6 +125,13 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
|
||||
(True, "uni", True, spec_config_short, False),
|
||||
(True, "mp", True, spec_config, True),
|
||||
(True, "uni", True, spec_config_short, True),
|
||||
(False, "ray", False, None, False),
|
||||
(False, "ray", False, spec_config, False),
|
||||
(True, "ray", False, spec_config, True),
|
||||
(False, "ray", True, spec_config, False),
|
||||
(True, "ray", True, spec_config, False),
|
||||
(False, "ray", True, spec_config_short, True),
|
||||
(True, "ray", True, spec_config, True),
|
||||
]
|
||||
|
||||
# On ROCm, use TRITON_ATTN + float32 for better numerical consistency
|
||||
|
||||
@ -532,13 +532,6 @@ class VllmConfig:
|
||||
self.model_config, self.load_config
|
||||
)
|
||||
|
||||
executor_backend = self.parallel_config.distributed_executor_backend
|
||||
executor_supports_async_sched = executor_backend in (
|
||||
"mp",
|
||||
"uni",
|
||||
"external_launcher",
|
||||
)
|
||||
|
||||
if self.scheduler_config.async_scheduling:
|
||||
# Async scheduling explicitly enabled, hard fail any incompatibilities.
|
||||
if self.parallel_config.pipeline_parallel_size > 1:
|
||||
@ -562,12 +555,6 @@ class VllmConfig:
|
||||
"this situation now. please set "
|
||||
"disable_padded_drafter_batch=Fasle"
|
||||
)
|
||||
if not executor_supports_async_sched:
|
||||
raise ValueError(
|
||||
"Currently, async scheduling only supports `mp`, `uni`, or "
|
||||
"`external_launcher` distributed executor backend, but you chose "
|
||||
f"`{executor_backend}`."
|
||||
)
|
||||
elif self.scheduler_config.async_scheduling is None:
|
||||
# Enable async scheduling unless there is an incompatible option.
|
||||
# NOTE: we won't reach here until async scheduling is enabled by default.
|
||||
@ -580,14 +567,6 @@ class VllmConfig:
|
||||
" or pipeline_parallel_size > 1 and will be disabled."
|
||||
)
|
||||
self.scheduler_config.async_scheduling = False
|
||||
elif not executor_supports_async_sched:
|
||||
logger.warning(
|
||||
"Async scheduling will be disabled because it is not supported "
|
||||
"with the `%s` distributed executor backend (only `mp`, `uni`, and "
|
||||
"`external_launcher` are supported).",
|
||||
executor_backend,
|
||||
)
|
||||
self.scheduler_config.async_scheduling = False
|
||||
else:
|
||||
self.scheduler_config.async_scheduling = True
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@ from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import Future
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import cloudpickle
|
||||
@ -434,6 +435,11 @@ class RayDistributedExecutor(Executor):
|
||||
|
||||
return self._execute_dag(scheduler_output, grammar_output, non_block)
|
||||
|
||||
@staticmethod
|
||||
def _get_async_refs(refs, worker, timeout=None):
|
||||
ray.get(refs, timeout=timeout)
|
||||
return worker.execute_method.remote("get_execute_model_output")
|
||||
|
||||
def _execute_dag(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
@ -446,6 +452,14 @@ class RayDistributedExecutor(Executor):
|
||||
|
||||
refs = self.forward_dag.execute((scheduler_output, grammar_output)) # type: ignore
|
||||
|
||||
if self.scheduler_config.async_scheduling:
|
||||
assert non_block
|
||||
|
||||
refs = [
|
||||
partial(RayDistributedExecutor._get_async_refs, ref, worker)
|
||||
for ref, worker in zip(refs, self.workers)
|
||||
]
|
||||
|
||||
if not self.has_connector:
|
||||
# Get output only from a single worker (output_rank)
|
||||
# When PP is not used, we block here until the result is available.
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
|
||||
import os
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections import defaultdict, deque
|
||||
from concurrent.futures import Future
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
@ -49,6 +49,7 @@ try:
|
||||
# The flag indicates is set_device is called on
|
||||
# that thread.
|
||||
self.compiled_dag_cuda_device_set = False
|
||||
self._execute_model_outputs = deque[AsyncModelRunnerOutput]()
|
||||
|
||||
def get_node_ip(self) -> str:
|
||||
return get_ip()
|
||||
@ -87,6 +88,7 @@ try:
|
||||
) -> Union[
|
||||
"ModelRunnerOutput",
|
||||
tuple["SchedulerOutput", "GrammarOutput", "IntermediateTensors"],
|
||||
None,
|
||||
]:
|
||||
# This method is used by Ray Compiled Graph to execute the model,
|
||||
# and it needs a special logic of self.setup_device_if_necessary()
|
||||
@ -112,11 +114,24 @@ try:
|
||||
output = scheduler_output, grammar_output, None
|
||||
elif output is None:
|
||||
output = self.worker.model_runner.sample_tokens(grammar_output)
|
||||
|
||||
if self.vllm_config.scheduler_config.async_scheduling:
|
||||
self._execute_model_outputs.append(output)
|
||||
return None
|
||||
|
||||
return output
|
||||
|
||||
def get_execute_model_output(self) -> "ModelRunnerOutput":
|
||||
assert self.vllm_config.scheduler_config.async_scheduling
|
||||
assert self._execute_model_outputs, "No execute_model output available"
|
||||
output = self._execute_model_outputs.popleft()
|
||||
|
||||
if isinstance(output, AsyncModelRunnerOutput):
|
||||
# Ensure outputs crossing Ray compiled DAG are serializable.
|
||||
# AsyncModelRunnerOutput holds CUDA events and cannot be
|
||||
# pickled.
|
||||
if isinstance(output, AsyncModelRunnerOutput):
|
||||
output = output.get_output()
|
||||
output = output.get_output()
|
||||
|
||||
return output
|
||||
|
||||
def override_env_vars(self, vars: dict[str, str]):
|
||||
@ -146,8 +161,24 @@ class FutureWrapper(Future):
|
||||
self.ref_or_refs = ref_or_refs
|
||||
self.aggregator = aggregator
|
||||
|
||||
def is_callable(self, ref_or_refs):
|
||||
if isinstance(ref_or_refs, list):
|
||||
return callable(ref_or_refs[0])
|
||||
else:
|
||||
return callable(ref_or_refs)
|
||||
|
||||
def get_refs(self, timeout=None):
|
||||
if self.is_callable(self.ref_or_refs):
|
||||
if isinstance(self.ref_or_refs, list):
|
||||
refs = [ref(timeout) for ref in self.ref_or_refs]
|
||||
else:
|
||||
refs = self.ref_or_refs(timeout)
|
||||
else:
|
||||
refs = self.ref_or_refs
|
||||
return refs
|
||||
|
||||
def result(self, timeout=None):
|
||||
outputs = ray.get(self.ref_or_refs, timeout=timeout)
|
||||
outputs = ray.get(self.get_refs(), timeout=timeout)
|
||||
if self.aggregator is None:
|
||||
return outputs
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user