infer hybrid lb mode on secondary modes

and update some comments

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-07-22 13:24:34 +01:00
parent 36ed9f3467
commit 82f9292b84
4 changed files with 31 additions and 20 deletions

View File

@ -1013,6 +1013,7 @@ class EngineArgs:
def create_engine_config(
self,
usage_context: Optional[UsageContext] = None,
headless: bool = False,
) -> VllmConfig:
"""
Create the VllmConfig.
@ -1101,6 +1102,10 @@ class EngineArgs:
# but we should not do this here.
placement_group = ray.util.get_current_placement_group()
assert not headless or not self.data_parallel_hybrid_lb, (
"data_parallel_hybrid_lb is not applicable in "
"headless mode")
data_parallel_external_lb = self.data_parallel_rank is not None
# Local DP rank = 1, use pure-external LB.
if data_parallel_external_lb:
@ -1110,24 +1115,25 @@ class EngineArgs:
data_parallel_size_local = 1
# Use full external lb if we have local_size of 1.
self.data_parallel_hybrid_lb = False
# Local DP rank > 1, use hybrid LB.
elif self.data_parallel_hybrid_lb:
assert self.data_parallel_start_rank is not None, (
"data_parallel_start_rank must be set to use "
"data_parallel_hybrid_lb.")
assert self.data_parallel_size_local is not None, (
"data_parallel_size_local must be set to use "
"data_parallel_hybrid_lb.")
# Use full external lb if we have local_size of 1.
if self.data_parallel_size_local == 1:
elif self.data_parallel_size_local is not None and (
self.data_parallel_size_local != self.data_parallel_size):
data_parallel_size_local = self.data_parallel_size_local
if self.data_parallel_start_rank and not headless:
# Infer hybrid LB mode.
self.data_parallel_hybrid_lb = True
if self.data_parallel_hybrid_lb and data_parallel_size_local == 1:
# Use full external lb if we have local_size of 1.
data_parallel_external_lb = True
self.data_parallel_hybrid_lb = False
data_parallel_size_local = self.data_parallel_size_local
self.data_parallel_rank = self.data_parallel_start_rank
elif self.data_parallel_size_local is not None:
data_parallel_size_local = self.data_parallel_size_local
self.data_parallel_rank = self.data_parallel_start_rank
self.data_parallel_rank = self.data_parallel_start_rank or 0
else:
assert self.data_parallel_hybrid_lb is None, (
"data_parallel_size_local must be set to use "
"data_parallel_hybrid_lb.")
# Local DP size defaults to global DP size if not set.
data_parallel_size_local = self.data_parallel_size

View File

@ -81,7 +81,8 @@ def run_headless(args: argparse.Namespace):
# Create the EngineConfig.
engine_args = vllm.AsyncEngineArgs.from_cli_args(args)
usage_context = UsageContext.OPENAI_API_SERVER
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
vllm_config = engine_args.create_engine_config(usage_context=usage_context,
headless=True)
if not envs.VLLM_USE_V1:
raise ValueError("Headless mode is only supported for V1")

View File

@ -467,13 +467,14 @@ class EngineCoreProc(EngineCore):
For DP>1 with internal loadbalancing this is with the shared front-end
process which may reside on a different node.
For DP>1 with external loadbalancing, two handshakes are performed:
For DP>1 with external or hybrid loadbalancing, two handshakes are
performed:
- With the rank 0 front-end process which retrieves the
DP Coordinator ZMQ addresses and DP process group address.
- With the colocated front-end process which retrieves the
client input/output socket addresses.
with the exception of the rank 0 engine itself which doesn't require
the second handshake.
with the exception of the rank 0 and colocated engines themselves which
don't require the second handshake.
Here, "front-end" process can mean the process containing the engine
core client (which is the API server process in the case the API
@ -489,8 +490,9 @@ class EngineCoreProc(EngineCore):
with handshake as addresses:
yield addresses
else:
assert local_client
local_handshake = self._perform_handshake(
input_ctx, client_handshake_address, identity, local_client,
input_ctx, client_handshake_address, identity, True,
vllm_config)
with handshake as addresses, local_handshake as client_addresses:
addresses.inputs = client_addresses.inputs

View File

@ -605,6 +605,8 @@ def launch_core_engines(
elif dp_rank == 0:
# Rank 0 holds Coordinator, so it handshakes with all Cores
# in both external dplb and internal dplb mode.
# Note this also covers the case where we have zero local engines
# and rank 0 is headless.
engines_to_handshake = [
CoreEngine(index=i, local=(i < local_engine_count))
for i in range(dp_size)