[V1] Update zmq socket creation in nixl connector (#18148)

Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
Russell Bryant 2025-05-15 02:17:57 -04:00 committed by GitHub
parent de71fec81b
commit a8f5aec20a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 34 additions and 15 deletions

View File

@ -17,7 +17,7 @@ from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache,
MemorySnapshot, PlaceholderModule, StoreBoolean,
bind_kv_cache, deprecate_kwargs, get_open_port,
make_zmq_socket, memory_profiling,
make_zmq_path, make_zmq_socket, memory_profiling,
merge_async_iterators, sha256, split_zmq_path,
supports_kw, swap_dict_values)
@ -714,3 +714,8 @@ def test_make_zmq_socket_ipv6():
# Clean up
zsock.close()
ctx.term()
def test_make_zmq_path():
assert make_zmq_path("tcp", "127.0.0.1", "5555") == "tcp://127.0.0.1:5555"
assert make_zmq_path("tcp", "::1", "5555") == "tcp://[::1]:5555"

View File

@ -21,7 +21,7 @@ from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
get_tp_group)
from vllm.logger import init_logger
from vllm.utils import round_down
from vllm.utils import make_zmq_path, make_zmq_socket, round_down
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import RequestStatus
@ -379,7 +379,7 @@ class NixlConnectorWorker:
# hack to keeps us moving. We will switch when moving to etcd
# or where we have a single ZMQ socket in the scheduler.
port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT + rank
path = f"tcp://{host}:{port}"
path = make_zmq_path("tcp", host, port)
logger.debug("Starting listening on path: %s", path)
with zmq_ctx(zmq.ROUTER, path) as sock:
ready_event.set()
@ -397,7 +397,7 @@ class NixlConnectorWorker:
# NOTE(rob): we need each rank to have a unique port. This is
# a hack to keep us moving. We will switch when moving to etcd
# or where we have a single ZMQ socket in the scheduler.
path = f"tcp://{host}:{port + self.rank}"
path = make_zmq_path("tcp", host, port + self.rank)
logger.debug("Querying metadata on path: %s", path)
with zmq_ctx(zmq.REQ, path) as sock:
# Send query for the request.
@ -741,20 +741,16 @@ class NixlConnectorWorker:
def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]:
"""Context manager for a ZMQ socket"""
if socket_type not in (zmq.ROUTER, zmq.REQ):
raise ValueError(f"Unexpected socket type: {socket_type}")
ctx: Optional[zmq.Context] = None
try:
ctx = zmq.Context() # type: ignore[attr-defined]
if socket_type == zmq.ROUTER:
socket = ctx.socket(zmq.ROUTER)
socket.bind(addr)
elif socket_type == zmq.REQ:
socket = ctx.socket(zmq.REQ)
socket.connect(addr)
else:
raise ValueError(f"Unexpected socket type: {socket_type}")
yield socket
yield make_zmq_socket(ctx=ctx,
path=addr,
socket_type=socket_type,
bind=socket_type == zmq.ROUTER)
finally:
if ctx is not None:
ctx.destroy(linger=0)

View File

@ -2350,6 +2350,24 @@ def split_zmq_path(path: str) -> Tuple[str, str, str]:
return scheme, host, port
def make_zmq_path(scheme: str, host: str, port: Optional[int] = None) -> str:
"""Make a ZMQ path from its parts.
Args:
scheme: The ZMQ transport scheme (e.g. tcp, ipc, inproc).
host: The host - can be an IPv4 address, IPv6 address, or hostname.
port: Optional port number, only used for TCP sockets.
Returns:
A properly formatted ZMQ path string.
"""
if not port:
return f"{scheme}://{host}"
if is_valid_ipv6_address(host):
return f"{scheme}://[{host}]:{port}"
return f"{scheme}://{host}:{port}"
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L783 # noqa: E501
def make_zmq_socket(
ctx: Union[zmq.asyncio.Context, zmq.Context], # type: ignore[name-defined]