From a8f5aec20ad685851f972847c0567db270d9845f Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Thu, 15 May 2025 02:17:57 -0400 Subject: [PATCH] [V1] Update zmq socket creation in nixl connector (#18148) Signed-off-by: Russell Bryant --- tests/test_utils.py | 7 +++++- .../kv_connector/v1/nixl_connector.py | 24 ++++++++----------- vllm/utils.py | 18 ++++++++++++++ 3 files changed, 34 insertions(+), 15 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index deff33e5c3caf..ea7db0a79c86b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -17,7 +17,7 @@ from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache, MemorySnapshot, PlaceholderModule, StoreBoolean, bind_kv_cache, deprecate_kwargs, get_open_port, - make_zmq_socket, memory_profiling, + make_zmq_path, make_zmq_socket, memory_profiling, merge_async_iterators, sha256, split_zmq_path, supports_kw, swap_dict_values) @@ -714,3 +714,8 @@ def test_make_zmq_socket_ipv6(): # Clean up zsock.close() ctx.term() + + +def test_make_zmq_path(): + assert make_zmq_path("tcp", "127.0.0.1", "5555") == "tcp://127.0.0.1:5555" + assert make_zmq_path("tcp", "::1", "5555") == "tcp://[::1]:5555" diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index abd1ea2bea82b..c0c03efcdbf4f 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -21,7 +21,7 @@ from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tp_group) from vllm.logger import init_logger -from vllm.utils import round_down +from vllm.utils import make_zmq_path, make_zmq_socket, round_down from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.request import RequestStatus @@ -379,7 +379,7 @@ class NixlConnectorWorker: # hack to keeps us moving. We will switch when moving to etcd # or where we have a single ZMQ socket in the scheduler. port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT + rank - path = f"tcp://{host}:{port}" + path = make_zmq_path("tcp", host, port) logger.debug("Starting listening on path: %s", path) with zmq_ctx(zmq.ROUTER, path) as sock: ready_event.set() @@ -397,7 +397,7 @@ class NixlConnectorWorker: # NOTE(rob): we need each rank to have a unique port. This is # a hack to keep us moving. We will switch when moving to etcd # or where we have a single ZMQ socket in the scheduler. - path = f"tcp://{host}:{port + self.rank}" + path = make_zmq_path("tcp", host, port + self.rank) logger.debug("Querying metadata on path: %s", path) with zmq_ctx(zmq.REQ, path) as sock: # Send query for the request. @@ -741,20 +741,16 @@ class NixlConnectorWorker: def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]: """Context manager for a ZMQ socket""" + if socket_type not in (zmq.ROUTER, zmq.REQ): + raise ValueError(f"Unexpected socket type: {socket_type}") + ctx: Optional[zmq.Context] = None try: ctx = zmq.Context() # type: ignore[attr-defined] - - if socket_type == zmq.ROUTER: - socket = ctx.socket(zmq.ROUTER) - socket.bind(addr) - elif socket_type == zmq.REQ: - socket = ctx.socket(zmq.REQ) - socket.connect(addr) - else: - raise ValueError(f"Unexpected socket type: {socket_type}") - - yield socket + yield make_zmq_socket(ctx=ctx, + path=addr, + socket_type=socket_type, + bind=socket_type == zmq.ROUTER) finally: if ctx is not None: ctx.destroy(linger=0) diff --git a/vllm/utils.py b/vllm/utils.py index 9a7da8067ba4d..edfbb8c9481e1 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -2350,6 +2350,24 @@ def split_zmq_path(path: str) -> Tuple[str, str, str]: return scheme, host, port +def make_zmq_path(scheme: str, host: str, port: Optional[int] = None) -> str: + """Make a ZMQ path from its parts. + + Args: + scheme: The ZMQ transport scheme (e.g. tcp, ipc, inproc). + host: The host - can be an IPv4 address, IPv6 address, or hostname. + port: Optional port number, only used for TCP sockets. + + Returns: + A properly formatted ZMQ path string. + """ + if not port: + return f"{scheme}://{host}" + if is_valid_ipv6_address(host): + return f"{scheme}://[{host}]:{port}" + return f"{scheme}://{host}:{port}" + + # Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L783 # noqa: E501 def make_zmq_socket( ctx: Union[zmq.asyncio.Context, zmq.Context], # type: ignore[name-defined]