mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 18:35:56 +08:00
[Misc] Add pynccl wrappers for all_gather and reduce_scatter (#9432)
This commit is contained in:
parent
ebda51968b
commit
978b39744b
@ -150,6 +150,75 @@ def worker_fn_with_cudagraph():
|
|||||||
assert a.mean().cpu().item() == pynccl_comm.world_size**1
|
assert a.mean().cpu().item() == pynccl_comm.world_size**1
|
||||||
|
|
||||||
|
|
||||||
|
@worker_fn_wrapper
|
||||||
|
def all_gather_worker_fn():
|
||||||
|
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
||||||
|
device=get_world_group().device)
|
||||||
|
|
||||||
|
rank = pynccl_comm.rank
|
||||||
|
world_size = pynccl_comm.world_size
|
||||||
|
device = f'cuda:{pynccl_comm.rank}'
|
||||||
|
|
||||||
|
num_elems = 1000
|
||||||
|
tensor = torch.arange(num_elems, dtype=torch.float32,
|
||||||
|
device=device) + rank * num_elems
|
||||||
|
result = torch.zeros(num_elems * world_size,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=device)
|
||||||
|
|
||||||
|
expected = torch.cat([
|
||||||
|
torch.arange(num_elems, dtype=torch.float32) + r * num_elems
|
||||||
|
for r in range(world_size)
|
||||||
|
]).to(device)
|
||||||
|
|
||||||
|
with pynccl_comm.change_state(enable=True):
|
||||||
|
pynccl_comm.all_gather(result, tensor)
|
||||||
|
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||||
|
reason="Need at least 2 GPUs to run the test.")
|
||||||
|
def test_pynccl_all_gather():
|
||||||
|
distributed_run(all_gather_worker_fn, 2)
|
||||||
|
|
||||||
|
|
||||||
|
@worker_fn_wrapper
|
||||||
|
def reduce_scatter_worker_fn():
|
||||||
|
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
||||||
|
device=get_world_group().device)
|
||||||
|
|
||||||
|
rank = pynccl_comm.rank
|
||||||
|
world_size = pynccl_comm.world_size
|
||||||
|
device = f'cuda:{pynccl_comm.rank}'
|
||||||
|
|
||||||
|
num_elems = 1000
|
||||||
|
tensor = torch.arange(num_elems, dtype=torch.float32,
|
||||||
|
device=device) + rank * num_elems
|
||||||
|
assert (num_elems % world_size == 0)
|
||||||
|
result = torch.zeros(num_elems // world_size,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=device)
|
||||||
|
|
||||||
|
# Calculate expected result for this rank's chunk
|
||||||
|
scattered_size = num_elems // world_size
|
||||||
|
all_tensors = [
|
||||||
|
torch.arange(num_elems, dtype=torch.float32) + r * num_elems
|
||||||
|
for r in range(world_size)
|
||||||
|
]
|
||||||
|
expected = sum(tensor[rank * scattered_size:(rank + 1) * scattered_size]
|
||||||
|
for tensor in all_tensors).to(device)
|
||||||
|
|
||||||
|
with pynccl_comm.change_state(enable=True):
|
||||||
|
pynccl_comm.reduce_scatter(result, tensor)
|
||||||
|
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||||
|
reason="Need at least 2 GPUs to run the test.")
|
||||||
|
def test_pynccl_reduce_scatter():
|
||||||
|
distributed_run(reduce_scatter_worker_fn, 2)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||||
reason="Need at least 2 GPUs to run the test.")
|
reason="Need at least 2 GPUs to run the test.")
|
||||||
def test_pynccl_with_cudagraph():
|
def test_pynccl_with_cudagraph():
|
||||||
|
|||||||
@ -131,6 +131,48 @@ class PyNcclCommunicator:
|
|||||||
ncclRedOpTypeEnum.from_torch(op), self.comm,
|
ncclRedOpTypeEnum.from_torch(op), self.comm,
|
||||||
cudaStream_t(stream.cuda_stream))
|
cudaStream_t(stream.cuda_stream))
|
||||||
|
|
||||||
|
def all_gather(self,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
input_tensor: torch.Tensor,
|
||||||
|
stream=None):
|
||||||
|
if self.disabled:
|
||||||
|
return
|
||||||
|
# nccl communicator created on a specific device
|
||||||
|
# will only work on tensors on the same device
|
||||||
|
# otherwise it will cause "illegal memory access"
|
||||||
|
assert input_tensor.device == self.device, (
|
||||||
|
f"this nccl communicator is created to work on {self.device}, "
|
||||||
|
f"but the input tensor is on {input_tensor.device}")
|
||||||
|
if stream is None:
|
||||||
|
stream = self.stream
|
||||||
|
self.nccl.ncclAllGather(
|
||||||
|
buffer_type(input_tensor.data_ptr()),
|
||||||
|
buffer_type(output_tensor.data_ptr()), input_tensor.numel(),
|
||||||
|
ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm,
|
||||||
|
cudaStream_t(stream.cuda_stream))
|
||||||
|
|
||||||
|
def reduce_scatter(self,
|
||||||
|
output_tensor: torch.Tensor,
|
||||||
|
input_tensor: torch.Tensor,
|
||||||
|
op: ReduceOp = ReduceOp.SUM,
|
||||||
|
stream=None):
|
||||||
|
if self.disabled:
|
||||||
|
return
|
||||||
|
# nccl communicator created on a specific device
|
||||||
|
# will only work on tensors on the same device
|
||||||
|
# otherwise it will cause "illegal memory access"
|
||||||
|
assert input_tensor.device == self.device, (
|
||||||
|
f"this nccl communicator is created to work on {self.device}, "
|
||||||
|
f"but the input tensor is on {input_tensor.device}")
|
||||||
|
if stream is None:
|
||||||
|
stream = self.stream
|
||||||
|
self.nccl.ncclReduceScatter(
|
||||||
|
buffer_type(input_tensor.data_ptr()),
|
||||||
|
buffer_type(output_tensor.data_ptr()), output_tensor.numel(),
|
||||||
|
ncclDataTypeEnum.from_torch(input_tensor.dtype),
|
||||||
|
ncclRedOpTypeEnum.from_torch(op), self.comm,
|
||||||
|
cudaStream_t(stream.cuda_stream))
|
||||||
|
|
||||||
def send(self, tensor: torch.Tensor, dst: int, stream=None):
|
def send(self, tensor: torch.Tensor, dst: int, stream=None):
|
||||||
if self.disabled:
|
if self.disabled:
|
||||||
return
|
return
|
||||||
|
|||||||
@ -151,6 +151,28 @@ class NCCLLibrary:
|
|||||||
ncclRedOp_t, ncclComm_t, cudaStream_t
|
ncclRedOp_t, ncclComm_t, cudaStream_t
|
||||||
]),
|
]),
|
||||||
|
|
||||||
|
# ncclResult_t ncclAllGather(
|
||||||
|
# const void* sendbuff, void* recvbuff, size_t count,
|
||||||
|
# ncclDataType_t datatype, ncclComm_t comm,
|
||||||
|
# cudaStream_t stream);
|
||||||
|
# note that cudaStream_t is a pointer type, so the last argument
|
||||||
|
# is a pointer
|
||||||
|
Function("ncclAllGather", ncclResult_t, [
|
||||||
|
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
|
||||||
|
ncclComm_t, cudaStream_t
|
||||||
|
]),
|
||||||
|
|
||||||
|
# ncclResult_t ncclReduceScatter(
|
||||||
|
# const void* sendbuff, void* recvbuff, size_t count,
|
||||||
|
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
|
||||||
|
# cudaStream_t stream);
|
||||||
|
# note that cudaStream_t is a pointer type, so the last argument
|
||||||
|
# is a pointer
|
||||||
|
Function("ncclReduceScatter", ncclResult_t, [
|
||||||
|
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
|
||||||
|
ncclRedOp_t, ncclComm_t, cudaStream_t
|
||||||
|
]),
|
||||||
|
|
||||||
# ncclResult_t ncclSend(
|
# ncclResult_t ncclSend(
|
||||||
# const void* sendbuff, size_t count, ncclDataType_t datatype,
|
# const void* sendbuff, size_t count, ncclDataType_t datatype,
|
||||||
# int dest, ncclComm_t comm, cudaStream_t stream);
|
# int dest, ncclComm_t comm, cudaStream_t stream);
|
||||||
@ -258,6 +280,28 @@ class NCCLLibrary:
|
|||||||
datatype, op, comm,
|
datatype, op, comm,
|
||||||
stream))
|
stream))
|
||||||
|
|
||||||
|
def ncclReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
||||||
|
count: int, datatype: int, op: int, comm: ncclComm_t,
|
||||||
|
stream: cudaStream_t) -> None:
|
||||||
|
# `datatype` actually should be `ncclDataType_t`
|
||||||
|
# and `op` should be `ncclRedOp_t`
|
||||||
|
# both are aliases of `ctypes.c_int`
|
||||||
|
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||||
|
# by ctypes automatically
|
||||||
|
self.NCCL_CHECK(self._funcs["ncclReduceScatter"](sendbuff, recvbuff,
|
||||||
|
count, datatype, op,
|
||||||
|
comm, stream))
|
||||||
|
|
||||||
|
def ncclAllGather(self, sendbuff: buffer_type, recvbuff: buffer_type,
|
||||||
|
count: int, datatype: int, comm: ncclComm_t,
|
||||||
|
stream: cudaStream_t) -> None:
|
||||||
|
# `datatype` actually should be `ncclDataType_t`
|
||||||
|
# which is an aliases of `ctypes.c_int`
|
||||||
|
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||||
|
# by ctypes automatically
|
||||||
|
self.NCCL_CHECK(self._funcs["ncclAllGather"](sendbuff, recvbuff, count,
|
||||||
|
datatype, comm, stream))
|
||||||
|
|
||||||
def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int,
|
def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int,
|
||||||
dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None:
|
dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None:
|
||||||
self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype,
|
self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user