mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 15:55:40 +08:00
[misc] Rename Ray ADAG to Compiled Graph (#13928)
This commit is contained in:
parent
ca377cf1b9
commit
c9944acbf9
@ -117,7 +117,7 @@ def test_models_distributed(
|
|||||||
pytest.skip(f"Skip test for {test_suite}")
|
pytest.skip(f"Skip test for {test_suite}")
|
||||||
|
|
||||||
if model == "meta-llama/Llama-3.2-1B-Instruct" and distributed_executor_backend == "ray" and attention_backend == "" and test_suite == "L4": # noqa
|
if model == "meta-llama/Llama-3.2-1B-Instruct" and distributed_executor_backend == "ray" and attention_backend == "" and test_suite == "L4": # noqa
|
||||||
# test ray adag
|
# test Ray Compiled Graph
|
||||||
os.environ['VLLM_USE_RAY_SPMD_WORKER'] = "1"
|
os.environ['VLLM_USE_RAY_SPMD_WORKER'] = "1"
|
||||||
os.environ['VLLM_USE_RAY_COMPILED_DAG'] = "1"
|
os.environ['VLLM_USE_RAY_COMPILED_DAG'] = "1"
|
||||||
|
|
||||||
|
|||||||
@ -93,7 +93,7 @@ def test_models_distributed(
|
|||||||
|
|
||||||
if (model == "meta-llama/Llama-3.2-1B-Instruct"
|
if (model == "meta-llama/Llama-3.2-1B-Instruct"
|
||||||
and distributed_executor_backend == "ray"):
|
and distributed_executor_backend == "ray"):
|
||||||
# test ray adag
|
# test Ray Compiled Graph
|
||||||
os.environ['VLLM_USE_RAY_SPMD_WORKER'] = "1"
|
os.environ['VLLM_USE_RAY_SPMD_WORKER'] = "1"
|
||||||
os.environ['VLLM_USE_RAY_COMPILED_DAG'] = "1"
|
os.environ['VLLM_USE_RAY_COMPILED_DAG'] = "1"
|
||||||
|
|
||||||
|
|||||||
@ -324,8 +324,8 @@ def _compare_tp(
|
|||||||
specific_case = tp_size == 2 and pp_size == 2 and chunked_prefill
|
specific_case = tp_size == 2 and pp_size == 2 and chunked_prefill
|
||||||
if distributed_backend == "ray" and (vllm_major_version == "1"
|
if distributed_backend == "ray" and (vllm_major_version == "1"
|
||||||
or specific_case):
|
or specific_case):
|
||||||
# For V1, test Ray ADAG for all the tests
|
# For V1, test Ray Compiled Graph for all the tests
|
||||||
# For V0, test Ray ADAG for a subset of the tests
|
# For V0, test Ray Compiled Graph for a subset of the tests
|
||||||
pp_env = {
|
pp_env = {
|
||||||
"VLLM_USE_V1": vllm_major_version,
|
"VLLM_USE_V1": vllm_major_version,
|
||||||
"VLLM_USE_RAY_COMPILED_DAG": "1",
|
"VLLM_USE_RAY_COMPILED_DAG": "1",
|
||||||
@ -333,7 +333,7 @@ def _compare_tp(
|
|||||||
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": "1",
|
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": "1",
|
||||||
}
|
}
|
||||||
# Temporary. Currently when zeromq + SPMD is used, it does not properly
|
# Temporary. Currently when zeromq + SPMD is used, it does not properly
|
||||||
# terminate because of aDAG issue.
|
# terminate because of a Ray Compiled Graph issue.
|
||||||
common_args.append("--disable-frontend-multiprocessing")
|
common_args.append("--disable-frontend-multiprocessing")
|
||||||
else:
|
else:
|
||||||
pp_env = None
|
pp_env = None
|
||||||
@ -367,8 +367,9 @@ def _compare_tp(
|
|||||||
if pp_env is None:
|
if pp_env is None:
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
# Ray ADAG tests are flaky, so we don't want to fail the test
|
# Ray Compiled Graph tests are flaky,
|
||||||
logger.exception("Ray ADAG tests failed")
|
# so we don't want to fail the test
|
||||||
|
logger.exception("Ray Compiled Graph tests failed")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|||||||
@ -371,21 +371,22 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_USE_RAY_SPMD_WORKER":
|
"VLLM_USE_RAY_SPMD_WORKER":
|
||||||
lambda: bool(int(os.getenv("VLLM_USE_RAY_SPMD_WORKER", "0"))),
|
lambda: bool(int(os.getenv("VLLM_USE_RAY_SPMD_WORKER", "0"))),
|
||||||
|
|
||||||
# If the env var is set, it uses the Ray's compiled DAG API
|
# If the env var is set, it uses the Ray's Compiled Graph
|
||||||
# which optimizes the control plane overhead.
|
# (previously known as ADAG) API which optimizes the
|
||||||
|
# control plane overhead.
|
||||||
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
|
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
|
||||||
"VLLM_USE_RAY_COMPILED_DAG":
|
"VLLM_USE_RAY_COMPILED_DAG":
|
||||||
lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG", "0"))),
|
lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG", "0"))),
|
||||||
|
|
||||||
# If the env var is set, it uses NCCL for communication in
|
# If the env var is set, it uses NCCL for communication in
|
||||||
# Ray's compiled DAG. This flag is ignored if
|
# Ray's Compiled Graph. This flag is ignored if
|
||||||
# VLLM_USE_RAY_COMPILED_DAG is not set.
|
# VLLM_USE_RAY_COMPILED_DAG is not set.
|
||||||
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL":
|
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL":
|
||||||
lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL", "1"))
|
lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL", "1"))
|
||||||
),
|
),
|
||||||
|
|
||||||
# If the env var is set, it enables GPU communication overlap
|
# If the env var is set, it enables GPU communication overlap
|
||||||
# (experimental feature) in Ray's compiled DAG. This flag is ignored if
|
# (experimental feature) in Ray's Compiled Graph. This flag is ignored if
|
||||||
# VLLM_USE_RAY_COMPILED_DAG is not set.
|
# VLLM_USE_RAY_COMPILED_DAG is not set.
|
||||||
"VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM":
|
"VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM":
|
||||||
lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM", "0"))
|
lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM", "0"))
|
||||||
|
|||||||
@ -491,7 +491,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
|||||||
async_run_remote_workers_only to complete."""
|
async_run_remote_workers_only to complete."""
|
||||||
ray.get(parallel_worker_tasks)
|
ray.get(parallel_worker_tasks)
|
||||||
|
|
||||||
def _check_ray_adag_installation(self):
|
def _check_ray_cgraph_installation(self):
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
@ -503,10 +503,10 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
|||||||
f"required, but found {current_version}")
|
f"required, but found {current_version}")
|
||||||
|
|
||||||
import importlib.util
|
import importlib.util
|
||||||
adag_spec = importlib.util.find_spec(
|
cgraph_spec = importlib.util.find_spec(
|
||||||
"ray.experimental.compiled_dag_ref")
|
"ray.experimental.compiled_dag_ref")
|
||||||
if adag_spec is None:
|
if cgraph_spec is None:
|
||||||
raise ValueError("Ray accelerated DAG is not installed. "
|
raise ValueError("Ray Compiled Graph is not installed. "
|
||||||
"Run `pip install ray[adag]` to install it.")
|
"Run `pip install ray[adag]` to install it.")
|
||||||
|
|
||||||
cupy_spec = importlib.util.find_spec("cupy")
|
cupy_spec = importlib.util.find_spec("cupy")
|
||||||
@ -518,7 +518,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
|||||||
|
|
||||||
def _compiled_ray_dag(self, enable_asyncio: bool):
|
def _compiled_ray_dag(self, enable_asyncio: bool):
|
||||||
assert self.parallel_config.use_ray
|
assert self.parallel_config.use_ray
|
||||||
self._check_ray_adag_installation()
|
self._check_ray_cgraph_installation()
|
||||||
from ray.dag import InputNode, MultiOutputNode
|
from ray.dag import InputNode, MultiOutputNode
|
||||||
from ray.experimental.channel.torch_tensor_type import TorchTensorType
|
from ray.experimental.channel.torch_tensor_type import TorchTensorType
|
||||||
|
|
||||||
|
|||||||
@ -83,9 +83,9 @@ try:
|
|||||||
|
|
||||||
execute_model_req = self.input_decoder.decode(serialized_req)
|
execute_model_req = self.input_decoder.decode(serialized_req)
|
||||||
|
|
||||||
# TODO(swang): This is needed right now because Ray aDAG executes
|
# TODO(swang): This is needed right now because Ray Compiled Graph
|
||||||
# on a background thread, so we need to reset torch's current
|
# executes on a background thread, so we need to reset torch's
|
||||||
# device.
|
# current device.
|
||||||
import torch
|
import torch
|
||||||
if not self.compiled_dag_cuda_device_set:
|
if not self.compiled_dag_cuda_device_set:
|
||||||
torch.cuda.set_device(self.worker.device)
|
torch.cuda.set_device(self.worker.device)
|
||||||
@ -119,7 +119,7 @@ try:
|
|||||||
"IntermediateTensors"]],
|
"IntermediateTensors"]],
|
||||||
) -> Union["ModelRunnerOutput", Tuple["SchedulerOutput",
|
) -> Union["ModelRunnerOutput", Tuple["SchedulerOutput",
|
||||||
"IntermediateTensors"]]:
|
"IntermediateTensors"]]:
|
||||||
# this method is used to compile ray CG,
|
# This method is used by Ray Compiled Graph to execute the model,
|
||||||
# and it needs a special logic of self.setup_device_if_necessary()
|
# and it needs a special logic of self.setup_device_if_necessary()
|
||||||
self.setup_device_if_necessary()
|
self.setup_device_if_necessary()
|
||||||
assert self.worker is not None, "Worker is not initialized"
|
assert self.worker is not None, "Worker is not initialized"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user