From 7234fe26858f2c621901494c307c90e65fe35340 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 29 Jul 2025 06:14:47 +0100 Subject: [PATCH] [Misc] Rework process titles (#21780) Signed-off-by: Nick Hill --- vllm/entrypoints/cli/serve.py | 6 ++++-- vllm/entrypoints/openai/api_server.py | 16 ++++++++++++---- vllm/utils/__init__.py | 16 ++++++++++++---- vllm/v1/engine/coordinator.py | 7 +++---- vllm/v1/engine/core.py | 7 ++++--- vllm/v1/executor/multiproc_executor.py | 16 ++++++++++------ vllm/v1/utils.py | 6 +++--- 7 files changed, 48 insertions(+), 26 deletions(-) diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 68eb2580991c8..a69363e3d98fe 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -21,7 +21,7 @@ from vllm.entrypoints.utils import (VLLM_SUBCMD_PARSER_EPILOG, from vllm.executor.multiproc_worker_utils import _add_prefix from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext -from vllm.utils import FlexibleArgumentParser, bind_process_name, get_tcp_uri +from vllm.utils import FlexibleArgumentParser, get_tcp_uri from vllm.v1.engine.core import EngineCoreProc from vllm.v1.engine.utils import CoreEngineProcManager, launch_core_engines from vllm.v1.executor.abstract import Executor @@ -77,7 +77,7 @@ def run_headless(args: argparse.Namespace): if args.api_server_count > 1: raise ValueError("api_server_count can't be set in headless mode") - bind_process_name("APIServer_Headless") + # set_process_title("Headless_ProcManager") # Create the EngineConfig. engine_args = vllm.AsyncEngineArgs.from_cli_args(args) usage_context = UsageContext.OPENAI_API_SERVER @@ -140,6 +140,8 @@ def run_multi_api_server(args: argparse.Namespace): num_api_servers = args.api_server_count assert num_api_servers > 0 + # set_process_title("ProcManager") + if num_api_servers > 1: setup_multiprocess_prometheus() diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 3d4c4a6b752a7..c375c8755108c 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -11,6 +11,7 @@ import multiprocessing import os import signal import socket +import sys import tempfile import uuid from argparse import Namespace @@ -94,15 +95,15 @@ from vllm.entrypoints.openai.serving_transcription import ( from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.entrypoints.utils import (cli_env_setup, load_aware_call, log_non_default_args, with_cancellation) +from vllm.executor.multiproc_worker_utils import _add_prefix from vllm.logger import init_logger from vllm.reasoning import ReasoningParserManager from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) from vllm.transformers_utils.tokenizer import MistralTokenizer from vllm.usage.usage_lib import UsageContext -from vllm.utils import (Device, FlexibleArgumentParser, bind_process_name, - get_open_zmq_ipc_path, is_valid_ipv6_address, - set_ulimit) +from vllm.utils import (Device, FlexibleArgumentParser, get_open_zmq_ipc_path, + is_valid_ipv6_address, set_process_title, set_ulimit) from vllm.v1.metrics.prometheus import get_prometheus_registry from vllm.version import __version__ as VLLM_VERSION @@ -1805,6 +1806,13 @@ def setup_server(args): async def run_server(args, **uvicorn_kwargs) -> None: """Run a single-worker API server.""" + + # Add process-specific prefix to stdout and stderr. + process_name = "APIServer" + pid = os.getpid() + _add_prefix(sys.stdout, process_name, pid) + _add_prefix(sys.stderr, process_name, pid) + listen_address, sock = setup_server(args) await run_server_worker(listen_address, sock, args, **uvicorn_kwargs) @@ -1820,7 +1828,7 @@ async def run_server_worker(listen_address, ToolParserManager.import_tool_parser(args.tool_parser_plugin) server_index = client_config.get("client_index", 0) if client_config else 0 - bind_process_name("APIServer", str(server_index)) + set_process_title("APIServer", str(server_index)) # Load logging config for uvicorn if specified log_config = load_log_config(args.log_config_file) if log_config is not None: diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 054037b8932b7..ae978c855a8e5 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -3282,14 +3282,22 @@ def has_deep_gemm() -> bool: return _has_module("deep_gemm") -def bind_process_name(name: str, suffix: str = "") -> None: - """Bind the process name to a specific name with an optional suffix. +def set_process_title(name: str, + suffix: str = "", + append: bool = False) -> None: + """ + Set the current process title to a specific name with an + optional suffix. Args: - name: The base name to bind the process to. + name: The title to assign to the current process. suffix: An optional suffix to append to the base name. + append: Whether to append to the existing process title. """ - name = f"{envs.VLLM_PROCESS_NAME_PREFIX}::{name}" if suffix: name = f"{name}_{suffix}" + if append: + name = f"{setproctitle.getproctitle()}_{name}" + else: + name = f"{envs.VLLM_PROCESS_NAME_PREFIX}::{name}" setproctitle.setproctitle(name) diff --git a/vllm/v1/engine/coordinator.py b/vllm/v1/engine/coordinator.py index fc45eea3a73cf..440628576bcb7 100644 --- a/vllm/v1/engine/coordinator.py +++ b/vllm/v1/engine/coordinator.py @@ -10,11 +10,10 @@ import zmq from vllm.config import ParallelConfig from vllm.logger import init_logger -from vllm.utils import get_mp_context, make_zmq_socket +from vllm.utils import get_mp_context, make_zmq_socket, set_process_title from vllm.v1.engine import EngineCoreOutputs, EngineCoreRequestType from vllm.v1.serial_utils import MsgpackDecoder -from vllm.v1.utils import (bind_process_name, get_engine_client_zmq_addr, - shutdown) +from vllm.v1.utils import get_engine_client_zmq_addr, shutdown logger = init_logger(__name__) @@ -119,7 +118,7 @@ class DPCoordinatorProc: def __init__(self, engine_count: int, min_stats_update_interval_ms: int = 100): - bind_process_name(self.__class__.__name__) + set_process_title("DPCoordinator") self.ctx = zmq.Context() self.engines = [EngineState() for _ in range(engine_count)] diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 57f60c4b289bb..cad93061e65b0 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -26,8 +26,8 @@ from vllm.lora.request import LoRARequest from vllm.tasks import POOLING_TASKS, SupportedTask from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) -from vllm.utils import (bind_process_name, make_zmq_socket, - resolve_obj_by_qualname) +from vllm.utils import (make_zmq_socket, resolve_obj_by_qualname, + set_process_title) from vllm.v1.core.kv_cache_utils import (get_kv_cache_config, unify_kv_cache_configs) from vllm.v1.core.sched.interface import SchedulerInterface @@ -425,7 +425,6 @@ class EngineCoreProc(EngineCore): client_handshake_address: Optional[str] = None, engine_index: int = 0, ): - bind_process_name(self.__class__.__name__, f"{engine_index}") self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]() self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs], bytes]]() @@ -630,11 +629,13 @@ class EngineCoreProc(EngineCore): parallel_config: ParallelConfig = kwargs[ "vllm_config"].parallel_config if parallel_config.data_parallel_size > 1 or dp_rank > 0: + set_process_title("DPEngineCore", str(dp_rank)) # Set data parallel rank for this engine process. parallel_config.data_parallel_rank = dp_rank parallel_config.data_parallel_rank_local = local_dp_rank engine_core = DPEngineCoreProc(*args, **kwargs) else: + set_process_title("EngineCore") engine_core = EngineCoreProc(*args, **kwargs) engine_core.run_busy_loop() diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 897174c1599df..8270385053852 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -30,8 +30,8 @@ from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator from vllm.executor.multiproc_worker_utils import ( _add_prefix, set_multiprocessing_worker_envs) from vllm.logger import init_logger -from vllm.utils import (bind_process_name, get_distributed_init_method, - get_loopback_ip, get_mp_context, get_open_port) +from vllm.utils import (get_distributed_init_method, get_loopback_ip, + get_mp_context, get_open_port, set_process_title) from vllm.v1.executor.abstract import Executor, FailureCallback from vllm.v1.outputs import ModelRunnerOutput from vllm.worker.worker_base import WorkerWrapperBase @@ -376,10 +376,14 @@ class WorkerProc: } wrapper.init_worker(all_kwargs) self.worker = wrapper - bind_process_name( - self.worker.worker.__class__.__name__, - f"TP{self.rank}_DP{vllm_config.parallel_config.data_parallel_rank}" - ) + + pp_size = vllm_config.parallel_config.pipeline_parallel_size + tp_size = vllm_config.parallel_config.tensor_parallel_size + pp_str = f"PP{rank // tp_size}" if pp_size > 1 else "" + tp_str = f"TP{rank % tp_size}" if tp_size > 1 else "" + suffix = f"{pp_str}{'_' if pp_str and tp_str else ''}{tp_str}" + if suffix: + set_process_title(suffix, append=True) pid = os.getpid() _add_prefix(sys.stdout, f"VllmWorker rank={rank}", pid) _add_prefix(sys.stderr, f"VllmWorker rank={rank}", pid) diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index bb5a36f38386b..c74d8c543f76c 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -15,8 +15,8 @@ import torch from vllm.logger import init_logger from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) -from vllm.utils import (bind_process_name, get_open_port, - get_open_zmq_ipc_path, get_tcp_uri, kill_process_tree) +from vllm.utils import (get_open_port, get_open_zmq_ipc_path, get_tcp_uri, + kill_process_tree) if TYPE_CHECKING: from vllm.v1.engine.coordinator import DPCoordinator @@ -144,7 +144,7 @@ class APIServerProcessManager: self.listen_address = listen_address self.sock = sock self.args = args - bind_process_name(self.__class__.__name__) + # Start API servers spawn_context = multiprocessing.get_context("spawn") self.processes: list[BaseProcess] = []