diff --git a/tests/transformers_utils/test_get_processor_kwargs_from_processor.py b/tests/transformers_utils/test_get_processor_kwargs_from_processor.py new file mode 100644 index 000000000000..95ff9a557fa0 --- /dev/null +++ b/tests/transformers_utils/test_get_processor_kwargs_from_processor.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import importlib + +from transformers.processing_utils import ProcessingKwargs +from typing_extensions import Unpack + +from vllm.transformers_utils.processor import ( + get_processor_kwargs_from_processor, +) + + +class _FakeProcessorKwargs(ProcessingKwargs, total=False): # type: ignore + pass + + +def _assert_has_all_expected(keys: set[str]) -> None: + # text + for k in ("text_pair", "text_target", "text_pair_target"): + assert k in keys + # image + for k in ("do_convert_rgb", "do_resize"): + assert k in keys + # audio + for k in ( + "fps", + "do_sample_frames", + "input_data_format", + "default_to_square", + ): + assert k in keys + # audio + for k in ("padding", "return_attention_mask"): + assert k in keys + + +# Path 1: __call__ method has kwargs: Unpack[*ProcessingKwargs] +class _ProcWithUnpack: + def __call__(self, *args, **kwargs: Unpack[_FakeProcessorKwargs]): # type: ignore + return None + + +def test_get_processor_kwargs_from_processor_unpack_path_returns_full_union(): + proc = _ProcWithUnpack() + keys = get_processor_kwargs_from_processor(proc) + _assert_has_all_expected(keys) + + +# ---- Path 2: No Unpack, fallback to scanning *ProcessingKwargs in module ---- + + +class _ProcWithoutUnpack: + def __call__(self, *args, **kwargs): + return None + + +def test_get_processor_kwargs_from_processor_module_scan_returns_full_union(): + # ensure the module scanned by fallback is this test module + module_name = _ProcWithoutUnpack.__module__ + mod = importlib.import_module(module_name) + assert hasattr(mod, "_FakeProcessorKwargs") + + proc = _ProcWithoutUnpack() + keys = get_processor_kwargs_from_processor(proc) + _assert_has_all_expected(keys) diff --git a/vllm/transformers_utils/processor.py b/vllm/transformers_utils/processor.py index 8ba3aec454ad..b3469c1b18f2 100644 --- a/vllm/transformers_utils/processor.py +++ b/vllm/transformers_utils/processor.py @@ -1,8 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import importlib +import inspect from functools import lru_cache -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, cast, get_args, get_type_hints from transformers import ( AutoFeatureExtractor, @@ -55,6 +57,23 @@ def _get_processor_factory_fn(processor_cls: type | tuple[type, ...]): return processor_cls +@lru_cache +def _collect_dynamic_keys_from_processing_kwargs(kwargs_cls: type) -> set[str]: + dynamic_kwargs: set[str] = set() + if kwargs_cls is None: + return dynamic_kwargs + # get kwargs annotations in processor + # merge text_kwargs / images_kwargs / videos_kwargs / audio_kwargs + kwargs_type_annotations = get_type_hints(kwargs_cls) + for kw_type in ("text_kwargs", "images_kwargs", "videos_kwargs", "audio_kwargs"): + if kw_type in kwargs_type_annotations: + kw_annotations = get_type_hints(kwargs_type_annotations[kw_type]) + for kw_name in kw_annotations: + dynamic_kwargs.add(kw_name) + dynamic_kwargs |= {"text_kwargs", "images_kwargs", "videos_kwargs", "audio_kwargs"} + return dynamic_kwargs + + def _merge_mm_kwargs( model_config: "ModelConfig", processor_cls: type | tuple[type, ...], @@ -71,7 +90,6 @@ def _merge_mm_kwargs( requires_kw_only=False, allow_var_kwargs=True, ) - # NOTE: Pythonic dict is not hashable and will raise unhashable type # error when calling `cached_get_processor`, therefore we need to # wrap it to a hashable dict. @@ -145,12 +163,80 @@ def get_processor( cached_get_processor = lru_cache(get_processor) +@lru_cache +def get_processor_kwargs_from_processor(processor: _P) -> set[str]: + try: + # get kwargs annotations in processor + call_kwargs = inspect.signature(type(processor).__call__).parameters.get( + "kwargs" + ) + call_kwargs_annotations = call_kwargs.annotation if call_kwargs else None + # if the processor has explicit kwargs annotation, use it + if call_kwargs_annotations not in (None, inspect._empty): + # get_type_hints will parse all type annotations at runtime, + # and if an annotation refers to a type or + # name that hasn’t been imported or defined, it will raise an error. + # So we use __annotations__ to get the raw annotations directly. + return _collect_dynamic_keys_from_processing_kwargs( + get_args(call_kwargs_annotations)[0] + ) + # otherwise, try to get from ProcessingKwargs + else: + module_name = type(processor).__module__ + mod = importlib.import_module(module_name) + # find *ProcessingKwargs in the module + processor_kwargs: set[str] = set() + for name, obj in vars(mod).items(): + if name.endswith("ProcessingKwargs"): + processor_kwargs = ( + processor_kwargs + | _collect_dynamic_keys_from_processing_kwargs(obj) + ) + return processor_kwargs + except Exception: + return set() + + +def cached_get_processor_without_dynamic_kwargs( + processor_name: str, + *args: Any, + revision: str | None = None, + trust_remote_code: bool = False, + processor_cls: type[_P] | tuple[type[_P], ...] = ProcessorMixin, + **kwargs: Any, +) -> _P: + # Step 1: use default kwargs to get a temporary processor instance + processor = cached_get_processor( + processor_name, + revision=revision, + trust_remote_code=trust_remote_code, + processor_cls=processor_cls, # type: ignore[arg-type] + ) + + # Step 2: use temporary processor collect dynamic keys + dynamic_keys = get_processor_kwargs_from_processor(processor) + + # Step 3: use dynamic_keys filter kwargs + filtered_kwargs = {k: v for k, v in kwargs.items() if k not in dynamic_keys} + + # Step 4: use filtered kwargs to get final processor instance + final_processor = cached_get_processor( + processor_name, + revision=revision, + trust_remote_code=trust_remote_code, + processor_cls=processor_cls, # type: ignore[arg-type] + **filtered_kwargs, + ) + + return final_processor + + def cached_processor_from_config( model_config: "ModelConfig", processor_cls: type[_P] | tuple[type[_P], ...] = ProcessorMixin, **kwargs: Any, ) -> _P: - return cached_get_processor( + return cached_get_processor_without_dynamic_kwargs( model_config.model, revision=model_config.revision, trust_remote_code=model_config.trust_remote_code,