mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 22:15:57 +08:00
[BugFix] Fix port lookup in internal DP LB tests (#22252)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
0933f9d518
commit
ae05a6d83d
@ -4,6 +4,8 @@ import asyncio
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from typing import Optional, cast
|
||||
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
@ -41,12 +43,15 @@ class MultinodeInternalLBServerManager:
|
||||
self.tp_size = tp_size
|
||||
self.api_server_count = api_server_count
|
||||
self.base_server_args = base_server_args
|
||||
self.servers: list[tuple[RemoteOpenAIServer, list[str]]] = []
|
||||
self.servers: list[Optional[tuple[RemoteOpenAIServer,
|
||||
list[str]]]] = [None] * (dp_size //
|
||||
dp_per_node)
|
||||
self.server_threads: list[threading.Thread] = []
|
||||
|
||||
def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]:
|
||||
"""Start all server instances for multi-node internal LB mode."""
|
||||
for rank in range(0, self.dp_size, self.dp_per_node):
|
||||
for server_idx, rank in enumerate(
|
||||
range(0, self.dp_size, self.dp_per_node)):
|
||||
# Create server args for this specific rank
|
||||
server_args = self.base_server_args.copy()
|
||||
|
||||
@ -87,7 +92,7 @@ class MultinodeInternalLBServerManager:
|
||||
])
|
||||
|
||||
# Use a thread to start each server to allow parallel initialization
|
||||
def start_server(r: int, sargs: list[str]):
|
||||
def start_server(sidx: int, r: int, sargs: list[str]):
|
||||
gpus_per_node = self.tp_size * self.dp_per_node
|
||||
try:
|
||||
# Start the server
|
||||
@ -110,13 +115,14 @@ class MultinodeInternalLBServerManager:
|
||||
f"{self.api_server_count} API servers")
|
||||
else:
|
||||
print(f"Headless node (rank {r}) started successfully")
|
||||
self.servers.append((server, sargs))
|
||||
self.servers[sidx] = (server, sargs)
|
||||
except Exception as e:
|
||||
print(f"Failed to start server rank {r}: {e}")
|
||||
traceback.print_exc()
|
||||
raise
|
||||
|
||||
thread = threading.Thread(target=start_server,
|
||||
args=(rank, server_args))
|
||||
args=(server_idx, rank, server_args))
|
||||
thread.start()
|
||||
|
||||
self.server_threads.append(thread)
|
||||
@ -128,18 +134,20 @@ class MultinodeInternalLBServerManager:
|
||||
# Give servers additional time to fully initialize and coordinate
|
||||
time.sleep(3)
|
||||
|
||||
if len(self.servers) != self.dp_size // self.dp_per_node:
|
||||
if not all(self.servers):
|
||||
raise Exception("Servers failed to start")
|
||||
|
||||
return self.servers
|
||||
return cast(list[tuple[RemoteOpenAIServer, list[str]]], self.servers)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Stop all server instances."""
|
||||
while self.servers:
|
||||
try:
|
||||
self.servers.pop()[0].__exit__(exc_type, exc_val, exc_tb)
|
||||
except Exception as e:
|
||||
print(f"Error stopping server: {e}")
|
||||
if server := self.servers.pop():
|
||||
try:
|
||||
server[0].__exit__(exc_type, exc_val, exc_tb)
|
||||
except Exception as e:
|
||||
print(f"Error stopping server: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
class APIOnlyServerManager:
|
||||
@ -157,7 +165,8 @@ class APIOnlyServerManager:
|
||||
self.tp_size = tp_size
|
||||
self.api_server_count = api_server_count
|
||||
self.base_server_args = base_server_args
|
||||
self.servers: list[tuple[RemoteOpenAIServer, list[str]]] = []
|
||||
self.servers: list[Optional[tuple[RemoteOpenAIServer,
|
||||
list[str]]]] = [None] * 2
|
||||
self.server_threads: list[threading.Thread] = []
|
||||
|
||||
def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]:
|
||||
@ -209,7 +218,7 @@ class APIOnlyServerManager:
|
||||
server.__enter__()
|
||||
print(f"API-only server started successfully with "
|
||||
f"{self.api_server_count} API servers")
|
||||
self.servers.append((server, api_server_args))
|
||||
self.servers[0] = (server, api_server_args)
|
||||
except Exception as e:
|
||||
print(f"Failed to start API-only server: {e}")
|
||||
raise
|
||||
@ -231,7 +240,7 @@ class APIOnlyServerManager:
|
||||
server.__enter__()
|
||||
print(f"Headless engines server started successfully with "
|
||||
f"{self.dp_size} engines")
|
||||
self.servers.append((server, engines_server_args))
|
||||
self.servers[1] = (server, engines_server_args)
|
||||
except Exception as e:
|
||||
print(f"Failed to start headless engines server: {e}")
|
||||
raise
|
||||
@ -253,18 +262,20 @@ class APIOnlyServerManager:
|
||||
# Give servers additional time to fully initialize and coordinate
|
||||
time.sleep(3)
|
||||
|
||||
if len(self.servers) != 2:
|
||||
if not all(self.servers):
|
||||
raise Exception("Both servers failed to start")
|
||||
|
||||
return self.servers
|
||||
return cast(list[tuple[RemoteOpenAIServer, list[str]]], self.servers)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Stop both server instances."""
|
||||
while self.servers:
|
||||
try:
|
||||
self.servers.pop()[0].__exit__(exc_type, exc_val, exc_tb)
|
||||
except Exception as e:
|
||||
print(f"Error stopping server: {e}")
|
||||
if server := self.servers.pop():
|
||||
try:
|
||||
server[0].__exit__(exc_type, exc_val, exc_tb)
|
||||
except Exception as e:
|
||||
print(f"Error stopping server: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@ -560,7 +571,7 @@ async def test_api_only_multinode_dp_completion(
|
||||
assert len(results) == num_requests
|
||||
assert all(completion is not None for completion in results)
|
||||
|
||||
_, api_server_args = api_only_servers[0]
|
||||
api_server, api_server_args = api_only_servers[0]
|
||||
api_server_count = (
|
||||
api_server_args.count('--api-server-count')
|
||||
and api_server_args[api_server_args.index('--api-server-count') + 1]
|
||||
@ -569,7 +580,6 @@ async def test_api_only_multinode_dp_completion(
|
||||
f"engines on headless server (API server count: {api_server_count})")
|
||||
|
||||
# Check request balancing via Prometheus metrics
|
||||
api_server = api_only_servers[0][0]
|
||||
check_request_balancing(api_server, DP_SIZE)
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user