From c391e4b68e6694986f24ccd620d7bf07c237ab60 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 4 Apr 2024 16:52:12 -0700 Subject: [PATCH] [Core] improve robustness of pynccl (#3860) --- vllm/model_executor/parallel_utils/pynccl.py | 21 ++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/parallel_utils/pynccl.py b/vllm/model_executor/parallel_utils/pynccl.py index f7f83528cd06..0a8bb860efa1 100644 --- a/vllm/model_executor/parallel_utils/pynccl.py +++ b/vllm/model_executor/parallel_utils/pynccl.py @@ -236,22 +236,25 @@ class NCCLCommunicator: if local_rank == -1: local_rank = self.rank self.local_rank = local_rank - torch.cuda.set_device(local_rank) - if rank == 0: + # don't use these args, as they can be -1 + # use `self.rank`, `self.local_rank` and `self.world_size` instead + del world_size, rank, local_rank + torch.cuda.set_device(self.local_rank) + if self.rank == 0: self.unique_id = ncclGetUniqueId() else: self.unique_id = NcclUniqueId() - tensor = torch.ByteTensor(list( - self.unique_id.internal)).cuda(local_rank) + tensor = torch.ByteTensor(list(self.unique_id.internal)).cuda( + self.local_rank) dist.broadcast(tensor, src=0) byte_list = tensor.cpu().tolist() for i, byte in enumerate(byte_list): self.unique_id.internal[i] = byte self.comm = ctypes.c_void_p() - result = _c_ncclCommInitRank(ctypes.byref(self.comm), world_size, - self.unique_id, rank) + result = _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size, + self.unique_id, self.rank) assert result == 0 - self.stream = torch.cuda.Stream(device=f"cuda:{local_rank}") + self.stream = torch.cuda.Stream(device=f"cuda:{self.local_rank}") def all_reduce(self, tensor: torch.Tensor, @@ -271,4 +274,6 @@ class NCCLCommunicator: # `dist` module might have been already destroyed if hasattr(dist, 'destroy_process_group'): dist.destroy_process_group() - _c_ncclCommDestroy(self.comm) + # function might have been already destroyed + if _c_ncclCommDestroy is not None: + _c_ncclCommDestroy(self.comm)