[Frontend] Pass API server count to each process (#23717)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-09-20 01:15:19 +08:00 committed by GitHub
parent 7ac67ea525
commit 6c117cff7d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 221 additions and 51 deletions

View File

@ -11,13 +11,13 @@ from datetime import datetime
from typing import Any
import torch
import triton
from tqdm import tqdm
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
_w8a8_block_fp8_matmul,
)
from vllm.platforms import current_platform
from vllm.triton_utils import triton
from vllm.utils import FlexibleArgumentParser
mp.set_start_method("spawn", force=True)

View File

@ -1,8 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import dataclasses
import json
import logging
import os
@ -327,12 +325,7 @@ def main():
if args.command == "serialize":
eng_args_dict = {f.name: getattr(args, f.name) for f in
dataclasses.fields(EngineArgs)}
engine_args = EngineArgs.from_cli_args(
argparse.Namespace(**eng_args_dict)
)
engine_args = EngineArgs.from_cli_args(args)
input_dir = tensorizer_dir.rstrip('/')
suffix = args.suffix if args.suffix else uuid.uuid4().hex

View File

@ -60,7 +60,7 @@ def test_api_server_process_manager_init(api_server_args, with_stats_update):
global WORKER_RUNTIME_SECONDS
WORKER_RUNTIME_SECONDS = 0.5
# Copy the args to avoid mutating the
# Copy the args to avoid mutating them
args = api_server_args.copy()
if not with_stats_update:

View File

@ -9,6 +9,7 @@ from contextlib import AsyncExitStack
import openai # use the official client for correctness check
import pytest
import pytest_asyncio
import requests
from tests.utils import RemoteOpenAIServer
from vllm.platforms import current_platform
@ -70,6 +71,8 @@ class ExternalLBServerManager:
sargs,
auto_port=False,
env_dict={
"VLLM_SERVER_DEV_MODE":
"1",
current_platform.device_control_env_var:
",".join(
str(
@ -127,11 +130,19 @@ def default_server_args():
@pytest.fixture(scope="module", params=[1, 4])
def servers(request, default_server_args):
def server_manager(request, default_server_args):
api_server_count = request.param
with ExternalLBServerManager(MODEL_NAME, DP_SIZE, api_server_count,
default_server_args) as server_list:
yield server_list
server_manager = ExternalLBServerManager(MODEL_NAME, DP_SIZE,
api_server_count,
default_server_args)
with server_manager:
yield server_manager
@pytest.fixture
def servers(server_manager):
return server_manager.servers
@pytest_asyncio.fixture
@ -144,6 +155,39 @@ async def clients(servers: list[tuple[RemoteOpenAIServer, list[str]]]):
]
def _get_parallel_config(server: RemoteOpenAIServer):
response = requests.get(server.url_for("server_info?config_format=json"))
response.raise_for_status()
vllm_config = response.json()["vllm_config"]
return vllm_config["parallel_config"]
def test_external_lb_server_info(server_manager):
servers = server_manager.servers
api_server_count = server_manager.api_server_count
for i, (server, _) in enumerate(servers):
print(f"Testing {i=}")
# Each request will hit one of the API servers
# `n_reqs` is set so that there is a good chance each server
# receives at least one request
n_reqs = 2 * api_server_count * api_server_count
parallel_configs = [
_get_parallel_config(server) for _ in range(n_reqs)
]
api_process_counts = [
c["_api_process_count"] for c in parallel_configs
]
api_process_ranks = [c["_api_process_rank"] for c in parallel_configs]
assert all(c == api_server_count
for c in api_process_counts), api_process_counts
assert all(0 <= r < api_server_count
for r in api_process_ranks), api_process_ranks
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",

View File

@ -9,6 +9,7 @@ from contextlib import AsyncExitStack
import openai # use the official client for correctness check
import pytest
import pytest_asyncio
import requests
from tests.utils import RemoteOpenAIServer
from tests.v1.test_utils import check_request_balancing
@ -92,6 +93,8 @@ class HybridLBServerManager:
sargs,
auto_port=False,
env_dict={
"VLLM_SERVER_DEV_MODE":
"1",
current_platform.device_control_env_var:
",".join(
str(
@ -150,12 +153,20 @@ def default_server_args():
@pytest.fixture(scope="module", params=[1, 4])
def servers(request, default_server_args):
def server_manager(request, default_server_args):
api_server_count = request.param
with HybridLBServerManager(MODEL_NAME, DP_SIZE, api_server_count,
default_server_args, DP_SIZE_LOCAL,
TP_SIZE) as server_list:
yield server_list
server_manager = HybridLBServerManager(MODEL_NAME, DP_SIZE,
api_server_count,
default_server_args, DP_SIZE_LOCAL,
TP_SIZE)
with server_manager:
yield server_manager
@pytest.fixture
def servers(server_manager):
return server_manager.servers
@pytest_asyncio.fixture
@ -168,6 +179,39 @@ async def clients(servers: list[tuple[RemoteOpenAIServer, list[str]]]):
]
def _get_parallel_config(server: RemoteOpenAIServer):
response = requests.get(server.url_for("server_info?config_format=json"))
response.raise_for_status()
vllm_config = response.json()["vllm_config"]
return vllm_config["parallel_config"]
def test_hybrid_dp_server_info(server_manager):
servers = server_manager.servers
api_server_count = server_manager.api_server_count
for i, (server, _) in enumerate(servers):
print(f"Testing {i=}")
# Each request will hit one of the API servers
# `n_reqs` is set so that there is a good chance each server
# receives at least one request
n_reqs = 2 * api_server_count * api_server_count
parallel_configs = [
_get_parallel_config(server) for _ in range(n_reqs)
]
api_process_counts = [
c["_api_process_count"] for c in parallel_configs
]
api_process_ranks = [c["_api_process_rank"] for c in parallel_configs]
assert all(c == api_server_count
for c in api_process_counts), api_process_counts
assert all(0 <= r < api_server_count
for r in api_process_ranks), api_process_ranks
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",

View File

@ -10,6 +10,7 @@ from typing import Optional, cast
import openai # use the official client for correctness check
import pytest
import pytest_asyncio
import requests
from tests.utils import RemoteOpenAIServer
from tests.v1.test_utils import check_request_balancing
@ -101,6 +102,8 @@ class MultinodeInternalLBServerManager:
sargs,
auto_port=False,
env_dict={
"VLLM_SERVER_DEV_MODE":
"1",
current_platform.device_control_env_var:
",".join(
str(
@ -214,7 +217,10 @@ class APIOnlyServerManager:
self.model_name,
api_server_args,
auto_port=False,
env_dict={}) # No GPUs needed for API-only server
env_dict={
"VLLM_SERVER_DEV_MODE": "1",
# No GPUs needed for API-only server
})
server.__enter__()
print(f"API-only server started successfully with "
f"{self.api_server_count} API servers")
@ -293,14 +299,21 @@ def default_server_args():
@pytest.fixture(scope="module", params=[1, 4])
def servers(request, default_server_args):
def server_manager(request, default_server_args):
api_server_count = request.param
with MultinodeInternalLBServerManager(MODEL_NAME, DP_SIZE,
api_server_count,
default_server_args,
DP_SIZE // NUM_NODES,
TP_SIZE) as server_list:
yield server_list
server_manager = MultinodeInternalLBServerManager(MODEL_NAME, DP_SIZE,
api_server_count,
default_server_args,
DP_SIZE // NUM_NODES,
TP_SIZE)
with server_manager:
yield server_manager
@pytest.fixture
def servers(server_manager):
return server_manager.servers
@pytest.fixture(scope="module", params=[1, 4])
@ -331,6 +344,34 @@ async def api_only_client(api_only_servers: list[tuple[RemoteOpenAIServer,
yield client
def _get_parallel_config(server: RemoteOpenAIServer):
response = requests.get(server.url_for("server_info?config_format=json"))
response.raise_for_status()
vllm_config = response.json()["vllm_config"]
return vllm_config["parallel_config"]
def test_multinode_dp_server_info(server_manager):
head_server = server_manager.servers[0][0]
api_server_count = server_manager.api_server_count
# Each request will hit one of the API servers
# `n_reqs` is set so that there is a good chance each server
# receives at least one request
n_reqs = 2 * api_server_count * api_server_count
parallel_configs = [
_get_parallel_config(head_server) for _ in range(n_reqs)
]
api_process_counts = [c["_api_process_count"] for c in parallel_configs]
api_process_ranks = [c["_api_process_rank"] for c in parallel_configs]
assert all(c == api_server_count
for c in api_process_counts), api_process_counts
assert all(0 <= r < api_server_count
for r in api_process_ranks), api_process_ranks
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",

View File

@ -193,6 +193,25 @@ class ParallelConfig:
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
needs to be divisible by dcp_size."""
_api_process_count: int = 1
"""
The number of API processes initialized.
Note:
This is an internal config that is only valid for and
should only be set by API server scale-out.
"""
_api_process_rank: int = 0
"""
The rank of this API process, or `-1` for engine core processes
under API server scale-out.
Note:
This is an internal config that is only valid for and
should only be set by API server scale-out.
"""
@property
def world_size_across_dp(self) -> int:
"""world_size_across_dp is TPxPPxDP, it is the size of the world
@ -428,6 +447,12 @@ class ParallelConfig:
if self.distributed_executor_backend is None and self.world_size == 1:
self.distributed_executor_backend = "uni"
if not -1 <= self._api_process_rank < self._api_process_count:
raise ValueError(
"Invalid value of `_api_process_rank`. "
f"Expected to be `-1` or `[0, {self._api_process_count})`, "
f"but found: {self._api_process_rank}")
@property
def use_ray(self) -> bool:
return self.distributed_executor_backend == "ray" or (

View File

@ -333,6 +333,8 @@ class EngineArgs:
enable_eplb: bool = ParallelConfig.enable_eplb
expert_placement_strategy: ExpertPlacementStrategy = \
ParallelConfig.expert_placement_strategy
_api_process_count: int = ParallelConfig._api_process_count
_api_process_rank: int = ParallelConfig._api_process_rank
num_redundant_experts: int = EPLBConfig.num_redundant_experts
eplb_window_size: int = EPLBConfig.window_size
eplb_step_interval: int = EPLBConfig.step_interval
@ -952,7 +954,10 @@ class EngineArgs:
# Get the list of attributes of this dataclass.
attrs = [attr.name for attr in dataclasses.fields(cls)]
# Set the attributes from the parsed arguments.
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
engine_args = cls(**{
attr: getattr(args, attr)
for attr in attrs if hasattr(args, attr)
})
return engine_args
def create_model_config(self) -> ModelConfig:
@ -1366,6 +1371,8 @@ class EngineArgs:
worker_cls=self.worker_cls,
worker_extension_cls=self.worker_extension_cls,
decode_context_parallel_size=self.decode_context_parallel_size,
_api_process_count=self._api_process_count,
_api_process_rank=self._api_process_rank,
)
speculative_config = self.create_speculative_config(

View File

@ -135,23 +135,20 @@ def run_headless(args: argparse.Namespace):
def run_multi_api_server(args: argparse.Namespace):
assert not args.headless
num_api_servers = args.api_server_count
num_api_servers: int = args.api_server_count
assert num_api_servers > 0
orig_mm_processor_cache_gb = args.mm_processor_cache_gb
if num_api_servers > 1:
setup_multiprocess_prometheus()
# Not compatible with API server scale-out
args.mm_processor_cache_gb = 0
listen_address, sock = setup_server(args)
engine_args = vllm.AsyncEngineArgs.from_cli_args(args)
engine_args._api_process_count = num_api_servers
engine_args._api_process_rank = -1
usage_context = UsageContext.OPENAI_API_SERVER
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
model_config = vllm_config.model_config
if num_api_servers > 1:
if not envs.VLLM_USE_V1:
@ -161,10 +158,6 @@ def run_multi_api_server(args: argparse.Namespace):
raise ValueError("VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used "
"with api_server_count > 1")
if model_config.is_multimodal_model and orig_mm_processor_cache_gb > 0:
logger.warning("Multi-modal processor cache is disabled because "
"it is not compatible with `api_server_count > 1`.")
executor_class = Executor.get_class(vllm_config)
log_stats = not engine_args.disable_log_stats
@ -221,9 +214,10 @@ def run_api_server_worker_proc(listen_address,
client_config=None,
**uvicorn_kwargs) -> None:
"""Entrypoint for individual API server worker processes."""
client_config = client_config or {}
server_index = client_config.get("client_index", 0)
# Set process title and add process-specific prefix to stdout and stderr.
server_index = client_config.get("client_index", 0) if client_config else 0
set_process_title("APIServer", str(server_index))
decorate_logs()

View File

@ -17,13 +17,14 @@ from argparse import Namespace
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable
from contextlib import asynccontextmanager
from http import HTTPStatus
from typing import Annotated, Any, Callable, Optional
from typing import Annotated, Any, Callable, Literal, Optional
import prometheus_client
import pydantic
import regex as re
import uvloop
from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
from fastapi import (APIRouter, Depends, FastAPI, Form, HTTPException, Query,
Request)
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
@ -166,6 +167,9 @@ async def build_async_engine_client(
# Context manager to handle engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit
engine_args = AsyncEngineArgs.from_cli_args(args)
if client_config:
engine_args._api_process_count = client_config.get("client_count", 1)
engine_args._api_process_rank = client_config.get("client_index", 0)
if disable_frontend_multiprocessing is None:
disable_frontend_multiprocessing = bool(
@ -209,8 +213,12 @@ async def build_async_engine_client_from_engine_args(
from vllm.v1.engine.async_llm import AsyncLLM
async_llm: Optional[AsyncLLM] = None
client_count = client_config.pop("client_count") if client_config else 1
client_index = client_config.pop("client_index") if client_config else 0
# Don't mutate the input client_config
client_config = dict(client_config) if client_config else {}
client_count = client_config.pop("client_count", 1)
client_index = client_config.pop("client_index", 0)
try:
async_llm = AsyncLLM.from_vllm_config(
vllm_config=vllm_config,
@ -956,9 +964,22 @@ if envs.VLLM_SERVER_DEV_MODE:
logger.warning("SECURITY WARNING: Development endpoints are enabled! "
"This should NOT be used in production!")
PydanticVllmConfig = pydantic.TypeAdapter(VllmConfig)
@router.get("/server_info")
async def show_server_info(raw_request: Request):
server_info = {"vllm_config": str(raw_request.app.state.vllm_config)}
async def show_server_info(
raw_request: Request,
config_format: Annotated[Literal["text", "json"],
Query()] = "text",
):
vllm_config: VllmConfig = raw_request.app.state.vllm_config
server_info = {
"vllm_config":
str(vllm_config)
if config_format == "text" else PydanticVllmConfig.dump_python(
vllm_config, mode="json", fallback=str)
# fallback=str is needed to handle e.g. torch.dtype
}
return JSONResponse(content=server_info)
@router.post("/reset_prefix_cache")
@ -1856,8 +1877,6 @@ async def run_server_worker(listen_address,
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
server_index = client_config.get("client_index", 0) if client_config else 0
# Load logging config for uvicorn if specified
log_config = load_log_config(args.log_config_file)
if log_config is not None:
@ -1873,7 +1892,8 @@ async def run_server_worker(listen_address,
vllm_config = await engine_client.get_vllm_config()
await init_app_state(engine_client, vllm_config, app.state, args)
logger.info("Starting vLLM API server %d on %s", server_index,
logger.info("Starting vLLM API server %d on %s",
vllm_config.parallel_config._api_process_rank,
listen_address)
shutdown_task = await serve_http(
app,

View File

@ -494,7 +494,8 @@ def _enable_processor_cache(
def _enable_ipc_cache(vllm_config: "VllmConfig") -> bool:
parallel_config = vllm_config.parallel_config
supports_ipc_cache = (parallel_config.data_parallel_size == 1
supports_ipc_cache = ((parallel_config._api_process_count == 1
and parallel_config.data_parallel_size == 1)
or parallel_config.data_parallel_external_lb)
return supports_ipc_cache

View File

@ -437,7 +437,7 @@ class MPClient(EngineCoreClient):
self.engines_running = False
self.stats_update_address: Optional[str] = None
if client_addresses is not None:
if client_addresses:
# Engines are managed externally to this client.
input_address = client_addresses["input_address"]
output_address = client_addresses["output_address"]
@ -774,6 +774,7 @@ class AsyncMPClient(MPClient):
client_addresses=client_addresses,
)
self.client_count = client_count
self.client_index = client_index
self.outputs_queue = asyncio.Queue[Union[EngineCoreOutputs,
Exception]]()