[Misc] Add pynccl wrappers for all_gather and reduce_scatter (#9432)

This commit is contained in:
Tyler Michael Smith 2024-11-22 22:14:03 -05:00 committed by GitHub
parent ebda51968b
commit 978b39744b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 155 additions and 0 deletions

View File

@ -150,6 +150,75 @@ def worker_fn_with_cudagraph():
assert a.mean().cpu().item() == pynccl_comm.world_size**1
@worker_fn_wrapper
def all_gather_worker_fn():
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
device=get_world_group().device)
rank = pynccl_comm.rank
world_size = pynccl_comm.world_size
device = f'cuda:{pynccl_comm.rank}'
num_elems = 1000
tensor = torch.arange(num_elems, dtype=torch.float32,
device=device) + rank * num_elems
result = torch.zeros(num_elems * world_size,
dtype=torch.float32,
device=device)
expected = torch.cat([
torch.arange(num_elems, dtype=torch.float32) + r * num_elems
for r in range(world_size)
]).to(device)
with pynccl_comm.change_state(enable=True):
pynccl_comm.all_gather(result, tensor)
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
def test_pynccl_all_gather():
distributed_run(all_gather_worker_fn, 2)
@worker_fn_wrapper
def reduce_scatter_worker_fn():
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
device=get_world_group().device)
rank = pynccl_comm.rank
world_size = pynccl_comm.world_size
device = f'cuda:{pynccl_comm.rank}'
num_elems = 1000
tensor = torch.arange(num_elems, dtype=torch.float32,
device=device) + rank * num_elems
assert (num_elems % world_size == 0)
result = torch.zeros(num_elems // world_size,
dtype=torch.float32,
device=device)
# Calculate expected result for this rank's chunk
scattered_size = num_elems // world_size
all_tensors = [
torch.arange(num_elems, dtype=torch.float32) + r * num_elems
for r in range(world_size)
]
expected = sum(tensor[rank * scattered_size:(rank + 1) * scattered_size]
for tensor in all_tensors).to(device)
with pynccl_comm.change_state(enable=True):
pynccl_comm.reduce_scatter(result, tensor)
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
def test_pynccl_reduce_scatter():
distributed_run(reduce_scatter_worker_fn, 2)
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
def test_pynccl_with_cudagraph():

View File

@ -131,6 +131,48 @@ class PyNcclCommunicator:
ncclRedOpTypeEnum.from_torch(op), self.comm,
cudaStream_t(stream.cuda_stream))
def all_gather(self,
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
stream=None):
if self.disabled:
return
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert input_tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}")
if stream is None:
stream = self.stream
self.nccl.ncclAllGather(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()), input_tensor.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm,
cudaStream_t(stream.cuda_stream))
def reduce_scatter(self,
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
stream=None):
if self.disabled:
return
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert input_tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}")
if stream is None:
stream = self.stream
self.nccl.ncclReduceScatter(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()), output_tensor.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
ncclRedOpTypeEnum.from_torch(op), self.comm,
cudaStream_t(stream.cuda_stream))
def send(self, tensor: torch.Tensor, dst: int, stream=None):
if self.disabled:
return

View File

@ -151,6 +151,28 @@ class NCCLLibrary:
ncclRedOp_t, ncclComm_t, cudaStream_t
]),
# ncclResult_t ncclAllGather(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclComm_t comm,
# cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function("ncclAllGather", ncclResult_t, [
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
ncclComm_t, cudaStream_t
]),
# ncclResult_t ncclReduceScatter(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
# cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function("ncclReduceScatter", ncclResult_t, [
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
ncclRedOp_t, ncclComm_t, cudaStream_t
]),
# ncclResult_t ncclSend(
# const void* sendbuff, size_t count, ncclDataType_t datatype,
# int dest, ncclComm_t comm, cudaStream_t stream);
@ -258,6 +280,28 @@ class NCCLLibrary:
datatype, op, comm,
stream))
def ncclReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, op: int, comm: ncclComm_t,
stream: cudaStream_t) -> None:
# `datatype` actually should be `ncclDataType_t`
# and `op` should be `ncclRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.NCCL_CHECK(self._funcs["ncclReduceScatter"](sendbuff, recvbuff,
count, datatype, op,
comm, stream))
def ncclAllGather(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, comm: ncclComm_t,
stream: cudaStream_t) -> None:
# `datatype` actually should be `ncclDataType_t`
# which is an aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.NCCL_CHECK(self._funcs["ncclAllGather"](sendbuff, recvbuff, count,
datatype, comm, stream))
def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int,
dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None:
self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype,