[DP] Support external DP Load Balancer mode (#19790)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-07-02 18:21:52 +01:00 committed by GitHub
parent a1aafc827a
commit 657f2f301a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 1254 additions and 787 deletions

View File

@ -155,6 +155,7 @@ steps:
- examples/offline_inference/rlhf_colocate.py
- tests/examples/offline_inference/data_parallel.py
- tests/v1/test_async_llm_dp.py
- tests/v1/test_external_lb_dp.py
- tests/v1/engine/test_engine_core_client.py
commands:
# test with tp=2 and external_dp=2
@ -163,8 +164,9 @@ steps:
# test with tp=2 and pp=2
- PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
# test with internal dp
- python3 ../examples/offline_inference/data_parallel.py
- python3 ../examples/offline_inference/data_parallel.py --enforce-eager
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py
- pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp
- pytest -v -s distributed/test_utils.py
- pytest -v -s compile/test_basic_correctness.py
@ -682,10 +684,12 @@ steps:
- vllm/worker/model_runner.py
- entrypoints/llm/test_collective_rpc.py
- tests/v1/test_async_llm_dp.py
- tests/v1/test_external_lb_dp.py
- tests/v1/entrypoints/openai/test_multi_api_servers.py
- vllm/v1/engine/
commands:
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py
- DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py
- pytest -v -s entrypoints/llm/test_collective_rpc.py
- pytest -v -s ./compile/test_basic_correctness.py

View File

@ -26,8 +26,8 @@ from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core import EngineCore
from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient,
SyncMPClient)
from vllm.v1.engine.utils import CoreEngineProcManager
from vllm.v1.executor.abstract import Executor
from vllm.v1.utils import CoreEngineProcManager
from ...distributed.conftest import MockSubscriber
from ...utils import create_new_process_for_each_test
@ -563,7 +563,7 @@ def test_engine_core_proc_instantiation_cuda_empty(
m.setenv("VLLM_USE_V1", "1")
m.setenv("CUDA_VISIBLE_DEVICES", "") # No CUDA devices
from vllm.v1.utils import EngineZmqAddresses
from vllm.v1.engine.utils import EngineZmqAddresses
def mock_startup_handshake(self, handshake_socket, on_head_node,
parallel_config):
@ -580,7 +580,7 @@ def test_engine_core_proc_instantiation_cuda_empty(
trust_remote_code=True).create_engine_config()
engine_core_proc = EngineCoreProc(
vllm_config=vllm_config,
on_head_node=True,
local_client=True,
handshake_address="tcp://127.0.0.1:12345",
executor_class=mock_executor_class,
log_stats=False,

View File

@ -0,0 +1,312 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import os
import threading
import time
from contextlib import AsyncExitStack
import openai # use the official client for correctness check
import pytest
import pytest_asyncio
from tests.utils import RemoteOpenAIServer
from vllm.platforms import Platform
MODEL_NAME = "ibm-research/PowerMoE-3b"
# Number of data parallel ranks for external LB testing
DP_SIZE = int(os.getenv("DP_SIZE", "2"))
# Default tensor parallell size to use
TP_SIZE = int(os.getenv("TP_SIZE", "1"))
class ExternalLBServerManager:
"""Manages data parallel vLLM server instances for external
load balancer testing."""
def __init__(self,
model_name: str,
dp_size: int,
api_server_count: int,
base_server_args: list,
tp_size: int = TP_SIZE):
self.model_name = model_name
self.dp_size = dp_size
self.tp_size = tp_size
self.api_server_count = api_server_count
self.base_server_args = base_server_args
self.servers: list[tuple[RemoteOpenAIServer, list[str]]] = []
self.server_threads: list[threading.Thread] = []
def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]:
"""Start all server instances for external LB mode."""
for rank in range(self.dp_size):
# Create server args for this specific rank
server_args = self.base_server_args.copy()
# Add external LB specific arguments
server_args.extend([
"--data-parallel-size",
str(self.dp_size),
"--data-parallel-rank",
str(rank),
"--data-parallel-size-local",
"1",
"--tensor-parallel-size",
str(self.tp_size),
"--port",
str(8000 + rank), # Different port for each rank
"--api-server-count",
str(self.api_server_count),
])
# Use a thread to start each server to allow parallel initialization
def start_server(r: int, sargs: list[str]):
try:
# Start the server
server = RemoteOpenAIServer(
self.model_name,
sargs,
auto_port=False,
env_dict={
"CUDA_VISIBLE_DEVICES":
",".join(
str(Platform.device_id_to_physical_device_id(
i))
for i in range(r * TP_SIZE, (r + 1) * TP_SIZE))
})
server.__enter__()
print(f"Server rank {r} started successfully with "
f"{self.api_server_count} API servers")
self.servers.append((server, sargs))
except Exception as e:
print(f"Failed to start server rank {r}: {e}")
raise
thread = threading.Thread(target=start_server,
args=(rank, server_args))
thread.start()
self.server_threads.append(thread)
# Wait for all servers to start
for thread in self.server_threads:
thread.join()
# Give servers additional time to fully initialize and coordinate
time.sleep(2)
if len(self.servers) != self.dp_size:
raise Exception("Servers failed to start")
return self.servers
def __exit__(self, exc_type, exc_val, exc_tb):
"""Stop all server instances."""
while self.servers:
try:
self.servers.pop()[0].__exit__(exc_type, exc_val, exc_tb)
except Exception as e:
print(f"Error stopping server: {e}")
@pytest.fixture(scope="module")
def default_server_args():
return [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-model-len",
"2048",
"--max-num-seqs",
"128",
"--enforce-eager",
]
@pytest.fixture(scope="module", params=[1, 4])
def servers(request, default_server_args):
api_server_count = request.param
with ExternalLBServerManager(MODEL_NAME, DP_SIZE, api_server_count,
default_server_args) as server_list:
yield server_list
@pytest_asyncio.fixture
async def clients(servers: list[tuple[RemoteOpenAIServer, list[str]]]):
# Create a client for each server
async with AsyncExitStack() as stack:
yield [
await stack.enter_async_context(server.get_async_client())
for server, _ in servers
]
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_external_lb_single_completion(clients: list[
openai.AsyncOpenAI], servers: list[tuple[RemoteOpenAIServer, list[str]]],
model_name: str) -> None:
async def make_request(client: openai.AsyncOpenAI):
completion = await client.completions.create(
model=model_name,
prompt="Hello, my name is",
max_tokens=10,
temperature=1.0)
assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 1
choice = completion.choices[0]
# The exact number of tokens can vary slightly with temperature=1.0,
# so we check for a reasonable minimum length.
assert len(choice.text) >= 1
# Finish reason might not always be 'length' if the model finishes early
# or due to other reasons, especially with high temperature.
# So, we'll accept 'length' or 'stop'.
assert choice.finish_reason in ("length", "stop")
# Token counts can also vary, so we check they are positive.
assert completion.usage.completion_tokens > 0
assert completion.usage.prompt_tokens > 0
assert completion.usage.total_tokens > 0
return completion
# Test single request to each server
for i, client in enumerate(clients):
result = await make_request(client)
assert result is not None
print(f"Server {i} handled single completion request successfully")
await asyncio.sleep(0.5)
# Send requests to all servers in round-robin fashion
num_requests_per_server = 25 # Total 50 requests across 2 servers
all_tasks = []
for i, client in enumerate(clients):
tasks = [make_request(client) for _ in range(num_requests_per_server)]
all_tasks.extend(tasks)
results = await asyncio.gather(*all_tasks)
assert len(results) == num_requests_per_server * len(clients)
assert all(completion is not None for completion in results)
await asyncio.sleep(0.5)
# Second burst of requests
all_tasks = []
for i, client in enumerate(clients):
tasks = [make_request(client) for _ in range(num_requests_per_server)]
all_tasks.extend(tasks)
results = await asyncio.gather(*all_tasks)
assert len(results) == num_requests_per_server * len(clients)
assert all(completion is not None for completion in results)
_, server_args = servers[0]
api_server_count = (
server_args.count('--api-server-count')
and server_args[server_args.index('--api-server-count') + 1] or 1)
print(
f"Successfully completed external LB test with {len(clients)} servers "
f"(API server count: {api_server_count})")
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_external_lb_completion_streaming(clients: list[
openai.AsyncOpenAI], servers: list[tuple[RemoteOpenAIServer, list[str]]],
model_name: str) -> None:
prompt = "What is an LLM?"
async def make_streaming_request(client: openai.AsyncOpenAI):
# Perform a non-streaming request to get the expected full output
single_completion = await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
)
single_output = single_completion.choices[0].text
# Perform the streaming request
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True)
chunks: list[str] = []
finish_reason_count = 0
last_chunk = None
async for chunk in stream:
chunks.append(chunk.choices[0].text)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
last_chunk = chunk # Keep track of the last chunk
# finish reason should only return in the last block for OpenAI API
assert finish_reason_count == 1, (
"Finish reason should appear exactly once.")
assert last_chunk is not None, (
"Stream should have yielded at least one chunk.")
assert last_chunk.choices[
0].finish_reason == "length", "Finish reason should be 'length'."
# Check that the combined text matches the non-streamed version.
assert "".join(
chunks
) == single_output, "Streamed output should match non-streamed output."
return True # Indicate success for this request
# Test single request to each server
for i, client in enumerate(clients):
result = await make_streaming_request(client)
assert result is not None
print(f"Server {i} handled single streaming request successfully")
await asyncio.sleep(0.5)
# Send streaming requests to all servers in round-robin fashion
num_requests_per_server = 25 # Total 50 requests across 2 servers
all_tasks = []
for i, client in enumerate(clients):
tasks = [
make_streaming_request(client)
for _ in range(num_requests_per_server)
]
all_tasks.extend(tasks)
results = await asyncio.gather(*all_tasks)
assert len(results) == num_requests_per_server * len(clients)
assert all(results), "Not all streaming requests completed successfully."
await asyncio.sleep(0.5)
# Second burst of streaming requests
all_tasks = []
for i, client in enumerate(clients):
tasks = [
make_streaming_request(client)
for _ in range(num_requests_per_server)
]
all_tasks.extend(tasks)
results = await asyncio.gather(*all_tasks)
assert len(results) == num_requests_per_server * len(clients)
assert all(results), "Not all streaming requests completed successfully."
_, server_args = servers[0]
api_server_count = (
server_args.count('--api-server-count')
and server_args[server_args.index('--api-server-count') + 1] or 1)
print(f"Successfully completed external LB streaming test with "
f"{len(clients)} servers (API server count: {api_server_count})")

View File

@ -1784,6 +1784,10 @@ class ParallelConfig:
"""Port of the data parallel master."""
data_parallel_backend: str = "mp"
"""Backend to use for data parallel, either "mp" or "ray"."""
data_parallel_external_lb: bool = False
"""Whether to use "external" DP LB mode. Applies only to online serving
and when data_parallel_size > 0. Set implicitly when
data_parallel_rank is provided explicitly to vllm serve."""
enable_expert_parallel: bool = False
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
enable_eplb: bool = False
@ -1953,6 +1957,11 @@ class ParallelConfig:
if self.data_parallel_size > 1 or self.data_parallel_size_local == 0:
# Data parallel was specified in the engine args.
self.data_parallel_master_port = get_open_port()
if not (0 <= self.data_parallel_rank < self.data_parallel_size):
raise ValueError(
f"data_parallel_rank ({self.data_parallel_rank})"
f" must be in the range [0, {self.data_parallel_size})")
else:
# Otherwise fall back to env vars (e.g. for offline SPMD case).
self.data_parallel_size = envs.VLLM_DP_SIZE
@ -1961,6 +1970,10 @@ class ParallelConfig:
self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP
self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT
if self.data_parallel_external_lb:
raise ValueError("data_parallel_external_lb can only "
"be set when data_parallel_size > 1")
if self.distributed_executor_backend == "external_launcher":
import os
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"

View File

@ -318,6 +318,7 @@ class EngineArgs:
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
data_parallel_size: int = ParallelConfig.data_parallel_size
data_parallel_rank: Optional[int] = None
data_parallel_size_local: Optional[int] = None
data_parallel_address: Optional[str] = None
data_parallel_rpc_port: Optional[int] = None
@ -655,6 +656,12 @@ class EngineArgs:
**parallel_kwargs["tensor_parallel_size"])
parallel_group.add_argument("--data-parallel-size", "-dp",
**parallel_kwargs["data_parallel_size"])
parallel_group.add_argument(
'--data-parallel-rank',
'-dpn',
type=int,
help='Data parallel rank of this instance. '
'When set, enables external load balancer mode.')
parallel_group.add_argument('--data-parallel-size-local',
'-dpl',
type=int,
@ -1126,10 +1133,17 @@ class EngineArgs:
# but we should not do this here.
placement_group = ray.util.get_current_placement_group()
# Local DP size defaults to global DP size if not set.
data_parallel_size_local = self.data_parallel_size if (
self.data_parallel_size_local
is None) else self.data_parallel_size_local
data_parallel_external_lb = self.data_parallel_rank is not None
if data_parallel_external_lb:
assert self.data_parallel_size_local in (1, None), (
"data_parallel_size_local must be 1 when data_parallel_rank "
"is set")
data_parallel_size_local = 1
elif self.data_parallel_size_local is not None:
data_parallel_size_local = self.data_parallel_size_local
else:
# Local DP size defaults to global DP size if not set.
data_parallel_size_local = self.data_parallel_size
# DP address, used in multi-node case for torch distributed group
# and ZMQ sockets.
@ -1154,16 +1168,16 @@ class EngineArgs:
self.data_parallel_rpc_port
is not None) else ParallelConfig.data_parallel_rpc_port
data_parallel_backend = self.data_parallel_backend
parallel_config = ParallelConfig(
pipeline_parallel_size=self.pipeline_parallel_size,
tensor_parallel_size=self.tensor_parallel_size,
data_parallel_size=self.data_parallel_size,
data_parallel_rank=self.data_parallel_rank or 0,
data_parallel_external_lb=data_parallel_external_lb,
data_parallel_size_local=data_parallel_size_local,
data_parallel_master_ip=data_parallel_address,
data_parallel_rpc_port=data_parallel_rpc_port,
data_parallel_backend=data_parallel_backend,
data_parallel_backend=self.data_parallel_backend,
enable_expert_parallel=self.enable_expert_parallel,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.num_redundant_experts,

View File

@ -5,9 +5,9 @@ import argparse
import os
import signal
import sys
from typing import Optional
import uvloop
import zmq
import vllm
import vllm.envs as envs
@ -21,17 +21,13 @@ from vllm.entrypoints.utils import (VLLM_SUBCMD_PARSER_EPILOG,
from vllm.executor.multiproc_worker_utils import _add_prefix
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, get_tcp_uri, zmq_socket_ctx
from vllm.v1.engine.coordinator import DPCoordinator
from vllm.utils import FlexibleArgumentParser, get_tcp_uri
from vllm.v1.engine.core import EngineCoreProc
from vllm.v1.engine.core_client import CoreEngineProcManager
from vllm.v1.engine.utils import CoreEngineProcManager, launch_core_engines
from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
from vllm.v1.utils import (APIServerProcessManager, CoreEngine,
CoreEngineActorManager, EngineZmqAddresses,
get_engine_client_zmq_addr,
wait_for_completion_or_failure,
wait_for_engine_startup)
from vllm.v1.utils import (APIServerProcessManager,
wait_for_completion_or_failure)
logger = init_logger(__name__)
@ -48,11 +44,15 @@ class ServeSubcommand(CLISubcommand):
if args.headless or args.api_server_count < 1:
run_headless(args)
elif args.api_server_count > 1:
run_multi_api_server(args)
else:
# Single API server (this process).
uvloop.run(run_server(args))
if args.data_parallel_start_rank:
raise ValueError("data_parallel_start_rank is only "
"applicable in headless mode")
if args.api_server_count > 1:
run_multi_api_server(args)
else:
# Single API server (this process).
uvloop.run(run_server(args))
def validate(self, args: argparse.Namespace) -> None:
validate_parsed_serve_args(args)
@ -121,14 +121,19 @@ def run_headless(args: argparse.Namespace):
parallel_config = vllm_config.parallel_config
local_engine_count = parallel_config.data_parallel_size_local
host = parallel_config.data_parallel_master_ip
port = engine_args.data_parallel_rpc_port # add to config too
handshake_address = get_tcp_uri(host, port)
if local_engine_count <= 0:
raise ValueError("data_parallel_size_local must be > 0 in "
"headless mode")
if parallel_config.data_parallel_rank is not None:
raise ValueError("data_parallel_rank is not applicable in "
"headless mode")
host = parallel_config.data_parallel_master_ip
port = engine_args.data_parallel_rpc_port # add to config too
handshake_address = get_tcp_uri(host, port)
# Catch SIGTERM and SIGINT to allow graceful shutdown.
def signal_handler(signum, frame):
logger.debug("Received %d signal.", signum)
@ -148,7 +153,7 @@ def run_headless(args: argparse.Namespace):
start_index=args.data_parallel_start_rank,
local_start_index=0,
vllm_config=vllm_config,
on_head_node=False,
local_client=False,
handshake_address=handshake_address,
executor_class=Executor.get_class(vllm_config),
log_stats=not engine_args.disable_log_stats,
@ -192,117 +197,53 @@ def run_multi_api_server(args: argparse.Namespace):
" api_server_count > 1")
model_config.disable_mm_preprocessor_cache = True
executor_class = Executor.get_class(vllm_config)
log_stats = not engine_args.disable_log_stats
parallel_config = vllm_config.parallel_config
dp_rank = parallel_config.data_parallel_rank
external_dp_lb = parallel_config.data_parallel_external_lb
assert external_dp_lb or dp_rank == 0
assert parallel_config.data_parallel_rank == 0
api_server_manager: Optional[APIServerProcessManager] = None
dp_size = parallel_config.data_parallel_size
local_engine_count = parallel_config.data_parallel_size_local
host = parallel_config.data_parallel_master_ip
local_only = local_engine_count == dp_size
with launch_core_engines(vllm_config, executor_class, log_stats,
num_api_servers) as (local_engine_manager,
coordinator, addresses):
# Set up input and output addresses.
input_addresses = [
get_engine_client_zmq_addr(local_only, host)
for _ in range(num_api_servers)
]
output_addresses = [
get_engine_client_zmq_addr(local_only, host)
for _ in range(num_api_servers)
]
addresses = EngineZmqAddresses(
inputs=input_addresses,
outputs=output_addresses,
)
# Set up coordinator for dp > 1.
coordinator = None
stats_update_address = None
if dp_size > 1:
coordinator = DPCoordinator(parallel_config)
addresses.coordinator_input, addresses.coordinator_output = (
coordinator.get_engine_socket_addresses())
stats_update_address = coordinator.get_stats_publish_address()
logger.info("Started DP Coordinator process (PID: %d)",
coordinator.proc.pid)
if parallel_config.data_parallel_backend == "ray":
logger.info("Starting ray-based data parallel backend")
engine_actor_manager = CoreEngineActorManager(
vllm_config=vllm_config,
addresses=addresses,
executor_class=Executor.get_class(vllm_config),
log_stats=not engine_args.disable_log_stats,
)
# Start API servers using the manager
api_server_manager = APIServerProcessManager(
# Construct common args for the APIServerProcessManager up-front.
api_server_manager_kwargs = dict(
target_server_fn=run_api_server_worker_proc,
listen_address=listen_address,
sock=sock,
args=args,
num_servers=num_api_servers,
input_addresses=input_addresses,
output_addresses=output_addresses,
stats_update_address=stats_update_address)
input_addresses=addresses.inputs,
output_addresses=addresses.outputs,
stats_update_address=coordinator.get_stats_publish_address()
if coordinator else None)
wait_for_completion_or_failure(api_server_manager=api_server_manager,
engine_manager=engine_actor_manager,
coordinator=coordinator)
return
# For dp ranks > 0 in external DP LB mode, we must delay the
# start of the API servers until the local engine is started
# (after the launcher context manager exits),
# since we get the front-end stats update address from the coordinator
# via the handshake with the local engine.
if dp_rank == 0 or not external_dp_lb:
# Start API servers using the manager.
api_server_manager = APIServerProcessManager(
**api_server_manager_kwargs)
handshake_address = get_engine_client_zmq_addr(
local_only, host, parallel_config.data_parallel_rpc_port)
with zmq_socket_ctx(handshake_address, zmq.ROUTER,
bind=True) as handshake_socket:
# Start local engines.
if not local_engine_count:
local_engine_manager = None
else:
local_engine_manager = CoreEngineProcManager(
EngineCoreProc.run_engine_core,
vllm_config=vllm_config,
executor_class=Executor.get_class(vllm_config),
log_stats=not engine_args.disable_log_stats,
handshake_address=handshake_address,
on_head_node=True,
local_engine_count=local_engine_count,
start_index=0,
local_start_index=0)
# Start API servers using the manager
# Start API servers now if they weren't already started.
if api_server_manager is None:
api_server_manager_kwargs["stats_update_address"] = (
addresses.frontend_stats_publish_address)
api_server_manager = APIServerProcessManager(
target_server_fn=run_api_server_worker_proc,
listen_address=listen_address,
sock=sock,
args=args,
num_servers=num_api_servers,
input_addresses=input_addresses,
output_addresses=output_addresses,
stats_update_address=stats_update_address)
**api_server_manager_kwargs)
# Wait for engine handshakes to complete.
core_engines = [
CoreEngine(index=i, local=(i < local_engine_count))
for i in range(dp_size)
]
wait_for_engine_startup(
handshake_socket,
addresses,
core_engines,
parallel_config,
vllm_config.cache_config,
local_engine_manager,
coordinator.proc if coordinator else None,
)
# Wait for API servers
wait_for_completion_or_failure(api_server_manager=api_server_manager,
engine_manager=local_engine_manager,
coordinator=coordinator)
# Wait for API servers
wait_for_completion_or_failure(api_server_manager=api_server_manager,
engine_manager=local_engine_manager,
coordinator=coordinator)
def run_api_server_worker_proc(listen_address,

View File

@ -10,7 +10,7 @@ import zmq
from vllm.config import ParallelConfig
from vllm.logger import init_logger
from vllm.utils import get_mp_context, get_open_zmq_ipc_path, make_zmq_socket
from vllm.utils import get_mp_context, make_zmq_socket
from vllm.v1.engine import EngineCoreOutputs, EngineCoreRequestType
from vllm.v1.serial_utils import MsgpackDecoder
from vllm.v1.utils import get_engine_client_zmq_addr, shutdown
@ -48,20 +48,33 @@ class DPCoordinator:
Engines will move into running state when receiving a new request or
START_DP_WAVE message.
Note that when deployed in External LB mode, no stats will be published by
the engines and thus updates will only be sent to front-ends when the
request wave / running state changes.
"""
def __init__(self, parallel_config: ParallelConfig):
# Assume coordinator is colocated with front-end procs.
front_publish_address = get_open_zmq_ipc_path()
dp_size = parallel_config.data_parallel_size
assert dp_size > 1, "Coordinator only used for data parallel"
local_only = dp_size == parallel_config.data_parallel_size_local
host = parallel_config.data_parallel_master_ip
back_publish_address = get_engine_client_zmq_addr(local_only, host)
back_output_address = get_engine_client_zmq_addr(local_only, host)
external_lb = parallel_config.data_parallel_external_lb
# Assume coordinator is colocated with front-end procs when not in
# external DP LB mode.
front_publish_address = get_engine_client_zmq_addr(
local_only=not external_lb, host=host)
local_only_eng = dp_size == parallel_config.data_parallel_size_local
back_publish_address = get_engine_client_zmq_addr(local_only_eng, host)
back_output_address = get_engine_client_zmq_addr(local_only_eng, host)
# When in external LB mode, load stats aren't published, only changes
# to request wave / running state, so we don't need to rate-limit the
# updates to the front-end proc(s).
min_stats_update_interval_ms = 0 if external_lb else 100
context = get_mp_context()
self.proc: multiprocessing.Process = context.Process(
@ -72,6 +85,7 @@ class DPCoordinator:
"front_publish_address": front_publish_address,
"back_output_address": back_output_address,
"back_publish_address": back_publish_address,
"min_stats_update_interval_ms": min_stats_update_interval_ms,
},
daemon=True)
self.proc.start()
@ -100,12 +114,16 @@ class EngineState:
class CoordinatorProc:
def __init__(self, engine_count: int):
def __init__(self,
engine_count: int,
min_stats_update_interval_ms: int = 100):
self.ctx = zmq.Context()
self.engines = [EngineState() for _ in range(engine_count)]
self.stats_update_interval_ms = min_stats_update_interval_ms
self.current_wave = 0
self.engines_running = False
self.stats_changed = False
@ -116,8 +134,11 @@ class CoordinatorProc:
front_publish_address: str,
back_output_address: str,
back_publish_address: str,
min_stats_update_interval_ms: int = 100,
):
coordinator = CoordinatorProc(engine_count=engine_count)
coordinator = CoordinatorProc(
engine_count=engine_count,
min_stats_update_interval_ms=min_stats_update_interval_ms)
try:
coordinator.process_input_socket(
front_publish_address,
@ -156,9 +177,10 @@ class CoordinatorProc:
last_publish_time = 0
while True:
elapsed = int(time.time() * 1000) - last_publish_time
# Send at 100 ms interval if the stats have changed,
# or otherwise every 3 seconds.
wait_for = 100 if self.stats_changed else 3000
# Send at stats_update_interval_ms interval if the stats have
# changed, or otherwise every 4 seconds.
wait_for = (self.stats_update_interval_ms
if self.stats_changed else 4000)
events = poller.poll(timeout=max(0, wait_for - elapsed))
if not events:
# Poller timeout - publish current stats to front-ends.
@ -174,7 +196,7 @@ class CoordinatorProc:
if publish_front in events:
buffer = publish_front.recv()
if buffer == b'\x01':
if buffer in (b'\x01', b'\x00'):
# Ignore subscription messages.
continue

View File

@ -34,6 +34,7 @@ from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType, UtilityOutput)
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses
from vllm.v1.executor.abstract import Executor
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import SchedulerStats
@ -41,7 +42,6 @@ from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.utils import EngineHandshakeMetadata, EngineZmqAddresses
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
@ -367,10 +367,11 @@ class EngineCoreProc(EngineCore):
def __init__(
self,
vllm_config: VllmConfig,
on_head_node: bool,
local_client: bool,
handshake_address: str,
executor_class: type[Executor],
log_stats: bool,
client_handshake_address: Optional[str] = None,
engine_index: int = 0,
):
self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
@ -383,12 +384,21 @@ class EngineCoreProc(EngineCore):
identity = self.engine_index.to_bytes(length=2, byteorder="little")
self.engines_running = False
with self._perform_handshake(handshake_address, identity, on_head_node,
vllm_config) as addresses:
with self._perform_handshakes(handshake_address, identity,
local_client, vllm_config,
client_handshake_address) as addresses:
self.client_count = len(addresses.outputs)
# Set up data parallel environment.
self.has_coordinator = addresses.coordinator_output is not None
self.frontend_stats_publish_address = (
addresses.frontend_stats_publish_address)
# Only publish request queue stats to coordinator for "internal"
# LB mode.
self.publish_dp_lb_stats = (
self.has_coordinator
and not vllm_config.parallel_config.data_parallel_external_lb)
self._init_data_parallel(vllm_config)
super().__init__(vllm_config, executor_class, log_stats,
@ -414,45 +424,102 @@ class EngineCoreProc(EngineCore):
self.output_thread.start()
@contextmanager
def _perform_handshake(
self, handshake_address: str, identity: bytes, on_head_node: bool,
vllm_config: VllmConfig
def _perform_handshakes(
self,
handshake_address: str,
identity: bytes,
local_client: bool,
vllm_config: VllmConfig,
client_handshake_address: Optional[str],
) -> Generator[EngineZmqAddresses, None, None]:
"""
Perform startup handshakes.
For DP=1 or offline mode, this is with the colocated front-end process.
For DP>1 with internal loadbalancing this is with the shared front-end
process which may reside on a different node.
For DP>1 with external loadbalancing, two handshakes are performed:
- With the rank 0 front-end process which retrieves the
DP Coordinator ZMQ addresses and DP process group address.
- With the colocated front-end process which retrieves the
client input/output socket addresses.
with the exception of the rank 0 engine itself which doesn't require
the second handshake.
Here, "front-end" process can mean the process containing the engine
core client (which is the API server process in the case the API
server is not scaled out), OR the launcher process running the
run_multi_api_server() function in serve.py.
"""
input_ctx = zmq.Context()
with make_zmq_socket(input_ctx,
is_local = local_client and client_handshake_address is None
handshake = self._perform_handshake(input_ctx, handshake_address,
identity, is_local, vllm_config,
vllm_config.parallel_config)
if client_handshake_address is None:
with handshake as addresses:
yield addresses
else:
local_handshake = self._perform_handshake(
input_ctx, client_handshake_address, identity, local_client,
vllm_config)
with handshake as addresses, local_handshake as client_addresses:
addresses.inputs = client_addresses.inputs
addresses.outputs = client_addresses.outputs
yield addresses
# Update config which may have changed from the handshake
vllm_config.__post_init__()
@contextmanager
def _perform_handshake(
self,
ctx: zmq.Context,
handshake_address: str,
identity: bytes,
local_client: bool,
vllm_config: VllmConfig,
parallel_config_to_update: Optional[ParallelConfig] = None,
) -> Generator[EngineZmqAddresses, None, None]:
with make_zmq_socket(ctx,
handshake_address,
zmq.DEALER,
identity=identity,
linger=5000,
bind=False) as handshake_socket:
# Register engine with front-end.
addresses = self.startup_handshake(handshake_socket, on_head_node,
vllm_config.parallel_config)
# Update config which may have changed from the handshake
vllm_config.__post_init__()
addresses = self.startup_handshake(handshake_socket, local_client,
parallel_config_to_update)
yield addresses
# Send ready message.
num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
# We pass back the coordinator stats update address here for the
# external LB case for our colocated front-end to use (coordinator
# only runs with rank 0).
dp_stats_address = self.frontend_stats_publish_address
handshake_socket.send(
msgspec.msgpack.encode({
"status": "READY",
"local": on_head_node,
"local": local_client,
"num_gpu_blocks": num_gpu_blocks,
"dp_stats_address": dp_stats_address,
}))
@staticmethod
def startup_handshake(
handshake_socket: zmq.Socket, on_head_node: bool,
parallel_config: ParallelConfig) -> EngineZmqAddresses:
handshake_socket: zmq.Socket,
local_client: bool,
parallel_config: Optional[ParallelConfig] = None,
) -> EngineZmqAddresses:
# Send registration message.
handshake_socket.send(
msgspec.msgpack.encode({
"status": "HELLO",
"local": on_head_node,
"local": local_client,
}))
# Receive initialization message.
@ -466,9 +533,9 @@ class EngineCoreProc(EngineCore):
init_bytes, type=EngineHandshakeMetadata)
logger.debug("Received init message: %s", init_message)
received_parallel_config = init_message.parallel_config
for key, value in received_parallel_config.items():
setattr(parallel_config, key, value)
if parallel_config is not None:
for key, value in init_message.parallel_config.items():
setattr(parallel_config, key, value)
return init_message.addresses
@ -749,12 +816,12 @@ class DPEngineCoreProc(EngineCoreProc):
def __init__(
self,
vllm_config: VllmConfig,
on_head_node: bool,
local_client: bool,
handshake_address: str,
executor_class: type[Executor],
log_stats: bool,
client_handshake_address: Optional[str] = None,
):
self._decorate_logs()
# Counts forward-passes of the model so that we can synchronize
@ -765,8 +832,9 @@ class DPEngineCoreProc(EngineCoreProc):
# Initialize the engine.
dp_rank = vllm_config.parallel_config.data_parallel_rank
super().__init__(vllm_config, on_head_node, handshake_address,
executor_class, log_stats, dp_rank)
super().__init__(vllm_config, local_client, handshake_address,
executor_class, log_stats, client_handshake_address,
dp_rank)
def _decorate_logs(self):
# Add process-specific prefix to stdout and stderr before
@ -799,10 +867,18 @@ class DPEngineCoreProc(EngineCoreProc):
from vllm.platforms import current_platform
device_control_env_var = current_platform.device_control_env_var
world_size = vllm_config.parallel_config.world_size
os.environ[device_control_env_var] = ",".join(
str(current_platform.device_id_to_physical_device_id(i))
for i in range(local_dp_rank * world_size, (local_dp_rank + 1) *
world_size))
# Set CUDA_VISIBLE_DEVICES or equivalent.
try:
os.environ[device_control_env_var] = ",".join(
str(current_platform.device_id_to_physical_device_id(i))
for i in range(local_dp_rank *
world_size, (local_dp_rank + 1) * world_size))
except IndexError as e:
raise Exception(
f"Error setting {device_control_env_var}: "
f"local range: [{local_dp_rank * world_size}, "
f"{(local_dp_rank + 1) * world_size}) "
f"base value: \"{os.getenv(device_control_env_var)}\"") from e
self.dp_rank = dp_rank
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
@ -839,7 +915,7 @@ class DPEngineCoreProc(EngineCoreProc):
super()._handle_client_request(request_type, request)
def _maybe_publish_request_counts(self):
if not self.has_coordinator:
if not self.publish_dp_lb_stats:
return
# Publish our request counts (if they've changed).
@ -892,9 +968,9 @@ class DPEngineCoreProc(EngineCoreProc):
def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
# Optimization - only perform finish-sync all-reduce every 24 steps.
# Optimization - only perform finish-sync all-reduce every 32 steps.
self.counter += 1
if self.counter != 24:
if self.counter != 32:
return True
self.counter = 0
@ -910,7 +986,7 @@ class DPEngineCoreActor(DPEngineCoreProc):
def __init__(
self,
vllm_config: VllmConfig,
on_head_node: bool,
local_client: bool,
addresses: EngineZmqAddresses,
executor_class: type[Executor],
log_stats: bool,
@ -927,15 +1003,16 @@ class DPEngineCoreActor(DPEngineCoreProc):
# data parallel groups.
del os.environ['CUDA_VISIBLE_DEVICES']
super().__init__(vllm_config, on_head_node, "", executor_class,
super().__init__(vllm_config, local_client, "", executor_class,
log_stats)
def _decorate_logs(self):
pass
@contextmanager
def _perform_handshake(self, handshake_address: str, identity: bytes,
on_head_node: bool, vllm_config: VllmConfig):
def _perform_handshakes(self, handshake_address: str, identity: bytes,
local_client: bool, vllm_config: VllmConfig,
client_handshake_address: Optional[str]):
"""
For Ray, we don't need to actually perform handshake.
All addresses information is known before the actor creation.

View File

@ -7,7 +7,7 @@ import sys
import uuid
import weakref
from abc import ABC, abstractmethod
from collections import deque
from collections import defaultdict, deque
from collections.abc import Awaitable, Sequence
from concurrent.futures import Future
from dataclasses import dataclass
@ -21,18 +21,16 @@ import zmq.asyncio
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.utils import (get_open_zmq_inproc_path, make_zmq_socket,
zmq_socket_ctx)
from vllm.utils import get_open_zmq_inproc_path, make_zmq_socket
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType, UtilityOutput)
from vllm.v1.engine.coordinator import DPCoordinator
from vllm.v1.engine.core import EngineCore, EngineCoreProc
from vllm.v1.engine.exceptions import EngineDeadError
from vllm.v1.engine.utils import (CoreEngineActorManager,
CoreEngineProcManager, launch_core_engines)
from vllm.v1.executor.abstract import Executor
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr
from vllm.v1.utils import (CoreEngine, CoreEngineActorManager,
CoreEngineProcManager, EngineZmqAddresses,
get_engine_client_zmq_addr, wait_for_engine_startup)
logger = init_logger(__name__)
@ -40,6 +38,8 @@ AnyFuture = Union[asyncio.Future[Any], Future[Any]]
_R = TypeVar('_R') # Return type for collective_rpc
EngineIdentity = bytes
class EngineCoreClient(ABC):
"""
@ -84,14 +84,16 @@ class EngineCoreClient(ABC):
client_addresses: Optional[dict[str, str]] = None,
client_index: int = 0,
) -> "MPClient":
if vllm_config.parallel_config.data_parallel_size > 1:
if vllm_config.parallel_config.data_parallel_backend == "ray":
return RayDPClient(vllm_config, executor_class, log_stats,
client_addresses, client_index)
return DPAsyncMPClient(vllm_config, executor_class, log_stats,
client_addresses, client_index)
return AsyncMPClient(vllm_config, executor_class, log_stats,
client_addresses, client_index)
parallel_config = vllm_config.parallel_config
client_args = (vllm_config, executor_class, log_stats,
client_addresses, client_index)
if parallel_config.data_parallel_size > 1:
if parallel_config.data_parallel_external_lb:
# External load balancer - client per DP rank.
return DPAsyncMPClient(*client_args)
# Internal load balancer - client balances to all DP ranks.
return DPLBAsyncMPClient(*client_args)
return AsyncMPClient(*client_args)
@abstractmethod
def shutdown(self):
@ -386,42 +388,32 @@ class MPClient(EngineCoreClient):
self._finalizer = weakref.finalize(self, self.resources)
success = False
try:
parallel_config = vllm_config.parallel_config
local_engine_count = parallel_config.data_parallel_size_local
local_start_index = parallel_config.data_parallel_rank_local
dp_size = parallel_config.data_parallel_size
dp_rank = parallel_config.data_parallel_rank
# State used for data parallel.
self.engines_running = False
# SPMD mode is where there is an LLM instance per DP rank and
# one core engine per LLM, see
# examples/offline_inference/data_parallel.py.
spmd_mode = local_start_index is not None
if spmd_mode:
assert local_engine_count == 1
self.core_engines = [CoreEngine(index=dp_rank, local=True)]
else:
assert dp_rank == 0
local_start_index = 0
self.core_engines = [
CoreEngine(index=i, local=(i < local_engine_count))
for i in range(dp_size)
]
local_only = spmd_mode or local_engine_count == dp_size
self.stats_update_address: Optional[str] = None
if client_addresses is not None:
# Engines are managed externally to this client.
input_address = client_addresses["input_address"]
output_address = client_addresses["output_address"]
self.stats_update_address = client_addresses.get(
"stats_update_address")
else:
host = parallel_config.data_parallel_master_ip
input_address = get_engine_client_zmq_addr(local_only, host)
output_address = get_engine_client_zmq_addr(local_only, host)
# Engines are managed by this client.
with launch_core_engines(vllm_config, executor_class,
log_stats) as (engine_manager,
coordinator,
addresses):
self.resources.coordinator = coordinator
self.resources.engine_manager = engine_manager
(input_address, ) = addresses.inputs
(output_address, ) = addresses.outputs
self.stats_update_address = (
addresses.frontend_stats_publish_address)
if coordinator is not None:
assert self.stats_update_address == (
coordinator.get_stats_publish_address())
# Create input and output sockets.
self.input_socket = self.resources.input_socket = make_zmq_socket(
@ -429,18 +421,24 @@ class MPClient(EngineCoreClient):
self.resources.output_socket = make_zmq_socket(
self.ctx, output_address, zmq.PULL)
if client_addresses is None:
self._init_engines_direct(vllm_config, local_only,
local_start_index, input_address,
output_address, executor_class,
log_stats)
coordinator = self.resources.coordinator
if coordinator:
self.stats_update_address = (
coordinator.get_stats_publish_address())
parallel_config = vllm_config.parallel_config
dp_size = parallel_config.data_parallel_size
dp_rank = parallel_config.data_parallel_rank
external_dp_lb = parallel_config.data_parallel_external_lb
offline_mode = parallel_config.data_parallel_rank_local is not None
engine_ranks = [dp_rank] if (offline_mode
or external_dp_lb) else range(dp_size)
assert parallel_config.data_parallel_size_local <= len(
engine_ranks)
# ZMQ identity of each engine that this client will talk to.
self.core_engines: list[EngineIdentity] = [
index.to_bytes(2, "little") for index in engine_ranks
]
# Wait for ready messages from each engine on the input socket.
identities = set(e.identity for e in self.core_engines)
identities = set(self.core_engines)
sync_input_socket = zmq.Socket.shadow(self.input_socket)
while identities:
if not sync_input_socket.poll(timeout=600_000):
@ -449,7 +447,7 @@ class MPClient(EngineCoreClient):
identity, _ = sync_input_socket.recv_multipart()
identities.remove(identity)
self.core_engine = self.core_engines[0]
self.core_engine: EngineIdentity = self.core_engines[0]
self.utility_results: dict[int, AnyFuture] = {}
# Request objects which may contain pytorch-allocated tensors
@ -462,73 +460,6 @@ class MPClient(EngineCoreClient):
if not success:
self._finalizer()
def _init_engines_direct(self, vllm_config: VllmConfig, local_only: bool,
local_start_index: int, input_address: str,
output_address: str,
executor_class: type[Executor], log_stats: bool):
"""Self-contained client mode, launch engine and coordinator process
as needed."""
parallel_config = vllm_config.parallel_config
local_engine_count = parallel_config.data_parallel_size_local
start_index = parallel_config.data_parallel_rank
host = parallel_config.data_parallel_master_ip
if len(self.core_engines) > 1:
self.resources.coordinator = DPCoordinator(parallel_config)
handshake_address = get_engine_client_zmq_addr(
local_only, host, parallel_config.data_parallel_rpc_port)
with zmq_socket_ctx(handshake_address, zmq.ROUTER,
bind=True) as handshake_socket:
# Start local engines.
if local_engine_count:
# In server mode, start_index and local_start_index will
# both be 0.
self.resources.engine_manager = CoreEngineProcManager(
EngineCoreProc.run_engine_core,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=log_stats,
handshake_address=handshake_address,
on_head_node=True,
local_engine_count=local_engine_count,
start_index=start_index,
local_start_index=local_start_index)
# Wait for engine core process(es) to start.
self._wait_for_engine_startup(handshake_socket, input_address,
output_address)
def _wait_for_engine_startup(self, handshake_socket: zmq.Socket,
input_address: str, output_address: str):
addresses = EngineZmqAddresses(
inputs=[input_address],
outputs=[output_address],
)
coordinator = self.resources.coordinator
if coordinator is not None:
addresses.coordinator_input, addresses.coordinator_output = (
coordinator.get_engine_socket_addresses())
proc_manager = self.resources.engine_manager
assert isinstance(proc_manager, (type(None), CoreEngineProcManager)), (
"_wait_for_engine_startup should only be "
"called with CoreEngineProcManager")
wait_for_engine_startup(
handshake_socket,
addresses,
self.core_engines,
self.vllm_config.parallel_config,
self.vllm_config.cache_config,
proc_manager,
coordinator.proc if coordinator else None,
)
def shutdown(self):
# Terminate background resources.
self._finalizer()
@ -583,7 +514,6 @@ class SyncMPClient(MPClient):
# a ref to the client which prevents gc.
ctx = self.ctx
out_socket = self.resources.output_socket
assert out_socket is not None
decoder = self.decoder
utility_results = self.utility_results
outputs_queue = self.outputs_queue
@ -593,6 +523,7 @@ class SyncMPClient(MPClient):
resources.shutdown_path = shutdown_path
def process_outputs_socket():
assert isinstance(out_socket, zmq.Socket)
shutdown_socket = ctx.socket(zmq.PAIR)
try:
shutdown_socket.bind(shutdown_path)
@ -609,7 +540,7 @@ class SyncMPClient(MPClient):
frames = out_socket.recv_multipart(copy=False)
resources.validate_alive(frames)
outputs = decoder.decode(frames)
outputs: EngineCoreOutputs = decoder.decode(frames)
if outputs.utility_output:
_process_utility_output(outputs.utility_output,
utility_results)
@ -646,7 +577,7 @@ class SyncMPClient(MPClient):
self.ensure_alive()
self.free_pending_messages()
# (Identity, RequestType, SerializedRequest)
msg = (self.core_engine.identity, request_type.value,
msg = (self.core_engine, request_type.value,
*self.encoder.encode(request))
if len(msg) <= 3:
@ -812,7 +743,7 @@ class AsyncMPClient(MPClient):
def _send_input(self,
request_type: EngineCoreRequestType,
request: Any,
engine: Optional[CoreEngine] = None) -> Awaitable[Any]:
engine: Optional[EngineIdentity] = None) -> Awaitable[Any]:
if engine is None:
engine = self.core_engine
@ -820,7 +751,7 @@ class AsyncMPClient(MPClient):
return self._send_input_message(message, engine, request)
def _send_input_message(self, message: tuple[bytestr,
...], engine: CoreEngine,
...], engine: EngineIdentity,
objects: Any) -> Awaitable[Any]:
"""
objects is a reference to retain until zmq is finished with the
@ -829,7 +760,7 @@ class AsyncMPClient(MPClient):
self.ensure_alive()
self.free_pending_messages()
msg = (engine.identity, ) + message
msg = (engine, ) + message
if not objects or len(msg) <= 3:
# No auxiliary buffers => no tensor backing buffers in request.
return self.input_socket.send_multipart(msg, copy=False)
@ -850,7 +781,7 @@ class AsyncMPClient(MPClient):
engine=self.core_engine)
async def _call_utility_async(self, method: str, *args,
engine: CoreEngine) -> Any:
engine: EngineIdentity) -> Any:
call_id = uuid.uuid1().int >> 64
future = asyncio.get_running_loop().create_future()
self.utility_results[call_id] = future
@ -921,7 +852,7 @@ class AsyncMPClient(MPClient):
class DPAsyncMPClient(AsyncMPClient):
"""Asyncio-compatible client for multi-proc, multi-engine (data parallel)
EngineCore."""
EngineCore. Assumes external load-balancing by default."""
def __init__(self,
vllm_config: VllmConfig,
@ -930,15 +861,12 @@ class DPAsyncMPClient(AsyncMPClient):
client_addresses: Optional[dict[str, str]] = None,
client_index: int = 0):
self.current_wave = 0
# To route aborts to the correct engine.
self.reqs_in_flight: dict[str, CoreEngine] = {}
super().__init__(vllm_config, executor_class, log_stats,
client_addresses, client_index)
assert len(self.core_engines) > 1
# List of [waiting, running] pair per engine.
# Used only by DPLBAsyncMPClient subclass.
self.lb_engines: list[list[int]] = []
self.first_req_sock_addr = get_open_zmq_inproc_path()
@ -969,6 +897,8 @@ class DPAsyncMPClient(AsyncMPClient):
self.first_req_sock_addr,
zmq.PAIR,
bind=False) as first_req_rcv_socket:
assert isinstance(socket, zmq.asyncio.Socket)
assert isinstance(first_req_rcv_socket, zmq.asyncio.Socket)
# Send subscription message.
await socket.send(b'\x01')
@ -1012,34 +942,75 @@ class DPAsyncMPClient(AsyncMPClient):
resources.stats_update_task = asyncio.create_task(
run_engine_stats_update_task())
def get_core_engine_for_request(self,
dp_rank: Optional[int] = None
) -> CoreEngine:
if dp_rank is not None:
# engines are already in rank order
return self.core_engines[dp_rank]
async def add_request_async(self, request: EngineCoreRequest) -> None:
self._ensure_stats_update_task()
if not self.lb_engines:
return self.core_engines[0]
# TODO use P2C alg for larger DP sizes
num_engines = len(self.lb_engines)
min_counts = [sys.maxsize, sys.maxsize]
eng_index = 0
for i in range(num_engines):
# Start from client_index to help with balancing when engines
# are empty.
idx = (self.client_index + i) % num_engines
counts = self.lb_engines[idx]
if counts < min_counts:
min_counts = counts
eng_index = idx
# Adjust local counts for better balancing between stats updates
# from the coordinator (which happen every 100ms).
if min_counts[0]:
min_counts[0] += 1
else:
min_counts[1] += 1
return self.core_engines[eng_index]
request.current_wave = self.current_wave
request.client_index = self.client_index
chosen_engine = self.get_core_engine_for_request(request)
to_await = self._send_input(EngineCoreRequestType.ADD, request,
chosen_engine)
if not self.engines_running:
# Notify coordinator that we're sending a request
await self.first_req_send_socket.send(chosen_engine)
await to_await
self._ensure_output_queue_task()
def get_core_engine_for_request(self, request: EngineCoreRequest):
return self.core_engine
class DPLBAsyncMPClient(DPAsyncMPClient):
"""Asyncio-compatible client for multi-proc, multi-engine (data parallel)
EngineCore. Load-balances between multiple engine processes."""
def __init__(self,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
client_addresses: Optional[dict[str, str]] = None,
client_index: int = 0):
# To route aborts to the correct engine.
self.reqs_in_flight: dict[str, EngineIdentity] = {}
super().__init__(vllm_config, executor_class, log_stats,
client_addresses, client_index)
assert len(self.core_engines) > 1
def get_core_engine_for_request(
self, request: EngineCoreRequest) -> EngineIdentity:
# Engines are in rank order.
if (eng_index := request.data_parallel_rank) is None:
if not self.lb_engines:
return self.core_engine
# TODO use P2C alg for larger DP sizes
num_engines = len(self.lb_engines)
min_counts = [sys.maxsize, sys.maxsize]
eng_index = 0
for i in range(num_engines):
# Start from client_index to help with balancing when engines
# are empty.
idx = (self.client_index + i) % num_engines
counts = self.lb_engines[idx]
if counts < min_counts:
min_counts = counts
eng_index = idx
# Adjust local counts for better balancing between stats updates
# from the coordinator (which happen every 100ms).
if min_counts[0]:
min_counts[0] += 1
else:
min_counts[1] += 1
chosen_engine = self.core_engines[eng_index]
# Record which engine is chosen for this request, to handle aborts.
self.reqs_in_flight[request.request_id] = chosen_engine
return chosen_engine
async def call_utility_async(self, method: str, *args) -> Any:
# Only the result from the first engine is returned.
@ -1048,28 +1019,8 @@ class DPAsyncMPClient(AsyncMPClient):
for engine in self.core_engines
]))[0]
async def add_request_async(self, request: EngineCoreRequest) -> None:
self._ensure_stats_update_task()
request.current_wave = self.current_wave
request.client_index = self.client_index
chosen_engine = self.get_core_engine_for_request(
request.data_parallel_rank)
self.reqs_in_flight[request.request_id] = chosen_engine
to_await = self._send_input(EngineCoreRequestType.ADD, request,
chosen_engine)
if not self.engines_running:
# Notify coordinator that we're sending a request
await self.first_req_send_socket.send(chosen_engine.identity)
await to_await
self._ensure_output_queue_task()
@staticmethod
async def process_engine_outputs(self: "DPAsyncMPClient",
async def process_engine_outputs(self: "DPLBAsyncMPClient",
outputs: EngineCoreOutputs):
if outputs.finished_requests and self.reqs_in_flight:
for req_id in outputs.finished_requests:
@ -1085,61 +1036,14 @@ class DPAsyncMPClient(AsyncMPClient):
await self._abort_requests(request_ids, engine)
return
by_engine: dict[CoreEngine, list[str]] = {}
by_engine = defaultdict[EngineIdentity, list[str]](list)
for req_id in request_ids:
if engine := self.reqs_in_flight.get(req_id):
by_engine.setdefault(engine, []).append(req_id)
by_engine[engine].append(req_id)
for engine, req_ids in by_engine.items():
await self._abort_requests(req_ids, engine)
async def _abort_requests(self, request_ids: list[str],
engine: CoreEngine) -> None:
engine: EngineIdentity) -> None:
await self._send_input(EngineCoreRequestType.ABORT, request_ids,
engine)
class RayDPClient(DPAsyncMPClient):
"""
Ray-based client for multi-proc, multi-engine (data parallel)
EngineCore.
"""
def __init__(
self,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
client_addresses: Optional[dict[str, str]] = None,
client_index: int = 0,
):
super().__init__(vllm_config, executor_class, log_stats,
client_addresses, client_index)
def _init_engines_direct(self, vllm_config: VllmConfig, local_only: bool,
local_start_index: int, input_address: str,
output_address: str,
executor_class: type[Executor], log_stats: bool):
"""Self-contained client mode, launch engine and coordinator process
as needed."""
parallel_config = vllm_config.parallel_config
assert parallel_config.data_parallel_rank == 0
assert local_start_index == 0
addresses = EngineZmqAddresses(
inputs=[input_address],
outputs=[output_address],
)
if len(self.core_engines) > 1:
coordinator = DPCoordinator(parallel_config)
self.resources.coordinator = coordinator
addresses.coordinator_input, addresses.coordinator_output = (
coordinator.get_engine_socket_addresses())
# Start all engines.
self.resources.engine_manager = CoreEngineActorManager(
vllm_config=vllm_config,
addresses=addresses,
executor_class=executor_class,
log_stats=log_stats)

546
vllm/v1/engine/utils.py Normal file
View File

@ -0,0 +1,546 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import weakref
from collections.abc import Iterator
from dataclasses import dataclass
from enum import Enum, auto
from multiprocessing import Process, connection
from multiprocessing.process import BaseProcess
from typing import TYPE_CHECKING, Callable, Optional, Union
import msgspec
import zmq
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.utils import get_mp_context, get_open_zmq_ipc_path, zmq_socket_ctx
from vllm.v1.engine.coordinator import DPCoordinator
from vllm.v1.executor.abstract import Executor
from vllm.v1.utils import get_engine_client_zmq_addr, shutdown
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
logger = init_logger(__name__)
STARTUP_POLL_PERIOD_MS = 10000
class CoreEngineState(Enum):
NEW = auto()
CONNECTED = auto()
READY = auto()
class CoreEngine:
"""One per data parallel rank, used to track state during handshaking."""
def __init__(self, index: int = 0, local: bool = True):
self.local = local
self.identity = index.to_bytes(2, "little")
self.state = CoreEngineState.NEW
@dataclass
class EngineZmqAddresses:
# ZMQ input socket addresses for each front-end client (requests)
inputs: list[str]
# ZMQ output socket addresses for each front-end client (responses)
outputs: list[str]
# ZMQ input socket address of DP coordinator if applicable
coordinator_input: Optional[str] = None
# ZMQ output socket address of DP coordinator if applicable
coordinator_output: Optional[str] = None
# ZMQ socket for front-end to connect to DP coordinator.
# Not used by engine, just relayed to front-end in handshake response.
# Only required for external DP LB case.
frontend_stats_publish_address: Optional[str] = None
@dataclass
class EngineHandshakeMetadata:
"""Metadata sent to each engine process during startup handshake,
including addresses of the front-end ZMQ queues that they should
connect to.
"""
addresses: EngineZmqAddresses
parallel_config: dict[str, Union[int, str]]
class CoreEngineProcManager:
"""
Utility class to handle creation, readiness, and shutdown
of background processes used by the AsyncLLM and LLMEngine.
"""
def __init__(
self,
target_fn: Callable,
local_engine_count: int,
start_index: int,
local_start_index: int,
vllm_config: VllmConfig,
local_client: bool,
handshake_address: str,
executor_class: type[Executor],
log_stats: bool,
client_handshake_address: Optional[str] = None,
):
context = get_mp_context()
common_kwargs = {
"vllm_config": vllm_config,
"local_client": local_client,
"handshake_address": handshake_address,
"executor_class": executor_class,
"log_stats": log_stats,
}
if client_handshake_address:
common_kwargs[
"client_handshake_address"] = client_handshake_address
self.processes: list[BaseProcess] = []
for index in range(local_engine_count):
local_index = local_start_index + index
global_index = start_index + index
# Start EngineCore in background process.
self.processes.append(
context.Process(target=target_fn,
name=f"EngineCore_{global_index}",
kwargs=common_kwargs | {
"dp_rank": global_index,
"local_dp_rank": local_index,
}))
self._finalizer = weakref.finalize(self, shutdown, self.processes)
try:
for proc in self.processes:
proc.start()
finally:
# Kill other procs if not all are running.
if self.finished_procs():
self.close()
def close(self):
"""Shutdown all procs."""
self._finalizer()
def join_first(self):
"""Wait for any process to exit."""
connection.wait(proc.sentinel for proc in self.processes)
def sentinels(self) -> list:
return [proc.sentinel for proc in self.processes]
def finished_procs(self) -> dict[str, int]:
"""Returns dict of proc name -> exit code for any finished procs."""
return {
proc.name: proc.exitcode
for proc in self.processes if proc.exitcode is not None
}
class CoreEngineActorManager:
"""
Utility class to handle creation, readiness, and shutdown
of core engine Ray actors used by the AsyncLLM and LLMEngine.
Different from CoreEngineProcManager, this class manages
core engines for both local and remote nodes.
"""
def __init__(
self,
vllm_config: VllmConfig,
addresses: EngineZmqAddresses,
executor_class: type[Executor],
log_stats: bool,
placement_groups: Optional[list["PlacementGroup"]] = None,
local_dp_ranks: Optional[list[int]] = None,
):
import copy
import ray
from ray.util.scheduling_strategies import (
PlacementGroupSchedulingStrategy)
from vllm.v1.engine.core import DPEngineCoreActor
self.local_engine_actors: list[ray.ActorHandle] = []
self.remote_engine_actors: list[ray.ActorHandle] = []
dp_size = vllm_config.parallel_config.data_parallel_size
local_engine_count = \
vllm_config.parallel_config.data_parallel_size_local
world_size = vllm_config.parallel_config.world_size
if ray.is_initialized():
logger.info(
"Ray is already initialized. Skipping Ray initialization.")
else:
ray.init()
if placement_groups is not None:
assert local_dp_ranks is not None, (
"local_dp_ranks must be provided if "
"placement_groups is provided")
assert len(placement_groups) == len(local_dp_ranks), (
"placement_groups and local_dp_ranks must "
"have the same length")
logger.info("Using provided placement groups")
# TODO(rui): validate passed-in placement groups
self.created_placement_groups = []
else:
placement_groups, local_dp_ranks = \
CoreEngineActorManager.create_dp_placement_groups(vllm_config)
self.created_placement_groups = placement_groups
assert len(placement_groups) == dp_size, (
"Number of placement groups must match data parallel size")
refs = []
for index in range(dp_size):
local_index = local_dp_ranks[index]
dp_vllm_config = copy.deepcopy(vllm_config)
pg = placement_groups[index]
dp_vllm_config.parallel_config.placement_group = pg
local_client = index < local_engine_count
actor = ray.remote(DPEngineCoreActor).options(
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg,
placement_group_bundle_index=world_size,
)).remote(vllm_config=dp_vllm_config,
executor_class=executor_class,
log_stats=log_stats,
local_client=local_client,
addresses=addresses,
dp_rank=index,
local_dp_rank=local_index)
if local_client:
self.local_engine_actors.append(actor)
else:
self.remote_engine_actors.append(actor)
refs.append(actor.wait_for_init.remote())
ray.get(refs)
self.run_refs = []
for actor in self.local_engine_actors + self.remote_engine_actors:
self.run_refs.append(actor.run.remote())
@staticmethod
def create_dp_placement_groups(
vllm_config: VllmConfig
) -> tuple[list["PlacementGroup"], list[int]]:
import ray
from ray._private.state import available_resources_per_node
from ray.util.state import list_nodes
logger.info("Creating placement groups for data parallel")
dp_master_ip = \
vllm_config.parallel_config.data_parallel_master_ip
dp_size = vllm_config.parallel_config.data_parallel_size
local_engine_count = \
vllm_config.parallel_config.data_parallel_size_local
nodes = sorted(list_nodes(),
key=lambda node: node.node_ip != dp_master_ip)
assert nodes[0].node_ip == dp_master_ip, (
"The first node must be the head node")
assert len(nodes) == 1 or nodes[1].node_ip != dp_master_ip, (
"There can only be one head node")
available_resources = available_resources_per_node()
world_size = vllm_config.parallel_config.world_size
placement_groups: list[PlacementGroup] = []
local_dp_ranks: list[int] = []
for node in nodes:
node_ip = node.node_ip
node_resources = available_resources[node.node_id]
# For now, each DP rank can only be assigned to one node
# TODO(rui): support allocating a single DP rank
# to multiple nodes
available_engine_count = int(node_resources["GPU"]) // world_size
if node_ip == dp_master_ip:
assert available_engine_count >= local_engine_count, (
"Not enough resources to allocate DP ranks "
f"on DP master node {node_ip}")
for i in range(local_engine_count):
bundles = [{
"GPU": 1.0,
"node:" + dp_master_ip: 0.001
}] * world_size + [{
"CPU": 1.0
}]
pg = ray.util.placement_group(
name=f"dp_rank_{len(placement_groups)}",
strategy="STRICT_PACK",
bundles=bundles,
)
placement_groups.append(pg)
local_dp_ranks.append(i)
else:
for i in range(available_engine_count):
if len(placement_groups) == dp_size:
break
bundles = [{"GPU": 1.0}] * world_size + [{"CPU": 1.0}]
pg = ray.util.placement_group(
name=f"dp_rank_{len(placement_groups)}",
strategy="STRICT_PACK",
bundles=bundles,
)
placement_groups.append(pg)
local_dp_ranks.append(i)
return placement_groups, local_dp_ranks
def get_run_refs(self):
return self.run_refs
def close(self):
import ray
for actor in self.local_engine_actors + self.remote_engine_actors:
ray.kill(actor)
for pg in self.created_placement_groups:
ray.util.remove_placement_group(pg)
@contextlib.contextmanager
def launch_core_engines(
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
num_api_servers: int = 1,
) -> Iterator[tuple[
Optional[Union[CoreEngineProcManager, CoreEngineActorManager]],
Optional[DPCoordinator],
EngineZmqAddresses,
]]:
"""Launch engine and DP coordinator processes as needed."""
parallel_config = vllm_config.parallel_config
dp_size = parallel_config.data_parallel_size
local_engine_count = parallel_config.data_parallel_size_local
local_start_index = parallel_config.data_parallel_rank_local
dp_rank = parallel_config.data_parallel_rank
host = parallel_config.data_parallel_master_ip
external_dp_lb = parallel_config.data_parallel_external_lb
# In offline mode there is an LLM instance per DP rank and
# one core engine per LLM, see
# examples/offline_inference/data_parallel.py.
offline_mode = local_start_index is not None
# client_local_only = True for cases where this front-end
# sends requests only to colocated engines.
client_local_only = offline_mode or external_dp_lb or (local_engine_count
== dp_size)
# Set up input and output addresses.
addresses = EngineZmqAddresses(
inputs=[
get_engine_client_zmq_addr(client_local_only, host)
for _ in range(num_api_servers)
],
outputs=[
get_engine_client_zmq_addr(client_local_only, host)
for _ in range(num_api_servers)
],
)
# Run the DP Coordinator process with rank 0 when in
# online DP mode.
run_coordinator = dp_size > 1 and not offline_mode and dp_rank == 0
if run_coordinator:
coordinator = DPCoordinator(parallel_config)
addresses.coordinator_input, addresses.coordinator_output = (
coordinator.get_engine_socket_addresses())
addresses.frontend_stats_publish_address = (
coordinator.get_stats_publish_address())
logger.info("Started DP Coordinator process (PID: %d)",
coordinator.proc.pid)
else:
coordinator = None
if parallel_config.data_parallel_backend == "ray":
logger.info("Starting ray-based data parallel backend")
engine_actor_manager = CoreEngineActorManager(
vllm_config=vllm_config,
addresses=addresses,
executor_class=executor_class,
log_stats=log_stats,
)
yield engine_actor_manager, coordinator, addresses
return
if offline_mode or (external_dp_lb and dp_rank > 0):
assert local_engine_count == 1
engines_to_handshake = [CoreEngine(index=dp_rank, local=True)]
else:
engines_to_handshake = [
CoreEngine(index=i, local=(i < local_engine_count))
for i in range(dp_size)
]
# Whether the started engines will handshake only with co-located
# front-end processes. In external_dp_lb mode, ranks > 0 handshake with
# their co-located frontend and also the rank 0 front-end, and hence this
# will be False.
handshake_local_only = offline_mode or local_engine_count == dp_size
handshake_address = get_engine_client_zmq_addr(
handshake_local_only, host, parallel_config.data_parallel_rpc_port)
if external_dp_lb and dp_rank > 0:
assert not handshake_local_only
local_handshake_address = get_open_zmq_ipc_path()
client_handshake_address = local_handshake_address
else:
local_handshake_address = handshake_address
client_handshake_address = None
with zmq_socket_ctx(local_handshake_address, zmq.ROUTER,
bind=True) as handshake_socket:
from vllm.v1.engine.core import EngineCoreProc
# Start local engines.
if local_engine_count:
# In server mode, start_index and local_start_index will
# both be 0.
local_engine_manager = CoreEngineProcManager(
EngineCoreProc.run_engine_core,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=log_stats,
handshake_address=handshake_address,
client_handshake_address=client_handshake_address,
local_client=True,
local_engine_count=local_engine_count,
start_index=dp_rank,
local_start_index=local_start_index or 0)
else:
local_engine_manager = None
yield local_engine_manager, coordinator, addresses
# Now wait for engines to start.
wait_for_engine_startup(
handshake_socket,
addresses,
engines_to_handshake,
parallel_config,
vllm_config.cache_config,
local_engine_manager,
coordinator.proc if coordinator else None,
)
def wait_for_engine_startup(
handshake_socket: zmq.Socket,
addresses: EngineZmqAddresses,
core_engines: list[CoreEngine],
parallel_config: ParallelConfig,
cache_config: CacheConfig,
proc_manager: Optional[CoreEngineProcManager],
coord_process: Optional[Process],
):
# Wait for engine core process(es) to send ready messages.
local_count = parallel_config.data_parallel_size_local
remote_count = len(core_engines) - local_count
# [local, remote] counts
conn_pending, start_pending = [local_count, remote_count], [0, 0]
poller = zmq.Poller()
poller.register(handshake_socket, zmq.POLLIN)
if proc_manager is not None:
for sentinel in proc_manager.sentinels():
poller.register(sentinel, zmq.POLLIN)
if coord_process is not None:
poller.register(coord_process.sentinel, zmq.POLLIN)
while any(conn_pending) or any(start_pending):
events = poller.poll(STARTUP_POLL_PERIOD_MS)
if not events:
if any(conn_pending):
logger.debug(
"Waiting for %d local, %d remote core engine proc(s) "
"to connect.", *conn_pending)
if any(start_pending):
logger.debug(
"Waiting for %d local, %d remote core engine proc(s) "
"to start.", *start_pending)
continue
if len(events) > 1 or events[0][0] != handshake_socket:
# One of the local core processes exited.
finished = proc_manager.finished_procs() if proc_manager else {}
if coord_process is not None and coord_process.exitcode is not None:
finished[coord_process.name] = coord_process.exitcode
raise RuntimeError("Engine core initialization failed. "
"See root cause above. "
f"Failed core proc(s): {finished}")
# Receive HELLO and READY messages from the input socket.
eng_identity, ready_msg_bytes = handshake_socket.recv_multipart()
eng_index = int.from_bytes(eng_identity, "little")
engine = next((e for e in core_engines if e.identity == eng_identity),
None)
if engine is None:
raise RuntimeError(f"Message from engine with unexpected data "
f"parallel rank: {eng_index}")
msg = msgspec.msgpack.decode(ready_msg_bytes)
status, local = msg["status"], msg["local"]
if local != engine.local:
raise RuntimeError(f"{status} message from "
f"{'local' if local else 'remote'} "
f"engine {eng_index}, expected it to be "
f"{'local' if engine.local else 'remote'}")
if status == "HELLO" and engine.state == CoreEngineState.NEW:
# Send init message with DP config info.
init_message = msgspec.msgpack.encode(
EngineHandshakeMetadata(
addresses=addresses,
parallel_config={
"data_parallel_master_ip":
parallel_config.data_parallel_master_ip,
"data_parallel_master_port":
parallel_config.data_parallel_master_port,
"data_parallel_size":
parallel_config.data_parallel_size,
}))
handshake_socket.send_multipart((eng_identity, init_message),
copy=False)
conn_pending[0 if local else 1] -= 1
start_pending[0 if local else 1] += 1
engine.state = CoreEngineState.CONNECTED
elif status == "READY" and engine.state == CoreEngineState.CONNECTED:
# Setup KV cache config with initialization state from
# engine core process. Sum values from all engines in DP case.
num_gpu_blocks = cache_config.num_gpu_blocks or 0
num_gpu_blocks += msg["num_gpu_blocks"]
cache_config.num_gpu_blocks = num_gpu_blocks
# In external DP LB mode, the coordinator address that the
# front-end procs connect to is obtained from rank 0 via
# one of the engine handshakes, and passed to the local
# front-end process in the response from the other.
if addresses.frontend_stats_publish_address is None:
addresses.frontend_stats_publish_address = msg.get(
"dp_stats_address")
start_pending[0 if local else 1] -= 1
engine.state = CoreEngineState.READY
else:
raise RuntimeError(f"Unexpected {status} message for "
f"{'local' if local else 'remote'} engine "
f"{eng_index} in {engine.state} state.")
logger.debug("%s from %s core engine process %s.", status,
"local" if local else "remote", eng_index)

View File

@ -1,44 +1,35 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import multiprocessing
import time
import weakref
from collections import defaultdict
from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum, auto
from multiprocessing import Process, connection
from multiprocessing import connection
from multiprocessing.process import BaseProcess
from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar,
Union, overload)
import msgspec
import torch
import zmq
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.models.utils import extract_layer_index
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import (get_mp_context, get_open_port, get_open_zmq_ipc_path,
get_tcp_uri, kill_process_tree)
from vllm.v1.executor.abstract import Executor
from vllm.utils import (get_open_port, get_open_zmq_ipc_path, get_tcp_uri,
kill_process_tree)
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
from vllm.attention.layer import Attention
from vllm.v1.engine.coordinator import DPCoordinator
from vllm.v1.engine.utils import (CoreEngineActorManager,
CoreEngineProcManager)
logger = init_logger(__name__)
T = TypeVar("T")
STARTUP_POLL_PERIOD_MS = 10000
class ConstantList(Generic[T], Sequence):
@ -111,49 +102,18 @@ class ConstantList(Generic[T], Sequence):
def get_engine_client_zmq_addr(local_only: bool,
host: str,
port: int = 0) -> str:
"""Assign a new ZMQ socket address.
If local_only is True, participants are colocated and so a unique IPC
address will be returned.
Otherwise, the provided host and port will be used to construct a TCP
address (port == 0 means assign an available port)."""
return get_open_zmq_ipc_path() if local_only else (get_tcp_uri(
host, port or get_open_port()))
class CoreEngineState(Enum):
NEW = auto()
CONNECTED = auto()
READY = auto()
class CoreEngine:
"""One per data parallel rank."""
def __init__(self, index: int = 0, local: bool = True):
self.local = local
self.index = index
self.identity = index.to_bytes(2, "little")
self.state = CoreEngineState.NEW
@dataclass
class EngineZmqAddresses:
# ZMQ input socket addresses for each front-end client (requests)
inputs: list[str]
# ZMQ output socket addresses for each front-end client (responses)
outputs: list[str]
# ZMQ input socket address of DP coordinator if applicable
coordinator_input: Optional[str] = None
# ZMQ output socket address of DP coordinator if applicable
coordinator_output: Optional[str] = None
@dataclass
class EngineHandshakeMetadata:
"""Metadata sent to each engine process during startup handshake,
including addresses of the front-end ZMQ queues that they should
connect to.
"""
addresses: EngineZmqAddresses
parallel_config: dict[str, Union[int, str]]
class APIServerProcessManager:
"""Manages a group of API server processes.
@ -219,339 +179,10 @@ class APIServerProcessManager:
self._finalizer()
class CoreEngineProcManager:
"""
Utility class to handle creation, readiness, and shutdown
of background processes used by the AsyncLLM and LLMEngine.
"""
def __init__(
self,
target_fn: Callable,
local_engine_count: int,
start_index: int,
local_start_index: int,
vllm_config: VllmConfig,
on_head_node: bool,
handshake_address: str,
executor_class: type[Executor],
log_stats: bool,
):
context = get_mp_context()
common_kwargs = {
"vllm_config": vllm_config,
"on_head_node": on_head_node,
"handshake_address": handshake_address,
"executor_class": executor_class,
"log_stats": log_stats,
}
self.processes: list[BaseProcess] = []
for index in range(local_engine_count):
local_index = local_start_index + index
global_index = start_index + index
# Start EngineCore in background process.
self.processes.append(
context.Process(target=target_fn,
name=f"EngineCore_{global_index}",
kwargs=common_kwargs | {
"dp_rank": global_index,
"local_dp_rank": local_index,
}))
self._finalizer = weakref.finalize(self, shutdown, self.processes)
try:
for proc in self.processes:
proc.start()
finally:
# Kill other procs if not all are running.
if self.finished_procs():
self.close()
def close(self):
"""Shutdown all procs."""
self._finalizer()
def join_first(self):
"""Wait for any process to exit."""
connection.wait(proc.sentinel for proc in self.processes)
def sentinels(self) -> list:
return [proc.sentinel for proc in self.processes]
def finished_procs(self) -> dict[str, int]:
"""Returns dict of proc name -> exit code for any finished procs."""
return {
proc.name: proc.exitcode
for proc in self.processes if proc.exitcode is not None
}
class CoreEngineActorManager:
"""
Utility class to handle creation, readiness, and shutdown
of core engine Ray actors used by the AsyncLLM and LLMEngine.
Different from CoreEngineProcManager, this class manages
core engines for both local and remote nodes.
"""
def __init__(
self,
vllm_config: VllmConfig,
addresses: EngineZmqAddresses,
executor_class: type[Executor],
log_stats: bool,
placement_groups: Optional[list["PlacementGroup"]] = None,
local_dp_ranks: Optional[list[int]] = None,
):
import copy
import ray
from ray.util.scheduling_strategies import (
PlacementGroupSchedulingStrategy)
from vllm.v1.engine.core import DPEngineCoreActor
self.local_engine_actors: list[ray.ActorHandle] = []
self.remote_engine_actors: list[ray.ActorHandle] = []
dp_size = vllm_config.parallel_config.data_parallel_size
local_engine_count = \
vllm_config.parallel_config.data_parallel_size_local
world_size = vllm_config.parallel_config.world_size
if ray.is_initialized():
logger.info(
"Ray is already initialized. Skipping Ray initialization.")
else:
ray.init()
if placement_groups is not None:
assert local_dp_ranks is not None, (
"local_dp_ranks must be provided if "
"placement_groups is provided")
assert len(placement_groups) == len(local_dp_ranks), (
"placement_groups and local_dp_ranks must "
"have the same length")
logger.info("Using provided placement groups")
# TODO(rui): validate passed-in placement groups
self.created_placement_groups = []
else:
placement_groups, local_dp_ranks = \
CoreEngineActorManager.create_dp_placement_groups(vllm_config)
self.created_placement_groups = placement_groups
assert len(placement_groups) == dp_size, (
"Number of placement groups must match data parallel size")
refs = []
for index in range(dp_size):
local_index = local_dp_ranks[index]
dp_vllm_config = copy.deepcopy(vllm_config)
pg = placement_groups[index]
dp_vllm_config.parallel_config.placement_group = pg
on_head_node = index < local_engine_count
actor = ray.remote(DPEngineCoreActor).options(
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg,
placement_group_bundle_index=world_size,
)).remote(vllm_config=dp_vllm_config,
executor_class=executor_class,
log_stats=log_stats,
on_head_node=on_head_node,
addresses=addresses,
dp_rank=index,
local_dp_rank=local_index)
if on_head_node:
self.local_engine_actors.append(actor)
else:
self.remote_engine_actors.append(actor)
refs.append(actor.wait_for_init.remote())
ray.get(refs)
self.run_refs = []
for actor in self.local_engine_actors + self.remote_engine_actors:
self.run_refs.append(actor.run.remote())
@staticmethod
def create_dp_placement_groups(
vllm_config: VllmConfig
) -> tuple[list["PlacementGroup"], list[int]]:
import ray
from ray._private.state import available_resources_per_node
from ray.util.state import list_nodes
logger.info("Creating placement groups for data parallel")
dp_master_ip = \
vllm_config.parallel_config.data_parallel_master_ip
dp_size = vllm_config.parallel_config.data_parallel_size
local_engine_count = \
vllm_config.parallel_config.data_parallel_size_local
nodes = list_nodes()
nodes = sorted(list_nodes(),
key=lambda node: node.node_ip != dp_master_ip)
assert nodes[0].node_ip == dp_master_ip, (
"The first node must be the head node")
assert len(nodes) == 1 or nodes[1].node_ip != dp_master_ip, (
"There can only be one head node")
available_resources = available_resources_per_node()
world_size = vllm_config.parallel_config.world_size
placement_groups: list[PlacementGroup] = []
local_dp_ranks: list[int] = []
for node in nodes:
node_ip = node.node_ip
node_resources = available_resources[node.node_id]
# For now, each DP rank can only be assigned to one node
# TODO(rui): support allocating a single DP rank
# to multiple nodes
available_engine_count = int(node_resources["GPU"]) // world_size
if node_ip == dp_master_ip:
assert available_engine_count >= local_engine_count, (
"Not enough resources to allocate DP ranks "
f"on DP master node {node_ip}")
for i in range(local_engine_count):
bundles = [{
"GPU": 1.0,
"node:" + dp_master_ip: 0.001
}] * world_size + [{
"CPU": 1.0
}]
pg = ray.util.placement_group(
name=f"dp_rank_{len(placement_groups)}",
strategy="STRICT_PACK",
bundles=bundles,
)
placement_groups.append(pg)
local_dp_ranks.append(i)
else:
for i in range(available_engine_count):
if len(placement_groups) == dp_size:
break
bundles = [{"GPU": 1.0}] * world_size + [{"CPU": 1.0}]
pg = ray.util.placement_group(
name=f"dp_rank_{len(placement_groups)}",
strategy="STRICT_PACK",
bundles=bundles,
)
placement_groups.append(pg)
local_dp_ranks.append(i)
return placement_groups, local_dp_ranks
def get_run_refs(self):
return self.run_refs
def close(self):
import ray
for actor in self.local_engine_actors + self.remote_engine_actors:
ray.kill(actor)
for pg in self.created_placement_groups:
ray.util.remove_placement_group(pg)
def wait_for_engine_startup(
handshake_socket: zmq.Socket,
addresses: EngineZmqAddresses,
core_engines: list[CoreEngine],
parallel_config: ParallelConfig,
cache_config: CacheConfig,
proc_manager: Optional[CoreEngineProcManager],
coord_process: Optional[Process],
):
# Wait for engine core process(es) to send ready messages.
local_count = parallel_config.data_parallel_size_local
remote_count = len(core_engines) - local_count
# [local, remote] counts
conn_pending, start_pending = [local_count, remote_count], [0, 0]
poller = zmq.Poller()
poller.register(handshake_socket, zmq.POLLIN)
if proc_manager is not None:
for sentinel in proc_manager.sentinels():
poller.register(sentinel, zmq.POLLIN)
if coord_process is not None:
poller.register(coord_process.sentinel, zmq.POLLIN)
while any(conn_pending) or any(start_pending):
events = poller.poll(STARTUP_POLL_PERIOD_MS)
if not events:
if any(conn_pending):
logger.debug(
"Waiting for %d local, %d remote core engine proc(s) "
"to connect.", *conn_pending)
if any(start_pending):
logger.debug(
"Waiting for %d local, %d remote core engine proc(s) "
"to start.", *start_pending)
continue
if len(events) > 1 or events[0][0] != handshake_socket:
# One of the local core processes exited.
finished = proc_manager.finished_procs() if proc_manager else {}
if coord_process is not None and coord_process.exitcode is not None:
finished[coord_process.name] = coord_process.exitcode
raise RuntimeError("Engine core initialization failed. "
"See root cause above. "
f"Failed core proc(s): {finished}")
# Receive HELLO and READY messages from the input socket.
eng_identity, ready_msg_bytes = handshake_socket.recv_multipart()
eng_index = int.from_bytes(eng_identity, "little")
engine = next((e for e in core_engines if e.identity == eng_identity),
None)
if engine is None:
raise RuntimeError(f"Message from engine with unexpected data "
f"parallel rank: {eng_index}")
msg = msgspec.msgpack.decode(ready_msg_bytes)
status, local = msg["status"], msg["local"]
if local != engine.local:
raise RuntimeError(f"{status} message from "
f"{'local' if local else 'remote'} "
f"engine {eng_index}, expected it to be "
f"{'local' if engine.local else 'remote'}")
if status == "HELLO" and engine.state == CoreEngineState.NEW:
# Send init message with DP config info.
init_message = msgspec.msgpack.encode(
EngineHandshakeMetadata(
addresses=addresses,
parallel_config={
"data_parallel_master_ip":
parallel_config.data_parallel_master_ip,
"data_parallel_master_port":
parallel_config.data_parallel_master_port,
"data_parallel_size":
parallel_config.data_parallel_size,
}))
handshake_socket.send_multipart((eng_identity, init_message),
copy=False)
conn_pending[0 if local else 1] -= 1
start_pending[0 if local else 1] += 1
engine.state = CoreEngineState.CONNECTED
elif status == "READY" and (engine.state == CoreEngineState.CONNECTED):
# Setup KV cache config with initialization state from
# engine core process. Sum values from all engines in DP case.
num_gpu_blocks = cache_config.num_gpu_blocks or 0
num_gpu_blocks += msg["num_gpu_blocks"]
cache_config.num_gpu_blocks = num_gpu_blocks
start_pending[0 if local else 1] -= 1
engine.state = CoreEngineState.READY
else:
raise RuntimeError(f"Unexpected {status} message for "
f"{'local' if local else 'remote'} engine "
f"{eng_index} in {engine.state} state.")
logger.debug("%s from %s core engine process %s.", status,
"local" if local else "remote", eng_index)
def wait_for_completion_or_failure(
api_server_manager: APIServerProcessManager,
engine_manager: Optional[Union[CoreEngineProcManager,
CoreEngineActorManager]] = None,
engine_manager: Optional[Union["CoreEngineProcManager",
"CoreEngineActorManager"]] = None,
coordinator: Optional["DPCoordinator"] = None) -> None:
"""Wait for all processes to complete or detect if any fail.
@ -565,6 +196,9 @@ def wait_for_completion_or_failure(
coordinator: The coordinator for data parallel.
"""
from vllm.v1.engine.utils import (CoreEngineActorManager,
CoreEngineProcManager)
try:
logger.info("Waiting for API servers to complete ...")
# Create a mapping of sentinels to their corresponding processes