mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 18:25:45 +08:00
[Core][Distributed] Use IPC (domain socket) ZMQ socket for local comms (#13688)
This commit is contained in:
parent
ba5106e519
commit
5a2ba16f5c
@ -19,7 +19,8 @@ from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.distributed.utils import StatelessProcessGroup
|
from vllm.distributed.utils import StatelessProcessGroup
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import get_ip, get_open_port, is_valid_ipv6_address
|
from vllm.utils import (get_ip, get_open_port, get_open_zmq_ipc_path,
|
||||||
|
is_valid_ipv6_address)
|
||||||
|
|
||||||
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
|
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
|
||||||
|
|
||||||
@ -165,12 +166,12 @@ class ShmRingBuffer:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Handle:
|
class Handle:
|
||||||
connect_ip: str
|
|
||||||
local_reader_ranks: List[int] = field(default_factory=list)
|
local_reader_ranks: List[int] = field(default_factory=list)
|
||||||
|
|
||||||
buffer_handle: Optional[Tuple[int, int, int, str]] = None
|
buffer_handle: Optional[Tuple[int, int, int, str]] = None
|
||||||
local_subscribe_port: Optional[int] = None
|
local_subscribe_addr: Optional[str] = None
|
||||||
remote_subscribe_port: Optional[int] = None
|
remote_subscribe_addr: Optional[str] = None
|
||||||
|
remote_addr_ipv6: bool = False
|
||||||
|
|
||||||
|
|
||||||
class MessageQueue:
|
class MessageQueue:
|
||||||
@ -192,9 +193,6 @@ class MessageQueue:
|
|||||||
n_remote_reader = n_reader - n_local_reader
|
n_remote_reader = n_reader - n_local_reader
|
||||||
self.n_remote_reader = n_remote_reader
|
self.n_remote_reader = n_remote_reader
|
||||||
|
|
||||||
if connect_ip is None:
|
|
||||||
connect_ip = get_ip() if n_remote_reader > 0 else "127.0.0.1"
|
|
||||||
|
|
||||||
context = Context()
|
context = Context()
|
||||||
|
|
||||||
if n_local_reader > 0:
|
if n_local_reader > 0:
|
||||||
@ -212,32 +210,34 @@ class MessageQueue:
|
|||||||
# message. otherwise, we will only receive the first subscription
|
# message. otherwise, we will only receive the first subscription
|
||||||
# see http://api.zeromq.org/3-3:zmq-setsockopt for more details
|
# see http://api.zeromq.org/3-3:zmq-setsockopt for more details
|
||||||
self.local_socket.setsockopt(XPUB_VERBOSE, True)
|
self.local_socket.setsockopt(XPUB_VERBOSE, True)
|
||||||
local_subscribe_port = get_open_port()
|
local_subscribe_addr = get_open_zmq_ipc_path()
|
||||||
socket_addr = f"tcp://127.0.0.1:{local_subscribe_port}"
|
logger.debug("Binding to %s", local_subscribe_addr)
|
||||||
logger.debug("Binding to %s", socket_addr)
|
self.local_socket.bind(local_subscribe_addr)
|
||||||
self.local_socket.bind(socket_addr)
|
|
||||||
|
|
||||||
self.current_idx = 0
|
self.current_idx = 0
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.buffer = None # type: ignore
|
self.buffer = None # type: ignore
|
||||||
local_subscribe_port = None
|
local_subscribe_addr = None
|
||||||
self.local_socket = None
|
self.local_socket = None
|
||||||
self.current_idx = -1
|
self.current_idx = -1
|
||||||
|
|
||||||
|
remote_addr_ipv6 = False
|
||||||
if n_remote_reader > 0:
|
if n_remote_reader > 0:
|
||||||
# for remote readers, we will:
|
# for remote readers, we will:
|
||||||
# create a publish-subscribe socket to communicate large data
|
# create a publish-subscribe socket to communicate large data
|
||||||
|
if not connect_ip:
|
||||||
|
connect_ip = get_ip()
|
||||||
self.remote_socket = context.socket(XPUB)
|
self.remote_socket = context.socket(XPUB)
|
||||||
self.remote_socket.setsockopt(XPUB_VERBOSE, True)
|
self.remote_socket.setsockopt(XPUB_VERBOSE, True)
|
||||||
remote_subscribe_port = get_open_port()
|
remote_subscribe_port = get_open_port()
|
||||||
if is_valid_ipv6_address(connect_ip):
|
if is_valid_ipv6_address(connect_ip):
|
||||||
self.remote_socket.setsockopt(IPV6, 1)
|
self.remote_socket.setsockopt(IPV6, 1)
|
||||||
|
remote_addr_ipv6 = True
|
||||||
socket_addr = f"tcp://*:{remote_subscribe_port}"
|
socket_addr = f"tcp://*:{remote_subscribe_port}"
|
||||||
self.remote_socket.bind(socket_addr)
|
self.remote_socket.bind(socket_addr)
|
||||||
|
remote_subscribe_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
|
||||||
else:
|
else:
|
||||||
remote_subscribe_port = None
|
remote_subscribe_addr = None
|
||||||
self.remote_socket = None
|
self.remote_socket = None
|
||||||
|
|
||||||
self._is_writer = True
|
self._is_writer = True
|
||||||
@ -247,12 +247,12 @@ class MessageQueue:
|
|||||||
self._is_remote_reader = False
|
self._is_remote_reader = False
|
||||||
|
|
||||||
self.handle = Handle(
|
self.handle = Handle(
|
||||||
connect_ip=connect_ip,
|
|
||||||
local_reader_ranks=local_reader_ranks,
|
local_reader_ranks=local_reader_ranks,
|
||||||
buffer_handle=self.buffer.handle()
|
buffer_handle=self.buffer.handle()
|
||||||
if self.buffer is not None else None,
|
if self.buffer is not None else None,
|
||||||
local_subscribe_port=local_subscribe_port,
|
local_subscribe_addr=local_subscribe_addr,
|
||||||
remote_subscribe_port=remote_subscribe_port,
|
remote_subscribe_addr=remote_subscribe_addr,
|
||||||
|
remote_addr_ipv6=remote_addr_ipv6,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("vLLM message queue communication handle: %s", self.handle)
|
logger.info("vLLM message queue communication handle: %s", self.handle)
|
||||||
@ -278,7 +278,7 @@ class MessageQueue:
|
|||||||
|
|
||||||
self.local_socket = context.socket(SUB)
|
self.local_socket = context.socket(SUB)
|
||||||
self.local_socket.setsockopt_string(SUBSCRIBE, "")
|
self.local_socket.setsockopt_string(SUBSCRIBE, "")
|
||||||
socket_addr = f"tcp://127.0.0.1:{handle.local_subscribe_port}"
|
socket_addr = handle.local_subscribe_addr
|
||||||
logger.debug("Connecting to %s", socket_addr)
|
logger.debug("Connecting to %s", socket_addr)
|
||||||
self.local_socket.connect(socket_addr)
|
self.local_socket.connect(socket_addr)
|
||||||
|
|
||||||
@ -294,9 +294,9 @@ class MessageQueue:
|
|||||||
|
|
||||||
self.remote_socket = context.socket(SUB)
|
self.remote_socket = context.socket(SUB)
|
||||||
self.remote_socket.setsockopt_string(SUBSCRIBE, "")
|
self.remote_socket.setsockopt_string(SUBSCRIBE, "")
|
||||||
if is_valid_ipv6_address(handle.connect_ip):
|
if handle.remote_addr_ipv6:
|
||||||
self.remote_socket.setsockopt(IPV6, 1)
|
self.remote_socket.setsockopt(IPV6, 1)
|
||||||
socket_addr = f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}"
|
socket_addr = handle.remote_subscribe_addr
|
||||||
logger.debug("Connecting to %s", socket_addr)
|
logger.debug("Connecting to %s", socket_addr)
|
||||||
self.remote_socket.connect(socket_addr)
|
self.remote_socket.connect(socket_addr)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user