[Chore] Separate out vllm.utils.func (#26904)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-10-15 21:03:58 +08:00 committed by GitHub
parent f57438338d
commit 136a17fe6e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 407 additions and 371 deletions

View File

@ -17,7 +17,7 @@ from transformers import (
)
from vllm.platforms import current_platform
from vllm.utils import identity
from vllm.utils.func import identity
from ....conftest import (
IMAGE_ASSETS,

View File

@ -0,0 +1,97 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa
import pytest
from vllm.utils.func import deprecate_kwargs, supports_kw
from ..utils import error_on_warning
def test_deprecate_kwargs_always():
@deprecate_kwargs("old_arg", is_deprecated=True)
def dummy(*, old_arg: object = None, new_arg: object = None):
pass
with pytest.warns(DeprecationWarning, match="'old_arg'"):
dummy(old_arg=1)
with error_on_warning(DeprecationWarning):
dummy(new_arg=1)
def test_deprecate_kwargs_never():
@deprecate_kwargs("old_arg", is_deprecated=False)
def dummy(*, old_arg: object = None, new_arg: object = None):
pass
with error_on_warning(DeprecationWarning):
dummy(old_arg=1)
with error_on_warning(DeprecationWarning):
dummy(new_arg=1)
def test_deprecate_kwargs_dynamic():
is_deprecated = True
@deprecate_kwargs("old_arg", is_deprecated=lambda: is_deprecated)
def dummy(*, old_arg: object = None, new_arg: object = None):
pass
with pytest.warns(DeprecationWarning, match="'old_arg'"):
dummy(old_arg=1)
with error_on_warning(DeprecationWarning):
dummy(new_arg=1)
is_deprecated = False
with error_on_warning(DeprecationWarning):
dummy(old_arg=1)
with error_on_warning(DeprecationWarning):
dummy(new_arg=1)
def test_deprecate_kwargs_additional_message():
@deprecate_kwargs("old_arg", is_deprecated=True, additional_message="abcd")
def dummy(*, old_arg: object = None, new_arg: object = None):
pass
with pytest.warns(DeprecationWarning, match="abcd"):
dummy(old_arg=1)
@pytest.mark.parametrize(
("callable", "kw_name", "requires_kw_only", "allow_var_kwargs", "is_supported"),
[
# Tests for positional argument support
(lambda foo: None, "foo", True, True, False),
(lambda foo: None, "foo", False, True, True),
# Tests for positional or keyword / keyword only
(lambda foo=100: None, "foo", True, True, False),
(lambda *, foo: None, "foo", False, True, True),
# Tests to make sure the names of variadic params are NOT supported
(lambda *args: None, "args", False, True, False),
(lambda **kwargs: None, "kwargs", False, True, False),
# Tests for if we allow var kwargs to add support
(lambda foo: None, "something_else", False, True, False),
(lambda foo, **kwargs: None, "something_else", False, True, True),
(lambda foo, **kwargs: None, "kwargs", True, True, False),
(lambda foo, **kwargs: None, "foo", True, True, False),
],
)
def test_supports_kw(
callable, kw_name, requires_kw_only, allow_var_kwargs, is_supported
):
assert (
supports_kw(
callable=callable,
kw_name=kw_name,
requires_kw_only=requires_kw_only,
allow_var_kwargs=allow_var_kwargs,
)
== is_supported
)

View File

@ -0,0 +1,32 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.utils.jsontree import json_count_leaves
def test_json_count_leaves():
"""Test json_count_leaves function from jsontree utility."""
# Single leaf values
assert json_count_leaves(42) == 1
assert json_count_leaves("hello") == 1
assert json_count_leaves(None) == 1
# Empty containers
assert json_count_leaves([]) == 0
assert json_count_leaves({}) == 0
assert json_count_leaves(()) == 0
# Flat structures
assert json_count_leaves([1, 2, 3]) == 3
assert json_count_leaves({"a": 1, "b": 2}) == 2
assert json_count_leaves((1, 2, 3)) == 3
# Nested structures
nested_dict = {"a": 1, "b": {"c": 2, "d": 3}}
assert json_count_leaves(nested_dict) == 3
nested_list = [1, [2, 3], 4]
assert json_count_leaves(nested_list) == 4
mixed_nested = {"list": [1, 2], "dict": {"x": 3}, "value": 4}
assert json_count_leaves(mixed_nested) == 4

View File

@ -30,7 +30,6 @@ from vllm.utils import (
bind_kv_cache,
common_broadcastable_dtype,
current_stream,
deprecate_kwargs,
get_open_port,
get_tcp_uri,
is_lossless_cast,
@ -42,12 +41,11 @@ from vllm.utils import (
sha256,
split_host_port,
split_zmq_path,
supports_kw,
swap_dict_values,
unique_filepath,
)
from ..utils import create_new_process_for_each_test, error_on_warning
from ..utils import create_new_process_for_each_test
@pytest.mark.asyncio
@ -83,61 +81,6 @@ async def test_merge_async_iterators():
raise AssertionError() from e
def test_deprecate_kwargs_always():
@deprecate_kwargs("old_arg", is_deprecated=True)
def dummy(*, old_arg: object = None, new_arg: object = None):
pass
with pytest.warns(DeprecationWarning, match="'old_arg'"):
dummy(old_arg=1)
with error_on_warning(DeprecationWarning):
dummy(new_arg=1)
def test_deprecate_kwargs_never():
@deprecate_kwargs("old_arg", is_deprecated=False)
def dummy(*, old_arg: object = None, new_arg: object = None):
pass
with error_on_warning(DeprecationWarning):
dummy(old_arg=1)
with error_on_warning(DeprecationWarning):
dummy(new_arg=1)
def test_deprecate_kwargs_dynamic():
is_deprecated = True
@deprecate_kwargs("old_arg", is_deprecated=lambda: is_deprecated)
def dummy(*, old_arg: object = None, new_arg: object = None):
pass
with pytest.warns(DeprecationWarning, match="'old_arg'"):
dummy(old_arg=1)
with error_on_warning(DeprecationWarning):
dummy(new_arg=1)
is_deprecated = False
with error_on_warning(DeprecationWarning):
dummy(old_arg=1)
with error_on_warning(DeprecationWarning):
dummy(new_arg=1)
def test_deprecate_kwargs_additional_message():
@deprecate_kwargs("old_arg", is_deprecated=True, additional_message="abcd")
def dummy(*, old_arg: object = None, new_arg: object = None):
pass
with pytest.warns(DeprecationWarning, match="abcd"):
dummy(old_arg=1)
def test_get_open_port(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
m.setenv("VLLM_PORT", "5678")
@ -383,39 +326,6 @@ def test_duplicate_dict_args(caplog_vllm, parser):
assert "-O.mode" in caplog_vllm.text
@pytest.mark.parametrize(
"callable,kw_name,requires_kw_only,allow_var_kwargs,is_supported",
[
# Tests for positional argument support
(lambda foo: None, "foo", True, True, False),
(lambda foo: None, "foo", False, True, True),
# Tests for positional or keyword / keyword only
(lambda foo=100: None, "foo", True, True, False),
(lambda *, foo: None, "foo", False, True, True),
# Tests to make sure the names of variadic params are NOT supported
(lambda *args: None, "args", False, True, False),
(lambda **kwargs: None, "kwargs", False, True, False),
# Tests for if we allow var kwargs to add support
(lambda foo: None, "something_else", False, True, False),
(lambda foo, **kwargs: None, "something_else", False, True, True),
(lambda foo, **kwargs: None, "kwargs", True, True, False),
(lambda foo, **kwargs: None, "foo", True, True, False),
],
)
def test_supports_kw(
callable, kw_name, requires_kw_only, allow_var_kwargs, is_supported
):
assert (
supports_kw(
callable=callable,
kw_name=kw_name,
requires_kw_only=requires_kw_only,
allow_var_kwargs=allow_var_kwargs,
)
== is_supported
)
@create_new_process_for_each_test()
def test_memory_profiling():
# Fake out some model loading + inference memory usage to test profiling
@ -863,36 +773,6 @@ def test_join_host_port():
assert join_host_port("::1", 5555) == "[::1]:5555"
def test_json_count_leaves():
"""Test json_count_leaves function from jsontree utility."""
from vllm.utils.jsontree import json_count_leaves
# Single leaf values
assert json_count_leaves(42) == 1
assert json_count_leaves("hello") == 1
assert json_count_leaves(None) == 1
# Empty containers
assert json_count_leaves([]) == 0
assert json_count_leaves({}) == 0
assert json_count_leaves(()) == 0
# Flat structures
assert json_count_leaves([1, 2, 3]) == 3
assert json_count_leaves({"a": 1, "b": 2}) == 2
assert json_count_leaves((1, 2, 3)) == 3
# Nested structures
nested_dict = {"a": 1, "b": {"c": 2, "d": 3}}
assert json_count_leaves(nested_dict) == 3
nested_list = [1, [2, 3], 4]
assert json_count_leaves(nested_list) == 4
mixed_nested = {"list": [1, 2], "dict": {"x": 3}, "value": 4}
assert json_count_leaves(mixed_nested) == 4
def test_convert_ids_list_to_tokens():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
token_ids = tokenizer.encode("Hello, world!")

View File

@ -50,7 +50,8 @@ from vllm.multimodal.utils import MediaConnector
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import random_uuid, supports_kw
from vllm.utils import random_uuid
from vllm.utils.func import supports_kw
logger = init_logger(__name__)

View File

@ -94,10 +94,10 @@ from vllm.utils import (
AsyncMicrobatchTokenizer,
collect_from_async_generator,
is_list_of,
make_async,
merge_async_iterators,
random_uuid,
)
from vllm.utils.func import make_async
from vllm.v1.engine import EngineCoreRequest
logger = init_logger(__name__)

View File

@ -37,7 +37,8 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import make_async, merge_async_iterators
from vllm.utils import merge_async_iterators
from vllm.utils.func import make_async
logger = init_logger(__name__)

View File

@ -17,7 +17,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest
from vllm.tasks import SupportedTask
from vllm.utils import make_async
from vllm.utils.func import make_async
from vllm.v1.outputs import SamplerOutput
from vllm.v1.worker.worker_base import WorkerBase

View File

@ -24,8 +24,8 @@ from vllm.utils import (
get_distributed_init_method,
get_ip,
get_open_port,
make_async,
)
from vllm.utils.func import make_async
from vllm.v1.outputs import SamplerOutput
if ray is not None:

View File

@ -27,8 +27,9 @@ from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
)
from vllm.utils import has_deep_gemm, run_once
from vllm.utils import has_deep_gemm
from vllm.utils.deep_gemm import m_grouped_fp8_gemm_nt_contiguous
from vllm.utils.func import run_once
logger = init_logger(__name__)

View File

@ -24,7 +24,7 @@ from vllm.inputs import TokensPrompt
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.utils import supports_kw
from vllm.utils.func import supports_kw
from .interfaces_base import VllmModel, is_pooling_model

View File

@ -15,7 +15,7 @@ import torch.nn as nn
from typing_extensions import TypeIs, TypeVar
from vllm.logger import init_logger
from vllm.utils import supports_kw
from vllm.utils.func import supports_kw
if TYPE_CHECKING:
from vllm.config import VllmConfig

View File

@ -25,7 +25,8 @@ from typing_extensions import TypeVar, assert_never
from vllm.logger import init_logger
from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.transformers_utils.tokenizer import AnyTokenizer, decode_tokens, encode_tokens
from vllm.utils import flatten_2d_lists, full_groupby, get_allowed_kwarg_only_overrides
from vllm.utils import flatten_2d_lists, full_groupby
from vllm.utils.func import get_allowed_kwarg_only_overrides
from vllm.utils.jsontree import JSONTree, json_map_leaves
from .hasher import MultiModalHasher

View File

@ -5,7 +5,7 @@ import os
from collections.abc import Mapping
from vllm.logger import init_logger
from vllm.utils import run_once
from vllm.utils.func import run_once
TRACE_HEADERS = ["traceparent", "tracestate"]

View File

@ -16,7 +16,7 @@ from transformers.processing_utils import ProcessorMixin
from transformers.video_processing_utils import BaseVideoProcessor
from typing_extensions import TypeVar
from vllm.utils import get_allowed_kwarg_only_overrides
from vllm.utils.func import get_allowed_kwarg_only_overrides
if TYPE_CHECKING:
from vllm.config import ModelConfig

View File

@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import concurrent
import contextlib
import datetime
import enum
@ -43,7 +42,6 @@ from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
from collections import UserDict, defaultdict
from collections.abc import (
AsyncGenerator,
Awaitable,
Callable,
Collection,
Generator,
@ -85,7 +83,7 @@ from packaging import version
from packaging.version import Version
from torch.library import Library
from transformers.tokenization_utils_base import BatchEncoding
from typing_extensions import Never, ParamSpec, TypeIs, assert_never
from typing_extensions import Never, TypeIs, assert_never
import vllm.envs as envs
from vllm.logger import enable_trace_function_call, init_logger
@ -174,7 +172,6 @@ def set_default_torch_num_threads(num_threads: int):
torch.set_num_threads(old_num_threads)
P = ParamSpec("P")
T = TypeVar("T")
U = TypeVar("U")
@ -452,24 +449,6 @@ def in_loop(event_loop: AbstractEventLoop) -> bool:
return False
def make_async(
func: Callable[P, T], executor: concurrent.futures.Executor | None = None
) -> Callable[P, Awaitable[T]]:
"""Take a blocking function, and run it on in an executor thread.
This function prevents the blocking function from blocking the
asyncio event loop.
The code in this function needs to be thread safe.
"""
def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future:
loop = asyncio.get_event_loop()
p_func = partial(func, *args, **kwargs)
return loop.run_in_executor(executor=executor, func=p_func)
return _async_wrapper
async def merge_async_iterators(
*iterators: AsyncGenerator[T, None],
) -> AsyncGenerator[tuple[int, T], None]:
@ -1254,90 +1233,6 @@ def enable_trace_function_call_for_thread(vllm_config: VllmConfig) -> None:
enable_trace_function_call(log_path)
# `functools` helpers
def identity(value: T, **kwargs) -> T:
"""Returns the first provided value."""
return value
F = TypeVar("F", bound=Callable[..., Any])
def deprecate_args(
start_index: int,
is_deprecated: bool | Callable[[], bool] = True,
additional_message: str | None = None,
) -> Callable[[F], F]:
if not callable(is_deprecated):
is_deprecated = partial(identity, is_deprecated)
def wrapper(fn: F) -> F:
params = inspect.signature(fn).parameters
pos_types = (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
)
pos_kws = [kw for kw, param in params.items() if param.kind in pos_types]
@wraps(fn)
def inner(*args, **kwargs):
if is_deprecated():
deprecated_args = pos_kws[start_index : len(args)]
if deprecated_args:
msg = (
f"The positional arguments {deprecated_args} are "
"deprecated and will be removed in a future update."
)
if additional_message is not None:
msg += f" {additional_message}"
warnings.warn(
DeprecationWarning(msg),
stacklevel=3, # The inner function takes up one level
)
return fn(*args, **kwargs)
return inner # type: ignore
return wrapper
def deprecate_kwargs(
*kws: str,
is_deprecated: bool | Callable[[], bool] = True,
additional_message: str | None = None,
) -> Callable[[F], F]:
deprecated_kws = set(kws)
if not callable(is_deprecated):
is_deprecated = partial(identity, is_deprecated)
def wrapper(fn: F) -> F:
@wraps(fn)
def inner(*args, **kwargs):
if is_deprecated():
deprecated_kwargs = kwargs.keys() & deprecated_kws
if deprecated_kwargs:
msg = (
f"The keyword arguments {deprecated_kwargs} are "
"deprecated and will be removed in a future update."
)
if additional_message is not None:
msg += f" {additional_message}"
warnings.warn(
DeprecationWarning(msg),
stacklevel=3, # The inner function takes up one level
)
return fn(*args, **kwargs)
return inner # type: ignore
return wrapper
@lru_cache(maxsize=8)
def _cuda_device_count_stateless(cuda_visible_devices: str | None = None) -> int:
# Note: cuda_visible_devices is not used, but we keep it as an argument for
@ -1426,21 +1321,6 @@ def weak_bind(
return weak_bound
def run_once(f: Callable[P, None]) -> Callable[P, None]:
def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
if wrapper.has_run: # type: ignore[attr-defined]
return
with wrapper.lock: # type: ignore[attr-defined]
if not wrapper.has_run: # type: ignore[attr-defined]
wrapper.has_run = True # type: ignore[attr-defined]
return f(*args, **kwargs)
wrapper.has_run = False # type: ignore[attr-defined]
wrapper.lock = threading.Lock() # type: ignore[attr-defined]
return wrapper
class StoreBoolean(Action):
def __call__(self, parser, namespace, values, option_string=None):
if values.lower() == "true":
@ -1929,122 +1809,6 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, **kwarg
return await task(*args, **kwargs)
@lru_cache
def supports_kw(
callable: Callable[..., object],
kw_name: str,
*,
requires_kw_only: bool = False,
allow_var_kwargs: bool = True,
) -> bool:
"""Check if a keyword is a valid kwarg for a callable; if requires_kw_only
disallows kwargs names that can also be positional arguments.
"""
params = inspect.signature(callable).parameters
if not params:
return False
param_val = params.get(kw_name)
# Types where the it may be valid, i.e., explicitly defined & nonvariadic
passable_kw_types = set(
(
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
)
)
if param_val:
is_sig_param = param_val.kind in passable_kw_types
# We want kwargs only, but this is passable as a positional arg
if (
requires_kw_only
and is_sig_param
and param_val.kind != inspect.Parameter.KEYWORD_ONLY
):
return False
if (requires_kw_only and param_val.kind == inspect.Parameter.KEYWORD_ONLY) or (
not requires_kw_only and is_sig_param
):
return True
# If we're okay with var-kwargs, it's supported as long as
# the kw_name isn't something like *args, **kwargs
if allow_var_kwargs:
# Get the last param; type is ignored here because params is a proxy
# mapping, but it wraps an ordered dict, and they appear in order.
# Ref: https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters
last_param = params[next(reversed(params))] # type: ignore
return (
last_param.kind == inspect.Parameter.VAR_KEYWORD
and last_param.name != kw_name
)
return False
def get_allowed_kwarg_only_overrides(
callable: Callable[..., object],
overrides: Mapping[str, object] | None,
*,
requires_kw_only: bool = True,
allow_var_kwargs: bool = False,
) -> dict[str, Any]:
"""
Given a callable which has one or more keyword only params and a dict
mapping param names to values, drop values that can be not be kwarg
expanded to overwrite one or more keyword-only args. This is used in a
few places to handle custom processor overrides for multimodal models,
e.g., for profiling when processor options provided by the user
may affect the number of mm tokens per instance.
Args:
callable: Callable which takes 0 or more keyword only arguments.
If None is provided, all overrides names are allowed.
overrides: Potential overrides to be used when invoking the callable.
allow_var_kwargs: Allows overrides that are expandable for var kwargs.
Returns:
Dictionary containing the kwargs to be leveraged which may be used
to overwrite one or more keyword only arguments when invoking the
callable.
"""
if not overrides:
return {}
# Drop any mm_processor_kwargs provided by the user that
# are not kwargs, unless it can fit it var_kwargs param
filtered_overrides = {
kwarg_name: val
for kwarg_name, val in overrides.items()
if supports_kw(
callable,
kwarg_name,
requires_kw_only=requires_kw_only,
allow_var_kwargs=allow_var_kwargs,
)
}
# If anything is dropped, log a warning
dropped_keys = overrides.keys() - filtered_overrides.keys()
if dropped_keys:
if requires_kw_only:
logger.warning(
"The following intended overrides are not keyword-only args "
"and will be dropped: %s",
dropped_keys,
)
else:
logger.warning(
"The following intended overrides are not keyword args "
"and will be dropped: %s",
dropped_keys,
)
return filtered_overrides
# Using dynamo with vLLM doesn't really work well with PyTorch versions < 2.4.0.
# In particular, the FakeScalarType is not supported for earlier versions of
# PyTorch which breaks dynamo for any ops registered using ScalarType.

258
vllm/utils/func.py Normal file
View File

@ -0,0 +1,258 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Contains helpers that are applied to functions.
This is similar in concept to the `functools` module.
"""
import asyncio
import concurrent.futures
import inspect
import threading
import warnings
from collections.abc import Awaitable, Callable, Mapping
from functools import lru_cache, partial, wraps
from typing import Any, TypeVar
from typing_extensions import ParamSpec
from vllm.logger import init_logger
logger = init_logger(__name__)
P = ParamSpec("P")
T = TypeVar("T")
F = TypeVar("F", bound=Callable[..., Any])
def identity(value: T, **kwargs) -> T:
"""Returns the first provided value."""
return value
def make_async(
func: Callable[P, T],
executor: concurrent.futures.Executor | None = None,
) -> Callable[P, Awaitable[T]]:
"""
Take a blocking function, and run it on in an executor thread.
This function prevents the blocking function from blocking the
asyncio event loop.
The code in this function needs to be thread safe.
"""
def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future[T]:
loop = asyncio.get_event_loop()
p_func = partial(func, *args, **kwargs)
return loop.run_in_executor(executor=executor, func=p_func)
return _async_wrapper
def run_once(f: Callable[P, None]) -> Callable[P, None]:
def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
if wrapper.has_run: # type: ignore[attr-defined]
return
with wrapper.lock: # type: ignore[attr-defined]
if not wrapper.has_run: # type: ignore[attr-defined]
wrapper.has_run = True # type: ignore[attr-defined]
return f(*args, **kwargs)
wrapper.has_run = False # type: ignore[attr-defined]
wrapper.lock = threading.Lock() # type: ignore[attr-defined]
return wrapper
def deprecate_args(
start_index: int,
is_deprecated: bool | Callable[[], bool] = True,
additional_message: str | None = None,
) -> Callable[[F], F]:
if not callable(is_deprecated):
is_deprecated = partial(identity, is_deprecated)
def wrapper(fn: F) -> F:
params = inspect.signature(fn).parameters
pos_types = (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
)
pos_kws = [kw for kw, param in params.items() if param.kind in pos_types]
@wraps(fn)
def inner(*args, **kwargs):
if is_deprecated():
deprecated_args = pos_kws[start_index : len(args)]
if deprecated_args:
msg = (
f"The positional arguments {deprecated_args} are "
"deprecated and will be removed in a future update."
)
if additional_message is not None:
msg += f" {additional_message}"
warnings.warn(
DeprecationWarning(msg),
stacklevel=3, # The inner function takes up one level
)
return fn(*args, **kwargs)
return inner # type: ignore
return wrapper
def deprecate_kwargs(
*kws: str,
is_deprecated: bool | Callable[[], bool] = True,
additional_message: str | None = None,
) -> Callable[[F], F]:
deprecated_kws = set(kws)
if not callable(is_deprecated):
is_deprecated = partial(identity, is_deprecated)
def wrapper(fn: F) -> F:
@wraps(fn)
def inner(*args, **kwargs):
if is_deprecated():
deprecated_kwargs = kwargs.keys() & deprecated_kws
if deprecated_kwargs:
msg = (
f"The keyword arguments {deprecated_kwargs} are "
"deprecated and will be removed in a future update."
)
if additional_message is not None:
msg += f" {additional_message}"
warnings.warn(
DeprecationWarning(msg),
stacklevel=3, # The inner function takes up one level
)
return fn(*args, **kwargs)
return inner # type: ignore
return wrapper
@lru_cache
def supports_kw(
callable: Callable[..., object],
kw_name: str,
*,
requires_kw_only: bool = False,
allow_var_kwargs: bool = True,
) -> bool:
"""Check if a keyword is a valid kwarg for a callable; if requires_kw_only
disallows kwargs names that can also be positional arguments.
"""
params = inspect.signature(callable).parameters
if not params:
return False
param_val = params.get(kw_name)
# Types where the it may be valid, i.e., explicitly defined & nonvariadic
passable_kw_types = set(
(
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
)
)
if param_val:
is_sig_param = param_val.kind in passable_kw_types
# We want kwargs only, but this is passable as a positional arg
if (
requires_kw_only
and is_sig_param
and param_val.kind != inspect.Parameter.KEYWORD_ONLY
):
return False
if (requires_kw_only and param_val.kind == inspect.Parameter.KEYWORD_ONLY) or (
not requires_kw_only and is_sig_param
):
return True
# If we're okay with var-kwargs, it's supported as long as
# the kw_name isn't something like *args, **kwargs
if allow_var_kwargs:
# Get the last param; type is ignored here because params is a proxy
# mapping, but it wraps an ordered dict, and they appear in order.
# Ref: https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters
last_param = params[next(reversed(params))] # type: ignore
return (
last_param.kind == inspect.Parameter.VAR_KEYWORD
and last_param.name != kw_name
)
return False
def get_allowed_kwarg_only_overrides(
callable: Callable[..., object],
overrides: Mapping[str, object] | None,
*,
requires_kw_only: bool = True,
allow_var_kwargs: bool = False,
) -> dict[str, Any]:
"""
Given a callable which has one or more keyword only params and a dict
mapping param names to values, drop values that can be not be kwarg
expanded to overwrite one or more keyword-only args. This is used in a
few places to handle custom processor overrides for multimodal models,
e.g., for profiling when processor options provided by the user
may affect the number of mm tokens per instance.
Args:
callable: Callable which takes 0 or more keyword only arguments.
If None is provided, all overrides names are allowed.
overrides: Potential overrides to be used when invoking the callable.
allow_var_kwargs: Allows overrides that are expandable for var kwargs.
Returns:
Dictionary containing the kwargs to be leveraged which may be used
to overwrite one or more keyword only arguments when invoking the
callable.
"""
if not overrides:
return {}
# Drop any mm_processor_kwargs provided by the user that
# are not kwargs, unless it can fit it var_kwargs param
filtered_overrides = {
kwarg_name: val
for kwarg_name, val in overrides.items()
if supports_kw(
callable,
kwarg_name,
requires_kw_only=requires_kw_only,
allow_var_kwargs=allow_var_kwargs,
)
}
# If anything is dropped, log a warning
dropped_keys = overrides.keys() - filtered_overrides.keys()
if dropped_keys:
if requires_kw_only:
logger.warning(
"The following intended overrides are not keyword-only args "
"and will be dropped: %s",
dropped_keys,
)
else:
logger.warning(
"The following intended overrides are not keyword args "
"and will be dropped: %s",
dropped_keys,
)
return filtered_overrides

View File

@ -29,7 +29,8 @@ from vllm.tracing import init_tracer
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Device, as_list, cancel_task_threadsafe, cdiv, deprecate_kwargs
from vllm.utils import Device, as_list, cancel_task_threadsafe, cdiv
from vllm.utils.func import deprecate_kwargs
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError