diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index 9de0ce0de416e..733c7395811ef 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -2,7 +2,7 @@ Run `pytest tests/distributed/test_comm_ops.py --forked`. """ -from multiprocessing import Process +from multiprocessing import Process, set_start_method import pytest import torch @@ -70,6 +70,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int, @pytest.mark.parametrize("test_target", [all_reduce_test_worker, all_gather_test_worker]) def test_multi_process_tensor_parallel(tensor_parallel_size, test_target): + set_start_method("spawn", force=True) distributed_init_port = get_open_port() processes = [] for rank in range(tensor_parallel_size):