From d327a6bed573297f81ddc3728e6b2cf53954894e Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 20 Jul 2025 23:26:48 +0000 Subject: [PATCH] cleanup Signed-off-by: Robert Shaw --- vllm/config.py | 15 +++++++++++---- vllm/engine/arg_utils.py | 3 +-- vllm/entrypoints/cli/serve.py | 5 +++++ vllm/v1/engine/core_client.py | 5 +++-- vllm/v1/engine/utils.py | 12 ++++++------ 5 files changed, 26 insertions(+), 14 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index fa18954104946..04db21766c657 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1891,10 +1891,17 @@ class ParallelConfig: """Backend to use for data parallel, either "mp" or "ray".""" data_parallel_external_lb: bool = False """Whether to use "external" DP LB mode. Applies only to online serving - and when data_parallel_size > 0. Set implicitly when - data_parallel_rank is provided explicitly to vllm serve.""" - data_parallel_rank_0_manage_all: bool = False - """XXX""" + and when data_parallel_size > 0. This is useful for a "one-pod-per-rank" + wide-EP setup in Kuberentes. Set implicitly when data_parallel_rank + is provided explicitly to vllm serve.""" + data_parallel_hybrid_lb: bool = False + """Whether to use "hybrid" DP LB mode. Applies only to online serving + and when data_parallel_size > 0. Enables running an AsyncLLM + and API server on a "per-node" basis where vLLM load balances + between local data parallel ranks, but an external LB balances + between vLLM nodes/replicas. This is useful for a "one-pod-per-node" + wide-EP setup in Kuberentes. Set explicitly by the user. + """ enable_expert_parallel: bool = False """Use expert parallelism instead of tensor parallelism for MoE layers.""" enable_eplb: bool = False diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 51929e123e307..3a542ff78e805 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1108,7 +1108,6 @@ class EngineArgs: # Validate External LB. data_parallel_hybrid_lb = True - if data_parallel_hybrid_lb: if self.data_parallel_size_local is None: @@ -1178,7 +1177,7 @@ class EngineArgs: data_parallel_master_ip=data_parallel_address, data_parallel_rpc_port=data_parallel_rpc_port, data_parallel_backend=self.data_parallel_backend, - data_parallel_rank_0_manage_all=False, + data_parallel_hybrid_lb=False, enable_expert_parallel=self.enable_expert_parallel, enable_eplb=self.enable_eplb, num_redundant_experts=self.num_redundant_experts, diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 2eaaa9c9c1f1b..e26642b2f65e4 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -164,6 +164,11 @@ def run_multi_api_server(args: argparse.Namespace): " api_server_count > 1") model_config.disable_mm_preprocessor_cache = True + if vllm_config.parallel_config.data_parallel_hybrid_lb: + raise NotImplementedError( + "Hybrid load balancing with --api-server-count > 0" + "is not yet supported.") + executor_class = Executor.get_class(vllm_config) log_stats = not engine_args.disable_log_stats diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 6c0d73aa26c69..2b5fd986bd9f3 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -431,11 +431,12 @@ class MPClient(EngineCoreClient): dp_rank = parallel_config.data_parallel_rank dp_local_size = parallel_config.data_parallel_size_local offline_mode = parallel_config.data_parallel_rank_local is not None - manage_only_local = not (parallel_config.data_parallel_rank_0_manage_all) + local_engines_only = (parallel_config.data_parallel_hybrid_lb + or parallel_config.data_parallel_external_lb) # If External DPLB, Client manages local EngineCores. # If Internal DPLB, Client manages local+remote EngineCores. - num_ranks = dp_local_size if manage_only_local else dp_size + num_ranks = dp_local_size if local_engines_only else dp_size self.engine_ranks_managed = ([dp_rank] if offline_mode else range( dp_rank, dp_rank + num_ranks)) assert parallel_config.data_parallel_size_local <= len( diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index aaf229dbcdc5a..a55b288fbdf9f 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -544,8 +544,8 @@ def launch_core_engines( local_start_index = parallel_config.data_parallel_rank_local dp_rank = parallel_config.data_parallel_rank host = parallel_config.data_parallel_master_ip - # external_dp_lb = parallel_config.data_parallel_external_lb - rank_0_local_only = (not parallel_config.data_parallel_rank_0_manage_all) + local_engines_only = (parallel_config.data_parallel_hybrid_lb + or parallel_config.data_parallel_external_lb) # In offline mode there is an LLM instance per DP rank and # one core engine per LLM, see @@ -554,8 +554,8 @@ def launch_core_engines( # client_local_only = True for cases where this front-end # sends requests only to colocated engines. - client_local_only = (offline_mode or rank_0_local_only or - (local_engine_count == dp_size)) + client_local_only = (offline_mode or local_engines_only + or (local_engine_count == dp_size)) # Set up input and output addresses. addresses = EngineZmqAddresses( @@ -611,7 +611,7 @@ def launch_core_engines( ] else: # Rank > 0 handshakes with just the local cores it is managing. - assert rank_0_local_only, ( + assert local_engines_only, ( "Attempting to launch core_engines from dp_rank > 0, but " "found internal DPLB, which is incompatible.") engines_to_handshake = [ @@ -628,7 +628,7 @@ def launch_core_engines( handshake_address = get_engine_client_zmq_addr( handshake_local_only, host, parallel_config.data_parallel_rpc_port) - if rank_0_local_only and dp_rank > 0: + if local_engines_only and dp_rank > 0: assert not handshake_local_only local_handshake_address = get_open_zmq_ipc_path() client_handshake_address = local_handshake_address