# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os import random import torch import torch.multiprocessing as mp from vllm.distributed.parallel_state import ( init_distributed_environment, ) from vllm.utils.system_utils import update_environment_variables mp.set_start_method("spawn", force=True) def distributed_run(fn, world_size, *args): number_of_processes = world_size processes: list[mp.Process] = [] for i in range(number_of_processes): env: dict[str, str] = {} env["RANK"] = str(i) env["LOCAL_RANK"] = str(i) env["WORLD_SIZE"] = str(number_of_processes) env["LOCAL_WORLD_SIZE"] = str(number_of_processes) env["MASTER_ADDR"] = "localhost" env["MASTER_PORT"] = "12345" p = mp.Process(target=fn, args=(env, world_size, *args)) processes.append(p) p.start() for p in processes: p.join() for p in processes: assert p.exitcode == 0 def set_env_vars_and_device(env: dict[str, str]) -> None: update_environment_variables(env) local_rank = os.environ["LOCAL_RANK"] device = torch.device(f"cuda:{local_rank}") torch.cuda.set_device(device) init_distributed_environment() # Ensure each worker process has the same random seed random.seed(42) torch.manual_seed(42)