mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 07:15:01 +08:00
[ROCm][CI] Fix "Cannot re-initialize CUDA in forked subprocess" in test_pynccl.py (#29119)
Signed-off-by: Micah Williamson <micah.williamson@amd.com>
This commit is contained in:
parent
3999442f1c
commit
55c21c8836
@ -40,5 +40,8 @@ mteb[bm25s]>=1.38.11, <2
|
|||||||
# Required for eval tests
|
# Required for eval tests
|
||||||
lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d
|
lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d
|
||||||
|
|
||||||
|
# Required for multiprocessed tests that use spawn method
|
||||||
|
multiprocess==0.70.16
|
||||||
|
|
||||||
# Plugins test
|
# Plugins test
|
||||||
terratorch @ git+https://github.com/IBM/terratorch.git@07184fcf91a1324f831ff521dd238d97fe350e3e
|
terratorch @ git+https://github.com/IBM/terratorch.git@07184fcf91a1324f831ff521dd238d97fe350e3e
|
||||||
|
|||||||
@ -1,9 +1,9 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import multiprocessing
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import multiprocess as mp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@ -20,10 +20,12 @@ from vllm.distributed.parallel_state import (
|
|||||||
)
|
)
|
||||||
from vllm.utils.system_utils import update_environment_variables
|
from vllm.utils.system_utils import update_environment_variables
|
||||||
|
|
||||||
|
mp.set_start_method("spawn", force=True)
|
||||||
|
|
||||||
|
|
||||||
def distributed_run(fn, world_size):
|
def distributed_run(fn, world_size):
|
||||||
number_of_processes = world_size
|
number_of_processes = world_size
|
||||||
processes: list[multiprocessing.Process] = []
|
processes: list[mp.Process] = []
|
||||||
for i in range(number_of_processes):
|
for i in range(number_of_processes):
|
||||||
env: dict[str, str] = {}
|
env: dict[str, str] = {}
|
||||||
env["RANK"] = str(i)
|
env["RANK"] = str(i)
|
||||||
@ -32,7 +34,7 @@ def distributed_run(fn, world_size):
|
|||||||
env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
|
env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
|
||||||
env["MASTER_ADDR"] = "localhost"
|
env["MASTER_ADDR"] = "localhost"
|
||||||
env["MASTER_PORT"] = "12345"
|
env["MASTER_PORT"] = "12345"
|
||||||
p = multiprocessing.Process(target=fn, args=(env,))
|
p = mp.Process(target=fn, args=(env,))
|
||||||
processes.append(p)
|
processes.append(p)
|
||||||
p.start()
|
p.start()
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user