[distributed] remove pynccl's redundant change_state (#11749)

This commit is contained in:
cennn 2025-01-06 09:05:48 +08:00 committed by GitHub
parent 33fc1e2e86
commit 9e764e7b10
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 28 additions and 62 deletions

View File

@ -59,8 +59,7 @@ def worker_fn():
device=get_world_group().device) device=get_world_group().device)
tensor = torch.ones(16, 1024, 1024, tensor = torch.ones(16, 1024, 1024,
dtype=torch.float32).cuda(pynccl_comm.rank) dtype=torch.float32).cuda(pynccl_comm.rank)
with pynccl_comm.change_state(enable=True): tensor = pynccl_comm.all_reduce(tensor)
tensor = pynccl_comm.all_reduce(tensor)
torch.cuda.synchronize() torch.cuda.synchronize()
assert torch.all(tensor == pynccl_comm.world_size).cpu().item() assert torch.all(tensor == pynccl_comm.world_size).cpu().item()
@ -81,17 +80,16 @@ def multiple_allreduce_worker_fn():
group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1] group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
pynccl_comm = PyNcclCommunicator(group=group, device=device) pynccl_comm = PyNcclCommunicator(group=group, device=device)
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
with pynccl_comm.change_state(enable=True): # two groups can communicate independently
# two groups can communicate independently if torch.distributed.get_rank() in [0, 1]:
if torch.distributed.get_rank() in [0, 1]: tensor = pynccl_comm.all_reduce(tensor)
tensor = pynccl_comm.all_reduce(tensor) tensor = pynccl_comm.all_reduce(tensor)
tensor = pynccl_comm.all_reduce(tensor) torch.cuda.synchronize()
torch.cuda.synchronize() assert torch.all(tensor == 4).cpu().item()
assert torch.all(tensor == 4).cpu().item() else:
else: tensor = pynccl_comm.all_reduce(tensor)
tensor = pynccl_comm.all_reduce(tensor) torch.cuda.synchronize()
torch.cuda.synchronize() assert torch.all(tensor == 2).cpu().item()
assert torch.all(tensor == 2).cpu().item()
@pytest.mark.skipif(torch.cuda.device_count() < 4, @pytest.mark.skipif(torch.cuda.device_count() < 4,
@ -137,8 +135,7 @@ def worker_fn_with_cudagraph():
# run something in the default stream to initialize torch engine # run something in the default stream to initialize torch engine
a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}') a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}')
torch.cuda.synchronize() torch.cuda.synchronize()
with torch.cuda.graph(graph), \ with torch.cuda.graph(graph):
pynccl_comm.change_state(enable=True):
a_out = pynccl_comm.all_reduce(a) a_out = pynccl_comm.all_reduce(a)
torch.cuda.synchronize() torch.cuda.synchronize()
graph.replay() graph.replay()
@ -167,8 +164,7 @@ def all_gather_worker_fn():
for r in range(world_size) for r in range(world_size)
]).to(device) ]).to(device)
with pynccl_comm.change_state(enable=True): pynccl_comm.all_gather(result, tensor)
pynccl_comm.all_gather(result, tensor)
torch.cuda.synchronize() torch.cuda.synchronize()
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
@ -205,8 +201,7 @@ def reduce_scatter_worker_fn():
expected = sum(tensor[rank * scattered_size:(rank + 1) * scattered_size] expected = sum(tensor[rank * scattered_size:(rank + 1) * scattered_size]
for tensor in all_tensors).to(device) for tensor in all_tensors).to(device)
with pynccl_comm.change_state(enable=True): pynccl_comm.reduce_scatter(result, tensor)
pynccl_comm.reduce_scatter(result, tensor)
torch.cuda.synchronize() torch.cuda.synchronize()
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
@ -233,15 +228,13 @@ def send_recv_worker_fn():
else: else:
tensor = torch.empty(16, 1024, 1024, tensor = torch.empty(16, 1024, 1024,
dtype=torch.float32).cuda(pynccl_comm.rank) dtype=torch.float32).cuda(pynccl_comm.rank)
with pynccl_comm.change_state(enable=True):
if pynccl_comm.rank == 0: if pynccl_comm.rank == 0:
pynccl_comm.send(tensor, pynccl_comm.send(tensor,
dst=(pynccl_comm.rank + 1) % dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
pynccl_comm.world_size) else:
else: pynccl_comm.recv(tensor,
pynccl_comm.recv(tensor, src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
src=(pynccl_comm.rank - 1) %
pynccl_comm.world_size)
torch.cuda.synchronize() torch.cuda.synchronize()
assert torch.all(tensor == 1).cpu().item() assert torch.all(tensor == 1).cpu().item()
@ -272,15 +265,12 @@ def multiple_send_recv_worker_fn():
1024, 1024,
dtype=torch.float32, dtype=torch.float32,
device=device) device=device)
with pynccl_comm.change_state(enable=True): if torch.distributed.get_rank() in [0, 1]:
if torch.distributed.get_rank() in [0, 1]: pynccl_comm.send(tensor,
pynccl_comm.send(tensor, dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
dst=(pynccl_comm.rank + 1) % else:
pynccl_comm.world_size) pynccl_comm.recv(tensor,
else: src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
pynccl_comm.recv(tensor,
src=(pynccl_comm.rank - 1) %
pynccl_comm.world_size)
torch.cuda.synchronize() torch.cuda.synchronize()
if torch.distributed.get_rank() in [0, 2]: if torch.distributed.get_rank() in [0, 2]:
assert torch.all(tensor == 1).cpu().item() assert torch.all(tensor == 1).cpu().item()

View File

@ -1,4 +1,3 @@
from contextlib import contextmanager
from typing import Optional, Union from typing import Optional, Union
# ===================== import region ===================== # ===================== import region =====================
@ -213,19 +212,3 @@ class PyNcclCommunicator:
self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(), self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), src, ncclDataTypeEnum.from_torch(tensor.dtype), src,
self.comm, cudaStream_t(stream.cuda_stream)) self.comm, cudaStream_t(stream.cuda_stream))
@contextmanager
def change_state(self, enable: Optional[bool] = None):
"""
A context manager to change the state of the communicator.
"""
if enable is None:
# guess a default value when not specified
enable = self.available
old_disable = self.disabled
self.disabled = not enable
yield
self.disabled = old_disable

View File

@ -305,14 +305,7 @@ class GroupCoordinator:
stream.wait_stream(curr_stream) stream.wait_stream(curr_stream)
with torch.cuda.stream(stream), maybe_ca_context: with torch.cuda.stream(stream), maybe_ca_context:
pynccl_comm = self.pynccl_comm yield graph_capture_context
maybe_pynccl_context: Any
if not pynccl_comm:
maybe_pynccl_context = nullcontext()
else:
maybe_pynccl_context = pynccl_comm.change_state()
with maybe_pynccl_context:
yield graph_capture_context
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
""" """