[Core][Distributed] Use IPC (domain socket) ZMQ socket for local comms (#13688)

This commit is contained in:
Nick Hill 2025-02-23 02:54:29 -08:00 committed by GitHub
parent ba5106e519
commit 5a2ba16f5c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -19,7 +19,8 @@ from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
import vllm.envs as envs
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
from vllm.utils import get_ip, get_open_port, is_valid_ipv6_address
from vllm.utils import (get_ip, get_open_port, get_open_zmq_ipc_path,
is_valid_ipv6_address)
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
@ -165,12 +166,12 @@ class ShmRingBuffer:
@dataclass
class Handle:
connect_ip: str
local_reader_ranks: List[int] = field(default_factory=list)
buffer_handle: Optional[Tuple[int, int, int, str]] = None
local_subscribe_port: Optional[int] = None
remote_subscribe_port: Optional[int] = None
local_subscribe_addr: Optional[str] = None
remote_subscribe_addr: Optional[str] = None
remote_addr_ipv6: bool = False
class MessageQueue:
@ -192,9 +193,6 @@ class MessageQueue:
n_remote_reader = n_reader - n_local_reader
self.n_remote_reader = n_remote_reader
if connect_ip is None:
connect_ip = get_ip() if n_remote_reader > 0 else "127.0.0.1"
context = Context()
if n_local_reader > 0:
@ -212,32 +210,34 @@ class MessageQueue:
# message. otherwise, we will only receive the first subscription
# see http://api.zeromq.org/3-3:zmq-setsockopt for more details
self.local_socket.setsockopt(XPUB_VERBOSE, True)
local_subscribe_port = get_open_port()
socket_addr = f"tcp://127.0.0.1:{local_subscribe_port}"
logger.debug("Binding to %s", socket_addr)
self.local_socket.bind(socket_addr)
local_subscribe_addr = get_open_zmq_ipc_path()
logger.debug("Binding to %s", local_subscribe_addr)
self.local_socket.bind(local_subscribe_addr)
self.current_idx = 0
else:
self.buffer = None # type: ignore
local_subscribe_port = None
local_subscribe_addr = None
self.local_socket = None
self.current_idx = -1
remote_addr_ipv6 = False
if n_remote_reader > 0:
# for remote readers, we will:
# create a publish-subscribe socket to communicate large data
if not connect_ip:
connect_ip = get_ip()
self.remote_socket = context.socket(XPUB)
self.remote_socket.setsockopt(XPUB_VERBOSE, True)
remote_subscribe_port = get_open_port()
if is_valid_ipv6_address(connect_ip):
self.remote_socket.setsockopt(IPV6, 1)
remote_addr_ipv6 = True
socket_addr = f"tcp://*:{remote_subscribe_port}"
self.remote_socket.bind(socket_addr)
remote_subscribe_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
else:
remote_subscribe_port = None
remote_subscribe_addr = None
self.remote_socket = None
self._is_writer = True
@ -247,12 +247,12 @@ class MessageQueue:
self._is_remote_reader = False
self.handle = Handle(
connect_ip=connect_ip,
local_reader_ranks=local_reader_ranks,
buffer_handle=self.buffer.handle()
if self.buffer is not None else None,
local_subscribe_port=local_subscribe_port,
remote_subscribe_port=remote_subscribe_port,
local_subscribe_addr=local_subscribe_addr,
remote_subscribe_addr=remote_subscribe_addr,
remote_addr_ipv6=remote_addr_ipv6,
)
logger.info("vLLM message queue communication handle: %s", self.handle)
@ -278,7 +278,7 @@ class MessageQueue:
self.local_socket = context.socket(SUB)
self.local_socket.setsockopt_string(SUBSCRIBE, "")
socket_addr = f"tcp://127.0.0.1:{handle.local_subscribe_port}"
socket_addr = handle.local_subscribe_addr
logger.debug("Connecting to %s", socket_addr)
self.local_socket.connect(socket_addr)
@ -294,9 +294,9 @@ class MessageQueue:
self.remote_socket = context.socket(SUB)
self.remote_socket.setsockopt_string(SUBSCRIBE, "")
if is_valid_ipv6_address(handle.connect_ip):
if handle.remote_addr_ipv6:
self.remote_socket.setsockopt(IPV6, 1)
socket_addr = f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}"
socket_addr = handle.remote_subscribe_addr
logger.debug("Connecting to %s", socket_addr)
self.remote_socket.connect(socket_addr)