From 67d25eca05e7927cd0c0acd16046a3b1318b7fc3 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 3 Jul 2025 07:49:13 +0100 Subject: [PATCH] [Tests] Update online DP tests to verify that requests are balanced (#20157) Signed-off-by: Nick Hill --- .../v1/entrypoints/openai/test_completion.py | 2 +- .../openai/test_multi_api_servers.py | 126 ++++++++++++++++++ tests/v1/test_async_llm_dp.py | 51 +++++-- 3 files changed, 170 insertions(+), 9 deletions(-) diff --git a/tests/v1/entrypoints/openai/test_completion.py b/tests/v1/entrypoints/openai/test_completion.py index a7c31c064224..776fd42bbc35 100644 --- a/tests/v1/entrypoints/openai/test_completion.py +++ b/tests/v1/entrypoints/openai/test_completion.py @@ -38,7 +38,7 @@ def default_server_args(): ]]) def server(default_server_args, request): if request.param: - default_server_args.extend(request.param) + default_server_args = default_server_args + request.param with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server: yield remote_server diff --git a/tests/v1/entrypoints/openai/test_multi_api_servers.py b/tests/v1/entrypoints/openai/test_multi_api_servers.py index ed4ecbe8484c..e84b5e3095d0 100644 --- a/tests/v1/entrypoints/openai/test_multi_api_servers.py +++ b/tests/v1/entrypoints/openai/test_multi_api_servers.py @@ -2,10 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio import os +import re import openai # use the official client for correctness check import pytest import pytest_asyncio +import requests from tests.utils import RemoteOpenAIServer @@ -14,6 +16,122 @@ MODEL_NAME = "ibm-research/PowerMoE-3b" DP_SIZE = os.getenv("DP_SIZE", "1") +def get_prometheus_metrics( + server: RemoteOpenAIServer) -> dict[str, dict[str, float]]: + """Fetch and parse Prometheus metrics from the /metrics endpoint. + + Returns: + Dict mapping metric names to their values grouped by labels. + For example: {"vllm:request_success": { + "engine=0": 5.0, "engine=1": 3.0} + } + """ + try: + response = requests.get(server.url_for("metrics"), timeout=10) + response.raise_for_status() + + metrics: dict[str, dict[str, float]] = {} + + # Regex patterns for Prometheus metrics + metric_with_labels = re.compile( + r'^([a-zA-Z_:][a-zA-Z0-9_:]*)\{([^}]*)\}\s+([\d\.\-\+e]+)$') + metric_simple = re.compile( + r'^([a-zA-Z_:][a-zA-Z0-9_:]*)\s+([\d\.\-\+e]+)$') + + for line in response.text.split('\n'): + line = line.strip() + # Skip comments and empty lines + if not line or line.startswith('#'): + continue + + # Try to match metric with labels first + match = metric_with_labels.match(line) + if match: + metric_name, labels_part, value_str = match.groups() + try: + value = float(value_str) + if metric_name not in metrics: + metrics[metric_name] = {} + metrics[metric_name][f'{{{labels_part}}}'] = value + except ValueError: + continue + else: + # Try simple metric without labels + match = metric_simple.match(line) + if match: + metric_name, value_str = match.groups() + try: + value = float(value_str) + if metric_name not in metrics: + metrics[metric_name] = {} + metrics[metric_name][''] = value + except ValueError: + continue + + return metrics + except Exception as e: + pytest.fail(f"Failed to fetch Prometheus metrics: {e}") + return {} + + +def get_engine_request_counts( + metrics: dict[str, dict[str, float]]) -> dict[str, float]: + """Extract request counts per engine from Prometheus metrics. + + Returns: + Dict mapping engine indices to request counts. + For example: {"0": 15.0, "1": 12.0} + """ + engine_counts = {} + + # Look for request success metrics with engine labels + success_metrics = metrics.get("vllm:request_success_total", {}) + engine_pattern = re.compile(r'engine="([^"]*)"') + + for labels, count in success_metrics.items(): + # Extract engine ID from labels using regex + match = engine_pattern.search(labels) + if match: + engine_id = match.group(1) + if engine_id not in engine_counts: + engine_counts[engine_id] = 0.0 + engine_counts[engine_id] += count + + return engine_counts + + +def check_request_balancing(server: RemoteOpenAIServer): + """Check request balancing via Prometheus metrics if DP_SIZE > 1. + + Args: + server: The RemoteOpenAIServer instance + """ + dp_size = int(DP_SIZE) + if dp_size <= 1: + return + + # Get metrics after all requests are completed + metrics = get_prometheus_metrics(server) + engine_counts = get_engine_request_counts(metrics) + + # Check that multiple engines received requests + engines_with_requests = [ + engine for engine, count in engine_counts.items() if count > 0 + ] + assert len(engines_with_requests) == dp_size, ( + f"Expected requests to be distributed across multiple engines," + f" but only engine(s) {engines_with_requests} received " + f"requests. Engine counts: {engine_counts}") + + # Verify that the load is reasonably balanced + # (no engine should handle all requests) + total_requests = sum(engine_counts.values()) + + for count in engine_counts.values(): + assert count > total_requests // (dp_size + 1), ( + f"requests are imbalanced: {engine_counts}") + + @pytest.fixture(scope="module") def default_server_args(): return [ @@ -50,6 +168,7 @@ async def client(server): [MODEL_NAME], ) async def test_single_completion(client: openai.AsyncOpenAI, + server: RemoteOpenAIServer, model_name: str) -> None: async def make_request(): @@ -97,6 +216,9 @@ async def test_single_completion(client: openai.AsyncOpenAI, assert len(results) == num_requests assert all(completion is not None for completion in results) + # Check request balancing via Prometheus metrics if DP_SIZE > 1 + check_request_balancing(server) + @pytest.mark.asyncio @pytest.mark.parametrize( @@ -104,6 +226,7 @@ async def test_single_completion(client: openai.AsyncOpenAI, [MODEL_NAME], ) async def test_completion_streaming(client: openai.AsyncOpenAI, + server: RemoteOpenAIServer, model_name: str) -> None: prompt = "What is an LLM?" @@ -170,3 +293,6 @@ async def test_completion_streaming(client: openai.AsyncOpenAI, results ) == num_requests, f"Expected {num_requests} results, got {len(results)}" assert all(results), "Not all streaming requests completed successfully." + + # Check request balancing via Prometheus metrics if DP_SIZE > 1 + check_request_balancing(server) diff --git a/tests/v1/test_async_llm_dp.py b/tests/v1/test_async_llm_dp.py index 075ceb257ab7..64a41bec3791 100644 --- a/tests/v1/test_async_llm_dp.py +++ b/tests/v1/test_async_llm_dp.py @@ -4,24 +4,30 @@ import asyncio import os from contextlib import ExitStack +from dataclasses import dataclass from typing import Optional import pytest from vllm import SamplingParams +from vllm.config import VllmConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.inputs import PromptType from vllm.platforms import current_platform from vllm.sampling_params import RequestOutputKind from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.core_client import DPAsyncMPClient +from vllm.v1.metrics.loggers import StatLoggerBase +from vllm.v1.metrics.stats import IterationStats, SchedulerStats + +DP_SIZE = int(os.getenv("DP_SIZE", 2)) engine_args = AsyncEngineArgs( model="ibm-research/PowerMoE-3b", enforce_eager=True, disable_log_requests=True, tensor_parallel_size=int(os.getenv("TP_SIZE", 1)), - data_parallel_size=int(os.getenv("DP_SIZE", 2)), + data_parallel_size=DP_SIZE, ) if not current_platform.supports_v1(engine_args.create_model_config()): @@ -74,12 +80,32 @@ async def generate( async def test_load(output_kind: RequestOutputKind, data_parallel_backend: str): + stats_loggers = {} + + @dataclass + class SimpleStatsLogger(StatLoggerBase): + init_count: int = 0 + finished_req_count: int = 0 + + def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): + stats_loggers[engine_index] = self + + def record(self, scheduler_stats: Optional[SchedulerStats], + iteration_stats: Optional[IterationStats]): + if iteration_stats: + self.finished_req_count += len( + iteration_stats.finished_requests) + + def log_engine_initialized(self): + self.init_count += 1 + with ExitStack() as after: prompt = "This is a test of data parallel" engine_args.data_parallel_backend = data_parallel_backend - engine = AsyncLLM.from_engine_args(engine_args) + engine = AsyncLLM.from_engine_args(engine_args, + stat_loggers=[SimpleStatsLogger]) after.callback(engine.shutdown) NUM_REQUESTS = 100 @@ -92,12 +118,10 @@ async def test_load(output_kind: RequestOutputKind, for request_id in request_ids: tasks.append( asyncio.create_task( - generate(engine, - request_id, - prompt, - output_kind, - NUM_EXPECTED_TOKENS, - data_parallel_rank=0))) + generate(engine, request_id, prompt, output_kind, + NUM_EXPECTED_TOKENS))) + # Short sleep to ensure that requests are distributed. + await asyncio.sleep(0.01) # Confirm that we got all the EXPECTED tokens from the requests. done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION) @@ -122,3 +146,14 @@ async def test_load(output_kind: RequestOutputKind, assert not core_client.engines_running assert not core_client.reqs_in_flight + + # Check that requests were distributed between the engines + print(f"Stats loggers after test: {stats_loggers}") + assert len(stats_loggers) == DP_SIZE + assert stats_loggers[0].init_count == 1 + + for sl in stats_loggers.values(): + slogger: SimpleStatsLogger = sl + + assert slogger.finished_req_count > NUM_REQUESTS // ( + DP_SIZE + 1), f"requests are imbalanced: {stats_loggers}"