diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 0941cc3f608e..4eb8e0cfaa5d 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -12,11 +12,11 @@ from mistral_common.protocol.instruct.request import ChatCompletionRequest from PIL import Image from vllm.config import ModelConfig -from vllm.inputs import InputProcessingContext from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict from vllm.multimodal.cache import MultiModalProcessorOnlyCache from vllm.multimodal.inputs import MultiModalInputs -from vllm.multimodal.processing import BaseMultiModalProcessor +from vllm.multimodal.processing import (BaseMultiModalProcessor, + InputProcessingContext) from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, cached_tokenizer_from_config, encode_tokens) diff --git a/tests/models/multimodal/processing/test_tensor_schema.py b/tests/models/multimodal/processing/test_tensor_schema.py index b678313752d6..d5d5bfaa3b45 100644 --- a/tests/models/multimodal/processing/test_tensor_schema.py +++ b/tests/models/multimodal/processing/test_tensor_schema.py @@ -18,10 +18,10 @@ from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config from vllm.distributed import (cleanup_dist_env_and_memory, init_distributed_environment, initialize_model_parallel) -from vllm.inputs import InputProcessingContext from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensorInputs -from vllm.multimodal.processing import BaseMultiModalProcessor +from vllm.multimodal.processing import (BaseMultiModalProcessor, + InputProcessingContext) from vllm.multimodal.utils import group_mm_kwargs_by_modality from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from vllm.utils import is_list_of diff --git a/tests/models/utils.py b/tests/models/utils.py index 5da2382cef81..f80e92ebb3e2 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -11,8 +11,9 @@ import torch.nn.functional as F from transformers import PretrainedConfig from vllm.config import ModelConfig, ModelDType, RunnerOption -from vllm.inputs import InputContext from vllm.logprobs import Logprob, PromptLogprobs, SampleLogprobs +from vllm.multimodal.processing import InputProcessingContext +from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from .registry import HF_EXAMPLE_MODELS @@ -264,7 +265,7 @@ def build_model_context( limit_mm_per_prompt: Optional[dict[str, int]] = None, mm_processor_cache_gb: int = 0, ): - """Creates an InputContext for a given model. + """Creates an InputProcessingContext for a given model. Args: model_id: ID of the model being considered. @@ -273,7 +274,7 @@ def build_model_context( limit_mm_per_prompt: Multimodal limits. Returns: - InputContext for the model being considered. + InputProcessingContext for the model being considered. """ model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) model_info.check_available_online(on_fail="skip") @@ -298,7 +299,11 @@ def build_model_context( enforce_eager=model_info.enforce_eager, **model_config_kwargs, ) - return InputContext(model_config) + + return InputProcessingContext( + model_config, + tokenizer=cached_tokenizer_from_config(model_config), + ) def check_embeddings_close( diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index 6ce5fcfe644b..352b5b5b4fd4 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -8,11 +8,11 @@ import numpy as np import pytest from vllm.config import ModelConfig -from vllm.inputs import InputProcessingContext from vllm.multimodal import MULTIMODAL_REGISTRY # yapf conflicts with isort for this block # yapf: disable -from vllm.multimodal.processing import (PlaceholderFeaturesInfo, +from vllm.multimodal.processing import (InputProcessingContext, + PlaceholderFeaturesInfo, PromptIndexTargets, PromptInsertion, PromptReplacement, apply_text_matches, apply_token_matches, diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index 46f49aaa013d..3f1cac531f45 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -7,7 +7,6 @@ from .data import (DataPrompt, DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt, SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt, build_explicit_enc_dec_prompt, embeds_inputs, to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts) -from .registry import InputContext, InputProcessingContext __all__ = [ "DataPrompt", @@ -28,6 +27,4 @@ __all__ = [ "build_explicit_enc_dec_prompt", "to_enc_dec_tuple_list", "zip_enc_dec_prompts", - "InputContext", - "InputProcessingContext", ] diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py deleted file mode 100644 index 0aad78b04e34..000000000000 --- a/vllm/inputs/registry.py +++ /dev/null @@ -1,206 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import time -from collections.abc import Mapping -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Union - -import torch -from transformers import BatchFeature, PretrainedConfig, ProcessorMixin -from typing_extensions import TypeVar - -from vllm.logger import init_logger -from vllm.transformers_utils.processor import cached_processor_from_config -from vllm.utils import get_allowed_kwarg_only_overrides -from vllm.utils.jsontree import JSONTree, json_map_leaves - -if TYPE_CHECKING: - from vllm.config import ModelConfig - from vllm.transformers_utils.tokenizer import AnyTokenizer -else: - ModelConfig = Any - AnyTokenizer = Any - -_T = TypeVar("_T") -_C = TypeVar("_C", bound=PretrainedConfig, default=PretrainedConfig) -_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin) - -logger = init_logger(__name__) - - -@dataclass(frozen=True) -class InputContext: - """ - Contains information about the model which may be used to - modify the inputs. - """ - - model_config: ModelConfig - """The configuration of the model.""" - - def get_hf_config( - self, - typ: Union[type[_C], tuple[type[_C], ...]] = PretrainedConfig, - /, - ) -> _C: - """ - Get the HuggingFace configuration - (`transformers.PretrainedConfig`) of the model, - additionally checking its type. - - Raises: - TypeError: If the configuration is not of the specified type. - """ - hf_config = self.model_config.hf_config - if not isinstance(hf_config, typ): - raise TypeError("Invalid type of HuggingFace config. " - f"Expected type: {typ}, but " - f"found type: {type(hf_config)}") - - return hf_config - - def get_hf_image_processor_config(self) -> dict[str, Any]: - """ - Get the HuggingFace image processor configuration of the model. - """ - return self.model_config.hf_image_processor_config - - def get_mm_config(self): - """ - Get the multimodal config of the model. - - Raises: - RuntimeError: If the model is not a multimodal model. - """ - mm_config = self.model_config.multimodal_config - if mm_config is None: - raise RuntimeError("Not a multimodal model") - - return mm_config - - def get_hf_processor( - self, - typ: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin, - /, - **kwargs: object, - ) -> _P: - """ - Get the HuggingFace processor - (`transformers.ProcessorMixin`) of the model, - additionally checking its type. - - Raises: - TypeError: If the processor is not of the specified type. - """ - return cached_processor_from_config( - self.model_config, - processor_cls=typ, - **kwargs, - ) - - def init_processor( - self, - typ: type[_T], - /, - **kwargs: object, - ) -> _T: - """ - Initialize a HuggingFace-like processor class, merging the - keyword arguments with those in the model's configuration. - """ - mm_config = self.model_config.get_multimodal_config() - base_kwargs = mm_config.mm_processor_kwargs - if base_kwargs is None: - base_kwargs = {} - - merged_kwargs = {**base_kwargs, **kwargs} - - return typ(**merged_kwargs) - - -@dataclass(frozen=True) -class InputProcessingContext(InputContext): - tokenizer: AnyTokenizer - """The tokenizer used to tokenize the inputs.""" - - def get_hf_processor( - self, - typ: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin, - /, - **kwargs: object, - ) -> _P: - return super().get_hf_processor( - typ, - tokenizer=self.tokenizer, - **kwargs, - ) - - def call_hf_processor( - self, - hf_processor: ProcessorMixin, - data: Mapping[str, object], - kwargs: Mapping[str, object] = {}, - *, - num_tries: int = 1, - max_tries: int = 5, - ) -> Union[BatchFeature, JSONTree]: - """ - Call `hf_processor` on the prompt `data` - (text, image, audio...) with configurable options `kwargs`. - """ - assert callable(hf_processor) - - mm_config = self.model_config.get_multimodal_config() - merged_kwargs = mm_config.merge_mm_processor_kwargs(kwargs) - - allowed_kwargs = get_allowed_kwarg_only_overrides( - hf_processor, - merged_kwargs, - requires_kw_only=False, - allow_var_kwargs=True, - ) - - def maybe_cast_dtype(x): - # This mimics the behavior of transformers.BatchFeature - if isinstance(x, torch.Tensor) and x.is_floating_point(): - return x.to(dtype=self.model_config.dtype) - return x - - try: - output = hf_processor(**data, - **allowed_kwargs, - return_tensors="pt") - # this emulates output.to(dtype=self.model_config.dtype) - if isinstance(output, BatchFeature): - cast_output = json_map_leaves(maybe_cast_dtype, output.data) - return BatchFeature(cast_output) - - cast_output = json_map_leaves(maybe_cast_dtype, output) - - logger.warning_once( - f"{type(hf_processor).__name__} did not return `BatchFeature`. " - "Make sure to match the behaviour of `ProcessorMixin` when " - "implementing custom processors.") - return cast_output - - except Exception as exc: - # See https://github.com/huggingface/tokenizers/issues/537 - if (isinstance(exc, RuntimeError) and exc - and exc.args[0] == "Already borrowed" - and num_tries < max_tries): - logger.warning( - "Failed to acquire tokenizer in current thread. " - "Retrying (%d/%d)...", num_tries, max_tries) - time.sleep(0.5) - return self.call_hf_processor( - hf_processor, - data, - kwargs, - num_tries=num_tries + 1, - max_tries=max_tries, - ) - - msg = (f"Failed to apply {type(hf_processor).__name__} " - f"on data={data} with kwargs={allowed_kwargs}") - - raise ValueError(msg) from exc diff --git a/vllm/model_executor/models/hyperclovax_vision.py b/vllm/model_executor/models/hyperclovax_vision.py index 54167f9f1099..4d39ff9ae79e 100644 --- a/vllm/model_executor/models/hyperclovax_vision.py +++ b/vllm/model_executor/models/hyperclovax_vision.py @@ -29,7 +29,6 @@ from transformers import BatchFeature, CLIPVisionConfig, SiglipVisionConfig from transformers.modeling_utils import no_init_weights from vllm.config import VllmConfig -from vllm.inputs import InputProcessingContext from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.cache import BaseMultiModalProcessorCache @@ -37,8 +36,9 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems) from vllm.multimodal.parse import ImageSize, MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate) + BaseProcessingInfo, + InputProcessingContext, + PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index e2d7b9f23b28..8d7feb965e76 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -15,7 +15,6 @@ from transformers.models.llava import LlavaProcessor from transformers.models.pixtral import PixtralProcessor from vllm.config import VllmConfig -from vllm.inputs import InputProcessingContext from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) @@ -28,8 +27,10 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate, PromptUpdateDetails) + BaseProcessingInfo, + InputProcessingContext, + PromptReplacement, PromptUpdate, + PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.jsontree import json_map_leaves diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py index 94e3d7234b6f..ba6da4403ae1 100644 --- a/vllm/model_executor/models/mistral3.py +++ b/vllm/model_executor/models/mistral3.py @@ -13,7 +13,6 @@ from transformers import (BatchFeature, Mistral3Config, PixtralVisionConfig, from transformers.models.pixtral import PixtralProcessor from vllm.config import VllmConfig -from vllm.inputs import InputProcessingContext from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -27,8 +26,10 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate, PromptUpdateDetails) + BaseProcessingInfo, + InputProcessingContext, + PromptReplacement, PromptUpdate, + PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index 50521b593786..79e315f79489 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -32,7 +32,6 @@ from transformers.models.llama4.image_processing_llama4_fast import ( from vllm.attention.layer import MultiHeadAttention from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.inputs import InputProcessingContext from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, @@ -47,8 +46,10 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate, PromptUpdateDetails) + BaseProcessingInfo, + InputProcessingContext, + PromptReplacement, PromptUpdate, + PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape diff --git a/vllm/model_executor/models/tarsier.py b/vllm/model_executor/models/tarsier.py index 67cf3ccf315d..b75c858a6480 100644 --- a/vllm/model_executor/models/tarsier.py +++ b/vllm/model_executor/models/tarsier.py @@ -17,7 +17,6 @@ from transformers.processing_utils import ProcessingKwargs, Unpack from transformers.tokenization_utils_base import PreTokenizedInput, TextInput from vllm.config import VllmConfig -from vllm.inputs import InputProcessingContext from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) @@ -29,8 +28,9 @@ from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargsItems from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, PromptReplacement, - PromptUpdate) + BaseProcessingInfo, + InputProcessingContext, + PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.jsontree import json_map_leaves diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 7471bfcb4d50..78e2cb7fa733 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import time from abc import ABC, abstractmethod from collections import defaultdict from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping, @@ -7,18 +8,20 @@ from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping, from dataclasses import dataclass, field, replace from enum import Enum from functools import lru_cache -from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol, - TypeVar, Union, cast) +from typing import (TYPE_CHECKING, Any, Generic, NamedTuple, Optional, + Protocol, Union, cast, overload) import regex as re import torch -from typing_extensions import assert_never +from typing_extensions import TypeVar, assert_never -from vllm.inputs import InputProcessingContext from vllm.logger import init_logger +from vllm.transformers_utils.processor import cached_processor_from_config from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens, encode_tokens) -from vllm.utils import flatten_2d_lists, full_groupby +from vllm.utils import (flatten_2d_lists, full_groupby, + get_allowed_kwarg_only_overrides) +from vllm.utils.jsontree import JSONTree, json_map_leaves from .hasher import MultiModalHasher from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, @@ -34,6 +37,8 @@ if TYPE_CHECKING: from transformers.feature_extraction_utils import BatchFeature from transformers.processing_utils import ProcessorMixin + from vllm.config import ModelConfig + from .cache import BaseMultiModalProcessorCache from .profiling import BaseDummyInputsBuilder @@ -875,6 +880,222 @@ def find_mm_placeholders( return dict(full_groupby_modality(it)) +_T = TypeVar("_T") +_C = TypeVar("_C", bound="PretrainedConfig", default="PretrainedConfig") +_P = TypeVar("_P", bound="ProcessorMixin", default="ProcessorMixin") + + +@dataclass(frozen=True) +class InputProcessingContext: + """ + Contains information about the model which may be used to + modify the inputs. + """ + + model_config: "ModelConfig" + """The configuration of the model.""" + + tokenizer: AnyTokenizer + """The tokenizer used to tokenize the inputs.""" + + @overload + def get_hf_config(self, /) -> "PretrainedConfig": + ... + + @overload + def get_hf_config( + self, + typ: Union[type[_C], tuple[type[_C], ...]], + /, + ) -> _C: + ... + + def get_hf_config( + self, + typ: Optional[Union[type[Any], tuple[type[Any], ...]]] = None, + /, + ) -> Any: + """ + Get the HuggingFace configuration + (`transformers.PretrainedConfig`) of the model, + additionally checking its type. + + Raises: + TypeError: If the configuration is not of the specified type. + """ + if typ is None: + from transformers.configuration_utils import PretrainedConfig + + typ = PretrainedConfig + + hf_config = self.model_config.hf_config + if not isinstance(hf_config, typ): + raise TypeError("Invalid type of HuggingFace config. " + f"Expected type: {typ}, but " + f"found type: {type(hf_config)}") + + return hf_config + + def get_hf_image_processor_config(self) -> dict[str, Any]: + """ + Get the HuggingFace image processor configuration of the model. + """ + return self.model_config.hf_image_processor_config + + def get_mm_config(self): + """ + Get the multimodal config of the model. + + Raises: + RuntimeError: If the model is not a multimodal model. + """ + mm_config = self.model_config.multimodal_config + if mm_config is None: + raise RuntimeError("Not a multimodal model") + + return mm_config + + @overload + def get_hf_processor(self, /, **kwargs: object) -> "ProcessorMixin": + ... + + @overload + def get_hf_processor( + self, + typ: Union[type[_P], tuple[type[_P], ...]], + /, + **kwargs: object, + ) -> _P: + ... + + def get_hf_processor( + self, + typ: Optional[Union[type[Any], tuple[type[Any], ...]]] = None, + /, + **kwargs: object, + ) -> Any: + """ + Get the HuggingFace processor + (`transformers.ProcessorMixin`) of the model, + additionally checking its type. + + Raises: + TypeError: If the processor is not of the specified type. + """ + if typ is None: + from transformers.processing_utils import ProcessorMixin + + typ = ProcessorMixin + + return cached_processor_from_config( + self.model_config, + processor_cls=typ, + tokenizer=self.tokenizer, + **kwargs, + ) + + def init_processor( + self, + typ: type[_T], + /, + **kwargs: object, + ) -> _T: + """ + Initialize a HuggingFace-like processor class, merging the + keyword arguments with those in the model's configuration. + """ + mm_config = self.model_config.get_multimodal_config() + base_kwargs = mm_config.mm_processor_kwargs + if base_kwargs is None: + base_kwargs = {} + + merged_kwargs = {**base_kwargs, **kwargs} + + return typ(**merged_kwargs) + + def _postprocess_output( + self, + output: JSONTree, + ) -> JSONTree: + + def _postprocess_one(x: object): + if isinstance(x, torch.Tensor): # noqa: SIM102 + # This mimics the behavior of transformers.BatchFeature + if x.is_floating_point(): + x = x.to(dtype=self.model_config.dtype) + + return x + + return json_map_leaves(_postprocess_one, output) + + def call_hf_processor( + self, + hf_processor: "ProcessorMixin", + data: Mapping[str, object], + kwargs: Mapping[str, object] = {}, + *, + num_tries: int = 1, + max_tries: int = 5, + ) -> Union["BatchFeature", JSONTree]: + """ + Call `hf_processor` on the prompt `data` + (text, image, audio...) with configurable options `kwargs`. + """ + assert callable(hf_processor) + + mm_config = self.model_config.get_multimodal_config() + merged_kwargs = mm_config.merge_mm_processor_kwargs(kwargs) + + allowed_kwargs = get_allowed_kwarg_only_overrides( + hf_processor, + merged_kwargs, + requires_kw_only=False, + allow_var_kwargs=True, + ) + + try: + output = hf_processor(**data, + **allowed_kwargs, + return_tensors="pt") + except Exception as exc: + # See https://github.com/huggingface/tokenizers/issues/537 + if (isinstance(exc, RuntimeError) and exc + and exc.args[0] == "Already borrowed" + and num_tries < max_tries): + logger.warning( + "Failed to acquire tokenizer in current thread. " + "Retrying (%d/%d)...", num_tries, max_tries) + time.sleep(0.5) + return self.call_hf_processor( + hf_processor, + data, + kwargs, + num_tries=num_tries + 1, + max_tries=max_tries, + ) + + msg = (f"Failed to apply {type(hf_processor).__name__} " + f"on data={data} with kwargs={allowed_kwargs}") + + raise ValueError(msg) from exc + + # this emulates output.to(dtype=self.model_config.dtype) + from transformers.feature_extraction_utils import BatchFeature + + if isinstance(output, BatchFeature): + output_ = self._postprocess_output(output.data) + return BatchFeature(output_) + + logger.warning_once( + "%s did not return `BatchFeature`. " + "Make sure to match the behaviour of `ProcessorMixin` when " + "implementing custom processors.", + type(hf_processor).__name__, + ) + + return self._postprocess_output(output) + + class BaseProcessingInfo: """Base class to provide the information necessary for data processing.""" diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 5d485bc361d1..2bbc0078ad13 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -6,14 +6,14 @@ from typing import TYPE_CHECKING, Generic, Optional, Protocol, TypeVar import torch.nn as nn -from vllm.inputs import InputProcessingContext from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import (AnyTokenizer, cached_tokenizer_from_config) from vllm.utils import ClassRegistry from .cache import BaseMultiModalProcessorCache -from .processing import BaseMultiModalProcessor, BaseProcessingInfo +from .processing import (BaseMultiModalProcessor, BaseProcessingInfo, + InputProcessingContext) from .profiling import (BaseDummyInputsBuilder, DummyDecoderData, DummyEncoderData, MultiModalProfiler) @@ -41,7 +41,7 @@ class ProcessingInfoFactory(Protocol[_I_co]): ... -class DummyInputsBuilderFactory(Protocol[_I]): +class DummyInputsBuilderFactory(Protocol[_I]): # type: ignore[misc] """ Constructs a [`BaseDummyInputsBuilder`][vllm.multimodal.profiling.BaseDummyInputsBuilder]