Speed up mm processor kwargs per request by spliting dynamic and static kwargs (#26483)

Signed-off-by: Junhong <liujunhong11@huawei.com>
Signed-off-by: Junhong Liu <98734602+LJH-LBJ@users.noreply.github.com>
Co-authored-by: Junhong <liujunhong11@huawei.com>
This commit is contained in:
Junhong Liu 2025-11-07 07:51:28 +08:00 committed by GitHub
parent 827e4237bc
commit 59b453eaa2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 155 additions and 3 deletions

View File

@ -0,0 +1,66 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
from transformers.processing_utils import ProcessingKwargs
from typing_extensions import Unpack
from vllm.transformers_utils.processor import (
get_processor_kwargs_from_processor,
)
class _FakeProcessorKwargs(ProcessingKwargs, total=False): # type: ignore
pass
def _assert_has_all_expected(keys: set[str]) -> None:
# text
for k in ("text_pair", "text_target", "text_pair_target"):
assert k in keys
# image
for k in ("do_convert_rgb", "do_resize"):
assert k in keys
# audio
for k in (
"fps",
"do_sample_frames",
"input_data_format",
"default_to_square",
):
assert k in keys
# audio
for k in ("padding", "return_attention_mask"):
assert k in keys
# Path 1: __call__ method has kwargs: Unpack[*ProcessingKwargs]
class _ProcWithUnpack:
def __call__(self, *args, **kwargs: Unpack[_FakeProcessorKwargs]): # type: ignore
return None
def test_get_processor_kwargs_from_processor_unpack_path_returns_full_union():
proc = _ProcWithUnpack()
keys = get_processor_kwargs_from_processor(proc)
_assert_has_all_expected(keys)
# ---- Path 2: No Unpack, fallback to scanning *ProcessingKwargs in module ----
class _ProcWithoutUnpack:
def __call__(self, *args, **kwargs):
return None
def test_get_processor_kwargs_from_processor_module_scan_returns_full_union():
# ensure the module scanned by fallback is this test module
module_name = _ProcWithoutUnpack.__module__
mod = importlib.import_module(module_name)
assert hasattr(mod, "_FakeProcessorKwargs")
proc = _ProcWithoutUnpack()
keys = get_processor_kwargs_from_processor(proc)
_assert_has_all_expected(keys)

View File

@ -1,8 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
import inspect
from functools import lru_cache
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, cast, get_args, get_type_hints
from transformers import (
AutoFeatureExtractor,
@ -55,6 +57,23 @@ def _get_processor_factory_fn(processor_cls: type | tuple[type, ...]):
return processor_cls
@lru_cache
def _collect_dynamic_keys_from_processing_kwargs(kwargs_cls: type) -> set[str]:
dynamic_kwargs: set[str] = set()
if kwargs_cls is None:
return dynamic_kwargs
# get kwargs annotations in processor
# merge text_kwargs / images_kwargs / videos_kwargs / audio_kwargs
kwargs_type_annotations = get_type_hints(kwargs_cls)
for kw_type in ("text_kwargs", "images_kwargs", "videos_kwargs", "audio_kwargs"):
if kw_type in kwargs_type_annotations:
kw_annotations = get_type_hints(kwargs_type_annotations[kw_type])
for kw_name in kw_annotations:
dynamic_kwargs.add(kw_name)
dynamic_kwargs |= {"text_kwargs", "images_kwargs", "videos_kwargs", "audio_kwargs"}
return dynamic_kwargs
def _merge_mm_kwargs(
model_config: "ModelConfig",
processor_cls: type | tuple[type, ...],
@ -71,7 +90,6 @@ def _merge_mm_kwargs(
requires_kw_only=False,
allow_var_kwargs=True,
)
# NOTE: Pythonic dict is not hashable and will raise unhashable type
# error when calling `cached_get_processor`, therefore we need to
# wrap it to a hashable dict.
@ -145,12 +163,80 @@ def get_processor(
cached_get_processor = lru_cache(get_processor)
@lru_cache
def get_processor_kwargs_from_processor(processor: _P) -> set[str]:
try:
# get kwargs annotations in processor
call_kwargs = inspect.signature(type(processor).__call__).parameters.get(
"kwargs"
)
call_kwargs_annotations = call_kwargs.annotation if call_kwargs else None
# if the processor has explicit kwargs annotation, use it
if call_kwargs_annotations not in (None, inspect._empty):
# get_type_hints will parse all type annotations at runtime,
# and if an annotation refers to a type or
# name that hasnt been imported or defined, it will raise an error.
# So we use __annotations__ to get the raw annotations directly.
return _collect_dynamic_keys_from_processing_kwargs(
get_args(call_kwargs_annotations)[0]
)
# otherwise, try to get from ProcessingKwargs
else:
module_name = type(processor).__module__
mod = importlib.import_module(module_name)
# find *ProcessingKwargs in the module
processor_kwargs: set[str] = set()
for name, obj in vars(mod).items():
if name.endswith("ProcessingKwargs"):
processor_kwargs = (
processor_kwargs
| _collect_dynamic_keys_from_processing_kwargs(obj)
)
return processor_kwargs
except Exception:
return set()
def cached_get_processor_without_dynamic_kwargs(
processor_name: str,
*args: Any,
revision: str | None = None,
trust_remote_code: bool = False,
processor_cls: type[_P] | tuple[type[_P], ...] = ProcessorMixin,
**kwargs: Any,
) -> _P:
# Step 1: use default kwargs to get a temporary processor instance
processor = cached_get_processor(
processor_name,
revision=revision,
trust_remote_code=trust_remote_code,
processor_cls=processor_cls, # type: ignore[arg-type]
)
# Step 2: use temporary processor collect dynamic keys
dynamic_keys = get_processor_kwargs_from_processor(processor)
# Step 3: use dynamic_keys filter kwargs
filtered_kwargs = {k: v for k, v in kwargs.items() if k not in dynamic_keys}
# Step 4: use filtered kwargs to get final processor instance
final_processor = cached_get_processor(
processor_name,
revision=revision,
trust_remote_code=trust_remote_code,
processor_cls=processor_cls, # type: ignore[arg-type]
**filtered_kwargs,
)
return final_processor
def cached_processor_from_config(
model_config: "ModelConfig",
processor_cls: type[_P] | tuple[type[_P], ...] = ProcessorMixin,
**kwargs: Any,
) -> _P:
return cached_get_processor(
return cached_get_processor_without_dynamic_kwargs(
model_config.model,
revision=model_config.revision,
trust_remote_code=model_config.trust_remote_code,