[Fix] Fix comm test (#1691)

This commit is contained in:
Zhuohan Li 2023-11-16 16:28:39 -08:00 committed by GitHub
parent 686f5e3210
commit 20d0699d49
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2,7 +2,7 @@
Run `pytest tests/distributed/test_comm_ops.py --forked`.
"""
from multiprocessing import Process
from multiprocessing import Process, set_start_method
import pytest
import torch
@ -70,6 +70,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int,
@pytest.mark.parametrize("test_target",
[all_reduce_test_worker, all_gather_test_worker])
def test_multi_process_tensor_parallel(tensor_parallel_size, test_target):
set_start_method("spawn", force=True)
distributed_init_port = get_open_port()
processes = []
for rank in range(tensor_parallel_size):