[V1] Support DP with Ray (#18779)

This commit is contained in:
Rui Qiao 2025-06-02 21:15:13 -07:00 committed by GitHub
parent 9e6f61e8c3
commit bdce64f236
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 551 additions and 120 deletions

View File

@ -17,7 +17,7 @@ vector_quantize_pytorch # required for minicpmo_26 test
vocos # required for minicpmo_26 test
peft
pqdm
ray[cgraph]>=2.43.0, !=2.44.* # Ray Compiled Graph, required by pipeline parallelism tests
ray[cgraph,default]>=2.43.0, !=2.44.* # Ray Compiled Graph, required by pipeline parallelism tests
sentence-transformers # required for embedding tests
soundfile # required for audio tests
jiwer # required for audio tests

View File

@ -10,9 +10,13 @@ aiohappyeyeballs==2.4.3
# via aiohttp
aiohttp==3.10.11
# via
# aiohttp-cors
# datasets
# fsspec
# lm-eval
# ray
aiohttp-cors==0.8.1
# via ray
aiosignal==1.3.1
# via
# aiohttp
@ -57,6 +61,8 @@ bounded-pool-executor==0.0.3
# via pqdm
buildkite-test-collector==0.1.9
# via -r requirements/test.in
cachetools==5.5.2
# via google-auth
certifi==2024.8.30
# via
# httpcore
@ -81,6 +87,8 @@ colorama==0.4.6
# sacrebleu
# schemathesis
# tqdm-multiprocess
colorful==0.5.6
# via ray
contourpy==1.3.0
# via matplotlib
cramjam==2.9.0
@ -108,6 +116,8 @@ dill==0.3.8
# evaluate
# lm-eval
# multiprocess
distlib==0.3.9
# via virtualenv
dnspython==2.7.0
# via email-validator
docopt==0.6.2
@ -143,6 +153,7 @@ filelock==3.16.1
# ray
# torch
# transformers
# virtualenv
fonttools==4.54.1
# via matplotlib
fqdn==1.5.1
@ -165,8 +176,16 @@ genai-perf==0.0.8
# via -r requirements/test.in
genson==1.3.0
# via datamodel-code-generator
google-api-core==2.24.2
# via opencensus
google-auth==2.40.2
# via google-api-core
googleapis-common-protos==1.70.0
# via google-api-core
graphql-core==3.2.6
# via hypothesis-graphql
grpcio==1.71.0
# via ray
h11==0.14.0
# via httpcore
harfile==0.3.0
@ -392,6 +411,10 @@ nvidia-nvjitlink-cu12==12.8.61
# torch
nvidia-nvtx-cu12==12.8.55
# via torch
opencensus==0.11.4
# via ray
opencensus-context==0.1.3
# via opencensus
opencv-python-headless==4.11.0.86
# via
# -r requirements/test.in
@ -445,6 +468,7 @@ platformdirs==4.3.6
# via
# black
# pooch
# virtualenv
plotly==5.24.1
# via genai-perf
pluggy==1.5.0
@ -457,10 +481,17 @@ portalocker==2.10.1
# via sacrebleu
pqdm==0.2.0
# via -r requirements/test.in
prometheus-client==0.22.0
# via ray
propcache==0.2.0
# via yarl
proto-plus==1.26.1
# via google-api-core
protobuf==5.28.3
# via
# google-api-core
# googleapis-common-protos
# proto-plus
# ray
# tensorizer
psutil==6.1.0
@ -470,10 +501,18 @@ psutil==6.1.0
# tensorizer
py==1.11.0
# via pytest-forked
py-spy==0.4.0
# via ray
pyarrow==18.0.0
# via
# datasets
# genai-perf
pyasn1==0.6.1
# via
# pyasn1-modules
# rsa
pyasn1-modules==0.4.2
# via google-auth
pybind11==2.13.6
# via lm-eval
pycparser==2.22
@ -486,6 +525,7 @@ pydantic==2.11.5
# datamodel-code-generator
# mistral-common
# mteb
# ray
pydantic-core==2.33.2
# via pydantic
pygments==2.18.0
@ -573,6 +613,7 @@ requests==2.32.3
# buildkite-test-collector
# datasets
# evaluate
# google-api-core
# huggingface-hub
# lm-eval
# mistral-common
@ -601,6 +642,8 @@ rpds-py==0.20.1
# via
# jsonschema
# referencing
rsa==4.9.1
# via google-auth
runai-model-streamer==0.11.0
# via -r requirements/test.in
runai-model-streamer-s3==0.11.0
@ -648,9 +691,12 @@ shellingham==1.5.4
six==1.16.0
# via
# junit-xml
# opencensus
# python-dateutil
# rfc3339-validator
# rouge-score
smart-open==7.1.0
# via ray
sniffio==1.3.1
# via
# anyio
@ -801,6 +847,8 @@ urllib3==2.2.3
# tritonclient
vector-quantize-pytorch==1.21.2
# via -r requirements/test.in
virtualenv==20.31.2
# via ray
vocos==0.1.0
# via -r requirements/test.in
webcolors==24.11.1
@ -809,6 +857,8 @@ werkzeug==3.1.3
# via schemathesis
word2number==1.1
# via lm-eval
wrapt==1.17.2
# via smart-open
xxhash==3.5.0
# via
# datasets

View File

@ -59,14 +59,22 @@ async def generate(engine: AsyncLLM,
@pytest.mark.parametrize(
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
"output_kind",
[
RequestOutputKind.DELTA,
RequestOutputKind.FINAL_ONLY,
],
)
@pytest.mark.parametrize("data_parallel_backend", ["mp", "ray"])
@pytest.mark.asyncio
async def test_load(output_kind: RequestOutputKind):
async def test_load(output_kind: RequestOutputKind,
data_parallel_backend: str):
with ExitStack() as after:
prompt = "This is a test of data parallel"
engine_args.data_parallel_backend = data_parallel_backend
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)
@ -82,7 +90,6 @@ async def test_load(output_kind: RequestOutputKind):
asyncio.create_task(
generate(engine, request_id, prompt, output_kind,
NUM_EXPECTED_TOKENS)))
# Confirm that we got all the EXPECTED tokens from the requests.
done, pending = await asyncio.wait(tasks,
return_when=asyncio.FIRST_EXCEPTION)

View File

@ -1742,6 +1742,8 @@ class ParallelConfig:
"""Port for data parallel messaging."""
data_parallel_master_port: int = 29500
"""Port of the data parallel master."""
data_parallel_backend: str = "mp"
"""Backend to use for data parallel, either "mp" or "ray"."""
enable_expert_parallel: bool = False
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
max_parallel_loading_workers: Optional[int] = None
@ -1911,6 +1913,10 @@ class ParallelConfig:
"please install Ray with `pip install "
"ray`.") from ray_utils.ray_import_err
backend = "ray"
elif self.data_parallel_backend == "ray":
logger.info("Using ray distributed inference because "
"data_parallel_backend is ray")
backend = "ray"
elif ray_found:
if self.placement_group:
backend = "ray"

View File

@ -39,7 +39,7 @@ from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
from vllm.transformers_utils.utils import check_gguf_file
from vllm.usage.usage_lib import UsageContext
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
GiB_bytes, is_in_ray_actor)
GiB_bytes, get_ip, is_in_ray_actor)
# yapf: enable
@ -292,6 +292,7 @@ class EngineArgs:
data_parallel_size_local: Optional[int] = None
data_parallel_address: Optional[str] = None
data_parallel_rpc_port: Optional[int] = None
data_parallel_backend: str = ParallelConfig.data_parallel_backend
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
max_parallel_loading_workers: Optional[
int] = ParallelConfig.max_parallel_loading_workers
@ -624,6 +625,12 @@ class EngineArgs:
type=int,
help='Port for data parallel RPC '
'communication.')
parallel_group.add_argument('--data-parallel-backend',
'-dpb',
type=str,
default='mp',
help='Backend for data parallel, either '
'"mp" or "ray".')
parallel_group.add_argument(
"--enable-expert-parallel",
**parallel_kwargs["enable_expert_parallel"])
@ -1059,9 +1066,20 @@ class EngineArgs:
# DP address, used in multi-node case for torch distributed group
# and ZMQ sockets.
data_parallel_address = self.data_parallel_address if (
self.data_parallel_address
is not None) else ParallelConfig.data_parallel_master_ip
if self.data_parallel_address is None:
if self.data_parallel_backend == "ray":
host_ip = get_ip()
logger.info(
"Using host IP %s as ray-based data parallel address",
host_ip)
data_parallel_address = host_ip
else:
assert self.data_parallel_backend == "mp", (
"data_parallel_backend can only be ray or mp, got %s",
self.data_parallel_backend)
data_parallel_address = ParallelConfig.data_parallel_master_ip
else:
data_parallel_address = self.data_parallel_address
# This port is only used when there are remote data parallel engines,
# otherwise the local IPC transport is used.
@ -1069,6 +1087,8 @@ class EngineArgs:
self.data_parallel_rpc_port
is not None) else ParallelConfig.data_parallel_rpc_port
data_parallel_backend = self.data_parallel_backend
parallel_config = ParallelConfig(
pipeline_parallel_size=self.pipeline_parallel_size,
tensor_parallel_size=self.tensor_parallel_size,
@ -1076,6 +1096,7 @@ class EngineArgs:
data_parallel_size_local=data_parallel_size_local,
data_parallel_master_ip=data_parallel_address,
data_parallel_rpc_port=data_parallel_rpc_port,
data_parallel_backend=data_parallel_backend,
enable_expert_parallel=self.enable_expert_parallel,
max_parallel_loading_workers=self.max_parallel_loading_workers,
disable_custom_all_reduce=self.disable_custom_all_reduce,

View File

@ -27,7 +27,8 @@ from vllm.v1.engine.core_client import CoreEngineProcManager
from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
from vllm.v1.utils import (APIServerProcessManager, CoreEngine,
EngineZmqAddresses, get_engine_client_zmq_addr,
CoreEngineActorManager, EngineZmqAddresses,
get_engine_client_zmq_addr,
wait_for_completion_or_failure,
wait_for_engine_startup)
@ -229,6 +230,31 @@ def run_multi_api_server(args: argparse.Namespace):
logger.info("Started DP Coordinator process (PID: %d)",
coordinator.proc.pid)
if parallel_config.data_parallel_backend == "ray":
logger.info("Starting ray-based data parallel backend")
engine_actor_manager = CoreEngineActorManager(
vllm_config=vllm_config,
addresses=addresses,
executor_class=Executor.get_class(vllm_config),
log_stats=not engine_args.disable_log_stats,
)
# Start API servers using the manager
api_server_manager = APIServerProcessManager(
target_server_fn=run_api_server_worker_proc,
listen_address=listen_address,
sock=sock,
args=args,
num_servers=num_api_servers,
input_addresses=input_addresses,
output_addresses=output_addresses,
stats_update_address=stats_update_address)
wait_for_completion_or_failure(api_server_manager=api_server_manager,
engine_manager=engine_actor_manager,
coordinator=coordinator)
return
handshake_address = get_engine_client_zmq_addr(
local_only, host, parallel_config.data_parallel_rpc_port)
@ -277,10 +303,9 @@ def run_multi_api_server(args: argparse.Namespace):
)
# Wait for API servers
wait_for_completion_or_failure(
api_server_manager=api_server_manager,
local_engine_manager=local_engine_manager,
coordinator=coordinator)
wait_for_completion_or_failure(api_server_manager=api_server_manager,
engine_manager=local_engine_manager,
coordinator=coordinator)
def run_api_server_worker_proc(listen_address,

View File

@ -27,7 +27,8 @@ from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Device, cdiv
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core_client import AsyncMPClient, DPAsyncMPClient
from vllm.v1.engine.core_client import (AsyncMPClient, DPAsyncMPClient,
RayDPClient)
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
from vllm.v1.engine.output_processor import (OutputProcessor,
RequestOutputCollector)
@ -119,9 +120,13 @@ class AsyncLLM(EngineClient):
log_stats=self.log_stats)
# EngineCore (starts the engine in background process).
core_client_class = AsyncMPClient if (
vllm_config.parallel_config.data_parallel_size
== 1) else DPAsyncMPClient
core_client_class: type[AsyncMPClient]
if vllm_config.parallel_config.data_parallel_size == 1:
core_client_class = AsyncMPClient
elif vllm_config.parallel_config.data_parallel_backend == "ray":
core_client_class = RayDPClient
else:
core_client_class = DPAsyncMPClient
self.engine_core = core_client_class(
vllm_config=vllm_config,

View File

@ -6,8 +6,9 @@ import sys
import threading
import time
from collections import deque
from collections.abc import Generator
from concurrent.futures import Future
from contextlib import ExitStack
from contextlib import ExitStack, contextmanager
from inspect import isclass, signature
from logging import DEBUG
from typing import Any, Callable, Optional, TypeVar, Union
@ -367,42 +368,66 @@ class EngineCoreProc(EngineCore):
log_stats: bool,
engine_index: int = 0,
):
input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
executor_fail_callback = lambda: input_queue.put_nowait(
self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs],
bytes]]()
executor_fail_callback = lambda: self.input_queue.put_nowait(
(EngineCoreRequestType.EXECUTOR_FAILED, b''))
# Create input socket.
self.engine_index = engine_index
identity = self.engine_index.to_bytes(length=2, byteorder="little")
self.engines_running = False
with self._perform_handshake(handshake_address, identity, on_head_node,
vllm_config) as addresses:
self.client_count = len(addresses.outputs)
# Set up data parallel environment.
self.has_coordinator = addresses.coordinator_output is not None
self._init_data_parallel(vllm_config)
super().__init__(vllm_config, executor_class, log_stats,
executor_fail_callback)
self.step_fn = (self.step if self.batch_queue is None else
self.step_with_batch_queue)
# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
# and to overlap some serialization/deserialization with the
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
threading.Thread(target=self.process_input_sockets,
args=(addresses.inputs, addresses.coordinator_input,
identity),
daemon=True).start()
self.output_thread = threading.Thread(
target=self.process_output_sockets,
args=(addresses.outputs, addresses.coordinator_output,
self.engine_index),
daemon=True)
self.output_thread.start()
@contextmanager
def _perform_handshake(
self, handshake_address: str, identity: bytes, on_head_node: bool,
vllm_config: VllmConfig
) -> Generator[EngineZmqAddresses, None, None]:
input_ctx = zmq.Context()
identity = engine_index.to_bytes(length=2, byteorder="little")
with make_zmq_socket(input_ctx,
handshake_address,
zmq.DEALER,
identity=identity,
linger=5000,
bind=False) as handshake_socket:
# Register engine with front-end.
addresses = self.startup_handshake(handshake_socket, on_head_node,
vllm_config.parallel_config)
self.client_count = len(addresses.outputs)
# Update config which may have changed from the handshake.
# Update config which may have changed from the handshake
vllm_config.__post_init__()
# Set up data parallel environment.
self.has_coordinator = addresses.coordinator_output is not None
self._init_data_parallel(vllm_config)
# Initialize engine core and model.
super().__init__(vllm_config, executor_class, log_stats,
executor_fail_callback)
self.engine_index = engine_index
self.step_fn = (self.step if self.batch_queue is None else
self.step_with_batch_queue)
self.engines_running = False
self.last_counts = (0, 0)
yield addresses
# Send ready message.
num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
@ -413,25 +438,6 @@ class EngineCoreProc(EngineCore):
"num_gpu_blocks": num_gpu_blocks,
}))
# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
# and to overlap some serialization/deserialization with the
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
self.input_queue = input_queue
self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs],
bytes]]()
threading.Thread(target=self.process_input_sockets,
args=(addresses.inputs, addresses.coordinator_input,
identity),
daemon=True).start()
self.output_thread = threading.Thread(
target=self.process_output_sockets,
args=(addresses.outputs, addresses.coordinator_output,
engine_index),
daemon=True)
self.output_thread.start()
@staticmethod
def startup_handshake(
handshake_socket: zmq.Socket, on_head_node: bool,
@ -743,6 +749,21 @@ class DPEngineCoreProc(EngineCoreProc):
executor_class: type[Executor],
log_stats: bool,
):
self._decorate_logs()
# Counts forward-passes of the model so that we can synchronize
# finished with DP peers every N steps.
self.counter = 0
self.current_wave = 0
self.last_counts = (0, 0)
# Initialize the engine.
dp_rank = vllm_config.parallel_config.data_parallel_rank
super().__init__(vllm_config, on_head_node, handshake_address,
executor_class, log_stats, dp_rank)
def _decorate_logs(self):
# Add process-specific prefix to stdout and stderr before
# we initialize the engine.
from multiprocessing import current_process
@ -751,16 +772,6 @@ class DPEngineCoreProc(EngineCoreProc):
_add_prefix(sys.stdout, process_name, pid)
_add_prefix(sys.stderr, process_name, pid)
# Counts forward-passes of the model so that we can synchronize
# finished with DP peers every N steps.
self.counter = 0
self.current_wave = 0
# Initialize the engine.
dp_rank = vllm_config.parallel_config.data_parallel_rank
super().__init__(vllm_config, on_head_node, handshake_address,
executor_class, log_stats, dp_rank)
def _init_data_parallel(self, vllm_config: VllmConfig):
# Configure GPUs and stateless process group for data parallel.
@ -880,3 +891,70 @@ class DPEngineCoreProc(EngineCoreProc):
return ParallelConfig.has_unfinished_dp(self.dp_group,
local_unfinished)
class DPEngineCoreActor(DPEngineCoreProc):
"""
Ray actor for running EngineCore in a data parallel context
"""
def __init__(
self,
vllm_config: VllmConfig,
on_head_node: bool,
addresses: EngineZmqAddresses,
executor_class: type[Executor],
log_stats: bool,
dp_rank: int = 0,
local_dp_rank: int = 0,
):
self.addresses = addresses
vllm_config.parallel_config.data_parallel_rank = dp_rank
vllm_config.parallel_config.data_parallel_rank_local = \
local_dp_rank
# Ray sets CUDA_VISIBLE_DEVICES to empty string,
# we clean this up to be able to properly initialize
# data parallel groups.
del os.environ['CUDA_VISIBLE_DEVICES']
super().__init__(vllm_config, on_head_node, "", executor_class,
log_stats)
def _decorate_logs(self):
pass
@contextmanager
def _perform_handshake(self, handshake_address: str, identity: bytes,
on_head_node: bool, vllm_config: VllmConfig):
"""
For Ray, we don't need to actually perform handshake.
All addresses information is known before the actor creation.
Therefore, we simply yield these addresses.
"""
yield self.addresses
def wait_for_init(self):
"""
Wait until the engine core is initialized.
This is just an empty method. When ray.get() on this method
(or any other method of the actor) returns, it is guaranteed
that actor creation (i.e., __init__) is complete.
"""
pass
def run(self):
"""
Run the engine core busy loop.
"""
try:
self.run_busy_loop()
except SystemExit:
logger.debug("EngineCore exiting.")
raise
except Exception:
logger.exception("EngineCore encountered a fatal error.")
raise
finally:
self.shutdown()

View File

@ -29,9 +29,9 @@ from vllm.v1.engine.core import EngineCore, EngineCoreProc
from vllm.v1.engine.exceptions import EngineDeadError
from vllm.v1.executor.abstract import Executor
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr
from vllm.v1.utils import (CoreEngine, CoreEngineProcManager,
EngineZmqAddresses, get_engine_client_zmq_addr,
wait_for_engine_startup)
from vllm.v1.utils import (CoreEngine, CoreEngineActorManager,
CoreEngineProcManager, EngineZmqAddresses,
get_engine_client_zmq_addr, wait_for_engine_startup)
logger = init_logger(__name__)
@ -68,6 +68,8 @@ class EngineCoreClient(ABC):
if multiprocess_mode and asyncio_mode:
if vllm_config.parallel_config.data_parallel_size > 1:
if vllm_config.parallel_config.data_parallel_backend == "ray":
return RayDPClient(vllm_config, executor_class, log_stats)
return DPAsyncMPClient(vllm_config, executor_class, log_stats)
return AsyncMPClient(vllm_config, executor_class, log_stats)
@ -273,7 +275,10 @@ class BackgroundResources:
circular reference back to the client object."""
ctx: Union[zmq.Context]
local_engine_manager: Optional[CoreEngineProcManager] = None
# If CoreEngineProcManager, it manages local engines;
# if CoreEngineActorManager, it manages all engines.
engine_manager: Optional[Union[CoreEngineProcManager,
CoreEngineActorManager]] = None
coordinator: Optional[DPCoordinator] = None
output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
@ -290,8 +295,8 @@ class BackgroundResources:
"""Clean up background resources."""
self.engine_dead = True
if self.local_engine_manager is not None:
self.local_engine_manager.close()
if self.engine_manager is not None:
self.engine_manager.close()
if self.coordinator is not None:
self.coordinator.close()
@ -457,7 +462,7 @@ class MPClient(EngineCoreClient):
if local_engine_count:
# In server mode, start_index and local_start_index will
# both be 0.
self.resources.local_engine_manager = CoreEngineProcManager(
self.resources.engine_manager = CoreEngineProcManager(
EngineCoreProc.run_engine_core,
vllm_config=vllm_config,
executor_class=executor_class,
@ -484,13 +489,18 @@ class MPClient(EngineCoreClient):
addresses.coordinator_input, addresses.coordinator_output = (
coordinator.get_engine_socket_addresses())
proc_manager = self.resources.engine_manager
assert isinstance(proc_manager, (type(None), CoreEngineProcManager)), (
"_wait_for_engine_startup should only be "
"called with CoreEngineProcManager")
wait_for_engine_startup(
handshake_socket,
addresses,
self.core_engines,
self.vllm_config.parallel_config,
self.vllm_config.cache_config,
self.resources.local_engine_manager,
proc_manager,
coordinator.proc if coordinator else None,
)
@ -887,7 +897,6 @@ class DPAsyncMPClient(AsyncMPClient):
log_stats: bool,
client_addresses: Optional[dict[str, str]] = None,
client_index: int = 0):
self.current_wave = 0
self.engines_running = False
# To route aborts to the correct engine.
@ -1050,3 +1059,50 @@ class DPAsyncMPClient(AsyncMPClient):
if not self.resources.engine_dead:
await self._send_input(EngineCoreRequestType.ABORT, request_ids,
engine)
class RayDPClient(DPAsyncMPClient):
"""
Ray-based client for multi-proc, multi-engine (data parallel)
EngineCore.
"""
def __init__(
self,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
client_addresses: Optional[dict[str, str]] = None,
client_index: int = 0,
):
super().__init__(vllm_config, executor_class, log_stats,
client_addresses, client_index)
def _init_engines_direct(self, vllm_config: VllmConfig, local_only: bool,
local_start_index: int, input_address: str,
output_address: str,
executor_class: type[Executor], log_stats: bool):
"""Self-contained client mode, launch engine and coordinator process
as needed."""
parallel_config = vllm_config.parallel_config
assert parallel_config.data_parallel_rank == 0
assert local_start_index == 0
addresses = EngineZmqAddresses(
inputs=[input_address],
outputs=[output_address],
)
if len(self.core_engines) > 1:
coordinator = DPCoordinator(parallel_config)
self.resources.coordinator = coordinator
addresses.coordinator_input, addresses.coordinator_output = (
coordinator.get_engine_socket_addresses())
# Start all engines.
self.resources.engine_manager = CoreEngineActorManager(
vllm_config=vllm_config,
addresses=addresses,
executor_class=executor_class,
log_stats=log_stats)

View File

@ -27,6 +27,8 @@ from vllm.utils import (get_mp_context, get_open_port, get_open_zmq_ipc_path,
from vllm.v1.executor.abstract import Executor
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
from vllm.attention.layer import Attention
from vllm.v1.engine.coordinator import DPCoordinator
@ -112,6 +114,45 @@ def get_engine_client_zmq_addr(local_only: bool,
host, port or get_open_port()))
class CoreEngineState(Enum):
NEW = auto()
CONNECTED = auto()
READY = auto()
class CoreEngine:
"""One per data parallel rank."""
def __init__(self, index: int = 0, local: bool = True):
self.local = local
self.index = index
self.identity = index.to_bytes(2, "little")
self.state = CoreEngineState.NEW
@dataclass
class EngineZmqAddresses:
# ZMQ input socket addresses for each front-end client (requests)
inputs: list[str]
# ZMQ output socket addresses for each front-end client (responses)
outputs: list[str]
# ZMQ input socket address of DP coordinator if applicable
coordinator_input: Optional[str] = None
# ZMQ output socket address of DP coordinator if applicable
coordinator_output: Optional[str] = None
@dataclass
class EngineHandshakeMetadata:
"""Metadata sent to each engine process during startup handshake,
including addresses of the front-end ZMQ queues that they should
connect to.
"""
addresses: EngineZmqAddresses
parallel_config: dict[str, Union[int, str]]
class APIServerProcessManager:
"""Manages a group of API server processes.
@ -245,43 +286,168 @@ class CoreEngineProcManager:
}
class CoreEngineState(Enum):
NEW = auto()
CONNECTED = auto()
READY = auto()
class CoreEngine:
"""One per data parallel rank."""
def __init__(self, index: int = 0, local: bool = True):
self.local = local
self.index = index
self.identity = index.to_bytes(2, "little")
self.state = CoreEngineState.NEW
@dataclass
class EngineZmqAddresses:
# ZMQ input socket addresses for each front-end client (requests)
inputs: list[str]
# ZMQ output socket addresses for each front-end client (responses)
outputs: list[str]
# ZMQ input socket address of DP coordinator if applicable
coordinator_input: Optional[str] = None
# ZMQ output socket address of DP coordinator if applicable
coordinator_output: Optional[str] = None
@dataclass
class EngineHandshakeMetadata:
"""Metadata sent to each engine process during startup handshake,
including addresses of the front-end ZMQ queues that they should
connect to.
class CoreEngineActorManager:
"""
addresses: EngineZmqAddresses
parallel_config: dict[str, Union[int, str]]
Utility class to handle creation, readiness, and shutdown
of core engine Ray actors used by the AsyncLLM and LLMEngine.
Different from CoreEngineProcManager, this class manages
core engines for both local and remote nodes.
"""
def __init__(
self,
vllm_config: VllmConfig,
addresses: EngineZmqAddresses,
executor_class: type[Executor],
log_stats: bool,
placement_groups: Optional[list["PlacementGroup"]] = None,
local_dp_ranks: Optional[list[int]] = None,
):
import copy
import ray
from ray.util.scheduling_strategies import (
PlacementGroupSchedulingStrategy)
from vllm.v1.engine.core import DPEngineCoreActor
self.local_engine_actors: list[ray.ActorHandle] = []
self.remote_engine_actors: list[ray.ActorHandle] = []
dp_size = vllm_config.parallel_config.data_parallel_size
local_engine_count = \
vllm_config.parallel_config.data_parallel_size_local
world_size = vllm_config.parallel_config.world_size
if ray.is_initialized():
logger.info(
"Ray is already initialized. Skipping Ray initialization.")
else:
ray.init()
if placement_groups is not None:
assert local_dp_ranks is not None, (
"local_dp_ranks must be provided if "
"placement_groups is provided")
assert len(placement_groups) == len(local_dp_ranks), (
"placement_groups and local_dp_ranks must "
"have the same length")
logger.info("Using provided placement groups")
# TODO(rui): validate passed-in placement groups
self.created_placement_groups = []
else:
placement_groups, local_dp_ranks = \
CoreEngineActorManager.create_dp_placement_groups(vllm_config)
self.created_placement_groups = placement_groups
assert len(placement_groups) == dp_size, (
"Number of placement groups must match data parallel size")
refs = []
for index in range(dp_size):
local_index = local_dp_ranks[index]
dp_vllm_config = copy.deepcopy(vllm_config)
pg = placement_groups[index]
dp_vllm_config.parallel_config.placement_group = pg
on_head_node = index < local_engine_count
actor = ray.remote(DPEngineCoreActor).options(
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=pg,
placement_group_bundle_index=world_size,
)).remote(vllm_config=dp_vllm_config,
executor_class=executor_class,
log_stats=log_stats,
on_head_node=on_head_node,
addresses=addresses,
dp_rank=index,
local_dp_rank=local_index)
if on_head_node:
self.local_engine_actors.append(actor)
else:
self.remote_engine_actors.append(actor)
refs.append(actor.wait_for_init.remote())
ray.get(refs)
self.run_refs = []
for actor in self.local_engine_actors + self.remote_engine_actors:
self.run_refs.append(actor.run.remote())
@staticmethod
def create_dp_placement_groups(
vllm_config: VllmConfig
) -> tuple[list["PlacementGroup"], list[int]]:
import ray
from ray._private.state import available_resources_per_node
from ray.util.state import list_nodes
logger.info("Creating placement groups for data parallel")
dp_master_ip = \
vllm_config.parallel_config.data_parallel_master_ip
dp_size = vllm_config.parallel_config.data_parallel_size
local_engine_count = \
vllm_config.parallel_config.data_parallel_size_local
nodes = list_nodes()
nodes = sorted(list_nodes(),
key=lambda node: node.node_ip != dp_master_ip)
assert nodes[0].node_ip == dp_master_ip, (
"The first node must be the head node")
assert len(nodes) == 1 or nodes[1].node_ip != dp_master_ip, (
"There can only be one head node")
available_resources = available_resources_per_node()
world_size = vllm_config.parallel_config.world_size
placement_groups: list[PlacementGroup] = []
local_dp_ranks: list[int] = []
for node in nodes:
node_ip = node.node_ip
node_resources = available_resources[node.node_id]
# For now, each DP rank can only be assigned to one node
# TODO(rui): support allocating a single DP rank
# to multiple nodes
available_engine_count = node_resources["GPU"] // world_size
if node_ip == dp_master_ip:
assert available_engine_count >= local_engine_count, (
"Not enough resources to allocate DP ranks "
f"on DP master node {node_ip}")
for i in range(local_engine_count):
bundles = [{
"GPU": 1.0,
"node:" + dp_master_ip: 0.001
}] * world_size + [{
"CPU": 1.0
}]
pg = ray.util.placement_group(
name=f"dp_rank_{len(placement_groups)}",
strategy="STRICT_PACK",
bundles=bundles,
)
placement_groups.append(pg)
local_dp_ranks.append(i)
else:
for i in range(available_engine_count):
if len(placement_groups) == dp_size:
break
bundles = [{"GPU": 1.0}] * world_size + [{"CPU": 1.0}]
pg = ray.util.placement_group(
name=f"dp_rank_{len(placement_groups)}",
strategy="STRICT_PACK",
bundles=bundles,
)
placement_groups.append(pg)
local_dp_ranks.append(i)
return placement_groups, local_dp_ranks
def get_run_refs(self):
return self.run_refs
def close(self):
import ray
for actor in self.local_engine_actors + self.remote_engine_actors:
ray.kill(actor)
for pg in self.created_placement_groups:
ray.util.remove_placement_group(pg)
def wait_for_engine_startup(
@ -383,11 +549,19 @@ def wait_for_engine_startup(
def wait_for_completion_or_failure(
api_server_manager: APIServerProcessManager,
local_engine_manager: Optional[CoreEngineProcManager] = None,
engine_manager: Optional[Union[CoreEngineProcManager,
CoreEngineActorManager]] = None,
coordinator: Optional["DPCoordinator"] = None) -> None:
"""Wait for all processes to complete or detect if any fail.
Raises an exception if any process exits with a non-zero status.
Args:
api_server_manager: The manager for API servers.
engine_manager: The manager for engine processes.
If CoreEngineProcManager, it manages local engines;
if CoreEngineActorManager, it manages all engines.
coordinator: The coordinator for data parallel.
"""
try:
@ -402,14 +576,18 @@ def wait_for_completion_or_failure(
if coordinator:
sentinel_to_proc[coordinator.proc.sentinel] = coordinator.proc
if local_engine_manager:
for proc in local_engine_manager.processes:
actor_run_refs = []
if isinstance(engine_manager, CoreEngineProcManager):
for proc in engine_manager.processes:
sentinel_to_proc[proc.sentinel] = proc
elif isinstance(engine_manager, CoreEngineActorManager):
actor_run_refs = engine_manager.get_run_refs()
# Check if any process terminates
while sentinel_to_proc:
while sentinel_to_proc or actor_run_refs:
# Wait for any process to terminate
ready_sentinels: list[Any] = connection.wait(sentinel_to_proc)
ready_sentinels: list[Any] = connection.wait(sentinel_to_proc,
timeout=5)
# Process any terminated processes
for sentinel in ready_sentinels:
@ -420,6 +598,11 @@ def wait_for_completion_or_failure(
raise RuntimeError(
f"Process {proc.name} (PID: {proc.pid}) "
f"died with exit code {proc.exitcode}")
if actor_run_refs:
import ray
_, actor_run_refs = ray.wait(actor_run_refs, timeout=5)
except KeyboardInterrupt:
logger.info("Received KeyboardInterrupt, shutting down API servers...")
except Exception as e:
@ -431,8 +614,8 @@ def wait_for_completion_or_failure(
api_server_manager.close()
if coordinator:
coordinator.close()
if local_engine_manager:
local_engine_manager.close()
if engine_manager:
engine_manager.close()
# Note(rob): shutdown function cannot be a bound method,