[Perf][CLI] Improve overall startup time (#19941)

This commit is contained in:
Aaron Pham 2025-06-22 19:11:22 -04:00 committed by GitHub
parent 33d51f599e
commit c4cf260677
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 293 additions and 103 deletions

View File

@ -115,6 +115,11 @@ repos:
entry: python tools/check_spdx_header.py entry: python tools/check_spdx_header.py
language: python language: python
types: [python] types: [python]
- id: check-root-lazy-imports
name: Check root lazy imports
entry: python tools/check_init_lazy_imports.py
language: python
types: [python]
- id: check-filenames - id: check-filenames
name: Check for spaces in all filenames name: Check for spaces in all filenames
entry: bash entry: bash

View File

@ -0,0 +1,108 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Ensure we perform lazy loading in vllm/__init__.py.
i.e: appears only within the ``if typing.TYPE_CHECKING:`` guard,
**except** for a short whitelist.
"""
from __future__ import annotations
import ast
import pathlib
import sys
from collections.abc import Iterable
from typing import Final
REPO_ROOT: Final = pathlib.Path(__file__).resolve().parent.parent
INIT_PATH: Final = REPO_ROOT / "vllm" / "__init__.py"
# If you need to add items to whitelist, do it here.
ALLOWED_IMPORTS: Final[frozenset[str]] = frozenset({
"vllm.env_override",
})
ALLOWED_FROM_MODULES: Final[frozenset[str]] = frozenset({
".version",
})
def _is_internal(name: str | None, *, level: int = 0) -> bool:
if level > 0:
return True
if name is None:
return False
return name.startswith("vllm.") or name == "vllm"
def _fail(violations: Iterable[tuple[int, str]]) -> None:
print("ERROR: Disallowed eager imports in vllm/__init__.py:\n",
file=sys.stderr)
for lineno, msg in violations:
print(f" Line {lineno}: {msg}", file=sys.stderr)
sys.exit(1)
def main() -> None:
source = INIT_PATH.read_text(encoding="utf-8")
tree = ast.parse(source, filename=str(INIT_PATH))
violations: list[tuple[int, str]] = []
class Visitor(ast.NodeVisitor):
def __init__(self) -> None:
super().__init__()
self._in_type_checking = False
def visit_If(self, node: ast.If) -> None:
guard_is_type_checking = False
test = node.test
if isinstance(test, ast.Attribute) and isinstance(
test.value, ast.Name):
guard_is_type_checking = (test.value.id == "typing"
and test.attr == "TYPE_CHECKING")
elif isinstance(test, ast.Name):
guard_is_type_checking = test.id == "TYPE_CHECKING"
if guard_is_type_checking:
prev = self._in_type_checking
self._in_type_checking = True
for child in node.body:
self.visit(child)
self._in_type_checking = prev
for child in node.orelse:
self.visit(child)
else:
self.generic_visit(node)
def visit_Import(self, node: ast.Import) -> None:
if self._in_type_checking:
return
for alias in node.names:
module_name = alias.name
if _is_internal(
module_name) and module_name not in ALLOWED_IMPORTS:
violations.append((
node.lineno,
f"import '{module_name}' must be inside typing.TYPE_CHECKING", # noqa: E501
))
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
if self._in_type_checking:
return
module_as_written = ("." * node.level) + (node.module or "")
if _is_internal(
node.module, level=node.level
) and module_as_written not in ALLOWED_FROM_MODULES:
violations.append((
node.lineno,
f"from '{module_as_written}' import ... must be inside typing.TYPE_CHECKING", # noqa: E501
))
Visitor().visit(tree)
if violations:
_fail(violations)
if __name__ == "__main__":
main()

View File

@ -1,29 +1,72 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs""" """vLLM: a high-throughput and memory-efficient inference engine for LLMs"""
# The version.py should be independent library, and we always import the # The version.py should be independent library, and we always import the
# version library first. Such assumption is critical for some customization. # version library first. Such assumption is critical for some customization.
from .version import __version__, __version_tuple__ # isort:skip from .version import __version__, __version_tuple__ # isort:skip
import typing
# The environment variables override should be imported before any other # The environment variables override should be imported before any other
# modules to ensure that the environment variables are set before any # modules to ensure that the environment variables are set before any
# other modules are imported. # other modules are imported.
import vllm.env_override # isort:skip # noqa: F401 import vllm.env_override # noqa: F401
MODULE_ATTRS = {
"AsyncEngineArgs": ".engine.arg_utils:AsyncEngineArgs",
"EngineArgs": ".engine.arg_utils:EngineArgs",
"AsyncLLMEngine": ".engine.async_llm_engine:AsyncLLMEngine",
"LLMEngine": ".engine.llm_engine:LLMEngine",
"LLM": ".entrypoints.llm:LLM",
"initialize_ray_cluster": ".executor.ray_utils:initialize_ray_cluster",
"PromptType": ".inputs:PromptType",
"TextPrompt": ".inputs:TextPrompt",
"TokensPrompt": ".inputs:TokensPrompt",
"ModelRegistry": ".model_executor.models:ModelRegistry",
"SamplingParams": ".sampling_params:SamplingParams",
"PoolingParams": ".pooling_params:PoolingParams",
"ClassificationOutput": ".outputs:ClassificationOutput",
"ClassificationRequestOutput": ".outputs:ClassificationRequestOutput",
"CompletionOutput": ".outputs:CompletionOutput",
"EmbeddingOutput": ".outputs:EmbeddingOutput",
"EmbeddingRequestOutput": ".outputs:EmbeddingRequestOutput",
"PoolingOutput": ".outputs:PoolingOutput",
"PoolingRequestOutput": ".outputs:PoolingRequestOutput",
"RequestOutput": ".outputs:RequestOutput",
"ScoringOutput": ".outputs:ScoringOutput",
"ScoringRequestOutput": ".outputs:ScoringRequestOutput",
}
if typing.TYPE_CHECKING:
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.model_executor.models import ModelRegistry
from vllm.outputs import (ClassificationOutput,
ClassificationRequestOutput, CompletionOutput,
EmbeddingOutput, EmbeddingRequestOutput,
PoolingOutput, PoolingRequestOutput,
RequestOutput, ScoringOutput,
ScoringRequestOutput)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
else:
def __getattr__(name: str) -> typing.Any:
from importlib import import_module
if name in MODULE_ATTRS:
module_name, attr_name = MODULE_ATTRS[name].split(":")
module = import_module(module_name, __package__)
return getattr(module, attr_name)
else:
raise AttributeError(
f'module {__package__} has no attribute {name}')
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.model_executor.models import ModelRegistry
from vllm.outputs import (ClassificationOutput, ClassificationRequestOutput,
CompletionOutput, EmbeddingOutput,
EmbeddingRequestOutput, PoolingOutput,
PoolingRequestOutput, RequestOutput, ScoringOutput,
ScoringRequestOutput)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
__all__ = [ __all__ = [
"__version__", "__version__",

View File

@ -28,7 +28,7 @@ from pydantic.dataclasses import dataclass
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from torch.distributed import ProcessGroup, ReduceOp from torch.distributed import ProcessGroup, ReduceOp
from transformers import PretrainedConfig from transformers import PretrainedConfig
from typing_extensions import deprecated, runtime_checkable from typing_extensions import Self, deprecated, runtime_checkable
import vllm.envs as envs import vllm.envs as envs
from vllm import version from vllm import version
@ -1537,7 +1537,6 @@ class CacheConfig:
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.swap_space_bytes = self.swap_space * GiB_bytes self.swap_space_bytes = self.swap_space * GiB_bytes
self._verify_args()
self._verify_cache_dtype() self._verify_cache_dtype()
self._verify_prefix_caching() self._verify_prefix_caching()
@ -1546,7 +1545,8 @@ class CacheConfig:
# metrics info # metrics info
return {key: str(value) for key, value in self.__dict__.items()} return {key: str(value) for key, value in self.__dict__.items()}
def _verify_args(self) -> None: @model_validator(mode='after')
def _verify_args(self) -> Self:
if self.cpu_offload_gb < 0: if self.cpu_offload_gb < 0:
raise ValueError("CPU offload space must be non-negative" raise ValueError("CPU offload space must be non-negative"
f", but got {self.cpu_offload_gb}") f", but got {self.cpu_offload_gb}")
@ -1556,6 +1556,8 @@ class CacheConfig:
"GPU memory utilization must be less than 1.0. Got " "GPU memory utilization must be less than 1.0. Got "
f"{self.gpu_memory_utilization}.") f"{self.gpu_memory_utilization}.")
return self
def _verify_cache_dtype(self) -> None: def _verify_cache_dtype(self) -> None:
if self.cache_dtype == "auto": if self.cache_dtype == "auto":
pass pass
@ -1942,15 +1944,14 @@ class ParallelConfig:
if self.distributed_executor_backend is None and self.world_size == 1: if self.distributed_executor_backend is None and self.world_size == 1:
self.distributed_executor_backend = "uni" self.distributed_executor_backend = "uni"
self._verify_args()
@property @property
def use_ray(self) -> bool: def use_ray(self) -> bool:
return self.distributed_executor_backend == "ray" or ( return self.distributed_executor_backend == "ray" or (
isinstance(self.distributed_executor_backend, type) isinstance(self.distributed_executor_backend, type)
and self.distributed_executor_backend.uses_ray) and self.distributed_executor_backend.uses_ray)
def _verify_args(self) -> None: @model_validator(mode='after')
def _verify_args(self) -> Self:
# Lazy import to avoid circular import # Lazy import to avoid circular import
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -1977,8 +1978,7 @@ class ParallelConfig:
raise ValueError("Unable to use nsight profiling unless workers " raise ValueError("Unable to use nsight profiling unless workers "
"run with Ray.") "run with Ray.")
assert isinstance(self.worker_extension_cls, str), ( return self
"worker_extension_cls must be a string (qualified class name).")
PreemptionMode = Literal["swap", "recompute"] PreemptionMode = Literal["swap", "recompute"]
@ -2202,9 +2202,8 @@ class SchedulerConfig:
self.max_num_partial_prefills, self.max_long_partial_prefills, self.max_num_partial_prefills, self.max_long_partial_prefills,
self.long_prefill_token_threshold) self.long_prefill_token_threshold)
self._verify_args() @model_validator(mode='after')
def _verify_args(self) -> Self:
def _verify_args(self) -> None:
if (self.max_num_batched_tokens < self.max_model_len if (self.max_num_batched_tokens < self.max_model_len
and not self.chunked_prefill_enabled): and not self.chunked_prefill_enabled):
raise ValueError( raise ValueError(
@ -2263,6 +2262,8 @@ class SchedulerConfig:
"must be greater than or equal to 1 and less than or equal to " "must be greater than or equal to 1 and less than or equal to "
f"max_num_partial_prefills ({self.max_num_partial_prefills}).") f"max_num_partial_prefills ({self.max_num_partial_prefills}).")
return self
@property @property
def is_multi_step(self) -> bool: def is_multi_step(self) -> bool:
return self.num_scheduler_steps > 1 return self.num_scheduler_steps > 1
@ -2669,8 +2670,6 @@ class SpeculativeConfig:
if self.posterior_alpha is None: if self.posterior_alpha is None:
self.posterior_alpha = 0.3 self.posterior_alpha = 0.3
self._verify_args()
@staticmethod @staticmethod
def _maybe_override_draft_max_model_len( def _maybe_override_draft_max_model_len(
speculative_max_model_len: Optional[int], speculative_max_model_len: Optional[int],
@ -2761,7 +2760,8 @@ class SpeculativeConfig:
return draft_parallel_config return draft_parallel_config
def _verify_args(self) -> None: @model_validator(mode='after')
def _verify_args(self) -> Self:
if self.num_speculative_tokens is None: if self.num_speculative_tokens is None:
raise ValueError( raise ValueError(
"num_speculative_tokens must be provided with " "num_speculative_tokens must be provided with "
@ -2812,6 +2812,8 @@ class SpeculativeConfig:
"Eagle3 is only supported for Llama models. " "Eagle3 is only supported for Llama models. "
f"Got {self.target_model_config.hf_text_config.model_type=}") f"Got {self.target_model_config.hf_text_config.model_type=}")
return self
@property @property
def num_lookahead_slots(self) -> int: def num_lookahead_slots(self) -> int:
"""The number of additional slots the scheduler should allocate per """The number of additional slots the scheduler should allocate per

View File

@ -3,7 +3,9 @@
# yapf: disable # yapf: disable
import argparse import argparse
import copy
import dataclasses import dataclasses
import functools
import json import json
import sys import sys
import threading import threading
@ -168,7 +170,8 @@ def get_type_hints(type_hint: TypeHint) -> set[TypeHint]:
return type_hints return type_hints
def get_kwargs(cls: ConfigType) -> dict[str, Any]: @functools.lru_cache(maxsize=30)
def _compute_kwargs(cls: ConfigType) -> dict[str, Any]:
cls_docs = get_attr_docs(cls) cls_docs = get_attr_docs(cls)
kwargs = {} kwargs = {}
for field in fields(cls): for field in fields(cls):
@ -269,6 +272,16 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
return kwargs return kwargs
def get_kwargs(cls: ConfigType) -> dict[str, Any]:
"""Return argparse kwargs for the given Config dataclass.
The heavy computation is cached via functools.lru_cache, and a deep copy
is returned so callers can mutate the dictionary without affecting the
cached version.
"""
return copy.deepcopy(_compute_kwargs(cls))
@dataclass @dataclass
class EngineArgs: class EngineArgs:
"""Arguments for vLLM engine.""" """Arguments for vLLM engine."""

View File

@ -1,10 +1,16 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import argparse import argparse
import typing
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
from vllm.entrypoints.cli.types import CLISubcommand from vllm.entrypoints.cli.types import CLISubcommand
from vllm.utils import FlexibleArgumentParser
if typing.TYPE_CHECKING:
from vllm.utils import FlexibleArgumentParser
class BenchmarkSubcommand(CLISubcommand): class BenchmarkSubcommand(CLISubcommand):
@ -23,7 +29,6 @@ class BenchmarkSubcommand(CLISubcommand):
def subparser_init( def subparser_init(
self, self,
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
bench_parser = subparsers.add_parser( bench_parser = subparsers.add_parser(
self.name, self.name,
help=self.help, help=self.help,

View File

@ -1,19 +1,21 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import argparse import argparse
import typing
from vllm.collect_env import main as collect_env_main from vllm.collect_env import main as collect_env_main
from vllm.entrypoints.cli.types import CLISubcommand from vllm.entrypoints.cli.types import CLISubcommand
from vllm.utils import FlexibleArgumentParser
if typing.TYPE_CHECKING:
from vllm.utils import FlexibleArgumentParser
class CollectEnvSubcommand(CLISubcommand): class CollectEnvSubcommand(CLISubcommand):
"""The `collect-env` subcommand for the vLLM CLI. """ """The `collect-env` subcommand for the vLLM CLI. """
name = "collect-env"
def __init__(self):
self.name = "collect-env"
super().__init__()
@staticmethod @staticmethod
def cmd(args: argparse.Namespace) -> None: def cmd(args: argparse.Namespace) -> None:
@ -23,12 +25,11 @@ class CollectEnvSubcommand(CLISubcommand):
def subparser_init( def subparser_init(
self, self,
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
collect_env_parser = subparsers.add_parser( return subparsers.add_parser(
"collect-env", "collect-env",
help="Start collecting environment information.", help="Start collecting environment information.",
description="Start collecting environment information.", description="Start collecting environment information.",
usage="vllm collect-env") usage="vllm collect-env")
return collect_env_parser
def cmd_init() -> list[CLISubcommand]: def cmd_init() -> list[CLISubcommand]:

View File

@ -1,27 +1,15 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
'''The CLI entrypoints of vLLM
# The CLI entrypoint to vLLM. Note that all future modules must be lazily loaded within main
to avoid certain eager import breakage.'''
from __future__ import annotations
import importlib.metadata
import signal import signal
import sys import sys
import vllm.entrypoints.cli.benchmark.main
import vllm.entrypoints.cli.collect_env
import vllm.entrypoints.cli.openai
import vllm.entrypoints.cli.run_batch
import vllm.entrypoints.cli.serve
import vllm.version
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG, cli_env_setup
from vllm.utils import FlexibleArgumentParser
CMD_MODULES = [
vllm.entrypoints.cli.openai,
vllm.entrypoints.cli.serve,
vllm.entrypoints.cli.benchmark.main,
vllm.entrypoints.cli.collect_env,
vllm.entrypoints.cli.run_batch,
]
def register_signal_handlers(): def register_signal_handlers():
@ -33,16 +21,34 @@ def register_signal_handlers():
def main(): def main():
import vllm.entrypoints.cli.benchmark.main
import vllm.entrypoints.cli.collect_env
import vllm.entrypoints.cli.openai
import vllm.entrypoints.cli.run_batch
import vllm.entrypoints.cli.serve
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG, cli_env_setup
from vllm.utils import FlexibleArgumentParser
CMD_MODULES = [
vllm.entrypoints.cli.openai,
vllm.entrypoints.cli.serve,
vllm.entrypoints.cli.benchmark.main,
vllm.entrypoints.cli.collect_env,
vllm.entrypoints.cli.run_batch,
]
cli_env_setup() cli_env_setup()
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description="vLLM CLI", description="vLLM CLI",
epilog=VLLM_SUBCMD_PARSER_EPILOG, epilog=VLLM_SUBCMD_PARSER_EPILOG,
) )
parser.add_argument('-v', parser.add_argument(
'--version', '-v',
action='version', '--version',
version=vllm.version.__version__) action='version',
version=importlib.metadata.version('vllm'),
)
subparsers = parser.add_subparsers(required=False, dest="subparser") subparsers = parser.add_subparsers(required=False, dest="subparser")
cmds = {} cmds = {}
for cmd_module in CMD_MODULES: for cmd_module in CMD_MODULES:

View File

@ -1,18 +1,21 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Commands that act as an interactive OpenAI API client
from __future__ import annotations
import argparse import argparse
import os import os
import signal import signal
import sys import sys
from typing import Optional from typing import TYPE_CHECKING
from openai import OpenAI from openai import OpenAI
from openai.types.chat import ChatCompletionMessageParam from openai.types.chat import ChatCompletionMessageParam
from vllm.entrypoints.cli.types import CLISubcommand from vllm.entrypoints.cli.types import CLISubcommand
from vllm.utils import FlexibleArgumentParser
if TYPE_CHECKING:
from vllm.utils import FlexibleArgumentParser
def _register_signal_handlers(): def _register_signal_handlers():
@ -42,8 +45,7 @@ def _interactive_cli(args: argparse.Namespace) -> tuple[str, OpenAI]:
return model_name, openai_client return model_name, openai_client
def chat(system_prompt: Optional[str], model_name: str, def chat(system_prompt: str | None, model_name: str, client: OpenAI) -> None:
client: OpenAI) -> None:
conversation: list[ChatCompletionMessageParam] = [] conversation: list[ChatCompletionMessageParam] = []
if system_prompt is not None: if system_prompt is not None:
conversation.append({"role": "system", "content": system_prompt}) conversation.append({"role": "system", "content": system_prompt})
@ -92,10 +94,7 @@ def _add_query_options(
class ChatCommand(CLISubcommand): class ChatCommand(CLISubcommand):
"""The `chat` subcommand for the vLLM CLI. """ """The `chat` subcommand for the vLLM CLI. """
name = "chat"
def __init__(self):
self.name = "chat"
super().__init__()
@staticmethod @staticmethod
def cmd(args: argparse.Namespace) -> None: def cmd(args: argparse.Namespace) -> None:
@ -157,10 +156,7 @@ class ChatCommand(CLISubcommand):
class CompleteCommand(CLISubcommand): class CompleteCommand(CLISubcommand):
"""The `complete` subcommand for the vLLM CLI. """ """The `complete` subcommand for the vLLM CLI. """
name = 'complete'
def __init__(self):
self.name = "complete"
super().__init__()
@staticmethod @staticmethod
def cmd(args: argparse.Namespace) -> None: def cmd(args: argparse.Namespace) -> None:

View File

@ -1,37 +1,42 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import argparse import argparse
import asyncio import asyncio
import importlib.metadata
from prometheus_client import start_http_server import typing
from vllm.entrypoints.cli.types import CLISubcommand from vllm.entrypoints.cli.types import CLISubcommand
from vllm.entrypoints.logger import logger
from vllm.entrypoints.openai.run_batch import main as run_batch_main
from vllm.entrypoints.openai.run_batch import make_arg_parser
from vllm.entrypoints.utils import (VLLM_SUBCMD_PARSER_EPILOG, from vllm.entrypoints.utils import (VLLM_SUBCMD_PARSER_EPILOG,
show_filtered_argument_or_group_from_help) show_filtered_argument_or_group_from_help)
from vllm.utils import FlexibleArgumentParser from vllm.logger import init_logger
from vllm.version import __version__ as VLLM_VERSION
if typing.TYPE_CHECKING:
from vllm.utils import FlexibleArgumentParser
logger = init_logger(__name__)
class RunBatchSubcommand(CLISubcommand): class RunBatchSubcommand(CLISubcommand):
"""The `run-batch` subcommand for vLLM CLI.""" """The `run-batch` subcommand for vLLM CLI."""
name = "run-batch"
def __init__(self):
self.name = "run-batch"
super().__init__()
@staticmethod @staticmethod
def cmd(args: argparse.Namespace) -> None: def cmd(args: argparse.Namespace) -> None:
logger.info("vLLM batch processing API version %s", VLLM_VERSION) from vllm.entrypoints.openai.run_batch import main as run_batch_main
logger.info("vLLM batch processing API version %s",
importlib.metadata.version("vllm"))
logger.info("args: %s", args) logger.info("args: %s", args)
# Start the Prometheus metrics server. # Start the Prometheus metrics server.
# LLMEngine uses the Prometheus client # LLMEngine uses the Prometheus client
# to publish metrics at the /metrics endpoint. # to publish metrics at the /metrics endpoint.
if args.enable_metrics: if args.enable_metrics:
from prometheus_client import start_http_server
logger.info("Prometheus metrics enabled") logger.info("Prometheus metrics enabled")
start_http_server(port=args.port, addr=args.url) start_http_server(port=args.port, addr=args.url)
else: else:
@ -42,6 +47,8 @@ class RunBatchSubcommand(CLISubcommand):
def subparser_init( def subparser_init(
self, self,
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
from vllm.entrypoints.openai.run_batch import make_arg_parser
run_batch_parser = subparsers.add_parser( run_batch_parser = subparsers.add_parser(
"run-batch", "run-batch",
help="Run batch prompts and write results to file.", help="Run batch prompts and write results to file.",

View File

@ -9,8 +9,8 @@ import sys
import uvloop import uvloop
import zmq import zmq
import vllm
import vllm.envs as envs import vllm.envs as envs
from vllm import AsyncEngineArgs
from vllm.entrypoints.cli.types import CLISubcommand from vllm.entrypoints.cli.types import CLISubcommand
from vllm.entrypoints.openai.api_server import (run_server, run_server_worker, from vllm.entrypoints.openai.api_server import (run_server, run_server_worker,
setup_server) setup_server)
@ -38,10 +38,7 @@ logger = init_logger(__name__)
class ServeSubcommand(CLISubcommand): class ServeSubcommand(CLISubcommand):
"""The `serve` subcommand for the vLLM CLI. """ """The `serve` subcommand for the vLLM CLI. """
name = "serve"
def __init__(self):
self.name = "serve"
super().__init__()
@staticmethod @staticmethod
def cmd(args: argparse.Namespace) -> None: def cmd(args: argparse.Namespace) -> None:
@ -115,7 +112,7 @@ def run_headless(args: argparse.Namespace):
raise ValueError("api_server_count can't be set in headless mode") raise ValueError("api_server_count can't be set in headless mode")
# Create the EngineConfig. # Create the EngineConfig.
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = vllm.AsyncEngineArgs.from_cli_args(args)
usage_context = UsageContext.OPENAI_API_SERVER usage_context = UsageContext.OPENAI_API_SERVER
vllm_config = engine_args.create_engine_config(usage_context=usage_context) vllm_config = engine_args.create_engine_config(usage_context=usage_context)
@ -175,7 +172,7 @@ def run_multi_api_server(args: argparse.Namespace):
listen_address, sock = setup_server(args) listen_address, sock = setup_server(args)
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = vllm.AsyncEngineArgs.from_cli_args(args)
usage_context = UsageContext.OPENAI_API_SERVER usage_context = UsageContext.OPENAI_API_SERVER
vllm_config = engine_args.create_engine_config(usage_context=usage_context) vllm_config = engine_args.create_engine_config(usage_context=usage_context)
model_config = vllm_config.model_config model_config = vllm_config.model_config

View File

@ -1,9 +1,13 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse from __future__ import annotations
from vllm.utils import FlexibleArgumentParser import argparse
import typing
if typing.TYPE_CHECKING:
from vllm.utils import FlexibleArgumentParser
class CLISubcommand: class CLISubcommand:

View File

@ -15,7 +15,7 @@ from tqdm import tqdm
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.logger import RequestLogger, logger from vllm.entrypoints.logger import RequestLogger
# yapf: disable # yapf: disable
from vllm.entrypoints.openai.protocol import (BatchRequestInput, from vllm.entrypoints.openai.protocol import (BatchRequestInput,
BatchRequestOutput, BatchRequestOutput,
@ -29,10 +29,13 @@ from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_models import (BaseModelPath, from vllm.entrypoints.openai.serving_models import (BaseModelPath,
OpenAIServingModels) OpenAIServingModels)
from vllm.entrypoints.openai.serving_score import ServingScores from vllm.entrypoints.openai.serving_score import ServingScores
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, random_uuid from vllm.utils import FlexibleArgumentParser, random_uuid
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
def make_arg_parser(parser: FlexibleArgumentParser): def make_arg_parser(parser: FlexibleArgumentParser):
parser.add_argument( parser.add_argument(
@ -201,13 +204,16 @@ async def upload_data(output_url: str, data_or_file: str,
except Exception as e: except Exception as e:
if attempt < max_retries: if attempt < max_retries:
logger.error( logger.error(
f"Failed to upload data (attempt {attempt}). " "Failed to upload data (attempt %d). Error message: %s.\nRetrying in %d seconds...", # noqa: E501
f"Error message: {str(e)}.\nRetrying in {delay} seconds..." attempt,
e,
delay,
) )
await asyncio.sleep(delay) await asyncio.sleep(delay)
else: else:
raise Exception(f"Failed to upload data (attempt {attempt}). " raise Exception(
f"Error message: {str(e)}.") from e f"Failed to upload data (attempt {attempt}). Error message: {str(e)}." # noqa: E501
) from e
async def write_file(path_or_url: str, batch_outputs: list[BatchRequestOutput], async def write_file(path_or_url: str, batch_outputs: list[BatchRequestOutput],

View File

@ -67,9 +67,6 @@ from torch.library import Library
from typing_extensions import Never, ParamSpec, TypeIs, assert_never from typing_extensions import Never, ParamSpec, TypeIs, assert_never
import vllm.envs as envs import vllm.envs as envs
# NOTE: import triton_utils to make TritonPlaceholderModule work
# if triton is unavailable
import vllm.triton_utils # noqa: F401
from vllm.logger import enable_trace_function_call, init_logger from vllm.logger import enable_trace_function_call, init_logger
if TYPE_CHECKING: if TYPE_CHECKING: