From 5a2ba16f5c6cf508d6a8e8a82cdfd4741190166a Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Sun, 23 Feb 2025 02:54:29 -0800 Subject: [PATCH] [Core][Distributed] Use IPC (domain socket) ZMQ socket for local comms (#13688) --- .../device_communicators/shm_broadcast.py | 42 +++++++++---------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index 48ac81ac008b2..12a720d47fbba 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -19,7 +19,8 @@ from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore import vllm.envs as envs from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger -from vllm.utils import get_ip, get_open_port, is_valid_ipv6_address +from vllm.utils import (get_ip, get_open_port, get_open_zmq_ipc_path, + is_valid_ipv6_address) VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL @@ -165,12 +166,12 @@ class ShmRingBuffer: @dataclass class Handle: - connect_ip: str local_reader_ranks: List[int] = field(default_factory=list) buffer_handle: Optional[Tuple[int, int, int, str]] = None - local_subscribe_port: Optional[int] = None - remote_subscribe_port: Optional[int] = None + local_subscribe_addr: Optional[str] = None + remote_subscribe_addr: Optional[str] = None + remote_addr_ipv6: bool = False class MessageQueue: @@ -192,9 +193,6 @@ class MessageQueue: n_remote_reader = n_reader - n_local_reader self.n_remote_reader = n_remote_reader - if connect_ip is None: - connect_ip = get_ip() if n_remote_reader > 0 else "127.0.0.1" - context = Context() if n_local_reader > 0: @@ -212,32 +210,34 @@ class MessageQueue: # message. otherwise, we will only receive the first subscription # see http://api.zeromq.org/3-3:zmq-setsockopt for more details self.local_socket.setsockopt(XPUB_VERBOSE, True) - local_subscribe_port = get_open_port() - socket_addr = f"tcp://127.0.0.1:{local_subscribe_port}" - logger.debug("Binding to %s", socket_addr) - self.local_socket.bind(socket_addr) + local_subscribe_addr = get_open_zmq_ipc_path() + logger.debug("Binding to %s", local_subscribe_addr) + self.local_socket.bind(local_subscribe_addr) self.current_idx = 0 - else: self.buffer = None # type: ignore - local_subscribe_port = None + local_subscribe_addr = None self.local_socket = None self.current_idx = -1 + remote_addr_ipv6 = False if n_remote_reader > 0: # for remote readers, we will: # create a publish-subscribe socket to communicate large data + if not connect_ip: + connect_ip = get_ip() self.remote_socket = context.socket(XPUB) self.remote_socket.setsockopt(XPUB_VERBOSE, True) remote_subscribe_port = get_open_port() if is_valid_ipv6_address(connect_ip): self.remote_socket.setsockopt(IPV6, 1) + remote_addr_ipv6 = True socket_addr = f"tcp://*:{remote_subscribe_port}" self.remote_socket.bind(socket_addr) - + remote_subscribe_addr = f"tcp://{connect_ip}:{remote_subscribe_port}" else: - remote_subscribe_port = None + remote_subscribe_addr = None self.remote_socket = None self._is_writer = True @@ -247,12 +247,12 @@ class MessageQueue: self._is_remote_reader = False self.handle = Handle( - connect_ip=connect_ip, local_reader_ranks=local_reader_ranks, buffer_handle=self.buffer.handle() if self.buffer is not None else None, - local_subscribe_port=local_subscribe_port, - remote_subscribe_port=remote_subscribe_port, + local_subscribe_addr=local_subscribe_addr, + remote_subscribe_addr=remote_subscribe_addr, + remote_addr_ipv6=remote_addr_ipv6, ) logger.info("vLLM message queue communication handle: %s", self.handle) @@ -278,7 +278,7 @@ class MessageQueue: self.local_socket = context.socket(SUB) self.local_socket.setsockopt_string(SUBSCRIBE, "") - socket_addr = f"tcp://127.0.0.1:{handle.local_subscribe_port}" + socket_addr = handle.local_subscribe_addr logger.debug("Connecting to %s", socket_addr) self.local_socket.connect(socket_addr) @@ -294,9 +294,9 @@ class MessageQueue: self.remote_socket = context.socket(SUB) self.remote_socket.setsockopt_string(SUBSCRIBE, "") - if is_valid_ipv6_address(handle.connect_ip): + if handle.remote_addr_ipv6: self.remote_socket.setsockopt(IPV6, 1) - socket_addr = f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}" + socket_addr = handle.remote_subscribe_addr logger.debug("Connecting to %s", socket_addr) self.remote_socket.connect(socket_addr)