mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-21 07:06:59 +08:00
[Frontend] Pass API server count to each process (#23717)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
7ac67ea525
commit
6c117cff7d
@ -11,13 +11,13 @@ from datetime import datetime
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
_w8a8_block_fp8_matmul,
|
_w8a8_block_fp8_matmul,
|
||||||
)
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.triton_utils import triton
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
mp.set_start_method("spawn", force=True)
|
mp.set_start_method("spawn", force=True)
|
||||||
|
|||||||
@ -1,8 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import argparse
|
|
||||||
import dataclasses
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@ -327,12 +325,7 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if args.command == "serialize":
|
if args.command == "serialize":
|
||||||
eng_args_dict = {f.name: getattr(args, f.name) for f in
|
engine_args = EngineArgs.from_cli_args(args)
|
||||||
dataclasses.fields(EngineArgs)}
|
|
||||||
|
|
||||||
engine_args = EngineArgs.from_cli_args(
|
|
||||||
argparse.Namespace(**eng_args_dict)
|
|
||||||
)
|
|
||||||
|
|
||||||
input_dir = tensorizer_dir.rstrip('/')
|
input_dir = tensorizer_dir.rstrip('/')
|
||||||
suffix = args.suffix if args.suffix else uuid.uuid4().hex
|
suffix = args.suffix if args.suffix else uuid.uuid4().hex
|
||||||
|
|||||||
@ -60,7 +60,7 @@ def test_api_server_process_manager_init(api_server_args, with_stats_update):
|
|||||||
global WORKER_RUNTIME_SECONDS
|
global WORKER_RUNTIME_SECONDS
|
||||||
WORKER_RUNTIME_SECONDS = 0.5
|
WORKER_RUNTIME_SECONDS = 0.5
|
||||||
|
|
||||||
# Copy the args to avoid mutating the
|
# Copy the args to avoid mutating them
|
||||||
args = api_server_args.copy()
|
args = api_server_args.copy()
|
||||||
|
|
||||||
if not with_stats_update:
|
if not with_stats_update:
|
||||||
|
|||||||
@ -9,6 +9,7 @@ from contextlib import AsyncExitStack
|
|||||||
import openai # use the official client for correctness check
|
import openai # use the official client for correctness check
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
import requests
|
||||||
|
|
||||||
from tests.utils import RemoteOpenAIServer
|
from tests.utils import RemoteOpenAIServer
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -70,6 +71,8 @@ class ExternalLBServerManager:
|
|||||||
sargs,
|
sargs,
|
||||||
auto_port=False,
|
auto_port=False,
|
||||||
env_dict={
|
env_dict={
|
||||||
|
"VLLM_SERVER_DEV_MODE":
|
||||||
|
"1",
|
||||||
current_platform.device_control_env_var:
|
current_platform.device_control_env_var:
|
||||||
",".join(
|
",".join(
|
||||||
str(
|
str(
|
||||||
@ -127,11 +130,19 @@ def default_server_args():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module", params=[1, 4])
|
@pytest.fixture(scope="module", params=[1, 4])
|
||||||
def servers(request, default_server_args):
|
def server_manager(request, default_server_args):
|
||||||
api_server_count = request.param
|
api_server_count = request.param
|
||||||
with ExternalLBServerManager(MODEL_NAME, DP_SIZE, api_server_count,
|
server_manager = ExternalLBServerManager(MODEL_NAME, DP_SIZE,
|
||||||
default_server_args) as server_list:
|
api_server_count,
|
||||||
yield server_list
|
default_server_args)
|
||||||
|
|
||||||
|
with server_manager:
|
||||||
|
yield server_manager
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def servers(server_manager):
|
||||||
|
return server_manager.servers
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
@pytest_asyncio.fixture
|
||||||
@ -144,6 +155,39 @@ async def clients(servers: list[tuple[RemoteOpenAIServer, list[str]]]):
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_parallel_config(server: RemoteOpenAIServer):
|
||||||
|
response = requests.get(server.url_for("server_info?config_format=json"))
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
vllm_config = response.json()["vllm_config"]
|
||||||
|
return vllm_config["parallel_config"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_external_lb_server_info(server_manager):
|
||||||
|
servers = server_manager.servers
|
||||||
|
api_server_count = server_manager.api_server_count
|
||||||
|
|
||||||
|
for i, (server, _) in enumerate(servers):
|
||||||
|
print(f"Testing {i=}")
|
||||||
|
|
||||||
|
# Each request will hit one of the API servers
|
||||||
|
# `n_reqs` is set so that there is a good chance each server
|
||||||
|
# receives at least one request
|
||||||
|
n_reqs = 2 * api_server_count * api_server_count
|
||||||
|
parallel_configs = [
|
||||||
|
_get_parallel_config(server) for _ in range(n_reqs)
|
||||||
|
]
|
||||||
|
api_process_counts = [
|
||||||
|
c["_api_process_count"] for c in parallel_configs
|
||||||
|
]
|
||||||
|
api_process_ranks = [c["_api_process_rank"] for c in parallel_configs]
|
||||||
|
|
||||||
|
assert all(c == api_server_count
|
||||||
|
for c in api_process_counts), api_process_counts
|
||||||
|
assert all(0 <= r < api_server_count
|
||||||
|
for r in api_process_ranks), api_process_ranks
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model_name",
|
"model_name",
|
||||||
|
|||||||
@ -9,6 +9,7 @@ from contextlib import AsyncExitStack
|
|||||||
import openai # use the official client for correctness check
|
import openai # use the official client for correctness check
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
import requests
|
||||||
|
|
||||||
from tests.utils import RemoteOpenAIServer
|
from tests.utils import RemoteOpenAIServer
|
||||||
from tests.v1.test_utils import check_request_balancing
|
from tests.v1.test_utils import check_request_balancing
|
||||||
@ -92,6 +93,8 @@ class HybridLBServerManager:
|
|||||||
sargs,
|
sargs,
|
||||||
auto_port=False,
|
auto_port=False,
|
||||||
env_dict={
|
env_dict={
|
||||||
|
"VLLM_SERVER_DEV_MODE":
|
||||||
|
"1",
|
||||||
current_platform.device_control_env_var:
|
current_platform.device_control_env_var:
|
||||||
",".join(
|
",".join(
|
||||||
str(
|
str(
|
||||||
@ -150,12 +153,20 @@ def default_server_args():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module", params=[1, 4])
|
@pytest.fixture(scope="module", params=[1, 4])
|
||||||
def servers(request, default_server_args):
|
def server_manager(request, default_server_args):
|
||||||
api_server_count = request.param
|
api_server_count = request.param
|
||||||
with HybridLBServerManager(MODEL_NAME, DP_SIZE, api_server_count,
|
server_manager = HybridLBServerManager(MODEL_NAME, DP_SIZE,
|
||||||
default_server_args, DP_SIZE_LOCAL,
|
api_server_count,
|
||||||
TP_SIZE) as server_list:
|
default_server_args, DP_SIZE_LOCAL,
|
||||||
yield server_list
|
TP_SIZE)
|
||||||
|
|
||||||
|
with server_manager:
|
||||||
|
yield server_manager
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def servers(server_manager):
|
||||||
|
return server_manager.servers
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
@pytest_asyncio.fixture
|
||||||
@ -168,6 +179,39 @@ async def clients(servers: list[tuple[RemoteOpenAIServer, list[str]]]):
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_parallel_config(server: RemoteOpenAIServer):
|
||||||
|
response = requests.get(server.url_for("server_info?config_format=json"))
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
vllm_config = response.json()["vllm_config"]
|
||||||
|
return vllm_config["parallel_config"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_hybrid_dp_server_info(server_manager):
|
||||||
|
servers = server_manager.servers
|
||||||
|
api_server_count = server_manager.api_server_count
|
||||||
|
|
||||||
|
for i, (server, _) in enumerate(servers):
|
||||||
|
print(f"Testing {i=}")
|
||||||
|
|
||||||
|
# Each request will hit one of the API servers
|
||||||
|
# `n_reqs` is set so that there is a good chance each server
|
||||||
|
# receives at least one request
|
||||||
|
n_reqs = 2 * api_server_count * api_server_count
|
||||||
|
parallel_configs = [
|
||||||
|
_get_parallel_config(server) for _ in range(n_reqs)
|
||||||
|
]
|
||||||
|
api_process_counts = [
|
||||||
|
c["_api_process_count"] for c in parallel_configs
|
||||||
|
]
|
||||||
|
api_process_ranks = [c["_api_process_rank"] for c in parallel_configs]
|
||||||
|
|
||||||
|
assert all(c == api_server_count
|
||||||
|
for c in api_process_counts), api_process_counts
|
||||||
|
assert all(0 <= r < api_server_count
|
||||||
|
for r in api_process_ranks), api_process_ranks
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model_name",
|
"model_name",
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from typing import Optional, cast
|
|||||||
import openai # use the official client for correctness check
|
import openai # use the official client for correctness check
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
import requests
|
||||||
|
|
||||||
from tests.utils import RemoteOpenAIServer
|
from tests.utils import RemoteOpenAIServer
|
||||||
from tests.v1.test_utils import check_request_balancing
|
from tests.v1.test_utils import check_request_balancing
|
||||||
@ -101,6 +102,8 @@ class MultinodeInternalLBServerManager:
|
|||||||
sargs,
|
sargs,
|
||||||
auto_port=False,
|
auto_port=False,
|
||||||
env_dict={
|
env_dict={
|
||||||
|
"VLLM_SERVER_DEV_MODE":
|
||||||
|
"1",
|
||||||
current_platform.device_control_env_var:
|
current_platform.device_control_env_var:
|
||||||
",".join(
|
",".join(
|
||||||
str(
|
str(
|
||||||
@ -214,7 +217,10 @@ class APIOnlyServerManager:
|
|||||||
self.model_name,
|
self.model_name,
|
||||||
api_server_args,
|
api_server_args,
|
||||||
auto_port=False,
|
auto_port=False,
|
||||||
env_dict={}) # No GPUs needed for API-only server
|
env_dict={
|
||||||
|
"VLLM_SERVER_DEV_MODE": "1",
|
||||||
|
# No GPUs needed for API-only server
|
||||||
|
})
|
||||||
server.__enter__()
|
server.__enter__()
|
||||||
print(f"API-only server started successfully with "
|
print(f"API-only server started successfully with "
|
||||||
f"{self.api_server_count} API servers")
|
f"{self.api_server_count} API servers")
|
||||||
@ -293,14 +299,21 @@ def default_server_args():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module", params=[1, 4])
|
@pytest.fixture(scope="module", params=[1, 4])
|
||||||
def servers(request, default_server_args):
|
def server_manager(request, default_server_args):
|
||||||
api_server_count = request.param
|
api_server_count = request.param
|
||||||
with MultinodeInternalLBServerManager(MODEL_NAME, DP_SIZE,
|
server_manager = MultinodeInternalLBServerManager(MODEL_NAME, DP_SIZE,
|
||||||
api_server_count,
|
api_server_count,
|
||||||
default_server_args,
|
default_server_args,
|
||||||
DP_SIZE // NUM_NODES,
|
DP_SIZE // NUM_NODES,
|
||||||
TP_SIZE) as server_list:
|
TP_SIZE)
|
||||||
yield server_list
|
|
||||||
|
with server_manager:
|
||||||
|
yield server_manager
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def servers(server_manager):
|
||||||
|
return server_manager.servers
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module", params=[1, 4])
|
@pytest.fixture(scope="module", params=[1, 4])
|
||||||
@ -331,6 +344,34 @@ async def api_only_client(api_only_servers: list[tuple[RemoteOpenAIServer,
|
|||||||
yield client
|
yield client
|
||||||
|
|
||||||
|
|
||||||
|
def _get_parallel_config(server: RemoteOpenAIServer):
|
||||||
|
response = requests.get(server.url_for("server_info?config_format=json"))
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
vllm_config = response.json()["vllm_config"]
|
||||||
|
return vllm_config["parallel_config"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_multinode_dp_server_info(server_manager):
|
||||||
|
head_server = server_manager.servers[0][0]
|
||||||
|
api_server_count = server_manager.api_server_count
|
||||||
|
|
||||||
|
# Each request will hit one of the API servers
|
||||||
|
# `n_reqs` is set so that there is a good chance each server
|
||||||
|
# receives at least one request
|
||||||
|
n_reqs = 2 * api_server_count * api_server_count
|
||||||
|
parallel_configs = [
|
||||||
|
_get_parallel_config(head_server) for _ in range(n_reqs)
|
||||||
|
]
|
||||||
|
api_process_counts = [c["_api_process_count"] for c in parallel_configs]
|
||||||
|
api_process_ranks = [c["_api_process_rank"] for c in parallel_configs]
|
||||||
|
|
||||||
|
assert all(c == api_server_count
|
||||||
|
for c in api_process_counts), api_process_counts
|
||||||
|
assert all(0 <= r < api_server_count
|
||||||
|
for r in api_process_ranks), api_process_ranks
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model_name",
|
"model_name",
|
||||||
|
|||||||
@ -193,6 +193,25 @@ class ParallelConfig:
|
|||||||
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
|
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
|
||||||
needs to be divisible by dcp_size."""
|
needs to be divisible by dcp_size."""
|
||||||
|
|
||||||
|
_api_process_count: int = 1
|
||||||
|
"""
|
||||||
|
The number of API processes initialized.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This is an internal config that is only valid for and
|
||||||
|
should only be set by API server scale-out.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_api_process_rank: int = 0
|
||||||
|
"""
|
||||||
|
The rank of this API process, or `-1` for engine core processes
|
||||||
|
under API server scale-out.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This is an internal config that is only valid for and
|
||||||
|
should only be set by API server scale-out.
|
||||||
|
"""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def world_size_across_dp(self) -> int:
|
def world_size_across_dp(self) -> int:
|
||||||
"""world_size_across_dp is TPxPPxDP, it is the size of the world
|
"""world_size_across_dp is TPxPPxDP, it is the size of the world
|
||||||
@ -428,6 +447,12 @@ class ParallelConfig:
|
|||||||
if self.distributed_executor_backend is None and self.world_size == 1:
|
if self.distributed_executor_backend is None and self.world_size == 1:
|
||||||
self.distributed_executor_backend = "uni"
|
self.distributed_executor_backend = "uni"
|
||||||
|
|
||||||
|
if not -1 <= self._api_process_rank < self._api_process_count:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid value of `_api_process_rank`. "
|
||||||
|
f"Expected to be `-1` or `[0, {self._api_process_count})`, "
|
||||||
|
f"but found: {self._api_process_rank}")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def use_ray(self) -> bool:
|
def use_ray(self) -> bool:
|
||||||
return self.distributed_executor_backend == "ray" or (
|
return self.distributed_executor_backend == "ray" or (
|
||||||
|
|||||||
@ -333,6 +333,8 @@ class EngineArgs:
|
|||||||
enable_eplb: bool = ParallelConfig.enable_eplb
|
enable_eplb: bool = ParallelConfig.enable_eplb
|
||||||
expert_placement_strategy: ExpertPlacementStrategy = \
|
expert_placement_strategy: ExpertPlacementStrategy = \
|
||||||
ParallelConfig.expert_placement_strategy
|
ParallelConfig.expert_placement_strategy
|
||||||
|
_api_process_count: int = ParallelConfig._api_process_count
|
||||||
|
_api_process_rank: int = ParallelConfig._api_process_rank
|
||||||
num_redundant_experts: int = EPLBConfig.num_redundant_experts
|
num_redundant_experts: int = EPLBConfig.num_redundant_experts
|
||||||
eplb_window_size: int = EPLBConfig.window_size
|
eplb_window_size: int = EPLBConfig.window_size
|
||||||
eplb_step_interval: int = EPLBConfig.step_interval
|
eplb_step_interval: int = EPLBConfig.step_interval
|
||||||
@ -952,7 +954,10 @@ class EngineArgs:
|
|||||||
# Get the list of attributes of this dataclass.
|
# Get the list of attributes of this dataclass.
|
||||||
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
||||||
# Set the attributes from the parsed arguments.
|
# Set the attributes from the parsed arguments.
|
||||||
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
|
engine_args = cls(**{
|
||||||
|
attr: getattr(args, attr)
|
||||||
|
for attr in attrs if hasattr(args, attr)
|
||||||
|
})
|
||||||
return engine_args
|
return engine_args
|
||||||
|
|
||||||
def create_model_config(self) -> ModelConfig:
|
def create_model_config(self) -> ModelConfig:
|
||||||
@ -1366,6 +1371,8 @@ class EngineArgs:
|
|||||||
worker_cls=self.worker_cls,
|
worker_cls=self.worker_cls,
|
||||||
worker_extension_cls=self.worker_extension_cls,
|
worker_extension_cls=self.worker_extension_cls,
|
||||||
decode_context_parallel_size=self.decode_context_parallel_size,
|
decode_context_parallel_size=self.decode_context_parallel_size,
|
||||||
|
_api_process_count=self._api_process_count,
|
||||||
|
_api_process_rank=self._api_process_rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
speculative_config = self.create_speculative_config(
|
speculative_config = self.create_speculative_config(
|
||||||
|
|||||||
@ -135,23 +135,20 @@ def run_headless(args: argparse.Namespace):
|
|||||||
def run_multi_api_server(args: argparse.Namespace):
|
def run_multi_api_server(args: argparse.Namespace):
|
||||||
|
|
||||||
assert not args.headless
|
assert not args.headless
|
||||||
num_api_servers = args.api_server_count
|
num_api_servers: int = args.api_server_count
|
||||||
assert num_api_servers > 0
|
assert num_api_servers > 0
|
||||||
|
|
||||||
orig_mm_processor_cache_gb = args.mm_processor_cache_gb
|
|
||||||
|
|
||||||
if num_api_servers > 1:
|
if num_api_servers > 1:
|
||||||
setup_multiprocess_prometheus()
|
setup_multiprocess_prometheus()
|
||||||
|
|
||||||
# Not compatible with API server scale-out
|
|
||||||
args.mm_processor_cache_gb = 0
|
|
||||||
|
|
||||||
listen_address, sock = setup_server(args)
|
listen_address, sock = setup_server(args)
|
||||||
|
|
||||||
engine_args = vllm.AsyncEngineArgs.from_cli_args(args)
|
engine_args = vllm.AsyncEngineArgs.from_cli_args(args)
|
||||||
|
engine_args._api_process_count = num_api_servers
|
||||||
|
engine_args._api_process_rank = -1
|
||||||
|
|
||||||
usage_context = UsageContext.OPENAI_API_SERVER
|
usage_context = UsageContext.OPENAI_API_SERVER
|
||||||
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
|
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
|
||||||
model_config = vllm_config.model_config
|
|
||||||
|
|
||||||
if num_api_servers > 1:
|
if num_api_servers > 1:
|
||||||
if not envs.VLLM_USE_V1:
|
if not envs.VLLM_USE_V1:
|
||||||
@ -161,10 +158,6 @@ def run_multi_api_server(args: argparse.Namespace):
|
|||||||
raise ValueError("VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used "
|
raise ValueError("VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used "
|
||||||
"with api_server_count > 1")
|
"with api_server_count > 1")
|
||||||
|
|
||||||
if model_config.is_multimodal_model and orig_mm_processor_cache_gb > 0:
|
|
||||||
logger.warning("Multi-modal processor cache is disabled because "
|
|
||||||
"it is not compatible with `api_server_count > 1`.")
|
|
||||||
|
|
||||||
executor_class = Executor.get_class(vllm_config)
|
executor_class = Executor.get_class(vllm_config)
|
||||||
log_stats = not engine_args.disable_log_stats
|
log_stats = not engine_args.disable_log_stats
|
||||||
|
|
||||||
@ -221,9 +214,10 @@ def run_api_server_worker_proc(listen_address,
|
|||||||
client_config=None,
|
client_config=None,
|
||||||
**uvicorn_kwargs) -> None:
|
**uvicorn_kwargs) -> None:
|
||||||
"""Entrypoint for individual API server worker processes."""
|
"""Entrypoint for individual API server worker processes."""
|
||||||
|
client_config = client_config or {}
|
||||||
|
server_index = client_config.get("client_index", 0)
|
||||||
|
|
||||||
# Set process title and add process-specific prefix to stdout and stderr.
|
# Set process title and add process-specific prefix to stdout and stderr.
|
||||||
server_index = client_config.get("client_index", 0) if client_config else 0
|
|
||||||
set_process_title("APIServer", str(server_index))
|
set_process_title("APIServer", str(server_index))
|
||||||
decorate_logs()
|
decorate_logs()
|
||||||
|
|
||||||
|
|||||||
@ -17,13 +17,14 @@ from argparse import Namespace
|
|||||||
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable
|
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Annotated, Any, Callable, Optional
|
from typing import Annotated, Any, Callable, Literal, Optional
|
||||||
|
|
||||||
import prometheus_client
|
import prometheus_client
|
||||||
import pydantic
|
import pydantic
|
||||||
import regex as re
|
import regex as re
|
||||||
import uvloop
|
import uvloop
|
||||||
from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
|
from fastapi import (APIRouter, Depends, FastAPI, Form, HTTPException, Query,
|
||||||
|
Request)
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||||
@ -166,6 +167,9 @@ async def build_async_engine_client(
|
|||||||
# Context manager to handle engine_client lifecycle
|
# Context manager to handle engine_client lifecycle
|
||||||
# Ensures everything is shutdown and cleaned up on error/exit
|
# Ensures everything is shutdown and cleaned up on error/exit
|
||||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||||
|
if client_config:
|
||||||
|
engine_args._api_process_count = client_config.get("client_count", 1)
|
||||||
|
engine_args._api_process_rank = client_config.get("client_index", 0)
|
||||||
|
|
||||||
if disable_frontend_multiprocessing is None:
|
if disable_frontend_multiprocessing is None:
|
||||||
disable_frontend_multiprocessing = bool(
|
disable_frontend_multiprocessing = bool(
|
||||||
@ -209,8 +213,12 @@ async def build_async_engine_client_from_engine_args(
|
|||||||
|
|
||||||
from vllm.v1.engine.async_llm import AsyncLLM
|
from vllm.v1.engine.async_llm import AsyncLLM
|
||||||
async_llm: Optional[AsyncLLM] = None
|
async_llm: Optional[AsyncLLM] = None
|
||||||
client_count = client_config.pop("client_count") if client_config else 1
|
|
||||||
client_index = client_config.pop("client_index") if client_config else 0
|
# Don't mutate the input client_config
|
||||||
|
client_config = dict(client_config) if client_config else {}
|
||||||
|
client_count = client_config.pop("client_count", 1)
|
||||||
|
client_index = client_config.pop("client_index", 0)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async_llm = AsyncLLM.from_vllm_config(
|
async_llm = AsyncLLM.from_vllm_config(
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
@ -956,9 +964,22 @@ if envs.VLLM_SERVER_DEV_MODE:
|
|||||||
logger.warning("SECURITY WARNING: Development endpoints are enabled! "
|
logger.warning("SECURITY WARNING: Development endpoints are enabled! "
|
||||||
"This should NOT be used in production!")
|
"This should NOT be used in production!")
|
||||||
|
|
||||||
|
PydanticVllmConfig = pydantic.TypeAdapter(VllmConfig)
|
||||||
|
|
||||||
@router.get("/server_info")
|
@router.get("/server_info")
|
||||||
async def show_server_info(raw_request: Request):
|
async def show_server_info(
|
||||||
server_info = {"vllm_config": str(raw_request.app.state.vllm_config)}
|
raw_request: Request,
|
||||||
|
config_format: Annotated[Literal["text", "json"],
|
||||||
|
Query()] = "text",
|
||||||
|
):
|
||||||
|
vllm_config: VllmConfig = raw_request.app.state.vllm_config
|
||||||
|
server_info = {
|
||||||
|
"vllm_config":
|
||||||
|
str(vllm_config)
|
||||||
|
if config_format == "text" else PydanticVllmConfig.dump_python(
|
||||||
|
vllm_config, mode="json", fallback=str)
|
||||||
|
# fallback=str is needed to handle e.g. torch.dtype
|
||||||
|
}
|
||||||
return JSONResponse(content=server_info)
|
return JSONResponse(content=server_info)
|
||||||
|
|
||||||
@router.post("/reset_prefix_cache")
|
@router.post("/reset_prefix_cache")
|
||||||
@ -1856,8 +1877,6 @@ async def run_server_worker(listen_address,
|
|||||||
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
|
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
|
||||||
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
||||||
|
|
||||||
server_index = client_config.get("client_index", 0) if client_config else 0
|
|
||||||
|
|
||||||
# Load logging config for uvicorn if specified
|
# Load logging config for uvicorn if specified
|
||||||
log_config = load_log_config(args.log_config_file)
|
log_config = load_log_config(args.log_config_file)
|
||||||
if log_config is not None:
|
if log_config is not None:
|
||||||
@ -1873,7 +1892,8 @@ async def run_server_worker(listen_address,
|
|||||||
vllm_config = await engine_client.get_vllm_config()
|
vllm_config = await engine_client.get_vllm_config()
|
||||||
await init_app_state(engine_client, vllm_config, app.state, args)
|
await init_app_state(engine_client, vllm_config, app.state, args)
|
||||||
|
|
||||||
logger.info("Starting vLLM API server %d on %s", server_index,
|
logger.info("Starting vLLM API server %d on %s",
|
||||||
|
vllm_config.parallel_config._api_process_rank,
|
||||||
listen_address)
|
listen_address)
|
||||||
shutdown_task = await serve_http(
|
shutdown_task = await serve_http(
|
||||||
app,
|
app,
|
||||||
|
|||||||
@ -494,7 +494,8 @@ def _enable_processor_cache(
|
|||||||
|
|
||||||
def _enable_ipc_cache(vllm_config: "VllmConfig") -> bool:
|
def _enable_ipc_cache(vllm_config: "VllmConfig") -> bool:
|
||||||
parallel_config = vllm_config.parallel_config
|
parallel_config = vllm_config.parallel_config
|
||||||
supports_ipc_cache = (parallel_config.data_parallel_size == 1
|
supports_ipc_cache = ((parallel_config._api_process_count == 1
|
||||||
|
and parallel_config.data_parallel_size == 1)
|
||||||
or parallel_config.data_parallel_external_lb)
|
or parallel_config.data_parallel_external_lb)
|
||||||
|
|
||||||
return supports_ipc_cache
|
return supports_ipc_cache
|
||||||
|
|||||||
@ -437,7 +437,7 @@ class MPClient(EngineCoreClient):
|
|||||||
self.engines_running = False
|
self.engines_running = False
|
||||||
|
|
||||||
self.stats_update_address: Optional[str] = None
|
self.stats_update_address: Optional[str] = None
|
||||||
if client_addresses is not None:
|
if client_addresses:
|
||||||
# Engines are managed externally to this client.
|
# Engines are managed externally to this client.
|
||||||
input_address = client_addresses["input_address"]
|
input_address = client_addresses["input_address"]
|
||||||
output_address = client_addresses["output_address"]
|
output_address = client_addresses["output_address"]
|
||||||
@ -774,6 +774,7 @@ class AsyncMPClient(MPClient):
|
|||||||
client_addresses=client_addresses,
|
client_addresses=client_addresses,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.client_count = client_count
|
||||||
self.client_index = client_index
|
self.client_index = client_index
|
||||||
self.outputs_queue = asyncio.Queue[Union[EngineCoreOutputs,
|
self.outputs_queue = asyncio.Queue[Union[EngineCoreOutputs,
|
||||||
Exception]]()
|
Exception]]()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user