Add pynccl all-gatherv and reducescatterv (#20154)

Signed-off-by: Trevor Morris <tmorris@nvidia.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Trevor Morris 2025-07-11 18:59:23 -07:00 committed by GitHub
parent fc0f41d10a
commit a8593237c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 284 additions and 2 deletions

View File

@ -4,6 +4,7 @@
import multiprocessing
import os
import numpy as np
import pytest
import torch
import torch.distributed
@ -177,6 +178,38 @@ def test_pynccl_all_gather():
distributed_run(all_gather_worker_fn, 2)
@worker_fn_wrapper
def all_gatherv_worker_fn():
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
device=get_world_group().device)
rank = pynccl_comm.rank
world_size = pynccl_comm.world_size
device = f'cuda:{pynccl_comm.rank}'
assert world_size <= 8
sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size]
num_elems = sizes[rank]
tensor = torch.arange(num_elems, dtype=torch.float32,
device=device) + rank * 100
result = torch.zeros(sum(sizes), dtype=torch.float32, device=device)
expected = torch.cat([
torch.arange(sizes[r], dtype=torch.float32) + r * 100
for r in range(world_size)
]).to(device)
pynccl_comm.all_gatherv(result, tensor, sizes=sizes)
torch.cuda.synchronize()
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
def test_pynccl_all_gatherv():
distributed_run(all_gatherv_worker_fn, 2)
@worker_fn_wrapper
def reduce_scatter_worker_fn():
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
@ -214,6 +247,43 @@ def test_pynccl_reduce_scatter():
distributed_run(reduce_scatter_worker_fn, 2)
@worker_fn_wrapper
def reduce_scatterv_worker_fn():
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
device=get_world_group().device)
rank = pynccl_comm.rank
world_size = pynccl_comm.world_size
device = f'cuda:{pynccl_comm.rank}'
assert world_size <= 8
sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size]
num_elems = sum(sizes)
tensor = torch.arange(num_elems, dtype=torch.float32,
device=device) + rank * 100
result = torch.zeros(sizes[rank], dtype=torch.float32, device=device)
# Calculate expected result for this rank's chunk
all_tensors = [
torch.arange(num_elems, dtype=torch.float32) + r * 100
for r in range(world_size)
]
sizes_cumsum = np.cumsum(sizes)
start = 0 if rank == 0 else sizes_cumsum[rank - 1]
end = sizes_cumsum[rank]
expected = sum(tensor[start:end] for tensor in all_tensors).to(device)
pynccl_comm.reduce_scatterv(result, tensor, sizes=sizes)
torch.cuda.synchronize()
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
def test_pynccl_reduce_scatterv():
distributed_run(reduce_scatterv_worker_fn, 2)
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
def test_pynccl_with_cudagraph():

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading
from typing import Optional
from typing import Optional, Union
from weakref import WeakValueDictionary
import torch
@ -138,6 +138,14 @@ class DeviceCommunicatorBase:
input_size[dim + 1:])
return output_tensor
def all_gatherv(
self,
input_: Union[torch.Tensor, list[torch.Tensor]],
dim: int = 0,
sizes: Optional[list[int]] = None
) -> Union[torch.Tensor, list[torch.Tensor]]:
raise NotImplementedError
def reduce_scatter(self,
input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
@ -172,6 +180,12 @@ class DeviceCommunicatorBase:
# Reshape before returning
return output_tensor.movedim(0, dim).contiguous()
def reduce_scatterv(self,
input_: torch.Tensor,
dim: int = -1,
sizes: Optional[list[int]] = None) -> torch.Tensor:
raise NotImplementedError
def gather(self,
input_: torch.Tensor,
dst: int = 0,

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from typing import Optional, Union
import torch
from torch.distributed import ProcessGroup
@ -142,6 +142,42 @@ class CudaCommunicator(DeviceCommunicatorBase):
# Reshape before returning
return output.movedim(0, dim).contiguous()
def reduce_scatterv(self,
input_: torch.Tensor,
dim: int = -1,
sizes: Optional[list[int]] = None):
world_size = self.world_size
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
# Note: This will produce an incorrect answer if we don't make
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
input_tensor = input_.movedim(0, dim).contiguous()
if sizes is not None:
assert len(sizes) == world_size
assert input_tensor.shape[0] == sum(sizes)
chunk_size = sizes[self.rank_in_group]
else:
assert input_tensor.shape[0] % world_size == 0
chunk_size = input_tensor.shape[0] // world_size
output_shape = (chunk_size, ) + input_tensor.shape[1:]
output = torch.empty(output_shape,
dtype=input_tensor.dtype,
device=input_tensor.device)
if sizes is not None:
pynccl_comm.reduce_scatterv(output, input_, sizes=sizes)
else:
pynccl_comm.reduce_scatter(output, input_)
# Reshape before returning
return output.movedim(0, dim).contiguous()
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
"""Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""
@ -180,6 +216,51 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.all2all_manager.destroy()
self.all2all_manager = None
def all_gatherv(self,
input_: Union[torch.Tensor, list[torch.Tensor]],
dim: int = 0,
sizes: Optional[list[int]] = None):
if dim != 0:
raise NotImplementedError("only dim 0 all-gatherv is supported")
world_size = self.world_size
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None and not pynccl_comm.disabled
# 'sizes' is not needed if all inputs in the same group have the same
# shape
if sizes is not None and all(s == sizes[0] for s in sizes):
sizes = None
def _all_gather_single(input_: torch.Tensor,
sizes: Optional[list[int]] = None):
input_size = input_.size()
if sizes is not None:
assert len(sizes) == world_size
assert input_.shape[dim] == sizes[self.rank_in_group]
output_size = (sum(sizes), ) + input_size[1:]
else:
output_size = (input_size[0] * world_size, ) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty(output_size,
dtype=input_.dtype,
device=input_.device)
if sizes is not None:
pynccl_comm.all_gatherv(output_tensor, input_, sizes=sizes)
else:
pynccl_comm.all_gather(output_tensor, input_)
return output_tensor
if isinstance(input_, torch.Tensor):
return _all_gather_single(input_, sizes)
output_list = []
pynccl_comm.group_start()
for inp in input_:
output_list.append(_all_gather_single(inp, sizes=sizes))
pynccl_comm.group_end()
return output_list
def dispatch(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:

View File

@ -152,6 +152,40 @@ class PyNcclCommunicator:
ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm,
cudaStream_t(stream.cuda_stream))
def all_gatherv(
self,
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
sizes: list[int],
stream=None,
):
if self.disabled:
return
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert input_tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}")
if stream is None:
stream = current_stream()
assert output_tensor.shape[0] == sum(sizes)
split_offset = 0
self.nccl.ncclGroupStart()
for root, split_size in enumerate(sizes):
dst_slice = output_tensor[split_offset:split_offset + split_size]
self.nccl.ncclBroadcast(
buffer_type(input_tensor.data_ptr()),
buffer_type(dst_slice.data_ptr()),
dst_slice.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
root,
self.comm,
cudaStream_t(stream.cuda_stream),
)
split_offset += split_size
self.nccl.ncclGroupEnd()
def reduce_scatter(self,
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
@ -174,6 +208,38 @@ class PyNcclCommunicator:
ncclRedOpTypeEnum.from_torch(op), self.comm,
cudaStream_t(stream.cuda_stream))
def reduce_scatterv(
self,
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
sizes: list[int],
op: ReduceOp = ReduceOp.SUM,
stream=None,
):
if self.disabled:
return
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert input_tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}")
if stream is None:
stream = current_stream()
split_offset = 0
self.nccl.ncclGroupStart()
for root, split_size in enumerate(sizes):
chunk = input_tensor[split_offset:split_offset + split_size, ...]
self.nccl.ncclReduce(
buffer_type(chunk.data_ptr()),
buffer_type(output_tensor.data_ptr()), chunk.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
ncclRedOpTypeEnum.from_torch(op), root, self.comm,
cudaStream_t(stream.cuda_stream))
split_offset += split_size
self.nccl.ncclGroupEnd()
def send(self, tensor: torch.Tensor, dst: int, stream=None):
if self.disabled:
return
@ -216,3 +282,9 @@ class PyNcclCommunicator:
self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), src,
self.comm, cudaStream_t(stream.cuda_stream))
def group_start(self):
self.nccl.ncclGroupStart()
def group_end(self):
self.nccl.ncclGroupEnd()

View File

@ -154,6 +154,17 @@ class NCCLLibrary:
ncclRedOp_t, ncclComm_t, cudaStream_t
]),
# ncclResult_t ncclReduce(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, int root,
# ncclComm_t comm, cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function("ncclReduce", ncclResult_t, [
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
ncclRedOp_t, ctypes.c_int, ncclComm_t, cudaStream_t
]),
# ncclResult_t ncclAllGather(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclComm_t comm,
@ -207,6 +218,10 @@ class NCCLLibrary:
# it is better not to call it at all.
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
# ncclResult_t ncclGroupStart();
Function("ncclGroupStart", ncclResult_t, []),
# ncclResult_t ncclGroupEnd();
Function("ncclGroupEnd", ncclResult_t, []),
]
# class attribute to store the mapping from the path to the library
@ -300,6 +315,18 @@ class NCCLLibrary:
datatype, op, comm,
stream))
def ncclReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, op: int, root: int,
comm: ncclComm_t, stream: cudaStream_t) -> None:
# `datatype` actually should be `ncclDataType_t`
# and `op` should be `ncclRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.NCCL_CHECK(self._funcs["ncclReduce"](sendbuff, recvbuff, count,
datatype, op, root, comm,
stream))
def ncclReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, op: int, comm: ncclComm_t,
stream: cudaStream_t) -> None:
@ -342,6 +369,12 @@ class NCCLLibrary:
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
def ncclGroupStart(self) -> None:
self.NCCL_CHECK(self._funcs["ncclGroupStart"]())
def ncclGroupEnd(self) -> None:
self.NCCL_CHECK(self._funcs["ncclGroupEnd"]())
__all__ = [
"NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId",

View File

@ -383,6 +383,12 @@ class GroupCoordinator:
dim: int) -> torch.Tensor:
return self.device_communicator.all_gather(input_, dim)
def all_gatherv(self,
input_: Union[torch.Tensor, list[torch.Tensor]],
dim: int = 0,
sizes: Optional[list[int]] = None):
return self.device_communicator.all_gatherv(input_, dim, sizes)
def reduce_scatter(self,
input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
@ -401,6 +407,12 @@ class GroupCoordinator:
else:
return self._reduce_scatter_out_place(input_, dim)
def reduce_scatterv(self,
input_: torch.Tensor,
dim: int = -1,
sizes: Optional[list[int]] = None) -> torch.Tensor:
return self.device_communicator.reduce_scatterv(input_, dim, sizes)
def _reduce_scatter_out_place(self, input_: torch.Tensor,
dim: int) -> torch.Tensor:
return self.device_communicator.reduce_scatter(input_, dim)