From fbefc8a78d22b20eac042c586805c7dcbfc66b1c Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Thu, 1 May 2025 05:38:18 -0400 Subject: [PATCH] [Core] Enable IPv6 with vllm.utils.make_zmq_socket() (#16506) Signed-off-by: Russell Bryant --- tests/test_utils.py | 54 ++++++++++++++++++++++++++++++++++++++++++++- vllm/utils.py | 28 +++++++++++++++++++++++ 2 files changed, 81 insertions(+), 1 deletion(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 580e65f1f833..deff33e5c3ca 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -10,13 +10,15 @@ from unittest.mock import patch import pytest import torch +import zmq from vllm_test_utils.monitor import monitor from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache, MemorySnapshot, PlaceholderModule, StoreBoolean, bind_kv_cache, deprecate_kwargs, get_open_port, - memory_profiling, merge_async_iterators, sha256, + make_zmq_socket, memory_profiling, + merge_async_iterators, sha256, split_zmq_path, supports_kw, swap_dict_values) from .utils import create_new_process_for_each_test, error_on_warning @@ -662,3 +664,53 @@ def test_sha256(input: tuple, output: int): # hashing different input, returns different value assert hash != sha256(input + (1, )) + + +@pytest.mark.parametrize( + "path,expected", + [ + ("ipc://some_path", ("ipc", "some_path", "")), + ("tcp://127.0.0.1:5555", ("tcp", "127.0.0.1", "5555")), + ("tcp://[::1]:5555", ("tcp", "::1", "5555")), # IPv6 address + ("inproc://some_identifier", ("inproc", "some_identifier", "")), + ] +) +def test_split_zmq_path(path, expected): + assert split_zmq_path(path) == expected + + +@pytest.mark.parametrize( + "invalid_path", + [ + "invalid_path", # Missing scheme + "tcp://127.0.0.1", # Missing port + "tcp://[::1]", # Missing port for IPv6 + "tcp://:5555", # Missing host + ] +) +def test_split_zmq_path_invalid(invalid_path): + with pytest.raises(ValueError): + split_zmq_path(invalid_path) + + +def test_make_zmq_socket_ipv6(): + # Check if IPv6 is supported by trying to create an IPv6 socket + try: + sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + sock.close() + except socket.error: + pytest.skip("IPv6 is not supported on this system") + + ctx = zmq.Context() + ipv6_path = "tcp://[::]:5555" # IPv6 loopback address + socket_type = zmq.REP # Example socket type + + # Create the socket + zsock: zmq.Socket = make_zmq_socket(ctx, ipv6_path, socket_type) + + # Verify that the IPV6 option is set + assert zsock.getsockopt(zmq.IPV6) == 1, "IPV6 option should be enabled for IPv6 addresses" + + # Clean up + zsock.close() + ctx.term() diff --git a/vllm/utils.py b/vllm/utils.py index 73726bb9a346..f85bbe3a5990 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -45,6 +45,7 @@ from types import MappingProxyType from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple, Optional, Sequence, Tuple, Type, TypeVar, Union, cast, overload) +from urllib.parse import urlparse from uuid import uuid4 import cachetools @@ -2278,6 +2279,27 @@ def get_exception_traceback(): return err_str +def split_zmq_path(path: str) -> Tuple[str, str, str]: + """Split a zmq path into its parts.""" + parsed = urlparse(path) + if not parsed.scheme: + raise ValueError(f"Invalid zmq path: {path}") + + scheme = parsed.scheme + host = parsed.hostname or "" + port = str(parsed.port or "") + + if scheme == "tcp" and not all((host, port)): + # The host and port fields are required for tcp + raise ValueError(f"Invalid zmq path: {path}") + + if scheme != "tcp" and port: + # port only makes sense with tcp + raise ValueError(f"Invalid zmq path: {path}") + + return scheme, host, port + + # Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L783 # noqa: E501 def make_zmq_socket( ctx: Union[zmq.asyncio.Context, zmq.Context], # type: ignore[name-defined] @@ -2317,6 +2339,12 @@ def make_zmq_socket( if identity is not None: socket.setsockopt(zmq.IDENTITY, identity) + # Determine if the path is a TCP socket with an IPv6 address. + # Enable IPv6 on the zmq socket if so. + scheme, host, _ = split_zmq_path(path) + if scheme == "tcp" and is_valid_ipv6_address(host): + socket.setsockopt(zmq.IPV6, 1) + if bind: socket.bind(path) else: