[V1] DP scale-out (2/N): Decouple engine process management and comms (#15977)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-05-13 10:48:21 -07:00 committed by GitHub
parent 0b217da646
commit 55aa7af994
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 516 additions and 243 deletions

View File

@ -41,7 +41,7 @@ class MockEngine:
self.abort_request_calls = 0 self.abort_request_calls = 0
self.request_id = None self.request_id = None
# Ugly, remove dependency when possible # Ugly, remove dependency when possible
self.parallel_config = ParallelConfig(1, 1, False) self.parallel_config = ParallelConfig()
self.model_config = MockModelConfig() self.model_config = MockModelConfig()
async def step_async(self, virtual_engine): async def step_async(self, virtual_engine):

View File

@ -18,9 +18,10 @@ from vllm.platforms import current_platform
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core import EngineCore from vllm.v1.engine.core import EngineCore
from vllm.v1.engine.core_client import (AsyncMPClient, CoreEngine, from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient,
EngineCoreClient, SyncMPClient) SyncMPClient)
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.utils import CoreEngineProcManager
from ...distributed.conftest import MockSubscriber from ...distributed.conftest import MockSubscriber
from ...utils import create_new_process_for_each_test from ...utils import create_new_process_for_each_test
@ -348,13 +349,13 @@ def test_startup_failure(monkeypatch: pytest.MonkeyPatch):
# Monkey-patch to extract core process pid while it's starting. # Monkey-patch to extract core process pid while it's starting.
core_proc_pid = [None] core_proc_pid = [None]
ce_ctor = CoreEngine.__init__ cepm_ctor = CoreEngineProcManager.__init__
def patched_ce_ctor(self, *args, **kwargs): def patched_cepm_ctor(self: CoreEngineProcManager, *args, **kwargs):
ce_ctor(self, *args, **kwargs) cepm_ctor(self, *args, **kwargs)
core_proc_pid[0] = self.proc_handle.proc.pid core_proc_pid[0] = self.processes[0].pid
m.setattr(CoreEngine, "__init__", patched_ce_ctor) m.setattr(CoreEngineProcManager, "__init__", patched_cepm_ctor)
t = time.time() t = time.time()
engine_args = EngineArgs(model=MODEL_NAME) engine_args = EngineArgs(model=MODEL_NAME)

View File

@ -1668,25 +1668,17 @@ class ParallelConfig:
data_parallel_size: int = 1 data_parallel_size: int = 1
"""Number of data parallel groups. MoE layers will be sharded according to """Number of data parallel groups. MoE layers will be sharded according to
the product of the tensor parallel size and data parallel size.""" the product of the tensor parallel size and data parallel size."""
data_parallel_size_local: int = 1
"""Number of local data parallel groups."""
data_parallel_rank: int = 0 data_parallel_rank: int = 0
"""Rank of the data parallel group.""" """Rank of the data parallel group."""
_data_parallel_rank_local: Optional[int] = field(default=None, init=False) data_parallel_rank_local: Optional[int] = None
"""Private field to store the local rank of the data parallel group.""" """Local rank of the data parallel group,
set only in SPMD mode."""
@property
def data_parallel_rank_local(self) -> int:
"""Local rank of the data parallel group, defaults to global rank."""
if self._data_parallel_rank_local is None:
return self.data_parallel_rank
return self._data_parallel_rank_local
@data_parallel_rank_local.setter
def data_parallel_rank_local(self, value: int) -> None:
"""Set the local rank of the data parallel group."""
self._data_parallel_rank_local = value
data_parallel_master_ip: str = "127.0.0.1" data_parallel_master_ip: str = "127.0.0.1"
"""IP of the data parallel master.""" """IP of the data parallel master."""
data_parallel_rpc_port: int = 29550
"""Port for data parallel messaging."""
data_parallel_master_port: int = 29500 data_parallel_master_port: int = 29500
"""Port of the data parallel master.""" """Port of the data parallel master."""
enable_expert_parallel: bool = False enable_expert_parallel: bool = False
@ -1734,13 +1726,16 @@ class ParallelConfig:
world_size: int = field(init=False) world_size: int = field(init=False)
"""world_size is TPxPP, it affects the number of workers we create.""" """world_size is TPxPP, it affects the number of workers we create."""
world_size_across_dp: int = field(init=False)
"""world_size_across_dp is TPxPPxDP, it is the size of the world
including data parallelism."""
rank: int = 0 rank: int = 0
"""Global rank in distributed setup.""" """Global rank in distributed setup."""
@property
def world_size_across_dp(self) -> int:
"""world_size_across_dp is TPxPPxDP, it is the size of the world
including data parallelism."""
return self.world_size * self.data_parallel_size
def get_next_dp_init_port(self) -> int: def get_next_dp_init_port(self) -> int:
""" """
We might need to initialize process groups in multiple We might need to initialize process groups in multiple
@ -1800,10 +1795,14 @@ class ParallelConfig:
self.world_size = self.pipeline_parallel_size * \ self.world_size = self.pipeline_parallel_size * \
self.tensor_parallel_size self.tensor_parallel_size
if self.data_parallel_size > 1: if self.data_parallel_size_local > self.data_parallel_size:
raise ValueError(
f"data_parallel_size_local ({self.data_parallel_size_local}) "
f"must be <= data_parallel_size ({self.data_parallel_size})")
if self.data_parallel_size > 1 or self.data_parallel_size_local == 0:
# Data parallel was specified in the engine args. # Data parallel was specified in the engine args.
self.data_parallel_master_port = get_open_port() self.data_parallel_master_port = get_open_port()
# TODO multi-node
else: else:
# Otherwise fall back to env vars (e.g. for offline SPMD case). # Otherwise fall back to env vars (e.g. for offline SPMD case).
self.data_parallel_size = envs.VLLM_DP_SIZE self.data_parallel_size = envs.VLLM_DP_SIZE
@ -1812,8 +1811,6 @@ class ParallelConfig:
self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP
self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT
self.world_size_across_dp = self.world_size * self.data_parallel_size
if self.distributed_executor_backend == "external_launcher": if self.distributed_executor_backend == "external_launcher":
import os import os
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"

View File

@ -22,6 +22,7 @@ from torch.distributed.rendezvous import rendezvous
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import get_tcp_uri
logger = init_logger(__name__) logger = init_logger(__name__)
@ -303,7 +304,7 @@ def stateless_init_torch_distributed_process_group(
always formed with process 1, 2, ..., 8, and the additional communication always formed with process 1, 2, ..., 8, and the additional communication
channel is formed with process 9 and 10. channel is formed with process 9 and 10.
""" """
init_method = f"tcp://{host}:{port}" init_method = get_tcp_uri(host, port)
backend = Backend(backend) # it is basically string backend = Backend(backend) # it is basically string
timeout = _get_default_timeout(backend) timeout = _get_default_timeout(backend)

View File

@ -283,6 +283,9 @@ class EngineArgs:
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
data_parallel_size: int = ParallelConfig.data_parallel_size data_parallel_size: int = ParallelConfig.data_parallel_size
data_parallel_size_local: Optional[int] = None
data_parallel_address: Optional[str] = None
data_parallel_rpc_port: Optional[int] = None
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
max_parallel_loading_workers: Optional[ max_parallel_loading_workers: Optional[
int] = ParallelConfig.max_parallel_loading_workers int] = ParallelConfig.max_parallel_loading_workers
@ -596,6 +599,21 @@ class EngineArgs:
**parallel_kwargs["tensor_parallel_size"]) **parallel_kwargs["tensor_parallel_size"])
parallel_group.add_argument("--data-parallel-size", "-dp", parallel_group.add_argument("--data-parallel-size", "-dp",
**parallel_kwargs["data_parallel_size"]) **parallel_kwargs["data_parallel_size"])
parallel_group.add_argument('--data-parallel-size-local',
'-dpl',
type=int,
help='Number of data parallel replicas '
'to run on this node.')
parallel_group.add_argument('--data-parallel-address',
'-dpa',
type=str,
help='Address of data parallel cluster '
'head-node.')
parallel_group.add_argument('--data-parallel-rpc-port',
'-dpp',
type=int,
help='Port for data parallel RPC '
'communication.')
parallel_group.add_argument( parallel_group.add_argument(
"--enable-expert-parallel", "--enable-expert-parallel",
**parallel_kwargs["enable_expert_parallel"]) **parallel_kwargs["enable_expert_parallel"])
@ -1019,10 +1037,30 @@ class EngineArgs:
# but we should not do this here. # but we should not do this here.
placement_group = ray.util.get_current_placement_group() placement_group = ray.util.get_current_placement_group()
# Local DP size defaults to global DP size if not set.
data_parallel_size_local = self.data_parallel_size if (
self.data_parallel_size_local
is None) else self.data_parallel_size_local
# DP address, used in multi-node case for torch distributed group
# and ZMQ sockets.
data_parallel_address = self.data_parallel_address if (
self.data_parallel_address
is not None) else ParallelConfig.data_parallel_master_ip
# This port is only used when there are remote data parallel engines,
# otherwise the local IPC transport is used.
data_parallel_rpc_port = self.data_parallel_rpc_port if (
self.data_parallel_rpc_port
is not None) else ParallelConfig.data_parallel_rpc_port
parallel_config = ParallelConfig( parallel_config = ParallelConfig(
pipeline_parallel_size=self.pipeline_parallel_size, pipeline_parallel_size=self.pipeline_parallel_size,
tensor_parallel_size=self.tensor_parallel_size, tensor_parallel_size=self.tensor_parallel_size,
data_parallel_size=self.data_parallel_size, data_parallel_size=self.data_parallel_size,
data_parallel_size_local=data_parallel_size_local,
data_parallel_master_ip=data_parallel_address,
data_parallel_rpc_port=data_parallel_rpc_port,
enable_expert_parallel=self.enable_expert_parallel, enable_expert_parallel=self.enable_expert_parallel,
max_parallel_loading_workers=self.max_parallel_loading_workers, max_parallel_loading_workers=self.max_parallel_loading_workers,
disable_custom_all_reduce=self.disable_custom_all_reduce, disable_custom_all_reduce=self.disable_custom_all_reduce,

View File

@ -1,14 +1,24 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import argparse import argparse
import signal
import uvloop import uvloop
import vllm.envs as envs
from vllm import AsyncEngineArgs
from vllm.entrypoints.cli.types import CLISubcommand from vllm.entrypoints.cli.types import CLISubcommand
from vllm.entrypoints.openai.api_server import run_server from vllm.entrypoints.openai.api_server import run_server
from vllm.entrypoints.openai.cli_args import (make_arg_parser, from vllm.entrypoints.openai.cli_args import (make_arg_parser,
validate_parsed_serve_args) validate_parsed_serve_args)
from vllm.utils import FlexibleArgumentParser from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, get_tcp_uri
from vllm.v1.engine.core import EngineCoreProc
from vllm.v1.engine.core_client import CoreEngineProcManager
from vllm.v1.executor.abstract import Executor
logger = init_logger(__name__)
class ServeSubcommand(CLISubcommand): class ServeSubcommand(CLISubcommand):
@ -24,7 +34,10 @@ class ServeSubcommand(CLISubcommand):
if hasattr(args, 'model_tag') and args.model_tag is not None: if hasattr(args, 'model_tag') and args.model_tag is not None:
args.model = args.model_tag args.model = args.model_tag
uvloop.run(run_server(args)) if args.headless:
run_headless(args)
else:
uvloop.run(run_server(args))
def validate(self, args: argparse.Namespace) -> None: def validate(self, args: argparse.Namespace) -> None:
validate_parsed_serve_args(args) validate_parsed_serve_args(args)
@ -42,6 +55,18 @@ class ServeSubcommand(CLISubcommand):
nargs='?', nargs='?',
help="The model tag to serve " help="The model tag to serve "
"(optional if specified in config)") "(optional if specified in config)")
serve_parser.add_argument(
"--headless",
action='store_true',
default=False,
help="Run in headless mode. See multi-node data parallel "
"documentation for more details.")
serve_parser.add_argument(
'--data-parallel-start-rank',
'-dpr',
type=int,
default=0,
help='Starting data parallel rank for secondary nodes.')
serve_parser.add_argument( serve_parser.add_argument(
"--config", "--config",
type=str, type=str,
@ -57,3 +82,55 @@ class ServeSubcommand(CLISubcommand):
def cmd_init() -> list[CLISubcommand]: def cmd_init() -> list[CLISubcommand]:
return [ServeSubcommand()] return [ServeSubcommand()]
def run_headless(args: argparse.Namespace):
# Create the EngineConfig.
engine_args = AsyncEngineArgs.from_cli_args(args)
usage_context = UsageContext.OPENAI_API_SERVER
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
if not envs.VLLM_USE_V1:
raise RuntimeError("Headless mode is only supported for V1")
parallel_config = vllm_config.parallel_config
local_engine_count = parallel_config.data_parallel_size_local
host = parallel_config.data_parallel_master_ip
port = engine_args.data_parallel_rpc_port # add to config too
input_address = get_tcp_uri(host, port)
if local_engine_count <= 0:
raise RuntimeError("data_parallel_size_local must be > 0 in "
"headless mode")
# Catch SIGTERM and SIGINT to allow graceful shutdown.
def signal_handler(signum, frame):
logger.debug("Received %d signal.", signum)
raise SystemExit
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
logger.info(
"Launching %d data parallel engine(s) in headless mode, "
"with head node address %s.", local_engine_count, input_address)
# Create the engines.
engine_manager = CoreEngineProcManager(
target_fn=EngineCoreProc.run_engine_core,
local_engine_count=local_engine_count,
start_index=args.data_parallel_start_rank,
local_start_index=0,
vllm_config=vllm_config,
on_head_node=False,
input_address=input_address,
executor_class=Executor.get_class(vllm_config),
log_stats=not engine_args.disable_log_stats,
)
try:
engine_manager.join_first()
finally:
logger.info("Shutting down.")
engine_manager.close()

View File

@ -613,6 +613,10 @@ def is_valid_ipv6_address(address: str) -> bool:
def get_distributed_init_method(ip: str, port: int) -> str: def get_distributed_init_method(ip: str, port: int) -> str:
return get_tcp_uri(ip, port)
def get_tcp_uri(ip: str, port: int) -> str:
# Brackets are not permitted in ipv4 addresses, # Brackets are not permitted in ipv4 addresses,
# see https://github.com/python/cpython/issues/103848 # see https://github.com/python/cpython/issues/103848
return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}" return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"

View File

@ -1,5 +1,4 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import json
import os import os
import queue import queue
import signal import signal
@ -23,7 +22,7 @@ from vllm.logging_utils.dump_input import dump_engine_exception
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.transformers_utils.config import ( from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value) maybe_register_config_serialize_by_value)
from vllm.utils import resolve_obj_by_qualname, zmq_socket_ctx from vllm.utils import make_zmq_socket, resolve_obj_by_qualname, zmq_socket_ctx
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config, from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
unify_kv_cache_configs) unify_kv_cache_configs)
from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.interface import SchedulerInterface
@ -43,6 +42,7 @@ from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__) logger = init_logger(__name__)
POLLING_TIMEOUT_S = 2.5 POLLING_TIMEOUT_S = 2.5
HANDSHAKE_TIMEOUT_MINS = 5
_R = TypeVar('_R') # Return type for collective_rpc _R = TypeVar('_R') # Return type for collective_rpc
@ -348,9 +348,9 @@ class EngineCoreProc(EngineCore):
def __init__( def __init__(
self, self,
input_path: str,
output_path: str,
vllm_config: VllmConfig, vllm_config: VllmConfig,
on_head_node: bool,
input_address: str,
executor_class: type[Executor], executor_class: type[Executor],
log_stats: bool, log_stats: bool,
engine_index: int = 0, engine_index: int = 0,
@ -360,28 +360,91 @@ class EngineCoreProc(EngineCore):
executor_fail_callback = lambda: input_queue.put_nowait( executor_fail_callback = lambda: input_queue.put_nowait(
(EngineCoreRequestType.EXECUTOR_FAILED, b'')) (EngineCoreRequestType.EXECUTOR_FAILED, b''))
super().__init__(vllm_config, executor_class, log_stats, # Create input socket.
executor_fail_callback) input_ctx = zmq.Context()
identity = engine_index.to_bytes(length=2, byteorder="little")
input_socket = make_zmq_socket(input_ctx,
input_address,
zmq.DEALER,
identity=identity,
bind=False)
try:
# Register engine with front-end.
output_address = self.startup_handshake(
input_socket, on_head_node, vllm_config.parallel_config)
self.step_fn = (self.step if self.batch_queue is None else # Update config which may have changed from the handshake.
self.step_with_batch_queue) vllm_config.__post_init__()
self.engines_running = False
# Background Threads and Queues for IO. These enable us to # Set up data parallel environment.
# overlap ZMQ socket IO with GPU since they release the GIL, self._init_data_parallel(vllm_config)
# and to overlap some serialization/deserialization with the
# model forward pass. # Initialize engine core and model.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue. super().__init__(vllm_config, executor_class, log_stats,
self.input_queue = input_queue executor_fail_callback)
self.output_queue = queue.Queue[Union[EngineCoreOutputs, bytes]]()
threading.Thread(target=self.process_input_socket, self.step_fn = (self.step if self.batch_queue is None else
args=(input_path, engine_index), self.step_with_batch_queue)
daemon=True).start() self.engines_running = False
self.output_thread = threading.Thread(
target=self.process_output_socket, # Send ready message.
args=(output_path, engine_index), num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
daemon=True) input_socket.send(
self.output_thread.start() msgspec.msgpack.encode({
"status": "READY",
"local": on_head_node,
"num_gpu_blocks": num_gpu_blocks,
}))
# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
# and to overlap some serialization/deserialization with the
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
self.input_queue = input_queue
self.output_queue = queue.Queue[Union[EngineCoreOutputs, bytes]]()
threading.Thread(target=self.process_input_socket,
args=(input_socket, ),
daemon=True).start()
input_socket = None
self.output_thread = threading.Thread(
target=self.process_output_socket,
args=(output_address, engine_index),
daemon=True)
self.output_thread.start()
finally:
if input_socket is not None:
input_socket.close(linger=0)
@staticmethod
def startup_handshake(input_socket: zmq.Socket, on_head_node: bool,
parallel_config: ParallelConfig) -> str:
# Send registration message.
input_socket.send(
msgspec.msgpack.encode({
"status": "HELLO",
"local": on_head_node,
}))
# Receive initialization message.
logger.info("Waiting for init message from front-end.")
if not input_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60 * 1000):
raise RuntimeError("Did not receive response from front-end "
f"process within {HANDSHAKE_TIMEOUT_MINS} "
f"minutes")
init_bytes = input_socket.recv()
init_message = msgspec.msgpack.decode(init_bytes)
logger.debug("Received init message: %s", init_message)
output_socket_address = init_message["output_socket_address"]
#TBD(nick) maybe replace IP with configured head node address
received_parallel_config = init_message["parallel_config"]
for key, value in received_parallel_config.items():
setattr(parallel_config, key, value)
return output_socket_address
@staticmethod @staticmethod
def run_engine_core(*args, def run_engine_core(*args,
@ -412,7 +475,7 @@ class EngineCoreProc(EngineCore):
try: try:
parallel_config: ParallelConfig = kwargs[ parallel_config: ParallelConfig = kwargs[
"vllm_config"].parallel_config "vllm_config"].parallel_config
if parallel_config.data_parallel_size > 1: if parallel_config.data_parallel_size > 1 or dp_rank > 0:
# Set data parallel rank for this engine process. # Set data parallel rank for this engine process.
parallel_config.data_parallel_rank = dp_rank parallel_config.data_parallel_rank = dp_rank
parallel_config.data_parallel_rank_local = local_dp_rank parallel_config.data_parallel_rank_local = local_dp_rank
@ -436,6 +499,9 @@ class EngineCoreProc(EngineCore):
if engine_core is not None: if engine_core is not None:
engine_core.shutdown() engine_core.shutdown()
def _init_data_parallel(self, vllm_config: VllmConfig):
pass
def run_busy_loop(self): def run_busy_loop(self):
"""Core busy loop of the EngineCore.""" """Core busy loop of the EngineCore."""
@ -527,40 +593,25 @@ class EngineCoreProc(EngineCore):
logger.fatal("vLLM shutdown signal from EngineCore failed " logger.fatal("vLLM shutdown signal from EngineCore failed "
"to send. Please report this issue.") "to send. Please report this issue.")
def process_input_socket(self, input_path: str, engine_index: int): def process_input_socket(self, input_socket: zmq.Socket):
"""Input socket IO thread.""" """Input socket IO thread."""
# Msgpack serialization decoding. # Msgpack serialization decoding.
add_request_decoder = MsgpackDecoder(EngineCoreRequest) add_request_decoder = MsgpackDecoder(EngineCoreRequest)
generic_decoder = MsgpackDecoder() generic_decoder = MsgpackDecoder()
identity = engine_index.to_bytes(length=2, byteorder="little")
with zmq_socket_ctx(input_path, while True:
zmq.DEALER, # (RequestType, RequestData)
identity=identity, type_frame, *data_frames = input_socket.recv_multipart(copy=False)
bind=False) as socket: request_type = EngineCoreRequestType(bytes(type_frame.buffer))
# Send ready message to front-end once input socket is connected. # Deserialize the request data.
message_dict = { decoder = add_request_decoder if (
'type': 'READY', request_type == EngineCoreRequestType.ADD) else generic_decoder
'num_gpu_blocks': self.vllm_config.cache_config.num_gpu_blocks, request = decoder.decode(data_frames)
}
message = json.dumps(message_dict).encode('utf-8')
socket.send(message)
while True: # Push to input queue for core busy loop.
# (RequestType, RequestData) self.input_queue.put_nowait((request_type, request))
type_frame, *data_frames = socket.recv_multipart(copy=False)
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
# Deserialize the request data.
decoder = add_request_decoder if (
request_type
== EngineCoreRequestType.ADD) else generic_decoder
request = decoder.decode(data_frames)
# Push to input queue for core busy loop.
self.input_queue.put_nowait((request_type, request))
def process_output_socket(self, output_path: str, engine_index: int): def process_output_socket(self, output_path: str, engine_index: int):
"""Output socket IO thread.""" """Output socket IO thread."""
@ -609,9 +660,9 @@ class DPEngineCoreProc(EngineCoreProc):
def __init__( def __init__(
self, self,
input_path: str,
output_path: str,
vllm_config: VllmConfig, vllm_config: VllmConfig,
on_head_node: bool,
input_address: str,
executor_class: type[Executor], executor_class: type[Executor],
log_stats: bool, log_stats: bool,
): ):
@ -623,8 +674,20 @@ class DPEngineCoreProc(EngineCoreProc):
_add_prefix(sys.stdout, process_name, pid) _add_prefix(sys.stdout, process_name, pid)
_add_prefix(sys.stderr, process_name, pid) _add_prefix(sys.stderr, process_name, pid)
dp_size = vllm_config.parallel_config.data_parallel_size # Counts forward-passes of the model so that we can synchronize
# finished with DP peers every N steps.
self.counter = 0
# Initialize the engine.
dp_rank = vllm_config.parallel_config.data_parallel_rank dp_rank = vllm_config.parallel_config.data_parallel_rank
super().__init__(vllm_config, on_head_node, input_address,
executor_class, log_stats, dp_rank)
def _init_data_parallel(self, vllm_config: VllmConfig):
# Configure GPUs and stateless process group for data parallel.
dp_rank = vllm_config.parallel_config.data_parallel_rank
dp_size = vllm_config.parallel_config.data_parallel_size
local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
assert dp_size > 1 assert dp_size > 1
@ -632,24 +695,16 @@ class DPEngineCoreProc(EngineCoreProc):
from vllm.platforms import current_platform from vllm.platforms import current_platform
device_control_env_var = current_platform.device_control_env_var device_control_env_var = current_platform.device_control_env_var
tp_size = vllm_config.parallel_config.tensor_parallel_size world_size = vllm_config.parallel_config.world_size
os.environ[device_control_env_var] = ",".join( os.environ[device_control_env_var] = ",".join(
str(current_platform.device_id_to_physical_device_id(i)) str(current_platform.device_id_to_physical_device_id(i))
for i in range(local_dp_rank * tp_size, (local_dp_rank + 1) * for i in range(local_dp_rank * world_size, (local_dp_rank + 1) *
tp_size)) world_size))
self.local_dp_rank = local_dp_rank self.local_dp_rank = local_dp_rank
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group() self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
self.current_wave = 0 self.current_wave = 0
# Initialize the engine after setting up environment.
super().__init__(input_path, output_path, vllm_config, executor_class,
log_stats, dp_rank)
# Counts forward-passes of the model so that we can synchronize
# finished with DP peers every N steps.
self.counter = 0
def shutdown(self): def shutdown(self):
super().shutdown() super().shutdown()
if dp_group := getattr(self, "dp_group", None): if dp_group := getattr(self, "dp_group", None):

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import asyncio import asyncio
import contextlib import contextlib
import json
import queue import queue
import uuid import uuid
import weakref import weakref
@ -9,25 +8,27 @@ from abc import ABC, abstractmethod
from collections import deque from collections import deque
from collections.abc import Awaitable, Sequence from collections.abc import Awaitable, Sequence
from concurrent.futures import Future from concurrent.futures import Future
from dataclasses import dataclass, field from dataclasses import dataclass
from enum import Enum, auto
from threading import Thread from threading import Thread
from typing import Any, Callable, Optional, TypeVar, Union from typing import Any, Callable, Optional, TypeVar, Union
import msgspec
import zmq import zmq
import zmq.asyncio import zmq.asyncio
from vllm.config import VllmConfig from vllm.config import ParallelConfig, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.utils import (get_open_zmq_inproc_path, get_open_zmq_ipc_path, from vllm.utils import (get_open_port, get_open_zmq_inproc_path,
make_zmq_socket) get_open_zmq_ipc_path, get_tcp_uri, make_zmq_socket)
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType, UtilityOutput) EngineCoreRequestType, UtilityOutput)
from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.engine.core import EngineCore, EngineCoreProc
from vllm.v1.engine.exceptions import EngineDeadError from vllm.v1.engine.exceptions import EngineDeadError
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr
from vllm.v1.utils import BackgroundProcHandle from vllm.v1.utils import CoreEngineProcManager
logger = init_logger(__name__) logger = init_logger(__name__)
@ -264,45 +265,22 @@ class InprocClient(EngineCoreClient):
return self.engine_core.collective_rpc(method, timeout, args, kwargs) return self.engine_core.collective_rpc(method, timeout, args, kwargs)
class CoreEngineState(Enum):
NEW = auto()
CONNECTED = auto()
READY = auto()
class CoreEngine: class CoreEngine:
"""One per data parallel rank.""" """One per data parallel rank."""
def __init__( def __init__(self, index: int = 0, local: bool = True):
self, self.local = local
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
input_path: str,
output_path: str,
index: int = 0,
local_dp_rank: int = 0,
):
self.index = index self.index = index
self.identity = index.to_bytes(length=2, byteorder="little") self.identity = index.to_bytes(length=2, byteorder="little")
try:
# Start EngineCore in background process.
self.proc_handle = BackgroundProcHandle(
input_path=input_path,
output_path=output_path,
process_name=f"EngineCore_{index}",
target_fn=EngineCoreProc.run_engine_core,
process_kwargs={
"vllm_config": vllm_config,
"dp_rank": index,
"local_dp_rank": local_dp_rank,
"executor_class": executor_class,
"log_stats": log_stats,
})
self.num_reqs_in_flight = 0 self.state = CoreEngineState.NEW
finally: self.num_reqs_in_flight = 0
if not hasattr(self, "num_reqs_in_flight"):
# Ensure socket is closed if process fails to start.
self.close()
def close(self):
if proc_handle := getattr(self, "proc_handle", None):
proc_handle.shutdown()
@dataclass @dataclass
@ -311,7 +289,7 @@ class BackgroundResources:
circular reference back to the client object.""" circular reference back to the client object."""
ctx: Union[zmq.Context] ctx: Union[zmq.Context]
core_engines: list[CoreEngine] = field(default_factory=list) local_engine_manager: Optional[CoreEngineProcManager] = None
output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
output_queue_task: Optional[asyncio.Task] = None output_queue_task: Optional[asyncio.Task] = None
@ -325,8 +303,8 @@ class BackgroundResources:
"""Clean up background resources.""" """Clean up background resources."""
self.engine_dead = True self.engine_dead = True
for core_engine in self.core_engines: if self.local_engine_manager is not None:
core_engine.close() self.local_engine_manager.close()
if self.output_queue_task is not None: if self.output_queue_task is not None:
self.output_queue_task.cancel() self.output_queue_task.cancel()
@ -388,25 +366,56 @@ class MPClient(EngineCoreClient):
self._finalizer = weakref.finalize(self, self.resources) self._finalizer = weakref.finalize(self, self.resources)
success = False success = False
try: try:
# Paths and sockets for IPC. parallel_config = vllm_config.parallel_config
self.output_path = get_open_zmq_ipc_path() local_engine_count = parallel_config.data_parallel_size_local
input_path = get_open_zmq_ipc_path() start_index = parallel_config.data_parallel_rank
self.input_socket = make_zmq_socket(self.ctx, local_start_index = parallel_config.data_parallel_rank_local
input_path,
zmq.ROUTER,
bind=True)
self.resources.input_socket = self.input_socket
new_core_engine = lambda index, local_dp_rank=None: CoreEngine( # SPMD mode is where there is an LLM instance per DP rank and
vllm_config, executor_class, log_stats, input_path, self. # one core engine per LLM, see
output_path, index, local_dp_rank) # examples/offline_inference/data_parallel.py.
spmd_mode = local_start_index is not None
if spmd_mode:
assert local_engine_count == 1
self.core_engines = [
CoreEngine(index=local_start_index, local=True)
]
else:
assert start_index == 0
local_start_index = 0
self.core_engines = [
CoreEngine(index=i, local=(i < local_engine_count))
for i in range(parallel_config.data_parallel_size)
]
# Start engine core process(es). input_address, output_address = self._get_zmq_addresses(
self._init_core_engines(vllm_config, new_core_engine, parallel_config, spmd_mode)
self.resources.core_engines)
# Create input and output sockets.
self.input_socket = self.resources.input_socket = make_zmq_socket(
self.ctx, input_address, zmq.ROUTER, bind=True)
self.resources.output_socket = make_zmq_socket(
self.ctx, output_address, zmq.constants.PULL)
# Start local engines.
if local_engine_count:
# In server mode, start_index and local_start_index will
# both be 0.
self.resources.local_engine_manager = CoreEngineProcManager(
EngineCoreProc.run_engine_core,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=log_stats,
input_address=input_address,
on_head_node=True,
local_engine_count=local_engine_count,
start_index=start_index,
local_start_index=local_start_index)
self.core_engine = self.core_engines[0]
# Wait for engine core process(es) to start. # Wait for engine core process(es) to start.
self._wait_for_engine_startup() self._wait_for_engine_startup(output_address, parallel_config)
self.utility_results: dict[int, AnyFuture] = {} self.utility_results: dict[int, AnyFuture] = {}
@ -420,56 +429,116 @@ class MPClient(EngineCoreClient):
if not success: if not success:
self._finalizer() self._finalizer()
def _wait_for_engine_startup(self): @staticmethod
def _get_zmq_addresses(parallel_config: ParallelConfig,
spmd_mode: bool) -> tuple[str, str]:
"""Returns (input_address, output_address)."""
dp_size = parallel_config.data_parallel_size
local_engine_count = parallel_config.data_parallel_size_local
if local_engine_count == dp_size or spmd_mode:
input_address = get_open_zmq_ipc_path()
output_address = get_open_zmq_ipc_path()
else:
host = parallel_config.data_parallel_master_ip
input_port = parallel_config.data_parallel_rpc_port
output_port = get_open_port()
input_address = get_tcp_uri(host, input_port)
output_address = get_tcp_uri(host, output_port)
return input_address, output_address
def _wait_for_engine_startup(self, output_address: str,
parallel_config: ParallelConfig):
# Get a sync handle to the socket which can be sync or async. # Get a sync handle to the socket which can be sync or async.
sync_input_socket = zmq.Socket.shadow(self.input_socket) sync_input_socket = zmq.Socket.shadow(self.input_socket)
# Wait for engine core process(es) to send ready messages. # Wait for engine core process(es) to send ready messages.
identities = set(eng.index for eng in self.resources.core_engines) local_count = parallel_config.data_parallel_size_local
remote_count = len(self.core_engines) - local_count
# [local, remote] counts
conn_pending, start_pending = [local_count, remote_count], [0, 0]
poller = zmq.Poller() poller = zmq.Poller()
poller.register(sync_input_socket, zmq.POLLIN) poller.register(sync_input_socket, zmq.POLLIN)
for eng in self.resources.core_engines: proc_manager = self.resources.local_engine_manager
poller.register(eng.proc_handle, zmq.POLLIN) if proc_manager is not None:
while identities: for sentinel in proc_manager.sentinels():
poller.register(sentinel, zmq.POLLIN)
while any(conn_pending) or any(start_pending):
events = poller.poll(STARTUP_POLL_PERIOD_MS) events = poller.poll(STARTUP_POLL_PERIOD_MS)
if not events: if not events:
logger.debug("Waiting for %d core engine proc(s) to start: %s", if any(conn_pending):
len(identities), identities) logger.debug(
"Waiting for %d local, %d remote core engine proc(s) "
"to connect.", *conn_pending)
if any(start_pending):
logger.debug(
"Waiting for %d local, %d remote core engine proc(s) "
"to start.", *start_pending)
continue continue
if len(events) > 1 or events[0][0] != sync_input_socket: if len(events) > 1 or events[0][0] != sync_input_socket:
# One of the core processes exited. # One of the local core processes exited.
finished = proc_manager.finished_procs(
) if proc_manager else {}
raise RuntimeError("Engine core initialization failed. " raise RuntimeError("Engine core initialization failed. "
"See root cause above.") "See root cause above. "
f"Failed core proc(s): {finished}")
eng_id_bytes, data = sync_input_socket.recv_multipart() # Receive HELLO and READY messages from the input socket.
eng_id = int.from_bytes(eng_id_bytes, byteorder="little") eng_identity, ready_msg_bytes = sync_input_socket.recv_multipart()
if eng_id not in identities: eng_index = int.from_bytes(eng_identity, byteorder="little")
raise RuntimeError(f"Unexpected or duplicate engine: {eng_id}") engine = next(
message_dict = json.loads(data.decode('utf-8')) (e for e in self.core_engines if e.identity == eng_identity),
if message_dict['type'] != 'READY': None)
raise RuntimeError(f"Engine {eng_id} failed: {data.decode()}") if engine is None:
logger.info("Core engine process %d ready.", eng_id) raise RuntimeError(f"Message from engine with unexpected data "
identities.discard(eng_id) f"parallel rank: {eng_index}")
# Setup KV cache config with initialization state from msg = msgspec.msgpack.decode(ready_msg_bytes)
# engine core process. Sum values from all engines in DP case. status, local = msg["status"], msg["local"]
num_gpu_blocks = self.vllm_config.cache_config.num_gpu_blocks or 0 if local != engine.local:
num_gpu_blocks += message_dict['num_gpu_blocks'] raise RuntimeError(f"{status} message from "
self.vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks f"{'local' if local else 'remote'} "
f"engine {eng_index}, expected it to be "
f"{'local' if engine.local else 'remote'}")
def _init_core_engines( if status == "HELLO" and engine.state == CoreEngineState.NEW:
self,
vllm_config: VllmConfig,
new_core_engine: Callable[[int, Optional[int]], CoreEngine],
core_engines: list[CoreEngine],
) -> None:
# Default case - single core engine. # Send init message with DP config info.
core_engine = new_core_engine( init_message = self.encoder.encode({
vllm_config.parallel_config.data_parallel_rank, "output_socket_address": output_address,
vllm_config.parallel_config.data_parallel_rank_local, "parallel_config": {
) "data_parallel_master_ip":
core_engines.append(core_engine) parallel_config.data_parallel_master_ip,
self.core_engine = core_engine "data_parallel_master_port":
parallel_config.data_parallel_master_port,
"data_parallel_size":
parallel_config.data_parallel_size,
},
})
sync_input_socket.send_multipart((eng_identity, *init_message),
copy=False)
conn_pending[0 if local else 1] -= 1
start_pending[0 if local else 1] += 1
engine.state = CoreEngineState.CONNECTED
elif status == "READY" and (engine.state
== CoreEngineState.CONNECTED):
# Setup KV cache config with initialization state from
# engine core process. Sum values from all engines in DP case.
cache_config = self.vllm_config.cache_config
num_gpu_blocks = cache_config.num_gpu_blocks or 0
num_gpu_blocks += msg['num_gpu_blocks']
cache_config.num_gpu_blocks = num_gpu_blocks
start_pending[0 if local else 1] -= 1
engine.state = CoreEngineState.READY
else:
raise RuntimeError(f"Unexpected {status} message for "
f"{'local' if local else 'remote'} engine "
f"{eng_index} in {engine.state} state.")
logger.debug("%s from %s core engine process %s.", status,
"local" if local else "remote", eng_index)
def shutdown(self): def shutdown(self):
# Terminate background resources. # Terminate background resources.
@ -520,7 +589,8 @@ class SyncMPClient(MPClient):
# Ensure that the outputs socket processing thread does not have # Ensure that the outputs socket processing thread does not have
# a ref to the client which prevents gc. # a ref to the client which prevents gc.
ctx = self.ctx ctx = self.ctx
output_path = self.output_path out_socket = self.resources.output_socket
assert out_socket is not None
decoder = self.decoder decoder = self.decoder
utility_results = self.utility_results utility_results = self.utility_results
outputs_queue = self.outputs_queue outputs_queue = self.outputs_queue
@ -531,7 +601,6 @@ class SyncMPClient(MPClient):
def process_outputs_socket(): def process_outputs_socket():
shutdown_socket = ctx.socket(zmq.PAIR) shutdown_socket = ctx.socket(zmq.PAIR)
out_socket = make_zmq_socket(ctx, output_path, zmq.constants.PULL)
try: try:
shutdown_socket.bind(shutdown_path) shutdown_socket.bind(shutdown_path)
poller = zmq.Poller() poller = zmq.Poller()
@ -566,6 +635,9 @@ class SyncMPClient(MPClient):
daemon=True) daemon=True)
self.output_queue_thread.start() self.output_queue_thread.start()
# The thread takes on responsibility for closing the socket.
self.resources.output_socket = None
def get_output(self) -> EngineCoreOutputs: def get_output(self) -> EngineCoreOutputs:
# If an exception arises in process_outputs_socket task, # If an exception arises in process_outputs_socket task,
# it is forwarded to the outputs_queue so we can raise it # it is forwarded to the outputs_queue so we can raise it
@ -693,10 +765,8 @@ class AsyncMPClient(MPClient):
self.__class__, self.__class__,
"process_engine_outputs", None) "process_engine_outputs", None)
_self_ref = weakref.ref(self) if output_handler else None _self_ref = weakref.ref(self) if output_handler else None
output_path = self.output_path output_socket = resources.output_socket
output_socket = make_zmq_socket(self.ctx, output_path, assert output_socket is not None
zmq.constants.PULL)
resources.output_socket = output_socket
async def process_outputs_socket(): async def process_outputs_socket():
try: try:
@ -861,21 +931,6 @@ class DPAsyncMPClient(AsyncMPClient):
assert len(self.core_engines) > 1 assert len(self.core_engines) > 1
def _init_core_engines(
self,
vllm_config: VllmConfig,
new_core_engine: Callable[[int, Optional[int]], CoreEngine],
core_engines: list[CoreEngine],
) -> None:
# Launch a core engine for each data parallel rank.
dp_size = vllm_config.parallel_config.data_parallel_size
for i in range(dp_size):
# Multi-node not yet supported so local_dp_rank == dp_rank.
core_engines.append(new_core_engine(i, i))
self.core_engines = core_engines
async def call_utility_async(self, method: str, *args) -> Any: async def call_utility_async(self, method: str, *args) -> Any:
# Only the result from the first engine is returned. # Only the result from the first engine is returned.
return (await asyncio.gather(*[ return (await asyncio.gather(*[

View File

@ -1,20 +1,23 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os import os
import time
import weakref import weakref
from collections import defaultdict from collections import defaultdict
from collections.abc import Sequence from collections.abc import Sequence
from multiprocessing import Process from multiprocessing import Process, connection
from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar, from typing import (TYPE_CHECKING, Callable, Generic, Optional, TypeVar, Union,
Union, overload) overload)
import torch import torch
from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.models.utils import extract_layer_index from vllm.model_executor.models.utils import extract_layer_index
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message) usage_message)
from vllm.utils import get_mp_context, kill_process_tree from vllm.utils import get_mp_context, kill_process_tree
from vllm.v1.executor.abstract import Executor
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
@ -92,7 +95,7 @@ class ConstantList(Generic[T], Sequence):
return f"ConstantList({self._x})" return f"ConstantList({self._x})"
class BackgroundProcHandle: class CoreEngineProcManager:
""" """
Utility class to handle creation, readiness, and shutdown Utility class to handle creation, readiness, and shutdown
of background processes used by the AsyncLLM and LLMEngine. of background processes used by the AsyncLLM and LLMEngine.
@ -100,49 +103,91 @@ class BackgroundProcHandle:
def __init__( def __init__(
self, self,
input_path: str,
output_path: str,
process_name: str,
target_fn: Callable, target_fn: Callable,
process_kwargs: dict[Any, Any], local_engine_count: int,
start_index: int,
local_start_index: int,
vllm_config: VllmConfig,
on_head_node: bool,
input_address: str,
executor_class: type[Executor],
log_stats: bool,
): ):
context = get_mp_context() context = get_mp_context()
common_kwargs = {
"vllm_config": vllm_config,
"on_head_node": on_head_node,
"input_address": input_address,
"executor_class": executor_class,
"log_stats": log_stats,
}
assert ("input_path" not in process_kwargs self.processes: list[Process] = []
and "output_path" not in process_kwargs) for index in range(local_engine_count):
process_kwargs["input_path"] = input_path local_index = local_start_index + index
process_kwargs["output_path"] = output_path global_index = start_index + index
# Start EngineCore in background process.
self.processes.append(
context.Process(target=target_fn,
name=f"EngineCore_{global_index}",
kwargs=common_kwargs | {
"dp_rank": global_index,
"local_dp_rank": local_index,
}))
# Run busy loop in background process. self._finalizer = weakref.finalize(self, shutdown, self.processes,
self.proc: Process = context.Process(target=target_fn, input_address)
kwargs=process_kwargs, try:
name=process_name) for proc in self.processes:
self._finalizer = weakref.finalize(self, shutdown, self.proc, proc.start()
input_path, output_path) finally:
self.proc.start() # Kill other procs if not all are running.
if self.finished_procs():
self.close()
def fileno(self): def close(self):
return self.proc.sentinel """Shutdown all procs."""
def shutdown(self):
self._finalizer() self._finalizer()
def join_first(self):
"""Wait for any process to exit."""
connection.wait(proc.sentinel for proc in self.processes)
def sentinels(self) -> list:
return [proc.sentinel for proc in self.processes]
def finished_procs(self) -> dict[str, int]:
"""Returns dict of proc name -> exit code for any finished procs."""
return {
proc.name: proc.exitcode
for proc in self.processes if proc.exitcode is not None
}
# Note(rob): shutdown function cannot be a bound method, # Note(rob): shutdown function cannot be a bound method,
# else the gc cannot collect the object. # else the gc cannot collect the objedecoupct.
def shutdown(proc: Process, input_path: str, output_path: str): def shutdown(procs: list[Process], input_address: str):
# Shutdown the process. # Shutdown the process.
if proc.is_alive(): for proc in procs:
proc.terminate() if proc.is_alive():
proc.join(5) proc.terminate()
# Allow 5 seconds for remaining procs to terminate.
deadline = time.monotonic() + 5
for proc in procs:
remaining = deadline - time.monotonic()
if remaining <= 0:
break
if proc.is_alive():
proc.join(remaining)
for proc in procs:
if proc.is_alive() and (pid := proc.pid) is not None: if proc.is_alive() and (pid := proc.pid) is not None:
kill_process_tree(pid) kill_process_tree(pid)
# Remove zmq ipc socket files. # Remove zmq ipc socket files.
ipc_sockets = [output_path, input_path] if input_address.startswith("ipc://"):
for ipc_socket in ipc_sockets: socket_file = input_address[len("ipc://"):]
socket_file = ipc_socket.replace("ipc://", "")
if os and os.path.exists(socket_file): if os and os.path.exists(socket_file):
os.remove(socket_file) os.remove(socket_file)