vllm/vllm/utils/func_utils.py
Cyrus Leung d31f7844f8
[Misc] Move utils to avoid conflicts with stdlib, and move tests (#27169)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-10-19 05:20:55 -07:00

237 lines
7.5 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Contains helpers that are applied to functions.
This is similar in concept to the `functools` module.
"""
import inspect
import threading
import warnings
from collections.abc import Callable, Mapping
from functools import lru_cache, partial, wraps
from typing import Any, TypeVar
from typing_extensions import ParamSpec
from vllm.logger import init_logger
logger = init_logger(__name__)
P = ParamSpec("P")
T = TypeVar("T")
F = TypeVar("F", bound=Callable[..., Any])
def identity(value: T, **kwargs) -> T:
"""Returns the first provided value."""
return value
def run_once(f: Callable[P, None]) -> Callable[P, None]:
def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
if wrapper.has_run: # type: ignore[attr-defined]
return
with wrapper.lock: # type: ignore[attr-defined]
if not wrapper.has_run: # type: ignore[attr-defined]
wrapper.has_run = True # type: ignore[attr-defined]
return f(*args, **kwargs)
wrapper.has_run = False # type: ignore[attr-defined]
wrapper.lock = threading.Lock() # type: ignore[attr-defined]
return wrapper
def deprecate_args(
start_index: int,
is_deprecated: bool | Callable[[], bool] = True,
additional_message: str | None = None,
) -> Callable[[F], F]:
if not callable(is_deprecated):
is_deprecated = partial(identity, is_deprecated)
def wrapper(fn: F) -> F:
params = inspect.signature(fn).parameters
pos_types = (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
)
pos_kws = [kw for kw, param in params.items() if param.kind in pos_types]
@wraps(fn)
def inner(*args, **kwargs):
if is_deprecated():
deprecated_args = pos_kws[start_index : len(args)]
if deprecated_args:
msg = (
f"The positional arguments {deprecated_args} are "
"deprecated and will be removed in a future update."
)
if additional_message is not None:
msg += f" {additional_message}"
warnings.warn(
DeprecationWarning(msg),
stacklevel=3, # The inner function takes up one level
)
return fn(*args, **kwargs)
return inner # type: ignore
return wrapper
def deprecate_kwargs(
*kws: str,
is_deprecated: bool | Callable[[], bool] = True,
additional_message: str | None = None,
) -> Callable[[F], F]:
deprecated_kws = set(kws)
if not callable(is_deprecated):
is_deprecated = partial(identity, is_deprecated)
def wrapper(fn: F) -> F:
@wraps(fn)
def inner(*args, **kwargs):
if is_deprecated():
deprecated_kwargs = kwargs.keys() & deprecated_kws
if deprecated_kwargs:
msg = (
f"The keyword arguments {deprecated_kwargs} are "
"deprecated and will be removed in a future update."
)
if additional_message is not None:
msg += f" {additional_message}"
warnings.warn(
DeprecationWarning(msg),
stacklevel=3, # The inner function takes up one level
)
return fn(*args, **kwargs)
return inner # type: ignore
return wrapper
@lru_cache
def supports_kw(
callable: Callable[..., object],
kw_name: str,
*,
requires_kw_only: bool = False,
allow_var_kwargs: bool = True,
) -> bool:
"""Check if a keyword is a valid kwarg for a callable; if requires_kw_only
disallows kwargs names that can also be positional arguments.
"""
params = inspect.signature(callable).parameters
if not params:
return False
param_val = params.get(kw_name)
# Types where the it may be valid, i.e., explicitly defined & nonvariadic
passable_kw_types = set(
(
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
)
)
if param_val:
is_sig_param = param_val.kind in passable_kw_types
# We want kwargs only, but this is passable as a positional arg
if (
requires_kw_only
and is_sig_param
and param_val.kind != inspect.Parameter.KEYWORD_ONLY
):
return False
if (requires_kw_only and param_val.kind == inspect.Parameter.KEYWORD_ONLY) or (
not requires_kw_only and is_sig_param
):
return True
# If we're okay with var-kwargs, it's supported as long as
# the kw_name isn't something like *args, **kwargs
if allow_var_kwargs:
# Get the last param; type is ignored here because params is a proxy
# mapping, but it wraps an ordered dict, and they appear in order.
# Ref: https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters
last_param = params[next(reversed(params))] # type: ignore
return (
last_param.kind == inspect.Parameter.VAR_KEYWORD
and last_param.name != kw_name
)
return False
def get_allowed_kwarg_only_overrides(
callable: Callable[..., object],
overrides: Mapping[str, object] | None,
*,
requires_kw_only: bool = True,
allow_var_kwargs: bool = False,
) -> dict[str, Any]:
"""
Given a callable which has one or more keyword only params and a dict
mapping param names to values, drop values that can be not be kwarg
expanded to overwrite one or more keyword-only args. This is used in a
few places to handle custom processor overrides for multimodal models,
e.g., for profiling when processor options provided by the user
may affect the number of mm tokens per instance.
Args:
callable: Callable which takes 0 or more keyword only arguments.
If None is provided, all overrides names are allowed.
overrides: Potential overrides to be used when invoking the callable.
allow_var_kwargs: Allows overrides that are expandable for var kwargs.
Returns:
Dictionary containing the kwargs to be leveraged which may be used
to overwrite one or more keyword only arguments when invoking the
callable.
"""
if not overrides:
return {}
# Drop any mm_processor_kwargs provided by the user that
# are not kwargs, unless it can fit it var_kwargs param
filtered_overrides = {
kwarg_name: val
for kwarg_name, val in overrides.items()
if supports_kw(
callable,
kwarg_name,
requires_kw_only=requires_kw_only,
allow_var_kwargs=allow_var_kwargs,
)
}
# If anything is dropped, log a warning
dropped_keys = overrides.keys() - filtered_overrides.keys()
if dropped_keys:
if requires_kw_only:
logger.warning(
"The following intended overrides are not keyword-only args "
"and will be dropped: %s",
dropped_keys,
)
else:
logger.warning(
"The following intended overrides are not keyword args "
"and will be dropped: %s",
dropped_keys,
)
return filtered_overrides