diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 30d1bf10138bb..9c2bf51a813e2 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -147,6 +147,10 @@ class CudaCommunicator(DeviceCommunicatorBase): assert out is not None return out pynccl_comm = self.pynccl_comm + if pynccl_comm is None or pynccl_comm.disabled: + out = input_.clone() + torch.distributed.all_reduce(out, group=self.device_group) + return out assert pynccl_comm is not None out = pynccl_comm.all_reduce(input_) if out is None: diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 76fe9a93259fa..81c02d1899e5a 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -8,6 +8,7 @@ import torch import torch.distributed as dist from torch.distributed import ProcessGroup, ReduceOp +import vllm.envs as envs from vllm.distributed.device_communicators.pynccl_wrapper import ( NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum, ncclRedOpTypeEnum, ncclUniqueId) @@ -83,7 +84,7 @@ class PyNcclCommunicator: self.group = group # if world_size == 1, no need to create communicator - if self.world_size == 1: + if self.world_size == 1 or envs.VLLM_DISABLE_PYNCCL: self.available = False self.disabled = True return diff --git a/vllm/envs.py b/vllm/envs.py index ffa7ed5c3aa5a..03a22e4b2c7e3 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -98,6 +98,7 @@ if TYPE_CHECKING: VLLM_SKIP_P2P_CHECK: bool = False VLLM_DISABLED_KERNELS: list[str] = [] VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION: bool = False + VLLM_DISABLE_PYNCCL: bool = False VLLM_USE_V1: bool = True VLLM_ROCM_USE_AITER: bool = False VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False @@ -897,6 +898,11 @@ environment_variables: dict[str, Callable[[], Any]] = { (os.getenv("VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION", "False").lower() in ("true", "1")), + # Disable pynccl (using torch.distributed instead) + "VLLM_DISABLE_PYNCCL": + lambda: + (os.getenv("VLLM_DISABLE_PYNCCL", "False").lower() in ("true", "1")), + # If set, use the V1 code path. "VLLM_USE_V1": lambda: bool(int(os.getenv("VLLM_USE_V1", "1"))),