From 0cdd213641d7adcf7ad0e9cf79867344ea1291d9 Mon Sep 17 00:00:00 2001 From: 22quinn <33176974+22quinn@users.noreply.github.com> Date: Mon, 8 Sep 2025 21:43:48 -0700 Subject: [PATCH] [Misc] Improve Worker process title and logging prefix (#22205) Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com> --- vllm/utils/__init__.py | 10 ++---- vllm/v1/engine/core.py | 6 ++-- vllm/v1/engine/utils.py | 2 +- vllm/v1/executor/multiproc_executor.py | 42 ++++++++++++++++++-------- 4 files changed, 37 insertions(+), 23 deletions(-) diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 6d0cb3710bb93..49c706bc37a84 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -3359,7 +3359,7 @@ def has_triton_kernels() -> bool: def set_process_title(name: str, suffix: str = "", - append: bool = False) -> None: + prefix: str = envs.VLLM_PROCESS_NAME_PREFIX) -> None: """ Set the current process title to a specific name with an optional suffix. @@ -3367,15 +3367,11 @@ def set_process_title(name: str, Args: name: The title to assign to the current process. suffix: An optional suffix to append to the base name. - append: Whether to append to the existing process title. + prefix: A prefix to prepend to the front separated by `::`. """ if suffix: name = f"{name}_{suffix}" - if append: - name = f"{setproctitle.getproctitle()}_{name}" - else: - name = f"{envs.VLLM_PROCESS_NAME_PREFIX}::{name}" - setproctitle.setproctitle(name) + setproctitle.setproctitle(f"{prefix}::{name}") def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None: diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index e239e6cbba167..b46ae72ccdf1b 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -224,7 +224,7 @@ class EngineCore: def add_request(self, request: Request, request_wave: int = 0): """Add request to the scheduler. - + `request_wave`: indicate which wave of requests this is expected to belong to in DP case """ @@ -433,7 +433,7 @@ class EngineCore: def preprocess_add_request( self, request: EngineCoreRequest) -> tuple[Request, int]: """Preprocess the request. - + This function could be directly used in input processing thread to allow request initialization running in parallel with Model forward """ @@ -697,7 +697,7 @@ class EngineCoreProc(EngineCore): parallel_config: ParallelConfig = kwargs[ "vllm_config"].parallel_config if parallel_config.data_parallel_size > 1 or dp_rank > 0: - set_process_title("DPEngineCore", str(dp_rank)) + set_process_title("EngineCore", f"DP{dp_rank}") decorate_logs() # Set data parallel rank for this engine process. parallel_config.data_parallel_rank = dp_rank diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index ed0129fda9474..df2fd8d9df078 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -116,7 +116,7 @@ class CoreEngineProcManager: local_dp_ranks.append(local_index) self.processes.append( context.Process(target=target_fn, - name=f"EngineCore_{global_index}", + name=f"EngineCore_DP{global_index}", kwargs=common_kwargs | { "dp_rank": global_index, "local_dp_rank": local_index, diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index c3d6c20e22e2a..bcf6dda9c1e91 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -26,6 +26,8 @@ from vllm.distributed import (destroy_distributed_environment, from vllm.distributed.device_communicators.shm_broadcast import (Handle, MessageQueue) from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator +from vllm.distributed.parallel_state import (get_dp_group, get_ep_group, + get_pp_group, get_tp_group) from vllm.executor.multiproc_worker_utils import ( set_multiprocessing_worker_envs) from vllm.logger import init_logger @@ -397,17 +399,6 @@ class WorkerProc: wrapper.init_worker(all_kwargs) self.worker = wrapper - pp_size = vllm_config.parallel_config.pipeline_parallel_size - tp_size = vllm_config.parallel_config.tensor_parallel_size - pp_str = f"PP{rank // tp_size}" if pp_size > 1 else "" - tp_str = f"TP{rank % tp_size}" if tp_size > 1 else "" - suffix = f"{pp_str}{'_' if pp_str and tp_str else ''}{tp_str}" - process_name = "VllmWorker" - if suffix: - set_process_title(suffix, append=True) - process_name = f"{process_name} {suffix}" - decorate_logs(process_name) - # Initialize MessageQueue for receiving SchedulerOutput self.rpc_broadcast_mq = MessageQueue.create_from_handle( input_shm_handle, self.worker.rank) @@ -425,8 +416,14 @@ class WorkerProc: name="WorkerAsyncOutputCopy") self.async_output_copy_thread.start() - # Initialize device and loads weights + # Initialize device self.worker.init_device() + + # Set process title and log prefix + self.setup_proc_title_and_log_prefix( + enable_ep=vllm_config.parallel_config.enable_expert_parallel) + + # Load model self.worker.load_model() @staticmethod @@ -663,3 +660,24 @@ class WorkerProc: if output_rank is None or self.rank == output_rank: self.handle_output(output) + + @staticmethod + def setup_proc_title_and_log_prefix(enable_ep: bool) -> None: + dp_size = get_dp_group().world_size + dp_rank = get_dp_group().rank_in_group + pp_size = get_pp_group().world_size + pp_rank = get_pp_group().rank_in_group + tp_size = get_tp_group().world_size + tp_rank = get_tp_group().rank_in_group + process_name = "Worker" + if dp_size > 1: + process_name += f"_DP{dp_rank}" + if pp_size > 1: + process_name += f"_PP{pp_rank}" + if tp_size > 1: + process_name += f"_TP{tp_rank}" + if enable_ep: + ep_rank = get_ep_group().rank_in_group + process_name += f"_EP{ep_rank}" + set_process_title(name=process_name) + decorate_logs(process_name)