[Core] Enable IPv6 with vllm.utils.make_zmq_socket() (#16506)

Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
Russell Bryant 2025-05-01 05:38:18 -04:00 committed by GitHub
parent 26bc4bbcd8
commit fbefc8a78d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 81 additions and 1 deletions

View File

@ -10,13 +10,15 @@ from unittest.mock import patch
import pytest
import torch
import zmq
from vllm_test_utils.monitor import monitor
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache,
MemorySnapshot, PlaceholderModule, StoreBoolean,
bind_kv_cache, deprecate_kwargs, get_open_port,
memory_profiling, merge_async_iterators, sha256,
make_zmq_socket, memory_profiling,
merge_async_iterators, sha256, split_zmq_path,
supports_kw, swap_dict_values)
from .utils import create_new_process_for_each_test, error_on_warning
@ -662,3 +664,53 @@ def test_sha256(input: tuple, output: int):
# hashing different input, returns different value
assert hash != sha256(input + (1, ))
@pytest.mark.parametrize(
"path,expected",
[
("ipc://some_path", ("ipc", "some_path", "")),
("tcp://127.0.0.1:5555", ("tcp", "127.0.0.1", "5555")),
("tcp://[::1]:5555", ("tcp", "::1", "5555")), # IPv6 address
("inproc://some_identifier", ("inproc", "some_identifier", "")),
]
)
def test_split_zmq_path(path, expected):
assert split_zmq_path(path) == expected
@pytest.mark.parametrize(
"invalid_path",
[
"invalid_path", # Missing scheme
"tcp://127.0.0.1", # Missing port
"tcp://[::1]", # Missing port for IPv6
"tcp://:5555", # Missing host
]
)
def test_split_zmq_path_invalid(invalid_path):
with pytest.raises(ValueError):
split_zmq_path(invalid_path)
def test_make_zmq_socket_ipv6():
# Check if IPv6 is supported by trying to create an IPv6 socket
try:
sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
sock.close()
except socket.error:
pytest.skip("IPv6 is not supported on this system")
ctx = zmq.Context()
ipv6_path = "tcp://[::]:5555" # IPv6 loopback address
socket_type = zmq.REP # Example socket type
# Create the socket
zsock: zmq.Socket = make_zmq_socket(ctx, ipv6_path, socket_type)
# Verify that the IPV6 option is set
assert zsock.getsockopt(zmq.IPV6) == 1, "IPV6 option should be enabled for IPv6 addresses"
# Clean up
zsock.close()
ctx.term()

View File

@ -45,6 +45,7 @@ from types import MappingProxyType
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
Optional, Sequence, Tuple, Type, TypeVar, Union, cast,
overload)
from urllib.parse import urlparse
from uuid import uuid4
import cachetools
@ -2278,6 +2279,27 @@ def get_exception_traceback():
return err_str
def split_zmq_path(path: str) -> Tuple[str, str, str]:
"""Split a zmq path into its parts."""
parsed = urlparse(path)
if not parsed.scheme:
raise ValueError(f"Invalid zmq path: {path}")
scheme = parsed.scheme
host = parsed.hostname or ""
port = str(parsed.port or "")
if scheme == "tcp" and not all((host, port)):
# The host and port fields are required for tcp
raise ValueError(f"Invalid zmq path: {path}")
if scheme != "tcp" and port:
# port only makes sense with tcp
raise ValueError(f"Invalid zmq path: {path}")
return scheme, host, port
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L783 # noqa: E501
def make_zmq_socket(
ctx: Union[zmq.asyncio.Context, zmq.Context], # type: ignore[name-defined]
@ -2317,6 +2339,12 @@ def make_zmq_socket(
if identity is not None:
socket.setsockopt(zmq.IDENTITY, identity)
# Determine if the path is a TCP socket with an IPv6 address.
# Enable IPv6 on the zmq socket if so.
scheme, host, _ = split_zmq_path(path)
if scheme == "tcp" and is_valid_ipv6_address(host):
socket.setsockopt(zmq.IPV6, 1)
if bind:
socket.bind(path)
else: