Automatically tell users that dict args must be valid JSON in CLI (#17577)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-05-02 13:24:55 +01:00 committed by GitHub
parent 6d1479ca4b
commit 785d75a03b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 15 additions and 13 deletions

View File

@ -106,6 +106,8 @@ class DummyConfigClass:
"""List with literal choices""" """List with literal choices"""
literal_literal: Literal[Literal[1], Literal[2]] = 1 literal_literal: Literal[Literal[1], Literal[2]] = 1
"""Literal of literals with default 1""" """Literal of literals with default 1"""
json_tip: dict = field(default_factory=dict)
"""Dict which will be JSON in CLI"""
@pytest.mark.parametrize(("type_hint", "expected"), [ @pytest.mark.parametrize(("type_hint", "expected"), [
@ -137,6 +139,9 @@ def test_get_kwargs():
assert kwargs["list_literal"]["choices"] == [1, 2] assert kwargs["list_literal"]["choices"] == [1, 2]
# literals of literals should have merged choices # literals of literals should have merged choices
assert kwargs["literal_literal"]["choices"] == [1, 2] assert kwargs["literal_literal"]["choices"] == [1, 2]
# dict should have json tip in help
json_tip = "\n\nShould be a valid JSON string."
assert kwargs["json_tip"]["help"].endswith(json_tip)
@pytest.mark.parametrize(("arg", "expected"), [ @pytest.mark.parametrize(("arg", "expected"), [

View File

@ -268,7 +268,7 @@ class ModelConfig:
It can be a branch name, a tag name, or a commit id. If unspecified, will It can be a branch name, a tag name, or a commit id. If unspecified, will
use the default version.""" use the default version."""
rope_scaling: dict[str, Any] = field(default_factory=dict) rope_scaling: dict[str, Any] = field(default_factory=dict)
"""RoPE scaling configuration in JSON format. For example, """RoPE scaling configuration. For example,
`{"rope_type":"dynamic","factor":2.0}`.""" `{"rope_type":"dynamic","factor":2.0}`."""
rope_theta: Optional[float] = None rope_theta: Optional[float] = None
"""RoPE theta. Use with `rope_scaling`. In some cases, changing the RoPE """RoPE theta. Use with `rope_scaling`. In some cases, changing the RoPE
@ -346,14 +346,13 @@ class ModelConfig:
(stored in `~/.huggingface`).""" (stored in `~/.huggingface`)."""
hf_overrides: HfOverrides = field(default_factory=dict) hf_overrides: HfOverrides = field(default_factory=dict)
"""If a dictionary, contains arguments to be forwarded to the Hugging Face """If a dictionary, contains arguments to be forwarded to the Hugging Face
config. If a callable, it is called to update the HuggingFace config. When config. If a callable, it is called to update the HuggingFace config."""
specified via CLI, the argument must be a valid JSON string."""
mm_processor_kwargs: Optional[dict[str, Any]] = None mm_processor_kwargs: Optional[dict[str, Any]] = None
"""Arguments to be forwarded to the model's processor for multi-modal data, """Arguments to be forwarded to the model's processor for multi-modal data,
e.g., image processor. Overrides for the multi-modal processor obtained e.g., image processor. Overrides for the multi-modal processor obtained
from `AutoProcessor.from_pretrained`. The available overrides depend on the from `AutoProcessor.from_pretrained`. The available overrides depend on the
model that is being run. For example, for Phi-3-Vision: `{"num_crops": 4}`. model that is being run. For example, for Phi-3-Vision: `{"num_crops": 4}`.
When specified via CLI, the argument must be a valid JSON string.""" """
disable_mm_preprocessor_cache: bool = False disable_mm_preprocessor_cache: bool = False
"""If `True`, disable caching of the multi-modal preprocessor/mapper (not """If `True`, disable caching of the multi-modal preprocessor/mapper (not
recommended).""" recommended)."""
@ -361,15 +360,14 @@ class ModelConfig:
"""Initialize non-default neuron config or override default neuron config """Initialize non-default neuron config or override default neuron config
that are specific to Neuron devices, this argument will be used to that are specific to Neuron devices, this argument will be used to
configure the neuron config that can not be gathered from the vllm configure the neuron config that can not be gathered from the vllm
arguments. e.g. `{"cast_logits_dtype": "bloat16"}`. When specified via CLI, arguments. e.g. `{"cast_logits_dtype": "bloat16"}`."""
the argument must be a valid JSON string."""
pooler_config: Optional["PoolerConfig"] = field(init=False) pooler_config: Optional["PoolerConfig"] = field(init=False)
"""Pooler config which controls the behaviour of output pooling in pooling """Pooler config which controls the behaviour of output pooling in pooling
models.""" models."""
override_pooler_config: Optional[Union[dict, "PoolerConfig"]] = None override_pooler_config: Optional[Union[dict, "PoolerConfig"]] = None
"""Initialize non-default pooling config or override default pooling config """Initialize non-default pooling config or override default pooling config
for the pooling model. e.g. `{"pooling_type": "mean", "normalize": false}`. for the pooling model. e.g. `{"pooling_type": "mean", "normalize": false}`.
When specified via CLI, the argument must be a valid JSON string.""" """
logits_processor_pattern: Optional[str] = None logits_processor_pattern: Optional[str] = None
"""Optional regex pattern specifying valid logits processor qualified names """Optional regex pattern specifying valid logits processor qualified names
that can be passed with the `logits_processors` extra completion argument. that can be passed with the `logits_processors` extra completion argument.
@ -385,8 +383,7 @@ class ModelConfig:
"""Overrides or sets generation config. e.g. `{"temperature": 0.5}`. If """Overrides or sets generation config. e.g. `{"temperature": 0.5}`. If
used with `--generation-config auto`, the override parameters will be used with `--generation-config auto`, the override parameters will be
merged with the default config from the model. If used with merged with the default config from the model. If used with
`--generation-config vllm`, only the override parameters are used. `--generation-config vllm`, only the override parameters are used."""
When specified via CLI, the argument must be a valid JSON string."""
enable_sleep_mode: bool = False enable_sleep_mode: bool = False
"""Enable sleep mode for the engine (only cuda platform is supported).""" """Enable sleep mode for the engine (only cuda platform is supported)."""
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO.value model_impl: Union[str, ModelImpl] = ModelImpl.AUTO.value
@ -1556,8 +1553,7 @@ class LoadConfig:
cache directory of Hugging Face.""" cache directory of Hugging Face."""
model_loader_extra_config: dict = field(default_factory=dict) model_loader_extra_config: dict = field(default_factory=dict)
"""Extra config for model loader. This will be passed to the model loader """Extra config for model loader. This will be passed to the model loader
corresponding to the chosen load_format. This should be a JSON string that corresponding to the chosen load_format."""
will be parsed into a dictionary."""
ignore_patterns: Optional[Union[list[str], str]] = None ignore_patterns: Optional[Union[list[str], str]] = None
"""The list of patterns to ignore when loading the model. Default to """The list of patterns to ignore when loading the model. Default to
"original/**/*" to avoid repeated loading of llama's checkpoints.""" "original/**/*" to avoid repeated loading of llama's checkpoints."""
@ -2826,7 +2822,6 @@ class MultiModalConfig:
"limit_mm_per_prompt") "limit_mm_per_prompt")
""" """
The maximum number of input items allowed per prompt for each modality. The maximum number of input items allowed per prompt for each modality.
This should be a JSON string that will be parsed into a dictionary.
Defaults to 1 (V0) or 999 (V1) for each modality. Defaults to 1 (V0) or 999 (V1) for each modality.
For example, to allow up to 16 images and 2 videos per prompt: For example, to allow up to 16 images and 2 videos per prompt:

View File

@ -150,7 +150,7 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
# Get the help text for the field # Get the help text for the field
name = field.name name = field.name
help = cls_docs[name] help = cls_docs[name].strip()
# Escape % for argparse # Escape % for argparse
help = help.replace("%", "%%") help = help.replace("%", "%%")
@ -165,6 +165,7 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
type_hints.add(field.type) type_hints.add(field.type)
# Set other kwargs based on the type hints # Set other kwargs based on the type hints
json_tip = "\n\nShould be a valid JSON string."
if contains_type(type_hints, bool): if contains_type(type_hints, bool):
# Creates --no-<name> and --<name> flags # Creates --no-<name> and --<name> flags
kwargs[name]["action"] = argparse.BooleanOptionalAction kwargs[name]["action"] = argparse.BooleanOptionalAction
@ -201,6 +202,7 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
elif contains_type(type_hints, dict): elif contains_type(type_hints, dict):
# Dict arguments will always be optional # Dict arguments will always be optional
kwargs[name]["type"] = optional_type(json.loads) kwargs[name]["type"] = optional_type(json.loads)
kwargs[name]["help"] += json_tip
elif (contains_type(type_hints, str) elif (contains_type(type_hints, str)
or any(is_not_builtin(th) for th in type_hints)): or any(is_not_builtin(th) for th in type_hints)):
kwargs[name]["type"] = str kwargs[name]["type"] = str