diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index f124220bb16d9..af7dad079a9b3 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -17,7 +17,7 @@ from transformers import ( ) from vllm.platforms import current_platform -from vllm.utils import identity +from vllm.utils.func import identity from ....conftest import ( IMAGE_ASSETS, diff --git a/tests/utils_/test_func_utils.py b/tests/utils_/test_func_utils.py new file mode 100644 index 0000000000000..147a396994596 --- /dev/null +++ b/tests/utils_/test_func_utils.py @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa + +import pytest + +from vllm.utils.func import deprecate_kwargs, supports_kw + +from ..utils import error_on_warning + + +def test_deprecate_kwargs_always(): + @deprecate_kwargs("old_arg", is_deprecated=True) + def dummy(*, old_arg: object = None, new_arg: object = None): + pass + + with pytest.warns(DeprecationWarning, match="'old_arg'"): + dummy(old_arg=1) + + with error_on_warning(DeprecationWarning): + dummy(new_arg=1) + + +def test_deprecate_kwargs_never(): + @deprecate_kwargs("old_arg", is_deprecated=False) + def dummy(*, old_arg: object = None, new_arg: object = None): + pass + + with error_on_warning(DeprecationWarning): + dummy(old_arg=1) + + with error_on_warning(DeprecationWarning): + dummy(new_arg=1) + + +def test_deprecate_kwargs_dynamic(): + is_deprecated = True + + @deprecate_kwargs("old_arg", is_deprecated=lambda: is_deprecated) + def dummy(*, old_arg: object = None, new_arg: object = None): + pass + + with pytest.warns(DeprecationWarning, match="'old_arg'"): + dummy(old_arg=1) + + with error_on_warning(DeprecationWarning): + dummy(new_arg=1) + + is_deprecated = False + + with error_on_warning(DeprecationWarning): + dummy(old_arg=1) + + with error_on_warning(DeprecationWarning): + dummy(new_arg=1) + + +def test_deprecate_kwargs_additional_message(): + @deprecate_kwargs("old_arg", is_deprecated=True, additional_message="abcd") + def dummy(*, old_arg: object = None, new_arg: object = None): + pass + + with pytest.warns(DeprecationWarning, match="abcd"): + dummy(old_arg=1) + + +@pytest.mark.parametrize( + ("callable", "kw_name", "requires_kw_only", "allow_var_kwargs", "is_supported"), + [ + # Tests for positional argument support + (lambda foo: None, "foo", True, True, False), + (lambda foo: None, "foo", False, True, True), + # Tests for positional or keyword / keyword only + (lambda foo=100: None, "foo", True, True, False), + (lambda *, foo: None, "foo", False, True, True), + # Tests to make sure the names of variadic params are NOT supported + (lambda *args: None, "args", False, True, False), + (lambda **kwargs: None, "kwargs", False, True, False), + # Tests for if we allow var kwargs to add support + (lambda foo: None, "something_else", False, True, False), + (lambda foo, **kwargs: None, "something_else", False, True, True), + (lambda foo, **kwargs: None, "kwargs", True, True, False), + (lambda foo, **kwargs: None, "foo", True, True, False), + ], +) +def test_supports_kw( + callable, kw_name, requires_kw_only, allow_var_kwargs, is_supported +): + assert ( + supports_kw( + callable=callable, + kw_name=kw_name, + requires_kw_only=requires_kw_only, + allow_var_kwargs=allow_var_kwargs, + ) + == is_supported + ) diff --git a/tests/utils_/test_jsontree.py b/tests/utils_/test_jsontree.py new file mode 100644 index 0000000000000..0af2751b2638c --- /dev/null +++ b/tests/utils_/test_jsontree.py @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.utils.jsontree import json_count_leaves + + +def test_json_count_leaves(): + """Test json_count_leaves function from jsontree utility.""" + + # Single leaf values + assert json_count_leaves(42) == 1 + assert json_count_leaves("hello") == 1 + assert json_count_leaves(None) == 1 + + # Empty containers + assert json_count_leaves([]) == 0 + assert json_count_leaves({}) == 0 + assert json_count_leaves(()) == 0 + + # Flat structures + assert json_count_leaves([1, 2, 3]) == 3 + assert json_count_leaves({"a": 1, "b": 2}) == 2 + assert json_count_leaves((1, 2, 3)) == 3 + + # Nested structures + nested_dict = {"a": 1, "b": {"c": 2, "d": 3}} + assert json_count_leaves(nested_dict) == 3 + + nested_list = [1, [2, 3], 4] + assert json_count_leaves(nested_list) == 4 + + mixed_nested = {"list": [1, 2], "dict": {"x": 3}, "value": 4} + assert json_count_leaves(mixed_nested) == 4 diff --git a/tests/utils_/test_utils.py b/tests/utils_/test_utils.py index af5fc758f2c26..b4883a4fea31a 100644 --- a/tests/utils_/test_utils.py +++ b/tests/utils_/test_utils.py @@ -30,7 +30,6 @@ from vllm.utils import ( bind_kv_cache, common_broadcastable_dtype, current_stream, - deprecate_kwargs, get_open_port, get_tcp_uri, is_lossless_cast, @@ -42,12 +41,11 @@ from vllm.utils import ( sha256, split_host_port, split_zmq_path, - supports_kw, swap_dict_values, unique_filepath, ) -from ..utils import create_new_process_for_each_test, error_on_warning +from ..utils import create_new_process_for_each_test @pytest.mark.asyncio @@ -83,61 +81,6 @@ async def test_merge_async_iterators(): raise AssertionError() from e -def test_deprecate_kwargs_always(): - @deprecate_kwargs("old_arg", is_deprecated=True) - def dummy(*, old_arg: object = None, new_arg: object = None): - pass - - with pytest.warns(DeprecationWarning, match="'old_arg'"): - dummy(old_arg=1) - - with error_on_warning(DeprecationWarning): - dummy(new_arg=1) - - -def test_deprecate_kwargs_never(): - @deprecate_kwargs("old_arg", is_deprecated=False) - def dummy(*, old_arg: object = None, new_arg: object = None): - pass - - with error_on_warning(DeprecationWarning): - dummy(old_arg=1) - - with error_on_warning(DeprecationWarning): - dummy(new_arg=1) - - -def test_deprecate_kwargs_dynamic(): - is_deprecated = True - - @deprecate_kwargs("old_arg", is_deprecated=lambda: is_deprecated) - def dummy(*, old_arg: object = None, new_arg: object = None): - pass - - with pytest.warns(DeprecationWarning, match="'old_arg'"): - dummy(old_arg=1) - - with error_on_warning(DeprecationWarning): - dummy(new_arg=1) - - is_deprecated = False - - with error_on_warning(DeprecationWarning): - dummy(old_arg=1) - - with error_on_warning(DeprecationWarning): - dummy(new_arg=1) - - -def test_deprecate_kwargs_additional_message(): - @deprecate_kwargs("old_arg", is_deprecated=True, additional_message="abcd") - def dummy(*, old_arg: object = None, new_arg: object = None): - pass - - with pytest.warns(DeprecationWarning, match="abcd"): - dummy(old_arg=1) - - def test_get_open_port(monkeypatch: pytest.MonkeyPatch): with monkeypatch.context() as m: m.setenv("VLLM_PORT", "5678") @@ -383,39 +326,6 @@ def test_duplicate_dict_args(caplog_vllm, parser): assert "-O.mode" in caplog_vllm.text -@pytest.mark.parametrize( - "callable,kw_name,requires_kw_only,allow_var_kwargs,is_supported", - [ - # Tests for positional argument support - (lambda foo: None, "foo", True, True, False), - (lambda foo: None, "foo", False, True, True), - # Tests for positional or keyword / keyword only - (lambda foo=100: None, "foo", True, True, False), - (lambda *, foo: None, "foo", False, True, True), - # Tests to make sure the names of variadic params are NOT supported - (lambda *args: None, "args", False, True, False), - (lambda **kwargs: None, "kwargs", False, True, False), - # Tests for if we allow var kwargs to add support - (lambda foo: None, "something_else", False, True, False), - (lambda foo, **kwargs: None, "something_else", False, True, True), - (lambda foo, **kwargs: None, "kwargs", True, True, False), - (lambda foo, **kwargs: None, "foo", True, True, False), - ], -) -def test_supports_kw( - callable, kw_name, requires_kw_only, allow_var_kwargs, is_supported -): - assert ( - supports_kw( - callable=callable, - kw_name=kw_name, - requires_kw_only=requires_kw_only, - allow_var_kwargs=allow_var_kwargs, - ) - == is_supported - ) - - @create_new_process_for_each_test() def test_memory_profiling(): # Fake out some model loading + inference memory usage to test profiling @@ -863,36 +773,6 @@ def test_join_host_port(): assert join_host_port("::1", 5555) == "[::1]:5555" -def test_json_count_leaves(): - """Test json_count_leaves function from jsontree utility.""" - from vllm.utils.jsontree import json_count_leaves - - # Single leaf values - assert json_count_leaves(42) == 1 - assert json_count_leaves("hello") == 1 - assert json_count_leaves(None) == 1 - - # Empty containers - assert json_count_leaves([]) == 0 - assert json_count_leaves({}) == 0 - assert json_count_leaves(()) == 0 - - # Flat structures - assert json_count_leaves([1, 2, 3]) == 3 - assert json_count_leaves({"a": 1, "b": 2}) == 2 - assert json_count_leaves((1, 2, 3)) == 3 - - # Nested structures - nested_dict = {"a": 1, "b": {"c": 2, "d": 3}} - assert json_count_leaves(nested_dict) == 3 - - nested_list = [1, [2, 3], 4] - assert json_count_leaves(nested_list) == 4 - - mixed_nested = {"list": [1, 2], "dict": {"x": 3}, "value": 4} - assert json_count_leaves(mixed_nested) == 4 - - def test_convert_ids_list_to_tokens(): tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct") token_ids = tokenizer.encode("Hello, world!") diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 21973018a2b64..0d8b0280d5045 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -50,7 +50,8 @@ from vllm.multimodal.utils import MediaConnector from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.utils import random_uuid, supports_kw +from vllm.utils import random_uuid +from vllm.utils.func import supports_kw logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 3965d2dac0887..c318c0f425bd2 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -94,10 +94,10 @@ from vllm.utils import ( AsyncMicrobatchTokenizer, collect_from_async_generator, is_list_of, - make_async, merge_async_iterators, random_uuid, ) +from vllm.utils.func import make_async from vllm.v1.engine import EngineCoreRequest logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index 7506e17fe585b..e5c7f80a17533 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -37,7 +37,8 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.utils import make_async, merge_async_iterators +from vllm.utils import merge_async_iterators +from vllm.utils.func import make_async logger = init_logger(__name__) diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index a5f83f9040023..093d5e97fd3e4 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -17,7 +17,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.sequence import ExecuteModelRequest from vllm.tasks import SupportedTask -from vllm.utils import make_async +from vllm.utils.func import make_async from vllm.v1.outputs import SamplerOutput from vllm.v1.worker.worker_base import WorkerBase diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py index 59e282ac92b6d..a57b64152f49c 100644 --- a/vllm/executor/ray_distributed_executor.py +++ b/vllm/executor/ray_distributed_executor.py @@ -24,8 +24,8 @@ from vllm.utils import ( get_distributed_init_method, get_ip, get_open_port, - make_async, ) +from vllm.utils.func import make_async from vllm.v1.outputs import SamplerOutput if ray is not None: diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 350c21e0a95bc..169b14ba46eb9 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -27,8 +27,9 @@ from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, ) -from vllm.utils import has_deep_gemm, run_once +from vllm.utils import has_deep_gemm from vllm.utils.deep_gemm import m_grouped_fp8_gemm_nt_contiguous +from vllm.utils.func import run_once logger = init_logger(__name__) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index d25a0c18d1659..2487d7a691135 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -24,7 +24,7 @@ from vllm.inputs import TokensPrompt from vllm.inputs.data import PromptType from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.utils import supports_kw +from vllm.utils.func import supports_kw from .interfaces_base import VllmModel, is_pooling_model diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py index afb94f7c35467..da1ffd2548274 100644 --- a/vllm/model_executor/models/interfaces_base.py +++ b/vllm/model_executor/models/interfaces_base.py @@ -15,7 +15,7 @@ import torch.nn as nn from typing_extensions import TypeIs, TypeVar from vllm.logger import init_logger -from vllm.utils import supports_kw +from vllm.utils.func import supports_kw if TYPE_CHECKING: from vllm.config import VllmConfig diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 5d9876539499d..96055551c26ef 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -25,7 +25,8 @@ from typing_extensions import TypeVar, assert_never from vllm.logger import init_logger from vllm.transformers_utils.processor import cached_processor_from_config from vllm.transformers_utils.tokenizer import AnyTokenizer, decode_tokens, encode_tokens -from vllm.utils import flatten_2d_lists, full_groupby, get_allowed_kwarg_only_overrides +from vllm.utils import flatten_2d_lists, full_groupby +from vllm.utils.func import get_allowed_kwarg_only_overrides from vllm.utils.jsontree import JSONTree, json_map_leaves from .hasher import MultiModalHasher diff --git a/vllm/tracing.py b/vllm/tracing.py index 7e3e883ca5f2d..b4008064fef0e 100644 --- a/vllm/tracing.py +++ b/vllm/tracing.py @@ -5,7 +5,7 @@ import os from collections.abc import Mapping from vllm.logger import init_logger -from vllm.utils import run_once +from vllm.utils.func import run_once TRACE_HEADERS = ["traceparent", "tracestate"] diff --git a/vllm/transformers_utils/processor.py b/vllm/transformers_utils/processor.py index 0a55ac96ccf89..cdc138064a33c 100644 --- a/vllm/transformers_utils/processor.py +++ b/vllm/transformers_utils/processor.py @@ -16,7 +16,7 @@ from transformers.processing_utils import ProcessorMixin from transformers.video_processing_utils import BaseVideoProcessor from typing_extensions import TypeVar -from vllm.utils import get_allowed_kwarg_only_overrides +from vllm.utils.func import get_allowed_kwarg_only_overrides if TYPE_CHECKING: from vllm.config import ModelConfig diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 1f01cbeda9686..5fd94b7b40492 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio -import concurrent import contextlib import datetime import enum @@ -43,7 +42,6 @@ from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task from collections import UserDict, defaultdict from collections.abc import ( AsyncGenerator, - Awaitable, Callable, Collection, Generator, @@ -85,7 +83,7 @@ from packaging import version from packaging.version import Version from torch.library import Library from transformers.tokenization_utils_base import BatchEncoding -from typing_extensions import Never, ParamSpec, TypeIs, assert_never +from typing_extensions import Never, TypeIs, assert_never import vllm.envs as envs from vllm.logger import enable_trace_function_call, init_logger @@ -174,7 +172,6 @@ def set_default_torch_num_threads(num_threads: int): torch.set_num_threads(old_num_threads) -P = ParamSpec("P") T = TypeVar("T") U = TypeVar("U") @@ -452,24 +449,6 @@ def in_loop(event_loop: AbstractEventLoop) -> bool: return False -def make_async( - func: Callable[P, T], executor: concurrent.futures.Executor | None = None -) -> Callable[P, Awaitable[T]]: - """Take a blocking function, and run it on in an executor thread. - - This function prevents the blocking function from blocking the - asyncio event loop. - The code in this function needs to be thread safe. - """ - - def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future: - loop = asyncio.get_event_loop() - p_func = partial(func, *args, **kwargs) - return loop.run_in_executor(executor=executor, func=p_func) - - return _async_wrapper - - async def merge_async_iterators( *iterators: AsyncGenerator[T, None], ) -> AsyncGenerator[tuple[int, T], None]: @@ -1254,90 +1233,6 @@ def enable_trace_function_call_for_thread(vllm_config: VllmConfig) -> None: enable_trace_function_call(log_path) -# `functools` helpers -def identity(value: T, **kwargs) -> T: - """Returns the first provided value.""" - return value - - -F = TypeVar("F", bound=Callable[..., Any]) - - -def deprecate_args( - start_index: int, - is_deprecated: bool | Callable[[], bool] = True, - additional_message: str | None = None, -) -> Callable[[F], F]: - if not callable(is_deprecated): - is_deprecated = partial(identity, is_deprecated) - - def wrapper(fn: F) -> F: - params = inspect.signature(fn).parameters - pos_types = ( - inspect.Parameter.POSITIONAL_ONLY, - inspect.Parameter.POSITIONAL_OR_KEYWORD, - ) - pos_kws = [kw for kw, param in params.items() if param.kind in pos_types] - - @wraps(fn) - def inner(*args, **kwargs): - if is_deprecated(): - deprecated_args = pos_kws[start_index : len(args)] - if deprecated_args: - msg = ( - f"The positional arguments {deprecated_args} are " - "deprecated and will be removed in a future update." - ) - if additional_message is not None: - msg += f" {additional_message}" - - warnings.warn( - DeprecationWarning(msg), - stacklevel=3, # The inner function takes up one level - ) - - return fn(*args, **kwargs) - - return inner # type: ignore - - return wrapper - - -def deprecate_kwargs( - *kws: str, - is_deprecated: bool | Callable[[], bool] = True, - additional_message: str | None = None, -) -> Callable[[F], F]: - deprecated_kws = set(kws) - - if not callable(is_deprecated): - is_deprecated = partial(identity, is_deprecated) - - def wrapper(fn: F) -> F: - @wraps(fn) - def inner(*args, **kwargs): - if is_deprecated(): - deprecated_kwargs = kwargs.keys() & deprecated_kws - if deprecated_kwargs: - msg = ( - f"The keyword arguments {deprecated_kwargs} are " - "deprecated and will be removed in a future update." - ) - if additional_message is not None: - msg += f" {additional_message}" - - warnings.warn( - DeprecationWarning(msg), - stacklevel=3, # The inner function takes up one level - ) - - return fn(*args, **kwargs) - - return inner # type: ignore - - return wrapper - - @lru_cache(maxsize=8) def _cuda_device_count_stateless(cuda_visible_devices: str | None = None) -> int: # Note: cuda_visible_devices is not used, but we keep it as an argument for @@ -1426,21 +1321,6 @@ def weak_bind( return weak_bound -def run_once(f: Callable[P, None]) -> Callable[P, None]: - def wrapper(*args: P.args, **kwargs: P.kwargs) -> None: - if wrapper.has_run: # type: ignore[attr-defined] - return - - with wrapper.lock: # type: ignore[attr-defined] - if not wrapper.has_run: # type: ignore[attr-defined] - wrapper.has_run = True # type: ignore[attr-defined] - return f(*args, **kwargs) - - wrapper.has_run = False # type: ignore[attr-defined] - wrapper.lock = threading.Lock() # type: ignore[attr-defined] - return wrapper - - class StoreBoolean(Action): def __call__(self, parser, namespace, values, option_string=None): if values.lower() == "true": @@ -1929,122 +1809,6 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, **kwarg return await task(*args, **kwargs) -@lru_cache -def supports_kw( - callable: Callable[..., object], - kw_name: str, - *, - requires_kw_only: bool = False, - allow_var_kwargs: bool = True, -) -> bool: - """Check if a keyword is a valid kwarg for a callable; if requires_kw_only - disallows kwargs names that can also be positional arguments. - """ - params = inspect.signature(callable).parameters - if not params: - return False - - param_val = params.get(kw_name) - - # Types where the it may be valid, i.e., explicitly defined & nonvariadic - passable_kw_types = set( - ( - inspect.Parameter.POSITIONAL_ONLY, - inspect.Parameter.POSITIONAL_OR_KEYWORD, - inspect.Parameter.KEYWORD_ONLY, - ) - ) - - if param_val: - is_sig_param = param_val.kind in passable_kw_types - # We want kwargs only, but this is passable as a positional arg - if ( - requires_kw_only - and is_sig_param - and param_val.kind != inspect.Parameter.KEYWORD_ONLY - ): - return False - if (requires_kw_only and param_val.kind == inspect.Parameter.KEYWORD_ONLY) or ( - not requires_kw_only and is_sig_param - ): - return True - - # If we're okay with var-kwargs, it's supported as long as - # the kw_name isn't something like *args, **kwargs - if allow_var_kwargs: - # Get the last param; type is ignored here because params is a proxy - # mapping, but it wraps an ordered dict, and they appear in order. - # Ref: https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters - last_param = params[next(reversed(params))] # type: ignore - return ( - last_param.kind == inspect.Parameter.VAR_KEYWORD - and last_param.name != kw_name - ) - - return False - - -def get_allowed_kwarg_only_overrides( - callable: Callable[..., object], - overrides: Mapping[str, object] | None, - *, - requires_kw_only: bool = True, - allow_var_kwargs: bool = False, -) -> dict[str, Any]: - """ - Given a callable which has one or more keyword only params and a dict - mapping param names to values, drop values that can be not be kwarg - expanded to overwrite one or more keyword-only args. This is used in a - few places to handle custom processor overrides for multimodal models, - e.g., for profiling when processor options provided by the user - may affect the number of mm tokens per instance. - - Args: - callable: Callable which takes 0 or more keyword only arguments. - If None is provided, all overrides names are allowed. - overrides: Potential overrides to be used when invoking the callable. - allow_var_kwargs: Allows overrides that are expandable for var kwargs. - - Returns: - Dictionary containing the kwargs to be leveraged which may be used - to overwrite one or more keyword only arguments when invoking the - callable. - """ - if not overrides: - return {} - - # Drop any mm_processor_kwargs provided by the user that - # are not kwargs, unless it can fit it var_kwargs param - filtered_overrides = { - kwarg_name: val - for kwarg_name, val in overrides.items() - if supports_kw( - callable, - kwarg_name, - requires_kw_only=requires_kw_only, - allow_var_kwargs=allow_var_kwargs, - ) - } - - # If anything is dropped, log a warning - dropped_keys = overrides.keys() - filtered_overrides.keys() - if dropped_keys: - if requires_kw_only: - logger.warning( - "The following intended overrides are not keyword-only args " - "and will be dropped: %s", - dropped_keys, - ) - else: - logger.warning( - "The following intended overrides are not keyword args " - "and will be dropped: %s", - dropped_keys, - ) - - return filtered_overrides - - # Using dynamo with vLLM doesn't really work well with PyTorch versions < 2.4.0. # In particular, the FakeScalarType is not supported for earlier versions of # PyTorch which breaks dynamo for any ops registered using ScalarType. diff --git a/vllm/utils/func.py b/vllm/utils/func.py new file mode 100644 index 0000000000000..bd26b29d5f6dc --- /dev/null +++ b/vllm/utils/func.py @@ -0,0 +1,258 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Contains helpers that are applied to functions. + +This is similar in concept to the `functools` module. +""" + +import asyncio +import concurrent.futures +import inspect +import threading +import warnings +from collections.abc import Awaitable, Callable, Mapping +from functools import lru_cache, partial, wraps +from typing import Any, TypeVar + +from typing_extensions import ParamSpec + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +P = ParamSpec("P") +T = TypeVar("T") +F = TypeVar("F", bound=Callable[..., Any]) + + +def identity(value: T, **kwargs) -> T: + """Returns the first provided value.""" + return value + + +def make_async( + func: Callable[P, T], + executor: concurrent.futures.Executor | None = None, +) -> Callable[P, Awaitable[T]]: + """ + Take a blocking function, and run it on in an executor thread. + + This function prevents the blocking function from blocking the + asyncio event loop. + The code in this function needs to be thread safe. + """ + + def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future[T]: + loop = asyncio.get_event_loop() + p_func = partial(func, *args, **kwargs) + return loop.run_in_executor(executor=executor, func=p_func) + + return _async_wrapper + + +def run_once(f: Callable[P, None]) -> Callable[P, None]: + def wrapper(*args: P.args, **kwargs: P.kwargs) -> None: + if wrapper.has_run: # type: ignore[attr-defined] + return + + with wrapper.lock: # type: ignore[attr-defined] + if not wrapper.has_run: # type: ignore[attr-defined] + wrapper.has_run = True # type: ignore[attr-defined] + return f(*args, **kwargs) + + wrapper.has_run = False # type: ignore[attr-defined] + wrapper.lock = threading.Lock() # type: ignore[attr-defined] + return wrapper + + +def deprecate_args( + start_index: int, + is_deprecated: bool | Callable[[], bool] = True, + additional_message: str | None = None, +) -> Callable[[F], F]: + if not callable(is_deprecated): + is_deprecated = partial(identity, is_deprecated) + + def wrapper(fn: F) -> F: + params = inspect.signature(fn).parameters + pos_types = ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ) + pos_kws = [kw for kw, param in params.items() if param.kind in pos_types] + + @wraps(fn) + def inner(*args, **kwargs): + if is_deprecated(): + deprecated_args = pos_kws[start_index : len(args)] + if deprecated_args: + msg = ( + f"The positional arguments {deprecated_args} are " + "deprecated and will be removed in a future update." + ) + if additional_message is not None: + msg += f" {additional_message}" + + warnings.warn( + DeprecationWarning(msg), + stacklevel=3, # The inner function takes up one level + ) + + return fn(*args, **kwargs) + + return inner # type: ignore + + return wrapper + + +def deprecate_kwargs( + *kws: str, + is_deprecated: bool | Callable[[], bool] = True, + additional_message: str | None = None, +) -> Callable[[F], F]: + deprecated_kws = set(kws) + + if not callable(is_deprecated): + is_deprecated = partial(identity, is_deprecated) + + def wrapper(fn: F) -> F: + @wraps(fn) + def inner(*args, **kwargs): + if is_deprecated(): + deprecated_kwargs = kwargs.keys() & deprecated_kws + if deprecated_kwargs: + msg = ( + f"The keyword arguments {deprecated_kwargs} are " + "deprecated and will be removed in a future update." + ) + if additional_message is not None: + msg += f" {additional_message}" + + warnings.warn( + DeprecationWarning(msg), + stacklevel=3, # The inner function takes up one level + ) + + return fn(*args, **kwargs) + + return inner # type: ignore + + return wrapper + + +@lru_cache +def supports_kw( + callable: Callable[..., object], + kw_name: str, + *, + requires_kw_only: bool = False, + allow_var_kwargs: bool = True, +) -> bool: + """Check if a keyword is a valid kwarg for a callable; if requires_kw_only + disallows kwargs names that can also be positional arguments. + """ + params = inspect.signature(callable).parameters + if not params: + return False + + param_val = params.get(kw_name) + + # Types where the it may be valid, i.e., explicitly defined & nonvariadic + passable_kw_types = set( + ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ) + ) + + if param_val: + is_sig_param = param_val.kind in passable_kw_types + # We want kwargs only, but this is passable as a positional arg + if ( + requires_kw_only + and is_sig_param + and param_val.kind != inspect.Parameter.KEYWORD_ONLY + ): + return False + if (requires_kw_only and param_val.kind == inspect.Parameter.KEYWORD_ONLY) or ( + not requires_kw_only and is_sig_param + ): + return True + + # If we're okay with var-kwargs, it's supported as long as + # the kw_name isn't something like *args, **kwargs + if allow_var_kwargs: + # Get the last param; type is ignored here because params is a proxy + # mapping, but it wraps an ordered dict, and they appear in order. + # Ref: https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters + last_param = params[next(reversed(params))] # type: ignore + return ( + last_param.kind == inspect.Parameter.VAR_KEYWORD + and last_param.name != kw_name + ) + + return False + + +def get_allowed_kwarg_only_overrides( + callable: Callable[..., object], + overrides: Mapping[str, object] | None, + *, + requires_kw_only: bool = True, + allow_var_kwargs: bool = False, +) -> dict[str, Any]: + """ + Given a callable which has one or more keyword only params and a dict + mapping param names to values, drop values that can be not be kwarg + expanded to overwrite one or more keyword-only args. This is used in a + few places to handle custom processor overrides for multimodal models, + e.g., for profiling when processor options provided by the user + may affect the number of mm tokens per instance. + + Args: + callable: Callable which takes 0 or more keyword only arguments. + If None is provided, all overrides names are allowed. + overrides: Potential overrides to be used when invoking the callable. + allow_var_kwargs: Allows overrides that are expandable for var kwargs. + + Returns: + Dictionary containing the kwargs to be leveraged which may be used + to overwrite one or more keyword only arguments when invoking the + callable. + """ + if not overrides: + return {} + + # Drop any mm_processor_kwargs provided by the user that + # are not kwargs, unless it can fit it var_kwargs param + filtered_overrides = { + kwarg_name: val + for kwarg_name, val in overrides.items() + if supports_kw( + callable, + kwarg_name, + requires_kw_only=requires_kw_only, + allow_var_kwargs=allow_var_kwargs, + ) + } + + # If anything is dropped, log a warning + dropped_keys = overrides.keys() - filtered_overrides.keys() + if dropped_keys: + if requires_kw_only: + logger.warning( + "The following intended overrides are not keyword-only args " + "and will be dropped: %s", + dropped_keys, + ) + else: + logger.warning( + "The following intended overrides are not keyword args " + "and will be dropped: %s", + dropped_keys, + ) + + return filtered_overrides diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 0ec153e233161..c8fb30f96c0a0 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -29,7 +29,8 @@ from vllm.tracing import init_tracer from vllm.transformers_utils.config import maybe_register_config_serialize_by_value from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs from vllm.usage.usage_lib import UsageContext -from vllm.utils import Device, as_list, cancel_task_threadsafe, cdiv, deprecate_kwargs +from vllm.utils import Device, as_list, cancel_task_threadsafe, cdiv +from vllm.utils.func import deprecate_kwargs from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError