Use custom address for listening socket (#15988)

Signed-off-by: Jens Glaser <glaserj@ornl.gov>
This commit is contained in:
jglaser 2025-04-24 21:57:16 -04:00 committed by GitHub
parent 9420a1fc30
commit 0d6e187e88
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -7,6 +7,7 @@
import dataclasses
import datetime
import pickle
import socket
import time
from collections import deque
from typing import Any, Deque, Dict, Optional, Sequence, Tuple
@ -123,6 +124,10 @@ class StatelessProcessGroup:
rank: int
world_size: int
store: torch._C._distributed_c10d.Store
# stores a reference to the socket so that the file descriptor stays alive
socket: Optional[socket.socket]
data_expiration_seconds: int = 3600 # 1 hour
# dst rank -> counter
@ -234,18 +239,33 @@ class StatelessProcessGroup:
can call `StatelessProcessGroup.create` to form a group, and then process A, B,
C, and D can call `StatelessProcessGroup.create` to form another group.
""" # noqa
launch_server = rank == 0
if launch_server:
# listen on the specified interface (instead of 0.0.0.0)
listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
listen_socket.bind((host, port))
listen_socket.listen()
listen_fd = listen_socket.fileno()
else:
listen_socket = None
listen_fd = None
store = TCPStore(
host_name=host,
port=port,
world_size=world_size,
is_master=(rank == 0),
is_master=launch_server,
timeout=datetime.timedelta(seconds=store_timeout),
use_libuv=False, # for now: github.com/pytorch/pytorch/pull/150215
master_listen_fd=listen_fd,
)
return StatelessProcessGroup(
rank=rank,
world_size=world_size,
store=store,
socket=listen_socket,
data_expiration_seconds=data_expiration_seconds)