mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 12:45:33 +08:00
[Core] improve robustness of pynccl (#3860)
This commit is contained in:
parent
9117f892f0
commit
c391e4b68e
@ -236,22 +236,25 @@ class NCCLCommunicator:
|
|||||||
if local_rank == -1:
|
if local_rank == -1:
|
||||||
local_rank = self.rank
|
local_rank = self.rank
|
||||||
self.local_rank = local_rank
|
self.local_rank = local_rank
|
||||||
torch.cuda.set_device(local_rank)
|
# don't use these args, as they can be -1
|
||||||
if rank == 0:
|
# use `self.rank`, `self.local_rank` and `self.world_size` instead
|
||||||
|
del world_size, rank, local_rank
|
||||||
|
torch.cuda.set_device(self.local_rank)
|
||||||
|
if self.rank == 0:
|
||||||
self.unique_id = ncclGetUniqueId()
|
self.unique_id = ncclGetUniqueId()
|
||||||
else:
|
else:
|
||||||
self.unique_id = NcclUniqueId()
|
self.unique_id = NcclUniqueId()
|
||||||
tensor = torch.ByteTensor(list(
|
tensor = torch.ByteTensor(list(self.unique_id.internal)).cuda(
|
||||||
self.unique_id.internal)).cuda(local_rank)
|
self.local_rank)
|
||||||
dist.broadcast(tensor, src=0)
|
dist.broadcast(tensor, src=0)
|
||||||
byte_list = tensor.cpu().tolist()
|
byte_list = tensor.cpu().tolist()
|
||||||
for i, byte in enumerate(byte_list):
|
for i, byte in enumerate(byte_list):
|
||||||
self.unique_id.internal[i] = byte
|
self.unique_id.internal[i] = byte
|
||||||
self.comm = ctypes.c_void_p()
|
self.comm = ctypes.c_void_p()
|
||||||
result = _c_ncclCommInitRank(ctypes.byref(self.comm), world_size,
|
result = _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
|
||||||
self.unique_id, rank)
|
self.unique_id, self.rank)
|
||||||
assert result == 0
|
assert result == 0
|
||||||
self.stream = torch.cuda.Stream(device=f"cuda:{local_rank}")
|
self.stream = torch.cuda.Stream(device=f"cuda:{self.local_rank}")
|
||||||
|
|
||||||
def all_reduce(self,
|
def all_reduce(self,
|
||||||
tensor: torch.Tensor,
|
tensor: torch.Tensor,
|
||||||
@ -271,4 +274,6 @@ class NCCLCommunicator:
|
|||||||
# `dist` module might have been already destroyed
|
# `dist` module might have been already destroyed
|
||||||
if hasattr(dist, 'destroy_process_group'):
|
if hasattr(dist, 'destroy_process_group'):
|
||||||
dist.destroy_process_group()
|
dist.destroy_process_group()
|
||||||
_c_ncclCommDestroy(self.comm)
|
# function might have been already destroyed
|
||||||
|
if _c_ncclCommDestroy is not None:
|
||||||
|
_c_ncclCommDestroy(self.comm)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user