mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-25 00:44:29 +08:00
[distributed] remove pynccl's redundant stream (#11744)
This commit is contained in:
parent
4068f4b5b5
commit
635b897246
@ -137,9 +137,8 @@ def worker_fn_with_cudagraph():
|
||||
# run something in the default stream to initialize torch engine
|
||||
a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}')
|
||||
torch.cuda.synchronize()
|
||||
with torch.cuda.graph(
|
||||
graph, stream=pynccl_comm.stream), pynccl_comm.change_state(
|
||||
enable=True):
|
||||
with torch.cuda.graph(graph), \
|
||||
pynccl_comm.change_state(enable=True):
|
||||
a_out = pynccl_comm.all_reduce(a)
|
||||
torch.cuda.synchronize()
|
||||
graph.replay()
|
||||
|
||||
@ -51,7 +51,6 @@ class PyNcclCommunicator:
|
||||
if self.world_size == 1:
|
||||
self.available = False
|
||||
self.disabled = True
|
||||
self.stream = None
|
||||
return
|
||||
try:
|
||||
self.nccl = NCCLLibrary(library_path)
|
||||
@ -60,7 +59,6 @@ class PyNcclCommunicator:
|
||||
# e.g. in a non-GPU environment
|
||||
self.available = False
|
||||
self.disabled = True
|
||||
self.stream = None
|
||||
return
|
||||
|
||||
self.available = True
|
||||
@ -98,12 +96,12 @@ class PyNcclCommunicator:
|
||||
with torch.cuda.device(device):
|
||||
self.comm: ncclComm_t = self.nccl.ncclCommInitRank(
|
||||
self.world_size, self.unique_id, self.rank)
|
||||
self.stream = torch.cuda.Stream()
|
||||
|
||||
stream = torch.cuda.current_stream()
|
||||
# A small all_reduce for warmup.
|
||||
data = torch.zeros(1, device=device)
|
||||
self.all_reduce(data)
|
||||
self.stream.synchronize()
|
||||
stream.synchronize()
|
||||
del data
|
||||
|
||||
def all_reduce(self,
|
||||
@ -122,7 +120,7 @@ class PyNcclCommunicator:
|
||||
out_tensor = torch.empty_like(in_tensor)
|
||||
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
stream = torch.cuda.current_stream()
|
||||
self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()),
|
||||
buffer_type(out_tensor.data_ptr()),
|
||||
in_tensor.numel(),
|
||||
@ -144,7 +142,7 @@ class PyNcclCommunicator:
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {input_tensor.device}")
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
stream = torch.cuda.current_stream()
|
||||
self.nccl.ncclAllGather(
|
||||
buffer_type(input_tensor.data_ptr()),
|
||||
buffer_type(output_tensor.data_ptr()), input_tensor.numel(),
|
||||
@ -165,7 +163,7 @@ class PyNcclCommunicator:
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {input_tensor.device}")
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
stream = torch.cuda.current_stream()
|
||||
self.nccl.ncclReduceScatter(
|
||||
buffer_type(input_tensor.data_ptr()),
|
||||
buffer_type(output_tensor.data_ptr()), output_tensor.numel(),
|
||||
@ -180,7 +178,7 @@ class PyNcclCommunicator:
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}")
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
stream = torch.cuda.current_stream()
|
||||
self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype), dst,
|
||||
self.comm, cudaStream_t(stream.cuda_stream))
|
||||
@ -192,7 +190,7 @@ class PyNcclCommunicator:
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}")
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
stream = torch.cuda.current_stream()
|
||||
self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype), src,
|
||||
self.comm, cudaStream_t(stream.cuda_stream))
|
||||
@ -204,7 +202,7 @@ class PyNcclCommunicator:
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}")
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
stream = torch.cuda.current_stream()
|
||||
if src == self.rank:
|
||||
sendbuff = buffer_type(tensor.data_ptr())
|
||||
# NCCL requires the sender also to have a receive buffer
|
||||
@ -217,9 +215,7 @@ class PyNcclCommunicator:
|
||||
self.comm, cudaStream_t(stream.cuda_stream))
|
||||
|
||||
@contextmanager
|
||||
def change_state(self,
|
||||
enable: Optional[bool] = None,
|
||||
stream: Optional[torch.cuda.Stream] = None):
|
||||
def change_state(self, enable: Optional[bool] = None):
|
||||
"""
|
||||
A context manager to change the state of the communicator.
|
||||
"""
|
||||
@ -227,15 +223,9 @@ class PyNcclCommunicator:
|
||||
# guess a default value when not specified
|
||||
enable = self.available
|
||||
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
|
||||
old_disable = self.disabled
|
||||
old_stream = self.stream
|
||||
|
||||
self.stream = stream
|
||||
self.disabled = not enable
|
||||
yield
|
||||
|
||||
self.disabled = old_disable
|
||||
self.stream = old_stream
|
||||
|
||||
@ -310,8 +310,7 @@ class GroupCoordinator:
|
||||
if not pynccl_comm:
|
||||
maybe_pynccl_context = nullcontext()
|
||||
else:
|
||||
maybe_pynccl_context = pynccl_comm.change_state(
|
||||
stream=torch.cuda.current_stream())
|
||||
maybe_pynccl_context = pynccl_comm.change_state()
|
||||
with maybe_pynccl_context:
|
||||
yield graph_capture_context
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user