mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:05:01 +08:00
Improve conversion from dataclass configs to argparse arguments (#17303)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
72dfe4c74f
commit
f94886946e
@ -1,14 +1,118 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import json
|
||||||
from argparse import ArgumentError, ArgumentTypeError
|
from argparse import ArgumentError, ArgumentTypeError
|
||||||
|
from contextlib import nullcontext
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Literal, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.config import PoolerConfig
|
from vllm.config import PoolerConfig, config
|
||||||
from vllm.engine.arg_utils import EngineArgs, nullable_kvs
|
from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs,
|
||||||
|
get_type, is_not_builtin, is_type,
|
||||||
|
nullable_kvs, optional_type)
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(("type", "value", "expected"), [
|
||||||
|
(int, "42", 42),
|
||||||
|
(int, "None", None),
|
||||||
|
(float, "3.14", 3.14),
|
||||||
|
(float, "None", None),
|
||||||
|
(str, "Hello World!", "Hello World!"),
|
||||||
|
(str, "None", None),
|
||||||
|
(json.loads, '{"foo":1,"bar":2}', {
|
||||||
|
"foo": 1,
|
||||||
|
"bar": 2
|
||||||
|
}),
|
||||||
|
(json.loads, "foo=1,bar=2", {
|
||||||
|
"foo": 1,
|
||||||
|
"bar": 2
|
||||||
|
}),
|
||||||
|
(json.loads, "None", None),
|
||||||
|
])
|
||||||
|
def test_optional_type(type, value, expected):
|
||||||
|
optional_type_func = optional_type(type)
|
||||||
|
context = nullcontext()
|
||||||
|
if value == "foo=1,bar=2":
|
||||||
|
context = pytest.warns(DeprecationWarning)
|
||||||
|
with context:
|
||||||
|
assert optional_type_func(value) == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(("type_hint", "type", "expected"), [
|
||||||
|
(int, int, True),
|
||||||
|
(int, float, False),
|
||||||
|
(list[int], list, True),
|
||||||
|
(list[int], tuple, False),
|
||||||
|
(Literal[0, 1], Literal, True),
|
||||||
|
])
|
||||||
|
def test_is_type(type_hint, type, expected):
|
||||||
|
assert is_type(type_hint, type) == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(("type_hints", "type", "expected"), [
|
||||||
|
({float, int}, int, True),
|
||||||
|
({int, tuple[int]}, int, True),
|
||||||
|
({int, tuple[int]}, float, False),
|
||||||
|
({str, Literal["x", "y"]}, Literal, True),
|
||||||
|
])
|
||||||
|
def test_contains_type(type_hints, type, expected):
|
||||||
|
assert contains_type(type_hints, type) == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(("type_hints", "type", "expected"), [
|
||||||
|
({int, float}, int, int),
|
||||||
|
({int, float}, str, None),
|
||||||
|
({str, Literal["x", "y"]}, Literal, Literal["x", "y"]),
|
||||||
|
])
|
||||||
|
def test_get_type(type_hints, type, expected):
|
||||||
|
assert get_type(type_hints, type) == expected
|
||||||
|
|
||||||
|
|
||||||
|
@config
|
||||||
|
@dataclass
|
||||||
|
class DummyConfigClass:
|
||||||
|
regular_bool: bool = True
|
||||||
|
"""Regular bool with default True"""
|
||||||
|
optional_bool: Optional[bool] = None
|
||||||
|
"""Optional bool with default None"""
|
||||||
|
optional_literal: Optional[Literal["x", "y"]] = None
|
||||||
|
"""Optional literal with default None"""
|
||||||
|
tuple_n: tuple[int, ...] = field(default_factory=lambda: (1, 2, 3))
|
||||||
|
"""Tuple with default (1, 2, 3)"""
|
||||||
|
tuple_2: tuple[int, int] = field(default_factory=lambda: (1, 2))
|
||||||
|
"""Tuple with default (1, 2)"""
|
||||||
|
list_n: list[int] = field(default_factory=lambda: [1, 2, 3])
|
||||||
|
"""List with default [1, 2, 3]"""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(("type_hint", "expected"), [
|
||||||
|
(int, False),
|
||||||
|
(DummyConfigClass, True),
|
||||||
|
])
|
||||||
|
def test_is_not_builtin(type_hint, expected):
|
||||||
|
assert is_not_builtin(type_hint) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_kwargs():
|
||||||
|
kwargs = get_kwargs(DummyConfigClass)
|
||||||
|
print(kwargs)
|
||||||
|
|
||||||
|
# bools should not have their type set
|
||||||
|
assert kwargs["regular_bool"].get("type") is None
|
||||||
|
assert kwargs["optional_bool"].get("type") is None
|
||||||
|
# optional literals should have None as a choice
|
||||||
|
assert kwargs["optional_literal"]["choices"] == ["x", "y", "None"]
|
||||||
|
# tuples should have the correct nargs
|
||||||
|
assert kwargs["tuple_n"]["nargs"] == "+"
|
||||||
|
assert kwargs["tuple_2"]["nargs"] == 2
|
||||||
|
# lists should work
|
||||||
|
assert kwargs["list_n"]["type"] is int
|
||||||
|
assert kwargs["list_n"]["nargs"] == "+"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(("arg", "expected"), [
|
@pytest.mark.parametrize(("arg", "expected"), [
|
||||||
(None, dict()),
|
(None, dict()),
|
||||||
("image=16", {
|
("image=16", {
|
||||||
|
|||||||
@ -11,7 +11,7 @@ from typing import (Any, Callable, Dict, List, Literal, Optional, Type,
|
|||||||
TypeVar, Union, cast, get_args, get_origin)
|
TypeVar, Union, cast, get_args, get_origin)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing_extensions import TypeIs
|
from typing_extensions import TypeIs, deprecated
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm import version
|
from vllm import version
|
||||||
@ -48,33 +48,29 @@ TypeHint = Union[type[Any], object]
|
|||||||
TypeHintT = Union[type[T], object]
|
TypeHintT = Union[type[T], object]
|
||||||
|
|
||||||
|
|
||||||
def optional_arg(val: str, return_type: Callable[[str], T]) -> Optional[T]:
|
def optional_type(
|
||||||
if val == "" or val == "None":
|
return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]:
|
||||||
return None
|
|
||||||
try:
|
def _optional_type(val: str) -> Optional[T]:
|
||||||
return return_type(val)
|
if val == "" or val == "None":
|
||||||
except ValueError as e:
|
return None
|
||||||
raise argparse.ArgumentTypeError(
|
try:
|
||||||
f"Value {val} cannot be converted to {return_type}.") from e
|
if return_type is json.loads and not re.match("^{.*}$", val):
|
||||||
|
return cast(T, nullable_kvs(val))
|
||||||
|
return return_type(val)
|
||||||
|
except ValueError as e:
|
||||||
|
raise argparse.ArgumentTypeError(
|
||||||
|
f"Value {val} cannot be converted to {return_type}.") from e
|
||||||
|
|
||||||
|
return _optional_type
|
||||||
|
|
||||||
|
|
||||||
def optional_str(val: str) -> Optional[str]:
|
@deprecated(
|
||||||
return optional_arg(val, str)
|
"Passing a JSON argument as a string containing comma separated key=value "
|
||||||
|
"pairs is deprecated. This will be removed in v0.10.0. Please use a JSON "
|
||||||
|
"string instead.")
|
||||||
def optional_int(val: str) -> Optional[int]:
|
def nullable_kvs(val: str) -> dict[str, int]:
|
||||||
return optional_arg(val, int)
|
"""Parses a string containing comma separate key [str] to value [int]
|
||||||
|
|
||||||
|
|
||||||
def optional_float(val: str) -> Optional[float]:
|
|
||||||
return optional_arg(val, float)
|
|
||||||
|
|
||||||
|
|
||||||
def nullable_kvs(val: str) -> Optional[dict[str, int]]:
|
|
||||||
"""NOTE: This function is deprecated, args should be passed as JSON
|
|
||||||
strings instead.
|
|
||||||
|
|
||||||
Parses a string containing comma separate key [str] to value [int]
|
|
||||||
pairs into a dictionary.
|
pairs into a dictionary.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -83,10 +79,7 @@ def nullable_kvs(val: str) -> Optional[dict[str, int]]:
|
|||||||
Returns:
|
Returns:
|
||||||
Dictionary with parsed values.
|
Dictionary with parsed values.
|
||||||
"""
|
"""
|
||||||
if len(val) == 0:
|
out_dict: dict[str, int] = {}
|
||||||
return None
|
|
||||||
|
|
||||||
out_dict: Dict[str, int] = {}
|
|
||||||
for item in val.split(","):
|
for item in val.split(","):
|
||||||
kv_parts = [part.lower().strip() for part in item.split("=")]
|
kv_parts = [part.lower().strip() for part in item.split("=")]
|
||||||
if len(kv_parts) != 2:
|
if len(kv_parts) != 2:
|
||||||
@ -108,15 +101,103 @@ def nullable_kvs(val: str) -> Optional[dict[str, int]]:
|
|||||||
return out_dict
|
return out_dict
|
||||||
|
|
||||||
|
|
||||||
def optional_dict(val: str) -> Optional[dict[str, int]]:
|
def is_type(type_hint: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]:
|
||||||
if re.match("^{.*}$", val):
|
"""Check if the type hint is a specific type."""
|
||||||
return optional_arg(val, json.loads)
|
return type_hint is type or get_origin(type_hint) is type
|
||||||
|
|
||||||
logger.warning(
|
|
||||||
"Failed to parse JSON string. Attempting to parse as "
|
def contains_type(type_hints: set[TypeHint], type: TypeHintT) -> bool:
|
||||||
"comma-separated key=value pairs. This will be deprecated in a "
|
"""Check if the type hints contain a specific type."""
|
||||||
"future release.")
|
return any(is_type(type_hint, type) for type_hint in type_hints)
|
||||||
return nullable_kvs(val)
|
|
||||||
|
|
||||||
|
def get_type(type_hints: set[TypeHint], type: TypeHintT) -> TypeHintT:
|
||||||
|
"""Get the specific type from the type hints."""
|
||||||
|
return next((th for th in type_hints if is_type(th, type)), None)
|
||||||
|
|
||||||
|
|
||||||
|
def is_not_builtin(type_hint: TypeHint) -> bool:
|
||||||
|
"""Check if the class is not a built-in type."""
|
||||||
|
return type_hint.__module__ != "builtins"
|
||||||
|
|
||||||
|
|
||||||
|
def get_kwargs(cls: ConfigType) -> dict[str, Any]:
|
||||||
|
cls_docs = get_attr_docs(cls)
|
||||||
|
kwargs = {}
|
||||||
|
for field in fields(cls):
|
||||||
|
# Get the default value of the field
|
||||||
|
default = field.default
|
||||||
|
if field.default_factory is not MISSING:
|
||||||
|
default = field.default_factory()
|
||||||
|
|
||||||
|
# Get the help text for the field
|
||||||
|
name = field.name
|
||||||
|
help = cls_docs[name]
|
||||||
|
# Escape % for argparse
|
||||||
|
help = help.replace("%", "%%")
|
||||||
|
|
||||||
|
# Initialise the kwargs dictionary for the field
|
||||||
|
kwargs[name] = {"default": default, "help": help}
|
||||||
|
|
||||||
|
# Get the set of possible types for the field
|
||||||
|
type_hints: set[TypeHint] = set()
|
||||||
|
if get_origin(field.type) is Union:
|
||||||
|
type_hints.update(get_args(field.type))
|
||||||
|
else:
|
||||||
|
type_hints.add(field.type)
|
||||||
|
|
||||||
|
# Set other kwargs based on the type hints
|
||||||
|
if contains_type(type_hints, bool):
|
||||||
|
# Creates --no-<name> and --<name> flags
|
||||||
|
kwargs[name]["action"] = argparse.BooleanOptionalAction
|
||||||
|
elif contains_type(type_hints, Literal):
|
||||||
|
# Creates choices from Literal arguments
|
||||||
|
type_hint = get_type(type_hints, Literal)
|
||||||
|
choices = sorted(get_args(type_hint))
|
||||||
|
kwargs[name]["choices"] = choices
|
||||||
|
choice_type = type(choices[0])
|
||||||
|
assert all(type(c) is choice_type for c in choices), (
|
||||||
|
"All choices must be of the same type. "
|
||||||
|
f"Got {choices} with types {[type(c) for c in choices]}")
|
||||||
|
kwargs[name]["type"] = choice_type
|
||||||
|
elif contains_type(type_hints, tuple):
|
||||||
|
type_hint = get_type(type_hints, tuple)
|
||||||
|
types = get_args(type_hint)
|
||||||
|
tuple_type = types[0]
|
||||||
|
assert all(t is tuple_type for t in types if t is not Ellipsis), (
|
||||||
|
"All non-Ellipsis tuple elements must be of the same "
|
||||||
|
f"type. Got {types}.")
|
||||||
|
kwargs[name]["type"] = tuple_type
|
||||||
|
kwargs[name]["nargs"] = "+" if Ellipsis in types else len(types)
|
||||||
|
elif contains_type(type_hints, list):
|
||||||
|
type_hint = get_type(type_hints, list)
|
||||||
|
types = get_args(type_hint)
|
||||||
|
assert len(types) == 1, (
|
||||||
|
"List type must have exactly one type. Got "
|
||||||
|
f"{type_hint} with types {types}")
|
||||||
|
kwargs[name]["type"] = types[0]
|
||||||
|
kwargs[name]["nargs"] = "+"
|
||||||
|
elif contains_type(type_hints, int):
|
||||||
|
kwargs[name]["type"] = int
|
||||||
|
elif contains_type(type_hints, float):
|
||||||
|
kwargs[name]["type"] = float
|
||||||
|
elif contains_type(type_hints, dict):
|
||||||
|
# Dict arguments will always be optional
|
||||||
|
kwargs[name]["type"] = optional_type(json.loads)
|
||||||
|
elif (contains_type(type_hints, str)
|
||||||
|
or any(is_not_builtin(th) for th in type_hints)):
|
||||||
|
kwargs[name]["type"] = str
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported type {type_hints} for argument {name}.")
|
||||||
|
|
||||||
|
# If None is in type_hints, make the argument optional.
|
||||||
|
# But not if it's a bool, argparse will handle this better.
|
||||||
|
if type(None) in type_hints and not contains_type(type_hints, bool):
|
||||||
|
kwargs[name]["type"] = optional_type(kwargs[name]["type"])
|
||||||
|
if kwargs[name].get("choices"):
|
||||||
|
kwargs[name]["choices"].append("None")
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -279,100 +360,6 @@ class EngineArgs:
|
|||||||
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||||
"""Shared CLI arguments for vLLM engine."""
|
"""Shared CLI arguments for vLLM engine."""
|
||||||
|
|
||||||
def is_type_in_union(cls: TypeHint, type: TypeHint) -> bool:
|
|
||||||
"""Check if the class is a type in a union type."""
|
|
||||||
is_union = get_origin(cls) is Union
|
|
||||||
type_in_union = type in [get_origin(a) or a for a in get_args(cls)]
|
|
||||||
return is_union and type_in_union
|
|
||||||
|
|
||||||
def get_type_from_union(cls: TypeHint, type: TypeHintT) -> TypeHintT:
|
|
||||||
"""Get the type in a union type."""
|
|
||||||
for arg in get_args(cls):
|
|
||||||
if (get_origin(arg) or arg) is type:
|
|
||||||
return arg
|
|
||||||
raise ValueError(f"Type {type} not found in union type {cls}.")
|
|
||||||
|
|
||||||
def is_optional(cls: TypeHint) -> TypeIs[Union[Any, None]]:
|
|
||||||
"""Check if the class is an optional type."""
|
|
||||||
return is_type_in_union(cls, type(None))
|
|
||||||
|
|
||||||
def can_be_type(cls: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]:
|
|
||||||
"""Check if the class can be of type."""
|
|
||||||
return cls is type or get_origin(cls) is type or is_type_in_union(
|
|
||||||
cls, type)
|
|
||||||
|
|
||||||
def is_custom_type(cls: TypeHint) -> bool:
|
|
||||||
"""Check if the class is a custom type."""
|
|
||||||
return cls.__module__ != "builtins"
|
|
||||||
|
|
||||||
def get_kwargs(cls: ConfigType) -> dict[str, Any]:
|
|
||||||
cls_docs = get_attr_docs(cls)
|
|
||||||
kwargs = {}
|
|
||||||
for field in fields(cls):
|
|
||||||
# Get the default value of the field
|
|
||||||
default = field.default
|
|
||||||
if field.default_factory is not MISSING:
|
|
||||||
default = field.default_factory()
|
|
||||||
|
|
||||||
# Get the help text for the field
|
|
||||||
name = field.name
|
|
||||||
help = cls_docs[name]
|
|
||||||
# Escape % for argparse
|
|
||||||
help = help.replace("%", "%%")
|
|
||||||
|
|
||||||
# Initialise the kwargs dictionary for the field
|
|
||||||
kwargs[name] = {"default": default, "help": help}
|
|
||||||
|
|
||||||
# Make note of if the field is optional and get the actual
|
|
||||||
# type of the field if it is
|
|
||||||
optional = is_optional(field.type)
|
|
||||||
field_type = get_args(
|
|
||||||
field.type)[0] if optional else field.type
|
|
||||||
|
|
||||||
# Set type, action and choices for the field depending on the
|
|
||||||
# type of the field
|
|
||||||
if can_be_type(field_type, bool):
|
|
||||||
# Creates --no-<name> and --<name> flags
|
|
||||||
kwargs[name]["action"] = argparse.BooleanOptionalAction
|
|
||||||
kwargs[name]["type"] = bool
|
|
||||||
elif can_be_type(field_type, Literal):
|
|
||||||
# Creates choices from Literal arguments
|
|
||||||
if is_type_in_union(field_type, Literal):
|
|
||||||
field_type = get_type_from_union(field_type, Literal)
|
|
||||||
choices = get_args(field_type)
|
|
||||||
kwargs[name]["choices"] = choices
|
|
||||||
choice_type = type(choices[0])
|
|
||||||
assert all(type(c) is choice_type for c in choices), (
|
|
||||||
"All choices must be of the same type. "
|
|
||||||
f"Got {choices} with types {[type(c) for c in choices]}"
|
|
||||||
)
|
|
||||||
kwargs[name]["type"] = choice_type
|
|
||||||
elif can_be_type(field_type, tuple):
|
|
||||||
if is_type_in_union(field_type, tuple):
|
|
||||||
field_type = get_type_from_union(field_type, tuple)
|
|
||||||
dtypes = get_args(field_type)
|
|
||||||
dtype = dtypes[0]
|
|
||||||
assert all(
|
|
||||||
d is dtype for d in dtypes if d is not Ellipsis
|
|
||||||
), ("All non-Ellipsis tuple elements must be of the same "
|
|
||||||
f"type. Got {dtypes}.")
|
|
||||||
kwargs[name]["type"] = dtype
|
|
||||||
kwargs[name]["nargs"] = "+"
|
|
||||||
elif can_be_type(field_type, int):
|
|
||||||
kwargs[name]["type"] = optional_int if optional else int
|
|
||||||
elif can_be_type(field_type, float):
|
|
||||||
kwargs[name][
|
|
||||||
"type"] = optional_float if optional else float
|
|
||||||
elif can_be_type(field_type, dict):
|
|
||||||
kwargs[name]["type"] = optional_dict
|
|
||||||
elif (can_be_type(field_type, str)
|
|
||||||
or is_custom_type(field_type)):
|
|
||||||
kwargs[name]["type"] = optional_str if optional else str
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unsupported type {field.type} for argument {name}. ")
|
|
||||||
return kwargs
|
|
||||||
|
|
||||||
# Model arguments
|
# Model arguments
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--model',
|
'--model',
|
||||||
@ -390,13 +377,13 @@ class EngineArgs:
|
|||||||
'which task to use.')
|
'which task to use.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--tokenizer',
|
'--tokenizer',
|
||||||
type=optional_str,
|
type=optional_type(str),
|
||||||
default=EngineArgs.tokenizer,
|
default=EngineArgs.tokenizer,
|
||||||
help='Name or path of the huggingface tokenizer to use. '
|
help='Name or path of the huggingface tokenizer to use. '
|
||||||
'If unspecified, model name or path will be used.')
|
'If unspecified, model name or path will be used.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--hf-config-path",
|
"--hf-config-path",
|
||||||
type=optional_str,
|
type=optional_type(str),
|
||||||
default=EngineArgs.hf_config_path,
|
default=EngineArgs.hf_config_path,
|
||||||
help='Name or path of the huggingface config to use. '
|
help='Name or path of the huggingface config to use. '
|
||||||
'If unspecified, model name or path will be used.')
|
'If unspecified, model name or path will be used.')
|
||||||
@ -408,21 +395,21 @@ class EngineArgs:
|
|||||||
'the input. The generated output will contain token ids.')
|
'the input. The generated output will contain token ids.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--revision',
|
'--revision',
|
||||||
type=optional_str,
|
type=optional_type(str),
|
||||||
default=None,
|
default=None,
|
||||||
help='The specific model version to use. It can be a branch '
|
help='The specific model version to use. It can be a branch '
|
||||||
'name, a tag name, or a commit id. If unspecified, will use '
|
'name, a tag name, or a commit id. If unspecified, will use '
|
||||||
'the default version.')
|
'the default version.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--code-revision',
|
'--code-revision',
|
||||||
type=optional_str,
|
type=optional_type(str),
|
||||||
default=None,
|
default=None,
|
||||||
help='The specific revision to use for the model code on '
|
help='The specific revision to use for the model code on '
|
||||||
'Hugging Face Hub. It can be a branch name, a tag name, or a '
|
'Hugging Face Hub. It can be a branch name, a tag name, or a '
|
||||||
'commit id. If unspecified, will use the default version.')
|
'commit id. If unspecified, will use the default version.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--tokenizer-revision',
|
'--tokenizer-revision',
|
||||||
type=optional_str,
|
type=optional_type(str),
|
||||||
default=None,
|
default=None,
|
||||||
help='Revision of the huggingface tokenizer to use. '
|
help='Revision of the huggingface tokenizer to use. '
|
||||||
'It can be a branch name, a tag name, or a commit id. '
|
'It can be a branch name, a tag name, or a commit id. '
|
||||||
@ -513,7 +500,7 @@ class EngineArgs:
|
|||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--logits-processor-pattern',
|
'--logits-processor-pattern',
|
||||||
type=optional_str,
|
type=optional_type(str),
|
||||||
default=None,
|
default=None,
|
||||||
help='Optional regex pattern specifying valid logits processor '
|
help='Optional regex pattern specifying valid logits processor '
|
||||||
'qualified names that can be passed with the `logits_processors` '
|
'qualified names that can be passed with the `logits_processors` '
|
||||||
@ -612,7 +599,7 @@ class EngineArgs:
|
|||||||
# Quantization settings.
|
# Quantization settings.
|
||||||
parser.add_argument('--quantization',
|
parser.add_argument('--quantization',
|
||||||
'-q',
|
'-q',
|
||||||
type=optional_str,
|
type=optional_type(str),
|
||||||
choices=[*QUANTIZATION_METHODS, None],
|
choices=[*QUANTIZATION_METHODS, None],
|
||||||
default=EngineArgs.quantization,
|
default=EngineArgs.quantization,
|
||||||
help='Method used to quantize the weights. If '
|
help='Method used to quantize the weights. If '
|
||||||
@ -921,7 +908,7 @@ class EngineArgs:
|
|||||||
'class without changing the existing functions.')
|
'class without changing the existing functions.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--generation-config",
|
"--generation-config",
|
||||||
type=optional_str,
|
type=optional_type(str),
|
||||||
default="auto",
|
default="auto",
|
||||||
help="The folder path to the generation config. "
|
help="The folder path to the generation config. "
|
||||||
"Defaults to 'auto', the generation config will be loaded from "
|
"Defaults to 'auto', the generation config will be loaded from "
|
||||||
|
|||||||
@ -11,7 +11,7 @@ import ssl
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Optional, Union, get_args
|
from typing import Optional, Union, get_args
|
||||||
|
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs, optional_str
|
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
|
||||||
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
|
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
|
||||||
validate_chat_template)
|
validate_chat_template)
|
||||||
from vllm.entrypoints.openai.serving_models import (LoRAModulePath,
|
from vllm.entrypoints.openai.serving_models import (LoRAModulePath,
|
||||||
@ -79,7 +79,7 @@ class PromptAdapterParserAction(argparse.Action):
|
|||||||
|
|
||||||
def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||||
parser.add_argument("--host",
|
parser.add_argument("--host",
|
||||||
type=optional_str,
|
type=optional_type(str),
|
||||||
default=None,
|
default=None,
|
||||||
help="Host name.")
|
help="Host name.")
|
||||||
parser.add_argument("--port", type=int, default=8000, help="Port number.")
|
parser.add_argument("--port", type=int, default=8000, help="Port number.")
|
||||||
@ -108,13 +108,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
|||||||
default=["*"],
|
default=["*"],
|
||||||
help="Allowed headers.")
|
help="Allowed headers.")
|
||||||
parser.add_argument("--api-key",
|
parser.add_argument("--api-key",
|
||||||
type=optional_str,
|
type=optional_type(str),
|
||||||
default=None,
|
default=None,
|
||||||
help="If provided, the server will require this key "
|
help="If provided, the server will require this key "
|
||||||
"to be presented in the header.")
|
"to be presented in the header.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lora-modules",
|
"--lora-modules",
|
||||||
type=optional_str,
|
type=optional_type(str),
|
||||||
default=None,
|
default=None,
|
||||||
nargs='+',
|
nargs='+',
|
||||||
action=LoRAParserAction,
|
action=LoRAParserAction,
|
||||||
@ -126,14 +126,14 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
|||||||
"\"base_model_name\": \"id\"}``")
|
"\"base_model_name\": \"id\"}``")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prompt-adapters",
|
"--prompt-adapters",
|
||||||
type=optional_str,
|
type=optional_type(str),
|
||||||
default=None,
|
default=None,
|
||||||
nargs='+',
|
nargs='+',
|
||||||
action=PromptAdapterParserAction,
|
action=PromptAdapterParserAction,
|
||||||
help="Prompt adapter configurations in the format name=path. "
|
help="Prompt adapter configurations in the format name=path. "
|
||||||
"Multiple adapters can be specified.")
|
"Multiple adapters can be specified.")
|
||||||
parser.add_argument("--chat-template",
|
parser.add_argument("--chat-template",
|
||||||
type=optional_str,
|
type=optional_type(str),
|
||||||
default=None,
|
default=None,
|
||||||
help="The file path to the chat template, "
|
help="The file path to the chat template, "
|
||||||
"or the template in single-line form "
|
"or the template in single-line form "
|
||||||
@ -151,20 +151,20 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
|||||||
'similar to OpenAI schema. '
|
'similar to OpenAI schema. '
|
||||||
'Example: ``[{"type": "text", "text": "Hello world!"}]``')
|
'Example: ``[{"type": "text", "text": "Hello world!"}]``')
|
||||||
parser.add_argument("--response-role",
|
parser.add_argument("--response-role",
|
||||||
type=optional_str,
|
type=optional_type(str),
|
||||||
default="assistant",
|
default="assistant",
|
||||||
help="The role name to return if "
|
help="The role name to return if "
|
||||||
"``request.add_generation_prompt=true``.")
|
"``request.add_generation_prompt=true``.")
|
||||||
parser.add_argument("--ssl-keyfile",
|
parser.add_argument("--ssl-keyfile",
|
||||||
type=optional_str,
|
type=optional_type(str),
|
||||||
default=None,
|
default=None,
|
||||||
help="The file path to the SSL key file.")
|
help="The file path to the SSL key file.")
|
||||||
parser.add_argument("--ssl-certfile",
|
parser.add_argument("--ssl-certfile",
|
||||||
type=optional_str,
|
type=optional_type(str),
|
||||||
default=None,
|
default=None,
|
||||||
help="The file path to the SSL cert file.")
|
help="The file path to the SSL cert file.")
|
||||||
parser.add_argument("--ssl-ca-certs",
|
parser.add_argument("--ssl-ca-certs",
|
||||||
type=optional_str,
|
type=optional_type(str),
|
||||||
default=None,
|
default=None,
|
||||||
help="The CA certificates file.")
|
help="The CA certificates file.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -180,13 +180,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--root-path",
|
"--root-path",
|
||||||
type=optional_str,
|
type=optional_type(str),
|
||||||
default=None,
|
default=None,
|
||||||
help="FastAPI root_path when app is behind a path based routing proxy."
|
help="FastAPI root_path when app is behind a path based routing proxy."
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--middleware",
|
"--middleware",
|
||||||
type=optional_str,
|
type=optional_type(str),
|
||||||
action="append",
|
action="append",
|
||||||
default=[],
|
default=[],
|
||||||
help="Additional ASGI middleware to apply to the app. "
|
help="Additional ASGI middleware to apply to the app. "
|
||||||
|
|||||||
@ -12,7 +12,7 @@ import torch
|
|||||||
from prometheus_client import start_http_server
|
from prometheus_client import start_http_server
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs, optional_str
|
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
from vllm.entrypoints.logger import RequestLogger, logger
|
from vllm.entrypoints.logger import RequestLogger, logger
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
@ -61,7 +61,7 @@ def parse_args():
|
|||||||
"to the output URL.",
|
"to the output URL.",
|
||||||
)
|
)
|
||||||
parser.add_argument("--response-role",
|
parser.add_argument("--response-role",
|
||||||
type=optional_str,
|
type=optional_type(str),
|
||||||
default="assistant",
|
default="assistant",
|
||||||
help="The role name to return if "
|
help="The role name to return if "
|
||||||
"`request.add_generation_prompt=True`.")
|
"`request.add_generation_prompt=True`.")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user