mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-26 07:04:29 +08:00
[V1] Support DP with Ray (#18779)
This commit is contained in:
parent
9e6f61e8c3
commit
bdce64f236
@ -17,7 +17,7 @@ vector_quantize_pytorch # required for minicpmo_26 test
|
||||
vocos # required for minicpmo_26 test
|
||||
peft
|
||||
pqdm
|
||||
ray[cgraph]>=2.43.0, !=2.44.* # Ray Compiled Graph, required by pipeline parallelism tests
|
||||
ray[cgraph,default]>=2.43.0, !=2.44.* # Ray Compiled Graph, required by pipeline parallelism tests
|
||||
sentence-transformers # required for embedding tests
|
||||
soundfile # required for audio tests
|
||||
jiwer # required for audio tests
|
||||
|
||||
@ -10,9 +10,13 @@ aiohappyeyeballs==2.4.3
|
||||
# via aiohttp
|
||||
aiohttp==3.10.11
|
||||
# via
|
||||
# aiohttp-cors
|
||||
# datasets
|
||||
# fsspec
|
||||
# lm-eval
|
||||
# ray
|
||||
aiohttp-cors==0.8.1
|
||||
# via ray
|
||||
aiosignal==1.3.1
|
||||
# via
|
||||
# aiohttp
|
||||
@ -57,6 +61,8 @@ bounded-pool-executor==0.0.3
|
||||
# via pqdm
|
||||
buildkite-test-collector==0.1.9
|
||||
# via -r requirements/test.in
|
||||
cachetools==5.5.2
|
||||
# via google-auth
|
||||
certifi==2024.8.30
|
||||
# via
|
||||
# httpcore
|
||||
@ -81,6 +87,8 @@ colorama==0.4.6
|
||||
# sacrebleu
|
||||
# schemathesis
|
||||
# tqdm-multiprocess
|
||||
colorful==0.5.6
|
||||
# via ray
|
||||
contourpy==1.3.0
|
||||
# via matplotlib
|
||||
cramjam==2.9.0
|
||||
@ -108,6 +116,8 @@ dill==0.3.8
|
||||
# evaluate
|
||||
# lm-eval
|
||||
# multiprocess
|
||||
distlib==0.3.9
|
||||
# via virtualenv
|
||||
dnspython==2.7.0
|
||||
# via email-validator
|
||||
docopt==0.6.2
|
||||
@ -143,6 +153,7 @@ filelock==3.16.1
|
||||
# ray
|
||||
# torch
|
||||
# transformers
|
||||
# virtualenv
|
||||
fonttools==4.54.1
|
||||
# via matplotlib
|
||||
fqdn==1.5.1
|
||||
@ -165,8 +176,16 @@ genai-perf==0.0.8
|
||||
# via -r requirements/test.in
|
||||
genson==1.3.0
|
||||
# via datamodel-code-generator
|
||||
google-api-core==2.24.2
|
||||
# via opencensus
|
||||
google-auth==2.40.2
|
||||
# via google-api-core
|
||||
googleapis-common-protos==1.70.0
|
||||
# via google-api-core
|
||||
graphql-core==3.2.6
|
||||
# via hypothesis-graphql
|
||||
grpcio==1.71.0
|
||||
# via ray
|
||||
h11==0.14.0
|
||||
# via httpcore
|
||||
harfile==0.3.0
|
||||
@ -392,6 +411,10 @@ nvidia-nvjitlink-cu12==12.8.61
|
||||
# torch
|
||||
nvidia-nvtx-cu12==12.8.55
|
||||
# via torch
|
||||
opencensus==0.11.4
|
||||
# via ray
|
||||
opencensus-context==0.1.3
|
||||
# via opencensus
|
||||
opencv-python-headless==4.11.0.86
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
@ -445,6 +468,7 @@ platformdirs==4.3.6
|
||||
# via
|
||||
# black
|
||||
# pooch
|
||||
# virtualenv
|
||||
plotly==5.24.1
|
||||
# via genai-perf
|
||||
pluggy==1.5.0
|
||||
@ -457,10 +481,17 @@ portalocker==2.10.1
|
||||
# via sacrebleu
|
||||
pqdm==0.2.0
|
||||
# via -r requirements/test.in
|
||||
prometheus-client==0.22.0
|
||||
# via ray
|
||||
propcache==0.2.0
|
||||
# via yarl
|
||||
proto-plus==1.26.1
|
||||
# via google-api-core
|
||||
protobuf==5.28.3
|
||||
# via
|
||||
# google-api-core
|
||||
# googleapis-common-protos
|
||||
# proto-plus
|
||||
# ray
|
||||
# tensorizer
|
||||
psutil==6.1.0
|
||||
@ -470,10 +501,18 @@ psutil==6.1.0
|
||||
# tensorizer
|
||||
py==1.11.0
|
||||
# via pytest-forked
|
||||
py-spy==0.4.0
|
||||
# via ray
|
||||
pyarrow==18.0.0
|
||||
# via
|
||||
# datasets
|
||||
# genai-perf
|
||||
pyasn1==0.6.1
|
||||
# via
|
||||
# pyasn1-modules
|
||||
# rsa
|
||||
pyasn1-modules==0.4.2
|
||||
# via google-auth
|
||||
pybind11==2.13.6
|
||||
# via lm-eval
|
||||
pycparser==2.22
|
||||
@ -486,6 +525,7 @@ pydantic==2.11.5
|
||||
# datamodel-code-generator
|
||||
# mistral-common
|
||||
# mteb
|
||||
# ray
|
||||
pydantic-core==2.33.2
|
||||
# via pydantic
|
||||
pygments==2.18.0
|
||||
@ -573,6 +613,7 @@ requests==2.32.3
|
||||
# buildkite-test-collector
|
||||
# datasets
|
||||
# evaluate
|
||||
# google-api-core
|
||||
# huggingface-hub
|
||||
# lm-eval
|
||||
# mistral-common
|
||||
@ -601,6 +642,8 @@ rpds-py==0.20.1
|
||||
# via
|
||||
# jsonschema
|
||||
# referencing
|
||||
rsa==4.9.1
|
||||
# via google-auth
|
||||
runai-model-streamer==0.11.0
|
||||
# via -r requirements/test.in
|
||||
runai-model-streamer-s3==0.11.0
|
||||
@ -648,9 +691,12 @@ shellingham==1.5.4
|
||||
six==1.16.0
|
||||
# via
|
||||
# junit-xml
|
||||
# opencensus
|
||||
# python-dateutil
|
||||
# rfc3339-validator
|
||||
# rouge-score
|
||||
smart-open==7.1.0
|
||||
# via ray
|
||||
sniffio==1.3.1
|
||||
# via
|
||||
# anyio
|
||||
@ -801,6 +847,8 @@ urllib3==2.2.3
|
||||
# tritonclient
|
||||
vector-quantize-pytorch==1.21.2
|
||||
# via -r requirements/test.in
|
||||
virtualenv==20.31.2
|
||||
# via ray
|
||||
vocos==0.1.0
|
||||
# via -r requirements/test.in
|
||||
webcolors==24.11.1
|
||||
@ -809,6 +857,8 @@ werkzeug==3.1.3
|
||||
# via schemathesis
|
||||
word2number==1.1
|
||||
# via lm-eval
|
||||
wrapt==1.17.2
|
||||
# via smart-open
|
||||
xxhash==3.5.0
|
||||
# via
|
||||
# datasets
|
||||
|
||||
@ -59,14 +59,22 @@ async def generate(engine: AsyncLLM,
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
|
||||
"output_kind",
|
||||
[
|
||||
RequestOutputKind.DELTA,
|
||||
RequestOutputKind.FINAL_ONLY,
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("data_parallel_backend", ["mp", "ray"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_load(output_kind: RequestOutputKind):
|
||||
async def test_load(output_kind: RequestOutputKind,
|
||||
data_parallel_backend: str):
|
||||
|
||||
with ExitStack() as after:
|
||||
|
||||
prompt = "This is a test of data parallel"
|
||||
|
||||
engine_args.data_parallel_backend = data_parallel_backend
|
||||
engine = AsyncLLM.from_engine_args(engine_args)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
@ -82,7 +90,6 @@ async def test_load(output_kind: RequestOutputKind):
|
||||
asyncio.create_task(
|
||||
generate(engine, request_id, prompt, output_kind,
|
||||
NUM_EXPECTED_TOKENS)))
|
||||
|
||||
# Confirm that we got all the EXPECTED tokens from the requests.
|
||||
done, pending = await asyncio.wait(tasks,
|
||||
return_when=asyncio.FIRST_EXCEPTION)
|
||||
|
||||
@ -1742,6 +1742,8 @@ class ParallelConfig:
|
||||
"""Port for data parallel messaging."""
|
||||
data_parallel_master_port: int = 29500
|
||||
"""Port of the data parallel master."""
|
||||
data_parallel_backend: str = "mp"
|
||||
"""Backend to use for data parallel, either "mp" or "ray"."""
|
||||
enable_expert_parallel: bool = False
|
||||
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
|
||||
max_parallel_loading_workers: Optional[int] = None
|
||||
@ -1911,6 +1913,10 @@ class ParallelConfig:
|
||||
"please install Ray with `pip install "
|
||||
"ray`.") from ray_utils.ray_import_err
|
||||
backend = "ray"
|
||||
elif self.data_parallel_backend == "ray":
|
||||
logger.info("Using ray distributed inference because "
|
||||
"data_parallel_backend is ray")
|
||||
backend = "ray"
|
||||
elif ray_found:
|
||||
if self.placement_group:
|
||||
backend = "ray"
|
||||
|
||||
@ -39,7 +39,7 @@ from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
|
||||
GiB_bytes, is_in_ray_actor)
|
||||
GiB_bytes, get_ip, is_in_ray_actor)
|
||||
|
||||
# yapf: enable
|
||||
|
||||
@ -292,6 +292,7 @@ class EngineArgs:
|
||||
data_parallel_size_local: Optional[int] = None
|
||||
data_parallel_address: Optional[str] = None
|
||||
data_parallel_rpc_port: Optional[int] = None
|
||||
data_parallel_backend: str = ParallelConfig.data_parallel_backend
|
||||
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
|
||||
max_parallel_loading_workers: Optional[
|
||||
int] = ParallelConfig.max_parallel_loading_workers
|
||||
@ -624,6 +625,12 @@ class EngineArgs:
|
||||
type=int,
|
||||
help='Port for data parallel RPC '
|
||||
'communication.')
|
||||
parallel_group.add_argument('--data-parallel-backend',
|
||||
'-dpb',
|
||||
type=str,
|
||||
default='mp',
|
||||
help='Backend for data parallel, either '
|
||||
'"mp" or "ray".')
|
||||
parallel_group.add_argument(
|
||||
"--enable-expert-parallel",
|
||||
**parallel_kwargs["enable_expert_parallel"])
|
||||
@ -1059,9 +1066,20 @@ class EngineArgs:
|
||||
|
||||
# DP address, used in multi-node case for torch distributed group
|
||||
# and ZMQ sockets.
|
||||
data_parallel_address = self.data_parallel_address if (
|
||||
self.data_parallel_address
|
||||
is not None) else ParallelConfig.data_parallel_master_ip
|
||||
if self.data_parallel_address is None:
|
||||
if self.data_parallel_backend == "ray":
|
||||
host_ip = get_ip()
|
||||
logger.info(
|
||||
"Using host IP %s as ray-based data parallel address",
|
||||
host_ip)
|
||||
data_parallel_address = host_ip
|
||||
else:
|
||||
assert self.data_parallel_backend == "mp", (
|
||||
"data_parallel_backend can only be ray or mp, got %s",
|
||||
self.data_parallel_backend)
|
||||
data_parallel_address = ParallelConfig.data_parallel_master_ip
|
||||
else:
|
||||
data_parallel_address = self.data_parallel_address
|
||||
|
||||
# This port is only used when there are remote data parallel engines,
|
||||
# otherwise the local IPC transport is used.
|
||||
@ -1069,6 +1087,8 @@ class EngineArgs:
|
||||
self.data_parallel_rpc_port
|
||||
is not None) else ParallelConfig.data_parallel_rpc_port
|
||||
|
||||
data_parallel_backend = self.data_parallel_backend
|
||||
|
||||
parallel_config = ParallelConfig(
|
||||
pipeline_parallel_size=self.pipeline_parallel_size,
|
||||
tensor_parallel_size=self.tensor_parallel_size,
|
||||
@ -1076,6 +1096,7 @@ class EngineArgs:
|
||||
data_parallel_size_local=data_parallel_size_local,
|
||||
data_parallel_master_ip=data_parallel_address,
|
||||
data_parallel_rpc_port=data_parallel_rpc_port,
|
||||
data_parallel_backend=data_parallel_backend,
|
||||
enable_expert_parallel=self.enable_expert_parallel,
|
||||
max_parallel_loading_workers=self.max_parallel_loading_workers,
|
||||
disable_custom_all_reduce=self.disable_custom_all_reduce,
|
||||
|
||||
@ -27,7 +27,8 @@ from vllm.v1.engine.core_client import CoreEngineProcManager
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
|
||||
from vllm.v1.utils import (APIServerProcessManager, CoreEngine,
|
||||
EngineZmqAddresses, get_engine_client_zmq_addr,
|
||||
CoreEngineActorManager, EngineZmqAddresses,
|
||||
get_engine_client_zmq_addr,
|
||||
wait_for_completion_or_failure,
|
||||
wait_for_engine_startup)
|
||||
|
||||
@ -229,6 +230,31 @@ def run_multi_api_server(args: argparse.Namespace):
|
||||
logger.info("Started DP Coordinator process (PID: %d)",
|
||||
coordinator.proc.pid)
|
||||
|
||||
if parallel_config.data_parallel_backend == "ray":
|
||||
logger.info("Starting ray-based data parallel backend")
|
||||
|
||||
engine_actor_manager = CoreEngineActorManager(
|
||||
vllm_config=vllm_config,
|
||||
addresses=addresses,
|
||||
executor_class=Executor.get_class(vllm_config),
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
)
|
||||
# Start API servers using the manager
|
||||
api_server_manager = APIServerProcessManager(
|
||||
target_server_fn=run_api_server_worker_proc,
|
||||
listen_address=listen_address,
|
||||
sock=sock,
|
||||
args=args,
|
||||
num_servers=num_api_servers,
|
||||
input_addresses=input_addresses,
|
||||
output_addresses=output_addresses,
|
||||
stats_update_address=stats_update_address)
|
||||
|
||||
wait_for_completion_or_failure(api_server_manager=api_server_manager,
|
||||
engine_manager=engine_actor_manager,
|
||||
coordinator=coordinator)
|
||||
return
|
||||
|
||||
handshake_address = get_engine_client_zmq_addr(
|
||||
local_only, host, parallel_config.data_parallel_rpc_port)
|
||||
|
||||
@ -277,10 +303,9 @@ def run_multi_api_server(args: argparse.Namespace):
|
||||
)
|
||||
|
||||
# Wait for API servers
|
||||
wait_for_completion_or_failure(
|
||||
api_server_manager=api_server_manager,
|
||||
local_engine_manager=local_engine_manager,
|
||||
coordinator=coordinator)
|
||||
wait_for_completion_or_failure(api_server_manager=api_server_manager,
|
||||
engine_manager=local_engine_manager,
|
||||
coordinator=coordinator)
|
||||
|
||||
|
||||
def run_api_server_worker_proc(listen_address,
|
||||
|
||||
@ -27,7 +27,8 @@ from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Device, cdiv
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.core_client import AsyncMPClient, DPAsyncMPClient
|
||||
from vllm.v1.engine.core_client import (AsyncMPClient, DPAsyncMPClient,
|
||||
RayDPClient)
|
||||
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
|
||||
from vllm.v1.engine.output_processor import (OutputProcessor,
|
||||
RequestOutputCollector)
|
||||
@ -119,9 +120,13 @@ class AsyncLLM(EngineClient):
|
||||
log_stats=self.log_stats)
|
||||
|
||||
# EngineCore (starts the engine in background process).
|
||||
core_client_class = AsyncMPClient if (
|
||||
vllm_config.parallel_config.data_parallel_size
|
||||
== 1) else DPAsyncMPClient
|
||||
core_client_class: type[AsyncMPClient]
|
||||
if vllm_config.parallel_config.data_parallel_size == 1:
|
||||
core_client_class = AsyncMPClient
|
||||
elif vllm_config.parallel_config.data_parallel_backend == "ray":
|
||||
core_client_class = RayDPClient
|
||||
else:
|
||||
core_client_class = DPAsyncMPClient
|
||||
|
||||
self.engine_core = core_client_class(
|
||||
vllm_config=vllm_config,
|
||||
|
||||
@ -6,8 +6,9 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from collections.abc import Generator
|
||||
from concurrent.futures import Future
|
||||
from contextlib import ExitStack
|
||||
from contextlib import ExitStack, contextmanager
|
||||
from inspect import isclass, signature
|
||||
from logging import DEBUG
|
||||
from typing import Any, Callable, Optional, TypeVar, Union
|
||||
@ -367,42 +368,66 @@ class EngineCoreProc(EngineCore):
|
||||
log_stats: bool,
|
||||
engine_index: int = 0,
|
||||
):
|
||||
input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
|
||||
|
||||
executor_fail_callback = lambda: input_queue.put_nowait(
|
||||
self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
|
||||
self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs],
|
||||
bytes]]()
|
||||
executor_fail_callback = lambda: self.input_queue.put_nowait(
|
||||
(EngineCoreRequestType.EXECUTOR_FAILED, b''))
|
||||
|
||||
# Create input socket.
|
||||
self.engine_index = engine_index
|
||||
identity = self.engine_index.to_bytes(length=2, byteorder="little")
|
||||
self.engines_running = False
|
||||
|
||||
with self._perform_handshake(handshake_address, identity, on_head_node,
|
||||
vllm_config) as addresses:
|
||||
self.client_count = len(addresses.outputs)
|
||||
|
||||
# Set up data parallel environment.
|
||||
self.has_coordinator = addresses.coordinator_output is not None
|
||||
self._init_data_parallel(vllm_config)
|
||||
|
||||
super().__init__(vllm_config, executor_class, log_stats,
|
||||
executor_fail_callback)
|
||||
|
||||
self.step_fn = (self.step if self.batch_queue is None else
|
||||
self.step_with_batch_queue)
|
||||
|
||||
# Background Threads and Queues for IO. These enable us to
|
||||
# overlap ZMQ socket IO with GPU since they release the GIL,
|
||||
# and to overlap some serialization/deserialization with the
|
||||
# model forward pass.
|
||||
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
|
||||
threading.Thread(target=self.process_input_sockets,
|
||||
args=(addresses.inputs, addresses.coordinator_input,
|
||||
identity),
|
||||
daemon=True).start()
|
||||
self.output_thread = threading.Thread(
|
||||
target=self.process_output_sockets,
|
||||
args=(addresses.outputs, addresses.coordinator_output,
|
||||
self.engine_index),
|
||||
daemon=True)
|
||||
self.output_thread.start()
|
||||
|
||||
@contextmanager
|
||||
def _perform_handshake(
|
||||
self, handshake_address: str, identity: bytes, on_head_node: bool,
|
||||
vllm_config: VllmConfig
|
||||
) -> Generator[EngineZmqAddresses, None, None]:
|
||||
input_ctx = zmq.Context()
|
||||
identity = engine_index.to_bytes(length=2, byteorder="little")
|
||||
with make_zmq_socket(input_ctx,
|
||||
handshake_address,
|
||||
zmq.DEALER,
|
||||
identity=identity,
|
||||
linger=5000,
|
||||
bind=False) as handshake_socket:
|
||||
|
||||
# Register engine with front-end.
|
||||
addresses = self.startup_handshake(handshake_socket, on_head_node,
|
||||
vllm_config.parallel_config)
|
||||
self.client_count = len(addresses.outputs)
|
||||
|
||||
# Update config which may have changed from the handshake.
|
||||
# Update config which may have changed from the handshake
|
||||
vllm_config.__post_init__()
|
||||
|
||||
# Set up data parallel environment.
|
||||
self.has_coordinator = addresses.coordinator_output is not None
|
||||
self._init_data_parallel(vllm_config)
|
||||
|
||||
# Initialize engine core and model.
|
||||
super().__init__(vllm_config, executor_class, log_stats,
|
||||
executor_fail_callback)
|
||||
|
||||
self.engine_index = engine_index
|
||||
self.step_fn = (self.step if self.batch_queue is None else
|
||||
self.step_with_batch_queue)
|
||||
self.engines_running = False
|
||||
self.last_counts = (0, 0)
|
||||
yield addresses
|
||||
|
||||
# Send ready message.
|
||||
num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
|
||||
@ -413,25 +438,6 @@ class EngineCoreProc(EngineCore):
|
||||
"num_gpu_blocks": num_gpu_blocks,
|
||||
}))
|
||||
|
||||
# Background Threads and Queues for IO. These enable us to
|
||||
# overlap ZMQ socket IO with GPU since they release the GIL,
|
||||
# and to overlap some serialization/deserialization with the
|
||||
# model forward pass.
|
||||
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
|
||||
self.input_queue = input_queue
|
||||
self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs],
|
||||
bytes]]()
|
||||
threading.Thread(target=self.process_input_sockets,
|
||||
args=(addresses.inputs, addresses.coordinator_input,
|
||||
identity),
|
||||
daemon=True).start()
|
||||
self.output_thread = threading.Thread(
|
||||
target=self.process_output_sockets,
|
||||
args=(addresses.outputs, addresses.coordinator_output,
|
||||
engine_index),
|
||||
daemon=True)
|
||||
self.output_thread.start()
|
||||
|
||||
@staticmethod
|
||||
def startup_handshake(
|
||||
handshake_socket: zmq.Socket, on_head_node: bool,
|
||||
@ -743,6 +749,21 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
):
|
||||
|
||||
self._decorate_logs()
|
||||
|
||||
# Counts forward-passes of the model so that we can synchronize
|
||||
# finished with DP peers every N steps.
|
||||
self.counter = 0
|
||||
self.current_wave = 0
|
||||
self.last_counts = (0, 0)
|
||||
|
||||
# Initialize the engine.
|
||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
super().__init__(vllm_config, on_head_node, handshake_address,
|
||||
executor_class, log_stats, dp_rank)
|
||||
|
||||
def _decorate_logs(self):
|
||||
# Add process-specific prefix to stdout and stderr before
|
||||
# we initialize the engine.
|
||||
from multiprocessing import current_process
|
||||
@ -751,16 +772,6 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
_add_prefix(sys.stdout, process_name, pid)
|
||||
_add_prefix(sys.stderr, process_name, pid)
|
||||
|
||||
# Counts forward-passes of the model so that we can synchronize
|
||||
# finished with DP peers every N steps.
|
||||
self.counter = 0
|
||||
self.current_wave = 0
|
||||
|
||||
# Initialize the engine.
|
||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
super().__init__(vllm_config, on_head_node, handshake_address,
|
||||
executor_class, log_stats, dp_rank)
|
||||
|
||||
def _init_data_parallel(self, vllm_config: VllmConfig):
|
||||
|
||||
# Configure GPUs and stateless process group for data parallel.
|
||||
@ -880,3 +891,70 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
|
||||
return ParallelConfig.has_unfinished_dp(self.dp_group,
|
||||
local_unfinished)
|
||||
|
||||
|
||||
class DPEngineCoreActor(DPEngineCoreProc):
|
||||
"""
|
||||
Ray actor for running EngineCore in a data parallel context
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
on_head_node: bool,
|
||||
addresses: EngineZmqAddresses,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
dp_rank: int = 0,
|
||||
local_dp_rank: int = 0,
|
||||
):
|
||||
self.addresses = addresses
|
||||
vllm_config.parallel_config.data_parallel_rank = dp_rank
|
||||
vllm_config.parallel_config.data_parallel_rank_local = \
|
||||
local_dp_rank
|
||||
|
||||
# Ray sets CUDA_VISIBLE_DEVICES to empty string,
|
||||
# we clean this up to be able to properly initialize
|
||||
# data parallel groups.
|
||||
del os.environ['CUDA_VISIBLE_DEVICES']
|
||||
|
||||
super().__init__(vllm_config, on_head_node, "", executor_class,
|
||||
log_stats)
|
||||
|
||||
def _decorate_logs(self):
|
||||
pass
|
||||
|
||||
@contextmanager
|
||||
def _perform_handshake(self, handshake_address: str, identity: bytes,
|
||||
on_head_node: bool, vllm_config: VllmConfig):
|
||||
"""
|
||||
For Ray, we don't need to actually perform handshake.
|
||||
All addresses information is known before the actor creation.
|
||||
Therefore, we simply yield these addresses.
|
||||
"""
|
||||
yield self.addresses
|
||||
|
||||
def wait_for_init(self):
|
||||
"""
|
||||
Wait until the engine core is initialized.
|
||||
|
||||
This is just an empty method. When ray.get() on this method
|
||||
(or any other method of the actor) returns, it is guaranteed
|
||||
that actor creation (i.e., __init__) is complete.
|
||||
"""
|
||||
pass
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
Run the engine core busy loop.
|
||||
"""
|
||||
try:
|
||||
self.run_busy_loop()
|
||||
except SystemExit:
|
||||
logger.debug("EngineCore exiting.")
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("EngineCore encountered a fatal error.")
|
||||
raise
|
||||
finally:
|
||||
self.shutdown()
|
||||
|
||||
@ -29,9 +29,9 @@ from vllm.v1.engine.core import EngineCore, EngineCoreProc
|
||||
from vllm.v1.engine.exceptions import EngineDeadError
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr
|
||||
from vllm.v1.utils import (CoreEngine, CoreEngineProcManager,
|
||||
EngineZmqAddresses, get_engine_client_zmq_addr,
|
||||
wait_for_engine_startup)
|
||||
from vllm.v1.utils import (CoreEngine, CoreEngineActorManager,
|
||||
CoreEngineProcManager, EngineZmqAddresses,
|
||||
get_engine_client_zmq_addr, wait_for_engine_startup)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -68,6 +68,8 @@ class EngineCoreClient(ABC):
|
||||
|
||||
if multiprocess_mode and asyncio_mode:
|
||||
if vllm_config.parallel_config.data_parallel_size > 1:
|
||||
if vllm_config.parallel_config.data_parallel_backend == "ray":
|
||||
return RayDPClient(vllm_config, executor_class, log_stats)
|
||||
return DPAsyncMPClient(vllm_config, executor_class, log_stats)
|
||||
|
||||
return AsyncMPClient(vllm_config, executor_class, log_stats)
|
||||
@ -273,7 +275,10 @@ class BackgroundResources:
|
||||
circular reference back to the client object."""
|
||||
|
||||
ctx: Union[zmq.Context]
|
||||
local_engine_manager: Optional[CoreEngineProcManager] = None
|
||||
# If CoreEngineProcManager, it manages local engines;
|
||||
# if CoreEngineActorManager, it manages all engines.
|
||||
engine_manager: Optional[Union[CoreEngineProcManager,
|
||||
CoreEngineActorManager]] = None
|
||||
coordinator: Optional[DPCoordinator] = None
|
||||
output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
|
||||
input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
|
||||
@ -290,8 +295,8 @@ class BackgroundResources:
|
||||
"""Clean up background resources."""
|
||||
|
||||
self.engine_dead = True
|
||||
if self.local_engine_manager is not None:
|
||||
self.local_engine_manager.close()
|
||||
if self.engine_manager is not None:
|
||||
self.engine_manager.close()
|
||||
if self.coordinator is not None:
|
||||
self.coordinator.close()
|
||||
|
||||
@ -457,7 +462,7 @@ class MPClient(EngineCoreClient):
|
||||
if local_engine_count:
|
||||
# In server mode, start_index and local_start_index will
|
||||
# both be 0.
|
||||
self.resources.local_engine_manager = CoreEngineProcManager(
|
||||
self.resources.engine_manager = CoreEngineProcManager(
|
||||
EngineCoreProc.run_engine_core,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
@ -484,13 +489,18 @@ class MPClient(EngineCoreClient):
|
||||
addresses.coordinator_input, addresses.coordinator_output = (
|
||||
coordinator.get_engine_socket_addresses())
|
||||
|
||||
proc_manager = self.resources.engine_manager
|
||||
assert isinstance(proc_manager, (type(None), CoreEngineProcManager)), (
|
||||
"_wait_for_engine_startup should only be "
|
||||
"called with CoreEngineProcManager")
|
||||
|
||||
wait_for_engine_startup(
|
||||
handshake_socket,
|
||||
addresses,
|
||||
self.core_engines,
|
||||
self.vllm_config.parallel_config,
|
||||
self.vllm_config.cache_config,
|
||||
self.resources.local_engine_manager,
|
||||
proc_manager,
|
||||
coordinator.proc if coordinator else None,
|
||||
)
|
||||
|
||||
@ -887,7 +897,6 @@ class DPAsyncMPClient(AsyncMPClient):
|
||||
log_stats: bool,
|
||||
client_addresses: Optional[dict[str, str]] = None,
|
||||
client_index: int = 0):
|
||||
|
||||
self.current_wave = 0
|
||||
self.engines_running = False
|
||||
# To route aborts to the correct engine.
|
||||
@ -1050,3 +1059,50 @@ class DPAsyncMPClient(AsyncMPClient):
|
||||
if not self.resources.engine_dead:
|
||||
await self._send_input(EngineCoreRequestType.ABORT, request_ids,
|
||||
engine)
|
||||
|
||||
|
||||
class RayDPClient(DPAsyncMPClient):
|
||||
"""
|
||||
Ray-based client for multi-proc, multi-engine (data parallel)
|
||||
EngineCore.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
client_addresses: Optional[dict[str, str]] = None,
|
||||
client_index: int = 0,
|
||||
):
|
||||
super().__init__(vllm_config, executor_class, log_stats,
|
||||
client_addresses, client_index)
|
||||
|
||||
def _init_engines_direct(self, vllm_config: VllmConfig, local_only: bool,
|
||||
local_start_index: int, input_address: str,
|
||||
output_address: str,
|
||||
executor_class: type[Executor], log_stats: bool):
|
||||
"""Self-contained client mode, launch engine and coordinator process
|
||||
as needed."""
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
assert parallel_config.data_parallel_rank == 0
|
||||
assert local_start_index == 0
|
||||
|
||||
addresses = EngineZmqAddresses(
|
||||
inputs=[input_address],
|
||||
outputs=[output_address],
|
||||
)
|
||||
|
||||
if len(self.core_engines) > 1:
|
||||
coordinator = DPCoordinator(parallel_config)
|
||||
self.resources.coordinator = coordinator
|
||||
addresses.coordinator_input, addresses.coordinator_output = (
|
||||
coordinator.get_engine_socket_addresses())
|
||||
|
||||
# Start all engines.
|
||||
self.resources.engine_manager = CoreEngineActorManager(
|
||||
vllm_config=vllm_config,
|
||||
addresses=addresses,
|
||||
executor_class=executor_class,
|
||||
log_stats=log_stats)
|
||||
|
||||
269
vllm/v1/utils.py
269
vllm/v1/utils.py
@ -27,6 +27,8 @@ from vllm.utils import (get_mp_context, get_open_port, get_open_zmq_ipc_path,
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.v1.engine.coordinator import DPCoordinator
|
||||
|
||||
@ -112,6 +114,45 @@ def get_engine_client_zmq_addr(local_only: bool,
|
||||
host, port or get_open_port()))
|
||||
|
||||
|
||||
class CoreEngineState(Enum):
|
||||
NEW = auto()
|
||||
CONNECTED = auto()
|
||||
READY = auto()
|
||||
|
||||
|
||||
class CoreEngine:
|
||||
"""One per data parallel rank."""
|
||||
|
||||
def __init__(self, index: int = 0, local: bool = True):
|
||||
self.local = local
|
||||
self.index = index
|
||||
self.identity = index.to_bytes(2, "little")
|
||||
|
||||
self.state = CoreEngineState.NEW
|
||||
|
||||
|
||||
@dataclass
|
||||
class EngineZmqAddresses:
|
||||
# ZMQ input socket addresses for each front-end client (requests)
|
||||
inputs: list[str]
|
||||
# ZMQ output socket addresses for each front-end client (responses)
|
||||
outputs: list[str]
|
||||
# ZMQ input socket address of DP coordinator if applicable
|
||||
coordinator_input: Optional[str] = None
|
||||
# ZMQ output socket address of DP coordinator if applicable
|
||||
coordinator_output: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EngineHandshakeMetadata:
|
||||
"""Metadata sent to each engine process during startup handshake,
|
||||
including addresses of the front-end ZMQ queues that they should
|
||||
connect to.
|
||||
"""
|
||||
addresses: EngineZmqAddresses
|
||||
parallel_config: dict[str, Union[int, str]]
|
||||
|
||||
|
||||
class APIServerProcessManager:
|
||||
"""Manages a group of API server processes.
|
||||
|
||||
@ -245,43 +286,168 @@ class CoreEngineProcManager:
|
||||
}
|
||||
|
||||
|
||||
class CoreEngineState(Enum):
|
||||
NEW = auto()
|
||||
CONNECTED = auto()
|
||||
READY = auto()
|
||||
|
||||
|
||||
class CoreEngine:
|
||||
"""One per data parallel rank."""
|
||||
|
||||
def __init__(self, index: int = 0, local: bool = True):
|
||||
self.local = local
|
||||
self.index = index
|
||||
self.identity = index.to_bytes(2, "little")
|
||||
|
||||
self.state = CoreEngineState.NEW
|
||||
|
||||
|
||||
@dataclass
|
||||
class EngineZmqAddresses:
|
||||
# ZMQ input socket addresses for each front-end client (requests)
|
||||
inputs: list[str]
|
||||
# ZMQ output socket addresses for each front-end client (responses)
|
||||
outputs: list[str]
|
||||
# ZMQ input socket address of DP coordinator if applicable
|
||||
coordinator_input: Optional[str] = None
|
||||
# ZMQ output socket address of DP coordinator if applicable
|
||||
coordinator_output: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EngineHandshakeMetadata:
|
||||
"""Metadata sent to each engine process during startup handshake,
|
||||
including addresses of the front-end ZMQ queues that they should
|
||||
connect to.
|
||||
class CoreEngineActorManager:
|
||||
"""
|
||||
addresses: EngineZmqAddresses
|
||||
parallel_config: dict[str, Union[int, str]]
|
||||
Utility class to handle creation, readiness, and shutdown
|
||||
of core engine Ray actors used by the AsyncLLM and LLMEngine.
|
||||
|
||||
Different from CoreEngineProcManager, this class manages
|
||||
core engines for both local and remote nodes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
addresses: EngineZmqAddresses,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
placement_groups: Optional[list["PlacementGroup"]] = None,
|
||||
local_dp_ranks: Optional[list[int]] = None,
|
||||
):
|
||||
import copy
|
||||
|
||||
import ray
|
||||
from ray.util.scheduling_strategies import (
|
||||
PlacementGroupSchedulingStrategy)
|
||||
|
||||
from vllm.v1.engine.core import DPEngineCoreActor
|
||||
|
||||
self.local_engine_actors: list[ray.ActorHandle] = []
|
||||
self.remote_engine_actors: list[ray.ActorHandle] = []
|
||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
local_engine_count = \
|
||||
vllm_config.parallel_config.data_parallel_size_local
|
||||
world_size = vllm_config.parallel_config.world_size
|
||||
|
||||
if ray.is_initialized():
|
||||
logger.info(
|
||||
"Ray is already initialized. Skipping Ray initialization.")
|
||||
else:
|
||||
ray.init()
|
||||
|
||||
if placement_groups is not None:
|
||||
assert local_dp_ranks is not None, (
|
||||
"local_dp_ranks must be provided if "
|
||||
"placement_groups is provided")
|
||||
assert len(placement_groups) == len(local_dp_ranks), (
|
||||
"placement_groups and local_dp_ranks must "
|
||||
"have the same length")
|
||||
logger.info("Using provided placement groups")
|
||||
# TODO(rui): validate passed-in placement groups
|
||||
self.created_placement_groups = []
|
||||
else:
|
||||
placement_groups, local_dp_ranks = \
|
||||
CoreEngineActorManager.create_dp_placement_groups(vllm_config)
|
||||
self.created_placement_groups = placement_groups
|
||||
assert len(placement_groups) == dp_size, (
|
||||
"Number of placement groups must match data parallel size")
|
||||
|
||||
refs = []
|
||||
for index in range(dp_size):
|
||||
local_index = local_dp_ranks[index]
|
||||
dp_vllm_config = copy.deepcopy(vllm_config)
|
||||
pg = placement_groups[index]
|
||||
dp_vllm_config.parallel_config.placement_group = pg
|
||||
on_head_node = index < local_engine_count
|
||||
actor = ray.remote(DPEngineCoreActor).options(
|
||||
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||
placement_group=pg,
|
||||
placement_group_bundle_index=world_size,
|
||||
)).remote(vllm_config=dp_vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=log_stats,
|
||||
on_head_node=on_head_node,
|
||||
addresses=addresses,
|
||||
dp_rank=index,
|
||||
local_dp_rank=local_index)
|
||||
if on_head_node:
|
||||
self.local_engine_actors.append(actor)
|
||||
else:
|
||||
self.remote_engine_actors.append(actor)
|
||||
refs.append(actor.wait_for_init.remote())
|
||||
|
||||
ray.get(refs)
|
||||
self.run_refs = []
|
||||
for actor in self.local_engine_actors + self.remote_engine_actors:
|
||||
self.run_refs.append(actor.run.remote())
|
||||
|
||||
@staticmethod
|
||||
def create_dp_placement_groups(
|
||||
vllm_config: VllmConfig
|
||||
) -> tuple[list["PlacementGroup"], list[int]]:
|
||||
|
||||
import ray
|
||||
from ray._private.state import available_resources_per_node
|
||||
from ray.util.state import list_nodes
|
||||
|
||||
logger.info("Creating placement groups for data parallel")
|
||||
dp_master_ip = \
|
||||
vllm_config.parallel_config.data_parallel_master_ip
|
||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
local_engine_count = \
|
||||
vllm_config.parallel_config.data_parallel_size_local
|
||||
|
||||
nodes = list_nodes()
|
||||
nodes = sorted(list_nodes(),
|
||||
key=lambda node: node.node_ip != dp_master_ip)
|
||||
assert nodes[0].node_ip == dp_master_ip, (
|
||||
"The first node must be the head node")
|
||||
assert len(nodes) == 1 or nodes[1].node_ip != dp_master_ip, (
|
||||
"There can only be one head node")
|
||||
|
||||
available_resources = available_resources_per_node()
|
||||
world_size = vllm_config.parallel_config.world_size
|
||||
placement_groups: list[PlacementGroup] = []
|
||||
local_dp_ranks: list[int] = []
|
||||
|
||||
for node in nodes:
|
||||
node_ip = node.node_ip
|
||||
node_resources = available_resources[node.node_id]
|
||||
# For now, each DP rank can only be assigned to one node
|
||||
# TODO(rui): support allocating a single DP rank
|
||||
# to multiple nodes
|
||||
available_engine_count = node_resources["GPU"] // world_size
|
||||
if node_ip == dp_master_ip:
|
||||
assert available_engine_count >= local_engine_count, (
|
||||
"Not enough resources to allocate DP ranks "
|
||||
f"on DP master node {node_ip}")
|
||||
for i in range(local_engine_count):
|
||||
bundles = [{
|
||||
"GPU": 1.0,
|
||||
"node:" + dp_master_ip: 0.001
|
||||
}] * world_size + [{
|
||||
"CPU": 1.0
|
||||
}]
|
||||
pg = ray.util.placement_group(
|
||||
name=f"dp_rank_{len(placement_groups)}",
|
||||
strategy="STRICT_PACK",
|
||||
bundles=bundles,
|
||||
)
|
||||
placement_groups.append(pg)
|
||||
local_dp_ranks.append(i)
|
||||
else:
|
||||
for i in range(available_engine_count):
|
||||
if len(placement_groups) == dp_size:
|
||||
break
|
||||
bundles = [{"GPU": 1.0}] * world_size + [{"CPU": 1.0}]
|
||||
pg = ray.util.placement_group(
|
||||
name=f"dp_rank_{len(placement_groups)}",
|
||||
strategy="STRICT_PACK",
|
||||
bundles=bundles,
|
||||
)
|
||||
placement_groups.append(pg)
|
||||
local_dp_ranks.append(i)
|
||||
return placement_groups, local_dp_ranks
|
||||
|
||||
def get_run_refs(self):
|
||||
return self.run_refs
|
||||
|
||||
def close(self):
|
||||
import ray
|
||||
for actor in self.local_engine_actors + self.remote_engine_actors:
|
||||
ray.kill(actor)
|
||||
for pg in self.created_placement_groups:
|
||||
ray.util.remove_placement_group(pg)
|
||||
|
||||
|
||||
def wait_for_engine_startup(
|
||||
@ -383,11 +549,19 @@ def wait_for_engine_startup(
|
||||
|
||||
def wait_for_completion_or_failure(
|
||||
api_server_manager: APIServerProcessManager,
|
||||
local_engine_manager: Optional[CoreEngineProcManager] = None,
|
||||
engine_manager: Optional[Union[CoreEngineProcManager,
|
||||
CoreEngineActorManager]] = None,
|
||||
coordinator: Optional["DPCoordinator"] = None) -> None:
|
||||
"""Wait for all processes to complete or detect if any fail.
|
||||
|
||||
Raises an exception if any process exits with a non-zero status.
|
||||
|
||||
Args:
|
||||
api_server_manager: The manager for API servers.
|
||||
engine_manager: The manager for engine processes.
|
||||
If CoreEngineProcManager, it manages local engines;
|
||||
if CoreEngineActorManager, it manages all engines.
|
||||
coordinator: The coordinator for data parallel.
|
||||
"""
|
||||
|
||||
try:
|
||||
@ -402,14 +576,18 @@ def wait_for_completion_or_failure(
|
||||
if coordinator:
|
||||
sentinel_to_proc[coordinator.proc.sentinel] = coordinator.proc
|
||||
|
||||
if local_engine_manager:
|
||||
for proc in local_engine_manager.processes:
|
||||
actor_run_refs = []
|
||||
if isinstance(engine_manager, CoreEngineProcManager):
|
||||
for proc in engine_manager.processes:
|
||||
sentinel_to_proc[proc.sentinel] = proc
|
||||
elif isinstance(engine_manager, CoreEngineActorManager):
|
||||
actor_run_refs = engine_manager.get_run_refs()
|
||||
|
||||
# Check if any process terminates
|
||||
while sentinel_to_proc:
|
||||
while sentinel_to_proc or actor_run_refs:
|
||||
# Wait for any process to terminate
|
||||
ready_sentinels: list[Any] = connection.wait(sentinel_to_proc)
|
||||
ready_sentinels: list[Any] = connection.wait(sentinel_to_proc,
|
||||
timeout=5)
|
||||
|
||||
# Process any terminated processes
|
||||
for sentinel in ready_sentinels:
|
||||
@ -420,6 +598,11 @@ def wait_for_completion_or_failure(
|
||||
raise RuntimeError(
|
||||
f"Process {proc.name} (PID: {proc.pid}) "
|
||||
f"died with exit code {proc.exitcode}")
|
||||
|
||||
if actor_run_refs:
|
||||
import ray
|
||||
_, actor_run_refs = ray.wait(actor_run_refs, timeout=5)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received KeyboardInterrupt, shutting down API servers...")
|
||||
except Exception as e:
|
||||
@ -431,8 +614,8 @@ def wait_for_completion_or_failure(
|
||||
api_server_manager.close()
|
||||
if coordinator:
|
||||
coordinator.close()
|
||||
if local_engine_manager:
|
||||
local_engine_manager.close()
|
||||
if engine_manager:
|
||||
engine_manager.close()
|
||||
|
||||
|
||||
# Note(rob): shutdown function cannot be a bound method,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user