[Core][Distributed] Use IPC (domain socket) ZMQ socket for local comms (#13688)

This commit is contained in:
Nick Hill 2025-02-23 02:54:29 -08:00 committed by GitHub
parent ba5106e519
commit 5a2ba16f5c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -19,7 +19,8 @@ from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
import vllm.envs as envs import vllm.envs as envs
from vllm.distributed.utils import StatelessProcessGroup from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import get_ip, get_open_port, is_valid_ipv6_address from vllm.utils import (get_ip, get_open_port, get_open_zmq_ipc_path,
is_valid_ipv6_address)
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
@ -165,12 +166,12 @@ class ShmRingBuffer:
@dataclass @dataclass
class Handle: class Handle:
connect_ip: str
local_reader_ranks: List[int] = field(default_factory=list) local_reader_ranks: List[int] = field(default_factory=list)
buffer_handle: Optional[Tuple[int, int, int, str]] = None buffer_handle: Optional[Tuple[int, int, int, str]] = None
local_subscribe_port: Optional[int] = None local_subscribe_addr: Optional[str] = None
remote_subscribe_port: Optional[int] = None remote_subscribe_addr: Optional[str] = None
remote_addr_ipv6: bool = False
class MessageQueue: class MessageQueue:
@ -192,9 +193,6 @@ class MessageQueue:
n_remote_reader = n_reader - n_local_reader n_remote_reader = n_reader - n_local_reader
self.n_remote_reader = n_remote_reader self.n_remote_reader = n_remote_reader
if connect_ip is None:
connect_ip = get_ip() if n_remote_reader > 0 else "127.0.0.1"
context = Context() context = Context()
if n_local_reader > 0: if n_local_reader > 0:
@ -212,32 +210,34 @@ class MessageQueue:
# message. otherwise, we will only receive the first subscription # message. otherwise, we will only receive the first subscription
# see http://api.zeromq.org/3-3:zmq-setsockopt for more details # see http://api.zeromq.org/3-3:zmq-setsockopt for more details
self.local_socket.setsockopt(XPUB_VERBOSE, True) self.local_socket.setsockopt(XPUB_VERBOSE, True)
local_subscribe_port = get_open_port() local_subscribe_addr = get_open_zmq_ipc_path()
socket_addr = f"tcp://127.0.0.1:{local_subscribe_port}" logger.debug("Binding to %s", local_subscribe_addr)
logger.debug("Binding to %s", socket_addr) self.local_socket.bind(local_subscribe_addr)
self.local_socket.bind(socket_addr)
self.current_idx = 0 self.current_idx = 0
else: else:
self.buffer = None # type: ignore self.buffer = None # type: ignore
local_subscribe_port = None local_subscribe_addr = None
self.local_socket = None self.local_socket = None
self.current_idx = -1 self.current_idx = -1
remote_addr_ipv6 = False
if n_remote_reader > 0: if n_remote_reader > 0:
# for remote readers, we will: # for remote readers, we will:
# create a publish-subscribe socket to communicate large data # create a publish-subscribe socket to communicate large data
if not connect_ip:
connect_ip = get_ip()
self.remote_socket = context.socket(XPUB) self.remote_socket = context.socket(XPUB)
self.remote_socket.setsockopt(XPUB_VERBOSE, True) self.remote_socket.setsockopt(XPUB_VERBOSE, True)
remote_subscribe_port = get_open_port() remote_subscribe_port = get_open_port()
if is_valid_ipv6_address(connect_ip): if is_valid_ipv6_address(connect_ip):
self.remote_socket.setsockopt(IPV6, 1) self.remote_socket.setsockopt(IPV6, 1)
remote_addr_ipv6 = True
socket_addr = f"tcp://*:{remote_subscribe_port}" socket_addr = f"tcp://*:{remote_subscribe_port}"
self.remote_socket.bind(socket_addr) self.remote_socket.bind(socket_addr)
remote_subscribe_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
else: else:
remote_subscribe_port = None remote_subscribe_addr = None
self.remote_socket = None self.remote_socket = None
self._is_writer = True self._is_writer = True
@ -247,12 +247,12 @@ class MessageQueue:
self._is_remote_reader = False self._is_remote_reader = False
self.handle = Handle( self.handle = Handle(
connect_ip=connect_ip,
local_reader_ranks=local_reader_ranks, local_reader_ranks=local_reader_ranks,
buffer_handle=self.buffer.handle() buffer_handle=self.buffer.handle()
if self.buffer is not None else None, if self.buffer is not None else None,
local_subscribe_port=local_subscribe_port, local_subscribe_addr=local_subscribe_addr,
remote_subscribe_port=remote_subscribe_port, remote_subscribe_addr=remote_subscribe_addr,
remote_addr_ipv6=remote_addr_ipv6,
) )
logger.info("vLLM message queue communication handle: %s", self.handle) logger.info("vLLM message queue communication handle: %s", self.handle)
@ -278,7 +278,7 @@ class MessageQueue:
self.local_socket = context.socket(SUB) self.local_socket = context.socket(SUB)
self.local_socket.setsockopt_string(SUBSCRIBE, "") self.local_socket.setsockopt_string(SUBSCRIBE, "")
socket_addr = f"tcp://127.0.0.1:{handle.local_subscribe_port}" socket_addr = handle.local_subscribe_addr
logger.debug("Connecting to %s", socket_addr) logger.debug("Connecting to %s", socket_addr)
self.local_socket.connect(socket_addr) self.local_socket.connect(socket_addr)
@ -294,9 +294,9 @@ class MessageQueue:
self.remote_socket = context.socket(SUB) self.remote_socket = context.socket(SUB)
self.remote_socket.setsockopt_string(SUBSCRIBE, "") self.remote_socket.setsockopt_string(SUBSCRIBE, "")
if is_valid_ipv6_address(handle.connect_ip): if handle.remote_addr_ipv6:
self.remote_socket.setsockopt(IPV6, 1) self.remote_socket.setsockopt(IPV6, 1)
socket_addr = f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}" socket_addr = handle.remote_subscribe_addr
logger.debug("Connecting to %s", socket_addr) logger.debug("Connecting to %s", socket_addr)
self.remote_socket.connect(socket_addr) self.remote_socket.connect(socket_addr)