diff --git a/tests/v1/test_internal_lb_dp.py b/tests/v1/test_internal_lb_dp.py index ca80d3a4949d..2b031865cad7 100644 --- a/tests/v1/test_internal_lb_dp.py +++ b/tests/v1/test_internal_lb_dp.py @@ -4,6 +4,8 @@ import asyncio import os import threading import time +import traceback +from typing import Optional, cast import openai # use the official client for correctness check import pytest @@ -41,12 +43,15 @@ class MultinodeInternalLBServerManager: self.tp_size = tp_size self.api_server_count = api_server_count self.base_server_args = base_server_args - self.servers: list[tuple[RemoteOpenAIServer, list[str]]] = [] + self.servers: list[Optional[tuple[RemoteOpenAIServer, + list[str]]]] = [None] * (dp_size // + dp_per_node) self.server_threads: list[threading.Thread] = [] def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]: """Start all server instances for multi-node internal LB mode.""" - for rank in range(0, self.dp_size, self.dp_per_node): + for server_idx, rank in enumerate( + range(0, self.dp_size, self.dp_per_node)): # Create server args for this specific rank server_args = self.base_server_args.copy() @@ -87,7 +92,7 @@ class MultinodeInternalLBServerManager: ]) # Use a thread to start each server to allow parallel initialization - def start_server(r: int, sargs: list[str]): + def start_server(sidx: int, r: int, sargs: list[str]): gpus_per_node = self.tp_size * self.dp_per_node try: # Start the server @@ -110,13 +115,14 @@ class MultinodeInternalLBServerManager: f"{self.api_server_count} API servers") else: print(f"Headless node (rank {r}) started successfully") - self.servers.append((server, sargs)) + self.servers[sidx] = (server, sargs) except Exception as e: print(f"Failed to start server rank {r}: {e}") + traceback.print_exc() raise thread = threading.Thread(target=start_server, - args=(rank, server_args)) + args=(server_idx, rank, server_args)) thread.start() self.server_threads.append(thread) @@ -128,18 +134,20 @@ class MultinodeInternalLBServerManager: # Give servers additional time to fully initialize and coordinate time.sleep(3) - if len(self.servers) != self.dp_size // self.dp_per_node: + if not all(self.servers): raise Exception("Servers failed to start") - return self.servers + return cast(list[tuple[RemoteOpenAIServer, list[str]]], self.servers) def __exit__(self, exc_type, exc_val, exc_tb): """Stop all server instances.""" while self.servers: - try: - self.servers.pop()[0].__exit__(exc_type, exc_val, exc_tb) - except Exception as e: - print(f"Error stopping server: {e}") + if server := self.servers.pop(): + try: + server[0].__exit__(exc_type, exc_val, exc_tb) + except Exception as e: + print(f"Error stopping server: {e}") + traceback.print_exc() class APIOnlyServerManager: @@ -157,7 +165,8 @@ class APIOnlyServerManager: self.tp_size = tp_size self.api_server_count = api_server_count self.base_server_args = base_server_args - self.servers: list[tuple[RemoteOpenAIServer, list[str]]] = [] + self.servers: list[Optional[tuple[RemoteOpenAIServer, + list[str]]]] = [None] * 2 self.server_threads: list[threading.Thread] = [] def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]: @@ -209,7 +218,7 @@ class APIOnlyServerManager: server.__enter__() print(f"API-only server started successfully with " f"{self.api_server_count} API servers") - self.servers.append((server, api_server_args)) + self.servers[0] = (server, api_server_args) except Exception as e: print(f"Failed to start API-only server: {e}") raise @@ -231,7 +240,7 @@ class APIOnlyServerManager: server.__enter__() print(f"Headless engines server started successfully with " f"{self.dp_size} engines") - self.servers.append((server, engines_server_args)) + self.servers[1] = (server, engines_server_args) except Exception as e: print(f"Failed to start headless engines server: {e}") raise @@ -253,18 +262,20 @@ class APIOnlyServerManager: # Give servers additional time to fully initialize and coordinate time.sleep(3) - if len(self.servers) != 2: + if not all(self.servers): raise Exception("Both servers failed to start") - return self.servers + return cast(list[tuple[RemoteOpenAIServer, list[str]]], self.servers) def __exit__(self, exc_type, exc_val, exc_tb): """Stop both server instances.""" while self.servers: - try: - self.servers.pop()[0].__exit__(exc_type, exc_val, exc_tb) - except Exception as e: - print(f"Error stopping server: {e}") + if server := self.servers.pop(): + try: + server[0].__exit__(exc_type, exc_val, exc_tb) + except Exception as e: + print(f"Error stopping server: {e}") + traceback.print_exc() @pytest.fixture(scope="module") @@ -560,7 +571,7 @@ async def test_api_only_multinode_dp_completion( assert len(results) == num_requests assert all(completion is not None for completion in results) - _, api_server_args = api_only_servers[0] + api_server, api_server_args = api_only_servers[0] api_server_count = ( api_server_args.count('--api-server-count') and api_server_args[api_server_args.index('--api-server-count') + 1] @@ -569,7 +580,6 @@ async def test_api_only_multinode_dp_completion( f"engines on headless server (API server count: {api_server_count})") # Check request balancing via Prometheus metrics - api_server = api_only_servers[0][0] check_request_balancing(api_server, DP_SIZE)