mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-05 02:37:03 +08:00
[DP] Support external DP Load Balancer mode (#19790)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
a1aafc827a
commit
657f2f301a
@ -155,6 +155,7 @@ steps:
|
||||
- examples/offline_inference/rlhf_colocate.py
|
||||
- tests/examples/offline_inference/data_parallel.py
|
||||
- tests/v1/test_async_llm_dp.py
|
||||
- tests/v1/test_external_lb_dp.py
|
||||
- tests/v1/engine/test_engine_core_client.py
|
||||
commands:
|
||||
# test with tp=2 and external_dp=2
|
||||
@ -163,8 +164,9 @@ steps:
|
||||
# test with tp=2 and pp=2
|
||||
- PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
|
||||
# test with internal dp
|
||||
- python3 ../examples/offline_inference/data_parallel.py
|
||||
- python3 ../examples/offline_inference/data_parallel.py --enforce-eager
|
||||
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
||||
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py
|
||||
- pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp
|
||||
- pytest -v -s distributed/test_utils.py
|
||||
- pytest -v -s compile/test_basic_correctness.py
|
||||
@ -682,10 +684,12 @@ steps:
|
||||
- vllm/worker/model_runner.py
|
||||
- entrypoints/llm/test_collective_rpc.py
|
||||
- tests/v1/test_async_llm_dp.py
|
||||
- tests/v1/test_external_lb_dp.py
|
||||
- tests/v1/entrypoints/openai/test_multi_api_servers.py
|
||||
- vllm/v1/engine/
|
||||
commands:
|
||||
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
||||
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py
|
||||
- DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py
|
||||
- pytest -v -s entrypoints/llm/test_collective_rpc.py
|
||||
- pytest -v -s ./compile/test_basic_correctness.py
|
||||
|
||||
@ -26,8 +26,8 @@ from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.core import EngineCore
|
||||
from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient,
|
||||
SyncMPClient)
|
||||
from vllm.v1.engine.utils import CoreEngineProcManager
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.utils import CoreEngineProcManager
|
||||
|
||||
from ...distributed.conftest import MockSubscriber
|
||||
from ...utils import create_new_process_for_each_test
|
||||
@ -563,7 +563,7 @@ def test_engine_core_proc_instantiation_cuda_empty(
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
m.setenv("CUDA_VISIBLE_DEVICES", "") # No CUDA devices
|
||||
|
||||
from vllm.v1.utils import EngineZmqAddresses
|
||||
from vllm.v1.engine.utils import EngineZmqAddresses
|
||||
|
||||
def mock_startup_handshake(self, handshake_socket, on_head_node,
|
||||
parallel_config):
|
||||
@ -580,7 +580,7 @@ def test_engine_core_proc_instantiation_cuda_empty(
|
||||
trust_remote_code=True).create_engine_config()
|
||||
engine_core_proc = EngineCoreProc(
|
||||
vllm_config=vllm_config,
|
||||
on_head_node=True,
|
||||
local_client=True,
|
||||
handshake_address="tcp://127.0.0.1:12345",
|
||||
executor_class=mock_executor_class,
|
||||
log_stats=False,
|
||||
|
||||
312
tests/v1/test_external_lb_dp.py
Normal file
312
tests/v1/test_external_lb_dp.py
Normal file
@ -0,0 +1,312 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from contextlib import AsyncExitStack
|
||||
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.platforms import Platform
|
||||
|
||||
MODEL_NAME = "ibm-research/PowerMoE-3b"
|
||||
|
||||
# Number of data parallel ranks for external LB testing
|
||||
DP_SIZE = int(os.getenv("DP_SIZE", "2"))
|
||||
# Default tensor parallell size to use
|
||||
TP_SIZE = int(os.getenv("TP_SIZE", "1"))
|
||||
|
||||
|
||||
class ExternalLBServerManager:
|
||||
"""Manages data parallel vLLM server instances for external
|
||||
load balancer testing."""
|
||||
|
||||
def __init__(self,
|
||||
model_name: str,
|
||||
dp_size: int,
|
||||
api_server_count: int,
|
||||
base_server_args: list,
|
||||
tp_size: int = TP_SIZE):
|
||||
self.model_name = model_name
|
||||
self.dp_size = dp_size
|
||||
self.tp_size = tp_size
|
||||
self.api_server_count = api_server_count
|
||||
self.base_server_args = base_server_args
|
||||
self.servers: list[tuple[RemoteOpenAIServer, list[str]]] = []
|
||||
self.server_threads: list[threading.Thread] = []
|
||||
|
||||
def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]:
|
||||
"""Start all server instances for external LB mode."""
|
||||
for rank in range(self.dp_size):
|
||||
# Create server args for this specific rank
|
||||
server_args = self.base_server_args.copy()
|
||||
|
||||
# Add external LB specific arguments
|
||||
server_args.extend([
|
||||
"--data-parallel-size",
|
||||
str(self.dp_size),
|
||||
"--data-parallel-rank",
|
||||
str(rank),
|
||||
"--data-parallel-size-local",
|
||||
"1",
|
||||
"--tensor-parallel-size",
|
||||
str(self.tp_size),
|
||||
"--port",
|
||||
str(8000 + rank), # Different port for each rank
|
||||
"--api-server-count",
|
||||
str(self.api_server_count),
|
||||
])
|
||||
|
||||
# Use a thread to start each server to allow parallel initialization
|
||||
def start_server(r: int, sargs: list[str]):
|
||||
try:
|
||||
# Start the server
|
||||
server = RemoteOpenAIServer(
|
||||
self.model_name,
|
||||
sargs,
|
||||
auto_port=False,
|
||||
env_dict={
|
||||
"CUDA_VISIBLE_DEVICES":
|
||||
",".join(
|
||||
str(Platform.device_id_to_physical_device_id(
|
||||
i))
|
||||
for i in range(r * TP_SIZE, (r + 1) * TP_SIZE))
|
||||
})
|
||||
server.__enter__()
|
||||
print(f"Server rank {r} started successfully with "
|
||||
f"{self.api_server_count} API servers")
|
||||
self.servers.append((server, sargs))
|
||||
except Exception as e:
|
||||
print(f"Failed to start server rank {r}: {e}")
|
||||
raise
|
||||
|
||||
thread = threading.Thread(target=start_server,
|
||||
args=(rank, server_args))
|
||||
thread.start()
|
||||
|
||||
self.server_threads.append(thread)
|
||||
|
||||
# Wait for all servers to start
|
||||
for thread in self.server_threads:
|
||||
thread.join()
|
||||
|
||||
# Give servers additional time to fully initialize and coordinate
|
||||
time.sleep(2)
|
||||
|
||||
if len(self.servers) != self.dp_size:
|
||||
raise Exception("Servers failed to start")
|
||||
|
||||
return self.servers
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Stop all server instances."""
|
||||
while self.servers:
|
||||
try:
|
||||
self.servers.pop()[0].__exit__(exc_type, exc_val, exc_tb)
|
||||
except Exception as e:
|
||||
print(f"Error stopping server: {e}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def default_server_args():
|
||||
return [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"2048",
|
||||
"--max-num-seqs",
|
||||
"128",
|
||||
"--enforce-eager",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=[1, 4])
|
||||
def servers(request, default_server_args):
|
||||
api_server_count = request.param
|
||||
with ExternalLBServerManager(MODEL_NAME, DP_SIZE, api_server_count,
|
||||
default_server_args) as server_list:
|
||||
yield server_list
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def clients(servers: list[tuple[RemoteOpenAIServer, list[str]]]):
|
||||
# Create a client for each server
|
||||
async with AsyncExitStack() as stack:
|
||||
yield [
|
||||
await stack.enter_async_context(server.get_async_client())
|
||||
for server, _ in servers
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME],
|
||||
)
|
||||
async def test_external_lb_single_completion(clients: list[
|
||||
openai.AsyncOpenAI], servers: list[tuple[RemoteOpenAIServer, list[str]]],
|
||||
model_name: str) -> None:
|
||||
|
||||
async def make_request(client: openai.AsyncOpenAI):
|
||||
completion = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt="Hello, my name is",
|
||||
max_tokens=10,
|
||||
temperature=1.0)
|
||||
|
||||
assert completion.id is not None
|
||||
assert completion.choices is not None and len(completion.choices) == 1
|
||||
|
||||
choice = completion.choices[0]
|
||||
# The exact number of tokens can vary slightly with temperature=1.0,
|
||||
# so we check for a reasonable minimum length.
|
||||
assert len(choice.text) >= 1
|
||||
# Finish reason might not always be 'length' if the model finishes early
|
||||
# or due to other reasons, especially with high temperature.
|
||||
# So, we'll accept 'length' or 'stop'.
|
||||
assert choice.finish_reason in ("length", "stop")
|
||||
|
||||
# Token counts can also vary, so we check they are positive.
|
||||
assert completion.usage.completion_tokens > 0
|
||||
assert completion.usage.prompt_tokens > 0
|
||||
assert completion.usage.total_tokens > 0
|
||||
return completion
|
||||
|
||||
# Test single request to each server
|
||||
for i, client in enumerate(clients):
|
||||
result = await make_request(client)
|
||||
assert result is not None
|
||||
print(f"Server {i} handled single completion request successfully")
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Send requests to all servers in round-robin fashion
|
||||
num_requests_per_server = 25 # Total 50 requests across 2 servers
|
||||
all_tasks = []
|
||||
|
||||
for i, client in enumerate(clients):
|
||||
tasks = [make_request(client) for _ in range(num_requests_per_server)]
|
||||
all_tasks.extend(tasks)
|
||||
|
||||
results = await asyncio.gather(*all_tasks)
|
||||
assert len(results) == num_requests_per_server * len(clients)
|
||||
assert all(completion is not None for completion in results)
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Second burst of requests
|
||||
all_tasks = []
|
||||
for i, client in enumerate(clients):
|
||||
tasks = [make_request(client) for _ in range(num_requests_per_server)]
|
||||
all_tasks.extend(tasks)
|
||||
|
||||
results = await asyncio.gather(*all_tasks)
|
||||
assert len(results) == num_requests_per_server * len(clients)
|
||||
assert all(completion is not None for completion in results)
|
||||
|
||||
_, server_args = servers[0]
|
||||
api_server_count = (
|
||||
server_args.count('--api-server-count')
|
||||
and server_args[server_args.index('--api-server-count') + 1] or 1)
|
||||
print(
|
||||
f"Successfully completed external LB test with {len(clients)} servers "
|
||||
f"(API server count: {api_server_count})")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME],
|
||||
)
|
||||
async def test_external_lb_completion_streaming(clients: list[
|
||||
openai.AsyncOpenAI], servers: list[tuple[RemoteOpenAIServer, list[str]]],
|
||||
model_name: str) -> None:
|
||||
prompt = "What is an LLM?"
|
||||
|
||||
async def make_streaming_request(client: openai.AsyncOpenAI):
|
||||
# Perform a non-streaming request to get the expected full output
|
||||
single_completion = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
single_output = single_completion.choices[0].text
|
||||
|
||||
# Perform the streaming request
|
||||
stream = await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True)
|
||||
chunks: list[str] = []
|
||||
finish_reason_count = 0
|
||||
last_chunk = None
|
||||
async for chunk in stream:
|
||||
chunks.append(chunk.choices[0].text)
|
||||
if chunk.choices[0].finish_reason is not None:
|
||||
finish_reason_count += 1
|
||||
last_chunk = chunk # Keep track of the last chunk
|
||||
|
||||
# finish reason should only return in the last block for OpenAI API
|
||||
assert finish_reason_count == 1, (
|
||||
"Finish reason should appear exactly once.")
|
||||
assert last_chunk is not None, (
|
||||
"Stream should have yielded at least one chunk.")
|
||||
assert last_chunk.choices[
|
||||
0].finish_reason == "length", "Finish reason should be 'length'."
|
||||
# Check that the combined text matches the non-streamed version.
|
||||
assert "".join(
|
||||
chunks
|
||||
) == single_output, "Streamed output should match non-streamed output."
|
||||
return True # Indicate success for this request
|
||||
|
||||
# Test single request to each server
|
||||
for i, client in enumerate(clients):
|
||||
result = await make_streaming_request(client)
|
||||
assert result is not None
|
||||
print(f"Server {i} handled single streaming request successfully")
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Send streaming requests to all servers in round-robin fashion
|
||||
num_requests_per_server = 25 # Total 50 requests across 2 servers
|
||||
all_tasks = []
|
||||
|
||||
for i, client in enumerate(clients):
|
||||
tasks = [
|
||||
make_streaming_request(client)
|
||||
for _ in range(num_requests_per_server)
|
||||
]
|
||||
all_tasks.extend(tasks)
|
||||
|
||||
results = await asyncio.gather(*all_tasks)
|
||||
assert len(results) == num_requests_per_server * len(clients)
|
||||
assert all(results), "Not all streaming requests completed successfully."
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Second burst of streaming requests
|
||||
all_tasks = []
|
||||
for i, client in enumerate(clients):
|
||||
tasks = [
|
||||
make_streaming_request(client)
|
||||
for _ in range(num_requests_per_server)
|
||||
]
|
||||
all_tasks.extend(tasks)
|
||||
|
||||
results = await asyncio.gather(*all_tasks)
|
||||
assert len(results) == num_requests_per_server * len(clients)
|
||||
assert all(results), "Not all streaming requests completed successfully."
|
||||
|
||||
_, server_args = servers[0]
|
||||
api_server_count = (
|
||||
server_args.count('--api-server-count')
|
||||
and server_args[server_args.index('--api-server-count') + 1] or 1)
|
||||
print(f"Successfully completed external LB streaming test with "
|
||||
f"{len(clients)} servers (API server count: {api_server_count})")
|
||||
@ -1784,6 +1784,10 @@ class ParallelConfig:
|
||||
"""Port of the data parallel master."""
|
||||
data_parallel_backend: str = "mp"
|
||||
"""Backend to use for data parallel, either "mp" or "ray"."""
|
||||
data_parallel_external_lb: bool = False
|
||||
"""Whether to use "external" DP LB mode. Applies only to online serving
|
||||
and when data_parallel_size > 0. Set implicitly when
|
||||
data_parallel_rank is provided explicitly to vllm serve."""
|
||||
enable_expert_parallel: bool = False
|
||||
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
|
||||
enable_eplb: bool = False
|
||||
@ -1953,6 +1957,11 @@ class ParallelConfig:
|
||||
if self.data_parallel_size > 1 or self.data_parallel_size_local == 0:
|
||||
# Data parallel was specified in the engine args.
|
||||
self.data_parallel_master_port = get_open_port()
|
||||
|
||||
if not (0 <= self.data_parallel_rank < self.data_parallel_size):
|
||||
raise ValueError(
|
||||
f"data_parallel_rank ({self.data_parallel_rank})"
|
||||
f" must be in the range [0, {self.data_parallel_size})")
|
||||
else:
|
||||
# Otherwise fall back to env vars (e.g. for offline SPMD case).
|
||||
self.data_parallel_size = envs.VLLM_DP_SIZE
|
||||
@ -1961,6 +1970,10 @@ class ParallelConfig:
|
||||
self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP
|
||||
self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT
|
||||
|
||||
if self.data_parallel_external_lb:
|
||||
raise ValueError("data_parallel_external_lb can only "
|
||||
"be set when data_parallel_size > 1")
|
||||
|
||||
if self.distributed_executor_backend == "external_launcher":
|
||||
import os
|
||||
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
|
||||
|
||||
@ -318,6 +318,7 @@ class EngineArgs:
|
||||
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
|
||||
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
|
||||
data_parallel_size: int = ParallelConfig.data_parallel_size
|
||||
data_parallel_rank: Optional[int] = None
|
||||
data_parallel_size_local: Optional[int] = None
|
||||
data_parallel_address: Optional[str] = None
|
||||
data_parallel_rpc_port: Optional[int] = None
|
||||
@ -655,6 +656,12 @@ class EngineArgs:
|
||||
**parallel_kwargs["tensor_parallel_size"])
|
||||
parallel_group.add_argument("--data-parallel-size", "-dp",
|
||||
**parallel_kwargs["data_parallel_size"])
|
||||
parallel_group.add_argument(
|
||||
'--data-parallel-rank',
|
||||
'-dpn',
|
||||
type=int,
|
||||
help='Data parallel rank of this instance. '
|
||||
'When set, enables external load balancer mode.')
|
||||
parallel_group.add_argument('--data-parallel-size-local',
|
||||
'-dpl',
|
||||
type=int,
|
||||
@ -1126,10 +1133,17 @@ class EngineArgs:
|
||||
# but we should not do this here.
|
||||
placement_group = ray.util.get_current_placement_group()
|
||||
|
||||
# Local DP size defaults to global DP size if not set.
|
||||
data_parallel_size_local = self.data_parallel_size if (
|
||||
self.data_parallel_size_local
|
||||
is None) else self.data_parallel_size_local
|
||||
data_parallel_external_lb = self.data_parallel_rank is not None
|
||||
if data_parallel_external_lb:
|
||||
assert self.data_parallel_size_local in (1, None), (
|
||||
"data_parallel_size_local must be 1 when data_parallel_rank "
|
||||
"is set")
|
||||
data_parallel_size_local = 1
|
||||
elif self.data_parallel_size_local is not None:
|
||||
data_parallel_size_local = self.data_parallel_size_local
|
||||
else:
|
||||
# Local DP size defaults to global DP size if not set.
|
||||
data_parallel_size_local = self.data_parallel_size
|
||||
|
||||
# DP address, used in multi-node case for torch distributed group
|
||||
# and ZMQ sockets.
|
||||
@ -1154,16 +1168,16 @@ class EngineArgs:
|
||||
self.data_parallel_rpc_port
|
||||
is not None) else ParallelConfig.data_parallel_rpc_port
|
||||
|
||||
data_parallel_backend = self.data_parallel_backend
|
||||
|
||||
parallel_config = ParallelConfig(
|
||||
pipeline_parallel_size=self.pipeline_parallel_size,
|
||||
tensor_parallel_size=self.tensor_parallel_size,
|
||||
data_parallel_size=self.data_parallel_size,
|
||||
data_parallel_rank=self.data_parallel_rank or 0,
|
||||
data_parallel_external_lb=data_parallel_external_lb,
|
||||
data_parallel_size_local=data_parallel_size_local,
|
||||
data_parallel_master_ip=data_parallel_address,
|
||||
data_parallel_rpc_port=data_parallel_rpc_port,
|
||||
data_parallel_backend=data_parallel_backend,
|
||||
data_parallel_backend=self.data_parallel_backend,
|
||||
enable_expert_parallel=self.enable_expert_parallel,
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.num_redundant_experts,
|
||||
|
||||
@ -5,9 +5,9 @@ import argparse
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
import uvloop
|
||||
import zmq
|
||||
|
||||
import vllm
|
||||
import vllm.envs as envs
|
||||
@ -21,17 +21,13 @@ from vllm.entrypoints.utils import (VLLM_SUBCMD_PARSER_EPILOG,
|
||||
from vllm.executor.multiproc_worker_utils import _add_prefix
|
||||
from vllm.logger import init_logger
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import FlexibleArgumentParser, get_tcp_uri, zmq_socket_ctx
|
||||
from vllm.v1.engine.coordinator import DPCoordinator
|
||||
from vllm.utils import FlexibleArgumentParser, get_tcp_uri
|
||||
from vllm.v1.engine.core import EngineCoreProc
|
||||
from vllm.v1.engine.core_client import CoreEngineProcManager
|
||||
from vllm.v1.engine.utils import CoreEngineProcManager, launch_core_engines
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
|
||||
from vllm.v1.utils import (APIServerProcessManager, CoreEngine,
|
||||
CoreEngineActorManager, EngineZmqAddresses,
|
||||
get_engine_client_zmq_addr,
|
||||
wait_for_completion_or_failure,
|
||||
wait_for_engine_startup)
|
||||
from vllm.v1.utils import (APIServerProcessManager,
|
||||
wait_for_completion_or_failure)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -48,11 +44,15 @@ class ServeSubcommand(CLISubcommand):
|
||||
|
||||
if args.headless or args.api_server_count < 1:
|
||||
run_headless(args)
|
||||
elif args.api_server_count > 1:
|
||||
run_multi_api_server(args)
|
||||
else:
|
||||
# Single API server (this process).
|
||||
uvloop.run(run_server(args))
|
||||
if args.data_parallel_start_rank:
|
||||
raise ValueError("data_parallel_start_rank is only "
|
||||
"applicable in headless mode")
|
||||
if args.api_server_count > 1:
|
||||
run_multi_api_server(args)
|
||||
else:
|
||||
# Single API server (this process).
|
||||
uvloop.run(run_server(args))
|
||||
|
||||
def validate(self, args: argparse.Namespace) -> None:
|
||||
validate_parsed_serve_args(args)
|
||||
@ -121,14 +121,19 @@ def run_headless(args: argparse.Namespace):
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
local_engine_count = parallel_config.data_parallel_size_local
|
||||
host = parallel_config.data_parallel_master_ip
|
||||
port = engine_args.data_parallel_rpc_port # add to config too
|
||||
handshake_address = get_tcp_uri(host, port)
|
||||
|
||||
if local_engine_count <= 0:
|
||||
raise ValueError("data_parallel_size_local must be > 0 in "
|
||||
"headless mode")
|
||||
|
||||
if parallel_config.data_parallel_rank is not None:
|
||||
raise ValueError("data_parallel_rank is not applicable in "
|
||||
"headless mode")
|
||||
|
||||
host = parallel_config.data_parallel_master_ip
|
||||
port = engine_args.data_parallel_rpc_port # add to config too
|
||||
handshake_address = get_tcp_uri(host, port)
|
||||
|
||||
# Catch SIGTERM and SIGINT to allow graceful shutdown.
|
||||
def signal_handler(signum, frame):
|
||||
logger.debug("Received %d signal.", signum)
|
||||
@ -148,7 +153,7 @@ def run_headless(args: argparse.Namespace):
|
||||
start_index=args.data_parallel_start_rank,
|
||||
local_start_index=0,
|
||||
vllm_config=vllm_config,
|
||||
on_head_node=False,
|
||||
local_client=False,
|
||||
handshake_address=handshake_address,
|
||||
executor_class=Executor.get_class(vllm_config),
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
@ -192,117 +197,53 @@ def run_multi_api_server(args: argparse.Namespace):
|
||||
" api_server_count > 1")
|
||||
model_config.disable_mm_preprocessor_cache = True
|
||||
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
log_stats = not engine_args.disable_log_stats
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
dp_rank = parallel_config.data_parallel_rank
|
||||
external_dp_lb = parallel_config.data_parallel_external_lb
|
||||
assert external_dp_lb or dp_rank == 0
|
||||
|
||||
assert parallel_config.data_parallel_rank == 0
|
||||
api_server_manager: Optional[APIServerProcessManager] = None
|
||||
|
||||
dp_size = parallel_config.data_parallel_size
|
||||
local_engine_count = parallel_config.data_parallel_size_local
|
||||
host = parallel_config.data_parallel_master_ip
|
||||
local_only = local_engine_count == dp_size
|
||||
with launch_core_engines(vllm_config, executor_class, log_stats,
|
||||
num_api_servers) as (local_engine_manager,
|
||||
coordinator, addresses):
|
||||
|
||||
# Set up input and output addresses.
|
||||
input_addresses = [
|
||||
get_engine_client_zmq_addr(local_only, host)
|
||||
for _ in range(num_api_servers)
|
||||
]
|
||||
output_addresses = [
|
||||
get_engine_client_zmq_addr(local_only, host)
|
||||
for _ in range(num_api_servers)
|
||||
]
|
||||
|
||||
addresses = EngineZmqAddresses(
|
||||
inputs=input_addresses,
|
||||
outputs=output_addresses,
|
||||
)
|
||||
|
||||
# Set up coordinator for dp > 1.
|
||||
coordinator = None
|
||||
stats_update_address = None
|
||||
if dp_size > 1:
|
||||
coordinator = DPCoordinator(parallel_config)
|
||||
addresses.coordinator_input, addresses.coordinator_output = (
|
||||
coordinator.get_engine_socket_addresses())
|
||||
stats_update_address = coordinator.get_stats_publish_address()
|
||||
logger.info("Started DP Coordinator process (PID: %d)",
|
||||
coordinator.proc.pid)
|
||||
|
||||
if parallel_config.data_parallel_backend == "ray":
|
||||
logger.info("Starting ray-based data parallel backend")
|
||||
|
||||
engine_actor_manager = CoreEngineActorManager(
|
||||
vllm_config=vllm_config,
|
||||
addresses=addresses,
|
||||
executor_class=Executor.get_class(vllm_config),
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
)
|
||||
# Start API servers using the manager
|
||||
api_server_manager = APIServerProcessManager(
|
||||
# Construct common args for the APIServerProcessManager up-front.
|
||||
api_server_manager_kwargs = dict(
|
||||
target_server_fn=run_api_server_worker_proc,
|
||||
listen_address=listen_address,
|
||||
sock=sock,
|
||||
args=args,
|
||||
num_servers=num_api_servers,
|
||||
input_addresses=input_addresses,
|
||||
output_addresses=output_addresses,
|
||||
stats_update_address=stats_update_address)
|
||||
input_addresses=addresses.inputs,
|
||||
output_addresses=addresses.outputs,
|
||||
stats_update_address=coordinator.get_stats_publish_address()
|
||||
if coordinator else None)
|
||||
|
||||
wait_for_completion_or_failure(api_server_manager=api_server_manager,
|
||||
engine_manager=engine_actor_manager,
|
||||
coordinator=coordinator)
|
||||
return
|
||||
# For dp ranks > 0 in external DP LB mode, we must delay the
|
||||
# start of the API servers until the local engine is started
|
||||
# (after the launcher context manager exits),
|
||||
# since we get the front-end stats update address from the coordinator
|
||||
# via the handshake with the local engine.
|
||||
if dp_rank == 0 or not external_dp_lb:
|
||||
# Start API servers using the manager.
|
||||
api_server_manager = APIServerProcessManager(
|
||||
**api_server_manager_kwargs)
|
||||
|
||||
handshake_address = get_engine_client_zmq_addr(
|
||||
local_only, host, parallel_config.data_parallel_rpc_port)
|
||||
|
||||
with zmq_socket_ctx(handshake_address, zmq.ROUTER,
|
||||
bind=True) as handshake_socket:
|
||||
|
||||
# Start local engines.
|
||||
if not local_engine_count:
|
||||
local_engine_manager = None
|
||||
else:
|
||||
local_engine_manager = CoreEngineProcManager(
|
||||
EngineCoreProc.run_engine_core,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=Executor.get_class(vllm_config),
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
handshake_address=handshake_address,
|
||||
on_head_node=True,
|
||||
local_engine_count=local_engine_count,
|
||||
start_index=0,
|
||||
local_start_index=0)
|
||||
|
||||
# Start API servers using the manager
|
||||
# Start API servers now if they weren't already started.
|
||||
if api_server_manager is None:
|
||||
api_server_manager_kwargs["stats_update_address"] = (
|
||||
addresses.frontend_stats_publish_address)
|
||||
api_server_manager = APIServerProcessManager(
|
||||
target_server_fn=run_api_server_worker_proc,
|
||||
listen_address=listen_address,
|
||||
sock=sock,
|
||||
args=args,
|
||||
num_servers=num_api_servers,
|
||||
input_addresses=input_addresses,
|
||||
output_addresses=output_addresses,
|
||||
stats_update_address=stats_update_address)
|
||||
**api_server_manager_kwargs)
|
||||
|
||||
# Wait for engine handshakes to complete.
|
||||
core_engines = [
|
||||
CoreEngine(index=i, local=(i < local_engine_count))
|
||||
for i in range(dp_size)
|
||||
]
|
||||
wait_for_engine_startup(
|
||||
handshake_socket,
|
||||
addresses,
|
||||
core_engines,
|
||||
parallel_config,
|
||||
vllm_config.cache_config,
|
||||
local_engine_manager,
|
||||
coordinator.proc if coordinator else None,
|
||||
)
|
||||
|
||||
# Wait for API servers
|
||||
wait_for_completion_or_failure(api_server_manager=api_server_manager,
|
||||
engine_manager=local_engine_manager,
|
||||
coordinator=coordinator)
|
||||
# Wait for API servers
|
||||
wait_for_completion_or_failure(api_server_manager=api_server_manager,
|
||||
engine_manager=local_engine_manager,
|
||||
coordinator=coordinator)
|
||||
|
||||
|
||||
def run_api_server_worker_proc(listen_address,
|
||||
|
||||
@ -10,7 +10,7 @@ import zmq
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import get_mp_context, get_open_zmq_ipc_path, make_zmq_socket
|
||||
from vllm.utils import get_mp_context, make_zmq_socket
|
||||
from vllm.v1.engine import EngineCoreOutputs, EngineCoreRequestType
|
||||
from vllm.v1.serial_utils import MsgpackDecoder
|
||||
from vllm.v1.utils import get_engine_client_zmq_addr, shutdown
|
||||
@ -48,20 +48,33 @@ class DPCoordinator:
|
||||
|
||||
Engines will move into running state when receiving a new request or
|
||||
START_DP_WAVE message.
|
||||
|
||||
Note that when deployed in External LB mode, no stats will be published by
|
||||
the engines and thus updates will only be sent to front-ends when the
|
||||
request wave / running state changes.
|
||||
"""
|
||||
|
||||
def __init__(self, parallel_config: ParallelConfig):
|
||||
|
||||
# Assume coordinator is colocated with front-end procs.
|
||||
front_publish_address = get_open_zmq_ipc_path()
|
||||
|
||||
dp_size = parallel_config.data_parallel_size
|
||||
assert dp_size > 1, "Coordinator only used for data parallel"
|
||||
|
||||
local_only = dp_size == parallel_config.data_parallel_size_local
|
||||
host = parallel_config.data_parallel_master_ip
|
||||
back_publish_address = get_engine_client_zmq_addr(local_only, host)
|
||||
back_output_address = get_engine_client_zmq_addr(local_only, host)
|
||||
external_lb = parallel_config.data_parallel_external_lb
|
||||
|
||||
# Assume coordinator is colocated with front-end procs when not in
|
||||
# external DP LB mode.
|
||||
front_publish_address = get_engine_client_zmq_addr(
|
||||
local_only=not external_lb, host=host)
|
||||
|
||||
local_only_eng = dp_size == parallel_config.data_parallel_size_local
|
||||
back_publish_address = get_engine_client_zmq_addr(local_only_eng, host)
|
||||
back_output_address = get_engine_client_zmq_addr(local_only_eng, host)
|
||||
|
||||
# When in external LB mode, load stats aren't published, only changes
|
||||
# to request wave / running state, so we don't need to rate-limit the
|
||||
# updates to the front-end proc(s).
|
||||
min_stats_update_interval_ms = 0 if external_lb else 100
|
||||
|
||||
context = get_mp_context()
|
||||
self.proc: multiprocessing.Process = context.Process(
|
||||
@ -72,6 +85,7 @@ class DPCoordinator:
|
||||
"front_publish_address": front_publish_address,
|
||||
"back_output_address": back_output_address,
|
||||
"back_publish_address": back_publish_address,
|
||||
"min_stats_update_interval_ms": min_stats_update_interval_ms,
|
||||
},
|
||||
daemon=True)
|
||||
self.proc.start()
|
||||
@ -100,12 +114,16 @@ class EngineState:
|
||||
|
||||
class CoordinatorProc:
|
||||
|
||||
def __init__(self, engine_count: int):
|
||||
def __init__(self,
|
||||
engine_count: int,
|
||||
min_stats_update_interval_ms: int = 100):
|
||||
|
||||
self.ctx = zmq.Context()
|
||||
|
||||
self.engines = [EngineState() for _ in range(engine_count)]
|
||||
|
||||
self.stats_update_interval_ms = min_stats_update_interval_ms
|
||||
|
||||
self.current_wave = 0
|
||||
self.engines_running = False
|
||||
self.stats_changed = False
|
||||
@ -116,8 +134,11 @@ class CoordinatorProc:
|
||||
front_publish_address: str,
|
||||
back_output_address: str,
|
||||
back_publish_address: str,
|
||||
min_stats_update_interval_ms: int = 100,
|
||||
):
|
||||
coordinator = CoordinatorProc(engine_count=engine_count)
|
||||
coordinator = CoordinatorProc(
|
||||
engine_count=engine_count,
|
||||
min_stats_update_interval_ms=min_stats_update_interval_ms)
|
||||
try:
|
||||
coordinator.process_input_socket(
|
||||
front_publish_address,
|
||||
@ -156,9 +177,10 @@ class CoordinatorProc:
|
||||
last_publish_time = 0
|
||||
while True:
|
||||
elapsed = int(time.time() * 1000) - last_publish_time
|
||||
# Send at 100 ms interval if the stats have changed,
|
||||
# or otherwise every 3 seconds.
|
||||
wait_for = 100 if self.stats_changed else 3000
|
||||
# Send at stats_update_interval_ms interval if the stats have
|
||||
# changed, or otherwise every 4 seconds.
|
||||
wait_for = (self.stats_update_interval_ms
|
||||
if self.stats_changed else 4000)
|
||||
events = poller.poll(timeout=max(0, wait_for - elapsed))
|
||||
if not events:
|
||||
# Poller timeout - publish current stats to front-ends.
|
||||
@ -174,7 +196,7 @@ class CoordinatorProc:
|
||||
|
||||
if publish_front in events:
|
||||
buffer = publish_front.recv()
|
||||
if buffer == b'\x01':
|
||||
if buffer in (b'\x01', b'\x00'):
|
||||
# Ignore subscription messages.
|
||||
continue
|
||||
|
||||
|
||||
@ -34,6 +34,7 @@ from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
|
||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
||||
EngineCoreRequestType, UtilityOutput)
|
||||
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
|
||||
from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.metrics.stats import SchedulerStats
|
||||
@ -41,7 +42,6 @@ from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
from vllm.v1.utils import EngineHandshakeMetadata, EngineZmqAddresses
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -367,10 +367,11 @@ class EngineCoreProc(EngineCore):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
on_head_node: bool,
|
||||
local_client: bool,
|
||||
handshake_address: str,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
client_handshake_address: Optional[str] = None,
|
||||
engine_index: int = 0,
|
||||
):
|
||||
self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
|
||||
@ -383,12 +384,21 @@ class EngineCoreProc(EngineCore):
|
||||
identity = self.engine_index.to_bytes(length=2, byteorder="little")
|
||||
self.engines_running = False
|
||||
|
||||
with self._perform_handshake(handshake_address, identity, on_head_node,
|
||||
vllm_config) as addresses:
|
||||
with self._perform_handshakes(handshake_address, identity,
|
||||
local_client, vllm_config,
|
||||
client_handshake_address) as addresses:
|
||||
self.client_count = len(addresses.outputs)
|
||||
|
||||
# Set up data parallel environment.
|
||||
self.has_coordinator = addresses.coordinator_output is not None
|
||||
self.frontend_stats_publish_address = (
|
||||
addresses.frontend_stats_publish_address)
|
||||
# Only publish request queue stats to coordinator for "internal"
|
||||
# LB mode.
|
||||
self.publish_dp_lb_stats = (
|
||||
self.has_coordinator
|
||||
and not vllm_config.parallel_config.data_parallel_external_lb)
|
||||
|
||||
self._init_data_parallel(vllm_config)
|
||||
|
||||
super().__init__(vllm_config, executor_class, log_stats,
|
||||
@ -414,45 +424,102 @@ class EngineCoreProc(EngineCore):
|
||||
self.output_thread.start()
|
||||
|
||||
@contextmanager
|
||||
def _perform_handshake(
|
||||
self, handshake_address: str, identity: bytes, on_head_node: bool,
|
||||
vllm_config: VllmConfig
|
||||
def _perform_handshakes(
|
||||
self,
|
||||
handshake_address: str,
|
||||
identity: bytes,
|
||||
local_client: bool,
|
||||
vllm_config: VllmConfig,
|
||||
client_handshake_address: Optional[str],
|
||||
) -> Generator[EngineZmqAddresses, None, None]:
|
||||
"""
|
||||
Perform startup handshakes.
|
||||
|
||||
For DP=1 or offline mode, this is with the colocated front-end process.
|
||||
|
||||
For DP>1 with internal loadbalancing this is with the shared front-end
|
||||
process which may reside on a different node.
|
||||
|
||||
For DP>1 with external loadbalancing, two handshakes are performed:
|
||||
- With the rank 0 front-end process which retrieves the
|
||||
DP Coordinator ZMQ addresses and DP process group address.
|
||||
- With the colocated front-end process which retrieves the
|
||||
client input/output socket addresses.
|
||||
with the exception of the rank 0 engine itself which doesn't require
|
||||
the second handshake.
|
||||
|
||||
Here, "front-end" process can mean the process containing the engine
|
||||
core client (which is the API server process in the case the API
|
||||
server is not scaled out), OR the launcher process running the
|
||||
run_multi_api_server() function in serve.py.
|
||||
"""
|
||||
input_ctx = zmq.Context()
|
||||
with make_zmq_socket(input_ctx,
|
||||
is_local = local_client and client_handshake_address is None
|
||||
handshake = self._perform_handshake(input_ctx, handshake_address,
|
||||
identity, is_local, vllm_config,
|
||||
vllm_config.parallel_config)
|
||||
if client_handshake_address is None:
|
||||
with handshake as addresses:
|
||||
yield addresses
|
||||
else:
|
||||
local_handshake = self._perform_handshake(
|
||||
input_ctx, client_handshake_address, identity, local_client,
|
||||
vllm_config)
|
||||
with handshake as addresses, local_handshake as client_addresses:
|
||||
addresses.inputs = client_addresses.inputs
|
||||
addresses.outputs = client_addresses.outputs
|
||||
yield addresses
|
||||
|
||||
# Update config which may have changed from the handshake
|
||||
vllm_config.__post_init__()
|
||||
|
||||
@contextmanager
|
||||
def _perform_handshake(
|
||||
self,
|
||||
ctx: zmq.Context,
|
||||
handshake_address: str,
|
||||
identity: bytes,
|
||||
local_client: bool,
|
||||
vllm_config: VllmConfig,
|
||||
parallel_config_to_update: Optional[ParallelConfig] = None,
|
||||
) -> Generator[EngineZmqAddresses, None, None]:
|
||||
with make_zmq_socket(ctx,
|
||||
handshake_address,
|
||||
zmq.DEALER,
|
||||
identity=identity,
|
||||
linger=5000,
|
||||
bind=False) as handshake_socket:
|
||||
# Register engine with front-end.
|
||||
addresses = self.startup_handshake(handshake_socket, on_head_node,
|
||||
vllm_config.parallel_config)
|
||||
|
||||
# Update config which may have changed from the handshake
|
||||
vllm_config.__post_init__()
|
||||
|
||||
addresses = self.startup_handshake(handshake_socket, local_client,
|
||||
parallel_config_to_update)
|
||||
yield addresses
|
||||
|
||||
# Send ready message.
|
||||
num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
|
||||
# We pass back the coordinator stats update address here for the
|
||||
# external LB case for our colocated front-end to use (coordinator
|
||||
# only runs with rank 0).
|
||||
dp_stats_address = self.frontend_stats_publish_address
|
||||
handshake_socket.send(
|
||||
msgspec.msgpack.encode({
|
||||
"status": "READY",
|
||||
"local": on_head_node,
|
||||
"local": local_client,
|
||||
"num_gpu_blocks": num_gpu_blocks,
|
||||
"dp_stats_address": dp_stats_address,
|
||||
}))
|
||||
|
||||
@staticmethod
|
||||
def startup_handshake(
|
||||
handshake_socket: zmq.Socket, on_head_node: bool,
|
||||
parallel_config: ParallelConfig) -> EngineZmqAddresses:
|
||||
handshake_socket: zmq.Socket,
|
||||
local_client: bool,
|
||||
parallel_config: Optional[ParallelConfig] = None,
|
||||
) -> EngineZmqAddresses:
|
||||
|
||||
# Send registration message.
|
||||
handshake_socket.send(
|
||||
msgspec.msgpack.encode({
|
||||
"status": "HELLO",
|
||||
"local": on_head_node,
|
||||
"local": local_client,
|
||||
}))
|
||||
|
||||
# Receive initialization message.
|
||||
@ -466,9 +533,9 @@ class EngineCoreProc(EngineCore):
|
||||
init_bytes, type=EngineHandshakeMetadata)
|
||||
logger.debug("Received init message: %s", init_message)
|
||||
|
||||
received_parallel_config = init_message.parallel_config
|
||||
for key, value in received_parallel_config.items():
|
||||
setattr(parallel_config, key, value)
|
||||
if parallel_config is not None:
|
||||
for key, value in init_message.parallel_config.items():
|
||||
setattr(parallel_config, key, value)
|
||||
|
||||
return init_message.addresses
|
||||
|
||||
@ -749,12 +816,12 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
on_head_node: bool,
|
||||
local_client: bool,
|
||||
handshake_address: str,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
client_handshake_address: Optional[str] = None,
|
||||
):
|
||||
|
||||
self._decorate_logs()
|
||||
|
||||
# Counts forward-passes of the model so that we can synchronize
|
||||
@ -765,8 +832,9 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
|
||||
# Initialize the engine.
|
||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
super().__init__(vllm_config, on_head_node, handshake_address,
|
||||
executor_class, log_stats, dp_rank)
|
||||
super().__init__(vllm_config, local_client, handshake_address,
|
||||
executor_class, log_stats, client_handshake_address,
|
||||
dp_rank)
|
||||
|
||||
def _decorate_logs(self):
|
||||
# Add process-specific prefix to stdout and stderr before
|
||||
@ -799,10 +867,18 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
from vllm.platforms import current_platform
|
||||
device_control_env_var = current_platform.device_control_env_var
|
||||
world_size = vllm_config.parallel_config.world_size
|
||||
os.environ[device_control_env_var] = ",".join(
|
||||
str(current_platform.device_id_to_physical_device_id(i))
|
||||
for i in range(local_dp_rank * world_size, (local_dp_rank + 1) *
|
||||
world_size))
|
||||
# Set CUDA_VISIBLE_DEVICES or equivalent.
|
||||
try:
|
||||
os.environ[device_control_env_var] = ",".join(
|
||||
str(current_platform.device_id_to_physical_device_id(i))
|
||||
for i in range(local_dp_rank *
|
||||
world_size, (local_dp_rank + 1) * world_size))
|
||||
except IndexError as e:
|
||||
raise Exception(
|
||||
f"Error setting {device_control_env_var}: "
|
||||
f"local range: [{local_dp_rank * world_size}, "
|
||||
f"{(local_dp_rank + 1) * world_size}) "
|
||||
f"base value: \"{os.getenv(device_control_env_var)}\"") from e
|
||||
|
||||
self.dp_rank = dp_rank
|
||||
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
|
||||
@ -839,7 +915,7 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
super()._handle_client_request(request_type, request)
|
||||
|
||||
def _maybe_publish_request_counts(self):
|
||||
if not self.has_coordinator:
|
||||
if not self.publish_dp_lb_stats:
|
||||
return
|
||||
|
||||
# Publish our request counts (if they've changed).
|
||||
@ -892,9 +968,9 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
|
||||
def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
|
||||
|
||||
# Optimization - only perform finish-sync all-reduce every 24 steps.
|
||||
# Optimization - only perform finish-sync all-reduce every 32 steps.
|
||||
self.counter += 1
|
||||
if self.counter != 24:
|
||||
if self.counter != 32:
|
||||
return True
|
||||
self.counter = 0
|
||||
|
||||
@ -910,7 +986,7 @@ class DPEngineCoreActor(DPEngineCoreProc):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
on_head_node: bool,
|
||||
local_client: bool,
|
||||
addresses: EngineZmqAddresses,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
@ -927,15 +1003,16 @@ class DPEngineCoreActor(DPEngineCoreProc):
|
||||
# data parallel groups.
|
||||
del os.environ['CUDA_VISIBLE_DEVICES']
|
||||
|
||||
super().__init__(vllm_config, on_head_node, "", executor_class,
|
||||
super().__init__(vllm_config, local_client, "", executor_class,
|
||||
log_stats)
|
||||
|
||||
def _decorate_logs(self):
|
||||
pass
|
||||
|
||||
@contextmanager
|
||||
def _perform_handshake(self, handshake_address: str, identity: bytes,
|
||||
on_head_node: bool, vllm_config: VllmConfig):
|
||||
def _perform_handshakes(self, handshake_address: str, identity: bytes,
|
||||
local_client: bool, vllm_config: VllmConfig,
|
||||
client_handshake_address: Optional[str]):
|
||||
"""
|
||||
For Ray, we don't need to actually perform handshake.
|
||||
All addresses information is known before the actor creation.
|
||||
|
||||
@ -7,7 +7,7 @@ import sys
|
||||
import uuid
|
||||
import weakref
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Awaitable, Sequence
|
||||
from concurrent.futures import Future
|
||||
from dataclasses import dataclass
|
||||
@ -21,18 +21,16 @@ import zmq.asyncio
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.utils import (get_open_zmq_inproc_path, make_zmq_socket,
|
||||
zmq_socket_ctx)
|
||||
from vllm.utils import get_open_zmq_inproc_path, make_zmq_socket
|
||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
||||
EngineCoreRequestType, UtilityOutput)
|
||||
from vllm.v1.engine.coordinator import DPCoordinator
|
||||
from vllm.v1.engine.core import EngineCore, EngineCoreProc
|
||||
from vllm.v1.engine.exceptions import EngineDeadError
|
||||
from vllm.v1.engine.utils import (CoreEngineActorManager,
|
||||
CoreEngineProcManager, launch_core_engines)
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr
|
||||
from vllm.v1.utils import (CoreEngine, CoreEngineActorManager,
|
||||
CoreEngineProcManager, EngineZmqAddresses,
|
||||
get_engine_client_zmq_addr, wait_for_engine_startup)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -40,6 +38,8 @@ AnyFuture = Union[asyncio.Future[Any], Future[Any]]
|
||||
|
||||
_R = TypeVar('_R') # Return type for collective_rpc
|
||||
|
||||
EngineIdentity = bytes
|
||||
|
||||
|
||||
class EngineCoreClient(ABC):
|
||||
"""
|
||||
@ -84,14 +84,16 @@ class EngineCoreClient(ABC):
|
||||
client_addresses: Optional[dict[str, str]] = None,
|
||||
client_index: int = 0,
|
||||
) -> "MPClient":
|
||||
if vllm_config.parallel_config.data_parallel_size > 1:
|
||||
if vllm_config.parallel_config.data_parallel_backend == "ray":
|
||||
return RayDPClient(vllm_config, executor_class, log_stats,
|
||||
client_addresses, client_index)
|
||||
return DPAsyncMPClient(vllm_config, executor_class, log_stats,
|
||||
client_addresses, client_index)
|
||||
return AsyncMPClient(vllm_config, executor_class, log_stats,
|
||||
client_addresses, client_index)
|
||||
parallel_config = vllm_config.parallel_config
|
||||
client_args = (vllm_config, executor_class, log_stats,
|
||||
client_addresses, client_index)
|
||||
if parallel_config.data_parallel_size > 1:
|
||||
if parallel_config.data_parallel_external_lb:
|
||||
# External load balancer - client per DP rank.
|
||||
return DPAsyncMPClient(*client_args)
|
||||
# Internal load balancer - client balances to all DP ranks.
|
||||
return DPLBAsyncMPClient(*client_args)
|
||||
return AsyncMPClient(*client_args)
|
||||
|
||||
@abstractmethod
|
||||
def shutdown(self):
|
||||
@ -386,42 +388,32 @@ class MPClient(EngineCoreClient):
|
||||
self._finalizer = weakref.finalize(self, self.resources)
|
||||
success = False
|
||||
try:
|
||||
parallel_config = vllm_config.parallel_config
|
||||
local_engine_count = parallel_config.data_parallel_size_local
|
||||
local_start_index = parallel_config.data_parallel_rank_local
|
||||
dp_size = parallel_config.data_parallel_size
|
||||
dp_rank = parallel_config.data_parallel_rank
|
||||
|
||||
# State used for data parallel.
|
||||
self.engines_running = False
|
||||
|
||||
# SPMD mode is where there is an LLM instance per DP rank and
|
||||
# one core engine per LLM, see
|
||||
# examples/offline_inference/data_parallel.py.
|
||||
spmd_mode = local_start_index is not None
|
||||
if spmd_mode:
|
||||
assert local_engine_count == 1
|
||||
self.core_engines = [CoreEngine(index=dp_rank, local=True)]
|
||||
else:
|
||||
assert dp_rank == 0
|
||||
local_start_index = 0
|
||||
self.core_engines = [
|
||||
CoreEngine(index=i, local=(i < local_engine_count))
|
||||
for i in range(dp_size)
|
||||
]
|
||||
|
||||
local_only = spmd_mode or local_engine_count == dp_size
|
||||
|
||||
self.stats_update_address: Optional[str] = None
|
||||
if client_addresses is not None:
|
||||
# Engines are managed externally to this client.
|
||||
input_address = client_addresses["input_address"]
|
||||
output_address = client_addresses["output_address"]
|
||||
self.stats_update_address = client_addresses.get(
|
||||
"stats_update_address")
|
||||
else:
|
||||
host = parallel_config.data_parallel_master_ip
|
||||
input_address = get_engine_client_zmq_addr(local_only, host)
|
||||
output_address = get_engine_client_zmq_addr(local_only, host)
|
||||
# Engines are managed by this client.
|
||||
with launch_core_engines(vllm_config, executor_class,
|
||||
log_stats) as (engine_manager,
|
||||
coordinator,
|
||||
addresses):
|
||||
self.resources.coordinator = coordinator
|
||||
self.resources.engine_manager = engine_manager
|
||||
|
||||
(input_address, ) = addresses.inputs
|
||||
(output_address, ) = addresses.outputs
|
||||
self.stats_update_address = (
|
||||
addresses.frontend_stats_publish_address)
|
||||
if coordinator is not None:
|
||||
assert self.stats_update_address == (
|
||||
coordinator.get_stats_publish_address())
|
||||
|
||||
# Create input and output sockets.
|
||||
self.input_socket = self.resources.input_socket = make_zmq_socket(
|
||||
@ -429,18 +421,24 @@ class MPClient(EngineCoreClient):
|
||||
self.resources.output_socket = make_zmq_socket(
|
||||
self.ctx, output_address, zmq.PULL)
|
||||
|
||||
if client_addresses is None:
|
||||
self._init_engines_direct(vllm_config, local_only,
|
||||
local_start_index, input_address,
|
||||
output_address, executor_class,
|
||||
log_stats)
|
||||
coordinator = self.resources.coordinator
|
||||
if coordinator:
|
||||
self.stats_update_address = (
|
||||
coordinator.get_stats_publish_address())
|
||||
parallel_config = vllm_config.parallel_config
|
||||
dp_size = parallel_config.data_parallel_size
|
||||
dp_rank = parallel_config.data_parallel_rank
|
||||
external_dp_lb = parallel_config.data_parallel_external_lb
|
||||
|
||||
offline_mode = parallel_config.data_parallel_rank_local is not None
|
||||
engine_ranks = [dp_rank] if (offline_mode
|
||||
or external_dp_lb) else range(dp_size)
|
||||
assert parallel_config.data_parallel_size_local <= len(
|
||||
engine_ranks)
|
||||
|
||||
# ZMQ identity of each engine that this client will talk to.
|
||||
self.core_engines: list[EngineIdentity] = [
|
||||
index.to_bytes(2, "little") for index in engine_ranks
|
||||
]
|
||||
|
||||
# Wait for ready messages from each engine on the input socket.
|
||||
identities = set(e.identity for e in self.core_engines)
|
||||
identities = set(self.core_engines)
|
||||
sync_input_socket = zmq.Socket.shadow(self.input_socket)
|
||||
while identities:
|
||||
if not sync_input_socket.poll(timeout=600_000):
|
||||
@ -449,7 +447,7 @@ class MPClient(EngineCoreClient):
|
||||
identity, _ = sync_input_socket.recv_multipart()
|
||||
identities.remove(identity)
|
||||
|
||||
self.core_engine = self.core_engines[0]
|
||||
self.core_engine: EngineIdentity = self.core_engines[0]
|
||||
self.utility_results: dict[int, AnyFuture] = {}
|
||||
|
||||
# Request objects which may contain pytorch-allocated tensors
|
||||
@ -462,73 +460,6 @@ class MPClient(EngineCoreClient):
|
||||
if not success:
|
||||
self._finalizer()
|
||||
|
||||
def _init_engines_direct(self, vllm_config: VllmConfig, local_only: bool,
|
||||
local_start_index: int, input_address: str,
|
||||
output_address: str,
|
||||
executor_class: type[Executor], log_stats: bool):
|
||||
"""Self-contained client mode, launch engine and coordinator process
|
||||
as needed."""
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
local_engine_count = parallel_config.data_parallel_size_local
|
||||
start_index = parallel_config.data_parallel_rank
|
||||
host = parallel_config.data_parallel_master_ip
|
||||
|
||||
if len(self.core_engines) > 1:
|
||||
self.resources.coordinator = DPCoordinator(parallel_config)
|
||||
|
||||
handshake_address = get_engine_client_zmq_addr(
|
||||
local_only, host, parallel_config.data_parallel_rpc_port)
|
||||
|
||||
with zmq_socket_ctx(handshake_address, zmq.ROUTER,
|
||||
bind=True) as handshake_socket:
|
||||
|
||||
# Start local engines.
|
||||
if local_engine_count:
|
||||
# In server mode, start_index and local_start_index will
|
||||
# both be 0.
|
||||
self.resources.engine_manager = CoreEngineProcManager(
|
||||
EngineCoreProc.run_engine_core,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=log_stats,
|
||||
handshake_address=handshake_address,
|
||||
on_head_node=True,
|
||||
local_engine_count=local_engine_count,
|
||||
start_index=start_index,
|
||||
local_start_index=local_start_index)
|
||||
|
||||
# Wait for engine core process(es) to start.
|
||||
self._wait_for_engine_startup(handshake_socket, input_address,
|
||||
output_address)
|
||||
|
||||
def _wait_for_engine_startup(self, handshake_socket: zmq.Socket,
|
||||
input_address: str, output_address: str):
|
||||
addresses = EngineZmqAddresses(
|
||||
inputs=[input_address],
|
||||
outputs=[output_address],
|
||||
)
|
||||
|
||||
coordinator = self.resources.coordinator
|
||||
if coordinator is not None:
|
||||
addresses.coordinator_input, addresses.coordinator_output = (
|
||||
coordinator.get_engine_socket_addresses())
|
||||
|
||||
proc_manager = self.resources.engine_manager
|
||||
assert isinstance(proc_manager, (type(None), CoreEngineProcManager)), (
|
||||
"_wait_for_engine_startup should only be "
|
||||
"called with CoreEngineProcManager")
|
||||
|
||||
wait_for_engine_startup(
|
||||
handshake_socket,
|
||||
addresses,
|
||||
self.core_engines,
|
||||
self.vllm_config.parallel_config,
|
||||
self.vllm_config.cache_config,
|
||||
proc_manager,
|
||||
coordinator.proc if coordinator else None,
|
||||
)
|
||||
|
||||
def shutdown(self):
|
||||
# Terminate background resources.
|
||||
self._finalizer()
|
||||
@ -583,7 +514,6 @@ class SyncMPClient(MPClient):
|
||||
# a ref to the client which prevents gc.
|
||||
ctx = self.ctx
|
||||
out_socket = self.resources.output_socket
|
||||
assert out_socket is not None
|
||||
decoder = self.decoder
|
||||
utility_results = self.utility_results
|
||||
outputs_queue = self.outputs_queue
|
||||
@ -593,6 +523,7 @@ class SyncMPClient(MPClient):
|
||||
resources.shutdown_path = shutdown_path
|
||||
|
||||
def process_outputs_socket():
|
||||
assert isinstance(out_socket, zmq.Socket)
|
||||
shutdown_socket = ctx.socket(zmq.PAIR)
|
||||
try:
|
||||
shutdown_socket.bind(shutdown_path)
|
||||
@ -609,7 +540,7 @@ class SyncMPClient(MPClient):
|
||||
|
||||
frames = out_socket.recv_multipart(copy=False)
|
||||
resources.validate_alive(frames)
|
||||
outputs = decoder.decode(frames)
|
||||
outputs: EngineCoreOutputs = decoder.decode(frames)
|
||||
if outputs.utility_output:
|
||||
_process_utility_output(outputs.utility_output,
|
||||
utility_results)
|
||||
@ -646,7 +577,7 @@ class SyncMPClient(MPClient):
|
||||
self.ensure_alive()
|
||||
self.free_pending_messages()
|
||||
# (Identity, RequestType, SerializedRequest)
|
||||
msg = (self.core_engine.identity, request_type.value,
|
||||
msg = (self.core_engine, request_type.value,
|
||||
*self.encoder.encode(request))
|
||||
|
||||
if len(msg) <= 3:
|
||||
@ -812,7 +743,7 @@ class AsyncMPClient(MPClient):
|
||||
def _send_input(self,
|
||||
request_type: EngineCoreRequestType,
|
||||
request: Any,
|
||||
engine: Optional[CoreEngine] = None) -> Awaitable[Any]:
|
||||
engine: Optional[EngineIdentity] = None) -> Awaitable[Any]:
|
||||
if engine is None:
|
||||
engine = self.core_engine
|
||||
|
||||
@ -820,7 +751,7 @@ class AsyncMPClient(MPClient):
|
||||
return self._send_input_message(message, engine, request)
|
||||
|
||||
def _send_input_message(self, message: tuple[bytestr,
|
||||
...], engine: CoreEngine,
|
||||
...], engine: EngineIdentity,
|
||||
objects: Any) -> Awaitable[Any]:
|
||||
"""
|
||||
objects is a reference to retain until zmq is finished with the
|
||||
@ -829,7 +760,7 @@ class AsyncMPClient(MPClient):
|
||||
self.ensure_alive()
|
||||
self.free_pending_messages()
|
||||
|
||||
msg = (engine.identity, ) + message
|
||||
msg = (engine, ) + message
|
||||
if not objects or len(msg) <= 3:
|
||||
# No auxiliary buffers => no tensor backing buffers in request.
|
||||
return self.input_socket.send_multipart(msg, copy=False)
|
||||
@ -850,7 +781,7 @@ class AsyncMPClient(MPClient):
|
||||
engine=self.core_engine)
|
||||
|
||||
async def _call_utility_async(self, method: str, *args,
|
||||
engine: CoreEngine) -> Any:
|
||||
engine: EngineIdentity) -> Any:
|
||||
call_id = uuid.uuid1().int >> 64
|
||||
future = asyncio.get_running_loop().create_future()
|
||||
self.utility_results[call_id] = future
|
||||
@ -921,7 +852,7 @@ class AsyncMPClient(MPClient):
|
||||
|
||||
class DPAsyncMPClient(AsyncMPClient):
|
||||
"""Asyncio-compatible client for multi-proc, multi-engine (data parallel)
|
||||
EngineCore."""
|
||||
EngineCore. Assumes external load-balancing by default."""
|
||||
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
@ -930,15 +861,12 @@ class DPAsyncMPClient(AsyncMPClient):
|
||||
client_addresses: Optional[dict[str, str]] = None,
|
||||
client_index: int = 0):
|
||||
self.current_wave = 0
|
||||
# To route aborts to the correct engine.
|
||||
self.reqs_in_flight: dict[str, CoreEngine] = {}
|
||||
|
||||
super().__init__(vllm_config, executor_class, log_stats,
|
||||
client_addresses, client_index)
|
||||
|
||||
assert len(self.core_engines) > 1
|
||||
|
||||
# List of [waiting, running] pair per engine.
|
||||
# Used only by DPLBAsyncMPClient subclass.
|
||||
self.lb_engines: list[list[int]] = []
|
||||
|
||||
self.first_req_sock_addr = get_open_zmq_inproc_path()
|
||||
@ -969,6 +897,8 @@ class DPAsyncMPClient(AsyncMPClient):
|
||||
self.first_req_sock_addr,
|
||||
zmq.PAIR,
|
||||
bind=False) as first_req_rcv_socket:
|
||||
assert isinstance(socket, zmq.asyncio.Socket)
|
||||
assert isinstance(first_req_rcv_socket, zmq.asyncio.Socket)
|
||||
# Send subscription message.
|
||||
await socket.send(b'\x01')
|
||||
|
||||
@ -1012,34 +942,75 @@ class DPAsyncMPClient(AsyncMPClient):
|
||||
resources.stats_update_task = asyncio.create_task(
|
||||
run_engine_stats_update_task())
|
||||
|
||||
def get_core_engine_for_request(self,
|
||||
dp_rank: Optional[int] = None
|
||||
) -> CoreEngine:
|
||||
if dp_rank is not None:
|
||||
# engines are already in rank order
|
||||
return self.core_engines[dp_rank]
|
||||
async def add_request_async(self, request: EngineCoreRequest) -> None:
|
||||
self._ensure_stats_update_task()
|
||||
|
||||
if not self.lb_engines:
|
||||
return self.core_engines[0]
|
||||
# TODO use P2C alg for larger DP sizes
|
||||
num_engines = len(self.lb_engines)
|
||||
min_counts = [sys.maxsize, sys.maxsize]
|
||||
eng_index = 0
|
||||
for i in range(num_engines):
|
||||
# Start from client_index to help with balancing when engines
|
||||
# are empty.
|
||||
idx = (self.client_index + i) % num_engines
|
||||
counts = self.lb_engines[idx]
|
||||
if counts < min_counts:
|
||||
min_counts = counts
|
||||
eng_index = idx
|
||||
# Adjust local counts for better balancing between stats updates
|
||||
# from the coordinator (which happen every 100ms).
|
||||
if min_counts[0]:
|
||||
min_counts[0] += 1
|
||||
else:
|
||||
min_counts[1] += 1
|
||||
return self.core_engines[eng_index]
|
||||
request.current_wave = self.current_wave
|
||||
request.client_index = self.client_index
|
||||
|
||||
chosen_engine = self.get_core_engine_for_request(request)
|
||||
to_await = self._send_input(EngineCoreRequestType.ADD, request,
|
||||
chosen_engine)
|
||||
if not self.engines_running:
|
||||
# Notify coordinator that we're sending a request
|
||||
await self.first_req_send_socket.send(chosen_engine)
|
||||
|
||||
await to_await
|
||||
|
||||
self._ensure_output_queue_task()
|
||||
|
||||
def get_core_engine_for_request(self, request: EngineCoreRequest):
|
||||
return self.core_engine
|
||||
|
||||
|
||||
class DPLBAsyncMPClient(DPAsyncMPClient):
|
||||
"""Asyncio-compatible client for multi-proc, multi-engine (data parallel)
|
||||
EngineCore. Load-balances between multiple engine processes."""
|
||||
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
client_addresses: Optional[dict[str, str]] = None,
|
||||
client_index: int = 0):
|
||||
|
||||
# To route aborts to the correct engine.
|
||||
self.reqs_in_flight: dict[str, EngineIdentity] = {}
|
||||
|
||||
super().__init__(vllm_config, executor_class, log_stats,
|
||||
client_addresses, client_index)
|
||||
|
||||
assert len(self.core_engines) > 1
|
||||
|
||||
def get_core_engine_for_request(
|
||||
self, request: EngineCoreRequest) -> EngineIdentity:
|
||||
# Engines are in rank order.
|
||||
if (eng_index := request.data_parallel_rank) is None:
|
||||
if not self.lb_engines:
|
||||
return self.core_engine
|
||||
# TODO use P2C alg for larger DP sizes
|
||||
num_engines = len(self.lb_engines)
|
||||
min_counts = [sys.maxsize, sys.maxsize]
|
||||
eng_index = 0
|
||||
for i in range(num_engines):
|
||||
# Start from client_index to help with balancing when engines
|
||||
# are empty.
|
||||
idx = (self.client_index + i) % num_engines
|
||||
counts = self.lb_engines[idx]
|
||||
if counts < min_counts:
|
||||
min_counts = counts
|
||||
eng_index = idx
|
||||
# Adjust local counts for better balancing between stats updates
|
||||
# from the coordinator (which happen every 100ms).
|
||||
if min_counts[0]:
|
||||
min_counts[0] += 1
|
||||
else:
|
||||
min_counts[1] += 1
|
||||
|
||||
chosen_engine = self.core_engines[eng_index]
|
||||
# Record which engine is chosen for this request, to handle aborts.
|
||||
self.reqs_in_flight[request.request_id] = chosen_engine
|
||||
return chosen_engine
|
||||
|
||||
async def call_utility_async(self, method: str, *args) -> Any:
|
||||
# Only the result from the first engine is returned.
|
||||
@ -1048,28 +1019,8 @@ class DPAsyncMPClient(AsyncMPClient):
|
||||
for engine in self.core_engines
|
||||
]))[0]
|
||||
|
||||
async def add_request_async(self, request: EngineCoreRequest) -> None:
|
||||
self._ensure_stats_update_task()
|
||||
|
||||
request.current_wave = self.current_wave
|
||||
request.client_index = self.client_index
|
||||
|
||||
chosen_engine = self.get_core_engine_for_request(
|
||||
request.data_parallel_rank)
|
||||
self.reqs_in_flight[request.request_id] = chosen_engine
|
||||
|
||||
to_await = self._send_input(EngineCoreRequestType.ADD, request,
|
||||
chosen_engine)
|
||||
if not self.engines_running:
|
||||
# Notify coordinator that we're sending a request
|
||||
await self.first_req_send_socket.send(chosen_engine.identity)
|
||||
|
||||
await to_await
|
||||
|
||||
self._ensure_output_queue_task()
|
||||
|
||||
@staticmethod
|
||||
async def process_engine_outputs(self: "DPAsyncMPClient",
|
||||
async def process_engine_outputs(self: "DPLBAsyncMPClient",
|
||||
outputs: EngineCoreOutputs):
|
||||
if outputs.finished_requests and self.reqs_in_flight:
|
||||
for req_id in outputs.finished_requests:
|
||||
@ -1085,61 +1036,14 @@ class DPAsyncMPClient(AsyncMPClient):
|
||||
await self._abort_requests(request_ids, engine)
|
||||
return
|
||||
|
||||
by_engine: dict[CoreEngine, list[str]] = {}
|
||||
by_engine = defaultdict[EngineIdentity, list[str]](list)
|
||||
for req_id in request_ids:
|
||||
if engine := self.reqs_in_flight.get(req_id):
|
||||
by_engine.setdefault(engine, []).append(req_id)
|
||||
by_engine[engine].append(req_id)
|
||||
for engine, req_ids in by_engine.items():
|
||||
await self._abort_requests(req_ids, engine)
|
||||
|
||||
async def _abort_requests(self, request_ids: list[str],
|
||||
engine: CoreEngine) -> None:
|
||||
engine: EngineIdentity) -> None:
|
||||
await self._send_input(EngineCoreRequestType.ABORT, request_ids,
|
||||
engine)
|
||||
|
||||
|
||||
class RayDPClient(DPAsyncMPClient):
|
||||
"""
|
||||
Ray-based client for multi-proc, multi-engine (data parallel)
|
||||
EngineCore.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
client_addresses: Optional[dict[str, str]] = None,
|
||||
client_index: int = 0,
|
||||
):
|
||||
super().__init__(vllm_config, executor_class, log_stats,
|
||||
client_addresses, client_index)
|
||||
|
||||
def _init_engines_direct(self, vllm_config: VllmConfig, local_only: bool,
|
||||
local_start_index: int, input_address: str,
|
||||
output_address: str,
|
||||
executor_class: type[Executor], log_stats: bool):
|
||||
"""Self-contained client mode, launch engine and coordinator process
|
||||
as needed."""
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
assert parallel_config.data_parallel_rank == 0
|
||||
assert local_start_index == 0
|
||||
|
||||
addresses = EngineZmqAddresses(
|
||||
inputs=[input_address],
|
||||
outputs=[output_address],
|
||||
)
|
||||
|
||||
if len(self.core_engines) > 1:
|
||||
coordinator = DPCoordinator(parallel_config)
|
||||
self.resources.coordinator = coordinator
|
||||
addresses.coordinator_input, addresses.coordinator_output = (
|
||||
coordinator.get_engine_socket_addresses())
|
||||
|
||||
# Start all engines.
|
||||
self.resources.engine_manager = CoreEngineActorManager(
|
||||
vllm_config=vllm_config,
|
||||
addresses=addresses,
|
||||
executor_class=executor_class,
|
||||
log_stats=log_stats)
|
||||
|
||||
546
vllm/v1/engine/utils.py
Normal file
546
vllm/v1/engine/utils.py
Normal file
@ -0,0 +1,546 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import contextlib
|
||||
import weakref
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from multiprocessing import Process, connection
|
||||
from multiprocessing.process import BaseProcess
|
||||
from typing import TYPE_CHECKING, Callable, Optional, Union
|
||||
|
||||
import msgspec
|
||||
import zmq
|
||||
|
||||
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import get_mp_context, get_open_zmq_ipc_path, zmq_socket_ctx
|
||||
from vllm.v1.engine.coordinator import DPCoordinator
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.utils import get_engine_client_zmq_addr, shutdown
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
STARTUP_POLL_PERIOD_MS = 10000
|
||||
|
||||
|
||||
class CoreEngineState(Enum):
|
||||
NEW = auto()
|
||||
CONNECTED = auto()
|
||||
READY = auto()
|
||||
|
||||
|
||||
class CoreEngine:
|
||||
"""One per data parallel rank, used to track state during handshaking."""
|
||||
|
||||
def __init__(self, index: int = 0, local: bool = True):
|
||||
self.local = local
|
||||
self.identity = index.to_bytes(2, "little")
|
||||
|
||||
self.state = CoreEngineState.NEW
|
||||
|
||||
|
||||
@dataclass
|
||||
class EngineZmqAddresses:
|
||||
# ZMQ input socket addresses for each front-end client (requests)
|
||||
inputs: list[str]
|
||||
# ZMQ output socket addresses for each front-end client (responses)
|
||||
outputs: list[str]
|
||||
# ZMQ input socket address of DP coordinator if applicable
|
||||
coordinator_input: Optional[str] = None
|
||||
# ZMQ output socket address of DP coordinator if applicable
|
||||
coordinator_output: Optional[str] = None
|
||||
# ZMQ socket for front-end to connect to DP coordinator.
|
||||
# Not used by engine, just relayed to front-end in handshake response.
|
||||
# Only required for external DP LB case.
|
||||
frontend_stats_publish_address: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EngineHandshakeMetadata:
|
||||
"""Metadata sent to each engine process during startup handshake,
|
||||
including addresses of the front-end ZMQ queues that they should
|
||||
connect to.
|
||||
"""
|
||||
addresses: EngineZmqAddresses
|
||||
parallel_config: dict[str, Union[int, str]]
|
||||
|
||||
|
||||
class CoreEngineProcManager:
|
||||
"""
|
||||
Utility class to handle creation, readiness, and shutdown
|
||||
of background processes used by the AsyncLLM and LLMEngine.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target_fn: Callable,
|
||||
local_engine_count: int,
|
||||
start_index: int,
|
||||
local_start_index: int,
|
||||
vllm_config: VllmConfig,
|
||||
local_client: bool,
|
||||
handshake_address: str,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
client_handshake_address: Optional[str] = None,
|
||||
):
|
||||
context = get_mp_context()
|
||||
common_kwargs = {
|
||||
"vllm_config": vllm_config,
|
||||
"local_client": local_client,
|
||||
"handshake_address": handshake_address,
|
||||
"executor_class": executor_class,
|
||||
"log_stats": log_stats,
|
||||
}
|
||||
|
||||
if client_handshake_address:
|
||||
common_kwargs[
|
||||
"client_handshake_address"] = client_handshake_address
|
||||
|
||||
self.processes: list[BaseProcess] = []
|
||||
for index in range(local_engine_count):
|
||||
local_index = local_start_index + index
|
||||
global_index = start_index + index
|
||||
# Start EngineCore in background process.
|
||||
self.processes.append(
|
||||
context.Process(target=target_fn,
|
||||
name=f"EngineCore_{global_index}",
|
||||
kwargs=common_kwargs | {
|
||||
"dp_rank": global_index,
|
||||
"local_dp_rank": local_index,
|
||||
}))
|
||||
|
||||
self._finalizer = weakref.finalize(self, shutdown, self.processes)
|
||||
try:
|
||||
for proc in self.processes:
|
||||
proc.start()
|
||||
finally:
|
||||
# Kill other procs if not all are running.
|
||||
if self.finished_procs():
|
||||
self.close()
|
||||
|
||||
def close(self):
|
||||
"""Shutdown all procs."""
|
||||
self._finalizer()
|
||||
|
||||
def join_first(self):
|
||||
"""Wait for any process to exit."""
|
||||
connection.wait(proc.sentinel for proc in self.processes)
|
||||
|
||||
def sentinels(self) -> list:
|
||||
return [proc.sentinel for proc in self.processes]
|
||||
|
||||
def finished_procs(self) -> dict[str, int]:
|
||||
"""Returns dict of proc name -> exit code for any finished procs."""
|
||||
return {
|
||||
proc.name: proc.exitcode
|
||||
for proc in self.processes if proc.exitcode is not None
|
||||
}
|
||||
|
||||
|
||||
class CoreEngineActorManager:
|
||||
"""
|
||||
Utility class to handle creation, readiness, and shutdown
|
||||
of core engine Ray actors used by the AsyncLLM and LLMEngine.
|
||||
|
||||
Different from CoreEngineProcManager, this class manages
|
||||
core engines for both local and remote nodes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
addresses: EngineZmqAddresses,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
placement_groups: Optional[list["PlacementGroup"]] = None,
|
||||
local_dp_ranks: Optional[list[int]] = None,
|
||||
):
|
||||
import copy
|
||||
|
||||
import ray
|
||||
from ray.util.scheduling_strategies import (
|
||||
PlacementGroupSchedulingStrategy)
|
||||
|
||||
from vllm.v1.engine.core import DPEngineCoreActor
|
||||
|
||||
self.local_engine_actors: list[ray.ActorHandle] = []
|
||||
self.remote_engine_actors: list[ray.ActorHandle] = []
|
||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
local_engine_count = \
|
||||
vllm_config.parallel_config.data_parallel_size_local
|
||||
world_size = vllm_config.parallel_config.world_size
|
||||
|
||||
if ray.is_initialized():
|
||||
logger.info(
|
||||
"Ray is already initialized. Skipping Ray initialization.")
|
||||
else:
|
||||
ray.init()
|
||||
|
||||
if placement_groups is not None:
|
||||
assert local_dp_ranks is not None, (
|
||||
"local_dp_ranks must be provided if "
|
||||
"placement_groups is provided")
|
||||
assert len(placement_groups) == len(local_dp_ranks), (
|
||||
"placement_groups and local_dp_ranks must "
|
||||
"have the same length")
|
||||
logger.info("Using provided placement groups")
|
||||
# TODO(rui): validate passed-in placement groups
|
||||
self.created_placement_groups = []
|
||||
else:
|
||||
placement_groups, local_dp_ranks = \
|
||||
CoreEngineActorManager.create_dp_placement_groups(vllm_config)
|
||||
self.created_placement_groups = placement_groups
|
||||
assert len(placement_groups) == dp_size, (
|
||||
"Number of placement groups must match data parallel size")
|
||||
|
||||
refs = []
|
||||
for index in range(dp_size):
|
||||
local_index = local_dp_ranks[index]
|
||||
dp_vllm_config = copy.deepcopy(vllm_config)
|
||||
pg = placement_groups[index]
|
||||
dp_vllm_config.parallel_config.placement_group = pg
|
||||
local_client = index < local_engine_count
|
||||
actor = ray.remote(DPEngineCoreActor).options(
|
||||
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||
placement_group=pg,
|
||||
placement_group_bundle_index=world_size,
|
||||
)).remote(vllm_config=dp_vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=log_stats,
|
||||
local_client=local_client,
|
||||
addresses=addresses,
|
||||
dp_rank=index,
|
||||
local_dp_rank=local_index)
|
||||
if local_client:
|
||||
self.local_engine_actors.append(actor)
|
||||
else:
|
||||
self.remote_engine_actors.append(actor)
|
||||
refs.append(actor.wait_for_init.remote())
|
||||
|
||||
ray.get(refs)
|
||||
self.run_refs = []
|
||||
for actor in self.local_engine_actors + self.remote_engine_actors:
|
||||
self.run_refs.append(actor.run.remote())
|
||||
|
||||
@staticmethod
|
||||
def create_dp_placement_groups(
|
||||
vllm_config: VllmConfig
|
||||
) -> tuple[list["PlacementGroup"], list[int]]:
|
||||
|
||||
import ray
|
||||
from ray._private.state import available_resources_per_node
|
||||
from ray.util.state import list_nodes
|
||||
|
||||
logger.info("Creating placement groups for data parallel")
|
||||
dp_master_ip = \
|
||||
vllm_config.parallel_config.data_parallel_master_ip
|
||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
local_engine_count = \
|
||||
vllm_config.parallel_config.data_parallel_size_local
|
||||
|
||||
nodes = sorted(list_nodes(),
|
||||
key=lambda node: node.node_ip != dp_master_ip)
|
||||
assert nodes[0].node_ip == dp_master_ip, (
|
||||
"The first node must be the head node")
|
||||
assert len(nodes) == 1 or nodes[1].node_ip != dp_master_ip, (
|
||||
"There can only be one head node")
|
||||
|
||||
available_resources = available_resources_per_node()
|
||||
world_size = vllm_config.parallel_config.world_size
|
||||
placement_groups: list[PlacementGroup] = []
|
||||
local_dp_ranks: list[int] = []
|
||||
|
||||
for node in nodes:
|
||||
node_ip = node.node_ip
|
||||
node_resources = available_resources[node.node_id]
|
||||
# For now, each DP rank can only be assigned to one node
|
||||
# TODO(rui): support allocating a single DP rank
|
||||
# to multiple nodes
|
||||
available_engine_count = int(node_resources["GPU"]) // world_size
|
||||
if node_ip == dp_master_ip:
|
||||
assert available_engine_count >= local_engine_count, (
|
||||
"Not enough resources to allocate DP ranks "
|
||||
f"on DP master node {node_ip}")
|
||||
for i in range(local_engine_count):
|
||||
bundles = [{
|
||||
"GPU": 1.0,
|
||||
"node:" + dp_master_ip: 0.001
|
||||
}] * world_size + [{
|
||||
"CPU": 1.0
|
||||
}]
|
||||
pg = ray.util.placement_group(
|
||||
name=f"dp_rank_{len(placement_groups)}",
|
||||
strategy="STRICT_PACK",
|
||||
bundles=bundles,
|
||||
)
|
||||
placement_groups.append(pg)
|
||||
local_dp_ranks.append(i)
|
||||
else:
|
||||
for i in range(available_engine_count):
|
||||
if len(placement_groups) == dp_size:
|
||||
break
|
||||
bundles = [{"GPU": 1.0}] * world_size + [{"CPU": 1.0}]
|
||||
pg = ray.util.placement_group(
|
||||
name=f"dp_rank_{len(placement_groups)}",
|
||||
strategy="STRICT_PACK",
|
||||
bundles=bundles,
|
||||
)
|
||||
placement_groups.append(pg)
|
||||
local_dp_ranks.append(i)
|
||||
return placement_groups, local_dp_ranks
|
||||
|
||||
def get_run_refs(self):
|
||||
return self.run_refs
|
||||
|
||||
def close(self):
|
||||
import ray
|
||||
for actor in self.local_engine_actors + self.remote_engine_actors:
|
||||
ray.kill(actor)
|
||||
for pg in self.created_placement_groups:
|
||||
ray.util.remove_placement_group(pg)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def launch_core_engines(
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
num_api_servers: int = 1,
|
||||
) -> Iterator[tuple[
|
||||
Optional[Union[CoreEngineProcManager, CoreEngineActorManager]],
|
||||
Optional[DPCoordinator],
|
||||
EngineZmqAddresses,
|
||||
]]:
|
||||
"""Launch engine and DP coordinator processes as needed."""
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
dp_size = parallel_config.data_parallel_size
|
||||
local_engine_count = parallel_config.data_parallel_size_local
|
||||
local_start_index = parallel_config.data_parallel_rank_local
|
||||
dp_rank = parallel_config.data_parallel_rank
|
||||
host = parallel_config.data_parallel_master_ip
|
||||
external_dp_lb = parallel_config.data_parallel_external_lb
|
||||
|
||||
# In offline mode there is an LLM instance per DP rank and
|
||||
# one core engine per LLM, see
|
||||
# examples/offline_inference/data_parallel.py.
|
||||
offline_mode = local_start_index is not None
|
||||
|
||||
# client_local_only = True for cases where this front-end
|
||||
# sends requests only to colocated engines.
|
||||
client_local_only = offline_mode or external_dp_lb or (local_engine_count
|
||||
== dp_size)
|
||||
|
||||
# Set up input and output addresses.
|
||||
addresses = EngineZmqAddresses(
|
||||
inputs=[
|
||||
get_engine_client_zmq_addr(client_local_only, host)
|
||||
for _ in range(num_api_servers)
|
||||
],
|
||||
outputs=[
|
||||
get_engine_client_zmq_addr(client_local_only, host)
|
||||
for _ in range(num_api_servers)
|
||||
],
|
||||
)
|
||||
|
||||
# Run the DP Coordinator process with rank 0 when in
|
||||
# online DP mode.
|
||||
run_coordinator = dp_size > 1 and not offline_mode and dp_rank == 0
|
||||
|
||||
if run_coordinator:
|
||||
coordinator = DPCoordinator(parallel_config)
|
||||
|
||||
addresses.coordinator_input, addresses.coordinator_output = (
|
||||
coordinator.get_engine_socket_addresses())
|
||||
addresses.frontend_stats_publish_address = (
|
||||
coordinator.get_stats_publish_address())
|
||||
|
||||
logger.info("Started DP Coordinator process (PID: %d)",
|
||||
coordinator.proc.pid)
|
||||
else:
|
||||
coordinator = None
|
||||
|
||||
if parallel_config.data_parallel_backend == "ray":
|
||||
logger.info("Starting ray-based data parallel backend")
|
||||
|
||||
engine_actor_manager = CoreEngineActorManager(
|
||||
vllm_config=vllm_config,
|
||||
addresses=addresses,
|
||||
executor_class=executor_class,
|
||||
log_stats=log_stats,
|
||||
)
|
||||
|
||||
yield engine_actor_manager, coordinator, addresses
|
||||
return
|
||||
|
||||
if offline_mode or (external_dp_lb and dp_rank > 0):
|
||||
assert local_engine_count == 1
|
||||
engines_to_handshake = [CoreEngine(index=dp_rank, local=True)]
|
||||
else:
|
||||
engines_to_handshake = [
|
||||
CoreEngine(index=i, local=(i < local_engine_count))
|
||||
for i in range(dp_size)
|
||||
]
|
||||
|
||||
# Whether the started engines will handshake only with co-located
|
||||
# front-end processes. In external_dp_lb mode, ranks > 0 handshake with
|
||||
# their co-located frontend and also the rank 0 front-end, and hence this
|
||||
# will be False.
|
||||
handshake_local_only = offline_mode or local_engine_count == dp_size
|
||||
|
||||
handshake_address = get_engine_client_zmq_addr(
|
||||
handshake_local_only, host, parallel_config.data_parallel_rpc_port)
|
||||
|
||||
if external_dp_lb and dp_rank > 0:
|
||||
assert not handshake_local_only
|
||||
local_handshake_address = get_open_zmq_ipc_path()
|
||||
client_handshake_address = local_handshake_address
|
||||
else:
|
||||
local_handshake_address = handshake_address
|
||||
client_handshake_address = None
|
||||
|
||||
with zmq_socket_ctx(local_handshake_address, zmq.ROUTER,
|
||||
bind=True) as handshake_socket:
|
||||
|
||||
from vllm.v1.engine.core import EngineCoreProc
|
||||
|
||||
# Start local engines.
|
||||
if local_engine_count:
|
||||
# In server mode, start_index and local_start_index will
|
||||
# both be 0.
|
||||
local_engine_manager = CoreEngineProcManager(
|
||||
EngineCoreProc.run_engine_core,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=log_stats,
|
||||
handshake_address=handshake_address,
|
||||
client_handshake_address=client_handshake_address,
|
||||
local_client=True,
|
||||
local_engine_count=local_engine_count,
|
||||
start_index=dp_rank,
|
||||
local_start_index=local_start_index or 0)
|
||||
else:
|
||||
local_engine_manager = None
|
||||
|
||||
yield local_engine_manager, coordinator, addresses
|
||||
|
||||
# Now wait for engines to start.
|
||||
wait_for_engine_startup(
|
||||
handshake_socket,
|
||||
addresses,
|
||||
engines_to_handshake,
|
||||
parallel_config,
|
||||
vllm_config.cache_config,
|
||||
local_engine_manager,
|
||||
coordinator.proc if coordinator else None,
|
||||
)
|
||||
|
||||
|
||||
def wait_for_engine_startup(
|
||||
handshake_socket: zmq.Socket,
|
||||
addresses: EngineZmqAddresses,
|
||||
core_engines: list[CoreEngine],
|
||||
parallel_config: ParallelConfig,
|
||||
cache_config: CacheConfig,
|
||||
proc_manager: Optional[CoreEngineProcManager],
|
||||
coord_process: Optional[Process],
|
||||
):
|
||||
# Wait for engine core process(es) to send ready messages.
|
||||
local_count = parallel_config.data_parallel_size_local
|
||||
remote_count = len(core_engines) - local_count
|
||||
# [local, remote] counts
|
||||
conn_pending, start_pending = [local_count, remote_count], [0, 0]
|
||||
poller = zmq.Poller()
|
||||
poller.register(handshake_socket, zmq.POLLIN)
|
||||
|
||||
if proc_manager is not None:
|
||||
for sentinel in proc_manager.sentinels():
|
||||
poller.register(sentinel, zmq.POLLIN)
|
||||
if coord_process is not None:
|
||||
poller.register(coord_process.sentinel, zmq.POLLIN)
|
||||
while any(conn_pending) or any(start_pending):
|
||||
events = poller.poll(STARTUP_POLL_PERIOD_MS)
|
||||
if not events:
|
||||
if any(conn_pending):
|
||||
logger.debug(
|
||||
"Waiting for %d local, %d remote core engine proc(s) "
|
||||
"to connect.", *conn_pending)
|
||||
if any(start_pending):
|
||||
logger.debug(
|
||||
"Waiting for %d local, %d remote core engine proc(s) "
|
||||
"to start.", *start_pending)
|
||||
continue
|
||||
if len(events) > 1 or events[0][0] != handshake_socket:
|
||||
# One of the local core processes exited.
|
||||
finished = proc_manager.finished_procs() if proc_manager else {}
|
||||
if coord_process is not None and coord_process.exitcode is not None:
|
||||
finished[coord_process.name] = coord_process.exitcode
|
||||
raise RuntimeError("Engine core initialization failed. "
|
||||
"See root cause above. "
|
||||
f"Failed core proc(s): {finished}")
|
||||
|
||||
# Receive HELLO and READY messages from the input socket.
|
||||
eng_identity, ready_msg_bytes = handshake_socket.recv_multipart()
|
||||
eng_index = int.from_bytes(eng_identity, "little")
|
||||
engine = next((e for e in core_engines if e.identity == eng_identity),
|
||||
None)
|
||||
if engine is None:
|
||||
raise RuntimeError(f"Message from engine with unexpected data "
|
||||
f"parallel rank: {eng_index}")
|
||||
msg = msgspec.msgpack.decode(ready_msg_bytes)
|
||||
status, local = msg["status"], msg["local"]
|
||||
if local != engine.local:
|
||||
raise RuntimeError(f"{status} message from "
|
||||
f"{'local' if local else 'remote'} "
|
||||
f"engine {eng_index}, expected it to be "
|
||||
f"{'local' if engine.local else 'remote'}")
|
||||
|
||||
if status == "HELLO" and engine.state == CoreEngineState.NEW:
|
||||
|
||||
# Send init message with DP config info.
|
||||
init_message = msgspec.msgpack.encode(
|
||||
EngineHandshakeMetadata(
|
||||
addresses=addresses,
|
||||
parallel_config={
|
||||
"data_parallel_master_ip":
|
||||
parallel_config.data_parallel_master_ip,
|
||||
"data_parallel_master_port":
|
||||
parallel_config.data_parallel_master_port,
|
||||
"data_parallel_size":
|
||||
parallel_config.data_parallel_size,
|
||||
}))
|
||||
handshake_socket.send_multipart((eng_identity, init_message),
|
||||
copy=False)
|
||||
conn_pending[0 if local else 1] -= 1
|
||||
start_pending[0 if local else 1] += 1
|
||||
engine.state = CoreEngineState.CONNECTED
|
||||
elif status == "READY" and engine.state == CoreEngineState.CONNECTED:
|
||||
# Setup KV cache config with initialization state from
|
||||
# engine core process. Sum values from all engines in DP case.
|
||||
num_gpu_blocks = cache_config.num_gpu_blocks or 0
|
||||
num_gpu_blocks += msg["num_gpu_blocks"]
|
||||
cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
|
||||
# In external DP LB mode, the coordinator address that the
|
||||
# front-end procs connect to is obtained from rank 0 via
|
||||
# one of the engine handshakes, and passed to the local
|
||||
# front-end process in the response from the other.
|
||||
if addresses.frontend_stats_publish_address is None:
|
||||
addresses.frontend_stats_publish_address = msg.get(
|
||||
"dp_stats_address")
|
||||
|
||||
start_pending[0 if local else 1] -= 1
|
||||
engine.state = CoreEngineState.READY
|
||||
else:
|
||||
raise RuntimeError(f"Unexpected {status} message for "
|
||||
f"{'local' if local else 'remote'} engine "
|
||||
f"{eng_index} in {engine.state} state.")
|
||||
|
||||
logger.debug("%s from %s core engine process %s.", status,
|
||||
"local" if local else "remote", eng_index)
|
||||
402
vllm/v1/utils.py
402
vllm/v1/utils.py
@ -1,44 +1,35 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import multiprocessing
|
||||
import time
|
||||
import weakref
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from multiprocessing import Process, connection
|
||||
from multiprocessing import connection
|
||||
from multiprocessing.process import BaseProcess
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar,
|
||||
Union, overload)
|
||||
|
||||
import msgspec
|
||||
import torch
|
||||
import zmq
|
||||
|
||||
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
||||
usage_message)
|
||||
from vllm.utils import (get_mp_context, get_open_port, get_open_zmq_ipc_path,
|
||||
get_tcp_uri, kill_process_tree)
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.utils import (get_open_port, get_open_zmq_ipc_path, get_tcp_uri,
|
||||
kill_process_tree)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.v1.engine.coordinator import DPCoordinator
|
||||
from vllm.v1.engine.utils import (CoreEngineActorManager,
|
||||
CoreEngineProcManager)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
STARTUP_POLL_PERIOD_MS = 10000
|
||||
|
||||
|
||||
class ConstantList(Generic[T], Sequence):
|
||||
|
||||
@ -111,49 +102,18 @@ class ConstantList(Generic[T], Sequence):
|
||||
def get_engine_client_zmq_addr(local_only: bool,
|
||||
host: str,
|
||||
port: int = 0) -> str:
|
||||
"""Assign a new ZMQ socket address.
|
||||
|
||||
If local_only is True, participants are colocated and so a unique IPC
|
||||
address will be returned.
|
||||
|
||||
Otherwise, the provided host and port will be used to construct a TCP
|
||||
address (port == 0 means assign an available port)."""
|
||||
|
||||
return get_open_zmq_ipc_path() if local_only else (get_tcp_uri(
|
||||
host, port or get_open_port()))
|
||||
|
||||
|
||||
class CoreEngineState(Enum):
|
||||
NEW = auto()
|
||||
CONNECTED = auto()
|
||||
READY = auto()
|
||||
|
||||
|
||||
class CoreEngine:
|
||||
"""One per data parallel rank."""
|
||||
|
||||
def __init__(self, index: int = 0, local: bool = True):
|
||||
self.local = local
|
||||
self.index = index
|
||||
self.identity = index.to_bytes(2, "little")
|
||||
|
||||
self.state = CoreEngineState.NEW
|
||||
|
||||
|
||||
@dataclass
|
||||
class EngineZmqAddresses:
|
||||
# ZMQ input socket addresses for each front-end client (requests)
|
||||
inputs: list[str]
|
||||
# ZMQ output socket addresses for each front-end client (responses)
|
||||
outputs: list[str]
|
||||
# ZMQ input socket address of DP coordinator if applicable
|
||||
coordinator_input: Optional[str] = None
|
||||
# ZMQ output socket address of DP coordinator if applicable
|
||||
coordinator_output: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EngineHandshakeMetadata:
|
||||
"""Metadata sent to each engine process during startup handshake,
|
||||
including addresses of the front-end ZMQ queues that they should
|
||||
connect to.
|
||||
"""
|
||||
addresses: EngineZmqAddresses
|
||||
parallel_config: dict[str, Union[int, str]]
|
||||
|
||||
|
||||
class APIServerProcessManager:
|
||||
"""Manages a group of API server processes.
|
||||
|
||||
@ -219,339 +179,10 @@ class APIServerProcessManager:
|
||||
self._finalizer()
|
||||
|
||||
|
||||
class CoreEngineProcManager:
|
||||
"""
|
||||
Utility class to handle creation, readiness, and shutdown
|
||||
of background processes used by the AsyncLLM and LLMEngine.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target_fn: Callable,
|
||||
local_engine_count: int,
|
||||
start_index: int,
|
||||
local_start_index: int,
|
||||
vllm_config: VllmConfig,
|
||||
on_head_node: bool,
|
||||
handshake_address: str,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
):
|
||||
context = get_mp_context()
|
||||
common_kwargs = {
|
||||
"vllm_config": vllm_config,
|
||||
"on_head_node": on_head_node,
|
||||
"handshake_address": handshake_address,
|
||||
"executor_class": executor_class,
|
||||
"log_stats": log_stats,
|
||||
}
|
||||
|
||||
self.processes: list[BaseProcess] = []
|
||||
for index in range(local_engine_count):
|
||||
local_index = local_start_index + index
|
||||
global_index = start_index + index
|
||||
# Start EngineCore in background process.
|
||||
self.processes.append(
|
||||
context.Process(target=target_fn,
|
||||
name=f"EngineCore_{global_index}",
|
||||
kwargs=common_kwargs | {
|
||||
"dp_rank": global_index,
|
||||
"local_dp_rank": local_index,
|
||||
}))
|
||||
|
||||
self._finalizer = weakref.finalize(self, shutdown, self.processes)
|
||||
try:
|
||||
for proc in self.processes:
|
||||
proc.start()
|
||||
finally:
|
||||
# Kill other procs if not all are running.
|
||||
if self.finished_procs():
|
||||
self.close()
|
||||
|
||||
def close(self):
|
||||
"""Shutdown all procs."""
|
||||
self._finalizer()
|
||||
|
||||
def join_first(self):
|
||||
"""Wait for any process to exit."""
|
||||
connection.wait(proc.sentinel for proc in self.processes)
|
||||
|
||||
def sentinels(self) -> list:
|
||||
return [proc.sentinel for proc in self.processes]
|
||||
|
||||
def finished_procs(self) -> dict[str, int]:
|
||||
"""Returns dict of proc name -> exit code for any finished procs."""
|
||||
return {
|
||||
proc.name: proc.exitcode
|
||||
for proc in self.processes if proc.exitcode is not None
|
||||
}
|
||||
|
||||
|
||||
class CoreEngineActorManager:
|
||||
"""
|
||||
Utility class to handle creation, readiness, and shutdown
|
||||
of core engine Ray actors used by the AsyncLLM and LLMEngine.
|
||||
|
||||
Different from CoreEngineProcManager, this class manages
|
||||
core engines for both local and remote nodes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
addresses: EngineZmqAddresses,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
placement_groups: Optional[list["PlacementGroup"]] = None,
|
||||
local_dp_ranks: Optional[list[int]] = None,
|
||||
):
|
||||
import copy
|
||||
|
||||
import ray
|
||||
from ray.util.scheduling_strategies import (
|
||||
PlacementGroupSchedulingStrategy)
|
||||
|
||||
from vllm.v1.engine.core import DPEngineCoreActor
|
||||
|
||||
self.local_engine_actors: list[ray.ActorHandle] = []
|
||||
self.remote_engine_actors: list[ray.ActorHandle] = []
|
||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
local_engine_count = \
|
||||
vllm_config.parallel_config.data_parallel_size_local
|
||||
world_size = vllm_config.parallel_config.world_size
|
||||
|
||||
if ray.is_initialized():
|
||||
logger.info(
|
||||
"Ray is already initialized. Skipping Ray initialization.")
|
||||
else:
|
||||
ray.init()
|
||||
|
||||
if placement_groups is not None:
|
||||
assert local_dp_ranks is not None, (
|
||||
"local_dp_ranks must be provided if "
|
||||
"placement_groups is provided")
|
||||
assert len(placement_groups) == len(local_dp_ranks), (
|
||||
"placement_groups and local_dp_ranks must "
|
||||
"have the same length")
|
||||
logger.info("Using provided placement groups")
|
||||
# TODO(rui): validate passed-in placement groups
|
||||
self.created_placement_groups = []
|
||||
else:
|
||||
placement_groups, local_dp_ranks = \
|
||||
CoreEngineActorManager.create_dp_placement_groups(vllm_config)
|
||||
self.created_placement_groups = placement_groups
|
||||
assert len(placement_groups) == dp_size, (
|
||||
"Number of placement groups must match data parallel size")
|
||||
|
||||
refs = []
|
||||
for index in range(dp_size):
|
||||
local_index = local_dp_ranks[index]
|
||||
dp_vllm_config = copy.deepcopy(vllm_config)
|
||||
pg = placement_groups[index]
|
||||
dp_vllm_config.parallel_config.placement_group = pg
|
||||
on_head_node = index < local_engine_count
|
||||
actor = ray.remote(DPEngineCoreActor).options(
|
||||
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||
placement_group=pg,
|
||||
placement_group_bundle_index=world_size,
|
||||
)).remote(vllm_config=dp_vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=log_stats,
|
||||
on_head_node=on_head_node,
|
||||
addresses=addresses,
|
||||
dp_rank=index,
|
||||
local_dp_rank=local_index)
|
||||
if on_head_node:
|
||||
self.local_engine_actors.append(actor)
|
||||
else:
|
||||
self.remote_engine_actors.append(actor)
|
||||
refs.append(actor.wait_for_init.remote())
|
||||
|
||||
ray.get(refs)
|
||||
self.run_refs = []
|
||||
for actor in self.local_engine_actors + self.remote_engine_actors:
|
||||
self.run_refs.append(actor.run.remote())
|
||||
|
||||
@staticmethod
|
||||
def create_dp_placement_groups(
|
||||
vllm_config: VllmConfig
|
||||
) -> tuple[list["PlacementGroup"], list[int]]:
|
||||
|
||||
import ray
|
||||
from ray._private.state import available_resources_per_node
|
||||
from ray.util.state import list_nodes
|
||||
|
||||
logger.info("Creating placement groups for data parallel")
|
||||
dp_master_ip = \
|
||||
vllm_config.parallel_config.data_parallel_master_ip
|
||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
local_engine_count = \
|
||||
vllm_config.parallel_config.data_parallel_size_local
|
||||
|
||||
nodes = list_nodes()
|
||||
nodes = sorted(list_nodes(),
|
||||
key=lambda node: node.node_ip != dp_master_ip)
|
||||
assert nodes[0].node_ip == dp_master_ip, (
|
||||
"The first node must be the head node")
|
||||
assert len(nodes) == 1 or nodes[1].node_ip != dp_master_ip, (
|
||||
"There can only be one head node")
|
||||
|
||||
available_resources = available_resources_per_node()
|
||||
world_size = vllm_config.parallel_config.world_size
|
||||
placement_groups: list[PlacementGroup] = []
|
||||
local_dp_ranks: list[int] = []
|
||||
|
||||
for node in nodes:
|
||||
node_ip = node.node_ip
|
||||
node_resources = available_resources[node.node_id]
|
||||
# For now, each DP rank can only be assigned to one node
|
||||
# TODO(rui): support allocating a single DP rank
|
||||
# to multiple nodes
|
||||
available_engine_count = int(node_resources["GPU"]) // world_size
|
||||
if node_ip == dp_master_ip:
|
||||
assert available_engine_count >= local_engine_count, (
|
||||
"Not enough resources to allocate DP ranks "
|
||||
f"on DP master node {node_ip}")
|
||||
for i in range(local_engine_count):
|
||||
bundles = [{
|
||||
"GPU": 1.0,
|
||||
"node:" + dp_master_ip: 0.001
|
||||
}] * world_size + [{
|
||||
"CPU": 1.0
|
||||
}]
|
||||
pg = ray.util.placement_group(
|
||||
name=f"dp_rank_{len(placement_groups)}",
|
||||
strategy="STRICT_PACK",
|
||||
bundles=bundles,
|
||||
)
|
||||
placement_groups.append(pg)
|
||||
local_dp_ranks.append(i)
|
||||
else:
|
||||
for i in range(available_engine_count):
|
||||
if len(placement_groups) == dp_size:
|
||||
break
|
||||
bundles = [{"GPU": 1.0}] * world_size + [{"CPU": 1.0}]
|
||||
pg = ray.util.placement_group(
|
||||
name=f"dp_rank_{len(placement_groups)}",
|
||||
strategy="STRICT_PACK",
|
||||
bundles=bundles,
|
||||
)
|
||||
placement_groups.append(pg)
|
||||
local_dp_ranks.append(i)
|
||||
return placement_groups, local_dp_ranks
|
||||
|
||||
def get_run_refs(self):
|
||||
return self.run_refs
|
||||
|
||||
def close(self):
|
||||
import ray
|
||||
for actor in self.local_engine_actors + self.remote_engine_actors:
|
||||
ray.kill(actor)
|
||||
for pg in self.created_placement_groups:
|
||||
ray.util.remove_placement_group(pg)
|
||||
|
||||
|
||||
def wait_for_engine_startup(
|
||||
handshake_socket: zmq.Socket,
|
||||
addresses: EngineZmqAddresses,
|
||||
core_engines: list[CoreEngine],
|
||||
parallel_config: ParallelConfig,
|
||||
cache_config: CacheConfig,
|
||||
proc_manager: Optional[CoreEngineProcManager],
|
||||
coord_process: Optional[Process],
|
||||
):
|
||||
|
||||
# Wait for engine core process(es) to send ready messages.
|
||||
local_count = parallel_config.data_parallel_size_local
|
||||
remote_count = len(core_engines) - local_count
|
||||
# [local, remote] counts
|
||||
conn_pending, start_pending = [local_count, remote_count], [0, 0]
|
||||
poller = zmq.Poller()
|
||||
poller.register(handshake_socket, zmq.POLLIN)
|
||||
|
||||
if proc_manager is not None:
|
||||
for sentinel in proc_manager.sentinels():
|
||||
poller.register(sentinel, zmq.POLLIN)
|
||||
if coord_process is not None:
|
||||
poller.register(coord_process.sentinel, zmq.POLLIN)
|
||||
while any(conn_pending) or any(start_pending):
|
||||
events = poller.poll(STARTUP_POLL_PERIOD_MS)
|
||||
if not events:
|
||||
if any(conn_pending):
|
||||
logger.debug(
|
||||
"Waiting for %d local, %d remote core engine proc(s) "
|
||||
"to connect.", *conn_pending)
|
||||
if any(start_pending):
|
||||
logger.debug(
|
||||
"Waiting for %d local, %d remote core engine proc(s) "
|
||||
"to start.", *start_pending)
|
||||
continue
|
||||
if len(events) > 1 or events[0][0] != handshake_socket:
|
||||
# One of the local core processes exited.
|
||||
finished = proc_manager.finished_procs() if proc_manager else {}
|
||||
if coord_process is not None and coord_process.exitcode is not None:
|
||||
finished[coord_process.name] = coord_process.exitcode
|
||||
raise RuntimeError("Engine core initialization failed. "
|
||||
"See root cause above. "
|
||||
f"Failed core proc(s): {finished}")
|
||||
|
||||
# Receive HELLO and READY messages from the input socket.
|
||||
eng_identity, ready_msg_bytes = handshake_socket.recv_multipart()
|
||||
eng_index = int.from_bytes(eng_identity, "little")
|
||||
engine = next((e for e in core_engines if e.identity == eng_identity),
|
||||
None)
|
||||
if engine is None:
|
||||
raise RuntimeError(f"Message from engine with unexpected data "
|
||||
f"parallel rank: {eng_index}")
|
||||
msg = msgspec.msgpack.decode(ready_msg_bytes)
|
||||
status, local = msg["status"], msg["local"]
|
||||
if local != engine.local:
|
||||
raise RuntimeError(f"{status} message from "
|
||||
f"{'local' if local else 'remote'} "
|
||||
f"engine {eng_index}, expected it to be "
|
||||
f"{'local' if engine.local else 'remote'}")
|
||||
|
||||
if status == "HELLO" and engine.state == CoreEngineState.NEW:
|
||||
|
||||
# Send init message with DP config info.
|
||||
init_message = msgspec.msgpack.encode(
|
||||
EngineHandshakeMetadata(
|
||||
addresses=addresses,
|
||||
parallel_config={
|
||||
"data_parallel_master_ip":
|
||||
parallel_config.data_parallel_master_ip,
|
||||
"data_parallel_master_port":
|
||||
parallel_config.data_parallel_master_port,
|
||||
"data_parallel_size":
|
||||
parallel_config.data_parallel_size,
|
||||
}))
|
||||
handshake_socket.send_multipart((eng_identity, init_message),
|
||||
copy=False)
|
||||
conn_pending[0 if local else 1] -= 1
|
||||
start_pending[0 if local else 1] += 1
|
||||
engine.state = CoreEngineState.CONNECTED
|
||||
elif status == "READY" and (engine.state == CoreEngineState.CONNECTED):
|
||||
# Setup KV cache config with initialization state from
|
||||
# engine core process. Sum values from all engines in DP case.
|
||||
num_gpu_blocks = cache_config.num_gpu_blocks or 0
|
||||
num_gpu_blocks += msg["num_gpu_blocks"]
|
||||
cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
|
||||
start_pending[0 if local else 1] -= 1
|
||||
engine.state = CoreEngineState.READY
|
||||
else:
|
||||
raise RuntimeError(f"Unexpected {status} message for "
|
||||
f"{'local' if local else 'remote'} engine "
|
||||
f"{eng_index} in {engine.state} state.")
|
||||
|
||||
logger.debug("%s from %s core engine process %s.", status,
|
||||
"local" if local else "remote", eng_index)
|
||||
|
||||
|
||||
def wait_for_completion_or_failure(
|
||||
api_server_manager: APIServerProcessManager,
|
||||
engine_manager: Optional[Union[CoreEngineProcManager,
|
||||
CoreEngineActorManager]] = None,
|
||||
engine_manager: Optional[Union["CoreEngineProcManager",
|
||||
"CoreEngineActorManager"]] = None,
|
||||
coordinator: Optional["DPCoordinator"] = None) -> None:
|
||||
"""Wait for all processes to complete or detect if any fail.
|
||||
|
||||
@ -565,6 +196,9 @@ def wait_for_completion_or_failure(
|
||||
coordinator: The coordinator for data parallel.
|
||||
"""
|
||||
|
||||
from vllm.v1.engine.utils import (CoreEngineActorManager,
|
||||
CoreEngineProcManager)
|
||||
|
||||
try:
|
||||
logger.info("Waiting for API servers to complete ...")
|
||||
# Create a mapping of sentinels to their corresponding processes
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user