mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-30 04:57:18 +08:00
[V1] Update zmq socket creation in nixl connector (#18148)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
parent
de71fec81b
commit
a8f5aec20a
@ -17,7 +17,7 @@ from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache,
|
||||
MemorySnapshot, PlaceholderModule, StoreBoolean,
|
||||
bind_kv_cache, deprecate_kwargs, get_open_port,
|
||||
make_zmq_socket, memory_profiling,
|
||||
make_zmq_path, make_zmq_socket, memory_profiling,
|
||||
merge_async_iterators, sha256, split_zmq_path,
|
||||
supports_kw, swap_dict_values)
|
||||
|
||||
@ -714,3 +714,8 @@ def test_make_zmq_socket_ipv6():
|
||||
# Clean up
|
||||
zsock.close()
|
||||
ctx.term()
|
||||
|
||||
|
||||
def test_make_zmq_path():
|
||||
assert make_zmq_path("tcp", "127.0.0.1", "5555") == "tcp://127.0.0.1:5555"
|
||||
assert make_zmq_path("tcp", "::1", "5555") == "tcp://[::1]:5555"
|
||||
|
||||
@ -21,7 +21,7 @@ from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
|
||||
get_tp_group)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import round_down
|
||||
from vllm.utils import make_zmq_path, make_zmq_socket, round_down
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.request import RequestStatus
|
||||
|
||||
@ -379,7 +379,7 @@ class NixlConnectorWorker:
|
||||
# hack to keeps us moving. We will switch when moving to etcd
|
||||
# or where we have a single ZMQ socket in the scheduler.
|
||||
port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT + rank
|
||||
path = f"tcp://{host}:{port}"
|
||||
path = make_zmq_path("tcp", host, port)
|
||||
logger.debug("Starting listening on path: %s", path)
|
||||
with zmq_ctx(zmq.ROUTER, path) as sock:
|
||||
ready_event.set()
|
||||
@ -397,7 +397,7 @@ class NixlConnectorWorker:
|
||||
# NOTE(rob): we need each rank to have a unique port. This is
|
||||
# a hack to keep us moving. We will switch when moving to etcd
|
||||
# or where we have a single ZMQ socket in the scheduler.
|
||||
path = f"tcp://{host}:{port + self.rank}"
|
||||
path = make_zmq_path("tcp", host, port + self.rank)
|
||||
logger.debug("Querying metadata on path: %s", path)
|
||||
with zmq_ctx(zmq.REQ, path) as sock:
|
||||
# Send query for the request.
|
||||
@ -741,20 +741,16 @@ class NixlConnectorWorker:
|
||||
def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]:
|
||||
"""Context manager for a ZMQ socket"""
|
||||
|
||||
if socket_type not in (zmq.ROUTER, zmq.REQ):
|
||||
raise ValueError(f"Unexpected socket type: {socket_type}")
|
||||
|
||||
ctx: Optional[zmq.Context] = None
|
||||
try:
|
||||
ctx = zmq.Context() # type: ignore[attr-defined]
|
||||
|
||||
if socket_type == zmq.ROUTER:
|
||||
socket = ctx.socket(zmq.ROUTER)
|
||||
socket.bind(addr)
|
||||
elif socket_type == zmq.REQ:
|
||||
socket = ctx.socket(zmq.REQ)
|
||||
socket.connect(addr)
|
||||
else:
|
||||
raise ValueError(f"Unexpected socket type: {socket_type}")
|
||||
|
||||
yield socket
|
||||
yield make_zmq_socket(ctx=ctx,
|
||||
path=addr,
|
||||
socket_type=socket_type,
|
||||
bind=socket_type == zmq.ROUTER)
|
||||
finally:
|
||||
if ctx is not None:
|
||||
ctx.destroy(linger=0)
|
||||
|
||||
@ -2350,6 +2350,24 @@ def split_zmq_path(path: str) -> Tuple[str, str, str]:
|
||||
return scheme, host, port
|
||||
|
||||
|
||||
def make_zmq_path(scheme: str, host: str, port: Optional[int] = None) -> str:
|
||||
"""Make a ZMQ path from its parts.
|
||||
|
||||
Args:
|
||||
scheme: The ZMQ transport scheme (e.g. tcp, ipc, inproc).
|
||||
host: The host - can be an IPv4 address, IPv6 address, or hostname.
|
||||
port: Optional port number, only used for TCP sockets.
|
||||
|
||||
Returns:
|
||||
A properly formatted ZMQ path string.
|
||||
"""
|
||||
if not port:
|
||||
return f"{scheme}://{host}"
|
||||
if is_valid_ipv6_address(host):
|
||||
return f"{scheme}://[{host}]:{port}"
|
||||
return f"{scheme}://{host}:{port}"
|
||||
|
||||
|
||||
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L783 # noqa: E501
|
||||
def make_zmq_socket(
|
||||
ctx: Union[zmq.asyncio.Context, zmq.Context], # type: ignore[name-defined]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user