From 651cf0fec19ed1da44ad266066f86b74e5246006 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 4 Apr 2025 12:56:43 -0700 Subject: [PATCH] [V1] DP scale-out (1/N): Use zmq ROUTER/DEALER sockets for input queue (#15906) Signed-off-by: Nick Hill --- vllm/utils.py | 35 +++++++---- vllm/v1/engine/core.py | 28 +++++---- vllm/v1/engine/core_client.py | 108 ++++++++++++++++++++++------------ vllm/v1/utils.py | 11 +--- 4 files changed, 113 insertions(+), 69 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index 5f32f8cb66a5c..46f01638d0eb5 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -2189,6 +2189,8 @@ def make_zmq_socket( ctx: Union[zmq.asyncio.Context, zmq.Context], # type: ignore[name-defined] path: str, socket_type: Any, + bind: Optional[bool] = None, + identity: Optional[bytes] = None, ) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined] """Make a ZMQ socket with the proper bind/connect semantics.""" @@ -2207,16 +2209,24 @@ def make_zmq_socket( else: buf_size = -1 # Use system default buffer size - if socket_type == zmq.constants.PULL: - socket.setsockopt(zmq.constants.RCVHWM, 0) - socket.setsockopt(zmq.constants.RCVBUF, buf_size) + if bind is None: + bind = socket_type != zmq.PUSH + + if socket_type in (zmq.PULL, zmq.DEALER, zmq.ROUTER): + socket.setsockopt(zmq.RCVHWM, 0) + socket.setsockopt(zmq.RCVBUF, buf_size) + + if socket_type in (zmq.PUSH, zmq.DEALER, zmq.ROUTER): + socket.setsockopt(zmq.SNDHWM, 0) + socket.setsockopt(zmq.SNDBUF, buf_size) + + if identity is not None: + socket.setsockopt(zmq.IDENTITY, identity) + + if bind: socket.bind(path) - elif socket_type == zmq.constants.PUSH: - socket.setsockopt(zmq.constants.SNDHWM, 0) - socket.setsockopt(zmq.constants.SNDBUF, buf_size) - socket.connect(path) else: - raise ValueError(f"Unknown Socket Type: {socket_type}") + socket.connect(path) return socket @@ -2225,14 +2235,19 @@ def make_zmq_socket( def zmq_socket_ctx( path: str, socket_type: Any, + bind: Optional[bool] = None, linger: int = 0, + identity: Optional[bytes] = None, ) -> Iterator[zmq.Socket]: """Context manager for a ZMQ socket""" ctx = zmq.Context() # type: ignore[attr-defined] try: - yield make_zmq_socket(ctx, path, socket_type) - + yield make_zmq_socket(ctx, + path, + socket_type, + bind=bind, + identity=identity) except KeyboardInterrupt: logger.debug("Got Keyboard Interrupt.") diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 39caca0c2a452..f58c77e4f1658 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -318,6 +318,11 @@ class EngineCoreProc(EngineCore): ): super().__init__(vllm_config, executor_class, log_stats) + self.step_fn = (self.step if self.batch_queue is None else + self.step_with_batch_queue) + + self.global_unfinished_reqs = False + # Background Threads and Queues for IO. These enable us to # overlap ZMQ socket IO with GPU since they release the GIL, # and to overlap some serialization/deserialization with the @@ -327,22 +332,16 @@ class EngineCoreProc(EngineCore): Any]] = queue.Queue() self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue() threading.Thread(target=self.process_input_socket, - args=(input_path, ), + args=(input_path, engine_index), daemon=True).start() threading.Thread(target=self.process_output_socket, args=(output_path, engine_index), daemon=True).start() - self.global_unfinished_reqs = False - - self.step_fn = (self.step if self.batch_queue is None else - self.step_with_batch_queue) - @staticmethod def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, - ready_pipe, **kwargs): """Launch EngineCore busy loop in background process.""" @@ -377,9 +376,6 @@ class EngineCoreProc(EngineCore): else: engine_core = EngineCoreProc(*args, **kwargs) - # Send Readiness signal to EngineClient. - ready_pipe.send({"status": "READY"}) - engine_core.run_busy_loop() except SystemExit: @@ -476,14 +472,22 @@ class EngineCoreProc(EngineCore): and not isinstance(v, p.annotation) else v for v, p in zip(args, arg_types)) - def process_input_socket(self, input_path: str): + def process_input_socket(self, input_path: str, engine_index: int): """Input socket IO thread.""" # Msgpack serialization decoding. add_request_decoder = MsgpackDecoder(EngineCoreRequest) generic_decoder = MsgpackDecoder() + identity = engine_index.to_bytes(length=2, byteorder="little") + + with zmq_socket_ctx(input_path, + zmq.DEALER, + identity=identity, + bind=False) as socket: + + # Send ready message to front-end once input socket is connected. + socket.send(b'READY') - with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket: while True: # (RequestType, RequestData) type_frame, data_frame = socket.recv_multipart(copy=False) diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index e948e59b8c425..b94b0aa75386a 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -8,7 +8,7 @@ import threading import uuid import weakref from abc import ABC, abstractmethod -from collections.abc import Awaitable, Sequence +from collections.abc import Awaitable from concurrent.futures import Future from dataclasses import dataclass, field from threading import Thread @@ -35,6 +35,8 @@ AnyFuture = Union[asyncio.Future[Any], Future[Any]] _R = TypeVar('_R') # Return type for collective_rpc +STARTUP_POLL_PERIOD_MS = 10000 + class EngineCoreClient(ABC): """ @@ -261,15 +263,13 @@ class CoreEngine: vllm_config: VllmConfig, executor_class: type[Executor], log_stats: bool, - ctx: Union[zmq.Context, zmq.asyncio.Context], + input_path: str, output_path: str, index: int = 0, local_dp_rank: int = 0, ): - # Paths and sockets for IPC. - input_path = get_open_zmq_ipc_path() - self.input_socket = make_zmq_socket(ctx, input_path, - zmq.constants.PUSH) + self.index = index + self.identity = index.to_bytes(length=2, byteorder="little") try: # Start EngineCore in background process. self.proc_handle = BackgroundProcHandle( @@ -291,14 +291,9 @@ class CoreEngine: # Ensure socket is closed if process fails to start. self.close() - def send_multipart(self, msg_parts: Sequence): - return self.input_socket.send_multipart(msg_parts, copy=False) - def close(self): if proc_handle := getattr(self, "proc_handle", None): proc_handle.shutdown() - if socket := getattr(self, "input_socket", None): - socket.close(linger=0) @dataclass @@ -309,6 +304,7 @@ class BackgroundResources: ctx: Union[zmq.Context] core_engines: list[CoreEngine] = field(default_factory=list) output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None + input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None shutdown_path: Optional[str] = None def __call__(self): @@ -321,6 +317,8 @@ class BackgroundResources: # aren't explicitly closed first. if self.output_socket is not None: self.output_socket.close(linger=0) + if self.input_socket is not None: + self.input_socket.close(linger=0) if self.shutdown_path is not None: # We must ensure that the sync output socket is # closed cleanly in its own thread. @@ -387,21 +385,51 @@ class MPClient(EngineCoreClient): # Paths and sockets for IPC. self.output_path = get_open_zmq_ipc_path() + input_path = get_open_zmq_ipc_path() + self.input_socket = make_zmq_socket(self.ctx, + input_path, + zmq.ROUTER, + bind=True) + self.resources.input_socket = self.input_socket new_core_engine = lambda index, local_dp_rank=None: CoreEngine( - vllm_config, executor_class, log_stats, self.ctx, self.output_path, - index, local_dp_rank) + vllm_config, executor_class, log_stats, input_path, self. + output_path, index, local_dp_rank) # Start engine core process(es). self._init_core_engines(vllm_config, new_core_engine, self.resources.core_engines) # Wait for engine core process(es) to start. - for engine in self.resources.core_engines: - engine.proc_handle.wait_for_startup() + self._wait_for_engine_startup() self.utility_results: dict[int, AnyFuture] = {} + def _wait_for_engine_startup(self): + # Get a sync handle to the socket which can be sync or async. + sync_input_socket = zmq.Socket.shadow(self.input_socket) + + # Wait for engine core process(es) to send ready messages. + identities = set(eng.index for eng in self.resources.core_engines) + while identities: + while not sync_input_socket.poll(timeout=STARTUP_POLL_PERIOD_MS): + logger.info("Waiting for %d core engine proc(s) to start: %s", + len(identities), identities) + eng_id_bytes, msg = sync_input_socket.recv_multipart() + eng_id = int.from_bytes(eng_id_bytes, byteorder="little") + if eng_id not in identities: + raise RuntimeError(f"Unexpected or duplicate engine: {eng_id}") + if msg != b'READY': + raise RuntimeError(f"Engine {eng_id} failed: {msg.decode()}") + logger.info("Core engine process %d ready.", eng_id) + identities.discard(eng_id) + + # Double check that the process are running. + for engine in self.resources.core_engines: + proc = engine.proc_handle.proc + if proc.exitcode is not None: + raise RuntimeError(f"Engine proc {proc.name} not running") + def _init_core_engines( self, vllm_config: VllmConfig, @@ -494,9 +522,10 @@ class SyncMPClient(MPClient): return self.outputs_queue.get() def _send_input(self, request_type: EngineCoreRequestType, request: Any): - # (RequestType, SerializedRequest) - msg = (request_type.value, self.encoder.encode(request)) - self.core_engine.send_multipart(msg) + # (Identity, RequestType, SerializedRequest) + msg = (self.core_engine.identity, request_type.value, + self.encoder.encode(request)) + self.input_socket.send_multipart(msg, copy=False) def call_utility(self, method: str, *args) -> Any: call_id = uuid.uuid1().int >> 64 @@ -625,30 +654,34 @@ class AsyncMPClient(MPClient): assert self.outputs_queue is not None return await self.outputs_queue.get() - async def _send_input(self, request_type: EngineCoreRequestType, - request: Any) -> None: - await self.core_engine.send_multipart( - (request_type.value, self.encoder.encode(request))) + def _send_input(self, + request_type: EngineCoreRequestType, + request: Any, + engine: Optional[CoreEngine] = None) -> Awaitable[None]: + if engine is None: + engine = self.core_engine - self._ensure_output_queue_task() + message = (request_type.value, self.encoder.encode(request)) + return self._send_input_message(message, engine) + + def _send_input_message(self, message: tuple[bytes, bytes], + engine: CoreEngine) -> Awaitable[None]: + message = (engine.identity, ) + message # type: ignore[assignment] + return self.input_socket.send_multipart(message, copy=False) async def call_utility_async(self, method: str, *args) -> Any: return await self._call_utility_async(method, *args, engine=self.core_engine) - async def _call_utility_async( - self, - method: str, - *args, - engine: CoreEngine, - ) -> Any: + async def _call_utility_async(self, method: str, *args, + engine: CoreEngine) -> Any: call_id = uuid.uuid1().int >> 64 future = asyncio.get_running_loop().create_future() self.utility_results[call_id] = future message = (EngineCoreRequestType.UTILITY.value, self.encoder.encode((call_id, method, args))) - await engine.send_multipart(message) + await self._send_input_message(message, engine) self._ensure_output_queue_task() return await future @@ -657,6 +690,7 @@ class AsyncMPClient(MPClient): # tokenized. request.prompt = None await self._send_input(EngineCoreRequestType.ADD, request) + self._ensure_output_queue_task() async def abort_requests_async(self, request_ids: list[str]) -> None: if len(request_ids) > 0: @@ -761,15 +795,15 @@ class DPAsyncMPClient(AsyncMPClient): self.reqs_in_flight[request.request_id] = chosen_engine chosen_engine.num_reqs_in_flight += 1 if self.num_engines_running >= len(self.core_engines): - await chosen_engine.send_multipart(msg) + await self._send_input_message(msg, chosen_engine) else: # Send request to chosen engine and dp start loop # control message to all other engines. self.num_engines_running += len(self.core_engines) await asyncio.gather(*[ - engine.send_multipart(msg if engine is - chosen_engine else self.start_dp_msg) - for engine in self.core_engines + self._send_input_message( + msg if engine is chosen_engine else self.start_dp_msg, + engine) for engine in self.core_engines ]) self._ensure_output_queue_task() @@ -794,7 +828,7 @@ class DPAsyncMPClient(AsyncMPClient): # sure to start the other engines: self.num_engines_running = len(self.core_engines) coros = [ - engine.send_multipart(self.start_dp_msg) + self._send_input_message(self.start_dp_msg, engine) for engine in self.core_engines if not engine.num_reqs_in_flight ] @@ -820,5 +854,5 @@ class DPAsyncMPClient(AsyncMPClient): async def _abort_requests(self, request_ids: list[str], engine: CoreEngine) -> None: - await engine.send_multipart((EngineCoreRequestType.ABORT.value, - self.encoder.encode(request_ids))) + await self._send_input(EngineCoreRequestType.ABORT, request_ids, + engine) diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index f42b3501adb3b..fed5761b04b6c 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -105,12 +105,9 @@ class BackgroundProcHandle: process_kwargs: dict[Any, Any], ): context = get_mp_context() - self.reader, writer = context.Pipe(duplex=False) - assert ("ready_pipe" not in process_kwargs - and "input_path" not in process_kwargs + assert ("input_path" not in process_kwargs and "output_path" not in process_kwargs) - process_kwargs["ready_pipe"] = writer process_kwargs["input_path"] = input_path process_kwargs["output_path"] = output_path @@ -122,12 +119,6 @@ class BackgroundProcHandle: input_path, output_path) self.proc.start() - def wait_for_startup(self): - # Wait for startup. - if self.reader.recv()["status"] != "READY": - raise RuntimeError(f"{self.proc.name} initialization failed. " - "See root cause above.") - def shutdown(self): self._finalizer()