diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 128096c88a8b1..e9c6fc3a255e4 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -243,6 +243,13 @@ class GroupCoordinator: ca_comm = self.ca_comm maybe_ca_context = nullcontext( ) if ca_comm is None else ca_comm.capture() + + # ensure all initialization operations complete before attempting to + # capture the graph on another stream + curr_stream = torch.cuda.current_stream() + if curr_stream != stream: + stream.wait_stream(curr_stream) + with torch.cuda.stream(stream), maybe_ca_context: # In graph mode, we have to be very careful about the collective # operations. The current status is: