Fix nullable_kvs fallback (#16837)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-04-18 13:58:39 +01:00 committed by GitHub
parent aadb656562
commit 686623c5e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 27 additions and 19 deletions

View File

@ -10,7 +10,7 @@ from vllm.utils import FlexibleArgumentParser
@pytest.mark.parametrize(("arg", "expected"), [ @pytest.mark.parametrize(("arg", "expected"), [
(None, None), (None, dict()),
("image=16", { ("image=16", {
"image": 16 "image": 16
}), }),

View File

@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import json
import openai import openai
import pytest import pytest
import pytest_asyncio import pytest_asyncio
@ -27,7 +29,7 @@ def server():
"--enforce-eager", "--enforce-eager",
"--trust-remote-code", "--trust-remote-code",
"--limit-mm-per-prompt", "--limit-mm-per-prompt",
str({"audio": MAXIMUM_AUDIOS}), json.dumps({"audio": MAXIMUM_AUDIOS}),
] ]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:

View File

@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import json
import openai import openai
import pytest import pytest
import pytest_asyncio import pytest_asyncio
@ -31,7 +33,7 @@ def server():
"--enforce-eager", "--enforce-eager",
"--trust-remote-code", "--trust-remote-code",
"--limit-mm-per-prompt", "--limit-mm-per-prompt",
str({"video": MAXIMUM_VIDEOS}), json.dumps({"video": MAXIMUM_VIDEOS}),
] ]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:

View File

@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import json
import openai import openai
import pytest import pytest
import pytest_asyncio import pytest_asyncio
@ -35,7 +37,7 @@ def server():
"--enforce-eager", "--enforce-eager",
"--trust-remote-code", "--trust-remote-code",
"--limit-mm-per-prompt", "--limit-mm-per-prompt",
str({"image": MAXIMUM_IMAGES}), json.dumps({"image": MAXIMUM_IMAGES}),
] ]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:

View File

@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import json
import pytest import pytest
import requests import requests
from PIL import Image from PIL import Image
@ -37,7 +39,7 @@ def server():
"--enforce-eager", "--enforce-eager",
"--trust-remote-code", "--trust-remote-code",
"--limit-mm-per-prompt", "--limit-mm-per-prompt",
str({"image": MAXIMUM_IMAGES}), json.dumps({"image": MAXIMUM_IMAGES}),
"--chat-template", "--chat-template",
str(vlm2vec_jinja_path), str(vlm2vec_jinja_path),
] ]

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import json
from typing import Optional from typing import Optional
import numpy as np import numpy as np
@ -50,7 +51,7 @@ def server(request, audio_assets):
args = [ args = [
"--dtype", "bfloat16", "--max-model-len", "4096", "--enforce-eager", "--dtype", "bfloat16", "--max-model-len", "4096", "--enforce-eager",
"--limit-mm-per-prompt", "--limit-mm-per-prompt",
str({"audio": len(audio_assets)}), "--trust-remote-code" json.dumps({"audio": len(audio_assets)}), "--trust-remote-code"
] + [ ] + [
f"--{key.replace('_','-')}={value}" f"--{key.replace('_','-')}={value}"
for key, value in request.param.items() for key, value in request.param.items()

View File

@ -10,7 +10,6 @@ import sys
import textwrap import textwrap
import warnings import warnings
from collections import Counter from collections import Counter
from collections.abc import Mapping
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import (MISSING, dataclass, field, fields, is_dataclass, from dataclasses import (MISSING, dataclass, field, fields, is_dataclass,
replace) replace)
@ -355,7 +354,7 @@ class ModelConfig:
disable_cascade_attn: bool = False, disable_cascade_attn: bool = False,
skip_tokenizer_init: bool = False, skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, list[str]]] = None, served_model_name: Optional[Union[str, list[str]]] = None,
limit_mm_per_prompt: Optional[Mapping[str, int]] = None, limit_mm_per_prompt: Optional[dict[str, int]] = None,
use_async_output_proc: bool = True, use_async_output_proc: bool = True,
config_format: ConfigFormat = ConfigFormat.AUTO, config_format: ConfigFormat = ConfigFormat.AUTO,
hf_token: Optional[Union[bool, str]] = None, hf_token: Optional[Union[bool, str]] = None,
@ -578,7 +577,7 @@ class ModelConfig:
self.tokenizer = s3_tokenizer.dir self.tokenizer = s3_tokenizer.dir
def _init_multimodal_config( def _init_multimodal_config(
self, limit_mm_per_prompt: Optional[Mapping[str, int]] self, limit_mm_per_prompt: Optional[dict[str, int]]
) -> Optional["MultiModalConfig"]: ) -> Optional["MultiModalConfig"]:
if self.registry.is_multimodal_model(self.architectures): if self.registry.is_multimodal_model(self.architectures):
return MultiModalConfig(limit_per_prompt=limit_mm_per_prompt or {}) return MultiModalConfig(limit_per_prompt=limit_mm_per_prompt or {})
@ -2730,7 +2729,7 @@ class PromptAdapterConfig:
class MultiModalConfig: class MultiModalConfig:
"""Controls the behavior of multimodal models.""" """Controls the behavior of multimodal models."""
limit_per_prompt: Mapping[str, int] = field(default_factory=dict) limit_per_prompt: dict[str, int] = field(default_factory=dict)
""" """
The maximum number of input items allowed per prompt for each modality. The maximum number of input items allowed per prompt for each modality.
This should be a JSON string that will be parsed into a dictionary. This should be a JSON string that will be parsed into a dictionary.

View File

@ -7,7 +7,7 @@ import json
import re import re
import threading import threading
from dataclasses import MISSING, dataclass, fields from dataclasses import MISSING, dataclass, fields
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Literal, Mapping, from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Literal,
Optional, Tuple, Type, TypeVar, Union, cast, get_args, Optional, Tuple, Type, TypeVar, Union, cast, get_args,
get_origin) get_origin)
@ -112,14 +112,14 @@ def nullable_kvs(val: str) -> Optional[dict[str, int]]:
def optional_dict(val: str) -> Optional[dict[str, int]]: def optional_dict(val: str) -> Optional[dict[str, int]]:
try: if re.match("^{.*}$", val):
return optional_arg(val, json.loads) return optional_arg(val, json.loads)
except ValueError:
logger.warning( logger.warning(
"Failed to parse JSON string. Attempting to parse as " "Failed to parse JSON string. Attempting to parse as "
"comma-separated key=value pairs. This will be deprecated in a " "comma-separated key=value pairs. This will be deprecated in a "
"future release.") "future release.")
return nullable_kvs(val) return nullable_kvs(val)
@dataclass @dataclass
@ -191,7 +191,7 @@ class EngineArgs:
TokenizerPoolConfig.pool_type TokenizerPoolConfig.pool_type
tokenizer_pool_extra_config: dict[str, Any] = \ tokenizer_pool_extra_config: dict[str, Any] = \
get_field(TokenizerPoolConfig, "extra_config") get_field(TokenizerPoolConfig, "extra_config")
limit_mm_per_prompt: Mapping[str, int] = \ limit_mm_per_prompt: dict[str, int] = \
get_field(MultiModalConfig, "limit_per_prompt") get_field(MultiModalConfig, "limit_per_prompt")
mm_processor_kwargs: Optional[Dict[str, Any]] = None mm_processor_kwargs: Optional[Dict[str, Any]] = None
disable_mm_preprocessor_cache: bool = False disable_mm_preprocessor_cache: bool = False