mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 03:15:01 +08:00
[Fix] Fix comm test (#1691)
This commit is contained in:
parent
686f5e3210
commit
20d0699d49
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
Run `pytest tests/distributed/test_comm_ops.py --forked`.
|
Run `pytest tests/distributed/test_comm_ops.py --forked`.
|
||||||
"""
|
"""
|
||||||
from multiprocessing import Process
|
from multiprocessing import Process, set_start_method
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@ -70,6 +70,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int,
|
|||||||
@pytest.mark.parametrize("test_target",
|
@pytest.mark.parametrize("test_target",
|
||||||
[all_reduce_test_worker, all_gather_test_worker])
|
[all_reduce_test_worker, all_gather_test_worker])
|
||||||
def test_multi_process_tensor_parallel(tensor_parallel_size, test_target):
|
def test_multi_process_tensor_parallel(tensor_parallel_size, test_target):
|
||||||
|
set_start_method("spawn", force=True)
|
||||||
distributed_init_port = get_open_port()
|
distributed_init_port = get_open_port()
|
||||||
processes = []
|
processes = []
|
||||||
for rank in range(tensor_parallel_size):
|
for rank in range(tensor_parallel_size):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user