[BugFix] Fix port lookup in internal DP LB tests (#22252)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-08-14 20:17:11 -07:00 committed by GitHub
parent 0933f9d518
commit ae05a6d83d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,6 +4,8 @@ import asyncio
import os
import threading
import time
import traceback
from typing import Optional, cast
import openai # use the official client for correctness check
import pytest
@ -41,12 +43,15 @@ class MultinodeInternalLBServerManager:
self.tp_size = tp_size
self.api_server_count = api_server_count
self.base_server_args = base_server_args
self.servers: list[tuple[RemoteOpenAIServer, list[str]]] = []
self.servers: list[Optional[tuple[RemoteOpenAIServer,
list[str]]]] = [None] * (dp_size //
dp_per_node)
self.server_threads: list[threading.Thread] = []
def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]:
"""Start all server instances for multi-node internal LB mode."""
for rank in range(0, self.dp_size, self.dp_per_node):
for server_idx, rank in enumerate(
range(0, self.dp_size, self.dp_per_node)):
# Create server args for this specific rank
server_args = self.base_server_args.copy()
@ -87,7 +92,7 @@ class MultinodeInternalLBServerManager:
])
# Use a thread to start each server to allow parallel initialization
def start_server(r: int, sargs: list[str]):
def start_server(sidx: int, r: int, sargs: list[str]):
gpus_per_node = self.tp_size * self.dp_per_node
try:
# Start the server
@ -110,13 +115,14 @@ class MultinodeInternalLBServerManager:
f"{self.api_server_count} API servers")
else:
print(f"Headless node (rank {r}) started successfully")
self.servers.append((server, sargs))
self.servers[sidx] = (server, sargs)
except Exception as e:
print(f"Failed to start server rank {r}: {e}")
traceback.print_exc()
raise
thread = threading.Thread(target=start_server,
args=(rank, server_args))
args=(server_idx, rank, server_args))
thread.start()
self.server_threads.append(thread)
@ -128,18 +134,20 @@ class MultinodeInternalLBServerManager:
# Give servers additional time to fully initialize and coordinate
time.sleep(3)
if len(self.servers) != self.dp_size // self.dp_per_node:
if not all(self.servers):
raise Exception("Servers failed to start")
return self.servers
return cast(list[tuple[RemoteOpenAIServer, list[str]]], self.servers)
def __exit__(self, exc_type, exc_val, exc_tb):
"""Stop all server instances."""
while self.servers:
try:
self.servers.pop()[0].__exit__(exc_type, exc_val, exc_tb)
except Exception as e:
print(f"Error stopping server: {e}")
if server := self.servers.pop():
try:
server[0].__exit__(exc_type, exc_val, exc_tb)
except Exception as e:
print(f"Error stopping server: {e}")
traceback.print_exc()
class APIOnlyServerManager:
@ -157,7 +165,8 @@ class APIOnlyServerManager:
self.tp_size = tp_size
self.api_server_count = api_server_count
self.base_server_args = base_server_args
self.servers: list[tuple[RemoteOpenAIServer, list[str]]] = []
self.servers: list[Optional[tuple[RemoteOpenAIServer,
list[str]]]] = [None] * 2
self.server_threads: list[threading.Thread] = []
def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]:
@ -209,7 +218,7 @@ class APIOnlyServerManager:
server.__enter__()
print(f"API-only server started successfully with "
f"{self.api_server_count} API servers")
self.servers.append((server, api_server_args))
self.servers[0] = (server, api_server_args)
except Exception as e:
print(f"Failed to start API-only server: {e}")
raise
@ -231,7 +240,7 @@ class APIOnlyServerManager:
server.__enter__()
print(f"Headless engines server started successfully with "
f"{self.dp_size} engines")
self.servers.append((server, engines_server_args))
self.servers[1] = (server, engines_server_args)
except Exception as e:
print(f"Failed to start headless engines server: {e}")
raise
@ -253,18 +262,20 @@ class APIOnlyServerManager:
# Give servers additional time to fully initialize and coordinate
time.sleep(3)
if len(self.servers) != 2:
if not all(self.servers):
raise Exception("Both servers failed to start")
return self.servers
return cast(list[tuple[RemoteOpenAIServer, list[str]]], self.servers)
def __exit__(self, exc_type, exc_val, exc_tb):
"""Stop both server instances."""
while self.servers:
try:
self.servers.pop()[0].__exit__(exc_type, exc_val, exc_tb)
except Exception as e:
print(f"Error stopping server: {e}")
if server := self.servers.pop():
try:
server[0].__exit__(exc_type, exc_val, exc_tb)
except Exception as e:
print(f"Error stopping server: {e}")
traceback.print_exc()
@pytest.fixture(scope="module")
@ -560,7 +571,7 @@ async def test_api_only_multinode_dp_completion(
assert len(results) == num_requests
assert all(completion is not None for completion in results)
_, api_server_args = api_only_servers[0]
api_server, api_server_args = api_only_servers[0]
api_server_count = (
api_server_args.count('--api-server-count')
and api_server_args[api_server_args.index('--api-server-count') + 1]
@ -569,7 +580,6 @@ async def test_api_only_multinode_dp_completion(
f"engines on headless server (API server count: {api_server_count})")
# Check request balancing via Prometheus metrics
api_server = api_only_servers[0][0]
check_request_balancing(api_server, DP_SIZE)