mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 17:05:37 +08:00
Speed up mm processor kwargs per request by spliting dynamic and static kwargs (#26483)
Signed-off-by: Junhong <liujunhong11@huawei.com> Signed-off-by: Junhong Liu <98734602+LJH-LBJ@users.noreply.github.com> Co-authored-by: Junhong <liujunhong11@huawei.com>
This commit is contained in:
parent
827e4237bc
commit
59b453eaa2
@ -0,0 +1,66 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import importlib
|
||||
|
||||
from transformers.processing_utils import ProcessingKwargs
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from vllm.transformers_utils.processor import (
|
||||
get_processor_kwargs_from_processor,
|
||||
)
|
||||
|
||||
|
||||
class _FakeProcessorKwargs(ProcessingKwargs, total=False): # type: ignore
|
||||
pass
|
||||
|
||||
|
||||
def _assert_has_all_expected(keys: set[str]) -> None:
|
||||
# text
|
||||
for k in ("text_pair", "text_target", "text_pair_target"):
|
||||
assert k in keys
|
||||
# image
|
||||
for k in ("do_convert_rgb", "do_resize"):
|
||||
assert k in keys
|
||||
# audio
|
||||
for k in (
|
||||
"fps",
|
||||
"do_sample_frames",
|
||||
"input_data_format",
|
||||
"default_to_square",
|
||||
):
|
||||
assert k in keys
|
||||
# audio
|
||||
for k in ("padding", "return_attention_mask"):
|
||||
assert k in keys
|
||||
|
||||
|
||||
# Path 1: __call__ method has kwargs: Unpack[*ProcessingKwargs]
|
||||
class _ProcWithUnpack:
|
||||
def __call__(self, *args, **kwargs: Unpack[_FakeProcessorKwargs]): # type: ignore
|
||||
return None
|
||||
|
||||
|
||||
def test_get_processor_kwargs_from_processor_unpack_path_returns_full_union():
|
||||
proc = _ProcWithUnpack()
|
||||
keys = get_processor_kwargs_from_processor(proc)
|
||||
_assert_has_all_expected(keys)
|
||||
|
||||
|
||||
# ---- Path 2: No Unpack, fallback to scanning *ProcessingKwargs in module ----
|
||||
|
||||
|
||||
class _ProcWithoutUnpack:
|
||||
def __call__(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
|
||||
def test_get_processor_kwargs_from_processor_module_scan_returns_full_union():
|
||||
# ensure the module scanned by fallback is this test module
|
||||
module_name = _ProcWithoutUnpack.__module__
|
||||
mod = importlib.import_module(module_name)
|
||||
assert hasattr(mod, "_FakeProcessorKwargs")
|
||||
|
||||
proc = _ProcWithoutUnpack()
|
||||
keys = get_processor_kwargs_from_processor(proc)
|
||||
_assert_has_all_expected(keys)
|
||||
@ -1,8 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import TYPE_CHECKING, Any, cast, get_args, get_type_hints
|
||||
|
||||
from transformers import (
|
||||
AutoFeatureExtractor,
|
||||
@ -55,6 +57,23 @@ def _get_processor_factory_fn(processor_cls: type | tuple[type, ...]):
|
||||
return processor_cls
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _collect_dynamic_keys_from_processing_kwargs(kwargs_cls: type) -> set[str]:
|
||||
dynamic_kwargs: set[str] = set()
|
||||
if kwargs_cls is None:
|
||||
return dynamic_kwargs
|
||||
# get kwargs annotations in processor
|
||||
# merge text_kwargs / images_kwargs / videos_kwargs / audio_kwargs
|
||||
kwargs_type_annotations = get_type_hints(kwargs_cls)
|
||||
for kw_type in ("text_kwargs", "images_kwargs", "videos_kwargs", "audio_kwargs"):
|
||||
if kw_type in kwargs_type_annotations:
|
||||
kw_annotations = get_type_hints(kwargs_type_annotations[kw_type])
|
||||
for kw_name in kw_annotations:
|
||||
dynamic_kwargs.add(kw_name)
|
||||
dynamic_kwargs |= {"text_kwargs", "images_kwargs", "videos_kwargs", "audio_kwargs"}
|
||||
return dynamic_kwargs
|
||||
|
||||
|
||||
def _merge_mm_kwargs(
|
||||
model_config: "ModelConfig",
|
||||
processor_cls: type | tuple[type, ...],
|
||||
@ -71,7 +90,6 @@ def _merge_mm_kwargs(
|
||||
requires_kw_only=False,
|
||||
allow_var_kwargs=True,
|
||||
)
|
||||
|
||||
# NOTE: Pythonic dict is not hashable and will raise unhashable type
|
||||
# error when calling `cached_get_processor`, therefore we need to
|
||||
# wrap it to a hashable dict.
|
||||
@ -145,12 +163,80 @@ def get_processor(
|
||||
cached_get_processor = lru_cache(get_processor)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_processor_kwargs_from_processor(processor: _P) -> set[str]:
|
||||
try:
|
||||
# get kwargs annotations in processor
|
||||
call_kwargs = inspect.signature(type(processor).__call__).parameters.get(
|
||||
"kwargs"
|
||||
)
|
||||
call_kwargs_annotations = call_kwargs.annotation if call_kwargs else None
|
||||
# if the processor has explicit kwargs annotation, use it
|
||||
if call_kwargs_annotations not in (None, inspect._empty):
|
||||
# get_type_hints will parse all type annotations at runtime,
|
||||
# and if an annotation refers to a type or
|
||||
# name that hasn’t been imported or defined, it will raise an error.
|
||||
# So we use __annotations__ to get the raw annotations directly.
|
||||
return _collect_dynamic_keys_from_processing_kwargs(
|
||||
get_args(call_kwargs_annotations)[0]
|
||||
)
|
||||
# otherwise, try to get from ProcessingKwargs
|
||||
else:
|
||||
module_name = type(processor).__module__
|
||||
mod = importlib.import_module(module_name)
|
||||
# find *ProcessingKwargs in the module
|
||||
processor_kwargs: set[str] = set()
|
||||
for name, obj in vars(mod).items():
|
||||
if name.endswith("ProcessingKwargs"):
|
||||
processor_kwargs = (
|
||||
processor_kwargs
|
||||
| _collect_dynamic_keys_from_processing_kwargs(obj)
|
||||
)
|
||||
return processor_kwargs
|
||||
except Exception:
|
||||
return set()
|
||||
|
||||
|
||||
def cached_get_processor_without_dynamic_kwargs(
|
||||
processor_name: str,
|
||||
*args: Any,
|
||||
revision: str | None = None,
|
||||
trust_remote_code: bool = False,
|
||||
processor_cls: type[_P] | tuple[type[_P], ...] = ProcessorMixin,
|
||||
**kwargs: Any,
|
||||
) -> _P:
|
||||
# Step 1: use default kwargs to get a temporary processor instance
|
||||
processor = cached_get_processor(
|
||||
processor_name,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
processor_cls=processor_cls, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# Step 2: use temporary processor collect dynamic keys
|
||||
dynamic_keys = get_processor_kwargs_from_processor(processor)
|
||||
|
||||
# Step 3: use dynamic_keys filter kwargs
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if k not in dynamic_keys}
|
||||
|
||||
# Step 4: use filtered kwargs to get final processor instance
|
||||
final_processor = cached_get_processor(
|
||||
processor_name,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
processor_cls=processor_cls, # type: ignore[arg-type]
|
||||
**filtered_kwargs,
|
||||
)
|
||||
|
||||
return final_processor
|
||||
|
||||
|
||||
def cached_processor_from_config(
|
||||
model_config: "ModelConfig",
|
||||
processor_cls: type[_P] | tuple[type[_P], ...] = ProcessorMixin,
|
||||
**kwargs: Any,
|
||||
) -> _P:
|
||||
return cached_get_processor(
|
||||
return cached_get_processor_without_dynamic_kwargs(
|
||||
model_config.model,
|
||||
revision=model_config.revision,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user