mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 01:15:31 +08:00
[Bugfix] Fix load config when using bools (#9533)
This commit is contained in:
parent
e130c40e4e
commit
34a9941620
@ -1,3 +1,5 @@
|
|||||||
port: 12312
|
port: 12312
|
||||||
served_model_name: mymodel
|
served_model_name: mymodel
|
||||||
tensor_parallel_size: 2
|
tensor_parallel_size: 2
|
||||||
|
trust_remote_code: true
|
||||||
|
multi_step_stream_outputs: false
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from typing import AsyncIterator, Tuple
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.utils import (FlexibleArgumentParser, deprecate_kwargs,
|
from vllm.utils import (FlexibleArgumentParser, StoreBoolean, deprecate_kwargs,
|
||||||
get_open_port, merge_async_iterators, supports_kw)
|
get_open_port, merge_async_iterators, supports_kw)
|
||||||
|
|
||||||
from .utils import error_on_warning
|
from .utils import error_on_warning
|
||||||
@ -141,6 +141,8 @@ def parser_with_config():
|
|||||||
parser.add_argument('--config', type=str)
|
parser.add_argument('--config', type=str)
|
||||||
parser.add_argument('--port', type=int)
|
parser.add_argument('--port', type=int)
|
||||||
parser.add_argument('--tensor-parallel-size', type=int)
|
parser.add_argument('--tensor-parallel-size', type=int)
|
||||||
|
parser.add_argument('--trust-remote-code', action='store_true')
|
||||||
|
parser.add_argument('--multi-step-stream-outputs', action=StoreBoolean)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -214,6 +216,8 @@ def test_config_args(parser_with_config):
|
|||||||
args = parser_with_config.parse_args(
|
args = parser_with_config.parse_args(
|
||||||
['serve', 'mymodel', '--config', './data/test_config.yaml'])
|
['serve', 'mymodel', '--config', './data/test_config.yaml'])
|
||||||
assert args.tensor_parallel_size == 2
|
assert args.tensor_parallel_size == 2
|
||||||
|
assert args.trust_remote_code
|
||||||
|
assert not args.multi_step_stream_outputs
|
||||||
|
|
||||||
|
|
||||||
def test_config_file(parser_with_config):
|
def test_config_file(parser_with_config):
|
||||||
|
|||||||
@ -19,7 +19,7 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
|||||||
from vllm.transformers_utils.config import (
|
from vllm.transformers_utils.config import (
|
||||||
maybe_register_config_serialize_by_value)
|
maybe_register_config_serialize_by_value)
|
||||||
from vllm.transformers_utils.utils import check_gguf_file
|
from vllm.transformers_utils.utils import check_gguf_file
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser, StoreBoolean
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
||||||
@ -1144,18 +1144,6 @@ class AsyncEngineArgs(EngineArgs):
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
class StoreBoolean(argparse.Action):
|
|
||||||
|
|
||||||
def __call__(self, parser, namespace, values, option_string=None):
|
|
||||||
if values.lower() == "true":
|
|
||||||
setattr(namespace, self.dest, True)
|
|
||||||
elif values.lower() == "false":
|
|
||||||
setattr(namespace, self.dest, False)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid boolean value: {values}. "
|
|
||||||
"Expected 'true' or 'false'.")
|
|
||||||
|
|
||||||
|
|
||||||
# These functions are used by sphinx to build the documentation
|
# These functions are used by sphinx to build the documentation
|
||||||
def _engine_args_parser():
|
def _engine_args_parser():
|
||||||
return EngineArgs.add_cli_args(FlexibleArgumentParser())
|
return EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||||
|
|||||||
@ -1155,6 +1155,18 @@ def run_once(f: Callable[P, None]) -> Callable[P, None]:
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
class StoreBoolean(argparse.Action):
|
||||||
|
|
||||||
|
def __call__(self, parser, namespace, values, option_string=None):
|
||||||
|
if values.lower() == "true":
|
||||||
|
setattr(namespace, self.dest, True)
|
||||||
|
elif values.lower() == "false":
|
||||||
|
setattr(namespace, self.dest, False)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid boolean value: {values}. "
|
||||||
|
"Expected 'true' or 'false'.")
|
||||||
|
|
||||||
|
|
||||||
class FlexibleArgumentParser(argparse.ArgumentParser):
|
class FlexibleArgumentParser(argparse.ArgumentParser):
|
||||||
"""ArgumentParser that allows both underscore and dash in names."""
|
"""ArgumentParser that allows both underscore and dash in names."""
|
||||||
|
|
||||||
@ -1163,7 +1175,7 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
|
|||||||
args = sys.argv[1:]
|
args = sys.argv[1:]
|
||||||
|
|
||||||
if '--config' in args:
|
if '--config' in args:
|
||||||
args = FlexibleArgumentParser._pull_args_from_config(args)
|
args = self._pull_args_from_config(args)
|
||||||
|
|
||||||
# Convert underscores to dashes and vice versa in argument names
|
# Convert underscores to dashes and vice versa in argument names
|
||||||
processed_args = []
|
processed_args = []
|
||||||
@ -1181,8 +1193,7 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
|
|||||||
|
|
||||||
return super().parse_args(processed_args, namespace)
|
return super().parse_args(processed_args, namespace)
|
||||||
|
|
||||||
@staticmethod
|
def _pull_args_from_config(self, args: List[str]) -> List[str]:
|
||||||
def _pull_args_from_config(args: List[str]) -> List[str]:
|
|
||||||
"""Method to pull arguments specified in the config file
|
"""Method to pull arguments specified in the config file
|
||||||
into the command-line args variable.
|
into the command-line args variable.
|
||||||
|
|
||||||
@ -1226,7 +1237,7 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
|
|||||||
|
|
||||||
file_path = args[index + 1]
|
file_path = args[index + 1]
|
||||||
|
|
||||||
config_args = FlexibleArgumentParser._load_config_file(file_path)
|
config_args = self._load_config_file(file_path)
|
||||||
|
|
||||||
# 0th index is for {serve,chat,complete}
|
# 0th index is for {serve,chat,complete}
|
||||||
# followed by model_tag (only for serve)
|
# followed by model_tag (only for serve)
|
||||||
@ -1247,8 +1258,7 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
|
|||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
@staticmethod
|
def _load_config_file(self, file_path: str) -> List[str]:
|
||||||
def _load_config_file(file_path: str) -> List[str]:
|
|
||||||
"""Loads a yaml file and returns the key value pairs as a
|
"""Loads a yaml file and returns the key value pairs as a
|
||||||
flattened list with argparse like pattern
|
flattened list with argparse like pattern
|
||||||
```yaml
|
```yaml
|
||||||
@ -1282,7 +1292,16 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
|
|||||||
Make sure path is correct", file_path)
|
Make sure path is correct", file_path)
|
||||||
raise ex
|
raise ex
|
||||||
|
|
||||||
|
store_boolean_arguments = [
|
||||||
|
action.dest for action in self._actions
|
||||||
|
if isinstance(action, StoreBoolean)
|
||||||
|
]
|
||||||
|
|
||||||
for key, value in config.items():
|
for key, value in config.items():
|
||||||
|
if isinstance(value, bool) and key not in store_boolean_arguments:
|
||||||
|
if value:
|
||||||
|
processed_args.append('--' + key)
|
||||||
|
else:
|
||||||
processed_args.append('--' + key)
|
processed_args.append('--' + key)
|
||||||
processed_args.append(str(value))
|
processed_args.append(str(value))
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user