From d5b981f8b1de31a54d55f9c0ead977dbf4d6b987 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Date: Wed, 23 Jul 2025 23:57:32 -0400 Subject: [PATCH] [DP] Internal Load Balancing Per Node [`one-pod-per-node`] (#21238) Signed-off-by: Robert Shaw Signed-off-by: Nick Hill Signed-off-by: Tyler Michael Smith Co-authored-by: Robert Shaw Co-authored-by: Nick Hill Co-authored-by: Tyler Michael Smith --- .buildkite/test-pipeline.yaml | 2 + tests/v1/engine/test_engine_core_client.py | 4 +- tests/v1/test_hybrid_lb_dp.py | 352 +++++++++++++++++++++ vllm/config.py | 12 +- vllm/engine/arg_utils.py | 38 +++ vllm/entrypoints/cli/serve.py | 19 +- vllm/entrypoints/openai/cli_args.py | 7 - vllm/v1/engine/async_llm.py | 2 +- vllm/v1/engine/coordinator.py | 5 +- vllm/v1/engine/core.py | 19 +- vllm/v1/engine/core_client.py | 27 +- vllm/v1/engine/utils.py | 44 ++- 12 files changed, 486 insertions(+), 45 deletions(-) create mode 100644 tests/v1/test_hybrid_lb_dp.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index c2e56557ba9b..948ce9e8667f 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -166,6 +166,7 @@ steps: - tests/v1/test_async_llm_dp.py - tests/v1/test_external_lb_dp.py - tests/v1/test_internal_lb_dp.py + - tests/v1/test_hybrid_lb_dp.py - tests/v1/engine/test_engine_core_client.py commands: # test with tp=2 and external_dp=2 @@ -178,6 +179,7 @@ steps: - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py - TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/test_internal_lb_dp.py + - TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/test_hybrid_lb_dp.py - pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp - pytest -v -s distributed/test_utils.py - pytest -v -s compile/test_basic_correctness.py diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 65f1da803fb2..2ac6dc796bd1 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -565,8 +565,8 @@ def test_engine_core_proc_instantiation_cuda_empty( from vllm.v1.engine.utils import EngineZmqAddresses - def mock_startup_handshake(self, handshake_socket, on_head_node, - parallel_config): + def mock_startup_handshake(self, handshake_socket, local_client, + headless, parallel_config): return EngineZmqAddresses(inputs=["tcp://127.0.0.1:5555"], outputs=["tcp://127.0.0.1:5556"], coordinator_input=None, diff --git a/tests/v1/test_hybrid_lb_dp.py b/tests/v1/test_hybrid_lb_dp.py new file mode 100644 index 000000000000..08336489abee --- /dev/null +++ b/tests/v1/test_hybrid_lb_dp.py @@ -0,0 +1,352 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio +import os +import threading +import time +from contextlib import AsyncExitStack + +import openai # use the official client for correctness check +import pytest +import pytest_asyncio + +from tests.utils import RemoteOpenAIServer +from tests.v1.test_utils import check_request_balancing +from vllm.platforms import Platform + +MODEL_NAME = "ibm-research/PowerMoE-3b" + +# Number of data parallel ranks for hybrid LB testing (4 total) +DP_SIZE = int(os.getenv("DP_SIZE", "4")) +# Default tensor parallel size to use +TP_SIZE = int(os.getenv("TP_SIZE", "1")) + +# Number of nodes (2 nodes, each with 2 DP ranks) +NUM_NODES = 2 +DP_SIZE_LOCAL = DP_SIZE // NUM_NODES # 2 ranks per node + + +class HybridLBServerManager: + """Manages hybrid data parallel vLLM server instances where each node + runs a single logical API server that balances requests only to the + DP engines running on that same node.""" + + def __init__(self, + model_name: str, + dp_size: int, + api_server_count: int, + base_server_args: list, + dp_size_local: int = DP_SIZE_LOCAL, + tp_size: int = TP_SIZE): + self.model_name = model_name + self.dp_size = dp_size + self.dp_size_local = dp_size_local + self.tp_size = tp_size + self.api_server_count = api_server_count + self.base_server_args = base_server_args + self.servers: list[tuple[RemoteOpenAIServer, list[str]]] = [] + self.server_threads: list[threading.Thread] = [] + self.num_nodes = dp_size // dp_size_local + + def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]: + """Start all server instances for hybrid LB mode.""" + for node_id in range(self.num_nodes): + # Create server args for this specific node + server_args = self.base_server_args.copy() + + # Calculate start rank for this node + start_rank = node_id * self.dp_size_local + + # Add hybrid LB specific arguments + server_args.extend([ + "--data-parallel-size", + str(self.dp_size), + "--data-parallel-size-local", + str(self.dp_size_local), + "--data-parallel-start-rank", + str(start_rank), + "--data-parallel-hybrid-lb", # Enable hybrid LB mode + "--tensor-parallel-size", + str(self.tp_size), + "--port", + str(8000 + node_id), # Different port for each node + "--api-server-count", + str(self.api_server_count), + "--data-parallel-address", + "127.0.0.1", + "--data-parallel-rpc-port", + "13345", + ]) + + # Use a thread to start each server to allow parallel initialization + def start_server(node: int, sargs: list[str]): + try: + # Calculate GPU devices for this node + gpus_per_node = self.dp_size_local * self.tp_size + gpu_start = node * gpus_per_node + gpu_end = gpu_start + gpus_per_node + + # Start the server + server = RemoteOpenAIServer( + self.model_name, + sargs, + auto_port=False, + env_dict={ + "CUDA_VISIBLE_DEVICES": + ",".join( + str(Platform.device_id_to_physical_device_id( + i)) for i in range(gpu_start, gpu_end)) + }) + server.__enter__() + print(f"Hybrid LB node {node} started successfully with " + f"{self.dp_size_local} local DP ranks and " + f"{self.api_server_count} API servers") + self.servers.append((server, sargs)) + except Exception as e: + print(f"Failed to start hybrid LB node {node}: {e}") + raise + + thread = threading.Thread(target=start_server, + args=(node_id, server_args)) + thread.start() + + self.server_threads.append(thread) + + # Wait for all servers to start + for thread in self.server_threads: + thread.join() + + # Give servers additional time to fully initialize and coordinate + time.sleep(3) + + if len(self.servers) != self.num_nodes: + raise Exception("Servers failed to start") + + return self.servers + + def __exit__(self, exc_type, exc_val, exc_tb): + """Stop all server instances.""" + while self.servers: + try: + self.servers.pop()[0].__exit__(exc_type, exc_val, exc_tb) + except Exception as e: + print(f"Error stopping server: {e}") + + +@pytest.fixture(scope="module") +def default_server_args(): + return [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--max-num-seqs", + "128", + "--enforce-eager", + ] + + +@pytest.fixture(scope="module", params=[1]) # Only 1 API server for now +def servers(request, default_server_args): + api_server_count = request.param + with HybridLBServerManager(MODEL_NAME, DP_SIZE, api_server_count, + default_server_args, DP_SIZE_LOCAL, + TP_SIZE) as server_list: + yield server_list + + +@pytest_asyncio.fixture +async def clients(servers: list[tuple[RemoteOpenAIServer, list[str]]]): + # Create a client for each node (each node has its own API endpoint) + async with AsyncExitStack() as stack: + yield [ + await stack.enter_async_context(server.get_async_client()) + for server, _ in servers + ] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_hybrid_lb_completion(clients: list[openai.AsyncOpenAI], + servers: list[tuple[RemoteOpenAIServer, + list[str]]], + model_name: str) -> None: + + async def make_request(client: openai.AsyncOpenAI): + completion = await client.completions.create( + model=model_name, + prompt="Hello, my name is", + max_tokens=10, + temperature=1.0) + + assert completion.id is not None + assert completion.choices is not None and len(completion.choices) == 1 + + choice = completion.choices[0] + # The exact number of tokens can vary slightly with temperature=1.0, + # so we check for a reasonable minimum length. + assert len(choice.text) >= 1 + # Finish reason might not always be 'length' if the model finishes early + # or due to other reasons, especially with high temperature. + # So, we'll accept 'length' or 'stop'. + assert choice.finish_reason in ("length", "stop") + + # Token counts can also vary, so we check they are positive. + assert completion.usage.completion_tokens > 0 + assert completion.usage.prompt_tokens > 0 + assert completion.usage.total_tokens > 0 + return completion + + # Test single request to each node + for i, client in enumerate(clients): + result = await make_request(client) + assert result is not None + print( + f"Hybrid LB node {i} handled single completion request successfully" + ) + + await asyncio.sleep(0.5) + + # Send requests to all nodes - each should balance within its local DP ranks + num_requests_per_node = 25 # Total 50 requests across 2 nodes + all_tasks = [] + + for i, client in enumerate(clients): + tasks = [make_request(client) for _ in range(num_requests_per_node)] + all_tasks.extend(tasks) + + results = await asyncio.gather(*all_tasks) + assert len(results) == num_requests_per_node * len(clients) + assert all(completion is not None for completion in results) + + await asyncio.sleep(0.5) + + # Second burst of requests + all_tasks = [] + for i, client in enumerate(clients): + tasks = [make_request(client) for _ in range(num_requests_per_node)] + all_tasks.extend(tasks) + + results = await asyncio.gather(*all_tasks) + assert len(results) == num_requests_per_node * len(clients) + assert all(completion is not None for completion in results) + + _, server_args = servers[0] + api_server_count = ( + server_args.count('--api-server-count') + and server_args[server_args.index('--api-server-count') + 1] or 1) + print( + f"Successfully completed hybrid LB test with {len(clients)} nodes " + f"({DP_SIZE_LOCAL} DP ranks each, API server count: {api_server_count})" + ) + + # Check request balancing within each node + for i, (server, _) in enumerate(servers): + print(f"Checking request balancing for node {i}") + check_request_balancing(server, DP_SIZE_LOCAL) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_hybrid_lb_completion_streaming(clients: list[ + openai.AsyncOpenAI], servers: list[tuple[RemoteOpenAIServer, list[str]]], + model_name: str) -> None: + prompt = "What is an LLM?" + + async def make_streaming_request(client: openai.AsyncOpenAI): + # Perform a non-streaming request to get the expected full output + single_completion = await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + ) + single_output = single_completion.choices[0].text + + # Perform the streaming request + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True) + chunks: list[str] = [] + finish_reason_count = 0 + last_chunk = None + async for chunk in stream: + chunks.append(chunk.choices[0].text) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + last_chunk = chunk # Keep track of the last chunk + + # finish reason should only return in the last block for OpenAI API + assert finish_reason_count == 1, ( + "Finish reason should appear exactly once.") + assert last_chunk is not None, ( + "Stream should have yielded at least one chunk.") + assert last_chunk.choices[ + 0].finish_reason == "length", "Finish reason should be 'length'." + # Check that the combined text matches the non-streamed version. + assert "".join( + chunks + ) == single_output, "Streamed output should match non-streamed output." + return True # Indicate success for this request + + # Test single request to each node + for i, client in enumerate(clients): + result = await make_streaming_request(client) + assert result is not None + print( + f"Hybrid LB node {i} handled single streaming request successfully" + ) + + await asyncio.sleep(0.5) + + # Send streaming requests to all nodes + num_requests_per_node = 25 # Total 50 requests across 2 nodes + all_tasks = [] + + for i, client in enumerate(clients): + tasks = [ + make_streaming_request(client) + for _ in range(num_requests_per_node) + ] + all_tasks.extend(tasks) + + results = await asyncio.gather(*all_tasks) + assert len(results) == num_requests_per_node * len(clients) + assert all(results), "Not all streaming requests completed successfully." + + await asyncio.sleep(0.5) + + # Second burst of streaming requests + all_tasks = [] + for i, client in enumerate(clients): + tasks = [ + make_streaming_request(client) + for _ in range(num_requests_per_node) + ] + all_tasks.extend(tasks) + + results = await asyncio.gather(*all_tasks) + assert len(results) == num_requests_per_node * len(clients) + assert all(results), "Not all streaming requests completed successfully." + + _, server_args = servers[0] + api_server_count = ( + server_args.count('--api-server-count') + and server_args[server_args.index('--api-server-count') + 1] or 1) + print(f"Successfully completed hybrid LB streaming test with " + f"{len(clients)} nodes ({DP_SIZE_LOCAL} DP ranks each, " + f"API server count: {api_server_count})") + + # Check request balancing within each node + for i, (server, _) in enumerate(servers): + print(f"Checking streaming request balancing for node {i}") + check_request_balancing(server, DP_SIZE_LOCAL) diff --git a/vllm/config.py b/vllm/config.py index 7593b1c3e27a..f038cdd64c67 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1908,8 +1908,16 @@ class ParallelConfig: """Backend to use for data parallel, either "mp" or "ray".""" data_parallel_external_lb: bool = False """Whether to use "external" DP LB mode. Applies only to online serving - and when data_parallel_size > 0. Set implicitly when - data_parallel_rank is provided explicitly to vllm serve.""" + and when data_parallel_size > 0. This is useful for a "one-pod-per-rank" + wide-EP setup in Kuberentes. Set implicitly when --data-parallel-rank + is provided explicitly to vllm serve.""" + data_parallel_hybrid_lb: bool = False + """Whether to use "hybrid" DP LB mode. Applies only to online serving + and when data_parallel_size > 0. Enables running an AsyncLLM + and API server on a "per-node" basis where vLLM load balances + between local data parallel ranks, but an external LB balances + between vLLM nodes/replicas. Set explicitly in conjunction with + --data-parallel-start-rank.""" enable_expert_parallel: bool = False """Use expert parallelism instead of tensor parallelism for MoE layers.""" enable_eplb: bool = False diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 62792fade4ed..aec75f82631a 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -295,9 +295,11 @@ class EngineArgs: tensor_parallel_size: int = ParallelConfig.tensor_parallel_size data_parallel_size: int = ParallelConfig.data_parallel_size data_parallel_rank: Optional[int] = None + data_parallel_start_rank: Optional[int] = None data_parallel_size_local: Optional[int] = None data_parallel_address: Optional[str] = None data_parallel_rpc_port: Optional[int] = None + data_parallel_hybrid_lb: bool = False data_parallel_backend: str = ParallelConfig.data_parallel_backend enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel enable_eplb: bool = ParallelConfig.enable_eplb @@ -604,6 +606,11 @@ class EngineArgs: type=int, help='Data parallel rank of this instance. ' 'When set, enables external load balancer mode.') + parallel_group.add_argument('--data-parallel-start-rank', + '-dpr', + type=int, + help='Starting data parallel rank ' + 'for secondary nodes.') parallel_group.add_argument('--data-parallel-size-local', '-dpl', type=int, @@ -625,6 +632,9 @@ class EngineArgs: default='mp', help='Backend for data parallel, either ' '"mp" or "ray".') + parallel_group.add_argument( + "--data-parallel-hybrid-lb", + **parallel_kwargs["data_parallel_hybrid_lb"]) parallel_group.add_argument( "--enable-expert-parallel", **parallel_kwargs["enable_expert_parallel"]) @@ -972,6 +982,7 @@ class EngineArgs: def create_engine_config( self, usage_context: Optional[UsageContext] = None, + headless: bool = False, ) -> VllmConfig: """ Create the VllmConfig. @@ -1060,15 +1071,41 @@ class EngineArgs: # but we should not do this here. placement_group = ray.util.get_current_placement_group() + assert not headless or not self.data_parallel_hybrid_lb, ( + "data_parallel_hybrid_lb is not applicable in " + "headless mode") + data_parallel_external_lb = self.data_parallel_rank is not None + # Local DP rank = 1, use pure-external LB. if data_parallel_external_lb: assert self.data_parallel_size_local in (1, None), ( "data_parallel_size_local must be 1 when data_parallel_rank " "is set") data_parallel_size_local = 1 + # Use full external lb if we have local_size of 1. + self.data_parallel_hybrid_lb = False elif self.data_parallel_size_local is not None: data_parallel_size_local = self.data_parallel_size_local + + if self.data_parallel_start_rank and not headless: + # Infer hybrid LB mode. + self.data_parallel_hybrid_lb = True + + if self.data_parallel_hybrid_lb and data_parallel_size_local == 1: + # Use full external lb if we have local_size of 1. + data_parallel_external_lb = True + self.data_parallel_hybrid_lb = False + + if data_parallel_size_local == self.data_parallel_size: + # Disable hybrid LB mode if set for a single node + self.data_parallel_hybrid_lb = False + + self.data_parallel_rank = self.data_parallel_start_rank or 0 else: + assert not self.data_parallel_hybrid_lb, ( + "data_parallel_size_local must be set to use " + "data_parallel_hybrid_lb.") + # Local DP size defaults to global DP size if not set. data_parallel_size_local = self.data_parallel_size @@ -1125,6 +1162,7 @@ class EngineArgs: data_parallel_master_ip=data_parallel_address, data_parallel_rpc_port=data_parallel_rpc_port, data_parallel_backend=self.data_parallel_backend, + data_parallel_hybrid_lb=self.data_parallel_hybrid_lb, enable_expert_parallel=self.enable_expert_parallel, enable_eplb=self.enable_eplb, num_redundant_experts=self.num_redundant_experts, diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 1204ccc1c679..72460c2d91c7 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -45,11 +45,6 @@ class ServeSubcommand(CLISubcommand): if args.headless or args.api_server_count < 1: run_headless(args) else: - if args.data_parallel_start_rank: - raise ValueError( - "data_parallel_start_rank is only applicable " - "in headless mode. " - "Add --headless flag to enable headless mode.") if args.api_server_count > 1: run_multi_api_server(args) else: @@ -86,13 +81,14 @@ def run_headless(args: argparse.Namespace): # Create the EngineConfig. engine_args = vllm.AsyncEngineArgs.from_cli_args(args) usage_context = UsageContext.OPENAI_API_SERVER - vllm_config = engine_args.create_engine_config(usage_context=usage_context) + vllm_config = engine_args.create_engine_config(usage_context=usage_context, + headless=True) if not envs.VLLM_USE_V1: raise ValueError("Headless mode is only supported for V1") - if engine_args.data_parallel_rank is not None: - raise ValueError("data_parallel_rank is not applicable in " + if engine_args.data_parallel_hybrid_lb: + raise ValueError("data_parallel_hybrid_lb is not applicable in " "headless mode") parallel_config = vllm_config.parallel_config @@ -122,7 +118,7 @@ def run_headless(args: argparse.Namespace): engine_manager = CoreEngineProcManager( target_fn=EngineCoreProc.run_engine_core, local_engine_count=local_engine_count, - start_index=args.data_parallel_start_rank, + start_index=vllm_config.parallel_config.data_parallel_rank, local_start_index=0, vllm_config=vllm_config, local_client=False, @@ -169,6 +165,11 @@ def run_multi_api_server(args: argparse.Namespace): " api_server_count > 1") model_config.disable_mm_preprocessor_cache = True + if vllm_config.parallel_config.data_parallel_hybrid_lb: + raise NotImplementedError( + "Hybrid load balancing with --api-server-count > 0" + "is not yet supported.") + executor_class = Executor.get_class(vllm_config) log_stats = not engine_args.disable_log_stats diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index b18148666648..3025a6263682 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -222,13 +222,6 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=False, help="Run in headless mode. See multi-node data parallel " "documentation for more details.") - parser.add_argument( - "--data-parallel-start-rank", - "-dpr", - type=int, - default=0, - help="Starting data parallel rank for secondary nodes. " - "Requires --headless.") parser.add_argument("--api-server-count", "-asc", type=int, diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 66e76777d75e..02cb80197fa4 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -127,7 +127,7 @@ class AsyncLLM(EngineClient): if self.log_stats: self.logger_manager = StatLoggerManager( vllm_config=vllm_config, - engine_idxs=self.engine_core.engine_ranks, + engine_idxs=self.engine_core.engine_ranks_managed, custom_stat_loggers=stat_loggers, ) self.logger_manager.log_engine_initialized() diff --git a/vllm/v1/engine/coordinator.py b/vllm/v1/engine/coordinator.py index 005e71647aae..c0decd6ffa2c 100644 --- a/vllm/v1/engine/coordinator.py +++ b/vllm/v1/engine/coordinator.py @@ -61,11 +61,12 @@ class DPCoordinator: host = parallel_config.data_parallel_master_ip external_lb = parallel_config.data_parallel_external_lb + hybrid_lb = parallel_config.data_parallel_hybrid_lb # Assume coordinator is colocated with front-end procs when not in - # external DP LB mode. + # either external or hybrid DP LB mode. front_publish_address = get_engine_client_zmq_addr( - local_only=not external_lb, host=host) + local_only=not external_lb and not hybrid_lb, host=host) local_only_eng = dp_size == parallel_config.data_parallel_size_local back_publish_address = get_engine_client_zmq_addr(local_only_eng, host) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index ca636bf5a6f7..4a971e0b3120 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -467,13 +467,14 @@ class EngineCoreProc(EngineCore): For DP>1 with internal loadbalancing this is with the shared front-end process which may reside on a different node. - For DP>1 with external loadbalancing, two handshakes are performed: + For DP>1 with external or hybrid loadbalancing, two handshakes are + performed: - With the rank 0 front-end process which retrieves the DP Coordinator ZMQ addresses and DP process group address. - With the colocated front-end process which retrieves the client input/output socket addresses. - with the exception of the rank 0 engine itself which doesn't require - the second handshake. + with the exception of the rank 0 and colocated engines themselves which + don't require the second handshake. Here, "front-end" process can mean the process containing the engine core client (which is the API server process in the case the API @@ -482,15 +483,18 @@ class EngineCoreProc(EngineCore): """ input_ctx = zmq.Context() is_local = local_client and client_handshake_address is None + headless = not local_client handshake = self._perform_handshake(input_ctx, handshake_address, - identity, is_local, vllm_config, + identity, is_local, headless, + vllm_config, vllm_config.parallel_config) if client_handshake_address is None: with handshake as addresses: yield addresses else: + assert local_client local_handshake = self._perform_handshake( - input_ctx, client_handshake_address, identity, local_client, + input_ctx, client_handshake_address, identity, True, False, vllm_config) with handshake as addresses, local_handshake as client_addresses: addresses.inputs = client_addresses.inputs @@ -507,6 +511,7 @@ class EngineCoreProc(EngineCore): handshake_address: str, identity: bytes, local_client: bool, + headless: bool, vllm_config: VllmConfig, parallel_config_to_update: Optional[ParallelConfig] = None, ) -> Generator[EngineZmqAddresses, None, None]: @@ -518,6 +523,7 @@ class EngineCoreProc(EngineCore): bind=False) as handshake_socket: # Register engine with front-end. addresses = self.startup_handshake(handshake_socket, local_client, + headless, parallel_config_to_update) yield addresses @@ -531,6 +537,7 @@ class EngineCoreProc(EngineCore): msgspec.msgpack.encode({ "status": "READY", "local": local_client, + "headless": headless, "num_gpu_blocks": num_gpu_blocks, "dp_stats_address": dp_stats_address, })) @@ -539,6 +546,7 @@ class EngineCoreProc(EngineCore): def startup_handshake( handshake_socket: zmq.Socket, local_client: bool, + headless: bool, parallel_config: Optional[ParallelConfig] = None, ) -> EngineZmqAddresses: @@ -547,6 +555,7 @@ class EngineCoreProc(EngineCore): msgspec.msgpack.encode({ "status": "HELLO", "local": local_client, + "headless": headless, })) # Receive initialization message. diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 2ebb76a97ebe..69ae3690d00e 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -429,18 +429,23 @@ class MPClient(EngineCoreClient): parallel_config = vllm_config.parallel_config dp_size = parallel_config.data_parallel_size dp_rank = parallel_config.data_parallel_rank - external_dp_lb = parallel_config.data_parallel_external_lb - + dp_local_size = parallel_config.data_parallel_size_local offline_mode = parallel_config.data_parallel_rank_local is not None - self.engine_ranks = ([dp_rank] if - (offline_mode or external_dp_lb) else list( - range(dp_size))) + # Client manages local+remote EngineCores in pure internal LB case. + # Client manages local EngineCores in hybrid and external LB case. + local_engines_only = (parallel_config.data_parallel_hybrid_lb + or parallel_config.data_parallel_external_lb) + + num_ranks = dp_local_size if local_engines_only else dp_size + self.engine_ranks_managed = [dp_rank] if offline_mode else list( + range(dp_rank, dp_rank + num_ranks)) assert parallel_config.data_parallel_size_local <= len( - self.engine_ranks) + self.engine_ranks_managed) # ZMQ identity of each engine that this client will talk to. self.core_engines: list[EngineIdentity] = [ - index.to_bytes(2, "little") for index in self.engine_ranks + rank.to_bytes(2, "little") + for rank in self.engine_ranks_managed ] # Wait for ready messages from each engine on the input socket. @@ -895,6 +900,12 @@ class DPAsyncMPClient(AsyncMPClient): return assert self.stats_update_address is not None + assert len(self.engine_ranks_managed) > 0 + # NOTE: running and waiting counts are all global from + # the Coordinator include all global EngineCores. This + # slice includes just the cores managed by this client. + count_slice = slice(self.engine_ranks_managed[0], + self.engine_ranks_managed[-1] + 1) async def run_engine_stats_update_task(): with make_zmq_socket(self.ctx, self.stats_update_address, @@ -959,7 +970,7 @@ class DPAsyncMPClient(AsyncMPClient): counts, wave, running = msgspec.msgpack.decode(buf) self.current_wave = wave self.engines_running = running - self.lb_engines = counts + self.lb_engines = counts[count_slice] resources.stats_update_task = asyncio.create_task( run_engine_stats_update_task()) diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index 6dde477576b8..092b5b90bb57 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -544,7 +544,8 @@ def launch_core_engines( local_start_index = parallel_config.data_parallel_rank_local dp_rank = parallel_config.data_parallel_rank host = parallel_config.data_parallel_master_ip - external_dp_lb = parallel_config.data_parallel_external_lb + local_engines_only = (parallel_config.data_parallel_hybrid_lb + or parallel_config.data_parallel_external_lb) # In offline mode there is an LLM instance per DP rank and # one core engine per LLM, see @@ -553,8 +554,8 @@ def launch_core_engines( # client_local_only = True for cases where this front-end # sends requests only to colocated engines. - client_local_only = offline_mode or external_dp_lb or (local_engine_count - == dp_size) + client_local_only = (offline_mode or local_engines_only + or (local_engine_count == dp_size)) # Set up input and output addresses. addresses = EngineZmqAddresses( @@ -598,14 +599,27 @@ def launch_core_engines( yield engine_actor_manager, coordinator, addresses return - if offline_mode or (external_dp_lb and dp_rank > 0): + if offline_mode: assert local_engine_count == 1 engines_to_handshake = [CoreEngine(index=dp_rank, local=True)] - else: + elif dp_rank == 0: + # Rank 0 holds Coordinator, so it handshakes with all Cores + # in both external dplb and internal dplb mode. + # Note this also covers the case where we have zero local engines + # and rank 0 is headless. engines_to_handshake = [ CoreEngine(index=i, local=(i < local_engine_count)) for i in range(dp_size) ] + else: + # Rank > 0 handshakes with just the local cores it is managing. + assert local_engines_only, ( + "Attempting to launch core_engines from dp_rank > 0, but " + "found internal DPLB, which is incompatible.") + engines_to_handshake = [ + CoreEngine(index=i, local=True) + for i in range(dp_rank, dp_rank + local_engine_count) + ] # Whether the started engines will handshake only with co-located # front-end processes. In external_dp_lb mode, ranks > 0 handshake with @@ -616,7 +630,7 @@ def launch_core_engines( handshake_address = get_engine_client_zmq_addr( handshake_local_only, host, parallel_config.data_parallel_rpc_port) - if external_dp_lb and dp_rank > 0: + if local_engines_only and dp_rank > 0: assert not handshake_local_only local_handshake_address = get_open_zmq_ipc_path() client_handshake_address = local_handshake_address @@ -631,8 +645,6 @@ def launch_core_engines( # Start local engines. if local_engine_count: - # In server mode, start_index and local_start_index will - # both be 0. local_engine_manager = CoreEngineProcManager( EngineCoreProc.run_engine_core, vllm_config=vllm_config, @@ -678,6 +690,9 @@ def wait_for_engine_startup( poller = zmq.Poller() poller.register(handshake_socket, zmq.POLLIN) + remote_should_be_headless = not parallel_config.data_parallel_hybrid_lb \ + and not parallel_config.data_parallel_external_lb + if proc_manager is not None: for sentinel in proc_manager.sentinels(): poller.register(sentinel, zmq.POLLIN) @@ -713,13 +728,24 @@ def wait_for_engine_startup( raise RuntimeError(f"Message from engine with unexpected data " f"parallel rank: {eng_index}") msg = msgspec.msgpack.decode(ready_msg_bytes) - status, local = msg["status"], msg["local"] + status, local, headless = msg["status"], msg["local"], msg["headless"] if local != engine.local: raise RuntimeError(f"{status} message from " f"{'local' if local else 'remote'} " f"engine {eng_index}, expected it to be " f"{'local' if engine.local else 'remote'}") + # Remote engines must be headless iff we aren't in hybrid dp lb mode. + if not local and headless != remote_should_be_headless: + if headless: + raise RuntimeError(f"Remote engine {eng_index} must not use " + f"--headless in external or hybrid dp lb " + f"mode") + else: + raise RuntimeError(f"Remote engine {eng_index} must use " + f"--headless unless in external or hybrid " + f"dp lb mode") + if status == "HELLO" and engine.state == CoreEngineState.NEW: # Send init message with DP config info.