From 20d0699d49a730661434f8374ba495714a92f953 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Thu, 16 Nov 2023 16:28:39 -0800 Subject: [PATCH] [Fix] Fix comm test (#1691) --- tests/distributed/test_comm_ops.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index 9de0ce0de416..733c7395811e 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -2,7 +2,7 @@ Run `pytest tests/distributed/test_comm_ops.py --forked`. """ -from multiprocessing import Process +from multiprocessing import Process, set_start_method import pytest import torch @@ -70,6 +70,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int, @pytest.mark.parametrize("test_target", [all_reduce_test_worker, all_gather_test_worker]) def test_multi_process_tensor_parallel(tensor_parallel_size, test_target): + set_start_method("spawn", force=True) distributed_init_port = get_open_port() processes = [] for rank in range(tensor_parallel_size):