Fix nullable_kvs fallback (#16837)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-04-18 13:58:39 +01:00 committed by GitHub
parent aadb656562
commit 686623c5e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 27 additions and 19 deletions

View File

@ -10,7 +10,7 @@ from vllm.utils import FlexibleArgumentParser
@pytest.mark.parametrize(("arg", "expected"), [
(None, None),
(None, dict()),
("image=16", {
"image": 16
}),

View File

@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
import json
import openai
import pytest
import pytest_asyncio
@ -27,7 +29,7 @@ def server():
"--enforce-eager",
"--trust-remote-code",
"--limit-mm-per-prompt",
str({"audio": MAXIMUM_AUDIOS}),
json.dumps({"audio": MAXIMUM_AUDIOS}),
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:

View File

@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
import json
import openai
import pytest
import pytest_asyncio
@ -31,7 +33,7 @@ def server():
"--enforce-eager",
"--trust-remote-code",
"--limit-mm-per-prompt",
str({"video": MAXIMUM_VIDEOS}),
json.dumps({"video": MAXIMUM_VIDEOS}),
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:

View File

@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
import json
import openai
import pytest
import pytest_asyncio
@ -35,7 +37,7 @@ def server():
"--enforce-eager",
"--trust-remote-code",
"--limit-mm-per-prompt",
str({"image": MAXIMUM_IMAGES}),
json.dumps({"image": MAXIMUM_IMAGES}),
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:

View File

@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
import json
import pytest
import requests
from PIL import Image
@ -37,7 +39,7 @@ def server():
"--enforce-eager",
"--trust-remote-code",
"--limit-mm-per-prompt",
str({"image": MAXIMUM_IMAGES}),
json.dumps({"image": MAXIMUM_IMAGES}),
"--chat-template",
str(vlm2vec_jinja_path),
]

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
import json
from typing import Optional
import numpy as np
@ -50,7 +51,7 @@ def server(request, audio_assets):
args = [
"--dtype", "bfloat16", "--max-model-len", "4096", "--enforce-eager",
"--limit-mm-per-prompt",
str({"audio": len(audio_assets)}), "--trust-remote-code"
json.dumps({"audio": len(audio_assets)}), "--trust-remote-code"
] + [
f"--{key.replace('_','-')}={value}"
for key, value in request.param.items()

View File

@ -10,7 +10,6 @@ import sys
import textwrap
import warnings
from collections import Counter
from collections.abc import Mapping
from contextlib import contextmanager
from dataclasses import (MISSING, dataclass, field, fields, is_dataclass,
replace)
@ -355,7 +354,7 @@ class ModelConfig:
disable_cascade_attn: bool = False,
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, list[str]]] = None,
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
limit_mm_per_prompt: Optional[dict[str, int]] = None,
use_async_output_proc: bool = True,
config_format: ConfigFormat = ConfigFormat.AUTO,
hf_token: Optional[Union[bool, str]] = None,
@ -578,7 +577,7 @@ class ModelConfig:
self.tokenizer = s3_tokenizer.dir
def _init_multimodal_config(
self, limit_mm_per_prompt: Optional[Mapping[str, int]]
self, limit_mm_per_prompt: Optional[dict[str, int]]
) -> Optional["MultiModalConfig"]:
if self.registry.is_multimodal_model(self.architectures):
return MultiModalConfig(limit_per_prompt=limit_mm_per_prompt or {})
@ -2730,7 +2729,7 @@ class PromptAdapterConfig:
class MultiModalConfig:
"""Controls the behavior of multimodal models."""
limit_per_prompt: Mapping[str, int] = field(default_factory=dict)
limit_per_prompt: dict[str, int] = field(default_factory=dict)
"""
The maximum number of input items allowed per prompt for each modality.
This should be a JSON string that will be parsed into a dictionary.

View File

@ -7,7 +7,7 @@ import json
import re
import threading
from dataclasses import MISSING, dataclass, fields
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Literal, Mapping,
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Literal,
Optional, Tuple, Type, TypeVar, Union, cast, get_args,
get_origin)
@ -112,14 +112,14 @@ def nullable_kvs(val: str) -> Optional[dict[str, int]]:
def optional_dict(val: str) -> Optional[dict[str, int]]:
try:
if re.match("^{.*}$", val):
return optional_arg(val, json.loads)
except ValueError:
logger.warning(
"Failed to parse JSON string. Attempting to parse as "
"comma-separated key=value pairs. This will be deprecated in a "
"future release.")
return nullable_kvs(val)
logger.warning(
"Failed to parse JSON string. Attempting to parse as "
"comma-separated key=value pairs. This will be deprecated in a "
"future release.")
return nullable_kvs(val)
@dataclass
@ -191,7 +191,7 @@ class EngineArgs:
TokenizerPoolConfig.pool_type
tokenizer_pool_extra_config: dict[str, Any] = \
get_field(TokenizerPoolConfig, "extra_config")
limit_mm_per_prompt: Mapping[str, int] = \
limit_mm_per_prompt: dict[str, int] = \
get_field(MultiModalConfig, "limit_per_prompt")
mm_processor_kwargs: Optional[Dict[str, Any]] = None
disable_mm_preprocessor_cache: bool = False