[V1] DP scale-out (2/N): Decouple engine process management and comms (#15977)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-05-13 10:48:21 -07:00 committed by GitHub
parent 0b217da646
commit 55aa7af994
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 516 additions and 243 deletions

View File

@ -41,7 +41,7 @@ class MockEngine:
self.abort_request_calls = 0
self.request_id = None
# Ugly, remove dependency when possible
self.parallel_config = ParallelConfig(1, 1, False)
self.parallel_config = ParallelConfig()
self.model_config = MockModelConfig()
async def step_async(self, virtual_engine):

View File

@ -18,9 +18,10 @@ from vllm.platforms import current_platform
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core import EngineCore
from vllm.v1.engine.core_client import (AsyncMPClient, CoreEngine,
EngineCoreClient, SyncMPClient)
from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient,
SyncMPClient)
from vllm.v1.executor.abstract import Executor
from vllm.v1.utils import CoreEngineProcManager
from ...distributed.conftest import MockSubscriber
from ...utils import create_new_process_for_each_test
@ -348,13 +349,13 @@ def test_startup_failure(monkeypatch: pytest.MonkeyPatch):
# Monkey-patch to extract core process pid while it's starting.
core_proc_pid = [None]
ce_ctor = CoreEngine.__init__
cepm_ctor = CoreEngineProcManager.__init__
def patched_ce_ctor(self, *args, **kwargs):
ce_ctor(self, *args, **kwargs)
core_proc_pid[0] = self.proc_handle.proc.pid
def patched_cepm_ctor(self: CoreEngineProcManager, *args, **kwargs):
cepm_ctor(self, *args, **kwargs)
core_proc_pid[0] = self.processes[0].pid
m.setattr(CoreEngine, "__init__", patched_ce_ctor)
m.setattr(CoreEngineProcManager, "__init__", patched_cepm_ctor)
t = time.time()
engine_args = EngineArgs(model=MODEL_NAME)

View File

@ -1668,25 +1668,17 @@ class ParallelConfig:
data_parallel_size: int = 1
"""Number of data parallel groups. MoE layers will be sharded according to
the product of the tensor parallel size and data parallel size."""
data_parallel_size_local: int = 1
"""Number of local data parallel groups."""
data_parallel_rank: int = 0
"""Rank of the data parallel group."""
_data_parallel_rank_local: Optional[int] = field(default=None, init=False)
"""Private field to store the local rank of the data parallel group."""
@property
def data_parallel_rank_local(self) -> int:
"""Local rank of the data parallel group, defaults to global rank."""
if self._data_parallel_rank_local is None:
return self.data_parallel_rank
return self._data_parallel_rank_local
@data_parallel_rank_local.setter
def data_parallel_rank_local(self, value: int) -> None:
"""Set the local rank of the data parallel group."""
self._data_parallel_rank_local = value
data_parallel_rank_local: Optional[int] = None
"""Local rank of the data parallel group,
set only in SPMD mode."""
data_parallel_master_ip: str = "127.0.0.1"
"""IP of the data parallel master."""
data_parallel_rpc_port: int = 29550
"""Port for data parallel messaging."""
data_parallel_master_port: int = 29500
"""Port of the data parallel master."""
enable_expert_parallel: bool = False
@ -1734,13 +1726,16 @@ class ParallelConfig:
world_size: int = field(init=False)
"""world_size is TPxPP, it affects the number of workers we create."""
world_size_across_dp: int = field(init=False)
"""world_size_across_dp is TPxPPxDP, it is the size of the world
including data parallelism."""
rank: int = 0
"""Global rank in distributed setup."""
@property
def world_size_across_dp(self) -> int:
"""world_size_across_dp is TPxPPxDP, it is the size of the world
including data parallelism."""
return self.world_size * self.data_parallel_size
def get_next_dp_init_port(self) -> int:
"""
We might need to initialize process groups in multiple
@ -1800,10 +1795,14 @@ class ParallelConfig:
self.world_size = self.pipeline_parallel_size * \
self.tensor_parallel_size
if self.data_parallel_size > 1:
if self.data_parallel_size_local > self.data_parallel_size:
raise ValueError(
f"data_parallel_size_local ({self.data_parallel_size_local}) "
f"must be <= data_parallel_size ({self.data_parallel_size})")
if self.data_parallel_size > 1 or self.data_parallel_size_local == 0:
# Data parallel was specified in the engine args.
self.data_parallel_master_port = get_open_port()
# TODO multi-node
else:
# Otherwise fall back to env vars (e.g. for offline SPMD case).
self.data_parallel_size = envs.VLLM_DP_SIZE
@ -1812,8 +1811,6 @@ class ParallelConfig:
self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP
self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT
self.world_size_across_dp = self.world_size * self.data_parallel_size
if self.distributed_executor_backend == "external_launcher":
import os
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"

View File

@ -22,6 +22,7 @@ from torch.distributed.rendezvous import rendezvous
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils import get_tcp_uri
logger = init_logger(__name__)
@ -303,7 +304,7 @@ def stateless_init_torch_distributed_process_group(
always formed with process 1, 2, ..., 8, and the additional communication
channel is formed with process 9 and 10.
"""
init_method = f"tcp://{host}:{port}"
init_method = get_tcp_uri(host, port)
backend = Backend(backend) # it is basically string
timeout = _get_default_timeout(backend)

View File

@ -283,6 +283,9 @@ class EngineArgs:
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
data_parallel_size: int = ParallelConfig.data_parallel_size
data_parallel_size_local: Optional[int] = None
data_parallel_address: Optional[str] = None
data_parallel_rpc_port: Optional[int] = None
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
max_parallel_loading_workers: Optional[
int] = ParallelConfig.max_parallel_loading_workers
@ -596,6 +599,21 @@ class EngineArgs:
**parallel_kwargs["tensor_parallel_size"])
parallel_group.add_argument("--data-parallel-size", "-dp",
**parallel_kwargs["data_parallel_size"])
parallel_group.add_argument('--data-parallel-size-local',
'-dpl',
type=int,
help='Number of data parallel replicas '
'to run on this node.')
parallel_group.add_argument('--data-parallel-address',
'-dpa',
type=str,
help='Address of data parallel cluster '
'head-node.')
parallel_group.add_argument('--data-parallel-rpc-port',
'-dpp',
type=int,
help='Port for data parallel RPC '
'communication.')
parallel_group.add_argument(
"--enable-expert-parallel",
**parallel_kwargs["enable_expert_parallel"])
@ -1019,10 +1037,30 @@ class EngineArgs:
# but we should not do this here.
placement_group = ray.util.get_current_placement_group()
# Local DP size defaults to global DP size if not set.
data_parallel_size_local = self.data_parallel_size if (
self.data_parallel_size_local
is None) else self.data_parallel_size_local
# DP address, used in multi-node case for torch distributed group
# and ZMQ sockets.
data_parallel_address = self.data_parallel_address if (
self.data_parallel_address
is not None) else ParallelConfig.data_parallel_master_ip
# This port is only used when there are remote data parallel engines,
# otherwise the local IPC transport is used.
data_parallel_rpc_port = self.data_parallel_rpc_port if (
self.data_parallel_rpc_port
is not None) else ParallelConfig.data_parallel_rpc_port
parallel_config = ParallelConfig(
pipeline_parallel_size=self.pipeline_parallel_size,
tensor_parallel_size=self.tensor_parallel_size,
data_parallel_size=self.data_parallel_size,
data_parallel_size_local=data_parallel_size_local,
data_parallel_master_ip=data_parallel_address,
data_parallel_rpc_port=data_parallel_rpc_port,
enable_expert_parallel=self.enable_expert_parallel,
max_parallel_loading_workers=self.max_parallel_loading_workers,
disable_custom_all_reduce=self.disable_custom_all_reduce,

View File

@ -1,14 +1,24 @@
# SPDX-License-Identifier: Apache-2.0
import argparse
import signal
import uvloop
import vllm.envs as envs
from vllm import AsyncEngineArgs
from vllm.entrypoints.cli.types import CLISubcommand
from vllm.entrypoints.openai.api_server import run_server
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
validate_parsed_serve_args)
from vllm.utils import FlexibleArgumentParser
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, get_tcp_uri
from vllm.v1.engine.core import EngineCoreProc
from vllm.v1.engine.core_client import CoreEngineProcManager
from vllm.v1.executor.abstract import Executor
logger = init_logger(__name__)
class ServeSubcommand(CLISubcommand):
@ -24,7 +34,10 @@ class ServeSubcommand(CLISubcommand):
if hasattr(args, 'model_tag') and args.model_tag is not None:
args.model = args.model_tag
uvloop.run(run_server(args))
if args.headless:
run_headless(args)
else:
uvloop.run(run_server(args))
def validate(self, args: argparse.Namespace) -> None:
validate_parsed_serve_args(args)
@ -42,6 +55,18 @@ class ServeSubcommand(CLISubcommand):
nargs='?',
help="The model tag to serve "
"(optional if specified in config)")
serve_parser.add_argument(
"--headless",
action='store_true',
default=False,
help="Run in headless mode. See multi-node data parallel "
"documentation for more details.")
serve_parser.add_argument(
'--data-parallel-start-rank',
'-dpr',
type=int,
default=0,
help='Starting data parallel rank for secondary nodes.')
serve_parser.add_argument(
"--config",
type=str,
@ -57,3 +82,55 @@ class ServeSubcommand(CLISubcommand):
def cmd_init() -> list[CLISubcommand]:
return [ServeSubcommand()]
def run_headless(args: argparse.Namespace):
# Create the EngineConfig.
engine_args = AsyncEngineArgs.from_cli_args(args)
usage_context = UsageContext.OPENAI_API_SERVER
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
if not envs.VLLM_USE_V1:
raise RuntimeError("Headless mode is only supported for V1")
parallel_config = vllm_config.parallel_config
local_engine_count = parallel_config.data_parallel_size_local
host = parallel_config.data_parallel_master_ip
port = engine_args.data_parallel_rpc_port # add to config too
input_address = get_tcp_uri(host, port)
if local_engine_count <= 0:
raise RuntimeError("data_parallel_size_local must be > 0 in "
"headless mode")
# Catch SIGTERM and SIGINT to allow graceful shutdown.
def signal_handler(signum, frame):
logger.debug("Received %d signal.", signum)
raise SystemExit
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
logger.info(
"Launching %d data parallel engine(s) in headless mode, "
"with head node address %s.", local_engine_count, input_address)
# Create the engines.
engine_manager = CoreEngineProcManager(
target_fn=EngineCoreProc.run_engine_core,
local_engine_count=local_engine_count,
start_index=args.data_parallel_start_rank,
local_start_index=0,
vllm_config=vllm_config,
on_head_node=False,
input_address=input_address,
executor_class=Executor.get_class(vllm_config),
log_stats=not engine_args.disable_log_stats,
)
try:
engine_manager.join_first()
finally:
logger.info("Shutting down.")
engine_manager.close()

View File

@ -613,6 +613,10 @@ def is_valid_ipv6_address(address: str) -> bool:
def get_distributed_init_method(ip: str, port: int) -> str:
return get_tcp_uri(ip, port)
def get_tcp_uri(ip: str, port: int) -> str:
# Brackets are not permitted in ipv4 addresses,
# see https://github.com/python/cpython/issues/103848
return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"

View File

@ -1,5 +1,4 @@
# SPDX-License-Identifier: Apache-2.0
import json
import os
import queue
import signal
@ -23,7 +22,7 @@ from vllm.logging_utils.dump_input import dump_engine_exception
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.utils import resolve_obj_by_qualname, zmq_socket_ctx
from vllm.utils import make_zmq_socket, resolve_obj_by_qualname, zmq_socket_ctx
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
unify_kv_cache_configs)
from vllm.v1.core.sched.interface import SchedulerInterface
@ -43,6 +42,7 @@ from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
POLLING_TIMEOUT_S = 2.5
HANDSHAKE_TIMEOUT_MINS = 5
_R = TypeVar('_R') # Return type for collective_rpc
@ -348,9 +348,9 @@ class EngineCoreProc(EngineCore):
def __init__(
self,
input_path: str,
output_path: str,
vllm_config: VllmConfig,
on_head_node: bool,
input_address: str,
executor_class: type[Executor],
log_stats: bool,
engine_index: int = 0,
@ -360,28 +360,91 @@ class EngineCoreProc(EngineCore):
executor_fail_callback = lambda: input_queue.put_nowait(
(EngineCoreRequestType.EXECUTOR_FAILED, b''))
super().__init__(vllm_config, executor_class, log_stats,
executor_fail_callback)
# Create input socket.
input_ctx = zmq.Context()
identity = engine_index.to_bytes(length=2, byteorder="little")
input_socket = make_zmq_socket(input_ctx,
input_address,
zmq.DEALER,
identity=identity,
bind=False)
try:
# Register engine with front-end.
output_address = self.startup_handshake(
input_socket, on_head_node, vllm_config.parallel_config)
self.step_fn = (self.step if self.batch_queue is None else
self.step_with_batch_queue)
self.engines_running = False
# Update config which may have changed from the handshake.
vllm_config.__post_init__()
# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
# and to overlap some serialization/deserialization with the
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
self.input_queue = input_queue
self.output_queue = queue.Queue[Union[EngineCoreOutputs, bytes]]()
threading.Thread(target=self.process_input_socket,
args=(input_path, engine_index),
daemon=True).start()
self.output_thread = threading.Thread(
target=self.process_output_socket,
args=(output_path, engine_index),
daemon=True)
self.output_thread.start()
# Set up data parallel environment.
self._init_data_parallel(vllm_config)
# Initialize engine core and model.
super().__init__(vllm_config, executor_class, log_stats,
executor_fail_callback)
self.step_fn = (self.step if self.batch_queue is None else
self.step_with_batch_queue)
self.engines_running = False
# Send ready message.
num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
input_socket.send(
msgspec.msgpack.encode({
"status": "READY",
"local": on_head_node,
"num_gpu_blocks": num_gpu_blocks,
}))
# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
# and to overlap some serialization/deserialization with the
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
self.input_queue = input_queue
self.output_queue = queue.Queue[Union[EngineCoreOutputs, bytes]]()
threading.Thread(target=self.process_input_socket,
args=(input_socket, ),
daemon=True).start()
input_socket = None
self.output_thread = threading.Thread(
target=self.process_output_socket,
args=(output_address, engine_index),
daemon=True)
self.output_thread.start()
finally:
if input_socket is not None:
input_socket.close(linger=0)
@staticmethod
def startup_handshake(input_socket: zmq.Socket, on_head_node: bool,
parallel_config: ParallelConfig) -> str:
# Send registration message.
input_socket.send(
msgspec.msgpack.encode({
"status": "HELLO",
"local": on_head_node,
}))
# Receive initialization message.
logger.info("Waiting for init message from front-end.")
if not input_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60 * 1000):
raise RuntimeError("Did not receive response from front-end "
f"process within {HANDSHAKE_TIMEOUT_MINS} "
f"minutes")
init_bytes = input_socket.recv()
init_message = msgspec.msgpack.decode(init_bytes)
logger.debug("Received init message: %s", init_message)
output_socket_address = init_message["output_socket_address"]
#TBD(nick) maybe replace IP with configured head node address
received_parallel_config = init_message["parallel_config"]
for key, value in received_parallel_config.items():
setattr(parallel_config, key, value)
return output_socket_address
@staticmethod
def run_engine_core(*args,
@ -412,7 +475,7 @@ class EngineCoreProc(EngineCore):
try:
parallel_config: ParallelConfig = kwargs[
"vllm_config"].parallel_config
if parallel_config.data_parallel_size > 1:
if parallel_config.data_parallel_size > 1 or dp_rank > 0:
# Set data parallel rank for this engine process.
parallel_config.data_parallel_rank = dp_rank
parallel_config.data_parallel_rank_local = local_dp_rank
@ -436,6 +499,9 @@ class EngineCoreProc(EngineCore):
if engine_core is not None:
engine_core.shutdown()
def _init_data_parallel(self, vllm_config: VllmConfig):
pass
def run_busy_loop(self):
"""Core busy loop of the EngineCore."""
@ -527,40 +593,25 @@ class EngineCoreProc(EngineCore):
logger.fatal("vLLM shutdown signal from EngineCore failed "
"to send. Please report this issue.")
def process_input_socket(self, input_path: str, engine_index: int):
def process_input_socket(self, input_socket: zmq.Socket):
"""Input socket IO thread."""
# Msgpack serialization decoding.
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
generic_decoder = MsgpackDecoder()
identity = engine_index.to_bytes(length=2, byteorder="little")
with zmq_socket_ctx(input_path,
zmq.DEALER,
identity=identity,
bind=False) as socket:
while True:
# (RequestType, RequestData)
type_frame, *data_frames = input_socket.recv_multipart(copy=False)
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
# Send ready message to front-end once input socket is connected.
message_dict = {
'type': 'READY',
'num_gpu_blocks': self.vllm_config.cache_config.num_gpu_blocks,
}
message = json.dumps(message_dict).encode('utf-8')
socket.send(message)
# Deserialize the request data.
decoder = add_request_decoder if (
request_type == EngineCoreRequestType.ADD) else generic_decoder
request = decoder.decode(data_frames)
while True:
# (RequestType, RequestData)
type_frame, *data_frames = socket.recv_multipart(copy=False)
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
# Deserialize the request data.
decoder = add_request_decoder if (
request_type
== EngineCoreRequestType.ADD) else generic_decoder
request = decoder.decode(data_frames)
# Push to input queue for core busy loop.
self.input_queue.put_nowait((request_type, request))
# Push to input queue for core busy loop.
self.input_queue.put_nowait((request_type, request))
def process_output_socket(self, output_path: str, engine_index: int):
"""Output socket IO thread."""
@ -609,9 +660,9 @@ class DPEngineCoreProc(EngineCoreProc):
def __init__(
self,
input_path: str,
output_path: str,
vllm_config: VllmConfig,
on_head_node: bool,
input_address: str,
executor_class: type[Executor],
log_stats: bool,
):
@ -623,8 +674,20 @@ class DPEngineCoreProc(EngineCoreProc):
_add_prefix(sys.stdout, process_name, pid)
_add_prefix(sys.stderr, process_name, pid)
dp_size = vllm_config.parallel_config.data_parallel_size
# Counts forward-passes of the model so that we can synchronize
# finished with DP peers every N steps.
self.counter = 0
# Initialize the engine.
dp_rank = vllm_config.parallel_config.data_parallel_rank
super().__init__(vllm_config, on_head_node, input_address,
executor_class, log_stats, dp_rank)
def _init_data_parallel(self, vllm_config: VllmConfig):
# Configure GPUs and stateless process group for data parallel.
dp_rank = vllm_config.parallel_config.data_parallel_rank
dp_size = vllm_config.parallel_config.data_parallel_size
local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
assert dp_size > 1
@ -632,24 +695,16 @@ class DPEngineCoreProc(EngineCoreProc):
from vllm.platforms import current_platform
device_control_env_var = current_platform.device_control_env_var
tp_size = vllm_config.parallel_config.tensor_parallel_size
world_size = vllm_config.parallel_config.world_size
os.environ[device_control_env_var] = ",".join(
str(current_platform.device_id_to_physical_device_id(i))
for i in range(local_dp_rank * tp_size, (local_dp_rank + 1) *
tp_size))
for i in range(local_dp_rank * world_size, (local_dp_rank + 1) *
world_size))
self.local_dp_rank = local_dp_rank
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
self.current_wave = 0
# Initialize the engine after setting up environment.
super().__init__(input_path, output_path, vllm_config, executor_class,
log_stats, dp_rank)
# Counts forward-passes of the model so that we can synchronize
# finished with DP peers every N steps.
self.counter = 0
def shutdown(self):
super().shutdown()
if dp_group := getattr(self, "dp_group", None):

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import contextlib
import json
import queue
import uuid
import weakref
@ -9,25 +8,27 @@ from abc import ABC, abstractmethod
from collections import deque
from collections.abc import Awaitable, Sequence
from concurrent.futures import Future
from dataclasses import dataclass, field
from dataclasses import dataclass
from enum import Enum, auto
from threading import Thread
from typing import Any, Callable, Optional, TypeVar, Union
import msgspec
import zmq
import zmq.asyncio
from vllm.config import VllmConfig
from vllm.config import ParallelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.utils import (get_open_zmq_inproc_path, get_open_zmq_ipc_path,
make_zmq_socket)
from vllm.utils import (get_open_port, get_open_zmq_inproc_path,
get_open_zmq_ipc_path, get_tcp_uri, make_zmq_socket)
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType, UtilityOutput)
from vllm.v1.engine.core import EngineCore, EngineCoreProc
from vllm.v1.engine.exceptions import EngineDeadError
from vllm.v1.executor.abstract import Executor
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr
from vllm.v1.utils import BackgroundProcHandle
from vllm.v1.utils import CoreEngineProcManager
logger = init_logger(__name__)
@ -264,45 +265,22 @@ class InprocClient(EngineCoreClient):
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
class CoreEngineState(Enum):
NEW = auto()
CONNECTED = auto()
READY = auto()
class CoreEngine:
"""One per data parallel rank."""
def __init__(
self,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
input_path: str,
output_path: str,
index: int = 0,
local_dp_rank: int = 0,
):
def __init__(self, index: int = 0, local: bool = True):
self.local = local
self.index = index
self.identity = index.to_bytes(length=2, byteorder="little")
try:
# Start EngineCore in background process.
self.proc_handle = BackgroundProcHandle(
input_path=input_path,
output_path=output_path,
process_name=f"EngineCore_{index}",
target_fn=EngineCoreProc.run_engine_core,
process_kwargs={
"vllm_config": vllm_config,
"dp_rank": index,
"local_dp_rank": local_dp_rank,
"executor_class": executor_class,
"log_stats": log_stats,
})
self.num_reqs_in_flight = 0
finally:
if not hasattr(self, "num_reqs_in_flight"):
# Ensure socket is closed if process fails to start.
self.close()
def close(self):
if proc_handle := getattr(self, "proc_handle", None):
proc_handle.shutdown()
self.state = CoreEngineState.NEW
self.num_reqs_in_flight = 0
@dataclass
@ -311,7 +289,7 @@ class BackgroundResources:
circular reference back to the client object."""
ctx: Union[zmq.Context]
core_engines: list[CoreEngine] = field(default_factory=list)
local_engine_manager: Optional[CoreEngineProcManager] = None
output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
output_queue_task: Optional[asyncio.Task] = None
@ -325,8 +303,8 @@ class BackgroundResources:
"""Clean up background resources."""
self.engine_dead = True
for core_engine in self.core_engines:
core_engine.close()
if self.local_engine_manager is not None:
self.local_engine_manager.close()
if self.output_queue_task is not None:
self.output_queue_task.cancel()
@ -388,25 +366,56 @@ class MPClient(EngineCoreClient):
self._finalizer = weakref.finalize(self, self.resources)
success = False
try:
# Paths and sockets for IPC.
self.output_path = get_open_zmq_ipc_path()
input_path = get_open_zmq_ipc_path()
self.input_socket = make_zmq_socket(self.ctx,
input_path,
zmq.ROUTER,
bind=True)
self.resources.input_socket = self.input_socket
parallel_config = vllm_config.parallel_config
local_engine_count = parallel_config.data_parallel_size_local
start_index = parallel_config.data_parallel_rank
local_start_index = parallel_config.data_parallel_rank_local
new_core_engine = lambda index, local_dp_rank=None: CoreEngine(
vllm_config, executor_class, log_stats, input_path, self.
output_path, index, local_dp_rank)
# SPMD mode is where there is an LLM instance per DP rank and
# one core engine per LLM, see
# examples/offline_inference/data_parallel.py.
spmd_mode = local_start_index is not None
if spmd_mode:
assert local_engine_count == 1
self.core_engines = [
CoreEngine(index=local_start_index, local=True)
]
else:
assert start_index == 0
local_start_index = 0
self.core_engines = [
CoreEngine(index=i, local=(i < local_engine_count))
for i in range(parallel_config.data_parallel_size)
]
# Start engine core process(es).
self._init_core_engines(vllm_config, new_core_engine,
self.resources.core_engines)
input_address, output_address = self._get_zmq_addresses(
parallel_config, spmd_mode)
# Create input and output sockets.
self.input_socket = self.resources.input_socket = make_zmq_socket(
self.ctx, input_address, zmq.ROUTER, bind=True)
self.resources.output_socket = make_zmq_socket(
self.ctx, output_address, zmq.constants.PULL)
# Start local engines.
if local_engine_count:
# In server mode, start_index and local_start_index will
# both be 0.
self.resources.local_engine_manager = CoreEngineProcManager(
EngineCoreProc.run_engine_core,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=log_stats,
input_address=input_address,
on_head_node=True,
local_engine_count=local_engine_count,
start_index=start_index,
local_start_index=local_start_index)
self.core_engine = self.core_engines[0]
# Wait for engine core process(es) to start.
self._wait_for_engine_startup()
self._wait_for_engine_startup(output_address, parallel_config)
self.utility_results: dict[int, AnyFuture] = {}
@ -420,56 +429,116 @@ class MPClient(EngineCoreClient):
if not success:
self._finalizer()
def _wait_for_engine_startup(self):
@staticmethod
def _get_zmq_addresses(parallel_config: ParallelConfig,
spmd_mode: bool) -> tuple[str, str]:
"""Returns (input_address, output_address)."""
dp_size = parallel_config.data_parallel_size
local_engine_count = parallel_config.data_parallel_size_local
if local_engine_count == dp_size or spmd_mode:
input_address = get_open_zmq_ipc_path()
output_address = get_open_zmq_ipc_path()
else:
host = parallel_config.data_parallel_master_ip
input_port = parallel_config.data_parallel_rpc_port
output_port = get_open_port()
input_address = get_tcp_uri(host, input_port)
output_address = get_tcp_uri(host, output_port)
return input_address, output_address
def _wait_for_engine_startup(self, output_address: str,
parallel_config: ParallelConfig):
# Get a sync handle to the socket which can be sync or async.
sync_input_socket = zmq.Socket.shadow(self.input_socket)
# Wait for engine core process(es) to send ready messages.
identities = set(eng.index for eng in self.resources.core_engines)
local_count = parallel_config.data_parallel_size_local
remote_count = len(self.core_engines) - local_count
# [local, remote] counts
conn_pending, start_pending = [local_count, remote_count], [0, 0]
poller = zmq.Poller()
poller.register(sync_input_socket, zmq.POLLIN)
for eng in self.resources.core_engines:
poller.register(eng.proc_handle, zmq.POLLIN)
while identities:
proc_manager = self.resources.local_engine_manager
if proc_manager is not None:
for sentinel in proc_manager.sentinels():
poller.register(sentinel, zmq.POLLIN)
while any(conn_pending) or any(start_pending):
events = poller.poll(STARTUP_POLL_PERIOD_MS)
if not events:
logger.debug("Waiting for %d core engine proc(s) to start: %s",
len(identities), identities)
if any(conn_pending):
logger.debug(
"Waiting for %d local, %d remote core engine proc(s) "
"to connect.", *conn_pending)
if any(start_pending):
logger.debug(
"Waiting for %d local, %d remote core engine proc(s) "
"to start.", *start_pending)
continue
if len(events) > 1 or events[0][0] != sync_input_socket:
# One of the core processes exited.
# One of the local core processes exited.
finished = proc_manager.finished_procs(
) if proc_manager else {}
raise RuntimeError("Engine core initialization failed. "
"See root cause above.")
"See root cause above. "
f"Failed core proc(s): {finished}")
eng_id_bytes, data = sync_input_socket.recv_multipart()
eng_id = int.from_bytes(eng_id_bytes, byteorder="little")
if eng_id not in identities:
raise RuntimeError(f"Unexpected or duplicate engine: {eng_id}")
message_dict = json.loads(data.decode('utf-8'))
if message_dict['type'] != 'READY':
raise RuntimeError(f"Engine {eng_id} failed: {data.decode()}")
logger.info("Core engine process %d ready.", eng_id)
identities.discard(eng_id)
# Setup KV cache config with initialization state from
# engine core process. Sum values from all engines in DP case.
num_gpu_blocks = self.vllm_config.cache_config.num_gpu_blocks or 0
num_gpu_blocks += message_dict['num_gpu_blocks']
self.vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
# Receive HELLO and READY messages from the input socket.
eng_identity, ready_msg_bytes = sync_input_socket.recv_multipart()
eng_index = int.from_bytes(eng_identity, byteorder="little")
engine = next(
(e for e in self.core_engines if e.identity == eng_identity),
None)
if engine is None:
raise RuntimeError(f"Message from engine with unexpected data "
f"parallel rank: {eng_index}")
msg = msgspec.msgpack.decode(ready_msg_bytes)
status, local = msg["status"], msg["local"]
if local != engine.local:
raise RuntimeError(f"{status} message from "
f"{'local' if local else 'remote'} "
f"engine {eng_index}, expected it to be "
f"{'local' if engine.local else 'remote'}")
def _init_core_engines(
self,
vllm_config: VllmConfig,
new_core_engine: Callable[[int, Optional[int]], CoreEngine],
core_engines: list[CoreEngine],
) -> None:
if status == "HELLO" and engine.state == CoreEngineState.NEW:
# Default case - single core engine.
core_engine = new_core_engine(
vllm_config.parallel_config.data_parallel_rank,
vllm_config.parallel_config.data_parallel_rank_local,
)
core_engines.append(core_engine)
self.core_engine = core_engine
# Send init message with DP config info.
init_message = self.encoder.encode({
"output_socket_address": output_address,
"parallel_config": {
"data_parallel_master_ip":
parallel_config.data_parallel_master_ip,
"data_parallel_master_port":
parallel_config.data_parallel_master_port,
"data_parallel_size":
parallel_config.data_parallel_size,
},
})
sync_input_socket.send_multipart((eng_identity, *init_message),
copy=False)
conn_pending[0 if local else 1] -= 1
start_pending[0 if local else 1] += 1
engine.state = CoreEngineState.CONNECTED
elif status == "READY" and (engine.state
== CoreEngineState.CONNECTED):
# Setup KV cache config with initialization state from
# engine core process. Sum values from all engines in DP case.
cache_config = self.vllm_config.cache_config
num_gpu_blocks = cache_config.num_gpu_blocks or 0
num_gpu_blocks += msg['num_gpu_blocks']
cache_config.num_gpu_blocks = num_gpu_blocks
start_pending[0 if local else 1] -= 1
engine.state = CoreEngineState.READY
else:
raise RuntimeError(f"Unexpected {status} message for "
f"{'local' if local else 'remote'} engine "
f"{eng_index} in {engine.state} state.")
logger.debug("%s from %s core engine process %s.", status,
"local" if local else "remote", eng_index)
def shutdown(self):
# Terminate background resources.
@ -520,7 +589,8 @@ class SyncMPClient(MPClient):
# Ensure that the outputs socket processing thread does not have
# a ref to the client which prevents gc.
ctx = self.ctx
output_path = self.output_path
out_socket = self.resources.output_socket
assert out_socket is not None
decoder = self.decoder
utility_results = self.utility_results
outputs_queue = self.outputs_queue
@ -531,7 +601,6 @@ class SyncMPClient(MPClient):
def process_outputs_socket():
shutdown_socket = ctx.socket(zmq.PAIR)
out_socket = make_zmq_socket(ctx, output_path, zmq.constants.PULL)
try:
shutdown_socket.bind(shutdown_path)
poller = zmq.Poller()
@ -566,6 +635,9 @@ class SyncMPClient(MPClient):
daemon=True)
self.output_queue_thread.start()
# The thread takes on responsibility for closing the socket.
self.resources.output_socket = None
def get_output(self) -> EngineCoreOutputs:
# If an exception arises in process_outputs_socket task,
# it is forwarded to the outputs_queue so we can raise it
@ -693,10 +765,8 @@ class AsyncMPClient(MPClient):
self.__class__,
"process_engine_outputs", None)
_self_ref = weakref.ref(self) if output_handler else None
output_path = self.output_path
output_socket = make_zmq_socket(self.ctx, output_path,
zmq.constants.PULL)
resources.output_socket = output_socket
output_socket = resources.output_socket
assert output_socket is not None
async def process_outputs_socket():
try:
@ -861,21 +931,6 @@ class DPAsyncMPClient(AsyncMPClient):
assert len(self.core_engines) > 1
def _init_core_engines(
self,
vllm_config: VllmConfig,
new_core_engine: Callable[[int, Optional[int]], CoreEngine],
core_engines: list[CoreEngine],
) -> None:
# Launch a core engine for each data parallel rank.
dp_size = vllm_config.parallel_config.data_parallel_size
for i in range(dp_size):
# Multi-node not yet supported so local_dp_rank == dp_rank.
core_engines.append(new_core_engine(i, i))
self.core_engines = core_engines
async def call_utility_async(self, method: str, *args) -> Any:
# Only the result from the first engine is returned.
return (await asyncio.gather(*[

View File

@ -1,20 +1,23 @@
# SPDX-License-Identifier: Apache-2.0
import os
import time
import weakref
from collections import defaultdict
from collections.abc import Sequence
from multiprocessing import Process
from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar,
Union, overload)
from multiprocessing import Process, connection
from typing import (TYPE_CHECKING, Callable, Generic, Optional, TypeVar, Union,
overload)
import torch
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.models.utils import extract_layer_index
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import get_mp_context, kill_process_tree
from vllm.v1.executor.abstract import Executor
if TYPE_CHECKING:
from vllm.attention.layer import Attention
@ -92,7 +95,7 @@ class ConstantList(Generic[T], Sequence):
return f"ConstantList({self._x})"
class BackgroundProcHandle:
class CoreEngineProcManager:
"""
Utility class to handle creation, readiness, and shutdown
of background processes used by the AsyncLLM and LLMEngine.
@ -100,49 +103,91 @@ class BackgroundProcHandle:
def __init__(
self,
input_path: str,
output_path: str,
process_name: str,
target_fn: Callable,
process_kwargs: dict[Any, Any],
local_engine_count: int,
start_index: int,
local_start_index: int,
vllm_config: VllmConfig,
on_head_node: bool,
input_address: str,
executor_class: type[Executor],
log_stats: bool,
):
context = get_mp_context()
common_kwargs = {
"vllm_config": vllm_config,
"on_head_node": on_head_node,
"input_address": input_address,
"executor_class": executor_class,
"log_stats": log_stats,
}
assert ("input_path" not in process_kwargs
and "output_path" not in process_kwargs)
process_kwargs["input_path"] = input_path
process_kwargs["output_path"] = output_path
self.processes: list[Process] = []
for index in range(local_engine_count):
local_index = local_start_index + index
global_index = start_index + index
# Start EngineCore in background process.
self.processes.append(
context.Process(target=target_fn,
name=f"EngineCore_{global_index}",
kwargs=common_kwargs | {
"dp_rank": global_index,
"local_dp_rank": local_index,
}))
# Run busy loop in background process.
self.proc: Process = context.Process(target=target_fn,
kwargs=process_kwargs,
name=process_name)
self._finalizer = weakref.finalize(self, shutdown, self.proc,
input_path, output_path)
self.proc.start()
self._finalizer = weakref.finalize(self, shutdown, self.processes,
input_address)
try:
for proc in self.processes:
proc.start()
finally:
# Kill other procs if not all are running.
if self.finished_procs():
self.close()
def fileno(self):
return self.proc.sentinel
def shutdown(self):
def close(self):
"""Shutdown all procs."""
self._finalizer()
def join_first(self):
"""Wait for any process to exit."""
connection.wait(proc.sentinel for proc in self.processes)
def sentinels(self) -> list:
return [proc.sentinel for proc in self.processes]
def finished_procs(self) -> dict[str, int]:
"""Returns dict of proc name -> exit code for any finished procs."""
return {
proc.name: proc.exitcode
for proc in self.processes if proc.exitcode is not None
}
# Note(rob): shutdown function cannot be a bound method,
# else the gc cannot collect the object.
def shutdown(proc: Process, input_path: str, output_path: str):
# else the gc cannot collect the objedecoupct.
def shutdown(procs: list[Process], input_address: str):
# Shutdown the process.
if proc.is_alive():
proc.terminate()
proc.join(5)
for proc in procs:
if proc.is_alive():
proc.terminate()
# Allow 5 seconds for remaining procs to terminate.
deadline = time.monotonic() + 5
for proc in procs:
remaining = deadline - time.monotonic()
if remaining <= 0:
break
if proc.is_alive():
proc.join(remaining)
for proc in procs:
if proc.is_alive() and (pid := proc.pid) is not None:
kill_process_tree(pid)
# Remove zmq ipc socket files.
ipc_sockets = [output_path, input_path]
for ipc_socket in ipc_sockets:
socket_file = ipc_socket.replace("ipc://", "")
if input_address.startswith("ipc://"):
socket_file = input_address[len("ipc://"):]
if os and os.path.exists(socket_file):
os.remove(socket_file)