[BugFix] Fix port lookup in internal DP LB tests (#22252)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-08-14 20:17:11 -07:00 committed by GitHub
parent 0933f9d518
commit ae05a6d83d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,6 +4,8 @@ import asyncio
import os import os
import threading import threading
import time import time
import traceback
from typing import Optional, cast
import openai # use the official client for correctness check import openai # use the official client for correctness check
import pytest import pytest
@ -41,12 +43,15 @@ class MultinodeInternalLBServerManager:
self.tp_size = tp_size self.tp_size = tp_size
self.api_server_count = api_server_count self.api_server_count = api_server_count
self.base_server_args = base_server_args self.base_server_args = base_server_args
self.servers: list[tuple[RemoteOpenAIServer, list[str]]] = [] self.servers: list[Optional[tuple[RemoteOpenAIServer,
list[str]]]] = [None] * (dp_size //
dp_per_node)
self.server_threads: list[threading.Thread] = [] self.server_threads: list[threading.Thread] = []
def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]: def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]:
"""Start all server instances for multi-node internal LB mode.""" """Start all server instances for multi-node internal LB mode."""
for rank in range(0, self.dp_size, self.dp_per_node): for server_idx, rank in enumerate(
range(0, self.dp_size, self.dp_per_node)):
# Create server args for this specific rank # Create server args for this specific rank
server_args = self.base_server_args.copy() server_args = self.base_server_args.copy()
@ -87,7 +92,7 @@ class MultinodeInternalLBServerManager:
]) ])
# Use a thread to start each server to allow parallel initialization # Use a thread to start each server to allow parallel initialization
def start_server(r: int, sargs: list[str]): def start_server(sidx: int, r: int, sargs: list[str]):
gpus_per_node = self.tp_size * self.dp_per_node gpus_per_node = self.tp_size * self.dp_per_node
try: try:
# Start the server # Start the server
@ -110,13 +115,14 @@ class MultinodeInternalLBServerManager:
f"{self.api_server_count} API servers") f"{self.api_server_count} API servers")
else: else:
print(f"Headless node (rank {r}) started successfully") print(f"Headless node (rank {r}) started successfully")
self.servers.append((server, sargs)) self.servers[sidx] = (server, sargs)
except Exception as e: except Exception as e:
print(f"Failed to start server rank {r}: {e}") print(f"Failed to start server rank {r}: {e}")
traceback.print_exc()
raise raise
thread = threading.Thread(target=start_server, thread = threading.Thread(target=start_server,
args=(rank, server_args)) args=(server_idx, rank, server_args))
thread.start() thread.start()
self.server_threads.append(thread) self.server_threads.append(thread)
@ -128,18 +134,20 @@ class MultinodeInternalLBServerManager:
# Give servers additional time to fully initialize and coordinate # Give servers additional time to fully initialize and coordinate
time.sleep(3) time.sleep(3)
if len(self.servers) != self.dp_size // self.dp_per_node: if not all(self.servers):
raise Exception("Servers failed to start") raise Exception("Servers failed to start")
return self.servers return cast(list[tuple[RemoteOpenAIServer, list[str]]], self.servers)
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
"""Stop all server instances.""" """Stop all server instances."""
while self.servers: while self.servers:
try: if server := self.servers.pop():
self.servers.pop()[0].__exit__(exc_type, exc_val, exc_tb) try:
except Exception as e: server[0].__exit__(exc_type, exc_val, exc_tb)
print(f"Error stopping server: {e}") except Exception as e:
print(f"Error stopping server: {e}")
traceback.print_exc()
class APIOnlyServerManager: class APIOnlyServerManager:
@ -157,7 +165,8 @@ class APIOnlyServerManager:
self.tp_size = tp_size self.tp_size = tp_size
self.api_server_count = api_server_count self.api_server_count = api_server_count
self.base_server_args = base_server_args self.base_server_args = base_server_args
self.servers: list[tuple[RemoteOpenAIServer, list[str]]] = [] self.servers: list[Optional[tuple[RemoteOpenAIServer,
list[str]]]] = [None] * 2
self.server_threads: list[threading.Thread] = [] self.server_threads: list[threading.Thread] = []
def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]: def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]:
@ -209,7 +218,7 @@ class APIOnlyServerManager:
server.__enter__() server.__enter__()
print(f"API-only server started successfully with " print(f"API-only server started successfully with "
f"{self.api_server_count} API servers") f"{self.api_server_count} API servers")
self.servers.append((server, api_server_args)) self.servers[0] = (server, api_server_args)
except Exception as e: except Exception as e:
print(f"Failed to start API-only server: {e}") print(f"Failed to start API-only server: {e}")
raise raise
@ -231,7 +240,7 @@ class APIOnlyServerManager:
server.__enter__() server.__enter__()
print(f"Headless engines server started successfully with " print(f"Headless engines server started successfully with "
f"{self.dp_size} engines") f"{self.dp_size} engines")
self.servers.append((server, engines_server_args)) self.servers[1] = (server, engines_server_args)
except Exception as e: except Exception as e:
print(f"Failed to start headless engines server: {e}") print(f"Failed to start headless engines server: {e}")
raise raise
@ -253,18 +262,20 @@ class APIOnlyServerManager:
# Give servers additional time to fully initialize and coordinate # Give servers additional time to fully initialize and coordinate
time.sleep(3) time.sleep(3)
if len(self.servers) != 2: if not all(self.servers):
raise Exception("Both servers failed to start") raise Exception("Both servers failed to start")
return self.servers return cast(list[tuple[RemoteOpenAIServer, list[str]]], self.servers)
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
"""Stop both server instances.""" """Stop both server instances."""
while self.servers: while self.servers:
try: if server := self.servers.pop():
self.servers.pop()[0].__exit__(exc_type, exc_val, exc_tb) try:
except Exception as e: server[0].__exit__(exc_type, exc_val, exc_tb)
print(f"Error stopping server: {e}") except Exception as e:
print(f"Error stopping server: {e}")
traceback.print_exc()
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
@ -560,7 +571,7 @@ async def test_api_only_multinode_dp_completion(
assert len(results) == num_requests assert len(results) == num_requests
assert all(completion is not None for completion in results) assert all(completion is not None for completion in results)
_, api_server_args = api_only_servers[0] api_server, api_server_args = api_only_servers[0]
api_server_count = ( api_server_count = (
api_server_args.count('--api-server-count') api_server_args.count('--api-server-count')
and api_server_args[api_server_args.index('--api-server-count') + 1] and api_server_args[api_server_args.index('--api-server-count') + 1]
@ -569,7 +580,6 @@ async def test_api_only_multinode_dp_completion(
f"engines on headless server (API server count: {api_server_count})") f"engines on headless server (API server count: {api_server_count})")
# Check request balancing via Prometheus metrics # Check request balancing via Prometheus metrics
api_server = api_only_servers[0][0]
check_request_balancing(api_server, DP_SIZE) check_request_balancing(api_server, DP_SIZE)