Add pynccl all-gatherv and reducescatterv (#20154)

Signed-off-by: Trevor Morris <tmorris@nvidia.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Trevor Morris 2025-07-11 18:59:23 -07:00 committed by GitHub
parent fc0f41d10a
commit a8593237c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 284 additions and 2 deletions

View File

@ -4,6 +4,7 @@
import multiprocessing import multiprocessing
import os import os
import numpy as np
import pytest import pytest
import torch import torch
import torch.distributed import torch.distributed
@ -177,6 +178,38 @@ def test_pynccl_all_gather():
distributed_run(all_gather_worker_fn, 2) distributed_run(all_gather_worker_fn, 2)
@worker_fn_wrapper
def all_gatherv_worker_fn():
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
device=get_world_group().device)
rank = pynccl_comm.rank
world_size = pynccl_comm.world_size
device = f'cuda:{pynccl_comm.rank}'
assert world_size <= 8
sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size]
num_elems = sizes[rank]
tensor = torch.arange(num_elems, dtype=torch.float32,
device=device) + rank * 100
result = torch.zeros(sum(sizes), dtype=torch.float32, device=device)
expected = torch.cat([
torch.arange(sizes[r], dtype=torch.float32) + r * 100
for r in range(world_size)
]).to(device)
pynccl_comm.all_gatherv(result, tensor, sizes=sizes)
torch.cuda.synchronize()
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
def test_pynccl_all_gatherv():
distributed_run(all_gatherv_worker_fn, 2)
@worker_fn_wrapper @worker_fn_wrapper
def reduce_scatter_worker_fn(): def reduce_scatter_worker_fn():
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
@ -214,6 +247,43 @@ def test_pynccl_reduce_scatter():
distributed_run(reduce_scatter_worker_fn, 2) distributed_run(reduce_scatter_worker_fn, 2)
@worker_fn_wrapper
def reduce_scatterv_worker_fn():
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
device=get_world_group().device)
rank = pynccl_comm.rank
world_size = pynccl_comm.world_size
device = f'cuda:{pynccl_comm.rank}'
assert world_size <= 8
sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size]
num_elems = sum(sizes)
tensor = torch.arange(num_elems, dtype=torch.float32,
device=device) + rank * 100
result = torch.zeros(sizes[rank], dtype=torch.float32, device=device)
# Calculate expected result for this rank's chunk
all_tensors = [
torch.arange(num_elems, dtype=torch.float32) + r * 100
for r in range(world_size)
]
sizes_cumsum = np.cumsum(sizes)
start = 0 if rank == 0 else sizes_cumsum[rank - 1]
end = sizes_cumsum[rank]
expected = sum(tensor[start:end] for tensor in all_tensors).to(device)
pynccl_comm.reduce_scatterv(result, tensor, sizes=sizes)
torch.cuda.synchronize()
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
def test_pynccl_reduce_scatterv():
distributed_run(reduce_scatterv_worker_fn, 2)
@pytest.mark.skipif(torch.cuda.device_count() < 2, @pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.") reason="Need at least 2 GPUs to run the test.")
def test_pynccl_with_cudagraph(): def test_pynccl_with_cudagraph():

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading import threading
from typing import Optional from typing import Optional, Union
from weakref import WeakValueDictionary from weakref import WeakValueDictionary
import torch import torch
@ -138,6 +138,14 @@ class DeviceCommunicatorBase:
input_size[dim + 1:]) input_size[dim + 1:])
return output_tensor return output_tensor
def all_gatherv(
self,
input_: Union[torch.Tensor, list[torch.Tensor]],
dim: int = 0,
sizes: Optional[list[int]] = None
) -> Union[torch.Tensor, list[torch.Tensor]]:
raise NotImplementedError
def reduce_scatter(self, def reduce_scatter(self,
input_: torch.Tensor, input_: torch.Tensor,
dim: int = -1) -> torch.Tensor: dim: int = -1) -> torch.Tensor:
@ -172,6 +180,12 @@ class DeviceCommunicatorBase:
# Reshape before returning # Reshape before returning
return output_tensor.movedim(0, dim).contiguous() return output_tensor.movedim(0, dim).contiguous()
def reduce_scatterv(self,
input_: torch.Tensor,
dim: int = -1,
sizes: Optional[list[int]] = None) -> torch.Tensor:
raise NotImplementedError
def gather(self, def gather(self,
input_: torch.Tensor, input_: torch.Tensor,
dst: int = 0, dst: int = 0,

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional from typing import Optional, Union
import torch import torch
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
@ -142,6 +142,42 @@ class CudaCommunicator(DeviceCommunicatorBase):
# Reshape before returning # Reshape before returning
return output.movedim(0, dim).contiguous() return output.movedim(0, dim).contiguous()
def reduce_scatterv(self,
input_: torch.Tensor,
dim: int = -1,
sizes: Optional[list[int]] = None):
world_size = self.world_size
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
# Note: This will produce an incorrect answer if we don't make
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
input_tensor = input_.movedim(0, dim).contiguous()
if sizes is not None:
assert len(sizes) == world_size
assert input_tensor.shape[0] == sum(sizes)
chunk_size = sizes[self.rank_in_group]
else:
assert input_tensor.shape[0] % world_size == 0
chunk_size = input_tensor.shape[0] // world_size
output_shape = (chunk_size, ) + input_tensor.shape[1:]
output = torch.empty(output_shape,
dtype=input_tensor.dtype,
device=input_tensor.device)
if sizes is not None:
pynccl_comm.reduce_scatterv(output, input_, sizes=sizes)
else:
pynccl_comm.reduce_scatter(output, input_)
# Reshape before returning
return output.movedim(0, dim).contiguous()
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
"""Sends a tensor to the destination rank in a non-blocking way""" """Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the local rank of the destination rank.""" """NOTE: `dst` is the local rank of the destination rank."""
@ -180,6 +216,51 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.all2all_manager.destroy() self.all2all_manager.destroy()
self.all2all_manager = None self.all2all_manager = None
def all_gatherv(self,
input_: Union[torch.Tensor, list[torch.Tensor]],
dim: int = 0,
sizes: Optional[list[int]] = None):
if dim != 0:
raise NotImplementedError("only dim 0 all-gatherv is supported")
world_size = self.world_size
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None and not pynccl_comm.disabled
# 'sizes' is not needed if all inputs in the same group have the same
# shape
if sizes is not None and all(s == sizes[0] for s in sizes):
sizes = None
def _all_gather_single(input_: torch.Tensor,
sizes: Optional[list[int]] = None):
input_size = input_.size()
if sizes is not None:
assert len(sizes) == world_size
assert input_.shape[dim] == sizes[self.rank_in_group]
output_size = (sum(sizes), ) + input_size[1:]
else:
output_size = (input_size[0] * world_size, ) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty(output_size,
dtype=input_.dtype,
device=input_.device)
if sizes is not None:
pynccl_comm.all_gatherv(output_tensor, input_, sizes=sizes)
else:
pynccl_comm.all_gather(output_tensor, input_)
return output_tensor
if isinstance(input_, torch.Tensor):
return _all_gather_single(input_, sizes)
output_list = []
pynccl_comm.group_start()
for inp in input_:
output_list.append(_all_gather_single(inp, sizes=sizes))
pynccl_comm.group_end()
return output_list
def dispatch( def dispatch(
self, hidden_states: torch.Tensor, self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:

View File

@ -152,6 +152,40 @@ class PyNcclCommunicator:
ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm, ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm,
cudaStream_t(stream.cuda_stream)) cudaStream_t(stream.cuda_stream))
def all_gatherv(
self,
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
sizes: list[int],
stream=None,
):
if self.disabled:
return
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert input_tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}")
if stream is None:
stream = current_stream()
assert output_tensor.shape[0] == sum(sizes)
split_offset = 0
self.nccl.ncclGroupStart()
for root, split_size in enumerate(sizes):
dst_slice = output_tensor[split_offset:split_offset + split_size]
self.nccl.ncclBroadcast(
buffer_type(input_tensor.data_ptr()),
buffer_type(dst_slice.data_ptr()),
dst_slice.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
root,
self.comm,
cudaStream_t(stream.cuda_stream),
)
split_offset += split_size
self.nccl.ncclGroupEnd()
def reduce_scatter(self, def reduce_scatter(self,
output_tensor: torch.Tensor, output_tensor: torch.Tensor,
input_tensor: torch.Tensor, input_tensor: torch.Tensor,
@ -174,6 +208,38 @@ class PyNcclCommunicator:
ncclRedOpTypeEnum.from_torch(op), self.comm, ncclRedOpTypeEnum.from_torch(op), self.comm,
cudaStream_t(stream.cuda_stream)) cudaStream_t(stream.cuda_stream))
def reduce_scatterv(
self,
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
sizes: list[int],
op: ReduceOp = ReduceOp.SUM,
stream=None,
):
if self.disabled:
return
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert input_tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}")
if stream is None:
stream = current_stream()
split_offset = 0
self.nccl.ncclGroupStart()
for root, split_size in enumerate(sizes):
chunk = input_tensor[split_offset:split_offset + split_size, ...]
self.nccl.ncclReduce(
buffer_type(chunk.data_ptr()),
buffer_type(output_tensor.data_ptr()), chunk.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
ncclRedOpTypeEnum.from_torch(op), root, self.comm,
cudaStream_t(stream.cuda_stream))
split_offset += split_size
self.nccl.ncclGroupEnd()
def send(self, tensor: torch.Tensor, dst: int, stream=None): def send(self, tensor: torch.Tensor, dst: int, stream=None):
if self.disabled: if self.disabled:
return return
@ -216,3 +282,9 @@ class PyNcclCommunicator:
self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(), self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), src, ncclDataTypeEnum.from_torch(tensor.dtype), src,
self.comm, cudaStream_t(stream.cuda_stream)) self.comm, cudaStream_t(stream.cuda_stream))
def group_start(self):
self.nccl.ncclGroupStart()
def group_end(self):
self.nccl.ncclGroupEnd()

View File

@ -154,6 +154,17 @@ class NCCLLibrary:
ncclRedOp_t, ncclComm_t, cudaStream_t ncclRedOp_t, ncclComm_t, cudaStream_t
]), ]),
# ncclResult_t ncclReduce(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, int root,
# ncclComm_t comm, cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function("ncclReduce", ncclResult_t, [
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
ncclRedOp_t, ctypes.c_int, ncclComm_t, cudaStream_t
]),
# ncclResult_t ncclAllGather( # ncclResult_t ncclAllGather(
# const void* sendbuff, void* recvbuff, size_t count, # const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclComm_t comm, # ncclDataType_t datatype, ncclComm_t comm,
@ -207,6 +218,10 @@ class NCCLLibrary:
# it is better not to call it at all. # it is better not to call it at all.
# ncclResult_t ncclCommDestroy(ncclComm_t comm); # ncclResult_t ncclCommDestroy(ncclComm_t comm);
Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]), Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
# ncclResult_t ncclGroupStart();
Function("ncclGroupStart", ncclResult_t, []),
# ncclResult_t ncclGroupEnd();
Function("ncclGroupEnd", ncclResult_t, []),
] ]
# class attribute to store the mapping from the path to the library # class attribute to store the mapping from the path to the library
@ -300,6 +315,18 @@ class NCCLLibrary:
datatype, op, comm, datatype, op, comm,
stream)) stream))
def ncclReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, op: int, root: int,
comm: ncclComm_t, stream: cudaStream_t) -> None:
# `datatype` actually should be `ncclDataType_t`
# and `op` should be `ncclRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.NCCL_CHECK(self._funcs["ncclReduce"](sendbuff, recvbuff, count,
datatype, op, root, comm,
stream))
def ncclReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type, def ncclReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, op: int, comm: ncclComm_t, count: int, datatype: int, op: int, comm: ncclComm_t,
stream: cudaStream_t) -> None: stream: cudaStream_t) -> None:
@ -342,6 +369,12 @@ class NCCLLibrary:
def ncclCommDestroy(self, comm: ncclComm_t) -> None: def ncclCommDestroy(self, comm: ncclComm_t) -> None:
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
def ncclGroupStart(self) -> None:
self.NCCL_CHECK(self._funcs["ncclGroupStart"]())
def ncclGroupEnd(self) -> None:
self.NCCL_CHECK(self._funcs["ncclGroupEnd"]())
__all__ = [ __all__ = [
"NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId", "NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId",

View File

@ -383,6 +383,12 @@ class GroupCoordinator:
dim: int) -> torch.Tensor: dim: int) -> torch.Tensor:
return self.device_communicator.all_gather(input_, dim) return self.device_communicator.all_gather(input_, dim)
def all_gatherv(self,
input_: Union[torch.Tensor, list[torch.Tensor]],
dim: int = 0,
sizes: Optional[list[int]] = None):
return self.device_communicator.all_gatherv(input_, dim, sizes)
def reduce_scatter(self, def reduce_scatter(self,
input_: torch.Tensor, input_: torch.Tensor,
dim: int = -1) -> torch.Tensor: dim: int = -1) -> torch.Tensor:
@ -401,6 +407,12 @@ class GroupCoordinator:
else: else:
return self._reduce_scatter_out_place(input_, dim) return self._reduce_scatter_out_place(input_, dim)
def reduce_scatterv(self,
input_: torch.Tensor,
dim: int = -1,
sizes: Optional[list[int]] = None) -> torch.Tensor:
return self.device_communicator.reduce_scatterv(input_, dim, sizes)
def _reduce_scatter_out_place(self, input_: torch.Tensor, def _reduce_scatter_out_place(self, input_: torch.Tensor,
dim: int) -> torch.Tensor: dim: int) -> torch.Tensor:
return self.device_communicator.reduce_scatter(input_, dim) return self.device_communicator.reduce_scatter(input_, dim)