[DP] Internal Load Balancing Per Node [one-pod-per-node] (#21238)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Robert Shaw 2025-07-23 23:57:32 -04:00 committed by GitHub
parent eec6942014
commit d5b981f8b1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 486 additions and 45 deletions

View File

@ -166,6 +166,7 @@ steps:
- tests/v1/test_async_llm_dp.py - tests/v1/test_async_llm_dp.py
- tests/v1/test_external_lb_dp.py - tests/v1/test_external_lb_dp.py
- tests/v1/test_internal_lb_dp.py - tests/v1/test_internal_lb_dp.py
- tests/v1/test_hybrid_lb_dp.py
- tests/v1/engine/test_engine_core_client.py - tests/v1/engine/test_engine_core_client.py
commands: commands:
# test with tp=2 and external_dp=2 # test with tp=2 and external_dp=2
@ -178,6 +179,7 @@ steps:
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/test_internal_lb_dp.py - TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/test_internal_lb_dp.py
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/test_hybrid_lb_dp.py
- pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp - pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp
- pytest -v -s distributed/test_utils.py - pytest -v -s distributed/test_utils.py
- pytest -v -s compile/test_basic_correctness.py - pytest -v -s compile/test_basic_correctness.py

View File

@ -565,8 +565,8 @@ def test_engine_core_proc_instantiation_cuda_empty(
from vllm.v1.engine.utils import EngineZmqAddresses from vllm.v1.engine.utils import EngineZmqAddresses
def mock_startup_handshake(self, handshake_socket, on_head_node, def mock_startup_handshake(self, handshake_socket, local_client,
parallel_config): headless, parallel_config):
return EngineZmqAddresses(inputs=["tcp://127.0.0.1:5555"], return EngineZmqAddresses(inputs=["tcp://127.0.0.1:5555"],
outputs=["tcp://127.0.0.1:5556"], outputs=["tcp://127.0.0.1:5556"],
coordinator_input=None, coordinator_input=None,

View File

@ -0,0 +1,352 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import os
import threading
import time
from contextlib import AsyncExitStack
import openai # use the official client for correctness check
import pytest
import pytest_asyncio
from tests.utils import RemoteOpenAIServer
from tests.v1.test_utils import check_request_balancing
from vllm.platforms import Platform
MODEL_NAME = "ibm-research/PowerMoE-3b"
# Number of data parallel ranks for hybrid LB testing (4 total)
DP_SIZE = int(os.getenv("DP_SIZE", "4"))
# Default tensor parallel size to use
TP_SIZE = int(os.getenv("TP_SIZE", "1"))
# Number of nodes (2 nodes, each with 2 DP ranks)
NUM_NODES = 2
DP_SIZE_LOCAL = DP_SIZE // NUM_NODES # 2 ranks per node
class HybridLBServerManager:
"""Manages hybrid data parallel vLLM server instances where each node
runs a single logical API server that balances requests only to the
DP engines running on that same node."""
def __init__(self,
model_name: str,
dp_size: int,
api_server_count: int,
base_server_args: list,
dp_size_local: int = DP_SIZE_LOCAL,
tp_size: int = TP_SIZE):
self.model_name = model_name
self.dp_size = dp_size
self.dp_size_local = dp_size_local
self.tp_size = tp_size
self.api_server_count = api_server_count
self.base_server_args = base_server_args
self.servers: list[tuple[RemoteOpenAIServer, list[str]]] = []
self.server_threads: list[threading.Thread] = []
self.num_nodes = dp_size // dp_size_local
def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]:
"""Start all server instances for hybrid LB mode."""
for node_id in range(self.num_nodes):
# Create server args for this specific node
server_args = self.base_server_args.copy()
# Calculate start rank for this node
start_rank = node_id * self.dp_size_local
# Add hybrid LB specific arguments
server_args.extend([
"--data-parallel-size",
str(self.dp_size),
"--data-parallel-size-local",
str(self.dp_size_local),
"--data-parallel-start-rank",
str(start_rank),
"--data-parallel-hybrid-lb", # Enable hybrid LB mode
"--tensor-parallel-size",
str(self.tp_size),
"--port",
str(8000 + node_id), # Different port for each node
"--api-server-count",
str(self.api_server_count),
"--data-parallel-address",
"127.0.0.1",
"--data-parallel-rpc-port",
"13345",
])
# Use a thread to start each server to allow parallel initialization
def start_server(node: int, sargs: list[str]):
try:
# Calculate GPU devices for this node
gpus_per_node = self.dp_size_local * self.tp_size
gpu_start = node * gpus_per_node
gpu_end = gpu_start + gpus_per_node
# Start the server
server = RemoteOpenAIServer(
self.model_name,
sargs,
auto_port=False,
env_dict={
"CUDA_VISIBLE_DEVICES":
",".join(
str(Platform.device_id_to_physical_device_id(
i)) for i in range(gpu_start, gpu_end))
})
server.__enter__()
print(f"Hybrid LB node {node} started successfully with "
f"{self.dp_size_local} local DP ranks and "
f"{self.api_server_count} API servers")
self.servers.append((server, sargs))
except Exception as e:
print(f"Failed to start hybrid LB node {node}: {e}")
raise
thread = threading.Thread(target=start_server,
args=(node_id, server_args))
thread.start()
self.server_threads.append(thread)
# Wait for all servers to start
for thread in self.server_threads:
thread.join()
# Give servers additional time to fully initialize and coordinate
time.sleep(3)
if len(self.servers) != self.num_nodes:
raise Exception("Servers failed to start")
return self.servers
def __exit__(self, exc_type, exc_val, exc_tb):
"""Stop all server instances."""
while self.servers:
try:
self.servers.pop()[0].__exit__(exc_type, exc_val, exc_tb)
except Exception as e:
print(f"Error stopping server: {e}")
@pytest.fixture(scope="module")
def default_server_args():
return [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-model-len",
"2048",
"--max-num-seqs",
"128",
"--enforce-eager",
]
@pytest.fixture(scope="module", params=[1]) # Only 1 API server for now
def servers(request, default_server_args):
api_server_count = request.param
with HybridLBServerManager(MODEL_NAME, DP_SIZE, api_server_count,
default_server_args, DP_SIZE_LOCAL,
TP_SIZE) as server_list:
yield server_list
@pytest_asyncio.fixture
async def clients(servers: list[tuple[RemoteOpenAIServer, list[str]]]):
# Create a client for each node (each node has its own API endpoint)
async with AsyncExitStack() as stack:
yield [
await stack.enter_async_context(server.get_async_client())
for server, _ in servers
]
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_hybrid_lb_completion(clients: list[openai.AsyncOpenAI],
servers: list[tuple[RemoteOpenAIServer,
list[str]]],
model_name: str) -> None:
async def make_request(client: openai.AsyncOpenAI):
completion = await client.completions.create(
model=model_name,
prompt="Hello, my name is",
max_tokens=10,
temperature=1.0)
assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 1
choice = completion.choices[0]
# The exact number of tokens can vary slightly with temperature=1.0,
# so we check for a reasonable minimum length.
assert len(choice.text) >= 1
# Finish reason might not always be 'length' if the model finishes early
# or due to other reasons, especially with high temperature.
# So, we'll accept 'length' or 'stop'.
assert choice.finish_reason in ("length", "stop")
# Token counts can also vary, so we check they are positive.
assert completion.usage.completion_tokens > 0
assert completion.usage.prompt_tokens > 0
assert completion.usage.total_tokens > 0
return completion
# Test single request to each node
for i, client in enumerate(clients):
result = await make_request(client)
assert result is not None
print(
f"Hybrid LB node {i} handled single completion request successfully"
)
await asyncio.sleep(0.5)
# Send requests to all nodes - each should balance within its local DP ranks
num_requests_per_node = 25 # Total 50 requests across 2 nodes
all_tasks = []
for i, client in enumerate(clients):
tasks = [make_request(client) for _ in range(num_requests_per_node)]
all_tasks.extend(tasks)
results = await asyncio.gather(*all_tasks)
assert len(results) == num_requests_per_node * len(clients)
assert all(completion is not None for completion in results)
await asyncio.sleep(0.5)
# Second burst of requests
all_tasks = []
for i, client in enumerate(clients):
tasks = [make_request(client) for _ in range(num_requests_per_node)]
all_tasks.extend(tasks)
results = await asyncio.gather(*all_tasks)
assert len(results) == num_requests_per_node * len(clients)
assert all(completion is not None for completion in results)
_, server_args = servers[0]
api_server_count = (
server_args.count('--api-server-count')
and server_args[server_args.index('--api-server-count') + 1] or 1)
print(
f"Successfully completed hybrid LB test with {len(clients)} nodes "
f"({DP_SIZE_LOCAL} DP ranks each, API server count: {api_server_count})"
)
# Check request balancing within each node
for i, (server, _) in enumerate(servers):
print(f"Checking request balancing for node {i}")
check_request_balancing(server, DP_SIZE_LOCAL)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_hybrid_lb_completion_streaming(clients: list[
openai.AsyncOpenAI], servers: list[tuple[RemoteOpenAIServer, list[str]]],
model_name: str) -> None:
prompt = "What is an LLM?"
async def make_streaming_request(client: openai.AsyncOpenAI):
# Perform a non-streaming request to get the expected full output
single_completion = await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
)
single_output = single_completion.choices[0].text
# Perform the streaming request
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True)
chunks: list[str] = []
finish_reason_count = 0
last_chunk = None
async for chunk in stream:
chunks.append(chunk.choices[0].text)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
last_chunk = chunk # Keep track of the last chunk
# finish reason should only return in the last block for OpenAI API
assert finish_reason_count == 1, (
"Finish reason should appear exactly once.")
assert last_chunk is not None, (
"Stream should have yielded at least one chunk.")
assert last_chunk.choices[
0].finish_reason == "length", "Finish reason should be 'length'."
# Check that the combined text matches the non-streamed version.
assert "".join(
chunks
) == single_output, "Streamed output should match non-streamed output."
return True # Indicate success for this request
# Test single request to each node
for i, client in enumerate(clients):
result = await make_streaming_request(client)
assert result is not None
print(
f"Hybrid LB node {i} handled single streaming request successfully"
)
await asyncio.sleep(0.5)
# Send streaming requests to all nodes
num_requests_per_node = 25 # Total 50 requests across 2 nodes
all_tasks = []
for i, client in enumerate(clients):
tasks = [
make_streaming_request(client)
for _ in range(num_requests_per_node)
]
all_tasks.extend(tasks)
results = await asyncio.gather(*all_tasks)
assert len(results) == num_requests_per_node * len(clients)
assert all(results), "Not all streaming requests completed successfully."
await asyncio.sleep(0.5)
# Second burst of streaming requests
all_tasks = []
for i, client in enumerate(clients):
tasks = [
make_streaming_request(client)
for _ in range(num_requests_per_node)
]
all_tasks.extend(tasks)
results = await asyncio.gather(*all_tasks)
assert len(results) == num_requests_per_node * len(clients)
assert all(results), "Not all streaming requests completed successfully."
_, server_args = servers[0]
api_server_count = (
server_args.count('--api-server-count')
and server_args[server_args.index('--api-server-count') + 1] or 1)
print(f"Successfully completed hybrid LB streaming test with "
f"{len(clients)} nodes ({DP_SIZE_LOCAL} DP ranks each, "
f"API server count: {api_server_count})")
# Check request balancing within each node
for i, (server, _) in enumerate(servers):
print(f"Checking streaming request balancing for node {i}")
check_request_balancing(server, DP_SIZE_LOCAL)

View File

@ -1908,8 +1908,16 @@ class ParallelConfig:
"""Backend to use for data parallel, either "mp" or "ray".""" """Backend to use for data parallel, either "mp" or "ray"."""
data_parallel_external_lb: bool = False data_parallel_external_lb: bool = False
"""Whether to use "external" DP LB mode. Applies only to online serving """Whether to use "external" DP LB mode. Applies only to online serving
and when data_parallel_size > 0. Set implicitly when and when data_parallel_size > 0. This is useful for a "one-pod-per-rank"
data_parallel_rank is provided explicitly to vllm serve.""" wide-EP setup in Kuberentes. Set implicitly when --data-parallel-rank
is provided explicitly to vllm serve."""
data_parallel_hybrid_lb: bool = False
"""Whether to use "hybrid" DP LB mode. Applies only to online serving
and when data_parallel_size > 0. Enables running an AsyncLLM
and API server on a "per-node" basis where vLLM load balances
between local data parallel ranks, but an external LB balances
between vLLM nodes/replicas. Set explicitly in conjunction with
--data-parallel-start-rank."""
enable_expert_parallel: bool = False enable_expert_parallel: bool = False
"""Use expert parallelism instead of tensor parallelism for MoE layers.""" """Use expert parallelism instead of tensor parallelism for MoE layers."""
enable_eplb: bool = False enable_eplb: bool = False

View File

@ -295,9 +295,11 @@ class EngineArgs:
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
data_parallel_size: int = ParallelConfig.data_parallel_size data_parallel_size: int = ParallelConfig.data_parallel_size
data_parallel_rank: Optional[int] = None data_parallel_rank: Optional[int] = None
data_parallel_start_rank: Optional[int] = None
data_parallel_size_local: Optional[int] = None data_parallel_size_local: Optional[int] = None
data_parallel_address: Optional[str] = None data_parallel_address: Optional[str] = None
data_parallel_rpc_port: Optional[int] = None data_parallel_rpc_port: Optional[int] = None
data_parallel_hybrid_lb: bool = False
data_parallel_backend: str = ParallelConfig.data_parallel_backend data_parallel_backend: str = ParallelConfig.data_parallel_backend
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
enable_eplb: bool = ParallelConfig.enable_eplb enable_eplb: bool = ParallelConfig.enable_eplb
@ -604,6 +606,11 @@ class EngineArgs:
type=int, type=int,
help='Data parallel rank of this instance. ' help='Data parallel rank of this instance. '
'When set, enables external load balancer mode.') 'When set, enables external load balancer mode.')
parallel_group.add_argument('--data-parallel-start-rank',
'-dpr',
type=int,
help='Starting data parallel rank '
'for secondary nodes.')
parallel_group.add_argument('--data-parallel-size-local', parallel_group.add_argument('--data-parallel-size-local',
'-dpl', '-dpl',
type=int, type=int,
@ -625,6 +632,9 @@ class EngineArgs:
default='mp', default='mp',
help='Backend for data parallel, either ' help='Backend for data parallel, either '
'"mp" or "ray".') '"mp" or "ray".')
parallel_group.add_argument(
"--data-parallel-hybrid-lb",
**parallel_kwargs["data_parallel_hybrid_lb"])
parallel_group.add_argument( parallel_group.add_argument(
"--enable-expert-parallel", "--enable-expert-parallel",
**parallel_kwargs["enable_expert_parallel"]) **parallel_kwargs["enable_expert_parallel"])
@ -972,6 +982,7 @@ class EngineArgs:
def create_engine_config( def create_engine_config(
self, self,
usage_context: Optional[UsageContext] = None, usage_context: Optional[UsageContext] = None,
headless: bool = False,
) -> VllmConfig: ) -> VllmConfig:
""" """
Create the VllmConfig. Create the VllmConfig.
@ -1060,15 +1071,41 @@ class EngineArgs:
# but we should not do this here. # but we should not do this here.
placement_group = ray.util.get_current_placement_group() placement_group = ray.util.get_current_placement_group()
assert not headless or not self.data_parallel_hybrid_lb, (
"data_parallel_hybrid_lb is not applicable in "
"headless mode")
data_parallel_external_lb = self.data_parallel_rank is not None data_parallel_external_lb = self.data_parallel_rank is not None
# Local DP rank = 1, use pure-external LB.
if data_parallel_external_lb: if data_parallel_external_lb:
assert self.data_parallel_size_local in (1, None), ( assert self.data_parallel_size_local in (1, None), (
"data_parallel_size_local must be 1 when data_parallel_rank " "data_parallel_size_local must be 1 when data_parallel_rank "
"is set") "is set")
data_parallel_size_local = 1 data_parallel_size_local = 1
# Use full external lb if we have local_size of 1.
self.data_parallel_hybrid_lb = False
elif self.data_parallel_size_local is not None: elif self.data_parallel_size_local is not None:
data_parallel_size_local = self.data_parallel_size_local data_parallel_size_local = self.data_parallel_size_local
if self.data_parallel_start_rank and not headless:
# Infer hybrid LB mode.
self.data_parallel_hybrid_lb = True
if self.data_parallel_hybrid_lb and data_parallel_size_local == 1:
# Use full external lb if we have local_size of 1.
data_parallel_external_lb = True
self.data_parallel_hybrid_lb = False
if data_parallel_size_local == self.data_parallel_size:
# Disable hybrid LB mode if set for a single node
self.data_parallel_hybrid_lb = False
self.data_parallel_rank = self.data_parallel_start_rank or 0
else: else:
assert not self.data_parallel_hybrid_lb, (
"data_parallel_size_local must be set to use "
"data_parallel_hybrid_lb.")
# Local DP size defaults to global DP size if not set. # Local DP size defaults to global DP size if not set.
data_parallel_size_local = self.data_parallel_size data_parallel_size_local = self.data_parallel_size
@ -1125,6 +1162,7 @@ class EngineArgs:
data_parallel_master_ip=data_parallel_address, data_parallel_master_ip=data_parallel_address,
data_parallel_rpc_port=data_parallel_rpc_port, data_parallel_rpc_port=data_parallel_rpc_port,
data_parallel_backend=self.data_parallel_backend, data_parallel_backend=self.data_parallel_backend,
data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
enable_expert_parallel=self.enable_expert_parallel, enable_expert_parallel=self.enable_expert_parallel,
enable_eplb=self.enable_eplb, enable_eplb=self.enable_eplb,
num_redundant_experts=self.num_redundant_experts, num_redundant_experts=self.num_redundant_experts,

View File

@ -45,11 +45,6 @@ class ServeSubcommand(CLISubcommand):
if args.headless or args.api_server_count < 1: if args.headless or args.api_server_count < 1:
run_headless(args) run_headless(args)
else: else:
if args.data_parallel_start_rank:
raise ValueError(
"data_parallel_start_rank is only applicable "
"in headless mode. "
"Add --headless flag to enable headless mode.")
if args.api_server_count > 1: if args.api_server_count > 1:
run_multi_api_server(args) run_multi_api_server(args)
else: else:
@ -86,13 +81,14 @@ def run_headless(args: argparse.Namespace):
# Create the EngineConfig. # Create the EngineConfig.
engine_args = vllm.AsyncEngineArgs.from_cli_args(args) engine_args = vllm.AsyncEngineArgs.from_cli_args(args)
usage_context = UsageContext.OPENAI_API_SERVER usage_context = UsageContext.OPENAI_API_SERVER
vllm_config = engine_args.create_engine_config(usage_context=usage_context) vllm_config = engine_args.create_engine_config(usage_context=usage_context,
headless=True)
if not envs.VLLM_USE_V1: if not envs.VLLM_USE_V1:
raise ValueError("Headless mode is only supported for V1") raise ValueError("Headless mode is only supported for V1")
if engine_args.data_parallel_rank is not None: if engine_args.data_parallel_hybrid_lb:
raise ValueError("data_parallel_rank is not applicable in " raise ValueError("data_parallel_hybrid_lb is not applicable in "
"headless mode") "headless mode")
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
@ -122,7 +118,7 @@ def run_headless(args: argparse.Namespace):
engine_manager = CoreEngineProcManager( engine_manager = CoreEngineProcManager(
target_fn=EngineCoreProc.run_engine_core, target_fn=EngineCoreProc.run_engine_core,
local_engine_count=local_engine_count, local_engine_count=local_engine_count,
start_index=args.data_parallel_start_rank, start_index=vllm_config.parallel_config.data_parallel_rank,
local_start_index=0, local_start_index=0,
vllm_config=vllm_config, vllm_config=vllm_config,
local_client=False, local_client=False,
@ -169,6 +165,11 @@ def run_multi_api_server(args: argparse.Namespace):
" api_server_count > 1") " api_server_count > 1")
model_config.disable_mm_preprocessor_cache = True model_config.disable_mm_preprocessor_cache = True
if vllm_config.parallel_config.data_parallel_hybrid_lb:
raise NotImplementedError(
"Hybrid load balancing with --api-server-count > 0"
"is not yet supported.")
executor_class = Executor.get_class(vllm_config) executor_class = Executor.get_class(vllm_config)
log_stats = not engine_args.disable_log_stats log_stats = not engine_args.disable_log_stats

View File

@ -222,13 +222,6 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=False, default=False,
help="Run in headless mode. See multi-node data parallel " help="Run in headless mode. See multi-node data parallel "
"documentation for more details.") "documentation for more details.")
parser.add_argument(
"--data-parallel-start-rank",
"-dpr",
type=int,
default=0,
help="Starting data parallel rank for secondary nodes. "
"Requires --headless.")
parser.add_argument("--api-server-count", parser.add_argument("--api-server-count",
"-asc", "-asc",
type=int, type=int,

View File

@ -127,7 +127,7 @@ class AsyncLLM(EngineClient):
if self.log_stats: if self.log_stats:
self.logger_manager = StatLoggerManager( self.logger_manager = StatLoggerManager(
vllm_config=vllm_config, vllm_config=vllm_config,
engine_idxs=self.engine_core.engine_ranks, engine_idxs=self.engine_core.engine_ranks_managed,
custom_stat_loggers=stat_loggers, custom_stat_loggers=stat_loggers,
) )
self.logger_manager.log_engine_initialized() self.logger_manager.log_engine_initialized()

View File

@ -61,11 +61,12 @@ class DPCoordinator:
host = parallel_config.data_parallel_master_ip host = parallel_config.data_parallel_master_ip
external_lb = parallel_config.data_parallel_external_lb external_lb = parallel_config.data_parallel_external_lb
hybrid_lb = parallel_config.data_parallel_hybrid_lb
# Assume coordinator is colocated with front-end procs when not in # Assume coordinator is colocated with front-end procs when not in
# external DP LB mode. # either external or hybrid DP LB mode.
front_publish_address = get_engine_client_zmq_addr( front_publish_address = get_engine_client_zmq_addr(
local_only=not external_lb, host=host) local_only=not external_lb and not hybrid_lb, host=host)
local_only_eng = dp_size == parallel_config.data_parallel_size_local local_only_eng = dp_size == parallel_config.data_parallel_size_local
back_publish_address = get_engine_client_zmq_addr(local_only_eng, host) back_publish_address = get_engine_client_zmq_addr(local_only_eng, host)

View File

@ -467,13 +467,14 @@ class EngineCoreProc(EngineCore):
For DP>1 with internal loadbalancing this is with the shared front-end For DP>1 with internal loadbalancing this is with the shared front-end
process which may reside on a different node. process which may reside on a different node.
For DP>1 with external loadbalancing, two handshakes are performed: For DP>1 with external or hybrid loadbalancing, two handshakes are
performed:
- With the rank 0 front-end process which retrieves the - With the rank 0 front-end process which retrieves the
DP Coordinator ZMQ addresses and DP process group address. DP Coordinator ZMQ addresses and DP process group address.
- With the colocated front-end process which retrieves the - With the colocated front-end process which retrieves the
client input/output socket addresses. client input/output socket addresses.
with the exception of the rank 0 engine itself which doesn't require with the exception of the rank 0 and colocated engines themselves which
the second handshake. don't require the second handshake.
Here, "front-end" process can mean the process containing the engine Here, "front-end" process can mean the process containing the engine
core client (which is the API server process in the case the API core client (which is the API server process in the case the API
@ -482,15 +483,18 @@ class EngineCoreProc(EngineCore):
""" """
input_ctx = zmq.Context() input_ctx = zmq.Context()
is_local = local_client and client_handshake_address is None is_local = local_client and client_handshake_address is None
headless = not local_client
handshake = self._perform_handshake(input_ctx, handshake_address, handshake = self._perform_handshake(input_ctx, handshake_address,
identity, is_local, vllm_config, identity, is_local, headless,
vllm_config,
vllm_config.parallel_config) vllm_config.parallel_config)
if client_handshake_address is None: if client_handshake_address is None:
with handshake as addresses: with handshake as addresses:
yield addresses yield addresses
else: else:
assert local_client
local_handshake = self._perform_handshake( local_handshake = self._perform_handshake(
input_ctx, client_handshake_address, identity, local_client, input_ctx, client_handshake_address, identity, True, False,
vllm_config) vllm_config)
with handshake as addresses, local_handshake as client_addresses: with handshake as addresses, local_handshake as client_addresses:
addresses.inputs = client_addresses.inputs addresses.inputs = client_addresses.inputs
@ -507,6 +511,7 @@ class EngineCoreProc(EngineCore):
handshake_address: str, handshake_address: str,
identity: bytes, identity: bytes,
local_client: bool, local_client: bool,
headless: bool,
vllm_config: VllmConfig, vllm_config: VllmConfig,
parallel_config_to_update: Optional[ParallelConfig] = None, parallel_config_to_update: Optional[ParallelConfig] = None,
) -> Generator[EngineZmqAddresses, None, None]: ) -> Generator[EngineZmqAddresses, None, None]:
@ -518,6 +523,7 @@ class EngineCoreProc(EngineCore):
bind=False) as handshake_socket: bind=False) as handshake_socket:
# Register engine with front-end. # Register engine with front-end.
addresses = self.startup_handshake(handshake_socket, local_client, addresses = self.startup_handshake(handshake_socket, local_client,
headless,
parallel_config_to_update) parallel_config_to_update)
yield addresses yield addresses
@ -531,6 +537,7 @@ class EngineCoreProc(EngineCore):
msgspec.msgpack.encode({ msgspec.msgpack.encode({
"status": "READY", "status": "READY",
"local": local_client, "local": local_client,
"headless": headless,
"num_gpu_blocks": num_gpu_blocks, "num_gpu_blocks": num_gpu_blocks,
"dp_stats_address": dp_stats_address, "dp_stats_address": dp_stats_address,
})) }))
@ -539,6 +546,7 @@ class EngineCoreProc(EngineCore):
def startup_handshake( def startup_handshake(
handshake_socket: zmq.Socket, handshake_socket: zmq.Socket,
local_client: bool, local_client: bool,
headless: bool,
parallel_config: Optional[ParallelConfig] = None, parallel_config: Optional[ParallelConfig] = None,
) -> EngineZmqAddresses: ) -> EngineZmqAddresses:
@ -547,6 +555,7 @@ class EngineCoreProc(EngineCore):
msgspec.msgpack.encode({ msgspec.msgpack.encode({
"status": "HELLO", "status": "HELLO",
"local": local_client, "local": local_client,
"headless": headless,
})) }))
# Receive initialization message. # Receive initialization message.

View File

@ -429,18 +429,23 @@ class MPClient(EngineCoreClient):
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
dp_size = parallel_config.data_parallel_size dp_size = parallel_config.data_parallel_size
dp_rank = parallel_config.data_parallel_rank dp_rank = parallel_config.data_parallel_rank
external_dp_lb = parallel_config.data_parallel_external_lb dp_local_size = parallel_config.data_parallel_size_local
offline_mode = parallel_config.data_parallel_rank_local is not None offline_mode = parallel_config.data_parallel_rank_local is not None
self.engine_ranks = ([dp_rank] if # Client manages local+remote EngineCores in pure internal LB case.
(offline_mode or external_dp_lb) else list( # Client manages local EngineCores in hybrid and external LB case.
range(dp_size))) local_engines_only = (parallel_config.data_parallel_hybrid_lb
or parallel_config.data_parallel_external_lb)
num_ranks = dp_local_size if local_engines_only else dp_size
self.engine_ranks_managed = [dp_rank] if offline_mode else list(
range(dp_rank, dp_rank + num_ranks))
assert parallel_config.data_parallel_size_local <= len( assert parallel_config.data_parallel_size_local <= len(
self.engine_ranks) self.engine_ranks_managed)
# ZMQ identity of each engine that this client will talk to. # ZMQ identity of each engine that this client will talk to.
self.core_engines: list[EngineIdentity] = [ self.core_engines: list[EngineIdentity] = [
index.to_bytes(2, "little") for index in self.engine_ranks rank.to_bytes(2, "little")
for rank in self.engine_ranks_managed
] ]
# Wait for ready messages from each engine on the input socket. # Wait for ready messages from each engine on the input socket.
@ -895,6 +900,12 @@ class DPAsyncMPClient(AsyncMPClient):
return return
assert self.stats_update_address is not None assert self.stats_update_address is not None
assert len(self.engine_ranks_managed) > 0
# NOTE: running and waiting counts are all global from
# the Coordinator include all global EngineCores. This
# slice includes just the cores managed by this client.
count_slice = slice(self.engine_ranks_managed[0],
self.engine_ranks_managed[-1] + 1)
async def run_engine_stats_update_task(): async def run_engine_stats_update_task():
with make_zmq_socket(self.ctx, self.stats_update_address, with make_zmq_socket(self.ctx, self.stats_update_address,
@ -959,7 +970,7 @@ class DPAsyncMPClient(AsyncMPClient):
counts, wave, running = msgspec.msgpack.decode(buf) counts, wave, running = msgspec.msgpack.decode(buf)
self.current_wave = wave self.current_wave = wave
self.engines_running = running self.engines_running = running
self.lb_engines = counts self.lb_engines = counts[count_slice]
resources.stats_update_task = asyncio.create_task( resources.stats_update_task = asyncio.create_task(
run_engine_stats_update_task()) run_engine_stats_update_task())

View File

@ -544,7 +544,8 @@ def launch_core_engines(
local_start_index = parallel_config.data_parallel_rank_local local_start_index = parallel_config.data_parallel_rank_local
dp_rank = parallel_config.data_parallel_rank dp_rank = parallel_config.data_parallel_rank
host = parallel_config.data_parallel_master_ip host = parallel_config.data_parallel_master_ip
external_dp_lb = parallel_config.data_parallel_external_lb local_engines_only = (parallel_config.data_parallel_hybrid_lb
or parallel_config.data_parallel_external_lb)
# In offline mode there is an LLM instance per DP rank and # In offline mode there is an LLM instance per DP rank and
# one core engine per LLM, see # one core engine per LLM, see
@ -553,8 +554,8 @@ def launch_core_engines(
# client_local_only = True for cases where this front-end # client_local_only = True for cases where this front-end
# sends requests only to colocated engines. # sends requests only to colocated engines.
client_local_only = offline_mode or external_dp_lb or (local_engine_count client_local_only = (offline_mode or local_engines_only
== dp_size) or (local_engine_count == dp_size))
# Set up input and output addresses. # Set up input and output addresses.
addresses = EngineZmqAddresses( addresses = EngineZmqAddresses(
@ -598,14 +599,27 @@ def launch_core_engines(
yield engine_actor_manager, coordinator, addresses yield engine_actor_manager, coordinator, addresses
return return
if offline_mode or (external_dp_lb and dp_rank > 0): if offline_mode:
assert local_engine_count == 1 assert local_engine_count == 1
engines_to_handshake = [CoreEngine(index=dp_rank, local=True)] engines_to_handshake = [CoreEngine(index=dp_rank, local=True)]
else: elif dp_rank == 0:
# Rank 0 holds Coordinator, so it handshakes with all Cores
# in both external dplb and internal dplb mode.
# Note this also covers the case where we have zero local engines
# and rank 0 is headless.
engines_to_handshake = [ engines_to_handshake = [
CoreEngine(index=i, local=(i < local_engine_count)) CoreEngine(index=i, local=(i < local_engine_count))
for i in range(dp_size) for i in range(dp_size)
] ]
else:
# Rank > 0 handshakes with just the local cores it is managing.
assert local_engines_only, (
"Attempting to launch core_engines from dp_rank > 0, but "
"found internal DPLB, which is incompatible.")
engines_to_handshake = [
CoreEngine(index=i, local=True)
for i in range(dp_rank, dp_rank + local_engine_count)
]
# Whether the started engines will handshake only with co-located # Whether the started engines will handshake only with co-located
# front-end processes. In external_dp_lb mode, ranks > 0 handshake with # front-end processes. In external_dp_lb mode, ranks > 0 handshake with
@ -616,7 +630,7 @@ def launch_core_engines(
handshake_address = get_engine_client_zmq_addr( handshake_address = get_engine_client_zmq_addr(
handshake_local_only, host, parallel_config.data_parallel_rpc_port) handshake_local_only, host, parallel_config.data_parallel_rpc_port)
if external_dp_lb and dp_rank > 0: if local_engines_only and dp_rank > 0:
assert not handshake_local_only assert not handshake_local_only
local_handshake_address = get_open_zmq_ipc_path() local_handshake_address = get_open_zmq_ipc_path()
client_handshake_address = local_handshake_address client_handshake_address = local_handshake_address
@ -631,8 +645,6 @@ def launch_core_engines(
# Start local engines. # Start local engines.
if local_engine_count: if local_engine_count:
# In server mode, start_index and local_start_index will
# both be 0.
local_engine_manager = CoreEngineProcManager( local_engine_manager = CoreEngineProcManager(
EngineCoreProc.run_engine_core, EngineCoreProc.run_engine_core,
vllm_config=vllm_config, vllm_config=vllm_config,
@ -678,6 +690,9 @@ def wait_for_engine_startup(
poller = zmq.Poller() poller = zmq.Poller()
poller.register(handshake_socket, zmq.POLLIN) poller.register(handshake_socket, zmq.POLLIN)
remote_should_be_headless = not parallel_config.data_parallel_hybrid_lb \
and not parallel_config.data_parallel_external_lb
if proc_manager is not None: if proc_manager is not None:
for sentinel in proc_manager.sentinels(): for sentinel in proc_manager.sentinels():
poller.register(sentinel, zmq.POLLIN) poller.register(sentinel, zmq.POLLIN)
@ -713,13 +728,24 @@ def wait_for_engine_startup(
raise RuntimeError(f"Message from engine with unexpected data " raise RuntimeError(f"Message from engine with unexpected data "
f"parallel rank: {eng_index}") f"parallel rank: {eng_index}")
msg = msgspec.msgpack.decode(ready_msg_bytes) msg = msgspec.msgpack.decode(ready_msg_bytes)
status, local = msg["status"], msg["local"] status, local, headless = msg["status"], msg["local"], msg["headless"]
if local != engine.local: if local != engine.local:
raise RuntimeError(f"{status} message from " raise RuntimeError(f"{status} message from "
f"{'local' if local else 'remote'} " f"{'local' if local else 'remote'} "
f"engine {eng_index}, expected it to be " f"engine {eng_index}, expected it to be "
f"{'local' if engine.local else 'remote'}") f"{'local' if engine.local else 'remote'}")
# Remote engines must be headless iff we aren't in hybrid dp lb mode.
if not local and headless != remote_should_be_headless:
if headless:
raise RuntimeError(f"Remote engine {eng_index} must not use "
f"--headless in external or hybrid dp lb "
f"mode")
else:
raise RuntimeError(f"Remote engine {eng_index} must use "
f"--headless unless in external or hybrid "
f"dp lb mode")
if status == "HELLO" and engine.state == CoreEngineState.NEW: if status == "HELLO" and engine.state == CoreEngineState.NEW:
# Send init message with DP config info. # Send init message with DP config info.