mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:15:26 +08:00
[Misc] Add pynccl wrappers for all_gather and reduce_scatter (#9432)
This commit is contained in:
parent
ebda51968b
commit
978b39744b
@ -150,6 +150,75 @@ def worker_fn_with_cudagraph():
|
||||
assert a.mean().cpu().item() == pynccl_comm.world_size**1
|
||||
|
||||
|
||||
@worker_fn_wrapper
|
||||
def all_gather_worker_fn():
|
||||
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
||||
device=get_world_group().device)
|
||||
|
||||
rank = pynccl_comm.rank
|
||||
world_size = pynccl_comm.world_size
|
||||
device = f'cuda:{pynccl_comm.rank}'
|
||||
|
||||
num_elems = 1000
|
||||
tensor = torch.arange(num_elems, dtype=torch.float32,
|
||||
device=device) + rank * num_elems
|
||||
result = torch.zeros(num_elems * world_size,
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
|
||||
expected = torch.cat([
|
||||
torch.arange(num_elems, dtype=torch.float32) + r * num_elems
|
||||
for r in range(world_size)
|
||||
]).to(device)
|
||||
|
||||
with pynccl_comm.change_state(enable=True):
|
||||
pynccl_comm.all_gather(result, tensor)
|
||||
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
def test_pynccl_all_gather():
|
||||
distributed_run(all_gather_worker_fn, 2)
|
||||
|
||||
|
||||
@worker_fn_wrapper
|
||||
def reduce_scatter_worker_fn():
|
||||
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
||||
device=get_world_group().device)
|
||||
|
||||
rank = pynccl_comm.rank
|
||||
world_size = pynccl_comm.world_size
|
||||
device = f'cuda:{pynccl_comm.rank}'
|
||||
|
||||
num_elems = 1000
|
||||
tensor = torch.arange(num_elems, dtype=torch.float32,
|
||||
device=device) + rank * num_elems
|
||||
assert (num_elems % world_size == 0)
|
||||
result = torch.zeros(num_elems // world_size,
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
|
||||
# Calculate expected result for this rank's chunk
|
||||
scattered_size = num_elems // world_size
|
||||
all_tensors = [
|
||||
torch.arange(num_elems, dtype=torch.float32) + r * num_elems
|
||||
for r in range(world_size)
|
||||
]
|
||||
expected = sum(tensor[rank * scattered_size:(rank + 1) * scattered_size]
|
||||
for tensor in all_tensors).to(device)
|
||||
|
||||
with pynccl_comm.change_state(enable=True):
|
||||
pynccl_comm.reduce_scatter(result, tensor)
|
||||
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
def test_pynccl_reduce_scatter():
|
||||
distributed_run(reduce_scatter_worker_fn, 2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
def test_pynccl_with_cudagraph():
|
||||
|
||||
@ -131,6 +131,48 @@ class PyNcclCommunicator:
|
||||
ncclRedOpTypeEnum.from_torch(op), self.comm,
|
||||
cudaStream_t(stream.cuda_stream))
|
||||
|
||||
def all_gather(self,
|
||||
output_tensor: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
# nccl communicator created on a specific device
|
||||
# will only work on tensors on the same device
|
||||
# otherwise it will cause "illegal memory access"
|
||||
assert input_tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {input_tensor.device}")
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
self.nccl.ncclAllGather(
|
||||
buffer_type(input_tensor.data_ptr()),
|
||||
buffer_type(output_tensor.data_ptr()), input_tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm,
|
||||
cudaStream_t(stream.cuda_stream))
|
||||
|
||||
def reduce_scatter(self,
|
||||
output_tensor: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
# nccl communicator created on a specific device
|
||||
# will only work on tensors on the same device
|
||||
# otherwise it will cause "illegal memory access"
|
||||
assert input_tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {input_tensor.device}")
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
self.nccl.ncclReduceScatter(
|
||||
buffer_type(input_tensor.data_ptr()),
|
||||
buffer_type(output_tensor.data_ptr()), output_tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(input_tensor.dtype),
|
||||
ncclRedOpTypeEnum.from_torch(op), self.comm,
|
||||
cudaStream_t(stream.cuda_stream))
|
||||
|
||||
def send(self, tensor: torch.Tensor, dst: int, stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
|
||||
@ -151,6 +151,28 @@ class NCCLLibrary:
|
||||
ncclRedOp_t, ncclComm_t, cudaStream_t
|
||||
]),
|
||||
|
||||
# ncclResult_t ncclAllGather(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# ncclDataType_t datatype, ncclComm_t comm,
|
||||
# cudaStream_t stream);
|
||||
# note that cudaStream_t is a pointer type, so the last argument
|
||||
# is a pointer
|
||||
Function("ncclAllGather", ncclResult_t, [
|
||||
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
|
||||
ncclComm_t, cudaStream_t
|
||||
]),
|
||||
|
||||
# ncclResult_t ncclReduceScatter(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
|
||||
# cudaStream_t stream);
|
||||
# note that cudaStream_t is a pointer type, so the last argument
|
||||
# is a pointer
|
||||
Function("ncclReduceScatter", ncclResult_t, [
|
||||
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
|
||||
ncclRedOp_t, ncclComm_t, cudaStream_t
|
||||
]),
|
||||
|
||||
# ncclResult_t ncclSend(
|
||||
# const void* sendbuff, size_t count, ncclDataType_t datatype,
|
||||
# int dest, ncclComm_t comm, cudaStream_t stream);
|
||||
@ -258,6 +280,28 @@ class NCCLLibrary:
|
||||
datatype, op, comm,
|
||||
stream))
|
||||
|
||||
def ncclReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
||||
count: int, datatype: int, op: int, comm: ncclComm_t,
|
||||
stream: cudaStream_t) -> None:
|
||||
# `datatype` actually should be `ncclDataType_t`
|
||||
# and `op` should be `ncclRedOp_t`
|
||||
# both are aliases of `ctypes.c_int`
|
||||
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||
# by ctypes automatically
|
||||
self.NCCL_CHECK(self._funcs["ncclReduceScatter"](sendbuff, recvbuff,
|
||||
count, datatype, op,
|
||||
comm, stream))
|
||||
|
||||
def ncclAllGather(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
||||
count: int, datatype: int, comm: ncclComm_t,
|
||||
stream: cudaStream_t) -> None:
|
||||
# `datatype` actually should be `ncclDataType_t`
|
||||
# which is an aliases of `ctypes.c_int`
|
||||
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||
# by ctypes automatically
|
||||
self.NCCL_CHECK(self._funcs["ncclAllGather"](sendbuff, recvbuff, count,
|
||||
datatype, comm, stream))
|
||||
|
||||
def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int,
|
||||
dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None:
|
||||
self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user