mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 11:15:01 +08:00
[distributed] remove pynccl's redundant change_state (#11749)
This commit is contained in:
parent
33fc1e2e86
commit
9e764e7b10
@ -59,8 +59,7 @@ def worker_fn():
|
|||||||
device=get_world_group().device)
|
device=get_world_group().device)
|
||||||
tensor = torch.ones(16, 1024, 1024,
|
tensor = torch.ones(16, 1024, 1024,
|
||||||
dtype=torch.float32).cuda(pynccl_comm.rank)
|
dtype=torch.float32).cuda(pynccl_comm.rank)
|
||||||
with pynccl_comm.change_state(enable=True):
|
tensor = pynccl_comm.all_reduce(tensor)
|
||||||
tensor = pynccl_comm.all_reduce(tensor)
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
assert torch.all(tensor == pynccl_comm.world_size).cpu().item()
|
assert torch.all(tensor == pynccl_comm.world_size).cpu().item()
|
||||||
|
|
||||||
@ -81,17 +80,16 @@ def multiple_allreduce_worker_fn():
|
|||||||
group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
|
group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
|
||||||
pynccl_comm = PyNcclCommunicator(group=group, device=device)
|
pynccl_comm = PyNcclCommunicator(group=group, device=device)
|
||||||
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
|
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
|
||||||
with pynccl_comm.change_state(enable=True):
|
# two groups can communicate independently
|
||||||
# two groups can communicate independently
|
if torch.distributed.get_rank() in [0, 1]:
|
||||||
if torch.distributed.get_rank() in [0, 1]:
|
tensor = pynccl_comm.all_reduce(tensor)
|
||||||
tensor = pynccl_comm.all_reduce(tensor)
|
tensor = pynccl_comm.all_reduce(tensor)
|
||||||
tensor = pynccl_comm.all_reduce(tensor)
|
torch.cuda.synchronize()
|
||||||
torch.cuda.synchronize()
|
assert torch.all(tensor == 4).cpu().item()
|
||||||
assert torch.all(tensor == 4).cpu().item()
|
else:
|
||||||
else:
|
tensor = pynccl_comm.all_reduce(tensor)
|
||||||
tensor = pynccl_comm.all_reduce(tensor)
|
torch.cuda.synchronize()
|
||||||
torch.cuda.synchronize()
|
assert torch.all(tensor == 2).cpu().item()
|
||||||
assert torch.all(tensor == 2).cpu().item()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
||||||
@ -137,8 +135,7 @@ def worker_fn_with_cudagraph():
|
|||||||
# run something in the default stream to initialize torch engine
|
# run something in the default stream to initialize torch engine
|
||||||
a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}')
|
a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}')
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
with torch.cuda.graph(graph), \
|
with torch.cuda.graph(graph):
|
||||||
pynccl_comm.change_state(enable=True):
|
|
||||||
a_out = pynccl_comm.all_reduce(a)
|
a_out = pynccl_comm.all_reduce(a)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
graph.replay()
|
graph.replay()
|
||||||
@ -167,8 +164,7 @@ def all_gather_worker_fn():
|
|||||||
for r in range(world_size)
|
for r in range(world_size)
|
||||||
]).to(device)
|
]).to(device)
|
||||||
|
|
||||||
with pynccl_comm.change_state(enable=True):
|
pynccl_comm.all_gather(result, tensor)
|
||||||
pynccl_comm.all_gather(result, tensor)
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
||||||
|
|
||||||
@ -205,8 +201,7 @@ def reduce_scatter_worker_fn():
|
|||||||
expected = sum(tensor[rank * scattered_size:(rank + 1) * scattered_size]
|
expected = sum(tensor[rank * scattered_size:(rank + 1) * scattered_size]
|
||||||
for tensor in all_tensors).to(device)
|
for tensor in all_tensors).to(device)
|
||||||
|
|
||||||
with pynccl_comm.change_state(enable=True):
|
pynccl_comm.reduce_scatter(result, tensor)
|
||||||
pynccl_comm.reduce_scatter(result, tensor)
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
||||||
|
|
||||||
@ -233,15 +228,13 @@ def send_recv_worker_fn():
|
|||||||
else:
|
else:
|
||||||
tensor = torch.empty(16, 1024, 1024,
|
tensor = torch.empty(16, 1024, 1024,
|
||||||
dtype=torch.float32).cuda(pynccl_comm.rank)
|
dtype=torch.float32).cuda(pynccl_comm.rank)
|
||||||
with pynccl_comm.change_state(enable=True):
|
|
||||||
if pynccl_comm.rank == 0:
|
if pynccl_comm.rank == 0:
|
||||||
pynccl_comm.send(tensor,
|
pynccl_comm.send(tensor,
|
||||||
dst=(pynccl_comm.rank + 1) %
|
dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
|
||||||
pynccl_comm.world_size)
|
else:
|
||||||
else:
|
pynccl_comm.recv(tensor,
|
||||||
pynccl_comm.recv(tensor,
|
src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
|
||||||
src=(pynccl_comm.rank - 1) %
|
|
||||||
pynccl_comm.world_size)
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
assert torch.all(tensor == 1).cpu().item()
|
assert torch.all(tensor == 1).cpu().item()
|
||||||
|
|
||||||
@ -272,15 +265,12 @@ def multiple_send_recv_worker_fn():
|
|||||||
1024,
|
1024,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=device)
|
device=device)
|
||||||
with pynccl_comm.change_state(enable=True):
|
if torch.distributed.get_rank() in [0, 1]:
|
||||||
if torch.distributed.get_rank() in [0, 1]:
|
pynccl_comm.send(tensor,
|
||||||
pynccl_comm.send(tensor,
|
dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size)
|
||||||
dst=(pynccl_comm.rank + 1) %
|
else:
|
||||||
pynccl_comm.world_size)
|
pynccl_comm.recv(tensor,
|
||||||
else:
|
src=(pynccl_comm.rank - 1) % pynccl_comm.world_size)
|
||||||
pynccl_comm.recv(tensor,
|
|
||||||
src=(pynccl_comm.rank - 1) %
|
|
||||||
pynccl_comm.world_size)
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
if torch.distributed.get_rank() in [0, 2]:
|
if torch.distributed.get_rank() in [0, 2]:
|
||||||
assert torch.all(tensor == 1).cpu().item()
|
assert torch.all(tensor == 1).cpu().item()
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
from contextlib import contextmanager
|
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
# ===================== import region =====================
|
# ===================== import region =====================
|
||||||
@ -213,19 +212,3 @@ class PyNcclCommunicator:
|
|||||||
self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(),
|
self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(),
|
||||||
ncclDataTypeEnum.from_torch(tensor.dtype), src,
|
ncclDataTypeEnum.from_torch(tensor.dtype), src,
|
||||||
self.comm, cudaStream_t(stream.cuda_stream))
|
self.comm, cudaStream_t(stream.cuda_stream))
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def change_state(self, enable: Optional[bool] = None):
|
|
||||||
"""
|
|
||||||
A context manager to change the state of the communicator.
|
|
||||||
"""
|
|
||||||
if enable is None:
|
|
||||||
# guess a default value when not specified
|
|
||||||
enable = self.available
|
|
||||||
|
|
||||||
old_disable = self.disabled
|
|
||||||
|
|
||||||
self.disabled = not enable
|
|
||||||
yield
|
|
||||||
|
|
||||||
self.disabled = old_disable
|
|
||||||
|
|||||||
@ -305,14 +305,7 @@ class GroupCoordinator:
|
|||||||
stream.wait_stream(curr_stream)
|
stream.wait_stream(curr_stream)
|
||||||
|
|
||||||
with torch.cuda.stream(stream), maybe_ca_context:
|
with torch.cuda.stream(stream), maybe_ca_context:
|
||||||
pynccl_comm = self.pynccl_comm
|
yield graph_capture_context
|
||||||
maybe_pynccl_context: Any
|
|
||||||
if not pynccl_comm:
|
|
||||||
maybe_pynccl_context = nullcontext()
|
|
||||||
else:
|
|
||||||
maybe_pynccl_context = pynccl_comm.change_state()
|
|
||||||
with maybe_pynccl_context:
|
|
||||||
yield graph_capture_context
|
|
||||||
|
|
||||||
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user