mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 10:40:44 +08:00
[Core][Distributed] Use IPC (domain socket) ZMQ socket for local comms (#13688)
This commit is contained in:
parent
ba5106e519
commit
5a2ba16f5c
@ -19,7 +19,8 @@ from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import get_ip, get_open_port, is_valid_ipv6_address
|
||||
from vllm.utils import (get_ip, get_open_port, get_open_zmq_ipc_path,
|
||||
is_valid_ipv6_address)
|
||||
|
||||
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
|
||||
|
||||
@ -165,12 +166,12 @@ class ShmRingBuffer:
|
||||
|
||||
@dataclass
|
||||
class Handle:
|
||||
connect_ip: str
|
||||
local_reader_ranks: List[int] = field(default_factory=list)
|
||||
|
||||
buffer_handle: Optional[Tuple[int, int, int, str]] = None
|
||||
local_subscribe_port: Optional[int] = None
|
||||
remote_subscribe_port: Optional[int] = None
|
||||
local_subscribe_addr: Optional[str] = None
|
||||
remote_subscribe_addr: Optional[str] = None
|
||||
remote_addr_ipv6: bool = False
|
||||
|
||||
|
||||
class MessageQueue:
|
||||
@ -192,9 +193,6 @@ class MessageQueue:
|
||||
n_remote_reader = n_reader - n_local_reader
|
||||
self.n_remote_reader = n_remote_reader
|
||||
|
||||
if connect_ip is None:
|
||||
connect_ip = get_ip() if n_remote_reader > 0 else "127.0.0.1"
|
||||
|
||||
context = Context()
|
||||
|
||||
if n_local_reader > 0:
|
||||
@ -212,32 +210,34 @@ class MessageQueue:
|
||||
# message. otherwise, we will only receive the first subscription
|
||||
# see http://api.zeromq.org/3-3:zmq-setsockopt for more details
|
||||
self.local_socket.setsockopt(XPUB_VERBOSE, True)
|
||||
local_subscribe_port = get_open_port()
|
||||
socket_addr = f"tcp://127.0.0.1:{local_subscribe_port}"
|
||||
logger.debug("Binding to %s", socket_addr)
|
||||
self.local_socket.bind(socket_addr)
|
||||
local_subscribe_addr = get_open_zmq_ipc_path()
|
||||
logger.debug("Binding to %s", local_subscribe_addr)
|
||||
self.local_socket.bind(local_subscribe_addr)
|
||||
|
||||
self.current_idx = 0
|
||||
|
||||
else:
|
||||
self.buffer = None # type: ignore
|
||||
local_subscribe_port = None
|
||||
local_subscribe_addr = None
|
||||
self.local_socket = None
|
||||
self.current_idx = -1
|
||||
|
||||
remote_addr_ipv6 = False
|
||||
if n_remote_reader > 0:
|
||||
# for remote readers, we will:
|
||||
# create a publish-subscribe socket to communicate large data
|
||||
if not connect_ip:
|
||||
connect_ip = get_ip()
|
||||
self.remote_socket = context.socket(XPUB)
|
||||
self.remote_socket.setsockopt(XPUB_VERBOSE, True)
|
||||
remote_subscribe_port = get_open_port()
|
||||
if is_valid_ipv6_address(connect_ip):
|
||||
self.remote_socket.setsockopt(IPV6, 1)
|
||||
remote_addr_ipv6 = True
|
||||
socket_addr = f"tcp://*:{remote_subscribe_port}"
|
||||
self.remote_socket.bind(socket_addr)
|
||||
|
||||
remote_subscribe_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
|
||||
else:
|
||||
remote_subscribe_port = None
|
||||
remote_subscribe_addr = None
|
||||
self.remote_socket = None
|
||||
|
||||
self._is_writer = True
|
||||
@ -247,12 +247,12 @@ class MessageQueue:
|
||||
self._is_remote_reader = False
|
||||
|
||||
self.handle = Handle(
|
||||
connect_ip=connect_ip,
|
||||
local_reader_ranks=local_reader_ranks,
|
||||
buffer_handle=self.buffer.handle()
|
||||
if self.buffer is not None else None,
|
||||
local_subscribe_port=local_subscribe_port,
|
||||
remote_subscribe_port=remote_subscribe_port,
|
||||
local_subscribe_addr=local_subscribe_addr,
|
||||
remote_subscribe_addr=remote_subscribe_addr,
|
||||
remote_addr_ipv6=remote_addr_ipv6,
|
||||
)
|
||||
|
||||
logger.info("vLLM message queue communication handle: %s", self.handle)
|
||||
@ -278,7 +278,7 @@ class MessageQueue:
|
||||
|
||||
self.local_socket = context.socket(SUB)
|
||||
self.local_socket.setsockopt_string(SUBSCRIBE, "")
|
||||
socket_addr = f"tcp://127.0.0.1:{handle.local_subscribe_port}"
|
||||
socket_addr = handle.local_subscribe_addr
|
||||
logger.debug("Connecting to %s", socket_addr)
|
||||
self.local_socket.connect(socket_addr)
|
||||
|
||||
@ -294,9 +294,9 @@ class MessageQueue:
|
||||
|
||||
self.remote_socket = context.socket(SUB)
|
||||
self.remote_socket.setsockopt_string(SUBSCRIBE, "")
|
||||
if is_valid_ipv6_address(handle.connect_ip):
|
||||
if handle.remote_addr_ipv6:
|
||||
self.remote_socket.setsockopt(IPV6, 1)
|
||||
socket_addr = f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}"
|
||||
socket_addr = handle.remote_subscribe_addr
|
||||
logger.debug("Connecting to %s", socket_addr)
|
||||
self.remote_socket.connect(socket_addr)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user