mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:06:06 +08:00
[distributed] remove pynccl's redundant change_state (#11749)
This commit is contained in:
parent
33fc1e2e86
commit
9e764e7b10
@ -59,8 +59,7 @@ def worker_fn():
|
||||
device=get_world_group().device)
|
||||
tensor = torch.ones(16, 1024, 1024,
|
||||
dtype=torch.float32).cuda(pynccl_comm.rank)
|
||||
with pynccl_comm.change_state(enable=True):
|
||||
tensor = pynccl_comm.all_reduce(tensor)
|
||||
tensor = pynccl_comm.all_reduce(tensor)
|
||||
torch.cuda.synchronize()
|
||||
assert torch.all(tensor == pynccl_comm.world_size).cpu().item()
|
||||
|
||||
@ -81,17 +80,16 @@ def multiple_allreduce_worker_fn():
|
||||
group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
|
||||
pynccl_comm = PyNcclCommunicator(group=group, device=device)
|
||||
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
|
||||
with pynccl_comm.change_state(enable=True):
|
||||
# two groups can communicate independently
|
||||
if torch.distributed.get_rank() in [0, 1]:
|
||||
tensor = pynccl_comm.all_reduce(tensor)
|
||||
tensor = pynccl_comm.all_reduce(tensor)
|
||||
torch.cuda.synchronize()
|
||||
assert torch.all(tensor == 4).cpu().item()
|
||||
else:
|
||||
tensor = pynccl_comm.all_reduce(tensor)
|
||||
torch.cuda.synchronize()
|
||||
assert torch.all(tensor == 2).cpu().item()
|
||||
# two groups can communicate independently
|
||||
if torch.distributed.get_rank() in [0, 1]:
|
||||
tensor = pynccl_comm.all_reduce(tensor)
|
||||
tensor = pynccl_comm.all_reduce(tensor)
|
||||
torch.cuda.synchronize()
|
||||
assert torch.all(tensor == 4).cpu().item()
|
||||
else:
|
||||
tensor = pynccl_comm.all_reduce(tensor)
|
||||
torch.cuda.synchronize()
|
||||
assert torch.all(tensor == 2).cpu().item()
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
||||
@ -137,8 +135,7 @@ def worker_fn_with_cudagraph():
|
||||
# run something in the default stream to initialize torch engine
|
||||
a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}')
|
||||
torch.cuda.synchronize()
|
||||
with torch.cuda.graph(graph), \
|
||||
pynccl_comm.change_state(enable=True):
|
||||
with torch.cuda.graph(graph):
|
||||
a_out = pynccl_comm.all_reduce(a)
|
||||
torch.cuda.synchronize()
|
||||
graph.replay()
|
||||
@ -167,8 +164,7 @@ def all_gather_worker_fn():
|
||||
for r in range(world_size)
|
||||
]).to(device)
|
||||
|
||||
with pynccl_comm.change_state(enable=True):
|
||||
pynccl_comm.all_gather(result, tensor)
|
||||
pynccl_comm.all_gather(result, tensor)
|
||||
torch.cuda.synchronize()
|
||||
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
||||
|
||||
@ -205,8 +201,7 @@ def reduce_scatter_worker_fn():
|
||||
expected = sum(tensor[rank * scattered_size:(rank + 1) * scattered_size]
|
||||
for tensor in all_tensors).to(device)
|
||||
|
||||
with pynccl_comm.change_state(enable=True):
|
||||
pynccl_comm.reduce_scatter(result, tensor)
|
||||
pynccl_comm.reduce_scatter(result, tensor)
|
||||
torch.cuda.synchronize()
|
||||
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
||||
|
||||
@ -233,15 +228,13 @@ def send_recv_worker_fn():
|
||||
else:
|
||||
tensor = torch.empty(16, 1024, 1024,
|
||||
dtype=torch.float32).cuda(pynccl_comm.rank)
|
||||
with pynccl_comm.change_state(enable=True):
|
||||
if pynccl_comm.rank == 0:
|
||||
pynccl_comm.send(tensor,
|
||||
dst=(pynccl_comm.rank + 1) %
|
||||
pynccl_comm.world_size)
|
||||
else:
|
||||
pynccl_comm.recv(tensor,
|
||||
src=(pynccl_comm.rank - 1) %
|
||||
pynccl_comm.world_size)
|
||||
|
||||
if pynccl_comm.rank == 0:
|
||||
pynccl_comm.send(tensor,
|
||||
dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
|
||||
else:
|
||||
pynccl_comm.recv(tensor,
|
||||
src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
|
||||
torch.cuda.synchronize()
|
||||
assert torch.all(tensor == 1).cpu().item()
|
||||
|
||||
@ -272,15 +265,12 @@ def multiple_send_recv_worker_fn():
|
||||
1024,
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
with pynccl_comm.change_state(enable=True):
|
||||
if torch.distributed.get_rank() in [0, 1]:
|
||||
pynccl_comm.send(tensor,
|
||||
dst=(pynccl_comm.rank + 1) %
|
||||
pynccl_comm.world_size)
|
||||
else:
|
||||
pynccl_comm.recv(tensor,
|
||||
src=(pynccl_comm.rank - 1) %
|
||||
pynccl_comm.world_size)
|
||||
if torch.distributed.get_rank() in [0, 1]:
|
||||
pynccl_comm.send(tensor,
|
||||
dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
|
||||
else:
|
||||
pynccl_comm.recv(tensor,
|
||||
src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
|
||||
torch.cuda.synchronize()
|
||||
if torch.distributed.get_rank() in [0, 2]:
|
||||
assert torch.all(tensor == 1).cpu().item()
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Union
|
||||
|
||||
# ===================== import region =====================
|
||||
@ -213,19 +212,3 @@ class PyNcclCommunicator:
|
||||
self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype), src,
|
||||
self.comm, cudaStream_t(stream.cuda_stream))
|
||||
|
||||
@contextmanager
|
||||
def change_state(self, enable: Optional[bool] = None):
|
||||
"""
|
||||
A context manager to change the state of the communicator.
|
||||
"""
|
||||
if enable is None:
|
||||
# guess a default value when not specified
|
||||
enable = self.available
|
||||
|
||||
old_disable = self.disabled
|
||||
|
||||
self.disabled = not enable
|
||||
yield
|
||||
|
||||
self.disabled = old_disable
|
||||
|
||||
@ -305,14 +305,7 @@ class GroupCoordinator:
|
||||
stream.wait_stream(curr_stream)
|
||||
|
||||
with torch.cuda.stream(stream), maybe_ca_context:
|
||||
pynccl_comm = self.pynccl_comm
|
||||
maybe_pynccl_context: Any
|
||||
if not pynccl_comm:
|
||||
maybe_pynccl_context = nullcontext()
|
||||
else:
|
||||
maybe_pynccl_context = pynccl_comm.change_state()
|
||||
with maybe_pynccl_context:
|
||||
yield graph_capture_context
|
||||
yield graph_capture_context
|
||||
|
||||
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user