mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 02:05:00 +08:00
Add pynccl all-gatherv and reducescatterv (#20154)
Signed-off-by: Trevor Morris <tmorris@nvidia.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
fc0f41d10a
commit
a8593237c0
@ -4,6 +4,7 @@
|
|||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
@ -177,6 +178,38 @@ def test_pynccl_all_gather():
|
|||||||
distributed_run(all_gather_worker_fn, 2)
|
distributed_run(all_gather_worker_fn, 2)
|
||||||
|
|
||||||
|
|
||||||
|
@worker_fn_wrapper
|
||||||
|
def all_gatherv_worker_fn():
|
||||||
|
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
||||||
|
device=get_world_group().device)
|
||||||
|
|
||||||
|
rank = pynccl_comm.rank
|
||||||
|
world_size = pynccl_comm.world_size
|
||||||
|
device = f'cuda:{pynccl_comm.rank}'
|
||||||
|
|
||||||
|
assert world_size <= 8
|
||||||
|
sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size]
|
||||||
|
num_elems = sizes[rank]
|
||||||
|
tensor = torch.arange(num_elems, dtype=torch.float32,
|
||||||
|
device=device) + rank * 100
|
||||||
|
result = torch.zeros(sum(sizes), dtype=torch.float32, device=device)
|
||||||
|
|
||||||
|
expected = torch.cat([
|
||||||
|
torch.arange(sizes[r], dtype=torch.float32) + r * 100
|
||||||
|
for r in range(world_size)
|
||||||
|
]).to(device)
|
||||||
|
|
||||||
|
pynccl_comm.all_gatherv(result, tensor, sizes=sizes)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||||
|
reason="Need at least 2 GPUs to run the test.")
|
||||||
|
def test_pynccl_all_gatherv():
|
||||||
|
distributed_run(all_gatherv_worker_fn, 2)
|
||||||
|
|
||||||
|
|
||||||
@worker_fn_wrapper
|
@worker_fn_wrapper
|
||||||
def reduce_scatter_worker_fn():
|
def reduce_scatter_worker_fn():
|
||||||
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
||||||
@ -214,6 +247,43 @@ def test_pynccl_reduce_scatter():
|
|||||||
distributed_run(reduce_scatter_worker_fn, 2)
|
distributed_run(reduce_scatter_worker_fn, 2)
|
||||||
|
|
||||||
|
|
||||||
|
@worker_fn_wrapper
|
||||||
|
def reduce_scatterv_worker_fn():
|
||||||
|
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
||||||
|
device=get_world_group().device)
|
||||||
|
|
||||||
|
rank = pynccl_comm.rank
|
||||||
|
world_size = pynccl_comm.world_size
|
||||||
|
device = f'cuda:{pynccl_comm.rank}'
|
||||||
|
|
||||||
|
assert world_size <= 8
|
||||||
|
sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size]
|
||||||
|
num_elems = sum(sizes)
|
||||||
|
tensor = torch.arange(num_elems, dtype=torch.float32,
|
||||||
|
device=device) + rank * 100
|
||||||
|
result = torch.zeros(sizes[rank], dtype=torch.float32, device=device)
|
||||||
|
|
||||||
|
# Calculate expected result for this rank's chunk
|
||||||
|
all_tensors = [
|
||||||
|
torch.arange(num_elems, dtype=torch.float32) + r * 100
|
||||||
|
for r in range(world_size)
|
||||||
|
]
|
||||||
|
sizes_cumsum = np.cumsum(sizes)
|
||||||
|
start = 0 if rank == 0 else sizes_cumsum[rank - 1]
|
||||||
|
end = sizes_cumsum[rank]
|
||||||
|
expected = sum(tensor[start:end] for tensor in all_tensors).to(device)
|
||||||
|
|
||||||
|
pynccl_comm.reduce_scatterv(result, tensor, sizes=sizes)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||||
|
reason="Need at least 2 GPUs to run the test.")
|
||||||
|
def test_pynccl_reduce_scatterv():
|
||||||
|
distributed_run(reduce_scatterv_worker_fn, 2)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||||
reason="Need at least 2 GPUs to run the test.")
|
reason="Need at least 2 GPUs to run the test.")
|
||||||
def test_pynccl_with_cudagraph():
|
def test_pynccl_with_cudagraph():
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import threading
|
import threading
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
from weakref import WeakValueDictionary
|
from weakref import WeakValueDictionary
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -138,6 +138,14 @@ class DeviceCommunicatorBase:
|
|||||||
input_size[dim + 1:])
|
input_size[dim + 1:])
|
||||||
return output_tensor
|
return output_tensor
|
||||||
|
|
||||||
|
def all_gatherv(
|
||||||
|
self,
|
||||||
|
input_: Union[torch.Tensor, list[torch.Tensor]],
|
||||||
|
dim: int = 0,
|
||||||
|
sizes: Optional[list[int]] = None
|
||||||
|
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def reduce_scatter(self,
|
def reduce_scatter(self,
|
||||||
input_: torch.Tensor,
|
input_: torch.Tensor,
|
||||||
dim: int = -1) -> torch.Tensor:
|
dim: int = -1) -> torch.Tensor:
|
||||||
@ -172,6 +180,12 @@ class DeviceCommunicatorBase:
|
|||||||
# Reshape before returning
|
# Reshape before returning
|
||||||
return output_tensor.movedim(0, dim).contiguous()
|
return output_tensor.movedim(0, dim).contiguous()
|
||||||
|
|
||||||
|
def reduce_scatterv(self,
|
||||||
|
input_: torch.Tensor,
|
||||||
|
dim: int = -1,
|
||||||
|
sizes: Optional[list[int]] = None) -> torch.Tensor:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def gather(self,
|
def gather(self,
|
||||||
input_: torch.Tensor,
|
input_: torch.Tensor,
|
||||||
dst: int = 0,
|
dst: int = 0,
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
@ -142,6 +142,42 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
|||||||
# Reshape before returning
|
# Reshape before returning
|
||||||
return output.movedim(0, dim).contiguous()
|
return output.movedim(0, dim).contiguous()
|
||||||
|
|
||||||
|
def reduce_scatterv(self,
|
||||||
|
input_: torch.Tensor,
|
||||||
|
dim: int = -1,
|
||||||
|
sizes: Optional[list[int]] = None):
|
||||||
|
world_size = self.world_size
|
||||||
|
pynccl_comm = self.pynccl_comm
|
||||||
|
assert pynccl_comm is not None
|
||||||
|
if dim < 0:
|
||||||
|
# Convert negative dim to positive.
|
||||||
|
dim += input_.dim()
|
||||||
|
|
||||||
|
# Note: This will produce an incorrect answer if we don't make
|
||||||
|
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
|
||||||
|
input_tensor = input_.movedim(0, dim).contiguous()
|
||||||
|
|
||||||
|
if sizes is not None:
|
||||||
|
assert len(sizes) == world_size
|
||||||
|
assert input_tensor.shape[0] == sum(sizes)
|
||||||
|
chunk_size = sizes[self.rank_in_group]
|
||||||
|
else:
|
||||||
|
assert input_tensor.shape[0] % world_size == 0
|
||||||
|
chunk_size = input_tensor.shape[0] // world_size
|
||||||
|
output_shape = (chunk_size, ) + input_tensor.shape[1:]
|
||||||
|
|
||||||
|
output = torch.empty(output_shape,
|
||||||
|
dtype=input_tensor.dtype,
|
||||||
|
device=input_tensor.device)
|
||||||
|
|
||||||
|
if sizes is not None:
|
||||||
|
pynccl_comm.reduce_scatterv(output, input_, sizes=sizes)
|
||||||
|
else:
|
||||||
|
pynccl_comm.reduce_scatter(output, input_)
|
||||||
|
|
||||||
|
# Reshape before returning
|
||||||
|
return output.movedim(0, dim).contiguous()
|
||||||
|
|
||||||
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
|
def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
|
||||||
"""Sends a tensor to the destination rank in a non-blocking way"""
|
"""Sends a tensor to the destination rank in a non-blocking way"""
|
||||||
"""NOTE: `dst` is the local rank of the destination rank."""
|
"""NOTE: `dst` is the local rank of the destination rank."""
|
||||||
@ -180,6 +216,51 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
|||||||
self.all2all_manager.destroy()
|
self.all2all_manager.destroy()
|
||||||
self.all2all_manager = None
|
self.all2all_manager = None
|
||||||
|
|
||||||
|
def all_gatherv(self,
|
||||||
|
input_: Union[torch.Tensor, list[torch.Tensor]],
|
||||||
|
dim: int = 0,
|
||||||
|
sizes: Optional[list[int]] = None):
|
||||||
|
if dim != 0:
|
||||||
|
raise NotImplementedError("only dim 0 all-gatherv is supported")
|
||||||
|
world_size = self.world_size
|
||||||
|
pynccl_comm = self.pynccl_comm
|
||||||
|
assert pynccl_comm is not None and not pynccl_comm.disabled
|
||||||
|
|
||||||
|
# 'sizes' is not needed if all inputs in the same group have the same
|
||||||
|
# shape
|
||||||
|
if sizes is not None and all(s == sizes[0] for s in sizes):
|
||||||
|
sizes = None
|
||||||
|
|
||||||
|
def _all_gather_single(input_: torch.Tensor,
|
||||||
|
sizes: Optional[list[int]] = None):
|
||||||
|
input_size = input_.size()
|
||||||
|
if sizes is not None:
|
||||||
|
assert len(sizes) == world_size
|
||||||
|
assert input_.shape[dim] == sizes[self.rank_in_group]
|
||||||
|
output_size = (sum(sizes), ) + input_size[1:]
|
||||||
|
else:
|
||||||
|
output_size = (input_size[0] * world_size, ) + input_size[1:]
|
||||||
|
# Allocate output tensor.
|
||||||
|
output_tensor = torch.empty(output_size,
|
||||||
|
dtype=input_.dtype,
|
||||||
|
device=input_.device)
|
||||||
|
if sizes is not None:
|
||||||
|
pynccl_comm.all_gatherv(output_tensor, input_, sizes=sizes)
|
||||||
|
else:
|
||||||
|
pynccl_comm.all_gather(output_tensor, input_)
|
||||||
|
return output_tensor
|
||||||
|
|
||||||
|
if isinstance(input_, torch.Tensor):
|
||||||
|
return _all_gather_single(input_, sizes)
|
||||||
|
|
||||||
|
output_list = []
|
||||||
|
pynccl_comm.group_start()
|
||||||
|
for inp in input_:
|
||||||
|
output_list.append(_all_gather_single(inp, sizes=sizes))
|
||||||
|
pynccl_comm.group_end()
|
||||||
|
|
||||||
|
return output_list
|
||||||
|
|
||||||
def dispatch(
|
def dispatch(
|
||||||
self, hidden_states: torch.Tensor,
|
self, hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|||||||
@ -152,6 +152,40 @@ class PyNcclCommunicator:
|
|||||||
ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm,
|
ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm,
|
||||||
cudaStream_t(stream.cuda_stream))
|
cudaStream_t(stream.cuda_stream))
|
||||||
|
|
||||||
|
def all_gatherv(
|
||||||
|
self,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
input_tensor: torch.Tensor,
|
||||||
|
sizes: list[int],
|
||||||
|
stream=None,
|
||||||
|
):
|
||||||
|
if self.disabled:
|
||||||
|
return
|
||||||
|
# nccl communicator created on a specific device
|
||||||
|
# will only work on tensors on the same device
|
||||||
|
# otherwise it will cause "illegal memory access"
|
||||||
|
assert input_tensor.device == self.device, (
|
||||||
|
f"this nccl communicator is created to work on {self.device}, "
|
||||||
|
f"but the input tensor is on {input_tensor.device}")
|
||||||
|
if stream is None:
|
||||||
|
stream = current_stream()
|
||||||
|
assert output_tensor.shape[0] == sum(sizes)
|
||||||
|
split_offset = 0
|
||||||
|
self.nccl.ncclGroupStart()
|
||||||
|
for root, split_size in enumerate(sizes):
|
||||||
|
dst_slice = output_tensor[split_offset:split_offset + split_size]
|
||||||
|
self.nccl.ncclBroadcast(
|
||||||
|
buffer_type(input_tensor.data_ptr()),
|
||||||
|
buffer_type(dst_slice.data_ptr()),
|
||||||
|
dst_slice.numel(),
|
||||||
|
ncclDataTypeEnum.from_torch(input_tensor.dtype),
|
||||||
|
root,
|
||||||
|
self.comm,
|
||||||
|
cudaStream_t(stream.cuda_stream),
|
||||||
|
)
|
||||||
|
split_offset += split_size
|
||||||
|
self.nccl.ncclGroupEnd()
|
||||||
|
|
||||||
def reduce_scatter(self,
|
def reduce_scatter(self,
|
||||||
output_tensor: torch.Tensor,
|
output_tensor: torch.Tensor,
|
||||||
input_tensor: torch.Tensor,
|
input_tensor: torch.Tensor,
|
||||||
@ -174,6 +208,38 @@ class PyNcclCommunicator:
|
|||||||
ncclRedOpTypeEnum.from_torch(op), self.comm,
|
ncclRedOpTypeEnum.from_torch(op), self.comm,
|
||||||
cudaStream_t(stream.cuda_stream))
|
cudaStream_t(stream.cuda_stream))
|
||||||
|
|
||||||
|
def reduce_scatterv(
|
||||||
|
self,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
input_tensor: torch.Tensor,
|
||||||
|
sizes: list[int],
|
||||||
|
op: ReduceOp = ReduceOp.SUM,
|
||||||
|
stream=None,
|
||||||
|
):
|
||||||
|
if self.disabled:
|
||||||
|
return
|
||||||
|
# nccl communicator created on a specific device
|
||||||
|
# will only work on tensors on the same device
|
||||||
|
# otherwise it will cause "illegal memory access"
|
||||||
|
assert input_tensor.device == self.device, (
|
||||||
|
f"this nccl communicator is created to work on {self.device}, "
|
||||||
|
f"but the input tensor is on {input_tensor.device}")
|
||||||
|
if stream is None:
|
||||||
|
stream = current_stream()
|
||||||
|
|
||||||
|
split_offset = 0
|
||||||
|
self.nccl.ncclGroupStart()
|
||||||
|
for root, split_size in enumerate(sizes):
|
||||||
|
chunk = input_tensor[split_offset:split_offset + split_size, ...]
|
||||||
|
self.nccl.ncclReduce(
|
||||||
|
buffer_type(chunk.data_ptr()),
|
||||||
|
buffer_type(output_tensor.data_ptr()), chunk.numel(),
|
||||||
|
ncclDataTypeEnum.from_torch(input_tensor.dtype),
|
||||||
|
ncclRedOpTypeEnum.from_torch(op), root, self.comm,
|
||||||
|
cudaStream_t(stream.cuda_stream))
|
||||||
|
split_offset += split_size
|
||||||
|
self.nccl.ncclGroupEnd()
|
||||||
|
|
||||||
def send(self, tensor: torch.Tensor, dst: int, stream=None):
|
def send(self, tensor: torch.Tensor, dst: int, stream=None):
|
||||||
if self.disabled:
|
if self.disabled:
|
||||||
return
|
return
|
||||||
@ -216,3 +282,9 @@ class PyNcclCommunicator:
|
|||||||
self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(),
|
self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(),
|
||||||
ncclDataTypeEnum.from_torch(tensor.dtype), src,
|
ncclDataTypeEnum.from_torch(tensor.dtype), src,
|
||||||
self.comm, cudaStream_t(stream.cuda_stream))
|
self.comm, cudaStream_t(stream.cuda_stream))
|
||||||
|
|
||||||
|
def group_start(self):
|
||||||
|
self.nccl.ncclGroupStart()
|
||||||
|
|
||||||
|
def group_end(self):
|
||||||
|
self.nccl.ncclGroupEnd()
|
||||||
|
|||||||
@ -154,6 +154,17 @@ class NCCLLibrary:
|
|||||||
ncclRedOp_t, ncclComm_t, cudaStream_t
|
ncclRedOp_t, ncclComm_t, cudaStream_t
|
||||||
]),
|
]),
|
||||||
|
|
||||||
|
# ncclResult_t ncclReduce(
|
||||||
|
# const void* sendbuff, void* recvbuff, size_t count,
|
||||||
|
# ncclDataType_t datatype, ncclRedOp_t op, int root,
|
||||||
|
# ncclComm_t comm, cudaStream_t stream);
|
||||||
|
# note that cudaStream_t is a pointer type, so the last argument
|
||||||
|
# is a pointer
|
||||||
|
Function("ncclReduce", ncclResult_t, [
|
||||||
|
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
|
||||||
|
ncclRedOp_t, ctypes.c_int, ncclComm_t, cudaStream_t
|
||||||
|
]),
|
||||||
|
|
||||||
# ncclResult_t ncclAllGather(
|
# ncclResult_t ncclAllGather(
|
||||||
# const void* sendbuff, void* recvbuff, size_t count,
|
# const void* sendbuff, void* recvbuff, size_t count,
|
||||||
# ncclDataType_t datatype, ncclComm_t comm,
|
# ncclDataType_t datatype, ncclComm_t comm,
|
||||||
@ -207,6 +218,10 @@ class NCCLLibrary:
|
|||||||
# it is better not to call it at all.
|
# it is better not to call it at all.
|
||||||
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
|
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
|
||||||
Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
|
Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
|
||||||
|
# ncclResult_t ncclGroupStart();
|
||||||
|
Function("ncclGroupStart", ncclResult_t, []),
|
||||||
|
# ncclResult_t ncclGroupEnd();
|
||||||
|
Function("ncclGroupEnd", ncclResult_t, []),
|
||||||
]
|
]
|
||||||
|
|
||||||
# class attribute to store the mapping from the path to the library
|
# class attribute to store the mapping from the path to the library
|
||||||
@ -300,6 +315,18 @@ class NCCLLibrary:
|
|||||||
datatype, op, comm,
|
datatype, op, comm,
|
||||||
stream))
|
stream))
|
||||||
|
|
||||||
|
def ncclReduce(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
||||||
|
count: int, datatype: int, op: int, root: int,
|
||||||
|
comm: ncclComm_t, stream: cudaStream_t) -> None:
|
||||||
|
# `datatype` actually should be `ncclDataType_t`
|
||||||
|
# and `op` should be `ncclRedOp_t`
|
||||||
|
# both are aliases of `ctypes.c_int`
|
||||||
|
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||||
|
# by ctypes automatically
|
||||||
|
self.NCCL_CHECK(self._funcs["ncclReduce"](sendbuff, recvbuff, count,
|
||||||
|
datatype, op, root, comm,
|
||||||
|
stream))
|
||||||
|
|
||||||
def ncclReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
def ncclReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
||||||
count: int, datatype: int, op: int, comm: ncclComm_t,
|
count: int, datatype: int, op: int, comm: ncclComm_t,
|
||||||
stream: cudaStream_t) -> None:
|
stream: cudaStream_t) -> None:
|
||||||
@ -342,6 +369,12 @@ class NCCLLibrary:
|
|||||||
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
|
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
|
||||||
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
|
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
|
||||||
|
|
||||||
|
def ncclGroupStart(self) -> None:
|
||||||
|
self.NCCL_CHECK(self._funcs["ncclGroupStart"]())
|
||||||
|
|
||||||
|
def ncclGroupEnd(self) -> None:
|
||||||
|
self.NCCL_CHECK(self._funcs["ncclGroupEnd"]())
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId",
|
"NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId",
|
||||||
|
|||||||
@ -383,6 +383,12 @@ class GroupCoordinator:
|
|||||||
dim: int) -> torch.Tensor:
|
dim: int) -> torch.Tensor:
|
||||||
return self.device_communicator.all_gather(input_, dim)
|
return self.device_communicator.all_gather(input_, dim)
|
||||||
|
|
||||||
|
def all_gatherv(self,
|
||||||
|
input_: Union[torch.Tensor, list[torch.Tensor]],
|
||||||
|
dim: int = 0,
|
||||||
|
sizes: Optional[list[int]] = None):
|
||||||
|
return self.device_communicator.all_gatherv(input_, dim, sizes)
|
||||||
|
|
||||||
def reduce_scatter(self,
|
def reduce_scatter(self,
|
||||||
input_: torch.Tensor,
|
input_: torch.Tensor,
|
||||||
dim: int = -1) -> torch.Tensor:
|
dim: int = -1) -> torch.Tensor:
|
||||||
@ -401,6 +407,12 @@ class GroupCoordinator:
|
|||||||
else:
|
else:
|
||||||
return self._reduce_scatter_out_place(input_, dim)
|
return self._reduce_scatter_out_place(input_, dim)
|
||||||
|
|
||||||
|
def reduce_scatterv(self,
|
||||||
|
input_: torch.Tensor,
|
||||||
|
dim: int = -1,
|
||||||
|
sizes: Optional[list[int]] = None) -> torch.Tensor:
|
||||||
|
return self.device_communicator.reduce_scatterv(input_, dim, sizes)
|
||||||
|
|
||||||
def _reduce_scatter_out_place(self, input_: torch.Tensor,
|
def _reduce_scatter_out_place(self, input_: torch.Tensor,
|
||||||
dim: int) -> torch.Tensor:
|
dim: int) -> torch.Tensor:
|
||||||
return self.device_communicator.reduce_scatter(input_, dim)
|
return self.device_communicator.reduce_scatter(input_, dim)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user