mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-19 08:07:07 +08:00
[Frontend] Pass API server count to each process (#23717)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
7ac67ea525
commit
6c117cff7d
@ -11,13 +11,13 @@ from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
_w8a8_block_fp8_matmul,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@ -327,12 +325,7 @@ def main():
|
||||
|
||||
|
||||
if args.command == "serialize":
|
||||
eng_args_dict = {f.name: getattr(args, f.name) for f in
|
||||
dataclasses.fields(EngineArgs)}
|
||||
|
||||
engine_args = EngineArgs.from_cli_args(
|
||||
argparse.Namespace(**eng_args_dict)
|
||||
)
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
|
||||
input_dir = tensorizer_dir.rstrip('/')
|
||||
suffix = args.suffix if args.suffix else uuid.uuid4().hex
|
||||
|
||||
@ -60,7 +60,7 @@ def test_api_server_process_manager_init(api_server_args, with_stats_update):
|
||||
global WORKER_RUNTIME_SECONDS
|
||||
WORKER_RUNTIME_SECONDS = 0.5
|
||||
|
||||
# Copy the args to avoid mutating the
|
||||
# Copy the args to avoid mutating them
|
||||
args = api_server_args.copy()
|
||||
|
||||
if not with_stats_update:
|
||||
|
||||
@ -9,6 +9,7 @@ from contextlib import AsyncExitStack
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import requests
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.platforms import current_platform
|
||||
@ -70,6 +71,8 @@ class ExternalLBServerManager:
|
||||
sargs,
|
||||
auto_port=False,
|
||||
env_dict={
|
||||
"VLLM_SERVER_DEV_MODE":
|
||||
"1",
|
||||
current_platform.device_control_env_var:
|
||||
",".join(
|
||||
str(
|
||||
@ -127,11 +130,19 @@ def default_server_args():
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=[1, 4])
|
||||
def servers(request, default_server_args):
|
||||
def server_manager(request, default_server_args):
|
||||
api_server_count = request.param
|
||||
with ExternalLBServerManager(MODEL_NAME, DP_SIZE, api_server_count,
|
||||
default_server_args) as server_list:
|
||||
yield server_list
|
||||
server_manager = ExternalLBServerManager(MODEL_NAME, DP_SIZE,
|
||||
api_server_count,
|
||||
default_server_args)
|
||||
|
||||
with server_manager:
|
||||
yield server_manager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def servers(server_manager):
|
||||
return server_manager.servers
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
@ -144,6 +155,39 @@ async def clients(servers: list[tuple[RemoteOpenAIServer, list[str]]]):
|
||||
]
|
||||
|
||||
|
||||
def _get_parallel_config(server: RemoteOpenAIServer):
|
||||
response = requests.get(server.url_for("server_info?config_format=json"))
|
||||
response.raise_for_status()
|
||||
|
||||
vllm_config = response.json()["vllm_config"]
|
||||
return vllm_config["parallel_config"]
|
||||
|
||||
|
||||
def test_external_lb_server_info(server_manager):
|
||||
servers = server_manager.servers
|
||||
api_server_count = server_manager.api_server_count
|
||||
|
||||
for i, (server, _) in enumerate(servers):
|
||||
print(f"Testing {i=}")
|
||||
|
||||
# Each request will hit one of the API servers
|
||||
# `n_reqs` is set so that there is a good chance each server
|
||||
# receives at least one request
|
||||
n_reqs = 2 * api_server_count * api_server_count
|
||||
parallel_configs = [
|
||||
_get_parallel_config(server) for _ in range(n_reqs)
|
||||
]
|
||||
api_process_counts = [
|
||||
c["_api_process_count"] for c in parallel_configs
|
||||
]
|
||||
api_process_ranks = [c["_api_process_rank"] for c in parallel_configs]
|
||||
|
||||
assert all(c == api_server_count
|
||||
for c in api_process_counts), api_process_counts
|
||||
assert all(0 <= r < api_server_count
|
||||
for r in api_process_ranks), api_process_ranks
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
|
||||
@ -9,6 +9,7 @@ from contextlib import AsyncExitStack
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import requests
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from tests.v1.test_utils import check_request_balancing
|
||||
@ -92,6 +93,8 @@ class HybridLBServerManager:
|
||||
sargs,
|
||||
auto_port=False,
|
||||
env_dict={
|
||||
"VLLM_SERVER_DEV_MODE":
|
||||
"1",
|
||||
current_platform.device_control_env_var:
|
||||
",".join(
|
||||
str(
|
||||
@ -150,12 +153,20 @@ def default_server_args():
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=[1, 4])
|
||||
def servers(request, default_server_args):
|
||||
def server_manager(request, default_server_args):
|
||||
api_server_count = request.param
|
||||
with HybridLBServerManager(MODEL_NAME, DP_SIZE, api_server_count,
|
||||
default_server_args, DP_SIZE_LOCAL,
|
||||
TP_SIZE) as server_list:
|
||||
yield server_list
|
||||
server_manager = HybridLBServerManager(MODEL_NAME, DP_SIZE,
|
||||
api_server_count,
|
||||
default_server_args, DP_SIZE_LOCAL,
|
||||
TP_SIZE)
|
||||
|
||||
with server_manager:
|
||||
yield server_manager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def servers(server_manager):
|
||||
return server_manager.servers
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
@ -168,6 +179,39 @@ async def clients(servers: list[tuple[RemoteOpenAIServer, list[str]]]):
|
||||
]
|
||||
|
||||
|
||||
def _get_parallel_config(server: RemoteOpenAIServer):
|
||||
response = requests.get(server.url_for("server_info?config_format=json"))
|
||||
response.raise_for_status()
|
||||
|
||||
vllm_config = response.json()["vllm_config"]
|
||||
return vllm_config["parallel_config"]
|
||||
|
||||
|
||||
def test_hybrid_dp_server_info(server_manager):
|
||||
servers = server_manager.servers
|
||||
api_server_count = server_manager.api_server_count
|
||||
|
||||
for i, (server, _) in enumerate(servers):
|
||||
print(f"Testing {i=}")
|
||||
|
||||
# Each request will hit one of the API servers
|
||||
# `n_reqs` is set so that there is a good chance each server
|
||||
# receives at least one request
|
||||
n_reqs = 2 * api_server_count * api_server_count
|
||||
parallel_configs = [
|
||||
_get_parallel_config(server) for _ in range(n_reqs)
|
||||
]
|
||||
api_process_counts = [
|
||||
c["_api_process_count"] for c in parallel_configs
|
||||
]
|
||||
api_process_ranks = [c["_api_process_rank"] for c in parallel_configs]
|
||||
|
||||
assert all(c == api_server_count
|
||||
for c in api_process_counts), api_process_counts
|
||||
assert all(0 <= r < api_server_count
|
||||
for r in api_process_ranks), api_process_ranks
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
|
||||
@ -10,6 +10,7 @@ from typing import Optional, cast
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import requests
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from tests.v1.test_utils import check_request_balancing
|
||||
@ -101,6 +102,8 @@ class MultinodeInternalLBServerManager:
|
||||
sargs,
|
||||
auto_port=False,
|
||||
env_dict={
|
||||
"VLLM_SERVER_DEV_MODE":
|
||||
"1",
|
||||
current_platform.device_control_env_var:
|
||||
",".join(
|
||||
str(
|
||||
@ -214,7 +217,10 @@ class APIOnlyServerManager:
|
||||
self.model_name,
|
||||
api_server_args,
|
||||
auto_port=False,
|
||||
env_dict={}) # No GPUs needed for API-only server
|
||||
env_dict={
|
||||
"VLLM_SERVER_DEV_MODE": "1",
|
||||
# No GPUs needed for API-only server
|
||||
})
|
||||
server.__enter__()
|
||||
print(f"API-only server started successfully with "
|
||||
f"{self.api_server_count} API servers")
|
||||
@ -293,14 +299,21 @@ def default_server_args():
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=[1, 4])
|
||||
def servers(request, default_server_args):
|
||||
def server_manager(request, default_server_args):
|
||||
api_server_count = request.param
|
||||
with MultinodeInternalLBServerManager(MODEL_NAME, DP_SIZE,
|
||||
api_server_count,
|
||||
default_server_args,
|
||||
DP_SIZE // NUM_NODES,
|
||||
TP_SIZE) as server_list:
|
||||
yield server_list
|
||||
server_manager = MultinodeInternalLBServerManager(MODEL_NAME, DP_SIZE,
|
||||
api_server_count,
|
||||
default_server_args,
|
||||
DP_SIZE // NUM_NODES,
|
||||
TP_SIZE)
|
||||
|
||||
with server_manager:
|
||||
yield server_manager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def servers(server_manager):
|
||||
return server_manager.servers
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=[1, 4])
|
||||
@ -331,6 +344,34 @@ async def api_only_client(api_only_servers: list[tuple[RemoteOpenAIServer,
|
||||
yield client
|
||||
|
||||
|
||||
def _get_parallel_config(server: RemoteOpenAIServer):
|
||||
response = requests.get(server.url_for("server_info?config_format=json"))
|
||||
response.raise_for_status()
|
||||
|
||||
vllm_config = response.json()["vllm_config"]
|
||||
return vllm_config["parallel_config"]
|
||||
|
||||
|
||||
def test_multinode_dp_server_info(server_manager):
|
||||
head_server = server_manager.servers[0][0]
|
||||
api_server_count = server_manager.api_server_count
|
||||
|
||||
# Each request will hit one of the API servers
|
||||
# `n_reqs` is set so that there is a good chance each server
|
||||
# receives at least one request
|
||||
n_reqs = 2 * api_server_count * api_server_count
|
||||
parallel_configs = [
|
||||
_get_parallel_config(head_server) for _ in range(n_reqs)
|
||||
]
|
||||
api_process_counts = [c["_api_process_count"] for c in parallel_configs]
|
||||
api_process_ranks = [c["_api_process_rank"] for c in parallel_configs]
|
||||
|
||||
assert all(c == api_server_count
|
||||
for c in api_process_counts), api_process_counts
|
||||
assert all(0 <= r < api_server_count
|
||||
for r in api_process_ranks), api_process_ranks
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
|
||||
@ -193,6 +193,25 @@ class ParallelConfig:
|
||||
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
|
||||
needs to be divisible by dcp_size."""
|
||||
|
||||
_api_process_count: int = 1
|
||||
"""
|
||||
The number of API processes initialized.
|
||||
|
||||
Note:
|
||||
This is an internal config that is only valid for and
|
||||
should only be set by API server scale-out.
|
||||
"""
|
||||
|
||||
_api_process_rank: int = 0
|
||||
"""
|
||||
The rank of this API process, or `-1` for engine core processes
|
||||
under API server scale-out.
|
||||
|
||||
Note:
|
||||
This is an internal config that is only valid for and
|
||||
should only be set by API server scale-out.
|
||||
"""
|
||||
|
||||
@property
|
||||
def world_size_across_dp(self) -> int:
|
||||
"""world_size_across_dp is TPxPPxDP, it is the size of the world
|
||||
@ -428,6 +447,12 @@ class ParallelConfig:
|
||||
if self.distributed_executor_backend is None and self.world_size == 1:
|
||||
self.distributed_executor_backend = "uni"
|
||||
|
||||
if not -1 <= self._api_process_rank < self._api_process_count:
|
||||
raise ValueError(
|
||||
"Invalid value of `_api_process_rank`. "
|
||||
f"Expected to be `-1` or `[0, {self._api_process_count})`, "
|
||||
f"but found: {self._api_process_rank}")
|
||||
|
||||
@property
|
||||
def use_ray(self) -> bool:
|
||||
return self.distributed_executor_backend == "ray" or (
|
||||
|
||||
@ -333,6 +333,8 @@ class EngineArgs:
|
||||
enable_eplb: bool = ParallelConfig.enable_eplb
|
||||
expert_placement_strategy: ExpertPlacementStrategy = \
|
||||
ParallelConfig.expert_placement_strategy
|
||||
_api_process_count: int = ParallelConfig._api_process_count
|
||||
_api_process_rank: int = ParallelConfig._api_process_rank
|
||||
num_redundant_experts: int = EPLBConfig.num_redundant_experts
|
||||
eplb_window_size: int = EPLBConfig.window_size
|
||||
eplb_step_interval: int = EPLBConfig.step_interval
|
||||
@ -952,7 +954,10 @@ class EngineArgs:
|
||||
# Get the list of attributes of this dataclass.
|
||||
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
||||
# Set the attributes from the parsed arguments.
|
||||
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
|
||||
engine_args = cls(**{
|
||||
attr: getattr(args, attr)
|
||||
for attr in attrs if hasattr(args, attr)
|
||||
})
|
||||
return engine_args
|
||||
|
||||
def create_model_config(self) -> ModelConfig:
|
||||
@ -1366,6 +1371,8 @@ class EngineArgs:
|
||||
worker_cls=self.worker_cls,
|
||||
worker_extension_cls=self.worker_extension_cls,
|
||||
decode_context_parallel_size=self.decode_context_parallel_size,
|
||||
_api_process_count=self._api_process_count,
|
||||
_api_process_rank=self._api_process_rank,
|
||||
)
|
||||
|
||||
speculative_config = self.create_speculative_config(
|
||||
|
||||
@ -135,23 +135,20 @@ def run_headless(args: argparse.Namespace):
|
||||
def run_multi_api_server(args: argparse.Namespace):
|
||||
|
||||
assert not args.headless
|
||||
num_api_servers = args.api_server_count
|
||||
num_api_servers: int = args.api_server_count
|
||||
assert num_api_servers > 0
|
||||
|
||||
orig_mm_processor_cache_gb = args.mm_processor_cache_gb
|
||||
|
||||
if num_api_servers > 1:
|
||||
setup_multiprocess_prometheus()
|
||||
|
||||
# Not compatible with API server scale-out
|
||||
args.mm_processor_cache_gb = 0
|
||||
|
||||
listen_address, sock = setup_server(args)
|
||||
|
||||
engine_args = vllm.AsyncEngineArgs.from_cli_args(args)
|
||||
engine_args._api_process_count = num_api_servers
|
||||
engine_args._api_process_rank = -1
|
||||
|
||||
usage_context = UsageContext.OPENAI_API_SERVER
|
||||
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
if num_api_servers > 1:
|
||||
if not envs.VLLM_USE_V1:
|
||||
@ -161,10 +158,6 @@ def run_multi_api_server(args: argparse.Namespace):
|
||||
raise ValueError("VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used "
|
||||
"with api_server_count > 1")
|
||||
|
||||
if model_config.is_multimodal_model and orig_mm_processor_cache_gb > 0:
|
||||
logger.warning("Multi-modal processor cache is disabled because "
|
||||
"it is not compatible with `api_server_count > 1`.")
|
||||
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
log_stats = not engine_args.disable_log_stats
|
||||
|
||||
@ -221,9 +214,10 @@ def run_api_server_worker_proc(listen_address,
|
||||
client_config=None,
|
||||
**uvicorn_kwargs) -> None:
|
||||
"""Entrypoint for individual API server worker processes."""
|
||||
client_config = client_config or {}
|
||||
server_index = client_config.get("client_index", 0)
|
||||
|
||||
# Set process title and add process-specific prefix to stdout and stderr.
|
||||
server_index = client_config.get("client_index", 0) if client_config else 0
|
||||
set_process_title("APIServer", str(server_index))
|
||||
decorate_logs()
|
||||
|
||||
|
||||
@ -17,13 +17,14 @@ from argparse import Namespace
|
||||
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable
|
||||
from contextlib import asynccontextmanager
|
||||
from http import HTTPStatus
|
||||
from typing import Annotated, Any, Callable, Optional
|
||||
from typing import Annotated, Any, Callable, Literal, Optional
|
||||
|
||||
import prometheus_client
|
||||
import pydantic
|
||||
import regex as re
|
||||
import uvloop
|
||||
from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
|
||||
from fastapi import (APIRouter, Depends, FastAPI, Form, HTTPException, Query,
|
||||
Request)
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
@ -166,6 +167,9 @@ async def build_async_engine_client(
|
||||
# Context manager to handle engine_client lifecycle
|
||||
# Ensures everything is shutdown and cleaned up on error/exit
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
if client_config:
|
||||
engine_args._api_process_count = client_config.get("client_count", 1)
|
||||
engine_args._api_process_rank = client_config.get("client_index", 0)
|
||||
|
||||
if disable_frontend_multiprocessing is None:
|
||||
disable_frontend_multiprocessing = bool(
|
||||
@ -209,8 +213,12 @@ async def build_async_engine_client_from_engine_args(
|
||||
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
async_llm: Optional[AsyncLLM] = None
|
||||
client_count = client_config.pop("client_count") if client_config else 1
|
||||
client_index = client_config.pop("client_index") if client_config else 0
|
||||
|
||||
# Don't mutate the input client_config
|
||||
client_config = dict(client_config) if client_config else {}
|
||||
client_count = client_config.pop("client_count", 1)
|
||||
client_index = client_config.pop("client_index", 0)
|
||||
|
||||
try:
|
||||
async_llm = AsyncLLM.from_vllm_config(
|
||||
vllm_config=vllm_config,
|
||||
@ -956,9 +964,22 @@ if envs.VLLM_SERVER_DEV_MODE:
|
||||
logger.warning("SECURITY WARNING: Development endpoints are enabled! "
|
||||
"This should NOT be used in production!")
|
||||
|
||||
PydanticVllmConfig = pydantic.TypeAdapter(VllmConfig)
|
||||
|
||||
@router.get("/server_info")
|
||||
async def show_server_info(raw_request: Request):
|
||||
server_info = {"vllm_config": str(raw_request.app.state.vllm_config)}
|
||||
async def show_server_info(
|
||||
raw_request: Request,
|
||||
config_format: Annotated[Literal["text", "json"],
|
||||
Query()] = "text",
|
||||
):
|
||||
vllm_config: VllmConfig = raw_request.app.state.vllm_config
|
||||
server_info = {
|
||||
"vllm_config":
|
||||
str(vllm_config)
|
||||
if config_format == "text" else PydanticVllmConfig.dump_python(
|
||||
vllm_config, mode="json", fallback=str)
|
||||
# fallback=str is needed to handle e.g. torch.dtype
|
||||
}
|
||||
return JSONResponse(content=server_info)
|
||||
|
||||
@router.post("/reset_prefix_cache")
|
||||
@ -1856,8 +1877,6 @@ async def run_server_worker(listen_address,
|
||||
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
|
||||
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
||||
|
||||
server_index = client_config.get("client_index", 0) if client_config else 0
|
||||
|
||||
# Load logging config for uvicorn if specified
|
||||
log_config = load_log_config(args.log_config_file)
|
||||
if log_config is not None:
|
||||
@ -1873,7 +1892,8 @@ async def run_server_worker(listen_address,
|
||||
vllm_config = await engine_client.get_vllm_config()
|
||||
await init_app_state(engine_client, vllm_config, app.state, args)
|
||||
|
||||
logger.info("Starting vLLM API server %d on %s", server_index,
|
||||
logger.info("Starting vLLM API server %d on %s",
|
||||
vllm_config.parallel_config._api_process_rank,
|
||||
listen_address)
|
||||
shutdown_task = await serve_http(
|
||||
app,
|
||||
|
||||
@ -494,7 +494,8 @@ def _enable_processor_cache(
|
||||
|
||||
def _enable_ipc_cache(vllm_config: "VllmConfig") -> bool:
|
||||
parallel_config = vllm_config.parallel_config
|
||||
supports_ipc_cache = (parallel_config.data_parallel_size == 1
|
||||
supports_ipc_cache = ((parallel_config._api_process_count == 1
|
||||
and parallel_config.data_parallel_size == 1)
|
||||
or parallel_config.data_parallel_external_lb)
|
||||
|
||||
return supports_ipc_cache
|
||||
|
||||
@ -437,7 +437,7 @@ class MPClient(EngineCoreClient):
|
||||
self.engines_running = False
|
||||
|
||||
self.stats_update_address: Optional[str] = None
|
||||
if client_addresses is not None:
|
||||
if client_addresses:
|
||||
# Engines are managed externally to this client.
|
||||
input_address = client_addresses["input_address"]
|
||||
output_address = client_addresses["output_address"]
|
||||
@ -774,6 +774,7 @@ class AsyncMPClient(MPClient):
|
||||
client_addresses=client_addresses,
|
||||
)
|
||||
|
||||
self.client_count = client_count
|
||||
self.client_index = client_index
|
||||
self.outputs_queue = asyncio.Queue[Union[EngineCoreOutputs,
|
||||
Exception]]()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user