mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 16:49:08 +08:00
[V1] Support DP with Ray (#18779)
This commit is contained in:
parent
9e6f61e8c3
commit
bdce64f236
@ -17,7 +17,7 @@ vector_quantize_pytorch # required for minicpmo_26 test
|
|||||||
vocos # required for minicpmo_26 test
|
vocos # required for minicpmo_26 test
|
||||||
peft
|
peft
|
||||||
pqdm
|
pqdm
|
||||||
ray[cgraph]>=2.43.0, !=2.44.* # Ray Compiled Graph, required by pipeline parallelism tests
|
ray[cgraph,default]>=2.43.0, !=2.44.* # Ray Compiled Graph, required by pipeline parallelism tests
|
||||||
sentence-transformers # required for embedding tests
|
sentence-transformers # required for embedding tests
|
||||||
soundfile # required for audio tests
|
soundfile # required for audio tests
|
||||||
jiwer # required for audio tests
|
jiwer # required for audio tests
|
||||||
|
|||||||
@ -10,9 +10,13 @@ aiohappyeyeballs==2.4.3
|
|||||||
# via aiohttp
|
# via aiohttp
|
||||||
aiohttp==3.10.11
|
aiohttp==3.10.11
|
||||||
# via
|
# via
|
||||||
|
# aiohttp-cors
|
||||||
# datasets
|
# datasets
|
||||||
# fsspec
|
# fsspec
|
||||||
# lm-eval
|
# lm-eval
|
||||||
|
# ray
|
||||||
|
aiohttp-cors==0.8.1
|
||||||
|
# via ray
|
||||||
aiosignal==1.3.1
|
aiosignal==1.3.1
|
||||||
# via
|
# via
|
||||||
# aiohttp
|
# aiohttp
|
||||||
@ -57,6 +61,8 @@ bounded-pool-executor==0.0.3
|
|||||||
# via pqdm
|
# via pqdm
|
||||||
buildkite-test-collector==0.1.9
|
buildkite-test-collector==0.1.9
|
||||||
# via -r requirements/test.in
|
# via -r requirements/test.in
|
||||||
|
cachetools==5.5.2
|
||||||
|
# via google-auth
|
||||||
certifi==2024.8.30
|
certifi==2024.8.30
|
||||||
# via
|
# via
|
||||||
# httpcore
|
# httpcore
|
||||||
@ -81,6 +87,8 @@ colorama==0.4.6
|
|||||||
# sacrebleu
|
# sacrebleu
|
||||||
# schemathesis
|
# schemathesis
|
||||||
# tqdm-multiprocess
|
# tqdm-multiprocess
|
||||||
|
colorful==0.5.6
|
||||||
|
# via ray
|
||||||
contourpy==1.3.0
|
contourpy==1.3.0
|
||||||
# via matplotlib
|
# via matplotlib
|
||||||
cramjam==2.9.0
|
cramjam==2.9.0
|
||||||
@ -108,6 +116,8 @@ dill==0.3.8
|
|||||||
# evaluate
|
# evaluate
|
||||||
# lm-eval
|
# lm-eval
|
||||||
# multiprocess
|
# multiprocess
|
||||||
|
distlib==0.3.9
|
||||||
|
# via virtualenv
|
||||||
dnspython==2.7.0
|
dnspython==2.7.0
|
||||||
# via email-validator
|
# via email-validator
|
||||||
docopt==0.6.2
|
docopt==0.6.2
|
||||||
@ -143,6 +153,7 @@ filelock==3.16.1
|
|||||||
# ray
|
# ray
|
||||||
# torch
|
# torch
|
||||||
# transformers
|
# transformers
|
||||||
|
# virtualenv
|
||||||
fonttools==4.54.1
|
fonttools==4.54.1
|
||||||
# via matplotlib
|
# via matplotlib
|
||||||
fqdn==1.5.1
|
fqdn==1.5.1
|
||||||
@ -165,8 +176,16 @@ genai-perf==0.0.8
|
|||||||
# via -r requirements/test.in
|
# via -r requirements/test.in
|
||||||
genson==1.3.0
|
genson==1.3.0
|
||||||
# via datamodel-code-generator
|
# via datamodel-code-generator
|
||||||
|
google-api-core==2.24.2
|
||||||
|
# via opencensus
|
||||||
|
google-auth==2.40.2
|
||||||
|
# via google-api-core
|
||||||
|
googleapis-common-protos==1.70.0
|
||||||
|
# via google-api-core
|
||||||
graphql-core==3.2.6
|
graphql-core==3.2.6
|
||||||
# via hypothesis-graphql
|
# via hypothesis-graphql
|
||||||
|
grpcio==1.71.0
|
||||||
|
# via ray
|
||||||
h11==0.14.0
|
h11==0.14.0
|
||||||
# via httpcore
|
# via httpcore
|
||||||
harfile==0.3.0
|
harfile==0.3.0
|
||||||
@ -392,6 +411,10 @@ nvidia-nvjitlink-cu12==12.8.61
|
|||||||
# torch
|
# torch
|
||||||
nvidia-nvtx-cu12==12.8.55
|
nvidia-nvtx-cu12==12.8.55
|
||||||
# via torch
|
# via torch
|
||||||
|
opencensus==0.11.4
|
||||||
|
# via ray
|
||||||
|
opencensus-context==0.1.3
|
||||||
|
# via opencensus
|
||||||
opencv-python-headless==4.11.0.86
|
opencv-python-headless==4.11.0.86
|
||||||
# via
|
# via
|
||||||
# -r requirements/test.in
|
# -r requirements/test.in
|
||||||
@ -445,6 +468,7 @@ platformdirs==4.3.6
|
|||||||
# via
|
# via
|
||||||
# black
|
# black
|
||||||
# pooch
|
# pooch
|
||||||
|
# virtualenv
|
||||||
plotly==5.24.1
|
plotly==5.24.1
|
||||||
# via genai-perf
|
# via genai-perf
|
||||||
pluggy==1.5.0
|
pluggy==1.5.0
|
||||||
@ -457,10 +481,17 @@ portalocker==2.10.1
|
|||||||
# via sacrebleu
|
# via sacrebleu
|
||||||
pqdm==0.2.0
|
pqdm==0.2.0
|
||||||
# via -r requirements/test.in
|
# via -r requirements/test.in
|
||||||
|
prometheus-client==0.22.0
|
||||||
|
# via ray
|
||||||
propcache==0.2.0
|
propcache==0.2.0
|
||||||
# via yarl
|
# via yarl
|
||||||
|
proto-plus==1.26.1
|
||||||
|
# via google-api-core
|
||||||
protobuf==5.28.3
|
protobuf==5.28.3
|
||||||
# via
|
# via
|
||||||
|
# google-api-core
|
||||||
|
# googleapis-common-protos
|
||||||
|
# proto-plus
|
||||||
# ray
|
# ray
|
||||||
# tensorizer
|
# tensorizer
|
||||||
psutil==6.1.0
|
psutil==6.1.0
|
||||||
@ -470,10 +501,18 @@ psutil==6.1.0
|
|||||||
# tensorizer
|
# tensorizer
|
||||||
py==1.11.0
|
py==1.11.0
|
||||||
# via pytest-forked
|
# via pytest-forked
|
||||||
|
py-spy==0.4.0
|
||||||
|
# via ray
|
||||||
pyarrow==18.0.0
|
pyarrow==18.0.0
|
||||||
# via
|
# via
|
||||||
# datasets
|
# datasets
|
||||||
# genai-perf
|
# genai-perf
|
||||||
|
pyasn1==0.6.1
|
||||||
|
# via
|
||||||
|
# pyasn1-modules
|
||||||
|
# rsa
|
||||||
|
pyasn1-modules==0.4.2
|
||||||
|
# via google-auth
|
||||||
pybind11==2.13.6
|
pybind11==2.13.6
|
||||||
# via lm-eval
|
# via lm-eval
|
||||||
pycparser==2.22
|
pycparser==2.22
|
||||||
@ -486,6 +525,7 @@ pydantic==2.11.5
|
|||||||
# datamodel-code-generator
|
# datamodel-code-generator
|
||||||
# mistral-common
|
# mistral-common
|
||||||
# mteb
|
# mteb
|
||||||
|
# ray
|
||||||
pydantic-core==2.33.2
|
pydantic-core==2.33.2
|
||||||
# via pydantic
|
# via pydantic
|
||||||
pygments==2.18.0
|
pygments==2.18.0
|
||||||
@ -573,6 +613,7 @@ requests==2.32.3
|
|||||||
# buildkite-test-collector
|
# buildkite-test-collector
|
||||||
# datasets
|
# datasets
|
||||||
# evaluate
|
# evaluate
|
||||||
|
# google-api-core
|
||||||
# huggingface-hub
|
# huggingface-hub
|
||||||
# lm-eval
|
# lm-eval
|
||||||
# mistral-common
|
# mistral-common
|
||||||
@ -601,6 +642,8 @@ rpds-py==0.20.1
|
|||||||
# via
|
# via
|
||||||
# jsonschema
|
# jsonschema
|
||||||
# referencing
|
# referencing
|
||||||
|
rsa==4.9.1
|
||||||
|
# via google-auth
|
||||||
runai-model-streamer==0.11.0
|
runai-model-streamer==0.11.0
|
||||||
# via -r requirements/test.in
|
# via -r requirements/test.in
|
||||||
runai-model-streamer-s3==0.11.0
|
runai-model-streamer-s3==0.11.0
|
||||||
@ -648,9 +691,12 @@ shellingham==1.5.4
|
|||||||
six==1.16.0
|
six==1.16.0
|
||||||
# via
|
# via
|
||||||
# junit-xml
|
# junit-xml
|
||||||
|
# opencensus
|
||||||
# python-dateutil
|
# python-dateutil
|
||||||
# rfc3339-validator
|
# rfc3339-validator
|
||||||
# rouge-score
|
# rouge-score
|
||||||
|
smart-open==7.1.0
|
||||||
|
# via ray
|
||||||
sniffio==1.3.1
|
sniffio==1.3.1
|
||||||
# via
|
# via
|
||||||
# anyio
|
# anyio
|
||||||
@ -801,6 +847,8 @@ urllib3==2.2.3
|
|||||||
# tritonclient
|
# tritonclient
|
||||||
vector-quantize-pytorch==1.21.2
|
vector-quantize-pytorch==1.21.2
|
||||||
# via -r requirements/test.in
|
# via -r requirements/test.in
|
||||||
|
virtualenv==20.31.2
|
||||||
|
# via ray
|
||||||
vocos==0.1.0
|
vocos==0.1.0
|
||||||
# via -r requirements/test.in
|
# via -r requirements/test.in
|
||||||
webcolors==24.11.1
|
webcolors==24.11.1
|
||||||
@ -809,6 +857,8 @@ werkzeug==3.1.3
|
|||||||
# via schemathesis
|
# via schemathesis
|
||||||
word2number==1.1
|
word2number==1.1
|
||||||
# via lm-eval
|
# via lm-eval
|
||||||
|
wrapt==1.17.2
|
||||||
|
# via smart-open
|
||||||
xxhash==3.5.0
|
xxhash==3.5.0
|
||||||
# via
|
# via
|
||||||
# datasets
|
# datasets
|
||||||
|
|||||||
@ -59,14 +59,22 @@ async def generate(engine: AsyncLLM,
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
|
"output_kind",
|
||||||
|
[
|
||||||
|
RequestOutputKind.DELTA,
|
||||||
|
RequestOutputKind.FINAL_ONLY,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("data_parallel_backend", ["mp", "ray"])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_load(output_kind: RequestOutputKind):
|
async def test_load(output_kind: RequestOutputKind,
|
||||||
|
data_parallel_backend: str):
|
||||||
|
|
||||||
with ExitStack() as after:
|
with ExitStack() as after:
|
||||||
|
|
||||||
prompt = "This is a test of data parallel"
|
prompt = "This is a test of data parallel"
|
||||||
|
|
||||||
|
engine_args.data_parallel_backend = data_parallel_backend
|
||||||
engine = AsyncLLM.from_engine_args(engine_args)
|
engine = AsyncLLM.from_engine_args(engine_args)
|
||||||
after.callback(engine.shutdown)
|
after.callback(engine.shutdown)
|
||||||
|
|
||||||
@ -82,7 +90,6 @@ async def test_load(output_kind: RequestOutputKind):
|
|||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
generate(engine, request_id, prompt, output_kind,
|
generate(engine, request_id, prompt, output_kind,
|
||||||
NUM_EXPECTED_TOKENS)))
|
NUM_EXPECTED_TOKENS)))
|
||||||
|
|
||||||
# Confirm that we got all the EXPECTED tokens from the requests.
|
# Confirm that we got all the EXPECTED tokens from the requests.
|
||||||
done, pending = await asyncio.wait(tasks,
|
done, pending = await asyncio.wait(tasks,
|
||||||
return_when=asyncio.FIRST_EXCEPTION)
|
return_when=asyncio.FIRST_EXCEPTION)
|
||||||
|
|||||||
@ -1742,6 +1742,8 @@ class ParallelConfig:
|
|||||||
"""Port for data parallel messaging."""
|
"""Port for data parallel messaging."""
|
||||||
data_parallel_master_port: int = 29500
|
data_parallel_master_port: int = 29500
|
||||||
"""Port of the data parallel master."""
|
"""Port of the data parallel master."""
|
||||||
|
data_parallel_backend: str = "mp"
|
||||||
|
"""Backend to use for data parallel, either "mp" or "ray"."""
|
||||||
enable_expert_parallel: bool = False
|
enable_expert_parallel: bool = False
|
||||||
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
|
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
|
||||||
max_parallel_loading_workers: Optional[int] = None
|
max_parallel_loading_workers: Optional[int] = None
|
||||||
@ -1911,6 +1913,10 @@ class ParallelConfig:
|
|||||||
"please install Ray with `pip install "
|
"please install Ray with `pip install "
|
||||||
"ray`.") from ray_utils.ray_import_err
|
"ray`.") from ray_utils.ray_import_err
|
||||||
backend = "ray"
|
backend = "ray"
|
||||||
|
elif self.data_parallel_backend == "ray":
|
||||||
|
logger.info("Using ray distributed inference because "
|
||||||
|
"data_parallel_backend is ray")
|
||||||
|
backend = "ray"
|
||||||
elif ray_found:
|
elif ray_found:
|
||||||
if self.placement_group:
|
if self.placement_group:
|
||||||
backend = "ray"
|
backend = "ray"
|
||||||
|
|||||||
@ -39,7 +39,7 @@ from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
|
|||||||
from vllm.transformers_utils.utils import check_gguf_file
|
from vllm.transformers_utils.utils import check_gguf_file
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
|
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
|
||||||
GiB_bytes, is_in_ray_actor)
|
GiB_bytes, get_ip, is_in_ray_actor)
|
||||||
|
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
|
|
||||||
@ -292,6 +292,7 @@ class EngineArgs:
|
|||||||
data_parallel_size_local: Optional[int] = None
|
data_parallel_size_local: Optional[int] = None
|
||||||
data_parallel_address: Optional[str] = None
|
data_parallel_address: Optional[str] = None
|
||||||
data_parallel_rpc_port: Optional[int] = None
|
data_parallel_rpc_port: Optional[int] = None
|
||||||
|
data_parallel_backend: str = ParallelConfig.data_parallel_backend
|
||||||
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
|
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
|
||||||
max_parallel_loading_workers: Optional[
|
max_parallel_loading_workers: Optional[
|
||||||
int] = ParallelConfig.max_parallel_loading_workers
|
int] = ParallelConfig.max_parallel_loading_workers
|
||||||
@ -624,6 +625,12 @@ class EngineArgs:
|
|||||||
type=int,
|
type=int,
|
||||||
help='Port for data parallel RPC '
|
help='Port for data parallel RPC '
|
||||||
'communication.')
|
'communication.')
|
||||||
|
parallel_group.add_argument('--data-parallel-backend',
|
||||||
|
'-dpb',
|
||||||
|
type=str,
|
||||||
|
default='mp',
|
||||||
|
help='Backend for data parallel, either '
|
||||||
|
'"mp" or "ray".')
|
||||||
parallel_group.add_argument(
|
parallel_group.add_argument(
|
||||||
"--enable-expert-parallel",
|
"--enable-expert-parallel",
|
||||||
**parallel_kwargs["enable_expert_parallel"])
|
**parallel_kwargs["enable_expert_parallel"])
|
||||||
@ -1059,9 +1066,20 @@ class EngineArgs:
|
|||||||
|
|
||||||
# DP address, used in multi-node case for torch distributed group
|
# DP address, used in multi-node case for torch distributed group
|
||||||
# and ZMQ sockets.
|
# and ZMQ sockets.
|
||||||
data_parallel_address = self.data_parallel_address if (
|
if self.data_parallel_address is None:
|
||||||
self.data_parallel_address
|
if self.data_parallel_backend == "ray":
|
||||||
is not None) else ParallelConfig.data_parallel_master_ip
|
host_ip = get_ip()
|
||||||
|
logger.info(
|
||||||
|
"Using host IP %s as ray-based data parallel address",
|
||||||
|
host_ip)
|
||||||
|
data_parallel_address = host_ip
|
||||||
|
else:
|
||||||
|
assert self.data_parallel_backend == "mp", (
|
||||||
|
"data_parallel_backend can only be ray or mp, got %s",
|
||||||
|
self.data_parallel_backend)
|
||||||
|
data_parallel_address = ParallelConfig.data_parallel_master_ip
|
||||||
|
else:
|
||||||
|
data_parallel_address = self.data_parallel_address
|
||||||
|
|
||||||
# This port is only used when there are remote data parallel engines,
|
# This port is only used when there are remote data parallel engines,
|
||||||
# otherwise the local IPC transport is used.
|
# otherwise the local IPC transport is used.
|
||||||
@ -1069,6 +1087,8 @@ class EngineArgs:
|
|||||||
self.data_parallel_rpc_port
|
self.data_parallel_rpc_port
|
||||||
is not None) else ParallelConfig.data_parallel_rpc_port
|
is not None) else ParallelConfig.data_parallel_rpc_port
|
||||||
|
|
||||||
|
data_parallel_backend = self.data_parallel_backend
|
||||||
|
|
||||||
parallel_config = ParallelConfig(
|
parallel_config = ParallelConfig(
|
||||||
pipeline_parallel_size=self.pipeline_parallel_size,
|
pipeline_parallel_size=self.pipeline_parallel_size,
|
||||||
tensor_parallel_size=self.tensor_parallel_size,
|
tensor_parallel_size=self.tensor_parallel_size,
|
||||||
@ -1076,6 +1096,7 @@ class EngineArgs:
|
|||||||
data_parallel_size_local=data_parallel_size_local,
|
data_parallel_size_local=data_parallel_size_local,
|
||||||
data_parallel_master_ip=data_parallel_address,
|
data_parallel_master_ip=data_parallel_address,
|
||||||
data_parallel_rpc_port=data_parallel_rpc_port,
|
data_parallel_rpc_port=data_parallel_rpc_port,
|
||||||
|
data_parallel_backend=data_parallel_backend,
|
||||||
enable_expert_parallel=self.enable_expert_parallel,
|
enable_expert_parallel=self.enable_expert_parallel,
|
||||||
max_parallel_loading_workers=self.max_parallel_loading_workers,
|
max_parallel_loading_workers=self.max_parallel_loading_workers,
|
||||||
disable_custom_all_reduce=self.disable_custom_all_reduce,
|
disable_custom_all_reduce=self.disable_custom_all_reduce,
|
||||||
|
|||||||
@ -27,7 +27,8 @@ from vllm.v1.engine.core_client import CoreEngineProcManager
|
|||||||
from vllm.v1.executor.abstract import Executor
|
from vllm.v1.executor.abstract import Executor
|
||||||
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
|
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
|
||||||
from vllm.v1.utils import (APIServerProcessManager, CoreEngine,
|
from vllm.v1.utils import (APIServerProcessManager, CoreEngine,
|
||||||
EngineZmqAddresses, get_engine_client_zmq_addr,
|
CoreEngineActorManager, EngineZmqAddresses,
|
||||||
|
get_engine_client_zmq_addr,
|
||||||
wait_for_completion_or_failure,
|
wait_for_completion_or_failure,
|
||||||
wait_for_engine_startup)
|
wait_for_engine_startup)
|
||||||
|
|
||||||
@ -229,6 +230,31 @@ def run_multi_api_server(args: argparse.Namespace):
|
|||||||
logger.info("Started DP Coordinator process (PID: %d)",
|
logger.info("Started DP Coordinator process (PID: %d)",
|
||||||
coordinator.proc.pid)
|
coordinator.proc.pid)
|
||||||
|
|
||||||
|
if parallel_config.data_parallel_backend == "ray":
|
||||||
|
logger.info("Starting ray-based data parallel backend")
|
||||||
|
|
||||||
|
engine_actor_manager = CoreEngineActorManager(
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
addresses=addresses,
|
||||||
|
executor_class=Executor.get_class(vllm_config),
|
||||||
|
log_stats=not engine_args.disable_log_stats,
|
||||||
|
)
|
||||||
|
# Start API servers using the manager
|
||||||
|
api_server_manager = APIServerProcessManager(
|
||||||
|
target_server_fn=run_api_server_worker_proc,
|
||||||
|
listen_address=listen_address,
|
||||||
|
sock=sock,
|
||||||
|
args=args,
|
||||||
|
num_servers=num_api_servers,
|
||||||
|
input_addresses=input_addresses,
|
||||||
|
output_addresses=output_addresses,
|
||||||
|
stats_update_address=stats_update_address)
|
||||||
|
|
||||||
|
wait_for_completion_or_failure(api_server_manager=api_server_manager,
|
||||||
|
engine_manager=engine_actor_manager,
|
||||||
|
coordinator=coordinator)
|
||||||
|
return
|
||||||
|
|
||||||
handshake_address = get_engine_client_zmq_addr(
|
handshake_address = get_engine_client_zmq_addr(
|
||||||
local_only, host, parallel_config.data_parallel_rpc_port)
|
local_only, host, parallel_config.data_parallel_rpc_port)
|
||||||
|
|
||||||
@ -277,10 +303,9 @@ def run_multi_api_server(args: argparse.Namespace):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Wait for API servers
|
# Wait for API servers
|
||||||
wait_for_completion_or_failure(
|
wait_for_completion_or_failure(api_server_manager=api_server_manager,
|
||||||
api_server_manager=api_server_manager,
|
engine_manager=local_engine_manager,
|
||||||
local_engine_manager=local_engine_manager,
|
coordinator=coordinator)
|
||||||
coordinator=coordinator)
|
|
||||||
|
|
||||||
|
|
||||||
def run_api_server_worker_proc(listen_address,
|
def run_api_server_worker_proc(listen_address,
|
||||||
|
|||||||
@ -27,7 +27,8 @@ from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
|||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.utils import Device, cdiv
|
from vllm.utils import Device, cdiv
|
||||||
from vllm.v1.engine import EngineCoreRequest
|
from vllm.v1.engine import EngineCoreRequest
|
||||||
from vllm.v1.engine.core_client import AsyncMPClient, DPAsyncMPClient
|
from vllm.v1.engine.core_client import (AsyncMPClient, DPAsyncMPClient,
|
||||||
|
RayDPClient)
|
||||||
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
|
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
|
||||||
from vllm.v1.engine.output_processor import (OutputProcessor,
|
from vllm.v1.engine.output_processor import (OutputProcessor,
|
||||||
RequestOutputCollector)
|
RequestOutputCollector)
|
||||||
@ -119,9 +120,13 @@ class AsyncLLM(EngineClient):
|
|||||||
log_stats=self.log_stats)
|
log_stats=self.log_stats)
|
||||||
|
|
||||||
# EngineCore (starts the engine in background process).
|
# EngineCore (starts the engine in background process).
|
||||||
core_client_class = AsyncMPClient if (
|
core_client_class: type[AsyncMPClient]
|
||||||
vllm_config.parallel_config.data_parallel_size
|
if vllm_config.parallel_config.data_parallel_size == 1:
|
||||||
== 1) else DPAsyncMPClient
|
core_client_class = AsyncMPClient
|
||||||
|
elif vllm_config.parallel_config.data_parallel_backend == "ray":
|
||||||
|
core_client_class = RayDPClient
|
||||||
|
else:
|
||||||
|
core_client_class = DPAsyncMPClient
|
||||||
|
|
||||||
self.engine_core = core_client_class(
|
self.engine_core = core_client_class(
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
|
|||||||
@ -6,8 +6,9 @@ import sys
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
from collections.abc import Generator
|
||||||
from concurrent.futures import Future
|
from concurrent.futures import Future
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack, contextmanager
|
||||||
from inspect import isclass, signature
|
from inspect import isclass, signature
|
||||||
from logging import DEBUG
|
from logging import DEBUG
|
||||||
from typing import Any, Callable, Optional, TypeVar, Union
|
from typing import Any, Callable, Optional, TypeVar, Union
|
||||||
@ -367,42 +368,66 @@ class EngineCoreProc(EngineCore):
|
|||||||
log_stats: bool,
|
log_stats: bool,
|
||||||
engine_index: int = 0,
|
engine_index: int = 0,
|
||||||
):
|
):
|
||||||
input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
|
self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
|
||||||
|
self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs],
|
||||||
executor_fail_callback = lambda: input_queue.put_nowait(
|
bytes]]()
|
||||||
|
executor_fail_callback = lambda: self.input_queue.put_nowait(
|
||||||
(EngineCoreRequestType.EXECUTOR_FAILED, b''))
|
(EngineCoreRequestType.EXECUTOR_FAILED, b''))
|
||||||
|
|
||||||
# Create input socket.
|
self.engine_index = engine_index
|
||||||
|
identity = self.engine_index.to_bytes(length=2, byteorder="little")
|
||||||
|
self.engines_running = False
|
||||||
|
|
||||||
|
with self._perform_handshake(handshake_address, identity, on_head_node,
|
||||||
|
vllm_config) as addresses:
|
||||||
|
self.client_count = len(addresses.outputs)
|
||||||
|
|
||||||
|
# Set up data parallel environment.
|
||||||
|
self.has_coordinator = addresses.coordinator_output is not None
|
||||||
|
self._init_data_parallel(vllm_config)
|
||||||
|
|
||||||
|
super().__init__(vllm_config, executor_class, log_stats,
|
||||||
|
executor_fail_callback)
|
||||||
|
|
||||||
|
self.step_fn = (self.step if self.batch_queue is None else
|
||||||
|
self.step_with_batch_queue)
|
||||||
|
|
||||||
|
# Background Threads and Queues for IO. These enable us to
|
||||||
|
# overlap ZMQ socket IO with GPU since they release the GIL,
|
||||||
|
# and to overlap some serialization/deserialization with the
|
||||||
|
# model forward pass.
|
||||||
|
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
|
||||||
|
threading.Thread(target=self.process_input_sockets,
|
||||||
|
args=(addresses.inputs, addresses.coordinator_input,
|
||||||
|
identity),
|
||||||
|
daemon=True).start()
|
||||||
|
self.output_thread = threading.Thread(
|
||||||
|
target=self.process_output_sockets,
|
||||||
|
args=(addresses.outputs, addresses.coordinator_output,
|
||||||
|
self.engine_index),
|
||||||
|
daemon=True)
|
||||||
|
self.output_thread.start()
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _perform_handshake(
|
||||||
|
self, handshake_address: str, identity: bytes, on_head_node: bool,
|
||||||
|
vllm_config: VllmConfig
|
||||||
|
) -> Generator[EngineZmqAddresses, None, None]:
|
||||||
input_ctx = zmq.Context()
|
input_ctx = zmq.Context()
|
||||||
identity = engine_index.to_bytes(length=2, byteorder="little")
|
|
||||||
with make_zmq_socket(input_ctx,
|
with make_zmq_socket(input_ctx,
|
||||||
handshake_address,
|
handshake_address,
|
||||||
zmq.DEALER,
|
zmq.DEALER,
|
||||||
identity=identity,
|
identity=identity,
|
||||||
linger=5000,
|
linger=5000,
|
||||||
bind=False) as handshake_socket:
|
bind=False) as handshake_socket:
|
||||||
|
|
||||||
# Register engine with front-end.
|
# Register engine with front-end.
|
||||||
addresses = self.startup_handshake(handshake_socket, on_head_node,
|
addresses = self.startup_handshake(handshake_socket, on_head_node,
|
||||||
vllm_config.parallel_config)
|
vllm_config.parallel_config)
|
||||||
self.client_count = len(addresses.outputs)
|
|
||||||
|
|
||||||
# Update config which may have changed from the handshake.
|
# Update config which may have changed from the handshake
|
||||||
vllm_config.__post_init__()
|
vllm_config.__post_init__()
|
||||||
|
|
||||||
# Set up data parallel environment.
|
yield addresses
|
||||||
self.has_coordinator = addresses.coordinator_output is not None
|
|
||||||
self._init_data_parallel(vllm_config)
|
|
||||||
|
|
||||||
# Initialize engine core and model.
|
|
||||||
super().__init__(vllm_config, executor_class, log_stats,
|
|
||||||
executor_fail_callback)
|
|
||||||
|
|
||||||
self.engine_index = engine_index
|
|
||||||
self.step_fn = (self.step if self.batch_queue is None else
|
|
||||||
self.step_with_batch_queue)
|
|
||||||
self.engines_running = False
|
|
||||||
self.last_counts = (0, 0)
|
|
||||||
|
|
||||||
# Send ready message.
|
# Send ready message.
|
||||||
num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
|
num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
|
||||||
@ -413,25 +438,6 @@ class EngineCoreProc(EngineCore):
|
|||||||
"num_gpu_blocks": num_gpu_blocks,
|
"num_gpu_blocks": num_gpu_blocks,
|
||||||
}))
|
}))
|
||||||
|
|
||||||
# Background Threads and Queues for IO. These enable us to
|
|
||||||
# overlap ZMQ socket IO with GPU since they release the GIL,
|
|
||||||
# and to overlap some serialization/deserialization with the
|
|
||||||
# model forward pass.
|
|
||||||
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
|
|
||||||
self.input_queue = input_queue
|
|
||||||
self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs],
|
|
||||||
bytes]]()
|
|
||||||
threading.Thread(target=self.process_input_sockets,
|
|
||||||
args=(addresses.inputs, addresses.coordinator_input,
|
|
||||||
identity),
|
|
||||||
daemon=True).start()
|
|
||||||
self.output_thread = threading.Thread(
|
|
||||||
target=self.process_output_sockets,
|
|
||||||
args=(addresses.outputs, addresses.coordinator_output,
|
|
||||||
engine_index),
|
|
||||||
daemon=True)
|
|
||||||
self.output_thread.start()
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def startup_handshake(
|
def startup_handshake(
|
||||||
handshake_socket: zmq.Socket, on_head_node: bool,
|
handshake_socket: zmq.Socket, on_head_node: bool,
|
||||||
@ -743,6 +749,21 @@ class DPEngineCoreProc(EngineCoreProc):
|
|||||||
executor_class: type[Executor],
|
executor_class: type[Executor],
|
||||||
log_stats: bool,
|
log_stats: bool,
|
||||||
):
|
):
|
||||||
|
|
||||||
|
self._decorate_logs()
|
||||||
|
|
||||||
|
# Counts forward-passes of the model so that we can synchronize
|
||||||
|
# finished with DP peers every N steps.
|
||||||
|
self.counter = 0
|
||||||
|
self.current_wave = 0
|
||||||
|
self.last_counts = (0, 0)
|
||||||
|
|
||||||
|
# Initialize the engine.
|
||||||
|
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||||
|
super().__init__(vllm_config, on_head_node, handshake_address,
|
||||||
|
executor_class, log_stats, dp_rank)
|
||||||
|
|
||||||
|
def _decorate_logs(self):
|
||||||
# Add process-specific prefix to stdout and stderr before
|
# Add process-specific prefix to stdout and stderr before
|
||||||
# we initialize the engine.
|
# we initialize the engine.
|
||||||
from multiprocessing import current_process
|
from multiprocessing import current_process
|
||||||
@ -751,16 +772,6 @@ class DPEngineCoreProc(EngineCoreProc):
|
|||||||
_add_prefix(sys.stdout, process_name, pid)
|
_add_prefix(sys.stdout, process_name, pid)
|
||||||
_add_prefix(sys.stderr, process_name, pid)
|
_add_prefix(sys.stderr, process_name, pid)
|
||||||
|
|
||||||
# Counts forward-passes of the model so that we can synchronize
|
|
||||||
# finished with DP peers every N steps.
|
|
||||||
self.counter = 0
|
|
||||||
self.current_wave = 0
|
|
||||||
|
|
||||||
# Initialize the engine.
|
|
||||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
|
||||||
super().__init__(vllm_config, on_head_node, handshake_address,
|
|
||||||
executor_class, log_stats, dp_rank)
|
|
||||||
|
|
||||||
def _init_data_parallel(self, vllm_config: VllmConfig):
|
def _init_data_parallel(self, vllm_config: VllmConfig):
|
||||||
|
|
||||||
# Configure GPUs and stateless process group for data parallel.
|
# Configure GPUs and stateless process group for data parallel.
|
||||||
@ -880,3 +891,70 @@ class DPEngineCoreProc(EngineCoreProc):
|
|||||||
|
|
||||||
return ParallelConfig.has_unfinished_dp(self.dp_group,
|
return ParallelConfig.has_unfinished_dp(self.dp_group,
|
||||||
local_unfinished)
|
local_unfinished)
|
||||||
|
|
||||||
|
|
||||||
|
class DPEngineCoreActor(DPEngineCoreProc):
|
||||||
|
"""
|
||||||
|
Ray actor for running EngineCore in a data parallel context
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
on_head_node: bool,
|
||||||
|
addresses: EngineZmqAddresses,
|
||||||
|
executor_class: type[Executor],
|
||||||
|
log_stats: bool,
|
||||||
|
dp_rank: int = 0,
|
||||||
|
local_dp_rank: int = 0,
|
||||||
|
):
|
||||||
|
self.addresses = addresses
|
||||||
|
vllm_config.parallel_config.data_parallel_rank = dp_rank
|
||||||
|
vllm_config.parallel_config.data_parallel_rank_local = \
|
||||||
|
local_dp_rank
|
||||||
|
|
||||||
|
# Ray sets CUDA_VISIBLE_DEVICES to empty string,
|
||||||
|
# we clean this up to be able to properly initialize
|
||||||
|
# data parallel groups.
|
||||||
|
del os.environ['CUDA_VISIBLE_DEVICES']
|
||||||
|
|
||||||
|
super().__init__(vllm_config, on_head_node, "", executor_class,
|
||||||
|
log_stats)
|
||||||
|
|
||||||
|
def _decorate_logs(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _perform_handshake(self, handshake_address: str, identity: bytes,
|
||||||
|
on_head_node: bool, vllm_config: VllmConfig):
|
||||||
|
"""
|
||||||
|
For Ray, we don't need to actually perform handshake.
|
||||||
|
All addresses information is known before the actor creation.
|
||||||
|
Therefore, we simply yield these addresses.
|
||||||
|
"""
|
||||||
|
yield self.addresses
|
||||||
|
|
||||||
|
def wait_for_init(self):
|
||||||
|
"""
|
||||||
|
Wait until the engine core is initialized.
|
||||||
|
|
||||||
|
This is just an empty method. When ray.get() on this method
|
||||||
|
(or any other method of the actor) returns, it is guaranteed
|
||||||
|
that actor creation (i.e., __init__) is complete.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
"""
|
||||||
|
Run the engine core busy loop.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.run_busy_loop()
|
||||||
|
except SystemExit:
|
||||||
|
logger.debug("EngineCore exiting.")
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
logger.exception("EngineCore encountered a fatal error.")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self.shutdown()
|
||||||
|
|||||||
@ -29,9 +29,9 @@ from vllm.v1.engine.core import EngineCore, EngineCoreProc
|
|||||||
from vllm.v1.engine.exceptions import EngineDeadError
|
from vllm.v1.engine.exceptions import EngineDeadError
|
||||||
from vllm.v1.executor.abstract import Executor
|
from vllm.v1.executor.abstract import Executor
|
||||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr
|
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr
|
||||||
from vllm.v1.utils import (CoreEngine, CoreEngineProcManager,
|
from vllm.v1.utils import (CoreEngine, CoreEngineActorManager,
|
||||||
EngineZmqAddresses, get_engine_client_zmq_addr,
|
CoreEngineProcManager, EngineZmqAddresses,
|
||||||
wait_for_engine_startup)
|
get_engine_client_zmq_addr, wait_for_engine_startup)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -68,6 +68,8 @@ class EngineCoreClient(ABC):
|
|||||||
|
|
||||||
if multiprocess_mode and asyncio_mode:
|
if multiprocess_mode and asyncio_mode:
|
||||||
if vllm_config.parallel_config.data_parallel_size > 1:
|
if vllm_config.parallel_config.data_parallel_size > 1:
|
||||||
|
if vllm_config.parallel_config.data_parallel_backend == "ray":
|
||||||
|
return RayDPClient(vllm_config, executor_class, log_stats)
|
||||||
return DPAsyncMPClient(vllm_config, executor_class, log_stats)
|
return DPAsyncMPClient(vllm_config, executor_class, log_stats)
|
||||||
|
|
||||||
return AsyncMPClient(vllm_config, executor_class, log_stats)
|
return AsyncMPClient(vllm_config, executor_class, log_stats)
|
||||||
@ -273,7 +275,10 @@ class BackgroundResources:
|
|||||||
circular reference back to the client object."""
|
circular reference back to the client object."""
|
||||||
|
|
||||||
ctx: Union[zmq.Context]
|
ctx: Union[zmq.Context]
|
||||||
local_engine_manager: Optional[CoreEngineProcManager] = None
|
# If CoreEngineProcManager, it manages local engines;
|
||||||
|
# if CoreEngineActorManager, it manages all engines.
|
||||||
|
engine_manager: Optional[Union[CoreEngineProcManager,
|
||||||
|
CoreEngineActorManager]] = None
|
||||||
coordinator: Optional[DPCoordinator] = None
|
coordinator: Optional[DPCoordinator] = None
|
||||||
output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
|
output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
|
||||||
input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
|
input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
|
||||||
@ -290,8 +295,8 @@ class BackgroundResources:
|
|||||||
"""Clean up background resources."""
|
"""Clean up background resources."""
|
||||||
|
|
||||||
self.engine_dead = True
|
self.engine_dead = True
|
||||||
if self.local_engine_manager is not None:
|
if self.engine_manager is not None:
|
||||||
self.local_engine_manager.close()
|
self.engine_manager.close()
|
||||||
if self.coordinator is not None:
|
if self.coordinator is not None:
|
||||||
self.coordinator.close()
|
self.coordinator.close()
|
||||||
|
|
||||||
@ -457,7 +462,7 @@ class MPClient(EngineCoreClient):
|
|||||||
if local_engine_count:
|
if local_engine_count:
|
||||||
# In server mode, start_index and local_start_index will
|
# In server mode, start_index and local_start_index will
|
||||||
# both be 0.
|
# both be 0.
|
||||||
self.resources.local_engine_manager = CoreEngineProcManager(
|
self.resources.engine_manager = CoreEngineProcManager(
|
||||||
EngineCoreProc.run_engine_core,
|
EngineCoreProc.run_engine_core,
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
executor_class=executor_class,
|
executor_class=executor_class,
|
||||||
@ -484,13 +489,18 @@ class MPClient(EngineCoreClient):
|
|||||||
addresses.coordinator_input, addresses.coordinator_output = (
|
addresses.coordinator_input, addresses.coordinator_output = (
|
||||||
coordinator.get_engine_socket_addresses())
|
coordinator.get_engine_socket_addresses())
|
||||||
|
|
||||||
|
proc_manager = self.resources.engine_manager
|
||||||
|
assert isinstance(proc_manager, (type(None), CoreEngineProcManager)), (
|
||||||
|
"_wait_for_engine_startup should only be "
|
||||||
|
"called with CoreEngineProcManager")
|
||||||
|
|
||||||
wait_for_engine_startup(
|
wait_for_engine_startup(
|
||||||
handshake_socket,
|
handshake_socket,
|
||||||
addresses,
|
addresses,
|
||||||
self.core_engines,
|
self.core_engines,
|
||||||
self.vllm_config.parallel_config,
|
self.vllm_config.parallel_config,
|
||||||
self.vllm_config.cache_config,
|
self.vllm_config.cache_config,
|
||||||
self.resources.local_engine_manager,
|
proc_manager,
|
||||||
coordinator.proc if coordinator else None,
|
coordinator.proc if coordinator else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -887,7 +897,6 @@ class DPAsyncMPClient(AsyncMPClient):
|
|||||||
log_stats: bool,
|
log_stats: bool,
|
||||||
client_addresses: Optional[dict[str, str]] = None,
|
client_addresses: Optional[dict[str, str]] = None,
|
||||||
client_index: int = 0):
|
client_index: int = 0):
|
||||||
|
|
||||||
self.current_wave = 0
|
self.current_wave = 0
|
||||||
self.engines_running = False
|
self.engines_running = False
|
||||||
# To route aborts to the correct engine.
|
# To route aborts to the correct engine.
|
||||||
@ -1050,3 +1059,50 @@ class DPAsyncMPClient(AsyncMPClient):
|
|||||||
if not self.resources.engine_dead:
|
if not self.resources.engine_dead:
|
||||||
await self._send_input(EngineCoreRequestType.ABORT, request_ids,
|
await self._send_input(EngineCoreRequestType.ABORT, request_ids,
|
||||||
engine)
|
engine)
|
||||||
|
|
||||||
|
|
||||||
|
class RayDPClient(DPAsyncMPClient):
|
||||||
|
"""
|
||||||
|
Ray-based client for multi-proc, multi-engine (data parallel)
|
||||||
|
EngineCore.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
executor_class: type[Executor],
|
||||||
|
log_stats: bool,
|
||||||
|
client_addresses: Optional[dict[str, str]] = None,
|
||||||
|
client_index: int = 0,
|
||||||
|
):
|
||||||
|
super().__init__(vllm_config, executor_class, log_stats,
|
||||||
|
client_addresses, client_index)
|
||||||
|
|
||||||
|
def _init_engines_direct(self, vllm_config: VllmConfig, local_only: bool,
|
||||||
|
local_start_index: int, input_address: str,
|
||||||
|
output_address: str,
|
||||||
|
executor_class: type[Executor], log_stats: bool):
|
||||||
|
"""Self-contained client mode, launch engine and coordinator process
|
||||||
|
as needed."""
|
||||||
|
|
||||||
|
parallel_config = vllm_config.parallel_config
|
||||||
|
assert parallel_config.data_parallel_rank == 0
|
||||||
|
assert local_start_index == 0
|
||||||
|
|
||||||
|
addresses = EngineZmqAddresses(
|
||||||
|
inputs=[input_address],
|
||||||
|
outputs=[output_address],
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(self.core_engines) > 1:
|
||||||
|
coordinator = DPCoordinator(parallel_config)
|
||||||
|
self.resources.coordinator = coordinator
|
||||||
|
addresses.coordinator_input, addresses.coordinator_output = (
|
||||||
|
coordinator.get_engine_socket_addresses())
|
||||||
|
|
||||||
|
# Start all engines.
|
||||||
|
self.resources.engine_manager = CoreEngineActorManager(
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
addresses=addresses,
|
||||||
|
executor_class=executor_class,
|
||||||
|
log_stats=log_stats)
|
||||||
|
|||||||
269
vllm/v1/utils.py
269
vllm/v1/utils.py
@ -27,6 +27,8 @@ from vllm.utils import (get_mp_context, get_open_port, get_open_zmq_ipc_path,
|
|||||||
from vllm.v1.executor.abstract import Executor
|
from vllm.v1.executor.abstract import Executor
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from ray.util.placement_group import PlacementGroup
|
||||||
|
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.v1.engine.coordinator import DPCoordinator
|
from vllm.v1.engine.coordinator import DPCoordinator
|
||||||
|
|
||||||
@ -112,6 +114,45 @@ def get_engine_client_zmq_addr(local_only: bool,
|
|||||||
host, port or get_open_port()))
|
host, port or get_open_port()))
|
||||||
|
|
||||||
|
|
||||||
|
class CoreEngineState(Enum):
|
||||||
|
NEW = auto()
|
||||||
|
CONNECTED = auto()
|
||||||
|
READY = auto()
|
||||||
|
|
||||||
|
|
||||||
|
class CoreEngine:
|
||||||
|
"""One per data parallel rank."""
|
||||||
|
|
||||||
|
def __init__(self, index: int = 0, local: bool = True):
|
||||||
|
self.local = local
|
||||||
|
self.index = index
|
||||||
|
self.identity = index.to_bytes(2, "little")
|
||||||
|
|
||||||
|
self.state = CoreEngineState.NEW
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EngineZmqAddresses:
|
||||||
|
# ZMQ input socket addresses for each front-end client (requests)
|
||||||
|
inputs: list[str]
|
||||||
|
# ZMQ output socket addresses for each front-end client (responses)
|
||||||
|
outputs: list[str]
|
||||||
|
# ZMQ input socket address of DP coordinator if applicable
|
||||||
|
coordinator_input: Optional[str] = None
|
||||||
|
# ZMQ output socket address of DP coordinator if applicable
|
||||||
|
coordinator_output: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EngineHandshakeMetadata:
|
||||||
|
"""Metadata sent to each engine process during startup handshake,
|
||||||
|
including addresses of the front-end ZMQ queues that they should
|
||||||
|
connect to.
|
||||||
|
"""
|
||||||
|
addresses: EngineZmqAddresses
|
||||||
|
parallel_config: dict[str, Union[int, str]]
|
||||||
|
|
||||||
|
|
||||||
class APIServerProcessManager:
|
class APIServerProcessManager:
|
||||||
"""Manages a group of API server processes.
|
"""Manages a group of API server processes.
|
||||||
|
|
||||||
@ -245,43 +286,168 @@ class CoreEngineProcManager:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class CoreEngineState(Enum):
|
class CoreEngineActorManager:
|
||||||
NEW = auto()
|
|
||||||
CONNECTED = auto()
|
|
||||||
READY = auto()
|
|
||||||
|
|
||||||
|
|
||||||
class CoreEngine:
|
|
||||||
"""One per data parallel rank."""
|
|
||||||
|
|
||||||
def __init__(self, index: int = 0, local: bool = True):
|
|
||||||
self.local = local
|
|
||||||
self.index = index
|
|
||||||
self.identity = index.to_bytes(2, "little")
|
|
||||||
|
|
||||||
self.state = CoreEngineState.NEW
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class EngineZmqAddresses:
|
|
||||||
# ZMQ input socket addresses for each front-end client (requests)
|
|
||||||
inputs: list[str]
|
|
||||||
# ZMQ output socket addresses for each front-end client (responses)
|
|
||||||
outputs: list[str]
|
|
||||||
# ZMQ input socket address of DP coordinator if applicable
|
|
||||||
coordinator_input: Optional[str] = None
|
|
||||||
# ZMQ output socket address of DP coordinator if applicable
|
|
||||||
coordinator_output: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class EngineHandshakeMetadata:
|
|
||||||
"""Metadata sent to each engine process during startup handshake,
|
|
||||||
including addresses of the front-end ZMQ queues that they should
|
|
||||||
connect to.
|
|
||||||
"""
|
"""
|
||||||
addresses: EngineZmqAddresses
|
Utility class to handle creation, readiness, and shutdown
|
||||||
parallel_config: dict[str, Union[int, str]]
|
of core engine Ray actors used by the AsyncLLM and LLMEngine.
|
||||||
|
|
||||||
|
Different from CoreEngineProcManager, this class manages
|
||||||
|
core engines for both local and remote nodes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
addresses: EngineZmqAddresses,
|
||||||
|
executor_class: type[Executor],
|
||||||
|
log_stats: bool,
|
||||||
|
placement_groups: Optional[list["PlacementGroup"]] = None,
|
||||||
|
local_dp_ranks: Optional[list[int]] = None,
|
||||||
|
):
|
||||||
|
import copy
|
||||||
|
|
||||||
|
import ray
|
||||||
|
from ray.util.scheduling_strategies import (
|
||||||
|
PlacementGroupSchedulingStrategy)
|
||||||
|
|
||||||
|
from vllm.v1.engine.core import DPEngineCoreActor
|
||||||
|
|
||||||
|
self.local_engine_actors: list[ray.ActorHandle] = []
|
||||||
|
self.remote_engine_actors: list[ray.ActorHandle] = []
|
||||||
|
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||||
|
local_engine_count = \
|
||||||
|
vllm_config.parallel_config.data_parallel_size_local
|
||||||
|
world_size = vllm_config.parallel_config.world_size
|
||||||
|
|
||||||
|
if ray.is_initialized():
|
||||||
|
logger.info(
|
||||||
|
"Ray is already initialized. Skipping Ray initialization.")
|
||||||
|
else:
|
||||||
|
ray.init()
|
||||||
|
|
||||||
|
if placement_groups is not None:
|
||||||
|
assert local_dp_ranks is not None, (
|
||||||
|
"local_dp_ranks must be provided if "
|
||||||
|
"placement_groups is provided")
|
||||||
|
assert len(placement_groups) == len(local_dp_ranks), (
|
||||||
|
"placement_groups and local_dp_ranks must "
|
||||||
|
"have the same length")
|
||||||
|
logger.info("Using provided placement groups")
|
||||||
|
# TODO(rui): validate passed-in placement groups
|
||||||
|
self.created_placement_groups = []
|
||||||
|
else:
|
||||||
|
placement_groups, local_dp_ranks = \
|
||||||
|
CoreEngineActorManager.create_dp_placement_groups(vllm_config)
|
||||||
|
self.created_placement_groups = placement_groups
|
||||||
|
assert len(placement_groups) == dp_size, (
|
||||||
|
"Number of placement groups must match data parallel size")
|
||||||
|
|
||||||
|
refs = []
|
||||||
|
for index in range(dp_size):
|
||||||
|
local_index = local_dp_ranks[index]
|
||||||
|
dp_vllm_config = copy.deepcopy(vllm_config)
|
||||||
|
pg = placement_groups[index]
|
||||||
|
dp_vllm_config.parallel_config.placement_group = pg
|
||||||
|
on_head_node = index < local_engine_count
|
||||||
|
actor = ray.remote(DPEngineCoreActor).options(
|
||||||
|
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||||
|
placement_group=pg,
|
||||||
|
placement_group_bundle_index=world_size,
|
||||||
|
)).remote(vllm_config=dp_vllm_config,
|
||||||
|
executor_class=executor_class,
|
||||||
|
log_stats=log_stats,
|
||||||
|
on_head_node=on_head_node,
|
||||||
|
addresses=addresses,
|
||||||
|
dp_rank=index,
|
||||||
|
local_dp_rank=local_index)
|
||||||
|
if on_head_node:
|
||||||
|
self.local_engine_actors.append(actor)
|
||||||
|
else:
|
||||||
|
self.remote_engine_actors.append(actor)
|
||||||
|
refs.append(actor.wait_for_init.remote())
|
||||||
|
|
||||||
|
ray.get(refs)
|
||||||
|
self.run_refs = []
|
||||||
|
for actor in self.local_engine_actors + self.remote_engine_actors:
|
||||||
|
self.run_refs.append(actor.run.remote())
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_dp_placement_groups(
|
||||||
|
vllm_config: VllmConfig
|
||||||
|
) -> tuple[list["PlacementGroup"], list[int]]:
|
||||||
|
|
||||||
|
import ray
|
||||||
|
from ray._private.state import available_resources_per_node
|
||||||
|
from ray.util.state import list_nodes
|
||||||
|
|
||||||
|
logger.info("Creating placement groups for data parallel")
|
||||||
|
dp_master_ip = \
|
||||||
|
vllm_config.parallel_config.data_parallel_master_ip
|
||||||
|
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||||
|
local_engine_count = \
|
||||||
|
vllm_config.parallel_config.data_parallel_size_local
|
||||||
|
|
||||||
|
nodes = list_nodes()
|
||||||
|
nodes = sorted(list_nodes(),
|
||||||
|
key=lambda node: node.node_ip != dp_master_ip)
|
||||||
|
assert nodes[0].node_ip == dp_master_ip, (
|
||||||
|
"The first node must be the head node")
|
||||||
|
assert len(nodes) == 1 or nodes[1].node_ip != dp_master_ip, (
|
||||||
|
"There can only be one head node")
|
||||||
|
|
||||||
|
available_resources = available_resources_per_node()
|
||||||
|
world_size = vllm_config.parallel_config.world_size
|
||||||
|
placement_groups: list[PlacementGroup] = []
|
||||||
|
local_dp_ranks: list[int] = []
|
||||||
|
|
||||||
|
for node in nodes:
|
||||||
|
node_ip = node.node_ip
|
||||||
|
node_resources = available_resources[node.node_id]
|
||||||
|
# For now, each DP rank can only be assigned to one node
|
||||||
|
# TODO(rui): support allocating a single DP rank
|
||||||
|
# to multiple nodes
|
||||||
|
available_engine_count = node_resources["GPU"] // world_size
|
||||||
|
if node_ip == dp_master_ip:
|
||||||
|
assert available_engine_count >= local_engine_count, (
|
||||||
|
"Not enough resources to allocate DP ranks "
|
||||||
|
f"on DP master node {node_ip}")
|
||||||
|
for i in range(local_engine_count):
|
||||||
|
bundles = [{
|
||||||
|
"GPU": 1.0,
|
||||||
|
"node:" + dp_master_ip: 0.001
|
||||||
|
}] * world_size + [{
|
||||||
|
"CPU": 1.0
|
||||||
|
}]
|
||||||
|
pg = ray.util.placement_group(
|
||||||
|
name=f"dp_rank_{len(placement_groups)}",
|
||||||
|
strategy="STRICT_PACK",
|
||||||
|
bundles=bundles,
|
||||||
|
)
|
||||||
|
placement_groups.append(pg)
|
||||||
|
local_dp_ranks.append(i)
|
||||||
|
else:
|
||||||
|
for i in range(available_engine_count):
|
||||||
|
if len(placement_groups) == dp_size:
|
||||||
|
break
|
||||||
|
bundles = [{"GPU": 1.0}] * world_size + [{"CPU": 1.0}]
|
||||||
|
pg = ray.util.placement_group(
|
||||||
|
name=f"dp_rank_{len(placement_groups)}",
|
||||||
|
strategy="STRICT_PACK",
|
||||||
|
bundles=bundles,
|
||||||
|
)
|
||||||
|
placement_groups.append(pg)
|
||||||
|
local_dp_ranks.append(i)
|
||||||
|
return placement_groups, local_dp_ranks
|
||||||
|
|
||||||
|
def get_run_refs(self):
|
||||||
|
return self.run_refs
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
import ray
|
||||||
|
for actor in self.local_engine_actors + self.remote_engine_actors:
|
||||||
|
ray.kill(actor)
|
||||||
|
for pg in self.created_placement_groups:
|
||||||
|
ray.util.remove_placement_group(pg)
|
||||||
|
|
||||||
|
|
||||||
def wait_for_engine_startup(
|
def wait_for_engine_startup(
|
||||||
@ -383,11 +549,19 @@ def wait_for_engine_startup(
|
|||||||
|
|
||||||
def wait_for_completion_or_failure(
|
def wait_for_completion_or_failure(
|
||||||
api_server_manager: APIServerProcessManager,
|
api_server_manager: APIServerProcessManager,
|
||||||
local_engine_manager: Optional[CoreEngineProcManager] = None,
|
engine_manager: Optional[Union[CoreEngineProcManager,
|
||||||
|
CoreEngineActorManager]] = None,
|
||||||
coordinator: Optional["DPCoordinator"] = None) -> None:
|
coordinator: Optional["DPCoordinator"] = None) -> None:
|
||||||
"""Wait for all processes to complete or detect if any fail.
|
"""Wait for all processes to complete or detect if any fail.
|
||||||
|
|
||||||
Raises an exception if any process exits with a non-zero status.
|
Raises an exception if any process exits with a non-zero status.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_server_manager: The manager for API servers.
|
||||||
|
engine_manager: The manager for engine processes.
|
||||||
|
If CoreEngineProcManager, it manages local engines;
|
||||||
|
if CoreEngineActorManager, it manages all engines.
|
||||||
|
coordinator: The coordinator for data parallel.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -402,14 +576,18 @@ def wait_for_completion_or_failure(
|
|||||||
if coordinator:
|
if coordinator:
|
||||||
sentinel_to_proc[coordinator.proc.sentinel] = coordinator.proc
|
sentinel_to_proc[coordinator.proc.sentinel] = coordinator.proc
|
||||||
|
|
||||||
if local_engine_manager:
|
actor_run_refs = []
|
||||||
for proc in local_engine_manager.processes:
|
if isinstance(engine_manager, CoreEngineProcManager):
|
||||||
|
for proc in engine_manager.processes:
|
||||||
sentinel_to_proc[proc.sentinel] = proc
|
sentinel_to_proc[proc.sentinel] = proc
|
||||||
|
elif isinstance(engine_manager, CoreEngineActorManager):
|
||||||
|
actor_run_refs = engine_manager.get_run_refs()
|
||||||
|
|
||||||
# Check if any process terminates
|
# Check if any process terminates
|
||||||
while sentinel_to_proc:
|
while sentinel_to_proc or actor_run_refs:
|
||||||
# Wait for any process to terminate
|
# Wait for any process to terminate
|
||||||
ready_sentinels: list[Any] = connection.wait(sentinel_to_proc)
|
ready_sentinels: list[Any] = connection.wait(sentinel_to_proc,
|
||||||
|
timeout=5)
|
||||||
|
|
||||||
# Process any terminated processes
|
# Process any terminated processes
|
||||||
for sentinel in ready_sentinels:
|
for sentinel in ready_sentinels:
|
||||||
@ -420,6 +598,11 @@ def wait_for_completion_or_failure(
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Process {proc.name} (PID: {proc.pid}) "
|
f"Process {proc.name} (PID: {proc.pid}) "
|
||||||
f"died with exit code {proc.exitcode}")
|
f"died with exit code {proc.exitcode}")
|
||||||
|
|
||||||
|
if actor_run_refs:
|
||||||
|
import ray
|
||||||
|
_, actor_run_refs = ray.wait(actor_run_refs, timeout=5)
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logger.info("Received KeyboardInterrupt, shutting down API servers...")
|
logger.info("Received KeyboardInterrupt, shutting down API servers...")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -431,8 +614,8 @@ def wait_for_completion_or_failure(
|
|||||||
api_server_manager.close()
|
api_server_manager.close()
|
||||||
if coordinator:
|
if coordinator:
|
||||||
coordinator.close()
|
coordinator.close()
|
||||||
if local_engine_manager:
|
if engine_manager:
|
||||||
local_engine_manager.close()
|
engine_manager.close()
|
||||||
|
|
||||||
|
|
||||||
# Note(rob): shutdown function cannot be a bound method,
|
# Note(rob): shutdown function cannot be a bound method,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user