mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 22:55:51 +08:00
[V1] DP scale-out (2/N): Decouple engine process management and comms (#15977)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
0b217da646
commit
55aa7af994
@ -41,7 +41,7 @@ class MockEngine:
|
||||
self.abort_request_calls = 0
|
||||
self.request_id = None
|
||||
# Ugly, remove dependency when possible
|
||||
self.parallel_config = ParallelConfig(1, 1, False)
|
||||
self.parallel_config = ParallelConfig()
|
||||
self.model_config = MockModelConfig()
|
||||
|
||||
async def step_async(self, virtual_engine):
|
||||
|
||||
@ -18,9 +18,10 @@ from vllm.platforms import current_platform
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.core import EngineCore
|
||||
from vllm.v1.engine.core_client import (AsyncMPClient, CoreEngine,
|
||||
EngineCoreClient, SyncMPClient)
|
||||
from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient,
|
||||
SyncMPClient)
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.utils import CoreEngineProcManager
|
||||
|
||||
from ...distributed.conftest import MockSubscriber
|
||||
from ...utils import create_new_process_for_each_test
|
||||
@ -348,13 +349,13 @@ def test_startup_failure(monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
# Monkey-patch to extract core process pid while it's starting.
|
||||
core_proc_pid = [None]
|
||||
ce_ctor = CoreEngine.__init__
|
||||
cepm_ctor = CoreEngineProcManager.__init__
|
||||
|
||||
def patched_ce_ctor(self, *args, **kwargs):
|
||||
ce_ctor(self, *args, **kwargs)
|
||||
core_proc_pid[0] = self.proc_handle.proc.pid
|
||||
def patched_cepm_ctor(self: CoreEngineProcManager, *args, **kwargs):
|
||||
cepm_ctor(self, *args, **kwargs)
|
||||
core_proc_pid[0] = self.processes[0].pid
|
||||
|
||||
m.setattr(CoreEngine, "__init__", patched_ce_ctor)
|
||||
m.setattr(CoreEngineProcManager, "__init__", patched_cepm_ctor)
|
||||
|
||||
t = time.time()
|
||||
engine_args = EngineArgs(model=MODEL_NAME)
|
||||
|
||||
@ -1668,25 +1668,17 @@ class ParallelConfig:
|
||||
data_parallel_size: int = 1
|
||||
"""Number of data parallel groups. MoE layers will be sharded according to
|
||||
the product of the tensor parallel size and data parallel size."""
|
||||
data_parallel_size_local: int = 1
|
||||
"""Number of local data parallel groups."""
|
||||
data_parallel_rank: int = 0
|
||||
"""Rank of the data parallel group."""
|
||||
_data_parallel_rank_local: Optional[int] = field(default=None, init=False)
|
||||
"""Private field to store the local rank of the data parallel group."""
|
||||
|
||||
@property
|
||||
def data_parallel_rank_local(self) -> int:
|
||||
"""Local rank of the data parallel group, defaults to global rank."""
|
||||
if self._data_parallel_rank_local is None:
|
||||
return self.data_parallel_rank
|
||||
return self._data_parallel_rank_local
|
||||
|
||||
@data_parallel_rank_local.setter
|
||||
def data_parallel_rank_local(self, value: int) -> None:
|
||||
"""Set the local rank of the data parallel group."""
|
||||
self._data_parallel_rank_local = value
|
||||
|
||||
data_parallel_rank_local: Optional[int] = None
|
||||
"""Local rank of the data parallel group,
|
||||
set only in SPMD mode."""
|
||||
data_parallel_master_ip: str = "127.0.0.1"
|
||||
"""IP of the data parallel master."""
|
||||
data_parallel_rpc_port: int = 29550
|
||||
"""Port for data parallel messaging."""
|
||||
data_parallel_master_port: int = 29500
|
||||
"""Port of the data parallel master."""
|
||||
enable_expert_parallel: bool = False
|
||||
@ -1734,13 +1726,16 @@ class ParallelConfig:
|
||||
|
||||
world_size: int = field(init=False)
|
||||
"""world_size is TPxPP, it affects the number of workers we create."""
|
||||
world_size_across_dp: int = field(init=False)
|
||||
"""world_size_across_dp is TPxPPxDP, it is the size of the world
|
||||
including data parallelism."""
|
||||
|
||||
rank: int = 0
|
||||
"""Global rank in distributed setup."""
|
||||
|
||||
@property
|
||||
def world_size_across_dp(self) -> int:
|
||||
"""world_size_across_dp is TPxPPxDP, it is the size of the world
|
||||
including data parallelism."""
|
||||
return self.world_size * self.data_parallel_size
|
||||
|
||||
def get_next_dp_init_port(self) -> int:
|
||||
"""
|
||||
We might need to initialize process groups in multiple
|
||||
@ -1800,10 +1795,14 @@ class ParallelConfig:
|
||||
self.world_size = self.pipeline_parallel_size * \
|
||||
self.tensor_parallel_size
|
||||
|
||||
if self.data_parallel_size > 1:
|
||||
if self.data_parallel_size_local > self.data_parallel_size:
|
||||
raise ValueError(
|
||||
f"data_parallel_size_local ({self.data_parallel_size_local}) "
|
||||
f"must be <= data_parallel_size ({self.data_parallel_size})")
|
||||
|
||||
if self.data_parallel_size > 1 or self.data_parallel_size_local == 0:
|
||||
# Data parallel was specified in the engine args.
|
||||
self.data_parallel_master_port = get_open_port()
|
||||
# TODO multi-node
|
||||
else:
|
||||
# Otherwise fall back to env vars (e.g. for offline SPMD case).
|
||||
self.data_parallel_size = envs.VLLM_DP_SIZE
|
||||
@ -1812,8 +1811,6 @@ class ParallelConfig:
|
||||
self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP
|
||||
self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT
|
||||
|
||||
self.world_size_across_dp = self.world_size * self.data_parallel_size
|
||||
|
||||
if self.distributed_executor_backend == "external_launcher":
|
||||
import os
|
||||
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
|
||||
|
||||
@ -22,6 +22,7 @@ from torch.distributed.rendezvous import rendezvous
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import get_tcp_uri
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -303,7 +304,7 @@ def stateless_init_torch_distributed_process_group(
|
||||
always formed with process 1, 2, ..., 8, and the additional communication
|
||||
channel is formed with process 9 and 10.
|
||||
"""
|
||||
init_method = f"tcp://{host}:{port}"
|
||||
init_method = get_tcp_uri(host, port)
|
||||
backend = Backend(backend) # it is basically string
|
||||
timeout = _get_default_timeout(backend)
|
||||
|
||||
|
||||
@ -283,6 +283,9 @@ class EngineArgs:
|
||||
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
|
||||
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
|
||||
data_parallel_size: int = ParallelConfig.data_parallel_size
|
||||
data_parallel_size_local: Optional[int] = None
|
||||
data_parallel_address: Optional[str] = None
|
||||
data_parallel_rpc_port: Optional[int] = None
|
||||
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
|
||||
max_parallel_loading_workers: Optional[
|
||||
int] = ParallelConfig.max_parallel_loading_workers
|
||||
@ -596,6 +599,21 @@ class EngineArgs:
|
||||
**parallel_kwargs["tensor_parallel_size"])
|
||||
parallel_group.add_argument("--data-parallel-size", "-dp",
|
||||
**parallel_kwargs["data_parallel_size"])
|
||||
parallel_group.add_argument('--data-parallel-size-local',
|
||||
'-dpl',
|
||||
type=int,
|
||||
help='Number of data parallel replicas '
|
||||
'to run on this node.')
|
||||
parallel_group.add_argument('--data-parallel-address',
|
||||
'-dpa',
|
||||
type=str,
|
||||
help='Address of data parallel cluster '
|
||||
'head-node.')
|
||||
parallel_group.add_argument('--data-parallel-rpc-port',
|
||||
'-dpp',
|
||||
type=int,
|
||||
help='Port for data parallel RPC '
|
||||
'communication.')
|
||||
parallel_group.add_argument(
|
||||
"--enable-expert-parallel",
|
||||
**parallel_kwargs["enable_expert_parallel"])
|
||||
@ -1019,10 +1037,30 @@ class EngineArgs:
|
||||
# but we should not do this here.
|
||||
placement_group = ray.util.get_current_placement_group()
|
||||
|
||||
# Local DP size defaults to global DP size if not set.
|
||||
data_parallel_size_local = self.data_parallel_size if (
|
||||
self.data_parallel_size_local
|
||||
is None) else self.data_parallel_size_local
|
||||
|
||||
# DP address, used in multi-node case for torch distributed group
|
||||
# and ZMQ sockets.
|
||||
data_parallel_address = self.data_parallel_address if (
|
||||
self.data_parallel_address
|
||||
is not None) else ParallelConfig.data_parallel_master_ip
|
||||
|
||||
# This port is only used when there are remote data parallel engines,
|
||||
# otherwise the local IPC transport is used.
|
||||
data_parallel_rpc_port = self.data_parallel_rpc_port if (
|
||||
self.data_parallel_rpc_port
|
||||
is not None) else ParallelConfig.data_parallel_rpc_port
|
||||
|
||||
parallel_config = ParallelConfig(
|
||||
pipeline_parallel_size=self.pipeline_parallel_size,
|
||||
tensor_parallel_size=self.tensor_parallel_size,
|
||||
data_parallel_size=self.data_parallel_size,
|
||||
data_parallel_size_local=data_parallel_size_local,
|
||||
data_parallel_master_ip=data_parallel_address,
|
||||
data_parallel_rpc_port=data_parallel_rpc_port,
|
||||
enable_expert_parallel=self.enable_expert_parallel,
|
||||
max_parallel_loading_workers=self.max_parallel_loading_workers,
|
||||
disable_custom_all_reduce=self.disable_custom_all_reduce,
|
||||
|
||||
@ -1,14 +1,24 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import argparse
|
||||
import signal
|
||||
|
||||
import uvloop
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import AsyncEngineArgs
|
||||
from vllm.entrypoints.cli.types import CLISubcommand
|
||||
from vllm.entrypoints.openai.api_server import run_server
|
||||
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
|
||||
validate_parsed_serve_args)
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
from vllm.logger import init_logger
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import FlexibleArgumentParser, get_tcp_uri
|
||||
from vllm.v1.engine.core import EngineCoreProc
|
||||
from vllm.v1.engine.core_client import CoreEngineProcManager
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ServeSubcommand(CLISubcommand):
|
||||
@ -24,7 +34,10 @@ class ServeSubcommand(CLISubcommand):
|
||||
if hasattr(args, 'model_tag') and args.model_tag is not None:
|
||||
args.model = args.model_tag
|
||||
|
||||
uvloop.run(run_server(args))
|
||||
if args.headless:
|
||||
run_headless(args)
|
||||
else:
|
||||
uvloop.run(run_server(args))
|
||||
|
||||
def validate(self, args: argparse.Namespace) -> None:
|
||||
validate_parsed_serve_args(args)
|
||||
@ -42,6 +55,18 @@ class ServeSubcommand(CLISubcommand):
|
||||
nargs='?',
|
||||
help="The model tag to serve "
|
||||
"(optional if specified in config)")
|
||||
serve_parser.add_argument(
|
||||
"--headless",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Run in headless mode. See multi-node data parallel "
|
||||
"documentation for more details.")
|
||||
serve_parser.add_argument(
|
||||
'--data-parallel-start-rank',
|
||||
'-dpr',
|
||||
type=int,
|
||||
default=0,
|
||||
help='Starting data parallel rank for secondary nodes.')
|
||||
serve_parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
@ -57,3 +82,55 @@ class ServeSubcommand(CLISubcommand):
|
||||
|
||||
def cmd_init() -> list[CLISubcommand]:
|
||||
return [ServeSubcommand()]
|
||||
|
||||
|
||||
def run_headless(args: argparse.Namespace):
|
||||
|
||||
# Create the EngineConfig.
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
usage_context = UsageContext.OPENAI_API_SERVER
|
||||
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
|
||||
|
||||
if not envs.VLLM_USE_V1:
|
||||
raise RuntimeError("Headless mode is only supported for V1")
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
local_engine_count = parallel_config.data_parallel_size_local
|
||||
host = parallel_config.data_parallel_master_ip
|
||||
port = engine_args.data_parallel_rpc_port # add to config too
|
||||
input_address = get_tcp_uri(host, port)
|
||||
|
||||
if local_engine_count <= 0:
|
||||
raise RuntimeError("data_parallel_size_local must be > 0 in "
|
||||
"headless mode")
|
||||
|
||||
# Catch SIGTERM and SIGINT to allow graceful shutdown.
|
||||
def signal_handler(signum, frame):
|
||||
logger.debug("Received %d signal.", signum)
|
||||
raise SystemExit
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
logger.info(
|
||||
"Launching %d data parallel engine(s) in headless mode, "
|
||||
"with head node address %s.", local_engine_count, input_address)
|
||||
|
||||
# Create the engines.
|
||||
engine_manager = CoreEngineProcManager(
|
||||
target_fn=EngineCoreProc.run_engine_core,
|
||||
local_engine_count=local_engine_count,
|
||||
start_index=args.data_parallel_start_rank,
|
||||
local_start_index=0,
|
||||
vllm_config=vllm_config,
|
||||
on_head_node=False,
|
||||
input_address=input_address,
|
||||
executor_class=Executor.get_class(vllm_config),
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
)
|
||||
|
||||
try:
|
||||
engine_manager.join_first()
|
||||
finally:
|
||||
logger.info("Shutting down.")
|
||||
engine_manager.close()
|
||||
|
||||
@ -613,6 +613,10 @@ def is_valid_ipv6_address(address: str) -> bool:
|
||||
|
||||
|
||||
def get_distributed_init_method(ip: str, port: int) -> str:
|
||||
return get_tcp_uri(ip, port)
|
||||
|
||||
|
||||
def get_tcp_uri(ip: str, port: int) -> str:
|
||||
# Brackets are not permitted in ipv4 addresses,
|
||||
# see https://github.com/python/cpython/issues/103848
|
||||
return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import json
|
||||
import os
|
||||
import queue
|
||||
import signal
|
||||
@ -23,7 +22,7 @@ from vllm.logging_utils.dump_input import dump_engine_exception
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.transformers_utils.config import (
|
||||
maybe_register_config_serialize_by_value)
|
||||
from vllm.utils import resolve_obj_by_qualname, zmq_socket_ctx
|
||||
from vllm.utils import make_zmq_socket, resolve_obj_by_qualname, zmq_socket_ctx
|
||||
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
|
||||
unify_kv_cache_configs)
|
||||
from vllm.v1.core.sched.interface import SchedulerInterface
|
||||
@ -43,6 +42,7 @@ from vllm.version import __version__ as VLLM_VERSION
|
||||
logger = init_logger(__name__)
|
||||
|
||||
POLLING_TIMEOUT_S = 2.5
|
||||
HANDSHAKE_TIMEOUT_MINS = 5
|
||||
|
||||
_R = TypeVar('_R') # Return type for collective_rpc
|
||||
|
||||
@ -348,9 +348,9 @@ class EngineCoreProc(EngineCore):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_path: str,
|
||||
output_path: str,
|
||||
vllm_config: VllmConfig,
|
||||
on_head_node: bool,
|
||||
input_address: str,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
engine_index: int = 0,
|
||||
@ -360,28 +360,91 @@ class EngineCoreProc(EngineCore):
|
||||
executor_fail_callback = lambda: input_queue.put_nowait(
|
||||
(EngineCoreRequestType.EXECUTOR_FAILED, b''))
|
||||
|
||||
super().__init__(vllm_config, executor_class, log_stats,
|
||||
executor_fail_callback)
|
||||
# Create input socket.
|
||||
input_ctx = zmq.Context()
|
||||
identity = engine_index.to_bytes(length=2, byteorder="little")
|
||||
input_socket = make_zmq_socket(input_ctx,
|
||||
input_address,
|
||||
zmq.DEALER,
|
||||
identity=identity,
|
||||
bind=False)
|
||||
try:
|
||||
# Register engine with front-end.
|
||||
output_address = self.startup_handshake(
|
||||
input_socket, on_head_node, vllm_config.parallel_config)
|
||||
|
||||
self.step_fn = (self.step if self.batch_queue is None else
|
||||
self.step_with_batch_queue)
|
||||
self.engines_running = False
|
||||
# Update config which may have changed from the handshake.
|
||||
vllm_config.__post_init__()
|
||||
|
||||
# Background Threads and Queues for IO. These enable us to
|
||||
# overlap ZMQ socket IO with GPU since they release the GIL,
|
||||
# and to overlap some serialization/deserialization with the
|
||||
# model forward pass.
|
||||
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
|
||||
self.input_queue = input_queue
|
||||
self.output_queue = queue.Queue[Union[EngineCoreOutputs, bytes]]()
|
||||
threading.Thread(target=self.process_input_socket,
|
||||
args=(input_path, engine_index),
|
||||
daemon=True).start()
|
||||
self.output_thread = threading.Thread(
|
||||
target=self.process_output_socket,
|
||||
args=(output_path, engine_index),
|
||||
daemon=True)
|
||||
self.output_thread.start()
|
||||
# Set up data parallel environment.
|
||||
self._init_data_parallel(vllm_config)
|
||||
|
||||
# Initialize engine core and model.
|
||||
super().__init__(vllm_config, executor_class, log_stats,
|
||||
executor_fail_callback)
|
||||
|
||||
self.step_fn = (self.step if self.batch_queue is None else
|
||||
self.step_with_batch_queue)
|
||||
self.engines_running = False
|
||||
|
||||
# Send ready message.
|
||||
num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
|
||||
input_socket.send(
|
||||
msgspec.msgpack.encode({
|
||||
"status": "READY",
|
||||
"local": on_head_node,
|
||||
"num_gpu_blocks": num_gpu_blocks,
|
||||
}))
|
||||
|
||||
# Background Threads and Queues for IO. These enable us to
|
||||
# overlap ZMQ socket IO with GPU since they release the GIL,
|
||||
# and to overlap some serialization/deserialization with the
|
||||
# model forward pass.
|
||||
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
|
||||
self.input_queue = input_queue
|
||||
self.output_queue = queue.Queue[Union[EngineCoreOutputs, bytes]]()
|
||||
threading.Thread(target=self.process_input_socket,
|
||||
args=(input_socket, ),
|
||||
daemon=True).start()
|
||||
input_socket = None
|
||||
self.output_thread = threading.Thread(
|
||||
target=self.process_output_socket,
|
||||
args=(output_address, engine_index),
|
||||
daemon=True)
|
||||
self.output_thread.start()
|
||||
finally:
|
||||
if input_socket is not None:
|
||||
input_socket.close(linger=0)
|
||||
|
||||
@staticmethod
|
||||
def startup_handshake(input_socket: zmq.Socket, on_head_node: bool,
|
||||
parallel_config: ParallelConfig) -> str:
|
||||
|
||||
# Send registration message.
|
||||
input_socket.send(
|
||||
msgspec.msgpack.encode({
|
||||
"status": "HELLO",
|
||||
"local": on_head_node,
|
||||
}))
|
||||
|
||||
# Receive initialization message.
|
||||
logger.info("Waiting for init message from front-end.")
|
||||
if not input_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60 * 1000):
|
||||
raise RuntimeError("Did not receive response from front-end "
|
||||
f"process within {HANDSHAKE_TIMEOUT_MINS} "
|
||||
f"minutes")
|
||||
init_bytes = input_socket.recv()
|
||||
init_message = msgspec.msgpack.decode(init_bytes)
|
||||
logger.debug("Received init message: %s", init_message)
|
||||
|
||||
output_socket_address = init_message["output_socket_address"]
|
||||
#TBD(nick) maybe replace IP with configured head node address
|
||||
|
||||
received_parallel_config = init_message["parallel_config"]
|
||||
for key, value in received_parallel_config.items():
|
||||
setattr(parallel_config, key, value)
|
||||
|
||||
return output_socket_address
|
||||
|
||||
@staticmethod
|
||||
def run_engine_core(*args,
|
||||
@ -412,7 +475,7 @@ class EngineCoreProc(EngineCore):
|
||||
try:
|
||||
parallel_config: ParallelConfig = kwargs[
|
||||
"vllm_config"].parallel_config
|
||||
if parallel_config.data_parallel_size > 1:
|
||||
if parallel_config.data_parallel_size > 1 or dp_rank > 0:
|
||||
# Set data parallel rank for this engine process.
|
||||
parallel_config.data_parallel_rank = dp_rank
|
||||
parallel_config.data_parallel_rank_local = local_dp_rank
|
||||
@ -436,6 +499,9 @@ class EngineCoreProc(EngineCore):
|
||||
if engine_core is not None:
|
||||
engine_core.shutdown()
|
||||
|
||||
def _init_data_parallel(self, vllm_config: VllmConfig):
|
||||
pass
|
||||
|
||||
def run_busy_loop(self):
|
||||
"""Core busy loop of the EngineCore."""
|
||||
|
||||
@ -527,40 +593,25 @@ class EngineCoreProc(EngineCore):
|
||||
logger.fatal("vLLM shutdown signal from EngineCore failed "
|
||||
"to send. Please report this issue.")
|
||||
|
||||
def process_input_socket(self, input_path: str, engine_index: int):
|
||||
def process_input_socket(self, input_socket: zmq.Socket):
|
||||
"""Input socket IO thread."""
|
||||
|
||||
# Msgpack serialization decoding.
|
||||
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
|
||||
generic_decoder = MsgpackDecoder()
|
||||
identity = engine_index.to_bytes(length=2, byteorder="little")
|
||||
|
||||
with zmq_socket_ctx(input_path,
|
||||
zmq.DEALER,
|
||||
identity=identity,
|
||||
bind=False) as socket:
|
||||
while True:
|
||||
# (RequestType, RequestData)
|
||||
type_frame, *data_frames = input_socket.recv_multipart(copy=False)
|
||||
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
|
||||
|
||||
# Send ready message to front-end once input socket is connected.
|
||||
message_dict = {
|
||||
'type': 'READY',
|
||||
'num_gpu_blocks': self.vllm_config.cache_config.num_gpu_blocks,
|
||||
}
|
||||
message = json.dumps(message_dict).encode('utf-8')
|
||||
socket.send(message)
|
||||
# Deserialize the request data.
|
||||
decoder = add_request_decoder if (
|
||||
request_type == EngineCoreRequestType.ADD) else generic_decoder
|
||||
request = decoder.decode(data_frames)
|
||||
|
||||
while True:
|
||||
# (RequestType, RequestData)
|
||||
type_frame, *data_frames = socket.recv_multipart(copy=False)
|
||||
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
|
||||
|
||||
# Deserialize the request data.
|
||||
decoder = add_request_decoder if (
|
||||
request_type
|
||||
== EngineCoreRequestType.ADD) else generic_decoder
|
||||
request = decoder.decode(data_frames)
|
||||
|
||||
# Push to input queue for core busy loop.
|
||||
self.input_queue.put_nowait((request_type, request))
|
||||
# Push to input queue for core busy loop.
|
||||
self.input_queue.put_nowait((request_type, request))
|
||||
|
||||
def process_output_socket(self, output_path: str, engine_index: int):
|
||||
"""Output socket IO thread."""
|
||||
@ -609,9 +660,9 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_path: str,
|
||||
output_path: str,
|
||||
vllm_config: VllmConfig,
|
||||
on_head_node: bool,
|
||||
input_address: str,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
):
|
||||
@ -623,8 +674,20 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
_add_prefix(sys.stdout, process_name, pid)
|
||||
_add_prefix(sys.stderr, process_name, pid)
|
||||
|
||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
# Counts forward-passes of the model so that we can synchronize
|
||||
# finished with DP peers every N steps.
|
||||
self.counter = 0
|
||||
|
||||
# Initialize the engine.
|
||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
super().__init__(vllm_config, on_head_node, input_address,
|
||||
executor_class, log_stats, dp_rank)
|
||||
|
||||
def _init_data_parallel(self, vllm_config: VllmConfig):
|
||||
|
||||
# Configure GPUs and stateless process group for data parallel.
|
||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
|
||||
|
||||
assert dp_size > 1
|
||||
@ -632,24 +695,16 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
device_control_env_var = current_platform.device_control_env_var
|
||||
tp_size = vllm_config.parallel_config.tensor_parallel_size
|
||||
world_size = vllm_config.parallel_config.world_size
|
||||
os.environ[device_control_env_var] = ",".join(
|
||||
str(current_platform.device_id_to_physical_device_id(i))
|
||||
for i in range(local_dp_rank * tp_size, (local_dp_rank + 1) *
|
||||
tp_size))
|
||||
for i in range(local_dp_rank * world_size, (local_dp_rank + 1) *
|
||||
world_size))
|
||||
|
||||
self.local_dp_rank = local_dp_rank
|
||||
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
|
||||
self.current_wave = 0
|
||||
|
||||
# Initialize the engine after setting up environment.
|
||||
super().__init__(input_path, output_path, vllm_config, executor_class,
|
||||
log_stats, dp_rank)
|
||||
|
||||
# Counts forward-passes of the model so that we can synchronize
|
||||
# finished with DP peers every N steps.
|
||||
self.counter = 0
|
||||
|
||||
def shutdown(self):
|
||||
super().shutdown()
|
||||
if dp_group := getattr(self, "dp_group", None):
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import queue
|
||||
import uuid
|
||||
import weakref
|
||||
@ -9,25 +8,27 @@ from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from collections.abc import Awaitable, Sequence
|
||||
from concurrent.futures import Future
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from threading import Thread
|
||||
from typing import Any, Callable, Optional, TypeVar, Union
|
||||
|
||||
import msgspec
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.utils import (get_open_zmq_inproc_path, get_open_zmq_ipc_path,
|
||||
make_zmq_socket)
|
||||
from vllm.utils import (get_open_port, get_open_zmq_inproc_path,
|
||||
get_open_zmq_ipc_path, get_tcp_uri, make_zmq_socket)
|
||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
||||
EngineCoreRequestType, UtilityOutput)
|
||||
from vllm.v1.engine.core import EngineCore, EngineCoreProc
|
||||
from vllm.v1.engine.exceptions import EngineDeadError
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr
|
||||
from vllm.v1.utils import BackgroundProcHandle
|
||||
from vllm.v1.utils import CoreEngineProcManager
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -264,45 +265,22 @@ class InprocClient(EngineCoreClient):
|
||||
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
|
||||
|
||||
|
||||
class CoreEngineState(Enum):
|
||||
NEW = auto()
|
||||
CONNECTED = auto()
|
||||
READY = auto()
|
||||
|
||||
|
||||
class CoreEngine:
|
||||
"""One per data parallel rank."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
input_path: str,
|
||||
output_path: str,
|
||||
index: int = 0,
|
||||
local_dp_rank: int = 0,
|
||||
):
|
||||
def __init__(self, index: int = 0, local: bool = True):
|
||||
self.local = local
|
||||
self.index = index
|
||||
self.identity = index.to_bytes(length=2, byteorder="little")
|
||||
try:
|
||||
# Start EngineCore in background process.
|
||||
self.proc_handle = BackgroundProcHandle(
|
||||
input_path=input_path,
|
||||
output_path=output_path,
|
||||
process_name=f"EngineCore_{index}",
|
||||
target_fn=EngineCoreProc.run_engine_core,
|
||||
process_kwargs={
|
||||
"vllm_config": vllm_config,
|
||||
"dp_rank": index,
|
||||
"local_dp_rank": local_dp_rank,
|
||||
"executor_class": executor_class,
|
||||
"log_stats": log_stats,
|
||||
})
|
||||
|
||||
self.num_reqs_in_flight = 0
|
||||
finally:
|
||||
if not hasattr(self, "num_reqs_in_flight"):
|
||||
# Ensure socket is closed if process fails to start.
|
||||
self.close()
|
||||
|
||||
def close(self):
|
||||
if proc_handle := getattr(self, "proc_handle", None):
|
||||
proc_handle.shutdown()
|
||||
self.state = CoreEngineState.NEW
|
||||
self.num_reqs_in_flight = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -311,7 +289,7 @@ class BackgroundResources:
|
||||
circular reference back to the client object."""
|
||||
|
||||
ctx: Union[zmq.Context]
|
||||
core_engines: list[CoreEngine] = field(default_factory=list)
|
||||
local_engine_manager: Optional[CoreEngineProcManager] = None
|
||||
output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
|
||||
input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
|
||||
output_queue_task: Optional[asyncio.Task] = None
|
||||
@ -325,8 +303,8 @@ class BackgroundResources:
|
||||
"""Clean up background resources."""
|
||||
|
||||
self.engine_dead = True
|
||||
for core_engine in self.core_engines:
|
||||
core_engine.close()
|
||||
if self.local_engine_manager is not None:
|
||||
self.local_engine_manager.close()
|
||||
|
||||
if self.output_queue_task is not None:
|
||||
self.output_queue_task.cancel()
|
||||
@ -388,25 +366,56 @@ class MPClient(EngineCoreClient):
|
||||
self._finalizer = weakref.finalize(self, self.resources)
|
||||
success = False
|
||||
try:
|
||||
# Paths and sockets for IPC.
|
||||
self.output_path = get_open_zmq_ipc_path()
|
||||
input_path = get_open_zmq_ipc_path()
|
||||
self.input_socket = make_zmq_socket(self.ctx,
|
||||
input_path,
|
||||
zmq.ROUTER,
|
||||
bind=True)
|
||||
self.resources.input_socket = self.input_socket
|
||||
parallel_config = vllm_config.parallel_config
|
||||
local_engine_count = parallel_config.data_parallel_size_local
|
||||
start_index = parallel_config.data_parallel_rank
|
||||
local_start_index = parallel_config.data_parallel_rank_local
|
||||
|
||||
new_core_engine = lambda index, local_dp_rank=None: CoreEngine(
|
||||
vllm_config, executor_class, log_stats, input_path, self.
|
||||
output_path, index, local_dp_rank)
|
||||
# SPMD mode is where there is an LLM instance per DP rank and
|
||||
# one core engine per LLM, see
|
||||
# examples/offline_inference/data_parallel.py.
|
||||
spmd_mode = local_start_index is not None
|
||||
if spmd_mode:
|
||||
assert local_engine_count == 1
|
||||
self.core_engines = [
|
||||
CoreEngine(index=local_start_index, local=True)
|
||||
]
|
||||
else:
|
||||
assert start_index == 0
|
||||
local_start_index = 0
|
||||
self.core_engines = [
|
||||
CoreEngine(index=i, local=(i < local_engine_count))
|
||||
for i in range(parallel_config.data_parallel_size)
|
||||
]
|
||||
|
||||
# Start engine core process(es).
|
||||
self._init_core_engines(vllm_config, new_core_engine,
|
||||
self.resources.core_engines)
|
||||
input_address, output_address = self._get_zmq_addresses(
|
||||
parallel_config, spmd_mode)
|
||||
|
||||
# Create input and output sockets.
|
||||
self.input_socket = self.resources.input_socket = make_zmq_socket(
|
||||
self.ctx, input_address, zmq.ROUTER, bind=True)
|
||||
|
||||
self.resources.output_socket = make_zmq_socket(
|
||||
self.ctx, output_address, zmq.constants.PULL)
|
||||
# Start local engines.
|
||||
if local_engine_count:
|
||||
# In server mode, start_index and local_start_index will
|
||||
# both be 0.
|
||||
self.resources.local_engine_manager = CoreEngineProcManager(
|
||||
EngineCoreProc.run_engine_core,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=log_stats,
|
||||
input_address=input_address,
|
||||
on_head_node=True,
|
||||
local_engine_count=local_engine_count,
|
||||
start_index=start_index,
|
||||
local_start_index=local_start_index)
|
||||
|
||||
self.core_engine = self.core_engines[0]
|
||||
|
||||
# Wait for engine core process(es) to start.
|
||||
self._wait_for_engine_startup()
|
||||
self._wait_for_engine_startup(output_address, parallel_config)
|
||||
|
||||
self.utility_results: dict[int, AnyFuture] = {}
|
||||
|
||||
@ -420,56 +429,116 @@ class MPClient(EngineCoreClient):
|
||||
if not success:
|
||||
self._finalizer()
|
||||
|
||||
def _wait_for_engine_startup(self):
|
||||
@staticmethod
|
||||
def _get_zmq_addresses(parallel_config: ParallelConfig,
|
||||
spmd_mode: bool) -> tuple[str, str]:
|
||||
"""Returns (input_address, output_address)."""
|
||||
dp_size = parallel_config.data_parallel_size
|
||||
local_engine_count = parallel_config.data_parallel_size_local
|
||||
|
||||
if local_engine_count == dp_size or spmd_mode:
|
||||
input_address = get_open_zmq_ipc_path()
|
||||
output_address = get_open_zmq_ipc_path()
|
||||
else:
|
||||
host = parallel_config.data_parallel_master_ip
|
||||
input_port = parallel_config.data_parallel_rpc_port
|
||||
output_port = get_open_port()
|
||||
input_address = get_tcp_uri(host, input_port)
|
||||
output_address = get_tcp_uri(host, output_port)
|
||||
|
||||
return input_address, output_address
|
||||
|
||||
def _wait_for_engine_startup(self, output_address: str,
|
||||
parallel_config: ParallelConfig):
|
||||
# Get a sync handle to the socket which can be sync or async.
|
||||
sync_input_socket = zmq.Socket.shadow(self.input_socket)
|
||||
|
||||
# Wait for engine core process(es) to send ready messages.
|
||||
identities = set(eng.index for eng in self.resources.core_engines)
|
||||
local_count = parallel_config.data_parallel_size_local
|
||||
remote_count = len(self.core_engines) - local_count
|
||||
# [local, remote] counts
|
||||
conn_pending, start_pending = [local_count, remote_count], [0, 0]
|
||||
|
||||
poller = zmq.Poller()
|
||||
poller.register(sync_input_socket, zmq.POLLIN)
|
||||
for eng in self.resources.core_engines:
|
||||
poller.register(eng.proc_handle, zmq.POLLIN)
|
||||
while identities:
|
||||
proc_manager = self.resources.local_engine_manager
|
||||
if proc_manager is not None:
|
||||
for sentinel in proc_manager.sentinels():
|
||||
poller.register(sentinel, zmq.POLLIN)
|
||||
while any(conn_pending) or any(start_pending):
|
||||
events = poller.poll(STARTUP_POLL_PERIOD_MS)
|
||||
if not events:
|
||||
logger.debug("Waiting for %d core engine proc(s) to start: %s",
|
||||
len(identities), identities)
|
||||
if any(conn_pending):
|
||||
logger.debug(
|
||||
"Waiting for %d local, %d remote core engine proc(s) "
|
||||
"to connect.", *conn_pending)
|
||||
if any(start_pending):
|
||||
logger.debug(
|
||||
"Waiting for %d local, %d remote core engine proc(s) "
|
||||
"to start.", *start_pending)
|
||||
continue
|
||||
if len(events) > 1 or events[0][0] != sync_input_socket:
|
||||
# One of the core processes exited.
|
||||
# One of the local core processes exited.
|
||||
finished = proc_manager.finished_procs(
|
||||
) if proc_manager else {}
|
||||
raise RuntimeError("Engine core initialization failed. "
|
||||
"See root cause above.")
|
||||
"See root cause above. "
|
||||
f"Failed core proc(s): {finished}")
|
||||
|
||||
eng_id_bytes, data = sync_input_socket.recv_multipart()
|
||||
eng_id = int.from_bytes(eng_id_bytes, byteorder="little")
|
||||
if eng_id not in identities:
|
||||
raise RuntimeError(f"Unexpected or duplicate engine: {eng_id}")
|
||||
message_dict = json.loads(data.decode('utf-8'))
|
||||
if message_dict['type'] != 'READY':
|
||||
raise RuntimeError(f"Engine {eng_id} failed: {data.decode()}")
|
||||
logger.info("Core engine process %d ready.", eng_id)
|
||||
identities.discard(eng_id)
|
||||
# Setup KV cache config with initialization state from
|
||||
# engine core process. Sum values from all engines in DP case.
|
||||
num_gpu_blocks = self.vllm_config.cache_config.num_gpu_blocks or 0
|
||||
num_gpu_blocks += message_dict['num_gpu_blocks']
|
||||
self.vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
# Receive HELLO and READY messages from the input socket.
|
||||
eng_identity, ready_msg_bytes = sync_input_socket.recv_multipart()
|
||||
eng_index = int.from_bytes(eng_identity, byteorder="little")
|
||||
engine = next(
|
||||
(e for e in self.core_engines if e.identity == eng_identity),
|
||||
None)
|
||||
if engine is None:
|
||||
raise RuntimeError(f"Message from engine with unexpected data "
|
||||
f"parallel rank: {eng_index}")
|
||||
msg = msgspec.msgpack.decode(ready_msg_bytes)
|
||||
status, local = msg["status"], msg["local"]
|
||||
if local != engine.local:
|
||||
raise RuntimeError(f"{status} message from "
|
||||
f"{'local' if local else 'remote'} "
|
||||
f"engine {eng_index}, expected it to be "
|
||||
f"{'local' if engine.local else 'remote'}")
|
||||
|
||||
def _init_core_engines(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
new_core_engine: Callable[[int, Optional[int]], CoreEngine],
|
||||
core_engines: list[CoreEngine],
|
||||
) -> None:
|
||||
if status == "HELLO" and engine.state == CoreEngineState.NEW:
|
||||
|
||||
# Default case - single core engine.
|
||||
core_engine = new_core_engine(
|
||||
vllm_config.parallel_config.data_parallel_rank,
|
||||
vllm_config.parallel_config.data_parallel_rank_local,
|
||||
)
|
||||
core_engines.append(core_engine)
|
||||
self.core_engine = core_engine
|
||||
# Send init message with DP config info.
|
||||
init_message = self.encoder.encode({
|
||||
"output_socket_address": output_address,
|
||||
"parallel_config": {
|
||||
"data_parallel_master_ip":
|
||||
parallel_config.data_parallel_master_ip,
|
||||
"data_parallel_master_port":
|
||||
parallel_config.data_parallel_master_port,
|
||||
"data_parallel_size":
|
||||
parallel_config.data_parallel_size,
|
||||
},
|
||||
})
|
||||
sync_input_socket.send_multipart((eng_identity, *init_message),
|
||||
copy=False)
|
||||
conn_pending[0 if local else 1] -= 1
|
||||
start_pending[0 if local else 1] += 1
|
||||
engine.state = CoreEngineState.CONNECTED
|
||||
elif status == "READY" and (engine.state
|
||||
== CoreEngineState.CONNECTED):
|
||||
# Setup KV cache config with initialization state from
|
||||
# engine core process. Sum values from all engines in DP case.
|
||||
cache_config = self.vllm_config.cache_config
|
||||
num_gpu_blocks = cache_config.num_gpu_blocks or 0
|
||||
num_gpu_blocks += msg['num_gpu_blocks']
|
||||
cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
|
||||
start_pending[0 if local else 1] -= 1
|
||||
engine.state = CoreEngineState.READY
|
||||
else:
|
||||
raise RuntimeError(f"Unexpected {status} message for "
|
||||
f"{'local' if local else 'remote'} engine "
|
||||
f"{eng_index} in {engine.state} state.")
|
||||
|
||||
logger.debug("%s from %s core engine process %s.", status,
|
||||
"local" if local else "remote", eng_index)
|
||||
|
||||
def shutdown(self):
|
||||
# Terminate background resources.
|
||||
@ -520,7 +589,8 @@ class SyncMPClient(MPClient):
|
||||
# Ensure that the outputs socket processing thread does not have
|
||||
# a ref to the client which prevents gc.
|
||||
ctx = self.ctx
|
||||
output_path = self.output_path
|
||||
out_socket = self.resources.output_socket
|
||||
assert out_socket is not None
|
||||
decoder = self.decoder
|
||||
utility_results = self.utility_results
|
||||
outputs_queue = self.outputs_queue
|
||||
@ -531,7 +601,6 @@ class SyncMPClient(MPClient):
|
||||
|
||||
def process_outputs_socket():
|
||||
shutdown_socket = ctx.socket(zmq.PAIR)
|
||||
out_socket = make_zmq_socket(ctx, output_path, zmq.constants.PULL)
|
||||
try:
|
||||
shutdown_socket.bind(shutdown_path)
|
||||
poller = zmq.Poller()
|
||||
@ -566,6 +635,9 @@ class SyncMPClient(MPClient):
|
||||
daemon=True)
|
||||
self.output_queue_thread.start()
|
||||
|
||||
# The thread takes on responsibility for closing the socket.
|
||||
self.resources.output_socket = None
|
||||
|
||||
def get_output(self) -> EngineCoreOutputs:
|
||||
# If an exception arises in process_outputs_socket task,
|
||||
# it is forwarded to the outputs_queue so we can raise it
|
||||
@ -693,10 +765,8 @@ class AsyncMPClient(MPClient):
|
||||
self.__class__,
|
||||
"process_engine_outputs", None)
|
||||
_self_ref = weakref.ref(self) if output_handler else None
|
||||
output_path = self.output_path
|
||||
output_socket = make_zmq_socket(self.ctx, output_path,
|
||||
zmq.constants.PULL)
|
||||
resources.output_socket = output_socket
|
||||
output_socket = resources.output_socket
|
||||
assert output_socket is not None
|
||||
|
||||
async def process_outputs_socket():
|
||||
try:
|
||||
@ -861,21 +931,6 @@ class DPAsyncMPClient(AsyncMPClient):
|
||||
|
||||
assert len(self.core_engines) > 1
|
||||
|
||||
def _init_core_engines(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
new_core_engine: Callable[[int, Optional[int]], CoreEngine],
|
||||
core_engines: list[CoreEngine],
|
||||
) -> None:
|
||||
|
||||
# Launch a core engine for each data parallel rank.
|
||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
for i in range(dp_size):
|
||||
# Multi-node not yet supported so local_dp_rank == dp_rank.
|
||||
core_engines.append(new_core_engine(i, i))
|
||||
|
||||
self.core_engines = core_engines
|
||||
|
||||
async def call_utility_async(self, method: str, *args) -> Any:
|
||||
# Only the result from the first engine is returned.
|
||||
return (await asyncio.gather(*[
|
||||
|
||||
107
vllm/v1/utils.py
107
vllm/v1/utils.py
@ -1,20 +1,23 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
import time
|
||||
import weakref
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from multiprocessing import Process
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar,
|
||||
Union, overload)
|
||||
from multiprocessing import Process, connection
|
||||
from typing import (TYPE_CHECKING, Callable, Generic, Optional, TypeVar, Union,
|
||||
overload)
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
||||
usage_message)
|
||||
from vllm.utils import get_mp_context, kill_process_tree
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.layer import Attention
|
||||
@ -92,7 +95,7 @@ class ConstantList(Generic[T], Sequence):
|
||||
return f"ConstantList({self._x})"
|
||||
|
||||
|
||||
class BackgroundProcHandle:
|
||||
class CoreEngineProcManager:
|
||||
"""
|
||||
Utility class to handle creation, readiness, and shutdown
|
||||
of background processes used by the AsyncLLM and LLMEngine.
|
||||
@ -100,49 +103,91 @@ class BackgroundProcHandle:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_path: str,
|
||||
output_path: str,
|
||||
process_name: str,
|
||||
target_fn: Callable,
|
||||
process_kwargs: dict[Any, Any],
|
||||
local_engine_count: int,
|
||||
start_index: int,
|
||||
local_start_index: int,
|
||||
vllm_config: VllmConfig,
|
||||
on_head_node: bool,
|
||||
input_address: str,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
):
|
||||
context = get_mp_context()
|
||||
common_kwargs = {
|
||||
"vllm_config": vllm_config,
|
||||
"on_head_node": on_head_node,
|
||||
"input_address": input_address,
|
||||
"executor_class": executor_class,
|
||||
"log_stats": log_stats,
|
||||
}
|
||||
|
||||
assert ("input_path" not in process_kwargs
|
||||
and "output_path" not in process_kwargs)
|
||||
process_kwargs["input_path"] = input_path
|
||||
process_kwargs["output_path"] = output_path
|
||||
self.processes: list[Process] = []
|
||||
for index in range(local_engine_count):
|
||||
local_index = local_start_index + index
|
||||
global_index = start_index + index
|
||||
# Start EngineCore in background process.
|
||||
self.processes.append(
|
||||
context.Process(target=target_fn,
|
||||
name=f"EngineCore_{global_index}",
|
||||
kwargs=common_kwargs | {
|
||||
"dp_rank": global_index,
|
||||
"local_dp_rank": local_index,
|
||||
}))
|
||||
|
||||
# Run busy loop in background process.
|
||||
self.proc: Process = context.Process(target=target_fn,
|
||||
kwargs=process_kwargs,
|
||||
name=process_name)
|
||||
self._finalizer = weakref.finalize(self, shutdown, self.proc,
|
||||
input_path, output_path)
|
||||
self.proc.start()
|
||||
self._finalizer = weakref.finalize(self, shutdown, self.processes,
|
||||
input_address)
|
||||
try:
|
||||
for proc in self.processes:
|
||||
proc.start()
|
||||
finally:
|
||||
# Kill other procs if not all are running.
|
||||
if self.finished_procs():
|
||||
self.close()
|
||||
|
||||
def fileno(self):
|
||||
return self.proc.sentinel
|
||||
|
||||
def shutdown(self):
|
||||
def close(self):
|
||||
"""Shutdown all procs."""
|
||||
self._finalizer()
|
||||
|
||||
def join_first(self):
|
||||
"""Wait for any process to exit."""
|
||||
connection.wait(proc.sentinel for proc in self.processes)
|
||||
|
||||
def sentinels(self) -> list:
|
||||
return [proc.sentinel for proc in self.processes]
|
||||
|
||||
def finished_procs(self) -> dict[str, int]:
|
||||
"""Returns dict of proc name -> exit code for any finished procs."""
|
||||
return {
|
||||
proc.name: proc.exitcode
|
||||
for proc in self.processes if proc.exitcode is not None
|
||||
}
|
||||
|
||||
|
||||
# Note(rob): shutdown function cannot be a bound method,
|
||||
# else the gc cannot collect the object.
|
||||
def shutdown(proc: Process, input_path: str, output_path: str):
|
||||
# else the gc cannot collect the objedecoupct.
|
||||
def shutdown(procs: list[Process], input_address: str):
|
||||
# Shutdown the process.
|
||||
if proc.is_alive():
|
||||
proc.terminate()
|
||||
proc.join(5)
|
||||
for proc in procs:
|
||||
if proc.is_alive():
|
||||
proc.terminate()
|
||||
|
||||
# Allow 5 seconds for remaining procs to terminate.
|
||||
deadline = time.monotonic() + 5
|
||||
for proc in procs:
|
||||
remaining = deadline - time.monotonic()
|
||||
if remaining <= 0:
|
||||
break
|
||||
if proc.is_alive():
|
||||
proc.join(remaining)
|
||||
|
||||
for proc in procs:
|
||||
if proc.is_alive() and (pid := proc.pid) is not None:
|
||||
kill_process_tree(pid)
|
||||
|
||||
# Remove zmq ipc socket files.
|
||||
ipc_sockets = [output_path, input_path]
|
||||
for ipc_socket in ipc_sockets:
|
||||
socket_file = ipc_socket.replace("ipc://", "")
|
||||
if input_address.startswith("ipc://"):
|
||||
socket_file = input_address[len("ipc://"):]
|
||||
if os and os.path.exists(socket_file):
|
||||
os.remove(socket_file)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user