mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 22:25:32 +08:00
[V0 Deprecation] Remove V0 MP executor (#25329)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
12dbd834cf
commit
7ed82d1974
@ -1,244 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
|
||||
import cloudpickle
|
||||
|
||||
from vllm.executor.executor_base import DistributedExecutorBase
|
||||
from vllm.executor.multiproc_worker_utils import (
|
||||
ProcessWorkerWrapper, ResultHandler, WorkerMonitor,
|
||||
set_multiprocessing_worker_envs)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless,
|
||||
get_distributed_init_method, get_ip, get_open_port,
|
||||
make_async, run_method, update_environment_variables)
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MultiprocessingDistributedExecutor(DistributedExecutorBase):
|
||||
"""Python multiprocessing-based distributed executor"""
|
||||
|
||||
uses_ray: bool = False
|
||||
|
||||
def _check_cuda(self) -> None:
|
||||
"""Check that the number of GPUs is sufficient for the parallel
|
||||
configuration. Separate from _init_executor to reduce the number of
|
||||
indented blocks.
|
||||
"""
|
||||
parallel_config = self.parallel_config
|
||||
world_size = parallel_config.world_size
|
||||
tensor_parallel_size = parallel_config.tensor_parallel_size
|
||||
|
||||
cuda_device_count = cuda_device_count_stateless()
|
||||
# Use confusing message for more common TP-only case.
|
||||
if tensor_parallel_size > cuda_device_count:
|
||||
raise RuntimeError(
|
||||
f"please set tensor_parallel_size ({tensor_parallel_size}) "
|
||||
f"to less than max local gpu count ({cuda_device_count})")
|
||||
|
||||
if world_size > cuda_device_count:
|
||||
raise RuntimeError(
|
||||
f"please ensure that world_size ({world_size}) "
|
||||
f"is less than than max local gpu count ({cuda_device_count})")
|
||||
|
||||
# Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
|
||||
if "CUDA_VISIBLE_DEVICES" not in os.environ:
|
||||
update_environment_variables({
|
||||
"CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
|
||||
})
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
if current_platform.is_cuda_alike():
|
||||
self._check_cuda()
|
||||
|
||||
# Create the parallel GPU workers.
|
||||
world_size = self.parallel_config.world_size
|
||||
tensor_parallel_size = self.parallel_config.tensor_parallel_size
|
||||
|
||||
# Set multiprocessing envs that are common to V0 and V1
|
||||
set_multiprocessing_worker_envs(self.parallel_config)
|
||||
|
||||
# Multiprocessing-based executor does not support multi-node setting.
|
||||
# Since it only works for single node, we can use the loopback address
|
||||
# 127.0.0.1 for communication.
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
"127.0.0.1", get_open_port())
|
||||
|
||||
self.workers: List[ProcessWorkerWrapper] = []
|
||||
# This is the list of workers that are rank 0 of each TP group EXCEPT
|
||||
# global rank 0. These are the workers that will broadcast to the
|
||||
# rest of the workers.
|
||||
self.tp_driver_workers: List[ProcessWorkerWrapper] = []
|
||||
# This is the list of workers that are not drivers and not the first
|
||||
# worker in a TP group. These are the workers that will be
|
||||
# broadcasted to.
|
||||
self.non_driver_workers: List[ProcessWorkerWrapper] = []
|
||||
|
||||
if world_size == 1:
|
||||
self.worker_monitor = None
|
||||
else:
|
||||
result_handler = ResultHandler()
|
||||
for rank in range(1, world_size):
|
||||
worker = ProcessWorkerWrapper(result_handler,
|
||||
WorkerWrapperBase,
|
||||
self.vllm_config, rank)
|
||||
self.workers.append(worker)
|
||||
if rank % tensor_parallel_size == 0:
|
||||
self.tp_driver_workers.append(worker)
|
||||
else:
|
||||
self.non_driver_workers.append(worker)
|
||||
|
||||
self.worker_monitor = WorkerMonitor(self.workers, result_handler)
|
||||
result_handler.start()
|
||||
self.worker_monitor.start()
|
||||
|
||||
# Set up signal handlers to shut down the executor cleanly
|
||||
# sometimes gc does not work well
|
||||
|
||||
self.driver_worker = WorkerWrapperBase(self.vllm_config, 0)
|
||||
|
||||
all_kwargs = []
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
for i in range(world_size):
|
||||
local_rank = i
|
||||
rank = i
|
||||
kwargs = dict(
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
is_driver_worker=(not self.parallel_config)
|
||||
or (rank % self.parallel_config.tensor_parallel_size == 0),
|
||||
)
|
||||
all_kwargs.append(kwargs)
|
||||
self._run_workers("init_worker", all_kwargs)
|
||||
self._run_workers("init_device")
|
||||
self._run_workers("load_model",
|
||||
max_concurrent_workers=self.parallel_config.
|
||||
max_parallel_loading_workers)
|
||||
self.driver_exec_model = make_async(self.driver_worker.execute_model)
|
||||
self.pp_locks: Optional[List[asyncio.Lock]] = None
|
||||
|
||||
def shutdown(self):
|
||||
if (worker_monitor := getattr(self, "worker_monitor",
|
||||
None)) is not None:
|
||||
worker_monitor.close()
|
||||
|
||||
def _driver_execute_model(
|
||||
self, execute_model_req: Optional[ExecuteModelRequest]
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
"""Run execute_model in the driver worker.
|
||||
|
||||
Passing None will cause the driver to stop the model execution
|
||||
loop running in each of the remote workers.
|
||||
"""
|
||||
return self.driver_worker.execute_model(execute_model_req)
|
||||
|
||||
def _run_workers(
|
||||
self,
|
||||
method: Union[str, Callable],
|
||||
*args,
|
||||
async_run_tensor_parallel_workers_only: bool = False,
|
||||
max_concurrent_workers: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> List[Any]:
|
||||
"""Runs the given method on all workers.
|
||||
|
||||
Args:
|
||||
async_run_tensor_parallel_workers_only: If True the method will be
|
||||
run only in the remote TP workers, not the driver worker.
|
||||
It will also be run asynchronously and return a list of futures
|
||||
rather than blocking on the results.
|
||||
"""
|
||||
if isinstance(method, str):
|
||||
sent_method = method
|
||||
else:
|
||||
sent_method = cloudpickle.dumps(method)
|
||||
del method
|
||||
|
||||
if max_concurrent_workers:
|
||||
raise NotImplementedError(
|
||||
"max_concurrent_workers is not supported yet.")
|
||||
|
||||
if async_run_tensor_parallel_workers_only:
|
||||
# Run only non-driver workers and just return futures.
|
||||
return [
|
||||
worker.execute_method(sent_method, *args, **kwargs)
|
||||
for worker in self.non_driver_workers
|
||||
]
|
||||
|
||||
# Start all remote workers first.
|
||||
worker_outputs = [
|
||||
worker.execute_method(sent_method, *args, **kwargs)
|
||||
for worker in self.workers
|
||||
]
|
||||
|
||||
driver_worker_output = run_method(self.driver_worker, sent_method,
|
||||
args, kwargs)
|
||||
|
||||
# Get the results of the workers.
|
||||
return [driver_worker_output
|
||||
] + [output.get() for output in worker_outputs]
|
||||
|
||||
def check_health(self) -> None:
|
||||
"""Raises an error if engine is unhealthy."""
|
||||
if self.worker_monitor is not None and not self.worker_monitor.is_alive(
|
||||
):
|
||||
raise RuntimeError("Worker processes are not running")
|
||||
|
||||
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
|
||||
"""Wait for futures returned from _run_workers() with
|
||||
async_run_remote_workers_only to complete."""
|
||||
for result in parallel_worker_tasks:
|
||||
result.get()
|
||||
|
||||
async def _driver_execute_model_async(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
) -> List[SamplerOutput]:
|
||||
if not self.tp_driver_workers:
|
||||
return await self.driver_exec_model(execute_model_req)
|
||||
|
||||
if self.pp_locks is None:
|
||||
# This locks each pipeline parallel stage so multiple virtual
|
||||
# engines can't execute on the same stage at the same time
|
||||
# We create the locks here to avoid creating them in the constructor
|
||||
# which uses a different asyncio loop.
|
||||
self.pp_locks = [
|
||||
asyncio.Lock()
|
||||
for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(
|
||||
_run_task_with_lock(self.driver_exec_model, self.pp_locks[0],
|
||||
execute_model_req))
|
||||
]
|
||||
for pp_rank, driver_worker in enumerate(self.tp_driver_workers,
|
||||
start=1):
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
_run_task_with_lock(driver_worker.execute_method_async,
|
||||
self.pp_locks[pp_rank],
|
||||
"execute_model", execute_model_req)))
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# Only the last PP stage has the final results.
|
||||
return results[-1]
|
||||
|
||||
async def _start_worker_execution_loop(self):
|
||||
coros = [
|
||||
worker.execute_method_async("start_worker_execution_loop")
|
||||
for worker in self.non_driver_workers
|
||||
]
|
||||
return await asyncio.gather(*coros)
|
||||
@ -1,279 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import threading
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing import Queue
|
||||
from multiprocessing.connection import wait
|
||||
from multiprocessing.process import BaseProcess
|
||||
from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import (_maybe_force_spawn, decorate_logs, get_mp_context,
|
||||
run_method)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
_TERMINATE = "TERMINATE" # sentinel
|
||||
|
||||
JOIN_TIMEOUT_S = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class Result(Generic[T]):
|
||||
"""Result of task dispatched to worker"""
|
||||
|
||||
task_id: uuid.UUID
|
||||
value: Optional[T] = None
|
||||
exception: Optional[BaseException] = None
|
||||
|
||||
|
||||
class ResultFuture(threading.Event, Generic[T]):
|
||||
"""Synchronous future for non-async case"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.result: Optional[Result[T]] = None
|
||||
|
||||
def set_result(self, result: Result[T]):
|
||||
self.result = result
|
||||
self.set()
|
||||
|
||||
def get(self) -> T:
|
||||
self.wait()
|
||||
assert self.result is not None
|
||||
if self.result.exception is not None:
|
||||
raise self.result.exception
|
||||
return self.result.value # type: ignore[return-value]
|
||||
|
||||
|
||||
def _set_future_result(future: Union[ResultFuture, asyncio.Future],
|
||||
result: Result):
|
||||
if isinstance(future, ResultFuture):
|
||||
future.set_result(result)
|
||||
return
|
||||
loop = future.get_loop()
|
||||
if not loop.is_closed():
|
||||
if result.exception is not None:
|
||||
loop.call_soon_threadsafe(future.set_exception, result.exception)
|
||||
else:
|
||||
loop.call_soon_threadsafe(future.set_result, result.value)
|
||||
|
||||
|
||||
class ResultHandler(threading.Thread):
|
||||
"""Handle results from all workers (in background thread)"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(daemon=True)
|
||||
self.result_queue = get_mp_context().Queue()
|
||||
self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {}
|
||||
|
||||
def run(self):
|
||||
for result in iter(self.result_queue.get, _TERMINATE):
|
||||
future = self.tasks.pop(result.task_id)
|
||||
_set_future_result(future, result)
|
||||
# Ensure that all waiters will receive an exception
|
||||
for task_id, future in self.tasks.items():
|
||||
_set_future_result(
|
||||
future,
|
||||
Result(task_id=task_id,
|
||||
exception=ChildProcessError("worker died")))
|
||||
|
||||
def close(self):
|
||||
self.result_queue.put(_TERMINATE)
|
||||
|
||||
|
||||
class WorkerMonitor(threading.Thread):
|
||||
"""Monitor worker status (in background thread)"""
|
||||
|
||||
def __init__(self, workers: List['ProcessWorkerWrapper'],
|
||||
result_handler: ResultHandler):
|
||||
super().__init__(daemon=True)
|
||||
self.workers = workers
|
||||
self.result_handler = result_handler
|
||||
self._close = False
|
||||
|
||||
def run(self) -> None:
|
||||
# Blocks until any worker exits
|
||||
dead_sentinels = wait([w.process.sentinel for w in self.workers])
|
||||
if not self._close:
|
||||
self._close = True
|
||||
|
||||
# Kill / cleanup all workers
|
||||
for worker in self.workers:
|
||||
process = worker.process
|
||||
if process.sentinel in dead_sentinels:
|
||||
process.join(JOIN_TIMEOUT_S)
|
||||
if process.exitcode is not None and process.exitcode != 0:
|
||||
logger.error("Worker %s pid %s died, exit code: %s",
|
||||
process.name, process.pid, process.exitcode)
|
||||
# Cleanup any remaining workers
|
||||
if logger:
|
||||
logger.info("Killing local vLLM worker processes")
|
||||
for worker in self.workers:
|
||||
worker.kill_worker()
|
||||
# Must be done after worker task queues are all closed
|
||||
self.result_handler.close()
|
||||
|
||||
for worker in self.workers:
|
||||
worker.process.join(JOIN_TIMEOUT_S)
|
||||
|
||||
def close(self):
|
||||
if self._close:
|
||||
return
|
||||
self._close = True
|
||||
logger.info("Terminating local vLLM worker processes")
|
||||
for worker in self.workers:
|
||||
worker.terminate_worker()
|
||||
# Must be done after worker task queues are all closed
|
||||
self.result_handler.close()
|
||||
|
||||
|
||||
class ProcessWorkerWrapper:
|
||||
"""Local process wrapper for vllm.worker.Worker,
|
||||
for handling single-node multi-GPU tensor parallel."""
|
||||
|
||||
def __init__(self, result_handler: ResultHandler,
|
||||
worker_factory: Callable[[VllmConfig, int], Any],
|
||||
vllm_config: VllmConfig, rank: int) -> None:
|
||||
self.mp = get_mp_context()
|
||||
self._task_queue = self.mp.Queue()
|
||||
self.result_queue = result_handler.result_queue
|
||||
self.tasks = result_handler.tasks
|
||||
self.process: BaseProcess = self.mp.Process( # type: ignore[attr-defined]
|
||||
target=_run_worker_process,
|
||||
name="VllmWorkerProcess",
|
||||
kwargs=dict(
|
||||
worker_factory=worker_factory,
|
||||
task_queue=self._task_queue,
|
||||
result_queue=self.result_queue,
|
||||
vllm_config=vllm_config,
|
||||
rank=rank,
|
||||
),
|
||||
daemon=True)
|
||||
|
||||
self.process.start()
|
||||
|
||||
def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future],
|
||||
method: Union[str, bytes], args, kwargs):
|
||||
task_id = uuid.uuid4()
|
||||
self.tasks[task_id] = future
|
||||
try:
|
||||
self._task_queue.put((task_id, method, args, kwargs))
|
||||
except SystemExit:
|
||||
raise
|
||||
except BaseException as e:
|
||||
del self.tasks[task_id]
|
||||
raise ChildProcessError("worker died") from e
|
||||
|
||||
def execute_method(self, method: Union[str, bytes], *args, **kwargs):
|
||||
future: ResultFuture = ResultFuture()
|
||||
self._enqueue_task(future, method, args, kwargs)
|
||||
return future
|
||||
|
||||
async def execute_method_async(self, method: Union[str, bytes], *args,
|
||||
**kwargs):
|
||||
future = asyncio.get_running_loop().create_future()
|
||||
self._enqueue_task(future, method, args, kwargs)
|
||||
return await future
|
||||
|
||||
def terminate_worker(self):
|
||||
try:
|
||||
self._task_queue.put(_TERMINATE)
|
||||
except ValueError:
|
||||
self.process.kill()
|
||||
self._task_queue.close()
|
||||
|
||||
def kill_worker(self):
|
||||
self._task_queue.close()
|
||||
self.process.kill()
|
||||
|
||||
|
||||
def _run_worker_process(
|
||||
worker_factory: Callable[[VllmConfig, int], Any],
|
||||
task_queue: Queue,
|
||||
result_queue: Queue,
|
||||
vllm_config: VllmConfig,
|
||||
rank: int,
|
||||
) -> None:
|
||||
"""Worker process event loop"""
|
||||
|
||||
# Add process-specific prefix to stdout and stderr
|
||||
process_name = get_mp_context().current_process().name
|
||||
decorate_logs(process_name)
|
||||
|
||||
# Initialize worker
|
||||
worker = worker_factory(vllm_config, rank)
|
||||
del worker_factory
|
||||
|
||||
# Accept tasks from the engine in task_queue
|
||||
# and return task output in result_queue
|
||||
logger.info("Worker ready; awaiting tasks")
|
||||
try:
|
||||
for items in iter(task_queue.get, _TERMINATE):
|
||||
output = None
|
||||
exception = None
|
||||
task_id, method, args, kwargs = items
|
||||
try:
|
||||
output = run_method(worker, method, args, kwargs)
|
||||
except SystemExit:
|
||||
raise
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
except BaseException as e:
|
||||
logger.exception(
|
||||
"Exception in worker %s while processing method %s.",
|
||||
process_name, method)
|
||||
exception = e
|
||||
result_queue.put(
|
||||
Result(task_id=task_id, value=output, exception=exception))
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
except Exception:
|
||||
logger.exception("Worker failed")
|
||||
|
||||
# Flush TunableOp results when TunableOp is enabled and
|
||||
# online (in situ) tuning is enabled.
|
||||
# Offline tuning API (record_untuned_is_enabled()) only
|
||||
# available in PyTorch 2.6 or later.
|
||||
if torch.cuda.is_available():
|
||||
import torch.cuda.tunable as tunable
|
||||
if (tunable.is_enabled() and tunable.tuning_is_enabled()
|
||||
and not tunable.record_untuned_is_enabled()):
|
||||
tunable.write_file()
|
||||
|
||||
logger.info("Worker exiting")
|
||||
|
||||
|
||||
def set_multiprocessing_worker_envs(parallel_config):
|
||||
""" Set up environment variables that should be used when there are workers
|
||||
in a multiprocessing environment. This should be called by the parent
|
||||
process before worker processes are created"""
|
||||
|
||||
_maybe_force_spawn()
|
||||
|
||||
# Configure thread parallelism if OMP_NUM_THREADS isn't set
|
||||
#
|
||||
# Helps to avoid CPU contention. The default of spawning a thread per
|
||||
# core combined with multiprocessing for each GPU can have a negative
|
||||
# impact on performance. The contention is amplified when running in a
|
||||
# container where CPU limits can cause throttling.
|
||||
default_omp_num_threads = 1
|
||||
if "OMP_NUM_THREADS" not in os.environ and (
|
||||
current_parallelism :=
|
||||
torch.get_num_threads()) > default_omp_num_threads:
|
||||
logger.warning(
|
||||
"Reducing Torch parallelism from %d threads to %d to avoid "
|
||||
"unnecessary CPU contention. Set OMP_NUM_THREADS in the "
|
||||
"external environment to tune this value as needed.",
|
||||
current_parallelism, default_omp_num_threads)
|
||||
os.environ["OMP_NUM_THREADS"] = str(default_omp_num_threads)
|
||||
torch.set_num_threads(default_omp_num_threads)
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import multiprocessing
|
||||
import os
|
||||
import pickle
|
||||
import queue
|
||||
import signal
|
||||
@ -19,6 +20,7 @@ from threading import Thread
|
||||
from typing import Any, Callable, Optional, Union, cast
|
||||
|
||||
import cloudpickle
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
@ -28,14 +30,12 @@ from vllm.distributed.device_communicators.shm_broadcast import (Handle,
|
||||
MessageQueue)
|
||||
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
|
||||
get_pp_group, get_tp_group)
|
||||
from vllm.executor.multiproc_worker_utils import (
|
||||
set_multiprocessing_worker_envs)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.cache import worker_receiver_cache_from_config
|
||||
from vllm.utils import (decorate_logs, get_distributed_init_method,
|
||||
get_loopback_ip, get_mp_context, get_open_port,
|
||||
set_process_title)
|
||||
from vllm.utils import (_maybe_force_spawn, decorate_logs,
|
||||
get_distributed_init_method, get_loopback_ip,
|
||||
get_mp_context, get_open_port, set_process_title)
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.executor.abstract import Executor, FailureCallback
|
||||
from vllm.v1.executor.utils import get_and_update_mm_cache
|
||||
@ -67,8 +67,8 @@ class MultiprocExecutor(Executor):
|
||||
f"tensor_parallel_size ({tensor_parallel_size}) x pipeline"
|
||||
f"_parallel_size ({pp_parallel_size}). ")
|
||||
|
||||
# Set multiprocessing envs that are common to V0 and V1
|
||||
set_multiprocessing_worker_envs(self.parallel_config)
|
||||
# Set multiprocessing envs
|
||||
set_multiprocessing_worker_envs()
|
||||
|
||||
# Multiprocessing-based executor does not support multi-node setting.
|
||||
# Since it only works for single node, we can use the loopback address
|
||||
@ -698,3 +698,29 @@ class WorkerProc:
|
||||
process_name += f"_EP{ep_rank}"
|
||||
set_process_title(name=process_name)
|
||||
decorate_logs(process_name)
|
||||
|
||||
|
||||
def set_multiprocessing_worker_envs():
|
||||
""" Set up environment variables that should be used when there are workers
|
||||
in a multiprocessing environment. This should be called by the parent
|
||||
process before worker processes are created"""
|
||||
|
||||
_maybe_force_spawn()
|
||||
|
||||
# Configure thread parallelism if OMP_NUM_THREADS isn't set
|
||||
#
|
||||
# Helps to avoid CPU contention. The default of spawning a thread per
|
||||
# core combined with multiprocessing for each GPU can have a negative
|
||||
# impact on performance. The contention is amplified when running in a
|
||||
# container where CPU limits can cause throttling.
|
||||
default_omp_num_threads = 1
|
||||
if "OMP_NUM_THREADS" not in os.environ and (
|
||||
current_parallelism :=
|
||||
torch.get_num_threads()) > default_omp_num_threads:
|
||||
logger.warning(
|
||||
"Reducing Torch parallelism from %d threads to %d to avoid "
|
||||
"unnecessary CPU contention. Set OMP_NUM_THREADS in the "
|
||||
"external environment to tune this value as needed.",
|
||||
current_parallelism, default_omp_num_threads)
|
||||
os.environ["OMP_NUM_THREADS"] = str(default_omp_num_threads)
|
||||
torch.set_num_threads(default_omp_num_threads)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user