[Misc] Make timeout passable in init_distributed_environment (#24522)

Signed-off-by: jberkhahn <jaberkha@us.ibm.com>
This commit is contained in:
Jonathan Berkhahn 2025-09-10 15:41:12 -07:00 committed by GitHub
parent dcb28a332b
commit cc99baf14d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -29,6 +29,7 @@ import weakref
from collections import namedtuple from collections import namedtuple
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from dataclasses import dataclass from dataclasses import dataclass
from datetime import timedelta
from multiprocessing import shared_memory from multiprocessing import shared_memory
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, Union
from unittest.mock import patch from unittest.mock import patch
@ -978,13 +979,12 @@ def set_custom_all_reduce(enable: bool):
_ENABLE_CUSTOM_ALL_REDUCE = enable _ENABLE_CUSTOM_ALL_REDUCE = enable
def init_distributed_environment( def init_distributed_environment(world_size: int = -1,
world_size: int = -1, rank: int = -1,
rank: int = -1, distributed_init_method: str = "env://",
distributed_init_method: str = "env://", local_rank: int = -1,
local_rank: int = -1, backend: str = "nccl",
backend: str = "nccl", timeout: Optional[timedelta] = None):
):
logger.debug( logger.debug(
"world_size=%d rank=%d local_rank=%d " "world_size=%d rank=%d local_rank=%d "
"distributed_init_method=%s backend=%s", world_size, rank, local_rank, "distributed_init_method=%s backend=%s", world_size, rank, local_rank,
@ -1020,7 +1020,8 @@ def init_distributed_environment(
backend=backend, backend=backend,
init_method=distributed_init_method, init_method=distributed_init_method,
world_size=world_size, world_size=world_size,
rank=rank) rank=rank,
timeout=timeout)
# set the local rank # set the local rank
# local_rank is not available in torch ProcessGroup, # local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816 # see https://github.com/pytorch/pytorch/issues/122816