diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index 54daf1a91d645..647f1c7b7f34f 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -579,10 +579,7 @@ def test_parse_chat_messages_rejects_too_many_images_in_one_message( warnings.filterwarnings( "ignore", message="coroutine 'async_get_and_parse_image' was never awaited") - with pytest.raises( - ValueError, - match="At most 2 image\\(s\\) may be provided in one request\\." - ): + with pytest.raises(ValueError, match="At most"): parse_chat_messages( [{ "role": @@ -622,10 +619,7 @@ def test_parse_chat_messages_rejects_too_many_images_across_messages( warnings.filterwarnings( "ignore", message="coroutine 'async_get_and_parse_image' was never awaited") - with pytest.raises( - ValueError, - match="At most 2 image\\(s\\) may be provided in one request\\." - ): + with pytest.raises(ValueError, match="At most"): parse_chat_messages( [{ "role": diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index 659ee9af9ddec..508c773b8aedf 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -3,7 +3,6 @@ from contextlib import nullcontext from typing import Optional, cast -from unittest.mock import MagicMock import numpy as np import pytest @@ -957,15 +956,14 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid): ) processor = MULTIMODAL_REGISTRY.create_processor(model_config) - profiler = MultiModalProfiler(processor) + processor._supported_mm_limits = {"image": num_supported} - mock_supported_mm_limits = MagicMock(return_value={"image": num_supported}) - processor.info.get_supported_mm_limits = mock_supported_mm_limits + profiler = MultiModalProfiler(processor) if is_valid: exc_ctx = nullcontext() else: - exc_ctx = pytest.raises(ValueError, match="The model only supports") + exc_ctx = pytest.raises(ValueError, match="At most") with exc_ctx: profiler.get_decoder_dummy_data( @@ -1002,7 +1000,7 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid): if is_valid: exc_ctx = nullcontext() else: - exc_ctx = pytest.raises(ValueError, match=f"passed {num_images} image") + exc_ctx = pytest.raises(ValueError, match="At most") with exc_ctx: processor.apply( diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 6485ed6b148b4..a658d97cc8c5e 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -535,9 +535,10 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): return self._model_config @cached_property - def model_cls(self): + def model_cls(self) -> type[SupportsMultiModal]: from vllm.model_executor.model_loader import get_model_cls - return get_model_cls(self.model_config) + model_cls = get_model_cls(self.model_config) + return cast(type[SupportsMultiModal], model_cls) @property def allowed_local_media_path(self): @@ -547,31 +548,23 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): def mm_registry(self): return MULTIMODAL_REGISTRY + @cached_property + def mm_processor(self): + return self.mm_registry.create_processor(self.model_config) + def add(self, modality: ModalityStr, item: _T) -> Optional[str]: """ Add a multi-modal item to the current prompt and returns the placeholder string to use, if any. """ - mm_registry = self.mm_registry - model_config = self.model_config - model_cls = cast(SupportsMultiModal, self.model_cls) - input_modality = modality.replace("_embeds", "") + num_items = len(self._items_by_modality[modality]) + 1 - mm_processor = mm_registry.create_processor(model_config) - allowed_counts = mm_processor.info.get_allowed_mm_limits() - allowed_count = allowed_counts.get(input_modality, 0) - - current_count = len(self._items_by_modality[modality]) + 1 - if current_count > allowed_count: - raise ValueError( - f"At most {allowed_count} {modality}(s) may be provided in " - "one request. You can set `--limit-mm-per-prompt` to " - "increase this limit if the model supports it.") + self.mm_processor.validate_num_items(input_modality, num_items) self._items_by_modality[modality].append(item) - return model_cls.get_placeholder_str(modality, current_count) + return self.model_cls.get_placeholder_str(modality, num_items) @abstractmethod def create_parser(self) -> "BaseMultiModalContentParser": diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 78d244a6b4fc8..46240855d12a2 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import json import sys from abc import ABC, abstractmethod from collections import defaultdict @@ -1156,6 +1155,18 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): self.data_parser = self._get_data_parser() + # Avoid unnecessary recomputation + self._supported_mm_limits = self.info.get_supported_mm_limits() + self._allowed_mm_limits = self.info.get_allowed_mm_limits() + + @property + def supported_mm_limits(self): + return self._supported_mm_limits + + @property + def allowed_mm_limits(self): + return self._allowed_mm_limits + def __call__( self, prompt: str, @@ -1176,6 +1187,28 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): """ return MultiModalDataParser() + def validate_num_items( + self, + modality: str, + num_items: int, + ) -> None: + supported_limit = self.supported_mm_limits.get(modality, 0) + allowed_limit = self.allowed_mm_limits.get(modality, 0) + + if supported_limit is None: + supported_limit = allowed_limit + + limit = min(supported_limit, allowed_limit) + + if num_items > limit: + msg = (f"At most {limit} {modality}(s) may be provided in " + "one prompt.") + + if num_items <= supported_limit: + msg += " Set `--limit-mm-per-prompt` to increase this limit." + + raise ValueError(msg) + def _to_mm_items( self, mm_data: MultiModalDataDict, @@ -1188,26 +1221,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): [`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data]. """ mm_items = self.data_parser.parse_mm_data(mm_data) - supported_mm_limits = self.info.get_supported_mm_limits() - allowed_mm_limits = self.info.get_allowed_mm_limits() for modality, items in mm_items.items(): - supported_limit = supported_mm_limits.get(modality, 0) - allowed_limit = allowed_mm_limits.get(modality, 0) - num_items = len(items) - - if supported_limit is not None and num_items > supported_limit: - raise ValueError( - f"The model only supports at most {supported_limit} " - f"{modality} items, but you passed {num_items} " - f"{modality} items in the same prompt.") - - if num_items > allowed_limit: - raise ValueError( - "You set or defaulted to " - f"'{json.dumps({modality: allowed_limit})}' in " - f"`--limit-mm-per-prompt`, but passed {num_items} " - f"{modality} items in the same prompt.") + self.validate_num_items(modality, len(items)) return mm_items diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index d96803b643ff2..d876887fc155d 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -156,7 +156,7 @@ class MultiModalProfiler(Generic[_I]): return self.processor.dummy_inputs def get_mm_limits(self) -> Mapping[str, int]: - return self.processing_info.get_allowed_mm_limits() + return self.processor.allowed_mm_limits def _get_dummy_mm_inputs( self,