From 296c6572dd1f76b31b93be19e550790afcfb8843 Mon Sep 17 00:00:00 2001 From: simon-mo Date: Sat, 5 Apr 2025 21:10:57 -0700 Subject: [PATCH] Revert "[V1] DP scale-out (1/N): Use zmq ROUTER/DEALER sockets for input queue (#15906)" This reverts commit 651cf0fec19ed1da44ad266066f86b74e5246006. --- vllm/utils.py | 35 ++++------- vllm/v1/engine/core.py | 28 ++++----- vllm/v1/engine/core_client.py | 108 ++++++++++++---------------------- vllm/v1/utils.py | 11 +++- 4 files changed, 69 insertions(+), 113 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index 46f01638d0eb5..5f32f8cb66a5c 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -2189,8 +2189,6 @@ def make_zmq_socket( ctx: Union[zmq.asyncio.Context, zmq.Context], # type: ignore[name-defined] path: str, socket_type: Any, - bind: Optional[bool] = None, - identity: Optional[bytes] = None, ) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined] """Make a ZMQ socket with the proper bind/connect semantics.""" @@ -2209,24 +2207,16 @@ def make_zmq_socket( else: buf_size = -1 # Use system default buffer size - if bind is None: - bind = socket_type != zmq.PUSH - - if socket_type in (zmq.PULL, zmq.DEALER, zmq.ROUTER): - socket.setsockopt(zmq.RCVHWM, 0) - socket.setsockopt(zmq.RCVBUF, buf_size) - - if socket_type in (zmq.PUSH, zmq.DEALER, zmq.ROUTER): - socket.setsockopt(zmq.SNDHWM, 0) - socket.setsockopt(zmq.SNDBUF, buf_size) - - if identity is not None: - socket.setsockopt(zmq.IDENTITY, identity) - - if bind: + if socket_type == zmq.constants.PULL: + socket.setsockopt(zmq.constants.RCVHWM, 0) + socket.setsockopt(zmq.constants.RCVBUF, buf_size) socket.bind(path) - else: + elif socket_type == zmq.constants.PUSH: + socket.setsockopt(zmq.constants.SNDHWM, 0) + socket.setsockopt(zmq.constants.SNDBUF, buf_size) socket.connect(path) + else: + raise ValueError(f"Unknown Socket Type: {socket_type}") return socket @@ -2235,19 +2225,14 @@ def make_zmq_socket( def zmq_socket_ctx( path: str, socket_type: Any, - bind: Optional[bool] = None, linger: int = 0, - identity: Optional[bytes] = None, ) -> Iterator[zmq.Socket]: """Context manager for a ZMQ socket""" ctx = zmq.Context() # type: ignore[attr-defined] try: - yield make_zmq_socket(ctx, - path, - socket_type, - bind=bind, - identity=identity) + yield make_zmq_socket(ctx, path, socket_type) + except KeyboardInterrupt: logger.debug("Got Keyboard Interrupt.") diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index f58c77e4f1658..39caca0c2a452 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -318,11 +318,6 @@ class EngineCoreProc(EngineCore): ): super().__init__(vllm_config, executor_class, log_stats) - self.step_fn = (self.step if self.batch_queue is None else - self.step_with_batch_queue) - - self.global_unfinished_reqs = False - # Background Threads and Queues for IO. These enable us to # overlap ZMQ socket IO with GPU since they release the GIL, # and to overlap some serialization/deserialization with the @@ -332,16 +327,22 @@ class EngineCoreProc(EngineCore): Any]] = queue.Queue() self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue() threading.Thread(target=self.process_input_socket, - args=(input_path, engine_index), + args=(input_path, ), daemon=True).start() threading.Thread(target=self.process_output_socket, args=(output_path, engine_index), daemon=True).start() + self.global_unfinished_reqs = False + + self.step_fn = (self.step if self.batch_queue is None else + self.step_with_batch_queue) + @staticmethod def run_engine_core(*args, dp_rank: int = 0, local_dp_rank: int = 0, + ready_pipe, **kwargs): """Launch EngineCore busy loop in background process.""" @@ -376,6 +377,9 @@ class EngineCoreProc(EngineCore): else: engine_core = EngineCoreProc(*args, **kwargs) + # Send Readiness signal to EngineClient. + ready_pipe.send({"status": "READY"}) + engine_core.run_busy_loop() except SystemExit: @@ -472,22 +476,14 @@ class EngineCoreProc(EngineCore): and not isinstance(v, p.annotation) else v for v, p in zip(args, arg_types)) - def process_input_socket(self, input_path: str, engine_index: int): + def process_input_socket(self, input_path: str): """Input socket IO thread.""" # Msgpack serialization decoding. add_request_decoder = MsgpackDecoder(EngineCoreRequest) generic_decoder = MsgpackDecoder() - identity = engine_index.to_bytes(length=2, byteorder="little") - - with zmq_socket_ctx(input_path, - zmq.DEALER, - identity=identity, - bind=False) as socket: - - # Send ready message to front-end once input socket is connected. - socket.send(b'READY') + with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket: while True: # (RequestType, RequestData) type_frame, data_frame = socket.recv_multipart(copy=False) diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index b94b0aa75386a..e948e59b8c425 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -8,7 +8,7 @@ import threading import uuid import weakref from abc import ABC, abstractmethod -from collections.abc import Awaitable +from collections.abc import Awaitable, Sequence from concurrent.futures import Future from dataclasses import dataclass, field from threading import Thread @@ -35,8 +35,6 @@ AnyFuture = Union[asyncio.Future[Any], Future[Any]] _R = TypeVar('_R') # Return type for collective_rpc -STARTUP_POLL_PERIOD_MS = 10000 - class EngineCoreClient(ABC): """ @@ -263,13 +261,15 @@ class CoreEngine: vllm_config: VllmConfig, executor_class: type[Executor], log_stats: bool, - input_path: str, + ctx: Union[zmq.Context, zmq.asyncio.Context], output_path: str, index: int = 0, local_dp_rank: int = 0, ): - self.index = index - self.identity = index.to_bytes(length=2, byteorder="little") + # Paths and sockets for IPC. + input_path = get_open_zmq_ipc_path() + self.input_socket = make_zmq_socket(ctx, input_path, + zmq.constants.PUSH) try: # Start EngineCore in background process. self.proc_handle = BackgroundProcHandle( @@ -291,9 +291,14 @@ class CoreEngine: # Ensure socket is closed if process fails to start. self.close() + def send_multipart(self, msg_parts: Sequence): + return self.input_socket.send_multipart(msg_parts, copy=False) + def close(self): if proc_handle := getattr(self, "proc_handle", None): proc_handle.shutdown() + if socket := getattr(self, "input_socket", None): + socket.close(linger=0) @dataclass @@ -304,7 +309,6 @@ class BackgroundResources: ctx: Union[zmq.Context] core_engines: list[CoreEngine] = field(default_factory=list) output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None - input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None shutdown_path: Optional[str] = None def __call__(self): @@ -317,8 +321,6 @@ class BackgroundResources: # aren't explicitly closed first. if self.output_socket is not None: self.output_socket.close(linger=0) - if self.input_socket is not None: - self.input_socket.close(linger=0) if self.shutdown_path is not None: # We must ensure that the sync output socket is # closed cleanly in its own thread. @@ -385,51 +387,21 @@ class MPClient(EngineCoreClient): # Paths and sockets for IPC. self.output_path = get_open_zmq_ipc_path() - input_path = get_open_zmq_ipc_path() - self.input_socket = make_zmq_socket(self.ctx, - input_path, - zmq.ROUTER, - bind=True) - self.resources.input_socket = self.input_socket new_core_engine = lambda index, local_dp_rank=None: CoreEngine( - vllm_config, executor_class, log_stats, input_path, self. - output_path, index, local_dp_rank) + vllm_config, executor_class, log_stats, self.ctx, self.output_path, + index, local_dp_rank) # Start engine core process(es). self._init_core_engines(vllm_config, new_core_engine, self.resources.core_engines) # Wait for engine core process(es) to start. - self._wait_for_engine_startup() + for engine in self.resources.core_engines: + engine.proc_handle.wait_for_startup() self.utility_results: dict[int, AnyFuture] = {} - def _wait_for_engine_startup(self): - # Get a sync handle to the socket which can be sync or async. - sync_input_socket = zmq.Socket.shadow(self.input_socket) - - # Wait for engine core process(es) to send ready messages. - identities = set(eng.index for eng in self.resources.core_engines) - while identities: - while not sync_input_socket.poll(timeout=STARTUP_POLL_PERIOD_MS): - logger.info("Waiting for %d core engine proc(s) to start: %s", - len(identities), identities) - eng_id_bytes, msg = sync_input_socket.recv_multipart() - eng_id = int.from_bytes(eng_id_bytes, byteorder="little") - if eng_id not in identities: - raise RuntimeError(f"Unexpected or duplicate engine: {eng_id}") - if msg != b'READY': - raise RuntimeError(f"Engine {eng_id} failed: {msg.decode()}") - logger.info("Core engine process %d ready.", eng_id) - identities.discard(eng_id) - - # Double check that the process are running. - for engine in self.resources.core_engines: - proc = engine.proc_handle.proc - if proc.exitcode is not None: - raise RuntimeError(f"Engine proc {proc.name} not running") - def _init_core_engines( self, vllm_config: VllmConfig, @@ -522,10 +494,9 @@ class SyncMPClient(MPClient): return self.outputs_queue.get() def _send_input(self, request_type: EngineCoreRequestType, request: Any): - # (Identity, RequestType, SerializedRequest) - msg = (self.core_engine.identity, request_type.value, - self.encoder.encode(request)) - self.input_socket.send_multipart(msg, copy=False) + # (RequestType, SerializedRequest) + msg = (request_type.value, self.encoder.encode(request)) + self.core_engine.send_multipart(msg) def call_utility(self, method: str, *args) -> Any: call_id = uuid.uuid1().int >> 64 @@ -654,34 +625,30 @@ class AsyncMPClient(MPClient): assert self.outputs_queue is not None return await self.outputs_queue.get() - def _send_input(self, - request_type: EngineCoreRequestType, - request: Any, - engine: Optional[CoreEngine] = None) -> Awaitable[None]: - if engine is None: - engine = self.core_engine + async def _send_input(self, request_type: EngineCoreRequestType, + request: Any) -> None: + await self.core_engine.send_multipart( + (request_type.value, self.encoder.encode(request))) - message = (request_type.value, self.encoder.encode(request)) - return self._send_input_message(message, engine) - - def _send_input_message(self, message: tuple[bytes, bytes], - engine: CoreEngine) -> Awaitable[None]: - message = (engine.identity, ) + message # type: ignore[assignment] - return self.input_socket.send_multipart(message, copy=False) + self._ensure_output_queue_task() async def call_utility_async(self, method: str, *args) -> Any: return await self._call_utility_async(method, *args, engine=self.core_engine) - async def _call_utility_async(self, method: str, *args, - engine: CoreEngine) -> Any: + async def _call_utility_async( + self, + method: str, + *args, + engine: CoreEngine, + ) -> Any: call_id = uuid.uuid1().int >> 64 future = asyncio.get_running_loop().create_future() self.utility_results[call_id] = future message = (EngineCoreRequestType.UTILITY.value, self.encoder.encode((call_id, method, args))) - await self._send_input_message(message, engine) + await engine.send_multipart(message) self._ensure_output_queue_task() return await future @@ -690,7 +657,6 @@ class AsyncMPClient(MPClient): # tokenized. request.prompt = None await self._send_input(EngineCoreRequestType.ADD, request) - self._ensure_output_queue_task() async def abort_requests_async(self, request_ids: list[str]) -> None: if len(request_ids) > 0: @@ -795,15 +761,15 @@ class DPAsyncMPClient(AsyncMPClient): self.reqs_in_flight[request.request_id] = chosen_engine chosen_engine.num_reqs_in_flight += 1 if self.num_engines_running >= len(self.core_engines): - await self._send_input_message(msg, chosen_engine) + await chosen_engine.send_multipart(msg) else: # Send request to chosen engine and dp start loop # control message to all other engines. self.num_engines_running += len(self.core_engines) await asyncio.gather(*[ - self._send_input_message( - msg if engine is chosen_engine else self.start_dp_msg, - engine) for engine in self.core_engines + engine.send_multipart(msg if engine is + chosen_engine else self.start_dp_msg) + for engine in self.core_engines ]) self._ensure_output_queue_task() @@ -828,7 +794,7 @@ class DPAsyncMPClient(AsyncMPClient): # sure to start the other engines: self.num_engines_running = len(self.core_engines) coros = [ - self._send_input_message(self.start_dp_msg, engine) + engine.send_multipart(self.start_dp_msg) for engine in self.core_engines if not engine.num_reqs_in_flight ] @@ -854,5 +820,5 @@ class DPAsyncMPClient(AsyncMPClient): async def _abort_requests(self, request_ids: list[str], engine: CoreEngine) -> None: - await self._send_input(EngineCoreRequestType.ABORT, request_ids, - engine) + await engine.send_multipart((EngineCoreRequestType.ABORT.value, + self.encoder.encode(request_ids))) diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index fed5761b04b6c..f42b3501adb3b 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -105,9 +105,12 @@ class BackgroundProcHandle: process_kwargs: dict[Any, Any], ): context = get_mp_context() + self.reader, writer = context.Pipe(duplex=False) - assert ("input_path" not in process_kwargs + assert ("ready_pipe" not in process_kwargs + and "input_path" not in process_kwargs and "output_path" not in process_kwargs) + process_kwargs["ready_pipe"] = writer process_kwargs["input_path"] = input_path process_kwargs["output_path"] = output_path @@ -119,6 +122,12 @@ class BackgroundProcHandle: input_path, output_path) self.proc.start() + def wait_for_startup(self): + # Wait for startup. + if self.reader.recv()["status"] != "READY": + raise RuntimeError(f"{self.proc.name} initialization failed. " + "See root cause above.") + def shutdown(self): self._finalizer()