diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 7405f3986df8..0d3fa6b059be 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -2794,6 +2794,9 @@ def make_zmq_socket( if linger is not None: socket.setsockopt(zmq.LINGER, linger) + if socket_type == zmq.XPUB: + socket.setsockopt(zmq.XPUB_VERBOSE, True) + # Determine if the path is a TCP socket with an IPv6 address. # Enable IPv6 on the zmq socket if so. scheme, host, _ = split_zmq_path(path) diff --git a/vllm/v1/engine/coordinator.py b/vllm/v1/engine/coordinator.py index 440628576bcb..8d8d1689e61e 100644 --- a/vllm/v1/engine/coordinator.py +++ b/vllm/v1/engine/coordinator.py @@ -172,6 +172,18 @@ class DPCoordinatorProc: bind=True, ) as publish_back: + # Wait until all engines subscribe. + for _ in self.engines: + if publish_back.recv() != b'\x01': + logger.error( + "DP Coordinator received unexpected message while " + "waiting for engines to subscribe") + return + # Send ready message to engines. + publish_back.send(b"READY") + + logger.info("All engine subscriptions received by DP coordinator") + poller = zmq.Poller() poller.register(publish_front, zmq.POLLIN) poller.register(output_back, zmq.POLLIN) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 6ae5736df98b..0a889b2a0a18 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -461,8 +461,11 @@ class EngineCoreProc(EngineCore): self.has_coordinator = addresses.coordinator_output is not None self.frontend_stats_publish_address = ( addresses.frontend_stats_publish_address) + logger.debug("Has DP Coordinator: %s, stats publish address: %s", + self.has_coordinator, + self.frontend_stats_publish_address) # Only publish request queue stats to coordinator for "internal" - # LB mode. + # and "hybrid" LB modes . self.publish_dp_lb_stats = ( self.has_coordinator and not vllm_config.parallel_config.data_parallel_external_lb) @@ -472,25 +475,38 @@ class EngineCoreProc(EngineCore): super().__init__(vllm_config, executor_class, log_stats, executor_fail_callback) + # Background Threads and Queues for IO. These enable us to + # overlap ZMQ socket IO with GPU since they release the GIL, + # and to overlap some serialization/deserialization with the + # model forward pass. + # Threads handle Socket <-> Queues and core_busy_loop uses Queue. + ready_event = threading.Event() + input_thread = threading.Thread(target=self.process_input_sockets, + args=(addresses.inputs, + addresses.coordinator_input, + identity, ready_event), + daemon=True) + input_thread.start() + + self.output_thread = threading.Thread( + target=self.process_output_sockets, + args=(addresses.outputs, addresses.coordinator_output, + self.engine_index), + daemon=True) + self.output_thread.start() + + # Don't complete handshake until DP coordinator ready message is + # received. + while not ready_event.wait(timeout=10): + if not input_thread.is_alive(): + raise RuntimeError( + "Input socket thread died during startup") + assert addresses.coordinator_input is not None + logger.info("Waiting for READY message from DP Coordinator...") + self.step_fn = (self.step if self.batch_queue is None else self.step_with_batch_queue) - # Background Threads and Queues for IO. These enable us to - # overlap ZMQ socket IO with GPU since they release the GIL, - # and to overlap some serialization/deserialization with the - # model forward pass. - # Threads handle Socket <-> Queues and core_busy_loop uses Queue. - threading.Thread(target=self.process_input_sockets, - args=(addresses.inputs, addresses.coordinator_input, - identity), - daemon=True).start() - self.output_thread = threading.Thread( - target=self.process_output_sockets, - args=(addresses.outputs, addresses.coordinator_output, - self.engine_index), - daemon=True) - self.output_thread.start() - @contextmanager def _perform_handshakes( self, @@ -505,10 +521,10 @@ class EngineCoreProc(EngineCore): For DP=1 or offline mode, this is with the colocated front-end process. - For DP>1 with internal loadbalancing this is with the shared front-end + For DP>1 with internal load-balancing this is with the shared front-end process which may reside on a different node. - For DP>1 with external or hybrid loadbalancing, two handshakes are + For DP>1 with external or hybrid load-balancing, two handshakes are performed: - With the rank 0 front-end process which retrieves the DP Coordinator ZMQ addresses and DP process group address. @@ -772,7 +788,7 @@ class EngineCoreProc(EngineCore): def process_input_sockets(self, input_addresses: list[str], coord_input_address: Optional[str], - identity: bytes): + identity: bytes, ready_event: threading.Event): """Input socket IO thread.""" # Msgpack serialization decoding. @@ -809,9 +825,14 @@ class EngineCoreProc(EngineCore): # back to us. input_socket.send(b'') poller.register(input_socket, zmq.POLLIN) + if coord_socket is not None: + # Wait for ready message from coordinator. + assert coord_socket.recv() == b"READY" poller.register(coord_socket, zmq.POLLIN) + ready_event.set() + del ready_event while True: for input_socket, _ in poller.poll(): # (RequestType, RequestData)