diff --git a/requirements/test.in b/requirements/test.in index e906752ff875b..9b574a09fcce5 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -17,7 +17,7 @@ vector_quantize_pytorch # required for minicpmo_26 test vocos # required for minicpmo_26 test peft pqdm -ray[cgraph]>=2.43.0, !=2.44.* # Ray Compiled Graph, required by pipeline parallelism tests +ray[cgraph,default]>=2.43.0, !=2.44.* # Ray Compiled Graph, required by pipeline parallelism tests sentence-transformers # required for embedding tests soundfile # required for audio tests jiwer # required for audio tests diff --git a/requirements/test.txt b/requirements/test.txt index 60dcaca816a2b..03aec80ac1283 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -10,9 +10,13 @@ aiohappyeyeballs==2.4.3 # via aiohttp aiohttp==3.10.11 # via + # aiohttp-cors # datasets # fsspec # lm-eval + # ray +aiohttp-cors==0.8.1 + # via ray aiosignal==1.3.1 # via # aiohttp @@ -57,6 +61,8 @@ bounded-pool-executor==0.0.3 # via pqdm buildkite-test-collector==0.1.9 # via -r requirements/test.in +cachetools==5.5.2 + # via google-auth certifi==2024.8.30 # via # httpcore @@ -81,6 +87,8 @@ colorama==0.4.6 # sacrebleu # schemathesis # tqdm-multiprocess +colorful==0.5.6 + # via ray contourpy==1.3.0 # via matplotlib cramjam==2.9.0 @@ -108,6 +116,8 @@ dill==0.3.8 # evaluate # lm-eval # multiprocess +distlib==0.3.9 + # via virtualenv dnspython==2.7.0 # via email-validator docopt==0.6.2 @@ -143,6 +153,7 @@ filelock==3.16.1 # ray # torch # transformers + # virtualenv fonttools==4.54.1 # via matplotlib fqdn==1.5.1 @@ -165,8 +176,16 @@ genai-perf==0.0.8 # via -r requirements/test.in genson==1.3.0 # via datamodel-code-generator +google-api-core==2.24.2 + # via opencensus +google-auth==2.40.2 + # via google-api-core +googleapis-common-protos==1.70.0 + # via google-api-core graphql-core==3.2.6 # via hypothesis-graphql +grpcio==1.71.0 + # via ray h11==0.14.0 # via httpcore harfile==0.3.0 @@ -392,6 +411,10 @@ nvidia-nvjitlink-cu12==12.8.61 # torch nvidia-nvtx-cu12==12.8.55 # via torch +opencensus==0.11.4 + # via ray +opencensus-context==0.1.3 + # via opencensus opencv-python-headless==4.11.0.86 # via # -r requirements/test.in @@ -445,6 +468,7 @@ platformdirs==4.3.6 # via # black # pooch + # virtualenv plotly==5.24.1 # via genai-perf pluggy==1.5.0 @@ -457,10 +481,17 @@ portalocker==2.10.1 # via sacrebleu pqdm==0.2.0 # via -r requirements/test.in +prometheus-client==0.22.0 + # via ray propcache==0.2.0 # via yarl +proto-plus==1.26.1 + # via google-api-core protobuf==5.28.3 # via + # google-api-core + # googleapis-common-protos + # proto-plus # ray # tensorizer psutil==6.1.0 @@ -470,10 +501,18 @@ psutil==6.1.0 # tensorizer py==1.11.0 # via pytest-forked +py-spy==0.4.0 + # via ray pyarrow==18.0.0 # via # datasets # genai-perf +pyasn1==0.6.1 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.4.2 + # via google-auth pybind11==2.13.6 # via lm-eval pycparser==2.22 @@ -486,6 +525,7 @@ pydantic==2.11.5 # datamodel-code-generator # mistral-common # mteb + # ray pydantic-core==2.33.2 # via pydantic pygments==2.18.0 @@ -573,6 +613,7 @@ requests==2.32.3 # buildkite-test-collector # datasets # evaluate + # google-api-core # huggingface-hub # lm-eval # mistral-common @@ -601,6 +642,8 @@ rpds-py==0.20.1 # via # jsonschema # referencing +rsa==4.9.1 + # via google-auth runai-model-streamer==0.11.0 # via -r requirements/test.in runai-model-streamer-s3==0.11.0 @@ -648,9 +691,12 @@ shellingham==1.5.4 six==1.16.0 # via # junit-xml + # opencensus # python-dateutil # rfc3339-validator # rouge-score +smart-open==7.1.0 + # via ray sniffio==1.3.1 # via # anyio @@ -801,6 +847,8 @@ urllib3==2.2.3 # tritonclient vector-quantize-pytorch==1.21.2 # via -r requirements/test.in +virtualenv==20.31.2 + # via ray vocos==0.1.0 # via -r requirements/test.in webcolors==24.11.1 @@ -809,6 +857,8 @@ werkzeug==3.1.3 # via schemathesis word2number==1.1 # via lm-eval +wrapt==1.17.2 + # via smart-open xxhash==3.5.0 # via # datasets diff --git a/tests/v1/test_async_llm_dp.py b/tests/v1/test_async_llm_dp.py index ce4c4d198db58..366fa3b2561fd 100644 --- a/tests/v1/test_async_llm_dp.py +++ b/tests/v1/test_async_llm_dp.py @@ -59,14 +59,22 @@ async def generate(engine: AsyncLLM, @pytest.mark.parametrize( - "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) + "output_kind", + [ + RequestOutputKind.DELTA, + RequestOutputKind.FINAL_ONLY, + ], +) +@pytest.mark.parametrize("data_parallel_backend", ["mp", "ray"]) @pytest.mark.asyncio -async def test_load(output_kind: RequestOutputKind): +async def test_load(output_kind: RequestOutputKind, + data_parallel_backend: str): with ExitStack() as after: prompt = "This is a test of data parallel" + engine_args.data_parallel_backend = data_parallel_backend engine = AsyncLLM.from_engine_args(engine_args) after.callback(engine.shutdown) @@ -82,7 +90,6 @@ async def test_load(output_kind: RequestOutputKind): asyncio.create_task( generate(engine, request_id, prompt, output_kind, NUM_EXPECTED_TOKENS))) - # Confirm that we got all the EXPECTED tokens from the requests. done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION) diff --git a/vllm/config.py b/vllm/config.py index 1bd53e35b0532..8aa1b56103004 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1742,6 +1742,8 @@ class ParallelConfig: """Port for data parallel messaging.""" data_parallel_master_port: int = 29500 """Port of the data parallel master.""" + data_parallel_backend: str = "mp" + """Backend to use for data parallel, either "mp" or "ray".""" enable_expert_parallel: bool = False """Use expert parallelism instead of tensor parallelism for MoE layers.""" max_parallel_loading_workers: Optional[int] = None @@ -1911,6 +1913,10 @@ class ParallelConfig: "please install Ray with `pip install " "ray`.") from ray_utils.ray_import_err backend = "ray" + elif self.data_parallel_backend == "ray": + logger.info("Using ray distributed inference because " + "data_parallel_backend is ray") + backend = "ray" elif ray_found: if self.placement_group: backend = "ray" diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 299c8347f458a..a5b155024b73a 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -39,7 +39,7 @@ from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 from vllm.transformers_utils.utils import check_gguf_file from vllm.usage.usage_lib import UsageContext from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser, - GiB_bytes, is_in_ray_actor) + GiB_bytes, get_ip, is_in_ray_actor) # yapf: enable @@ -292,6 +292,7 @@ class EngineArgs: data_parallel_size_local: Optional[int] = None data_parallel_address: Optional[str] = None data_parallel_rpc_port: Optional[int] = None + data_parallel_backend: str = ParallelConfig.data_parallel_backend enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel max_parallel_loading_workers: Optional[ int] = ParallelConfig.max_parallel_loading_workers @@ -624,6 +625,12 @@ class EngineArgs: type=int, help='Port for data parallel RPC ' 'communication.') + parallel_group.add_argument('--data-parallel-backend', + '-dpb', + type=str, + default='mp', + help='Backend for data parallel, either ' + '"mp" or "ray".') parallel_group.add_argument( "--enable-expert-parallel", **parallel_kwargs["enable_expert_parallel"]) @@ -1059,9 +1066,20 @@ class EngineArgs: # DP address, used in multi-node case for torch distributed group # and ZMQ sockets. - data_parallel_address = self.data_parallel_address if ( - self.data_parallel_address - is not None) else ParallelConfig.data_parallel_master_ip + if self.data_parallel_address is None: + if self.data_parallel_backend == "ray": + host_ip = get_ip() + logger.info( + "Using host IP %s as ray-based data parallel address", + host_ip) + data_parallel_address = host_ip + else: + assert self.data_parallel_backend == "mp", ( + "data_parallel_backend can only be ray or mp, got %s", + self.data_parallel_backend) + data_parallel_address = ParallelConfig.data_parallel_master_ip + else: + data_parallel_address = self.data_parallel_address # This port is only used when there are remote data parallel engines, # otherwise the local IPC transport is used. @@ -1069,6 +1087,8 @@ class EngineArgs: self.data_parallel_rpc_port is not None) else ParallelConfig.data_parallel_rpc_port + data_parallel_backend = self.data_parallel_backend + parallel_config = ParallelConfig( pipeline_parallel_size=self.pipeline_parallel_size, tensor_parallel_size=self.tensor_parallel_size, @@ -1076,6 +1096,7 @@ class EngineArgs: data_parallel_size_local=data_parallel_size_local, data_parallel_master_ip=data_parallel_address, data_parallel_rpc_port=data_parallel_rpc_port, + data_parallel_backend=data_parallel_backend, enable_expert_parallel=self.enable_expert_parallel, max_parallel_loading_workers=self.max_parallel_loading_workers, disable_custom_all_reduce=self.disable_custom_all_reduce, diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index e65c97073218b..040ae166a2d5f 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -27,7 +27,8 @@ from vllm.v1.engine.core_client import CoreEngineProcManager from vllm.v1.executor.abstract import Executor from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus from vllm.v1.utils import (APIServerProcessManager, CoreEngine, - EngineZmqAddresses, get_engine_client_zmq_addr, + CoreEngineActorManager, EngineZmqAddresses, + get_engine_client_zmq_addr, wait_for_completion_or_failure, wait_for_engine_startup) @@ -229,6 +230,31 @@ def run_multi_api_server(args: argparse.Namespace): logger.info("Started DP Coordinator process (PID: %d)", coordinator.proc.pid) + if parallel_config.data_parallel_backend == "ray": + logger.info("Starting ray-based data parallel backend") + + engine_actor_manager = CoreEngineActorManager( + vllm_config=vllm_config, + addresses=addresses, + executor_class=Executor.get_class(vllm_config), + log_stats=not engine_args.disable_log_stats, + ) + # Start API servers using the manager + api_server_manager = APIServerProcessManager( + target_server_fn=run_api_server_worker_proc, + listen_address=listen_address, + sock=sock, + args=args, + num_servers=num_api_servers, + input_addresses=input_addresses, + output_addresses=output_addresses, + stats_update_address=stats_update_address) + + wait_for_completion_or_failure(api_server_manager=api_server_manager, + engine_manager=engine_actor_manager, + coordinator=coordinator) + return + handshake_address = get_engine_client_zmq_addr( local_only, host, parallel_config.data_parallel_rpc_port) @@ -277,10 +303,9 @@ def run_multi_api_server(args: argparse.Namespace): ) # Wait for API servers - wait_for_completion_or_failure( - api_server_manager=api_server_manager, - local_engine_manager=local_engine_manager, - coordinator=coordinator) + wait_for_completion_or_failure(api_server_manager=api_server_manager, + engine_manager=local_engine_manager, + coordinator=coordinator) def run_api_server_worker_proc(listen_address, diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 86781e7528fa3..4b235c596ed6d 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -27,7 +27,8 @@ from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.usage.usage_lib import UsageContext from vllm.utils import Device, cdiv from vllm.v1.engine import EngineCoreRequest -from vllm.v1.engine.core_client import AsyncMPClient, DPAsyncMPClient +from vllm.v1.engine.core_client import (AsyncMPClient, DPAsyncMPClient, + RayDPClient) from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError from vllm.v1.engine.output_processor import (OutputProcessor, RequestOutputCollector) @@ -119,9 +120,13 @@ class AsyncLLM(EngineClient): log_stats=self.log_stats) # EngineCore (starts the engine in background process). - core_client_class = AsyncMPClient if ( - vllm_config.parallel_config.data_parallel_size - == 1) else DPAsyncMPClient + core_client_class: type[AsyncMPClient] + if vllm_config.parallel_config.data_parallel_size == 1: + core_client_class = AsyncMPClient + elif vllm_config.parallel_config.data_parallel_backend == "ray": + core_client_class = RayDPClient + else: + core_client_class = DPAsyncMPClient self.engine_core = core_client_class( vllm_config=vllm_config, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index a02abb62b1f36..7253d1dc66d1f 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -6,8 +6,9 @@ import sys import threading import time from collections import deque +from collections.abc import Generator from concurrent.futures import Future -from contextlib import ExitStack +from contextlib import ExitStack, contextmanager from inspect import isclass, signature from logging import DEBUG from typing import Any, Callable, Optional, TypeVar, Union @@ -367,42 +368,66 @@ class EngineCoreProc(EngineCore): log_stats: bool, engine_index: int = 0, ): - input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]() - - executor_fail_callback = lambda: input_queue.put_nowait( + self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]() + self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs], + bytes]]() + executor_fail_callback = lambda: self.input_queue.put_nowait( (EngineCoreRequestType.EXECUTOR_FAILED, b'')) - # Create input socket. + self.engine_index = engine_index + identity = self.engine_index.to_bytes(length=2, byteorder="little") + self.engines_running = False + + with self._perform_handshake(handshake_address, identity, on_head_node, + vllm_config) as addresses: + self.client_count = len(addresses.outputs) + + # Set up data parallel environment. + self.has_coordinator = addresses.coordinator_output is not None + self._init_data_parallel(vllm_config) + + super().__init__(vllm_config, executor_class, log_stats, + executor_fail_callback) + + self.step_fn = (self.step if self.batch_queue is None else + self.step_with_batch_queue) + + # Background Threads and Queues for IO. These enable us to + # overlap ZMQ socket IO with GPU since they release the GIL, + # and to overlap some serialization/deserialization with the + # model forward pass. + # Threads handle Socket <-> Queues and core_busy_loop uses Queue. + threading.Thread(target=self.process_input_sockets, + args=(addresses.inputs, addresses.coordinator_input, + identity), + daemon=True).start() + self.output_thread = threading.Thread( + target=self.process_output_sockets, + args=(addresses.outputs, addresses.coordinator_output, + self.engine_index), + daemon=True) + self.output_thread.start() + + @contextmanager + def _perform_handshake( + self, handshake_address: str, identity: bytes, on_head_node: bool, + vllm_config: VllmConfig + ) -> Generator[EngineZmqAddresses, None, None]: input_ctx = zmq.Context() - identity = engine_index.to_bytes(length=2, byteorder="little") with make_zmq_socket(input_ctx, handshake_address, zmq.DEALER, identity=identity, linger=5000, bind=False) as handshake_socket: - # Register engine with front-end. addresses = self.startup_handshake(handshake_socket, on_head_node, vllm_config.parallel_config) - self.client_count = len(addresses.outputs) - # Update config which may have changed from the handshake. + # Update config which may have changed from the handshake vllm_config.__post_init__() - # Set up data parallel environment. - self.has_coordinator = addresses.coordinator_output is not None - self._init_data_parallel(vllm_config) - - # Initialize engine core and model. - super().__init__(vllm_config, executor_class, log_stats, - executor_fail_callback) - - self.engine_index = engine_index - self.step_fn = (self.step if self.batch_queue is None else - self.step_with_batch_queue) - self.engines_running = False - self.last_counts = (0, 0) + yield addresses # Send ready message. num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks @@ -413,25 +438,6 @@ class EngineCoreProc(EngineCore): "num_gpu_blocks": num_gpu_blocks, })) - # Background Threads and Queues for IO. These enable us to - # overlap ZMQ socket IO with GPU since they release the GIL, - # and to overlap some serialization/deserialization with the - # model forward pass. - # Threads handle Socket <-> Queues and core_busy_loop uses Queue. - self.input_queue = input_queue - self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs], - bytes]]() - threading.Thread(target=self.process_input_sockets, - args=(addresses.inputs, addresses.coordinator_input, - identity), - daemon=True).start() - self.output_thread = threading.Thread( - target=self.process_output_sockets, - args=(addresses.outputs, addresses.coordinator_output, - engine_index), - daemon=True) - self.output_thread.start() - @staticmethod def startup_handshake( handshake_socket: zmq.Socket, on_head_node: bool, @@ -743,6 +749,21 @@ class DPEngineCoreProc(EngineCoreProc): executor_class: type[Executor], log_stats: bool, ): + + self._decorate_logs() + + # Counts forward-passes of the model so that we can synchronize + # finished with DP peers every N steps. + self.counter = 0 + self.current_wave = 0 + self.last_counts = (0, 0) + + # Initialize the engine. + dp_rank = vllm_config.parallel_config.data_parallel_rank + super().__init__(vllm_config, on_head_node, handshake_address, + executor_class, log_stats, dp_rank) + + def _decorate_logs(self): # Add process-specific prefix to stdout and stderr before # we initialize the engine. from multiprocessing import current_process @@ -751,16 +772,6 @@ class DPEngineCoreProc(EngineCoreProc): _add_prefix(sys.stdout, process_name, pid) _add_prefix(sys.stderr, process_name, pid) - # Counts forward-passes of the model so that we can synchronize - # finished with DP peers every N steps. - self.counter = 0 - self.current_wave = 0 - - # Initialize the engine. - dp_rank = vllm_config.parallel_config.data_parallel_rank - super().__init__(vllm_config, on_head_node, handshake_address, - executor_class, log_stats, dp_rank) - def _init_data_parallel(self, vllm_config: VllmConfig): # Configure GPUs and stateless process group for data parallel. @@ -880,3 +891,70 @@ class DPEngineCoreProc(EngineCoreProc): return ParallelConfig.has_unfinished_dp(self.dp_group, local_unfinished) + + +class DPEngineCoreActor(DPEngineCoreProc): + """ + Ray actor for running EngineCore in a data parallel context + """ + + def __init__( + self, + vllm_config: VllmConfig, + on_head_node: bool, + addresses: EngineZmqAddresses, + executor_class: type[Executor], + log_stats: bool, + dp_rank: int = 0, + local_dp_rank: int = 0, + ): + self.addresses = addresses + vllm_config.parallel_config.data_parallel_rank = dp_rank + vllm_config.parallel_config.data_parallel_rank_local = \ + local_dp_rank + + # Ray sets CUDA_VISIBLE_DEVICES to empty string, + # we clean this up to be able to properly initialize + # data parallel groups. + del os.environ['CUDA_VISIBLE_DEVICES'] + + super().__init__(vllm_config, on_head_node, "", executor_class, + log_stats) + + def _decorate_logs(self): + pass + + @contextmanager + def _perform_handshake(self, handshake_address: str, identity: bytes, + on_head_node: bool, vllm_config: VllmConfig): + """ + For Ray, we don't need to actually perform handshake. + All addresses information is known before the actor creation. + Therefore, we simply yield these addresses. + """ + yield self.addresses + + def wait_for_init(self): + """ + Wait until the engine core is initialized. + + This is just an empty method. When ray.get() on this method + (or any other method of the actor) returns, it is guaranteed + that actor creation (i.e., __init__) is complete. + """ + pass + + def run(self): + """ + Run the engine core busy loop. + """ + try: + self.run_busy_loop() + except SystemExit: + logger.debug("EngineCore exiting.") + raise + except Exception: + logger.exception("EngineCore encountered a fatal error.") + raise + finally: + self.shutdown() diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 232d6742b7718..fa01998aa9fe2 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -29,9 +29,9 @@ from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.engine.exceptions import EngineDeadError from vllm.v1.executor.abstract import Executor from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr -from vllm.v1.utils import (CoreEngine, CoreEngineProcManager, - EngineZmqAddresses, get_engine_client_zmq_addr, - wait_for_engine_startup) +from vllm.v1.utils import (CoreEngine, CoreEngineActorManager, + CoreEngineProcManager, EngineZmqAddresses, + get_engine_client_zmq_addr, wait_for_engine_startup) logger = init_logger(__name__) @@ -68,6 +68,8 @@ class EngineCoreClient(ABC): if multiprocess_mode and asyncio_mode: if vllm_config.parallel_config.data_parallel_size > 1: + if vllm_config.parallel_config.data_parallel_backend == "ray": + return RayDPClient(vllm_config, executor_class, log_stats) return DPAsyncMPClient(vllm_config, executor_class, log_stats) return AsyncMPClient(vllm_config, executor_class, log_stats) @@ -273,7 +275,10 @@ class BackgroundResources: circular reference back to the client object.""" ctx: Union[zmq.Context] - local_engine_manager: Optional[CoreEngineProcManager] = None + # If CoreEngineProcManager, it manages local engines; + # if CoreEngineActorManager, it manages all engines. + engine_manager: Optional[Union[CoreEngineProcManager, + CoreEngineActorManager]] = None coordinator: Optional[DPCoordinator] = None output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None @@ -290,8 +295,8 @@ class BackgroundResources: """Clean up background resources.""" self.engine_dead = True - if self.local_engine_manager is not None: - self.local_engine_manager.close() + if self.engine_manager is not None: + self.engine_manager.close() if self.coordinator is not None: self.coordinator.close() @@ -457,7 +462,7 @@ class MPClient(EngineCoreClient): if local_engine_count: # In server mode, start_index and local_start_index will # both be 0. - self.resources.local_engine_manager = CoreEngineProcManager( + self.resources.engine_manager = CoreEngineProcManager( EngineCoreProc.run_engine_core, vllm_config=vllm_config, executor_class=executor_class, @@ -484,13 +489,18 @@ class MPClient(EngineCoreClient): addresses.coordinator_input, addresses.coordinator_output = ( coordinator.get_engine_socket_addresses()) + proc_manager = self.resources.engine_manager + assert isinstance(proc_manager, (type(None), CoreEngineProcManager)), ( + "_wait_for_engine_startup should only be " + "called with CoreEngineProcManager") + wait_for_engine_startup( handshake_socket, addresses, self.core_engines, self.vllm_config.parallel_config, self.vllm_config.cache_config, - self.resources.local_engine_manager, + proc_manager, coordinator.proc if coordinator else None, ) @@ -887,7 +897,6 @@ class DPAsyncMPClient(AsyncMPClient): log_stats: bool, client_addresses: Optional[dict[str, str]] = None, client_index: int = 0): - self.current_wave = 0 self.engines_running = False # To route aborts to the correct engine. @@ -1050,3 +1059,50 @@ class DPAsyncMPClient(AsyncMPClient): if not self.resources.engine_dead: await self._send_input(EngineCoreRequestType.ABORT, request_ids, engine) + + +class RayDPClient(DPAsyncMPClient): + """ + Ray-based client for multi-proc, multi-engine (data parallel) + EngineCore. + """ + + def __init__( + self, + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + client_addresses: Optional[dict[str, str]] = None, + client_index: int = 0, + ): + super().__init__(vllm_config, executor_class, log_stats, + client_addresses, client_index) + + def _init_engines_direct(self, vllm_config: VllmConfig, local_only: bool, + local_start_index: int, input_address: str, + output_address: str, + executor_class: type[Executor], log_stats: bool): + """Self-contained client mode, launch engine and coordinator process + as needed.""" + + parallel_config = vllm_config.parallel_config + assert parallel_config.data_parallel_rank == 0 + assert local_start_index == 0 + + addresses = EngineZmqAddresses( + inputs=[input_address], + outputs=[output_address], + ) + + if len(self.core_engines) > 1: + coordinator = DPCoordinator(parallel_config) + self.resources.coordinator = coordinator + addresses.coordinator_input, addresses.coordinator_output = ( + coordinator.get_engine_socket_addresses()) + + # Start all engines. + self.resources.engine_manager = CoreEngineActorManager( + vllm_config=vllm_config, + addresses=addresses, + executor_class=executor_class, + log_stats=log_stats) diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index a26794561a526..d347efc425ef4 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -27,6 +27,8 @@ from vllm.utils import (get_mp_context, get_open_port, get_open_zmq_ipc_path, from vllm.v1.executor.abstract import Executor if TYPE_CHECKING: + from ray.util.placement_group import PlacementGroup + from vllm.attention.layer import Attention from vllm.v1.engine.coordinator import DPCoordinator @@ -112,6 +114,45 @@ def get_engine_client_zmq_addr(local_only: bool, host, port or get_open_port())) +class CoreEngineState(Enum): + NEW = auto() + CONNECTED = auto() + READY = auto() + + +class CoreEngine: + """One per data parallel rank.""" + + def __init__(self, index: int = 0, local: bool = True): + self.local = local + self.index = index + self.identity = index.to_bytes(2, "little") + + self.state = CoreEngineState.NEW + + +@dataclass +class EngineZmqAddresses: + # ZMQ input socket addresses for each front-end client (requests) + inputs: list[str] + # ZMQ output socket addresses for each front-end client (responses) + outputs: list[str] + # ZMQ input socket address of DP coordinator if applicable + coordinator_input: Optional[str] = None + # ZMQ output socket address of DP coordinator if applicable + coordinator_output: Optional[str] = None + + +@dataclass +class EngineHandshakeMetadata: + """Metadata sent to each engine process during startup handshake, + including addresses of the front-end ZMQ queues that they should + connect to. + """ + addresses: EngineZmqAddresses + parallel_config: dict[str, Union[int, str]] + + class APIServerProcessManager: """Manages a group of API server processes. @@ -245,43 +286,168 @@ class CoreEngineProcManager: } -class CoreEngineState(Enum): - NEW = auto() - CONNECTED = auto() - READY = auto() - - -class CoreEngine: - """One per data parallel rank.""" - - def __init__(self, index: int = 0, local: bool = True): - self.local = local - self.index = index - self.identity = index.to_bytes(2, "little") - - self.state = CoreEngineState.NEW - - -@dataclass -class EngineZmqAddresses: - # ZMQ input socket addresses for each front-end client (requests) - inputs: list[str] - # ZMQ output socket addresses for each front-end client (responses) - outputs: list[str] - # ZMQ input socket address of DP coordinator if applicable - coordinator_input: Optional[str] = None - # ZMQ output socket address of DP coordinator if applicable - coordinator_output: Optional[str] = None - - -@dataclass -class EngineHandshakeMetadata: - """Metadata sent to each engine process during startup handshake, - including addresses of the front-end ZMQ queues that they should - connect to. +class CoreEngineActorManager: """ - addresses: EngineZmqAddresses - parallel_config: dict[str, Union[int, str]] + Utility class to handle creation, readiness, and shutdown + of core engine Ray actors used by the AsyncLLM and LLMEngine. + + Different from CoreEngineProcManager, this class manages + core engines for both local and remote nodes. + """ + + def __init__( + self, + vllm_config: VllmConfig, + addresses: EngineZmqAddresses, + executor_class: type[Executor], + log_stats: bool, + placement_groups: Optional[list["PlacementGroup"]] = None, + local_dp_ranks: Optional[list[int]] = None, + ): + import copy + + import ray + from ray.util.scheduling_strategies import ( + PlacementGroupSchedulingStrategy) + + from vllm.v1.engine.core import DPEngineCoreActor + + self.local_engine_actors: list[ray.ActorHandle] = [] + self.remote_engine_actors: list[ray.ActorHandle] = [] + dp_size = vllm_config.parallel_config.data_parallel_size + local_engine_count = \ + vllm_config.parallel_config.data_parallel_size_local + world_size = vllm_config.parallel_config.world_size + + if ray.is_initialized(): + logger.info( + "Ray is already initialized. Skipping Ray initialization.") + else: + ray.init() + + if placement_groups is not None: + assert local_dp_ranks is not None, ( + "local_dp_ranks must be provided if " + "placement_groups is provided") + assert len(placement_groups) == len(local_dp_ranks), ( + "placement_groups and local_dp_ranks must " + "have the same length") + logger.info("Using provided placement groups") + # TODO(rui): validate passed-in placement groups + self.created_placement_groups = [] + else: + placement_groups, local_dp_ranks = \ + CoreEngineActorManager.create_dp_placement_groups(vllm_config) + self.created_placement_groups = placement_groups + assert len(placement_groups) == dp_size, ( + "Number of placement groups must match data parallel size") + + refs = [] + for index in range(dp_size): + local_index = local_dp_ranks[index] + dp_vllm_config = copy.deepcopy(vllm_config) + pg = placement_groups[index] + dp_vllm_config.parallel_config.placement_group = pg + on_head_node = index < local_engine_count + actor = ray.remote(DPEngineCoreActor).options( + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=world_size, + )).remote(vllm_config=dp_vllm_config, + executor_class=executor_class, + log_stats=log_stats, + on_head_node=on_head_node, + addresses=addresses, + dp_rank=index, + local_dp_rank=local_index) + if on_head_node: + self.local_engine_actors.append(actor) + else: + self.remote_engine_actors.append(actor) + refs.append(actor.wait_for_init.remote()) + + ray.get(refs) + self.run_refs = [] + for actor in self.local_engine_actors + self.remote_engine_actors: + self.run_refs.append(actor.run.remote()) + + @staticmethod + def create_dp_placement_groups( + vllm_config: VllmConfig + ) -> tuple[list["PlacementGroup"], list[int]]: + + import ray + from ray._private.state import available_resources_per_node + from ray.util.state import list_nodes + + logger.info("Creating placement groups for data parallel") + dp_master_ip = \ + vllm_config.parallel_config.data_parallel_master_ip + dp_size = vllm_config.parallel_config.data_parallel_size + local_engine_count = \ + vllm_config.parallel_config.data_parallel_size_local + + nodes = list_nodes() + nodes = sorted(list_nodes(), + key=lambda node: node.node_ip != dp_master_ip) + assert nodes[0].node_ip == dp_master_ip, ( + "The first node must be the head node") + assert len(nodes) == 1 or nodes[1].node_ip != dp_master_ip, ( + "There can only be one head node") + + available_resources = available_resources_per_node() + world_size = vllm_config.parallel_config.world_size + placement_groups: list[PlacementGroup] = [] + local_dp_ranks: list[int] = [] + + for node in nodes: + node_ip = node.node_ip + node_resources = available_resources[node.node_id] + # For now, each DP rank can only be assigned to one node + # TODO(rui): support allocating a single DP rank + # to multiple nodes + available_engine_count = node_resources["GPU"] // world_size + if node_ip == dp_master_ip: + assert available_engine_count >= local_engine_count, ( + "Not enough resources to allocate DP ranks " + f"on DP master node {node_ip}") + for i in range(local_engine_count): + bundles = [{ + "GPU": 1.0, + "node:" + dp_master_ip: 0.001 + }] * world_size + [{ + "CPU": 1.0 + }] + pg = ray.util.placement_group( + name=f"dp_rank_{len(placement_groups)}", + strategy="STRICT_PACK", + bundles=bundles, + ) + placement_groups.append(pg) + local_dp_ranks.append(i) + else: + for i in range(available_engine_count): + if len(placement_groups) == dp_size: + break + bundles = [{"GPU": 1.0}] * world_size + [{"CPU": 1.0}] + pg = ray.util.placement_group( + name=f"dp_rank_{len(placement_groups)}", + strategy="STRICT_PACK", + bundles=bundles, + ) + placement_groups.append(pg) + local_dp_ranks.append(i) + return placement_groups, local_dp_ranks + + def get_run_refs(self): + return self.run_refs + + def close(self): + import ray + for actor in self.local_engine_actors + self.remote_engine_actors: + ray.kill(actor) + for pg in self.created_placement_groups: + ray.util.remove_placement_group(pg) def wait_for_engine_startup( @@ -383,11 +549,19 @@ def wait_for_engine_startup( def wait_for_completion_or_failure( api_server_manager: APIServerProcessManager, - local_engine_manager: Optional[CoreEngineProcManager] = None, + engine_manager: Optional[Union[CoreEngineProcManager, + CoreEngineActorManager]] = None, coordinator: Optional["DPCoordinator"] = None) -> None: """Wait for all processes to complete or detect if any fail. Raises an exception if any process exits with a non-zero status. + + Args: + api_server_manager: The manager for API servers. + engine_manager: The manager for engine processes. + If CoreEngineProcManager, it manages local engines; + if CoreEngineActorManager, it manages all engines. + coordinator: The coordinator for data parallel. """ try: @@ -402,14 +576,18 @@ def wait_for_completion_or_failure( if coordinator: sentinel_to_proc[coordinator.proc.sentinel] = coordinator.proc - if local_engine_manager: - for proc in local_engine_manager.processes: + actor_run_refs = [] + if isinstance(engine_manager, CoreEngineProcManager): + for proc in engine_manager.processes: sentinel_to_proc[proc.sentinel] = proc + elif isinstance(engine_manager, CoreEngineActorManager): + actor_run_refs = engine_manager.get_run_refs() # Check if any process terminates - while sentinel_to_proc: + while sentinel_to_proc or actor_run_refs: # Wait for any process to terminate - ready_sentinels: list[Any] = connection.wait(sentinel_to_proc) + ready_sentinels: list[Any] = connection.wait(sentinel_to_proc, + timeout=5) # Process any terminated processes for sentinel in ready_sentinels: @@ -420,6 +598,11 @@ def wait_for_completion_or_failure( raise RuntimeError( f"Process {proc.name} (PID: {proc.pid}) " f"died with exit code {proc.exitcode}") + + if actor_run_refs: + import ray + _, actor_run_refs = ray.wait(actor_run_refs, timeout=5) + except KeyboardInterrupt: logger.info("Received KeyboardInterrupt, shutting down API servers...") except Exception as e: @@ -431,8 +614,8 @@ def wait_for_completion_or_failure( api_server_manager.close() if coordinator: coordinator.close() - if local_engine_manager: - local_engine_manager.close() + if engine_manager: + engine_manager.close() # Note(rob): shutdown function cannot be a bound method,