[Frontend] Improve error message for too many mm items (#22114)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-08-02 17:20:38 +08:00 committed by GitHub
parent b690e34824
commit f5d0f4784f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 52 additions and 51 deletions

View File

@ -579,10 +579,7 @@ def test_parse_chat_messages_rejects_too_many_images_in_one_message(
warnings.filterwarnings(
"ignore",
message="coroutine 'async_get_and_parse_image' was never awaited")
with pytest.raises(
ValueError,
match="At most 2 image\\(s\\) may be provided in one request\\."
):
with pytest.raises(ValueError, match="At most"):
parse_chat_messages(
[{
"role":
@ -622,10 +619,7 @@ def test_parse_chat_messages_rejects_too_many_images_across_messages(
warnings.filterwarnings(
"ignore",
message="coroutine 'async_get_and_parse_image' was never awaited")
with pytest.raises(
ValueError,
match="At most 2 image\\(s\\) may be provided in one request\\."
):
with pytest.raises(ValueError, match="At most"):
parse_chat_messages(
[{
"role":

View File

@ -3,7 +3,6 @@
from contextlib import nullcontext
from typing import Optional, cast
from unittest.mock import MagicMock
import numpy as np
import pytest
@ -957,15 +956,14 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
)
processor = MULTIMODAL_REGISTRY.create_processor(model_config)
profiler = MultiModalProfiler(processor)
processor._supported_mm_limits = {"image": num_supported}
mock_supported_mm_limits = MagicMock(return_value={"image": num_supported})
processor.info.get_supported_mm_limits = mock_supported_mm_limits
profiler = MultiModalProfiler(processor)
if is_valid:
exc_ctx = nullcontext()
else:
exc_ctx = pytest.raises(ValueError, match="The model only supports")
exc_ctx = pytest.raises(ValueError, match="At most")
with exc_ctx:
profiler.get_decoder_dummy_data(
@ -1002,7 +1000,7 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
if is_valid:
exc_ctx = nullcontext()
else:
exc_ctx = pytest.raises(ValueError, match=f"passed {num_images} image")
exc_ctx = pytest.raises(ValueError, match="At most")
with exc_ctx:
processor.apply(

View File

@ -535,9 +535,10 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
return self._model_config
@cached_property
def model_cls(self):
def model_cls(self) -> type[SupportsMultiModal]:
from vllm.model_executor.model_loader import get_model_cls
return get_model_cls(self.model_config)
model_cls = get_model_cls(self.model_config)
return cast(type[SupportsMultiModal], model_cls)
@property
def allowed_local_media_path(self):
@ -547,31 +548,23 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
def mm_registry(self):
return MULTIMODAL_REGISTRY
@cached_property
def mm_processor(self):
return self.mm_registry.create_processor(self.model_config)
def add(self, modality: ModalityStr, item: _T) -> Optional[str]:
"""
Add a multi-modal item to the current prompt and returns the
placeholder string to use, if any.
"""
mm_registry = self.mm_registry
model_config = self.model_config
model_cls = cast(SupportsMultiModal, self.model_cls)
input_modality = modality.replace("_embeds", "")
num_items = len(self._items_by_modality[modality]) + 1
mm_processor = mm_registry.create_processor(model_config)
allowed_counts = mm_processor.info.get_allowed_mm_limits()
allowed_count = allowed_counts.get(input_modality, 0)
current_count = len(self._items_by_modality[modality]) + 1
if current_count > allowed_count:
raise ValueError(
f"At most {allowed_count} {modality}(s) may be provided in "
"one request. You can set `--limit-mm-per-prompt` to "
"increase this limit if the model supports it.")
self.mm_processor.validate_num_items(input_modality, num_items)
self._items_by_modality[modality].append(item)
return model_cls.get_placeholder_str(modality, current_count)
return self.model_cls.get_placeholder_str(modality, num_items)
@abstractmethod
def create_parser(self) -> "BaseMultiModalContentParser":

View File

@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import sys
from abc import ABC, abstractmethod
from collections import defaultdict
@ -1156,6 +1155,18 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
self.data_parser = self._get_data_parser()
# Avoid unnecessary recomputation
self._supported_mm_limits = self.info.get_supported_mm_limits()
self._allowed_mm_limits = self.info.get_allowed_mm_limits()
@property
def supported_mm_limits(self):
return self._supported_mm_limits
@property
def allowed_mm_limits(self):
return self._allowed_mm_limits
def __call__(
self,
prompt: str,
@ -1176,6 +1187,28 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
"""
return MultiModalDataParser()
def validate_num_items(
self,
modality: str,
num_items: int,
) -> None:
supported_limit = self.supported_mm_limits.get(modality, 0)
allowed_limit = self.allowed_mm_limits.get(modality, 0)
if supported_limit is None:
supported_limit = allowed_limit
limit = min(supported_limit, allowed_limit)
if num_items > limit:
msg = (f"At most {limit} {modality}(s) may be provided in "
"one prompt.")
if num_items <= supported_limit:
msg += " Set `--limit-mm-per-prompt` to increase this limit."
raise ValueError(msg)
def _to_mm_items(
self,
mm_data: MultiModalDataDict,
@ -1188,26 +1221,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
[`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data].
"""
mm_items = self.data_parser.parse_mm_data(mm_data)
supported_mm_limits = self.info.get_supported_mm_limits()
allowed_mm_limits = self.info.get_allowed_mm_limits()
for modality, items in mm_items.items():
supported_limit = supported_mm_limits.get(modality, 0)
allowed_limit = allowed_mm_limits.get(modality, 0)
num_items = len(items)
if supported_limit is not None and num_items > supported_limit:
raise ValueError(
f"The model only supports at most {supported_limit} "
f"{modality} items, but you passed {num_items} "
f"{modality} items in the same prompt.")
if num_items > allowed_limit:
raise ValueError(
"You set or defaulted to "
f"'{json.dumps({modality: allowed_limit})}' in "
f"`--limit-mm-per-prompt`, but passed {num_items} "
f"{modality} items in the same prompt.")
self.validate_num_items(modality, len(items))
return mm_items

View File

@ -156,7 +156,7 @@ class MultiModalProfiler(Generic[_I]):
return self.processor.dummy_inputs
def get_mm_limits(self) -> Mapping[str, int]:
return self.processing_info.get_allowed_mm_limits()
return self.processor.allowed_mm_limits
def _get_dummy_mm_inputs(
self,