From c22990470f09fdeae27c53b175c90ea55816268f Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Sun, 20 Jul 2025 20:01:57 +0000 Subject: [PATCH] refactor ux Signed-off-by: Robert Shaw --- vllm/engine/arg_utils.py | 47 +++++++++++++++++++++-------- vllm/entrypoints/cli/serve.py | 7 +---- vllm/entrypoints/openai/cli_args.py | 7 ----- 3 files changed, 36 insertions(+), 25 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 31c8a0fcf4129..d15d6755e988b 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -295,6 +295,7 @@ class EngineArgs: tensor_parallel_size: int = ParallelConfig.tensor_parallel_size data_parallel_size: int = ParallelConfig.data_parallel_size data_parallel_rank: Optional[int] = None + data_parallel_start_rank: Optional[int] = None data_parallel_size_local: Optional[int] = None data_parallel_address: Optional[str] = None data_parallel_rpc_port: Optional[int] = None @@ -606,6 +607,11 @@ class EngineArgs: type=int, help='Data parallel rank of this instance. ' 'When set, enables external load balancer mode.') + parallel_group.add_argument('--data-parallel-start-rank', + '-dpr', + type=int, + help='Starting data parallel rank ' + 'for secondary nodes.') parallel_group.add_argument('--data-parallel-size-local', '-dpl', type=int, @@ -1091,19 +1097,36 @@ class EngineArgs: # but we should not do this here. placement_group = ray.util.get_current_placement_group() - # data_parallel_external_lb = self.data_parallel_rank is not None - # if data_parallel_external_lb: - # assert self.data_parallel_size_local in (1, None), ( - # "data_parallel_size_local must be 1 when data_parallel_rank " - # "is set") - # data_parallel_size_local = 1 - # elif self.data_parallel_size_local is not None: - data_parallel_external_lb = False - if self.data_parallel_size_local is not None: + # Organize --data-parallel-start-rank and --data-parallel-rank. + if self.data_parallel_start_rank is not None: + if self.data_parallel_rank is not None: + raise ValueError( + "Found --data-parallel-rank and --data-parallel-start-rank." + "Only one should be set (use --data-parallel-start-rank).") + else: + self.data_parallel_rank = self.data_parallel_start_rank + + # Validate External LB. + data_parallel_external_lb = True + if data_parallel_external_lb: + if self.data_parallel_size_local is None: + raise ValueError( + "With external LB, --data-parallel-size-local must be set." + ) + if self.data_parallel_size_local >= self.data_parallel_size: + raise ValueError( + "With external LB, --data-parallel-size-local must be less " + "than --data-parallel-size.") + if (self.data_parallel_rank is not None + and self.data_parallel_size_local > 1): + raise ValueError( + "With --data-parallel-size-local > 1, use --data-parall" + "--data-parallel-rank") data_parallel_size_local = self.data_parallel_size_local - else: - # Local DP size defaults to global DP size if not set. - data_parallel_size_local = self.data_parallel_size + + # Local DP size defaults to global DP size if not set. + data_parallel_size_local = (self.data_parallel_size_local + or self.data_parallel_size) # DP address, used in multi-node case for torch distributed group # and ZMQ sockets. diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 1204ccc1c6796..2eaaa9c9c1f1b 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -45,11 +45,6 @@ class ServeSubcommand(CLISubcommand): if args.headless or args.api_server_count < 1: run_headless(args) else: - if args.data_parallel_start_rank: - raise ValueError( - "data_parallel_start_rank is only applicable " - "in headless mode. " - "Add --headless flag to enable headless mode.") if args.api_server_count > 1: run_multi_api_server(args) else: @@ -122,7 +117,7 @@ def run_headless(args: argparse.Namespace): engine_manager = CoreEngineProcManager( target_fn=EngineCoreProc.run_engine_core, local_engine_count=local_engine_count, - start_index=args.data_parallel_start_rank, + start_index=vllm_config.parallel_config.data_parallel_rank, local_start_index=0, vllm_config=vllm_config, local_client=False, diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 28857f8caef85..04775744b8d46 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -253,13 +253,6 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=False, help="Run in headless mode. See multi-node data parallel " "documentation for more details.") - parser.add_argument( - "--data-parallel-start-rank", - "-dpr", - type=int, - default=0, - help="Starting data parallel rank for secondary nodes. " - "Requires --headless.") parser.add_argument("--api-server-count", "-asc", type=int,