[Tests] Update online DP tests to verify that requests are balanced (#20157)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-07-03 07:49:13 +01:00 committed by GitHub
parent 363528de27
commit 67d25eca05
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 170 additions and 9 deletions

View File

@ -38,7 +38,7 @@ def default_server_args():
]]) ]])
def server(default_server_args, request): def server(default_server_args, request):
if request.param: if request.param:
default_server_args.extend(request.param) default_server_args = default_server_args + request.param
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server: with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
yield remote_server yield remote_server

View File

@ -2,10 +2,12 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio import asyncio
import os import os
import re
import openai # use the official client for correctness check import openai # use the official client for correctness check
import pytest import pytest
import pytest_asyncio import pytest_asyncio
import requests
from tests.utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
@ -14,6 +16,122 @@ MODEL_NAME = "ibm-research/PowerMoE-3b"
DP_SIZE = os.getenv("DP_SIZE", "1") DP_SIZE = os.getenv("DP_SIZE", "1")
def get_prometheus_metrics(
server: RemoteOpenAIServer) -> dict[str, dict[str, float]]:
"""Fetch and parse Prometheus metrics from the /metrics endpoint.
Returns:
Dict mapping metric names to their values grouped by labels.
For example: {"vllm:request_success": {
"engine=0": 5.0, "engine=1": 3.0}
}
"""
try:
response = requests.get(server.url_for("metrics"), timeout=10)
response.raise_for_status()
metrics: dict[str, dict[str, float]] = {}
# Regex patterns for Prometheus metrics
metric_with_labels = re.compile(
r'^([a-zA-Z_:][a-zA-Z0-9_:]*)\{([^}]*)\}\s+([\d\.\-\+e]+)$')
metric_simple = re.compile(
r'^([a-zA-Z_:][a-zA-Z0-9_:]*)\s+([\d\.\-\+e]+)$')
for line in response.text.split('\n'):
line = line.strip()
# Skip comments and empty lines
if not line or line.startswith('#'):
continue
# Try to match metric with labels first
match = metric_with_labels.match(line)
if match:
metric_name, labels_part, value_str = match.groups()
try:
value = float(value_str)
if metric_name not in metrics:
metrics[metric_name] = {}
metrics[metric_name][f'{{{labels_part}}}'] = value
except ValueError:
continue
else:
# Try simple metric without labels
match = metric_simple.match(line)
if match:
metric_name, value_str = match.groups()
try:
value = float(value_str)
if metric_name not in metrics:
metrics[metric_name] = {}
metrics[metric_name][''] = value
except ValueError:
continue
return metrics
except Exception as e:
pytest.fail(f"Failed to fetch Prometheus metrics: {e}")
return {}
def get_engine_request_counts(
metrics: dict[str, dict[str, float]]) -> dict[str, float]:
"""Extract request counts per engine from Prometheus metrics.
Returns:
Dict mapping engine indices to request counts.
For example: {"0": 15.0, "1": 12.0}
"""
engine_counts = {}
# Look for request success metrics with engine labels
success_metrics = metrics.get("vllm:request_success_total", {})
engine_pattern = re.compile(r'engine="([^"]*)"')
for labels, count in success_metrics.items():
# Extract engine ID from labels using regex
match = engine_pattern.search(labels)
if match:
engine_id = match.group(1)
if engine_id not in engine_counts:
engine_counts[engine_id] = 0.0
engine_counts[engine_id] += count
return engine_counts
def check_request_balancing(server: RemoteOpenAIServer):
"""Check request balancing via Prometheus metrics if DP_SIZE > 1.
Args:
server: The RemoteOpenAIServer instance
"""
dp_size = int(DP_SIZE)
if dp_size <= 1:
return
# Get metrics after all requests are completed
metrics = get_prometheus_metrics(server)
engine_counts = get_engine_request_counts(metrics)
# Check that multiple engines received requests
engines_with_requests = [
engine for engine, count in engine_counts.items() if count > 0
]
assert len(engines_with_requests) == dp_size, (
f"Expected requests to be distributed across multiple engines,"
f" but only engine(s) {engines_with_requests} received "
f"requests. Engine counts: {engine_counts}")
# Verify that the load is reasonably balanced
# (no engine should handle all requests)
total_requests = sum(engine_counts.values())
for count in engine_counts.values():
assert count > total_requests // (dp_size + 1), (
f"requests are imbalanced: {engine_counts}")
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def default_server_args(): def default_server_args():
return [ return [
@ -50,6 +168,7 @@ async def client(server):
[MODEL_NAME], [MODEL_NAME],
) )
async def test_single_completion(client: openai.AsyncOpenAI, async def test_single_completion(client: openai.AsyncOpenAI,
server: RemoteOpenAIServer,
model_name: str) -> None: model_name: str) -> None:
async def make_request(): async def make_request():
@ -97,6 +216,9 @@ async def test_single_completion(client: openai.AsyncOpenAI,
assert len(results) == num_requests assert len(results) == num_requests
assert all(completion is not None for completion in results) assert all(completion is not None for completion in results)
# Check request balancing via Prometheus metrics if DP_SIZE > 1
check_request_balancing(server)
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -104,6 +226,7 @@ async def test_single_completion(client: openai.AsyncOpenAI,
[MODEL_NAME], [MODEL_NAME],
) )
async def test_completion_streaming(client: openai.AsyncOpenAI, async def test_completion_streaming(client: openai.AsyncOpenAI,
server: RemoteOpenAIServer,
model_name: str) -> None: model_name: str) -> None:
prompt = "What is an LLM?" prompt = "What is an LLM?"
@ -170,3 +293,6 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
results results
) == num_requests, f"Expected {num_requests} results, got {len(results)}" ) == num_requests, f"Expected {num_requests} results, got {len(results)}"
assert all(results), "Not all streaming requests completed successfully." assert all(results), "Not all streaming requests completed successfully."
# Check request balancing via Prometheus metrics if DP_SIZE > 1
check_request_balancing(server)

View File

@ -4,24 +4,30 @@
import asyncio import asyncio
import os import os
from contextlib import ExitStack from contextlib import ExitStack
from dataclasses import dataclass
from typing import Optional from typing import Optional
import pytest import pytest
from vllm import SamplingParams from vllm import SamplingParams
from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs import PromptType from vllm.inputs import PromptType
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.core_client import DPAsyncMPClient from vllm.v1.engine.core_client import DPAsyncMPClient
from vllm.v1.metrics.loggers import StatLoggerBase
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
DP_SIZE = int(os.getenv("DP_SIZE", 2))
engine_args = AsyncEngineArgs( engine_args = AsyncEngineArgs(
model="ibm-research/PowerMoE-3b", model="ibm-research/PowerMoE-3b",
enforce_eager=True, enforce_eager=True,
disable_log_requests=True, disable_log_requests=True,
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)), tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
data_parallel_size=int(os.getenv("DP_SIZE", 2)), data_parallel_size=DP_SIZE,
) )
if not current_platform.supports_v1(engine_args.create_model_config()): if not current_platform.supports_v1(engine_args.create_model_config()):
@ -74,12 +80,32 @@ async def generate(
async def test_load(output_kind: RequestOutputKind, async def test_load(output_kind: RequestOutputKind,
data_parallel_backend: str): data_parallel_backend: str):
stats_loggers = {}
@dataclass
class SimpleStatsLogger(StatLoggerBase):
init_count: int = 0
finished_req_count: int = 0
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
stats_loggers[engine_index] = self
def record(self, scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats]):
if iteration_stats:
self.finished_req_count += len(
iteration_stats.finished_requests)
def log_engine_initialized(self):
self.init_count += 1
with ExitStack() as after: with ExitStack() as after:
prompt = "This is a test of data parallel" prompt = "This is a test of data parallel"
engine_args.data_parallel_backend = data_parallel_backend engine_args.data_parallel_backend = data_parallel_backend
engine = AsyncLLM.from_engine_args(engine_args) engine = AsyncLLM.from_engine_args(engine_args,
stat_loggers=[SimpleStatsLogger])
after.callback(engine.shutdown) after.callback(engine.shutdown)
NUM_REQUESTS = 100 NUM_REQUESTS = 100
@ -92,12 +118,10 @@ async def test_load(output_kind: RequestOutputKind,
for request_id in request_ids: for request_id in request_ids:
tasks.append( tasks.append(
asyncio.create_task( asyncio.create_task(
generate(engine, generate(engine, request_id, prompt, output_kind,
request_id, NUM_EXPECTED_TOKENS)))
prompt, # Short sleep to ensure that requests are distributed.
output_kind, await asyncio.sleep(0.01)
NUM_EXPECTED_TOKENS,
data_parallel_rank=0)))
# Confirm that we got all the EXPECTED tokens from the requests. # Confirm that we got all the EXPECTED tokens from the requests.
done, pending = await asyncio.wait(tasks, done, pending = await asyncio.wait(tasks,
return_when=asyncio.FIRST_EXCEPTION) return_when=asyncio.FIRST_EXCEPTION)
@ -122,3 +146,14 @@ async def test_load(output_kind: RequestOutputKind,
assert not core_client.engines_running assert not core_client.engines_running
assert not core_client.reqs_in_flight assert not core_client.reqs_in_flight
# Check that requests were distributed between the engines
print(f"Stats loggers after test: {stats_loggers}")
assert len(stats_loggers) == DP_SIZE
assert stats_loggers[0].init_count == 1
for sl in stats_loggers.values():
slogger: SimpleStatsLogger = sl
assert slogger.finished_req_count > NUM_REQUESTS // (
DP_SIZE + 1), f"requests are imbalanced: {stats_loggers}"