[ROCm][CI] Fix "Cannot re-initialize CUDA in forked subprocess" in test_pynccl.py (#29119)

Signed-off-by: Micah Williamson <micah.williamson@amd.com>
This commit is contained in:
Micah Williamson 2025-11-22 23:05:00 -06:00 committed by GitHub
parent 3999442f1c
commit 55c21c8836
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 3 deletions

View File

@ -40,5 +40,8 @@ mteb[bm25s]>=1.38.11, <2
# Required for eval tests # Required for eval tests
lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d
# Required for multiprocessed tests that use spawn method
multiprocess==0.70.16
# Plugins test # Plugins test
terratorch @ git+https://github.com/IBM/terratorch.git@07184fcf91a1324f831ff521dd238d97fe350e3e terratorch @ git+https://github.com/IBM/terratorch.git@07184fcf91a1324f831ff521dd238d97fe350e3e

View File

@ -1,9 +1,9 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import multiprocessing
import os import os
import multiprocess as mp
import numpy as np import numpy as np
import pytest import pytest
import torch import torch
@ -20,10 +20,12 @@ from vllm.distributed.parallel_state import (
) )
from vllm.utils.system_utils import update_environment_variables from vllm.utils.system_utils import update_environment_variables
mp.set_start_method("spawn", force=True)
def distributed_run(fn, world_size): def distributed_run(fn, world_size):
number_of_processes = world_size number_of_processes = world_size
processes: list[multiprocessing.Process] = [] processes: list[mp.Process] = []
for i in range(number_of_processes): for i in range(number_of_processes):
env: dict[str, str] = {} env: dict[str, str] = {}
env["RANK"] = str(i) env["RANK"] = str(i)
@ -32,7 +34,7 @@ def distributed_run(fn, world_size):
env["LOCAL_WORLD_SIZE"] = str(number_of_processes) env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
env["MASTER_ADDR"] = "localhost" env["MASTER_ADDR"] = "localhost"
env["MASTER_PORT"] = "12345" env["MASTER_PORT"] = "12345"
p = multiprocessing.Process(target=fn, args=(env,)) p = mp.Process(target=fn, args=(env,))
processes.append(p) processes.append(p)
p.start() p.start()