mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-26 00:36:32 +08:00
[Core] Introduce SPMD worker execution using Ray accelerated DAG (#6032)
Signed-off-by: Rui Qiao <ruisearch42@gmail.com> Co-authored-by: Stephanie Wang <swang@cs.berkeley.edu>
This commit is contained in:
parent
d25877dd9b
commit
61e592747c
@ -84,6 +84,8 @@ steps:
|
||||
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||
- TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py
|
||||
@ -108,6 +110,7 @@ steps:
|
||||
# We want to test that models which use 2 GPUs work with 4 GPUs, which is why we duplicate them here.
|
||||
# See https://github.com/vllm-project/vllm/pull/5473#issuecomment-2166601837 for context.
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@ from typing import Set, Type, TypeVar, Union
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
|
||||
LoRAConfig, ModelConfig, MultiModalConfig,
|
||||
ObservabilityConfig, ParallelConfig,
|
||||
@ -414,6 +415,9 @@ class LLMEngine:
|
||||
elif distributed_executor_backend == "mp":
|
||||
from vllm.executor.multiproc_gpu_executor import (
|
||||
MultiprocessingGPUExecutor)
|
||||
assert not envs.VLLM_USE_RAY_SPMD_WORKER, (
|
||||
"multiprocessing distributed executor backend does not "
|
||||
"support VLLM_USE_RAY_SPMD_WORKER=1")
|
||||
executor_class = MultiprocessingGPUExecutor
|
||||
else:
|
||||
from vllm.executor.gpu_executor import GPUExecutor
|
||||
@ -426,6 +430,7 @@ class LLMEngine:
|
||||
usage_context=usage_context,
|
||||
stat_loggers=stat_loggers,
|
||||
)
|
||||
|
||||
return engine
|
||||
|
||||
def __reduce__(self):
|
||||
|
||||
@ -34,6 +34,7 @@ if TYPE_CHECKING:
|
||||
VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False
|
||||
VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache")
|
||||
VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024
|
||||
VLLM_USE_RAY_SPMD_WORKER: bool = False
|
||||
VLLM_USE_RAY_COMPILED_DAG: bool = False
|
||||
VLLM_WORKER_MULTIPROC_METHOD: str = "fork"
|
||||
VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets")
|
||||
@ -261,6 +262,13 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
||||
"VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS":
|
||||
lambda: bool(os.getenv("VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS", False)),
|
||||
|
||||
# If the env var is set, then all workers will execute as separate
|
||||
# processes from the engine, and we use the same mechanism to trigger
|
||||
# execution on all workers.
|
||||
# Run vLLM with VLLM_USE_RAY_SPMD_WORKER=1 to enable it.
|
||||
"VLLM_USE_RAY_SPMD_WORKER":
|
||||
lambda: bool(os.getenv("VLLM_USE_RAY_SPMD_WORKER", 0)),
|
||||
|
||||
# If the env var is set, it uses the Ray's compiled DAG API
|
||||
# which optimizes the control plane overhead.
|
||||
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
|
||||
|
||||
@ -64,8 +64,8 @@ class DistributedGPUExecutor(GPUExecutor):
|
||||
num_cpu_blocks=num_cpu_blocks)
|
||||
|
||||
def execute_model(
|
||||
self, execute_model_req: ExecuteModelRequest
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
if self.parallel_worker_tasks is None:
|
||||
self.parallel_worker_tasks = self._run_workers(
|
||||
"start_worker_execution_loop",
|
||||
@ -73,7 +73,9 @@ class DistributedGPUExecutor(GPUExecutor):
|
||||
**self.extra_execute_model_run_workers_kwargs)
|
||||
|
||||
# Only the driver worker returns the sampling results.
|
||||
return self._driver_execute_model(execute_model_req)
|
||||
driver_outputs = self._driver_execute_model(execute_model_req)
|
||||
assert driver_outputs is not None
|
||||
return driver_outputs
|
||||
|
||||
def stop_remote_worker_execution_loop(self) -> None:
|
||||
if self.parallel_worker_tasks is None:
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import os
|
||||
import pickle
|
||||
from collections import defaultdict
|
||||
from itertools import islice, repeat
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
@ -23,12 +22,30 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG
|
||||
|
||||
|
||||
class RayGPUExecutor(DistributedGPUExecutor):
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
# If the env var is set, it uses the Ray's compiled DAG API
|
||||
# which optimizes the control plane overhead.
|
||||
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
|
||||
# Currently, this requires USE_RAY_SPMD_WORKER=True.
|
||||
self.use_ray_compiled_dag = envs.VLLM_USE_RAY_COMPILED_DAG
|
||||
# If the env var is set, then we do not distinguish between the
|
||||
# "driver worker" vs other workers. Also, the rank 0 worker will
|
||||
# be executed in a remote Ray worker. Currently this requires
|
||||
# USE_RAY_COMPILED_DAG=True.
|
||||
self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER
|
||||
if self.use_ray_compiled_dag:
|
||||
assert self.use_ray_spmd_worker, (
|
||||
"VLLM_USE_RAY_COMPILED_DAG=1 requires "
|
||||
"VLLM_USE_RAY_SPMD_WORKER=1")
|
||||
if self.use_ray_spmd_worker:
|
||||
# TODO: Support SPMD worker for non-DAG Ray executor.
|
||||
assert self.use_ray_compiled_dag, (
|
||||
"VLLM_USE_RAY_SPMD_WORKER=1 requires "
|
||||
"VLLM_USE_RAY_COMPILED_DAG=1")
|
||||
|
||||
assert self.parallel_config.distributed_executor_backend == "ray"
|
||||
placement_group = self.parallel_config.placement_group
|
||||
|
||||
@ -40,11 +57,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
# Create the parallel GPU workers.
|
||||
self._init_workers_ray(placement_group)
|
||||
|
||||
self.forward_dag = None
|
||||
if USE_RAY_COMPILED_DAG:
|
||||
self.forward_dag = self._compiled_ray_dag()
|
||||
self.extra_execute_model_run_workers_kwargs[
|
||||
"use_ray_compiled_dag"] = True
|
||||
self.forward_dag: Optional["ray.dag.CompiledDAG"] = None
|
||||
|
||||
def _configure_ray_workers_use_nsight(self,
|
||||
ray_remote_kwargs) -> Dict[str, Any]:
|
||||
@ -110,21 +123,24 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
)
|
||||
|
||||
worker_ip = ray.get(worker.get_node_ip.remote())
|
||||
if worker_ip == driver_ip and self.driver_dummy_worker is None:
|
||||
# If the worker is on the same node as the driver, we use it
|
||||
# as the resource holder for the driver process.
|
||||
self.driver_dummy_worker = worker
|
||||
self.driver_worker = RayWorkerWrapper(
|
||||
worker_module_name=worker_module_name,
|
||||
worker_class_name=worker_class_name,
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
)
|
||||
else:
|
||||
# Else, added to the list of workers.
|
||||
if self.use_ray_spmd_worker:
|
||||
self.workers.append(worker)
|
||||
else:
|
||||
worker_ip = ray.get(worker.get_node_ip.remote())
|
||||
if worker_ip == driver_ip and self.driver_dummy_worker is None:
|
||||
# If the worker is on the same node as the driver, we use it
|
||||
# as the resource holder for the driver process.
|
||||
self.driver_dummy_worker = worker
|
||||
self.driver_worker = RayWorkerWrapper(
|
||||
worker_module_name=worker_module_name,
|
||||
worker_class_name=worker_class_name,
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
)
|
||||
else:
|
||||
# Else, added to the list of workers.
|
||||
self.workers.append(worker)
|
||||
|
||||
if self.driver_dummy_worker is None:
|
||||
if not self.use_ray_spmd_worker and self.driver_dummy_worker is None:
|
||||
raise ValueError(
|
||||
"Ray does not allocate any GPUs on the driver node. Consider "
|
||||
"adjusting the Ray placement group or running the driver on a "
|
||||
@ -254,9 +270,23 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
Passing None will cause the driver to stop the model execution
|
||||
loop running in each of the remote workers.
|
||||
"""
|
||||
assert not self.use_ray_spmd_worker, (
|
||||
"driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1")
|
||||
return self.driver_worker.execute_method("execute_model",
|
||||
execute_model_req)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
if not self.use_ray_spmd_worker:
|
||||
return super().execute_model(execute_model_req)
|
||||
|
||||
if self.forward_dag is None:
|
||||
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
|
||||
|
||||
outputs = ray.get(self.forward_dag.execute(execute_model_req))
|
||||
return outputs[0]
|
||||
|
||||
def _run_workers(
|
||||
self,
|
||||
method: str,
|
||||
@ -266,7 +296,6 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
all_kwargs: Optional[List[Dict[str, Any]]] = None,
|
||||
use_dummy_driver: bool = False,
|
||||
max_concurrent_workers: Optional[int] = None,
|
||||
use_ray_compiled_dag: bool = False,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Runs the given method on all workers. Can be used in the following
|
||||
@ -281,6 +310,10 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
- all_args/all_kwargs: args/kwargs for each worker are specified
|
||||
individually
|
||||
"""
|
||||
if self.use_ray_spmd_worker:
|
||||
assert not async_run_tensor_parallel_workers_only, (
|
||||
"async_run_tensor_parallel_workers_only is not supported for "
|
||||
"spmd mode.")
|
||||
|
||||
if max_concurrent_workers:
|
||||
raise NotImplementedError(
|
||||
@ -289,71 +322,69 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
count = len(self.workers) if not \
|
||||
async_run_tensor_parallel_workers_only \
|
||||
else len(self.non_driver_workers)
|
||||
# If using SPMD worker, all workers are the same, so we should execute
|
||||
# the args on all workers. Otherwise, we skip the first worker's args
|
||||
# because those args will go to the driver worker.
|
||||
first_worker_args_index: int = 0 if self.use_ray_spmd_worker else 1
|
||||
all_worker_args = repeat(args, count) if all_args is None \
|
||||
else islice(all_args, 1, None)
|
||||
else islice(all_args, first_worker_args_index, None)
|
||||
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
|
||||
else islice(all_kwargs, 1, None)
|
||||
else islice(all_kwargs, first_worker_args_index, None)
|
||||
|
||||
if use_ray_compiled_dag:
|
||||
# Right now, compiled DAG can only accept a single
|
||||
# input. TODO(sang): Fix it.
|
||||
assert self.forward_dag is not None
|
||||
output_channels = self.forward_dag.execute(1)
|
||||
ray_worker_outputs = []
|
||||
else:
|
||||
# Start the ray workers first.
|
||||
ray_workers = self.workers
|
||||
if async_run_tensor_parallel_workers_only:
|
||||
ray_workers = self.non_driver_workers
|
||||
ray_worker_outputs = [
|
||||
worker.execute_method.remote(method, *worker_args,
|
||||
**worker_kwargs)
|
||||
for (worker, worker_args, worker_kwargs
|
||||
) in zip(ray_workers, all_worker_args, all_worker_kwargs)
|
||||
]
|
||||
# Start the ray workers first.
|
||||
ray_workers = self.workers
|
||||
if async_run_tensor_parallel_workers_only:
|
||||
ray_workers = self.non_driver_workers
|
||||
ray_worker_outputs = [
|
||||
worker.execute_method.remote(method, *worker_args, **worker_kwargs)
|
||||
for (worker, worker_args, worker_kwargs
|
||||
) in zip(ray_workers, all_worker_args, all_worker_kwargs)
|
||||
]
|
||||
|
||||
if async_run_tensor_parallel_workers_only:
|
||||
# Just return futures
|
||||
return ray_worker_outputs
|
||||
|
||||
driver_args = args if all_args is None else all_args[0]
|
||||
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
|
||||
driver_worker_output = []
|
||||
# In SPMD mode, the driver worker is the same as any other worker,
|
||||
# so we only explicitly execute on the driver worker if using a
|
||||
# non-SPMD worker class.
|
||||
if not self.use_ray_spmd_worker:
|
||||
driver_args = args if all_args is None else all_args[0]
|
||||
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
|
||||
|
||||
# Start the driver worker after all the ray workers.
|
||||
if not use_dummy_driver:
|
||||
driver_worker_output = [
|
||||
self.driver_worker.execute_method(method, *driver_args,
|
||||
**driver_kwargs)
|
||||
]
|
||||
else:
|
||||
assert self.driver_dummy_worker is not None
|
||||
driver_worker_output = [
|
||||
ray.get(
|
||||
self.driver_dummy_worker.execute_method.remote(
|
||||
method, *driver_args, **driver_kwargs))
|
||||
]
|
||||
|
||||
# Start the driver worker after all the ray workers.
|
||||
if not use_dummy_driver:
|
||||
driver_worker_output = self.driver_worker.execute_method(
|
||||
method, *driver_args, **driver_kwargs)
|
||||
else:
|
||||
assert self.driver_dummy_worker is not None
|
||||
driver_worker_output = ray.get(
|
||||
self.driver_dummy_worker.execute_method.remote(
|
||||
method, *driver_args, **driver_kwargs))
|
||||
# Get the results of the ray workers.
|
||||
if self.workers:
|
||||
if use_ray_compiled_dag:
|
||||
try:
|
||||
ray_worker_outputs = [
|
||||
pickle.loads(chan.begin_read())
|
||||
for chan in output_channels
|
||||
]
|
||||
finally:
|
||||
# Has to call end_read in order to reuse the DAG.
|
||||
for chan in output_channels:
|
||||
chan.end_read()
|
||||
else:
|
||||
ray_worker_outputs = ray.get(ray_worker_outputs)
|
||||
ray_worker_outputs = ray.get(ray_worker_outputs)
|
||||
|
||||
return [driver_worker_output] + ray_worker_outputs
|
||||
return driver_worker_output + ray_worker_outputs
|
||||
|
||||
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
|
||||
"""Wait for futures returned from _run_workers() with
|
||||
async_run_remote_workers_only to complete."""
|
||||
ray.get(parallel_worker_tasks)
|
||||
|
||||
def _compiled_ray_dag(self):
|
||||
def _compiled_ray_dag(self, enable_asyncio: bool):
|
||||
import pkg_resources
|
||||
required_version = "2.9"
|
||||
current_version = pkg_resources.get_distribution("ray").version
|
||||
from packaging import version
|
||||
|
||||
required_version = version.parse("2.32")
|
||||
current_version = version.parse(
|
||||
pkg_resources.get_distribution("ray").version)
|
||||
if current_version < required_version:
|
||||
raise ValueError(f"Ray version {required_version} or greater is "
|
||||
f"required, but found {current_version}")
|
||||
@ -365,23 +396,47 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
# a dummy value for now. It will be fixed soon.
|
||||
with InputNode() as input_data:
|
||||
forward_dag = MultiOutputNode([
|
||||
worker.execute_model_compiled_dag_remote.
|
||||
bind( # type: ignore[attr-defined]
|
||||
worker.execute_model_spmd.bind( # type: ignore[attr-defined]
|
||||
input_data) for worker in self.workers
|
||||
])
|
||||
return forward_dag.experimental_compile()
|
||||
return forward_dag.experimental_compile(enable_asyncio=enable_asyncio)
|
||||
|
||||
def __del__(self):
|
||||
if self.forward_dag is not None:
|
||||
self.forward_dag.teardown()
|
||||
import ray
|
||||
for worker in self.workers:
|
||||
ray.kill(worker)
|
||||
|
||||
|
||||
class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.driver_exec_method = make_async(self.driver_worker.execute_method)
|
||||
self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER
|
||||
if not self.use_ray_compiled_dag:
|
||||
self.driver_exec_method = make_async(
|
||||
self.driver_worker.execute_method)
|
||||
|
||||
async def execute_model_async(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
if not self.use_ray_spmd_worker:
|
||||
return await super().execute_model_async(execute_model_req)
|
||||
|
||||
if self.forward_dag is None:
|
||||
self.forward_dag = self._compiled_ray_dag(enable_asyncio=True)
|
||||
|
||||
dag_future = await self.forward_dag.execute_async(execute_model_req)
|
||||
outputs = await dag_future
|
||||
return outputs[0]
|
||||
|
||||
async def _driver_execute_model_async(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
) -> List[SamplerOutput]:
|
||||
assert not self.use_ray_spmd_worker, (
|
||||
"driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1")
|
||||
if self.pp_locks is None:
|
||||
# This locks each pipeline parallel stage so multiple virtual
|
||||
# engines can't execute on the same stage at the same time
|
||||
@ -415,8 +470,17 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
|
||||
return results[-1]
|
||||
|
||||
async def _start_worker_execution_loop(self):
|
||||
assert not self.use_ray_spmd_worker, (
|
||||
"worker loop is disabled for VLLM_USE_RAY_SPMD_WORKER=1")
|
||||
coros = [
|
||||
worker.execute_method.remote("start_worker_execution_loop")
|
||||
for worker in self.non_driver_workers
|
||||
]
|
||||
return await asyncio.gather(*coros)
|
||||
|
||||
def __del__(self):
|
||||
if self.forward_dag is not None:
|
||||
self.forward_dag.teardown()
|
||||
import ray
|
||||
for worker in self.workers:
|
||||
ray.kill(worker)
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
import pickle
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import get_ip, is_hip, is_xpu
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
@ -31,16 +31,18 @@ try:
|
||||
gpu_ids = ray.get_gpu_ids()
|
||||
return node_id, gpu_ids
|
||||
|
||||
def execute_model_compiled_dag_remote(self, ignored):
|
||||
"""Used only when compiled DAG is enabled."""
|
||||
def execute_model_spmd(self, execute_model_req: ExecuteModelRequest):
|
||||
"""Used only when SPMD worker and compiled DAG are both
|
||||
enabled."""
|
||||
# TODO(swang): This is needed right now because Ray aDAG executes
|
||||
# on a background thread, so we need to reset torch's current
|
||||
# device.
|
||||
import torch
|
||||
if not self.compiled_dag_cuda_device_set:
|
||||
torch.cuda.set_device(self.worker.device)
|
||||
self.compiled_dag_cuda_device_set = True
|
||||
|
||||
output = self.worker.execute_model()
|
||||
output = pickle.dumps(output)
|
||||
return output
|
||||
return self.worker._execute_model_spmd(execute_model_req)
|
||||
|
||||
ray_import_err = None
|
||||
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
import asyncio
|
||||
import os
|
||||
import pickle
|
||||
from collections import defaultdict
|
||||
from itertools import islice, repeat
|
||||
from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Set,
|
||||
Tuple, Union)
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig,
|
||||
@ -30,7 +30,7 @@ logger = init_logger(__name__)
|
||||
# If the env var is set, it uses the Ray's compiled DAG API
|
||||
# which optimizes the control plane overhead.
|
||||
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
|
||||
USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0))
|
||||
USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG
|
||||
|
||||
|
||||
class RayXPUExecutor(DistributedGPUExecutor):
|
||||
@ -72,10 +72,9 @@ class RayXPUExecutor(DistributedGPUExecutor):
|
||||
# Create the parallel GPU workers.
|
||||
self._init_workers_ray(placement_group)
|
||||
|
||||
# Profile the memory usage and initialize the cache.
|
||||
self.forward_dag = None
|
||||
if USE_RAY_COMPILED_DAG:
|
||||
self.forward_dag = self._compiled_ray_dag()
|
||||
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
|
||||
|
||||
# This is non-None when the execute model loop is running
|
||||
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
|
||||
@ -270,7 +269,6 @@ class RayXPUExecutor(DistributedGPUExecutor):
|
||||
all_kwargs: Optional[List[Dict[str, Any]]] = None,
|
||||
use_dummy_driver: bool = False,
|
||||
max_concurrent_workers: Optional[int] = None,
|
||||
use_ray_compiled_dag: bool = False,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Runs the given method on all workers. Can be used in the following
|
||||
@ -293,26 +291,20 @@ class RayXPUExecutor(DistributedGPUExecutor):
|
||||
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
|
||||
else islice(all_kwargs, 1, None)
|
||||
|
||||
if use_ray_compiled_dag:
|
||||
# Right now, compiled DAG can only accept a single
|
||||
# input. TODO(sang): Fix it.
|
||||
assert self.forward_dag is not None
|
||||
output_channels = self.forward_dag.execute(1)
|
||||
else:
|
||||
# Start the ray workers first.
|
||||
ray_worker_outputs = [
|
||||
worker.execute_method.remote(method, *worker_args,
|
||||
**worker_kwargs)
|
||||
for (worker, worker_args, worker_kwargs
|
||||
) in zip(self.workers, all_worker_args, all_worker_kwargs)
|
||||
]
|
||||
# Start the ray workers first.
|
||||
ray_worker_outputs = [
|
||||
worker.execute_method.remote(method, *worker_args, **worker_kwargs)
|
||||
for (worker, worker_args, worker_kwargs
|
||||
) in zip(self.workers, all_worker_args, all_worker_kwargs)
|
||||
]
|
||||
|
||||
if async_run_remote_workers_only:
|
||||
# Just return futures
|
||||
return ray_worker_outputs
|
||||
|
||||
driver_worker_output = []
|
||||
driver_args = args if all_args is None else all_args[0]
|
||||
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
|
||||
|
||||
# Start the driver worker after all the ray workers.
|
||||
if not use_dummy_driver:
|
||||
driver_worker_output = self.driver_worker.execute_method(
|
||||
@ -324,36 +316,28 @@ class RayXPUExecutor(DistributedGPUExecutor):
|
||||
method, *driver_args, **driver_kwargs))
|
||||
# Get the results of the ray workers.
|
||||
if self.workers:
|
||||
if use_ray_compiled_dag:
|
||||
try:
|
||||
ray_worker_outputs = [
|
||||
pickle.loads(chan.begin_read())
|
||||
for chan in output_channels
|
||||
]
|
||||
finally:
|
||||
# Has to call end_read in order to reuse the DAG.
|
||||
for chan in output_channels:
|
||||
chan.end_read()
|
||||
else:
|
||||
ray_worker_outputs = ray.get(ray_worker_outputs)
|
||||
ray_worker_outputs = ray.get(ray_worker_outputs)
|
||||
|
||||
return [driver_worker_output] + ray_worker_outputs
|
||||
return driver_worker_output + ray_worker_outputs
|
||||
|
||||
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
|
||||
"""Wait for futures returned from _run_workers() with
|
||||
async_run_remote_workers_only to complete."""
|
||||
ray.get(parallel_worker_tasks)
|
||||
|
||||
def _compiled_ray_dag(self):
|
||||
def _compiled_ray_dag(self, enable_asyncio: bool):
|
||||
import pkg_resources
|
||||
required_version = "2.9"
|
||||
current_version = pkg_resources.get_distribution("ray").version
|
||||
from packaging import version
|
||||
|
||||
required_version = version.parse("2.32")
|
||||
current_version = version.parse(
|
||||
pkg_resources.get_distribution("ray").version)
|
||||
if current_version < required_version:
|
||||
raise ValueError(f"Ray version {required_version} or greater is "
|
||||
f"required, but found {current_version}")
|
||||
|
||||
from ray.dag import InputNode, MultiOutputNode
|
||||
assert self.parallel_config.worker_use_ray
|
||||
assert self.parallel_config.distributed_executor_backend == "ray"
|
||||
|
||||
# Right now, compiled DAG requires at least 1 arg. We send
|
||||
# a dummy value for now. It will be fixed soon.
|
||||
@ -363,7 +347,7 @@ class RayXPUExecutor(DistributedGPUExecutor):
|
||||
bind( # type: ignore[attr-defined]
|
||||
input_data) for worker in self.workers
|
||||
])
|
||||
return forward_dag.experimental_compile()
|
||||
return forward_dag.experimental_compile(enable_asyncio=enable_asyncio)
|
||||
|
||||
def check_health(self) -> None:
|
||||
"""Raises an error if engine is unhealthy."""
|
||||
|
||||
@ -281,6 +281,33 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
||||
# list to conform to interface.
|
||||
return output
|
||||
|
||||
def _execute_model_spmd(
|
||||
self, execute_model_req: ExecuteModelRequest
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
"""
|
||||
Execute model in Single Program Multiple Data (SPMD) fashion.
|
||||
All workers take the same request, prepare the input and
|
||||
execute the model.
|
||||
"""
|
||||
assert execute_model_req is not None, (
|
||||
"_execute_model_spmd() requires each worker to take in an "
|
||||
"ExecuteModelRequest")
|
||||
worker_input: WorkerInput = self.prepare_worker_input(
|
||||
execute_model_req=execute_model_req)
|
||||
model_input: ModelRunnerInputBase = (
|
||||
self.model_runner.prepare_model_input(
|
||||
execute_model_req.seq_group_metadata_list))
|
||||
|
||||
self.execute_worker(worker_input)
|
||||
|
||||
# If there is no input, we don't need to execute the model.
|
||||
if worker_input.num_seq_groups == 0:
|
||||
return []
|
||||
|
||||
return self.model_runner.execute_model(
|
||||
model_input, self.kv_cache[worker_input.virtual_engine]
|
||||
if self.kv_cache is not None else None)
|
||||
|
||||
|
||||
class WorkerWrapperBase:
|
||||
"""
|
||||
@ -296,7 +323,7 @@ class WorkerWrapperBase:
|
||||
trust_remote_code: bool = False) -> None:
|
||||
self.worker_module_name = worker_module_name
|
||||
self.worker_class_name = worker_class_name
|
||||
self.worker = None
|
||||
self.worker: Optional[WorkerBase] = None
|
||||
if trust_remote_code:
|
||||
# note: lazy import to avoid importing torch before initializing
|
||||
from vllm.utils import init_cached_hf_modules
|
||||
@ -323,7 +350,9 @@ class WorkerWrapperBase:
|
||||
|
||||
mod = importlib.import_module(self.worker_module_name)
|
||||
worker_class = getattr(mod, self.worker_class_name)
|
||||
|
||||
self.worker = worker_class(*args, **kwargs)
|
||||
assert self.worker is not None
|
||||
|
||||
def execute_method(self, method, *args, **kwargs):
|
||||
try:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user