diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 602bcebc017dd..ef229299b6848 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -29,6 +29,7 @@ import weakref from collections import namedtuple from contextlib import contextmanager, nullcontext from dataclasses import dataclass +from datetime import timedelta from multiprocessing import shared_memory from typing import Any, Callable, Optional, Union from unittest.mock import patch @@ -978,13 +979,12 @@ def set_custom_all_reduce(enable: bool): _ENABLE_CUSTOM_ALL_REDUCE = enable -def init_distributed_environment( - world_size: int = -1, - rank: int = -1, - distributed_init_method: str = "env://", - local_rank: int = -1, - backend: str = "nccl", -): +def init_distributed_environment(world_size: int = -1, + rank: int = -1, + distributed_init_method: str = "env://", + local_rank: int = -1, + backend: str = "nccl", + timeout: Optional[timedelta] = None): logger.debug( "world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s", world_size, rank, local_rank, @@ -1020,7 +1020,8 @@ def init_distributed_environment( backend=backend, init_method=distributed_init_method, world_size=world_size, - rank=rank) + rank=rank, + timeout=timeout) # set the local rank # local_rank is not available in torch ProcessGroup, # see https://github.com/pytorch/pytorch/issues/122816