mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-25 23:44:31 +08:00
[Chore] Separate out vllm.utils.func (#26904)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
f57438338d
commit
136a17fe6e
@ -17,7 +17,7 @@ from transformers import (
|
||||
)
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import identity
|
||||
from vllm.utils.func import identity
|
||||
|
||||
from ....conftest import (
|
||||
IMAGE_ASSETS,
|
||||
|
||||
97
tests/utils_/test_func_utils.py
Normal file
97
tests/utils_/test_func_utils.py
Normal file
@ -0,0 +1,97 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.utils.func import deprecate_kwargs, supports_kw
|
||||
|
||||
from ..utils import error_on_warning
|
||||
|
||||
|
||||
def test_deprecate_kwargs_always():
|
||||
@deprecate_kwargs("old_arg", is_deprecated=True)
|
||||
def dummy(*, old_arg: object = None, new_arg: object = None):
|
||||
pass
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="'old_arg'"):
|
||||
dummy(old_arg=1)
|
||||
|
||||
with error_on_warning(DeprecationWarning):
|
||||
dummy(new_arg=1)
|
||||
|
||||
|
||||
def test_deprecate_kwargs_never():
|
||||
@deprecate_kwargs("old_arg", is_deprecated=False)
|
||||
def dummy(*, old_arg: object = None, new_arg: object = None):
|
||||
pass
|
||||
|
||||
with error_on_warning(DeprecationWarning):
|
||||
dummy(old_arg=1)
|
||||
|
||||
with error_on_warning(DeprecationWarning):
|
||||
dummy(new_arg=1)
|
||||
|
||||
|
||||
def test_deprecate_kwargs_dynamic():
|
||||
is_deprecated = True
|
||||
|
||||
@deprecate_kwargs("old_arg", is_deprecated=lambda: is_deprecated)
|
||||
def dummy(*, old_arg: object = None, new_arg: object = None):
|
||||
pass
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="'old_arg'"):
|
||||
dummy(old_arg=1)
|
||||
|
||||
with error_on_warning(DeprecationWarning):
|
||||
dummy(new_arg=1)
|
||||
|
||||
is_deprecated = False
|
||||
|
||||
with error_on_warning(DeprecationWarning):
|
||||
dummy(old_arg=1)
|
||||
|
||||
with error_on_warning(DeprecationWarning):
|
||||
dummy(new_arg=1)
|
||||
|
||||
|
||||
def test_deprecate_kwargs_additional_message():
|
||||
@deprecate_kwargs("old_arg", is_deprecated=True, additional_message="abcd")
|
||||
def dummy(*, old_arg: object = None, new_arg: object = None):
|
||||
pass
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="abcd"):
|
||||
dummy(old_arg=1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("callable", "kw_name", "requires_kw_only", "allow_var_kwargs", "is_supported"),
|
||||
[
|
||||
# Tests for positional argument support
|
||||
(lambda foo: None, "foo", True, True, False),
|
||||
(lambda foo: None, "foo", False, True, True),
|
||||
# Tests for positional or keyword / keyword only
|
||||
(lambda foo=100: None, "foo", True, True, False),
|
||||
(lambda *, foo: None, "foo", False, True, True),
|
||||
# Tests to make sure the names of variadic params are NOT supported
|
||||
(lambda *args: None, "args", False, True, False),
|
||||
(lambda **kwargs: None, "kwargs", False, True, False),
|
||||
# Tests for if we allow var kwargs to add support
|
||||
(lambda foo: None, "something_else", False, True, False),
|
||||
(lambda foo, **kwargs: None, "something_else", False, True, True),
|
||||
(lambda foo, **kwargs: None, "kwargs", True, True, False),
|
||||
(lambda foo, **kwargs: None, "foo", True, True, False),
|
||||
],
|
||||
)
|
||||
def test_supports_kw(
|
||||
callable, kw_name, requires_kw_only, allow_var_kwargs, is_supported
|
||||
):
|
||||
assert (
|
||||
supports_kw(
|
||||
callable=callable,
|
||||
kw_name=kw_name,
|
||||
requires_kw_only=requires_kw_only,
|
||||
allow_var_kwargs=allow_var_kwargs,
|
||||
)
|
||||
== is_supported
|
||||
)
|
||||
32
tests/utils_/test_jsontree.py
Normal file
32
tests/utils_/test_jsontree.py
Normal file
@ -0,0 +1,32 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from vllm.utils.jsontree import json_count_leaves
|
||||
|
||||
|
||||
def test_json_count_leaves():
|
||||
"""Test json_count_leaves function from jsontree utility."""
|
||||
|
||||
# Single leaf values
|
||||
assert json_count_leaves(42) == 1
|
||||
assert json_count_leaves("hello") == 1
|
||||
assert json_count_leaves(None) == 1
|
||||
|
||||
# Empty containers
|
||||
assert json_count_leaves([]) == 0
|
||||
assert json_count_leaves({}) == 0
|
||||
assert json_count_leaves(()) == 0
|
||||
|
||||
# Flat structures
|
||||
assert json_count_leaves([1, 2, 3]) == 3
|
||||
assert json_count_leaves({"a": 1, "b": 2}) == 2
|
||||
assert json_count_leaves((1, 2, 3)) == 3
|
||||
|
||||
# Nested structures
|
||||
nested_dict = {"a": 1, "b": {"c": 2, "d": 3}}
|
||||
assert json_count_leaves(nested_dict) == 3
|
||||
|
||||
nested_list = [1, [2, 3], 4]
|
||||
assert json_count_leaves(nested_list) == 4
|
||||
|
||||
mixed_nested = {"list": [1, 2], "dict": {"x": 3}, "value": 4}
|
||||
assert json_count_leaves(mixed_nested) == 4
|
||||
@ -30,7 +30,6 @@ from vllm.utils import (
|
||||
bind_kv_cache,
|
||||
common_broadcastable_dtype,
|
||||
current_stream,
|
||||
deprecate_kwargs,
|
||||
get_open_port,
|
||||
get_tcp_uri,
|
||||
is_lossless_cast,
|
||||
@ -42,12 +41,11 @@ from vllm.utils import (
|
||||
sha256,
|
||||
split_host_port,
|
||||
split_zmq_path,
|
||||
supports_kw,
|
||||
swap_dict_values,
|
||||
unique_filepath,
|
||||
)
|
||||
|
||||
from ..utils import create_new_process_for_each_test, error_on_warning
|
||||
from ..utils import create_new_process_for_each_test
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -83,61 +81,6 @@ async def test_merge_async_iterators():
|
||||
raise AssertionError() from e
|
||||
|
||||
|
||||
def test_deprecate_kwargs_always():
|
||||
@deprecate_kwargs("old_arg", is_deprecated=True)
|
||||
def dummy(*, old_arg: object = None, new_arg: object = None):
|
||||
pass
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="'old_arg'"):
|
||||
dummy(old_arg=1)
|
||||
|
||||
with error_on_warning(DeprecationWarning):
|
||||
dummy(new_arg=1)
|
||||
|
||||
|
||||
def test_deprecate_kwargs_never():
|
||||
@deprecate_kwargs("old_arg", is_deprecated=False)
|
||||
def dummy(*, old_arg: object = None, new_arg: object = None):
|
||||
pass
|
||||
|
||||
with error_on_warning(DeprecationWarning):
|
||||
dummy(old_arg=1)
|
||||
|
||||
with error_on_warning(DeprecationWarning):
|
||||
dummy(new_arg=1)
|
||||
|
||||
|
||||
def test_deprecate_kwargs_dynamic():
|
||||
is_deprecated = True
|
||||
|
||||
@deprecate_kwargs("old_arg", is_deprecated=lambda: is_deprecated)
|
||||
def dummy(*, old_arg: object = None, new_arg: object = None):
|
||||
pass
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="'old_arg'"):
|
||||
dummy(old_arg=1)
|
||||
|
||||
with error_on_warning(DeprecationWarning):
|
||||
dummy(new_arg=1)
|
||||
|
||||
is_deprecated = False
|
||||
|
||||
with error_on_warning(DeprecationWarning):
|
||||
dummy(old_arg=1)
|
||||
|
||||
with error_on_warning(DeprecationWarning):
|
||||
dummy(new_arg=1)
|
||||
|
||||
|
||||
def test_deprecate_kwargs_additional_message():
|
||||
@deprecate_kwargs("old_arg", is_deprecated=True, additional_message="abcd")
|
||||
def dummy(*, old_arg: object = None, new_arg: object = None):
|
||||
pass
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="abcd"):
|
||||
dummy(old_arg=1)
|
||||
|
||||
|
||||
def test_get_open_port(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_PORT", "5678")
|
||||
@ -383,39 +326,6 @@ def test_duplicate_dict_args(caplog_vllm, parser):
|
||||
assert "-O.mode" in caplog_vllm.text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"callable,kw_name,requires_kw_only,allow_var_kwargs,is_supported",
|
||||
[
|
||||
# Tests for positional argument support
|
||||
(lambda foo: None, "foo", True, True, False),
|
||||
(lambda foo: None, "foo", False, True, True),
|
||||
# Tests for positional or keyword / keyword only
|
||||
(lambda foo=100: None, "foo", True, True, False),
|
||||
(lambda *, foo: None, "foo", False, True, True),
|
||||
# Tests to make sure the names of variadic params are NOT supported
|
||||
(lambda *args: None, "args", False, True, False),
|
||||
(lambda **kwargs: None, "kwargs", False, True, False),
|
||||
# Tests for if we allow var kwargs to add support
|
||||
(lambda foo: None, "something_else", False, True, False),
|
||||
(lambda foo, **kwargs: None, "something_else", False, True, True),
|
||||
(lambda foo, **kwargs: None, "kwargs", True, True, False),
|
||||
(lambda foo, **kwargs: None, "foo", True, True, False),
|
||||
],
|
||||
)
|
||||
def test_supports_kw(
|
||||
callable, kw_name, requires_kw_only, allow_var_kwargs, is_supported
|
||||
):
|
||||
assert (
|
||||
supports_kw(
|
||||
callable=callable,
|
||||
kw_name=kw_name,
|
||||
requires_kw_only=requires_kw_only,
|
||||
allow_var_kwargs=allow_var_kwargs,
|
||||
)
|
||||
== is_supported
|
||||
)
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
def test_memory_profiling():
|
||||
# Fake out some model loading + inference memory usage to test profiling
|
||||
@ -863,36 +773,6 @@ def test_join_host_port():
|
||||
assert join_host_port("::1", 5555) == "[::1]:5555"
|
||||
|
||||
|
||||
def test_json_count_leaves():
|
||||
"""Test json_count_leaves function from jsontree utility."""
|
||||
from vllm.utils.jsontree import json_count_leaves
|
||||
|
||||
# Single leaf values
|
||||
assert json_count_leaves(42) == 1
|
||||
assert json_count_leaves("hello") == 1
|
||||
assert json_count_leaves(None) == 1
|
||||
|
||||
# Empty containers
|
||||
assert json_count_leaves([]) == 0
|
||||
assert json_count_leaves({}) == 0
|
||||
assert json_count_leaves(()) == 0
|
||||
|
||||
# Flat structures
|
||||
assert json_count_leaves([1, 2, 3]) == 3
|
||||
assert json_count_leaves({"a": 1, "b": 2}) == 2
|
||||
assert json_count_leaves((1, 2, 3)) == 3
|
||||
|
||||
# Nested structures
|
||||
nested_dict = {"a": 1, "b": {"c": 2, "d": 3}}
|
||||
assert json_count_leaves(nested_dict) == 3
|
||||
|
||||
nested_list = [1, [2, 3], 4]
|
||||
assert json_count_leaves(nested_list) == 4
|
||||
|
||||
mixed_nested = {"list": [1, 2], "dict": {"x": 3}, "value": 4}
|
||||
assert json_count_leaves(mixed_nested) == 4
|
||||
|
||||
|
||||
def test_convert_ids_list_to_tokens():
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
|
||||
token_ids = tokenizer.encode("Hello, world!")
|
||||
|
||||
@ -50,7 +50,8 @@ from vllm.multimodal.utils import MediaConnector
|
||||
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
|
||||
from vllm.transformers_utils.processor import cached_get_processor
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import random_uuid, supports_kw
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils.func import supports_kw
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -94,10 +94,10 @@ from vllm.utils import (
|
||||
AsyncMicrobatchTokenizer,
|
||||
collect_from_async_generator,
|
||||
is_list_of,
|
||||
make_async,
|
||||
merge_async_iterators,
|
||||
random_uuid,
|
||||
)
|
||||
from vllm.utils.func import make_async
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -37,7 +37,8 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import make_async, merge_async_iterators
|
||||
from vllm.utils import merge_async_iterators
|
||||
from vllm.utils.func import make_async
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -17,7 +17,7 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.utils import make_async
|
||||
from vllm.utils.func import make_async
|
||||
from vllm.v1.outputs import SamplerOutput
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
|
||||
|
||||
@ -24,8 +24,8 @@ from vllm.utils import (
|
||||
get_distributed_init_method,
|
||||
get_ip,
|
||||
get_open_port,
|
||||
make_async,
|
||||
)
|
||||
from vllm.utils.func import make_async
|
||||
from vllm.v1.outputs import SamplerOutput
|
||||
|
||||
if ray is not None:
|
||||
|
||||
@ -27,8 +27,9 @@ from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8,
|
||||
)
|
||||
from vllm.utils import has_deep_gemm, run_once
|
||||
from vllm.utils import has_deep_gemm
|
||||
from vllm.utils.deep_gemm import m_grouped_fp8_gemm_nt_contiguous
|
||||
from vllm.utils.func import run_once
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -24,7 +24,7 @@ from vllm.inputs import TokensPrompt
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.utils import supports_kw
|
||||
from vllm.utils.func import supports_kw
|
||||
|
||||
from .interfaces_base import VllmModel, is_pooling_model
|
||||
|
||||
|
||||
@ -15,7 +15,7 @@ import torch.nn as nn
|
||||
from typing_extensions import TypeIs, TypeVar
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import supports_kw
|
||||
from vllm.utils.func import supports_kw
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
@ -25,7 +25,8 @@ from typing_extensions import TypeVar, assert_never
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.processor import cached_processor_from_config
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, decode_tokens, encode_tokens
|
||||
from vllm.utils import flatten_2d_lists, full_groupby, get_allowed_kwarg_only_overrides
|
||||
from vllm.utils import flatten_2d_lists, full_groupby
|
||||
from vllm.utils.func import get_allowed_kwarg_only_overrides
|
||||
from vllm.utils.jsontree import JSONTree, json_map_leaves
|
||||
|
||||
from .hasher import MultiModalHasher
|
||||
|
||||
@ -5,7 +5,7 @@ import os
|
||||
from collections.abc import Mapping
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import run_once
|
||||
from vllm.utils.func import run_once
|
||||
|
||||
TRACE_HEADERS = ["traceparent", "tracestate"]
|
||||
|
||||
|
||||
@ -16,7 +16,7 @@ from transformers.processing_utils import ProcessorMixin
|
||||
from transformers.video_processing_utils import BaseVideoProcessor
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from vllm.utils import get_allowed_kwarg_only_overrides
|
||||
from vllm.utils.func import get_allowed_kwarg_only_overrides
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import concurrent
|
||||
import contextlib
|
||||
import datetime
|
||||
import enum
|
||||
@ -43,7 +42,6 @@ from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
|
||||
from collections import UserDict, defaultdict
|
||||
from collections.abc import (
|
||||
AsyncGenerator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Collection,
|
||||
Generator,
|
||||
@ -85,7 +83,7 @@ from packaging import version
|
||||
from packaging.version import Version
|
||||
from torch.library import Library
|
||||
from transformers.tokenization_utils_base import BatchEncoding
|
||||
from typing_extensions import Never, ParamSpec, TypeIs, assert_never
|
||||
from typing_extensions import Never, TypeIs, assert_never
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import enable_trace_function_call, init_logger
|
||||
@ -174,7 +172,6 @@ def set_default_torch_num_threads(num_threads: int):
|
||||
torch.set_num_threads(old_num_threads)
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
U = TypeVar("U")
|
||||
|
||||
@ -452,24 +449,6 @@ def in_loop(event_loop: AbstractEventLoop) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def make_async(
|
||||
func: Callable[P, T], executor: concurrent.futures.Executor | None = None
|
||||
) -> Callable[P, Awaitable[T]]:
|
||||
"""Take a blocking function, and run it on in an executor thread.
|
||||
|
||||
This function prevents the blocking function from blocking the
|
||||
asyncio event loop.
|
||||
The code in this function needs to be thread safe.
|
||||
"""
|
||||
|
||||
def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future:
|
||||
loop = asyncio.get_event_loop()
|
||||
p_func = partial(func, *args, **kwargs)
|
||||
return loop.run_in_executor(executor=executor, func=p_func)
|
||||
|
||||
return _async_wrapper
|
||||
|
||||
|
||||
async def merge_async_iterators(
|
||||
*iterators: AsyncGenerator[T, None],
|
||||
) -> AsyncGenerator[tuple[int, T], None]:
|
||||
@ -1254,90 +1233,6 @@ def enable_trace_function_call_for_thread(vllm_config: VllmConfig) -> None:
|
||||
enable_trace_function_call(log_path)
|
||||
|
||||
|
||||
# `functools` helpers
|
||||
def identity(value: T, **kwargs) -> T:
|
||||
"""Returns the first provided value."""
|
||||
return value
|
||||
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def deprecate_args(
|
||||
start_index: int,
|
||||
is_deprecated: bool | Callable[[], bool] = True,
|
||||
additional_message: str | None = None,
|
||||
) -> Callable[[F], F]:
|
||||
if not callable(is_deprecated):
|
||||
is_deprecated = partial(identity, is_deprecated)
|
||||
|
||||
def wrapper(fn: F) -> F:
|
||||
params = inspect.signature(fn).parameters
|
||||
pos_types = (
|
||||
inspect.Parameter.POSITIONAL_ONLY,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
)
|
||||
pos_kws = [kw for kw, param in params.items() if param.kind in pos_types]
|
||||
|
||||
@wraps(fn)
|
||||
def inner(*args, **kwargs):
|
||||
if is_deprecated():
|
||||
deprecated_args = pos_kws[start_index : len(args)]
|
||||
if deprecated_args:
|
||||
msg = (
|
||||
f"The positional arguments {deprecated_args} are "
|
||||
"deprecated and will be removed in a future update."
|
||||
)
|
||||
if additional_message is not None:
|
||||
msg += f" {additional_message}"
|
||||
|
||||
warnings.warn(
|
||||
DeprecationWarning(msg),
|
||||
stacklevel=3, # The inner function takes up one level
|
||||
)
|
||||
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return inner # type: ignore
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def deprecate_kwargs(
|
||||
*kws: str,
|
||||
is_deprecated: bool | Callable[[], bool] = True,
|
||||
additional_message: str | None = None,
|
||||
) -> Callable[[F], F]:
|
||||
deprecated_kws = set(kws)
|
||||
|
||||
if not callable(is_deprecated):
|
||||
is_deprecated = partial(identity, is_deprecated)
|
||||
|
||||
def wrapper(fn: F) -> F:
|
||||
@wraps(fn)
|
||||
def inner(*args, **kwargs):
|
||||
if is_deprecated():
|
||||
deprecated_kwargs = kwargs.keys() & deprecated_kws
|
||||
if deprecated_kwargs:
|
||||
msg = (
|
||||
f"The keyword arguments {deprecated_kwargs} are "
|
||||
"deprecated and will be removed in a future update."
|
||||
)
|
||||
if additional_message is not None:
|
||||
msg += f" {additional_message}"
|
||||
|
||||
warnings.warn(
|
||||
DeprecationWarning(msg),
|
||||
stacklevel=3, # The inner function takes up one level
|
||||
)
|
||||
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return inner # type: ignore
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
def _cuda_device_count_stateless(cuda_visible_devices: str | None = None) -> int:
|
||||
# Note: cuda_visible_devices is not used, but we keep it as an argument for
|
||||
@ -1426,21 +1321,6 @@ def weak_bind(
|
||||
return weak_bound
|
||||
|
||||
|
||||
def run_once(f: Callable[P, None]) -> Callable[P, None]:
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
|
||||
if wrapper.has_run: # type: ignore[attr-defined]
|
||||
return
|
||||
|
||||
with wrapper.lock: # type: ignore[attr-defined]
|
||||
if not wrapper.has_run: # type: ignore[attr-defined]
|
||||
wrapper.has_run = True # type: ignore[attr-defined]
|
||||
return f(*args, **kwargs)
|
||||
|
||||
wrapper.has_run = False # type: ignore[attr-defined]
|
||||
wrapper.lock = threading.Lock() # type: ignore[attr-defined]
|
||||
return wrapper
|
||||
|
||||
|
||||
class StoreBoolean(Action):
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
if values.lower() == "true":
|
||||
@ -1929,122 +1809,6 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, **kwarg
|
||||
return await task(*args, **kwargs)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def supports_kw(
|
||||
callable: Callable[..., object],
|
||||
kw_name: str,
|
||||
*,
|
||||
requires_kw_only: bool = False,
|
||||
allow_var_kwargs: bool = True,
|
||||
) -> bool:
|
||||
"""Check if a keyword is a valid kwarg for a callable; if requires_kw_only
|
||||
disallows kwargs names that can also be positional arguments.
|
||||
"""
|
||||
params = inspect.signature(callable).parameters
|
||||
if not params:
|
||||
return False
|
||||
|
||||
param_val = params.get(kw_name)
|
||||
|
||||
# Types where the it may be valid, i.e., explicitly defined & nonvariadic
|
||||
passable_kw_types = set(
|
||||
(
|
||||
inspect.Parameter.POSITIONAL_ONLY,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
inspect.Parameter.KEYWORD_ONLY,
|
||||
)
|
||||
)
|
||||
|
||||
if param_val:
|
||||
is_sig_param = param_val.kind in passable_kw_types
|
||||
# We want kwargs only, but this is passable as a positional arg
|
||||
if (
|
||||
requires_kw_only
|
||||
and is_sig_param
|
||||
and param_val.kind != inspect.Parameter.KEYWORD_ONLY
|
||||
):
|
||||
return False
|
||||
if (requires_kw_only and param_val.kind == inspect.Parameter.KEYWORD_ONLY) or (
|
||||
not requires_kw_only and is_sig_param
|
||||
):
|
||||
return True
|
||||
|
||||
# If we're okay with var-kwargs, it's supported as long as
|
||||
# the kw_name isn't something like *args, **kwargs
|
||||
if allow_var_kwargs:
|
||||
# Get the last param; type is ignored here because params is a proxy
|
||||
# mapping, but it wraps an ordered dict, and they appear in order.
|
||||
# Ref: https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters
|
||||
last_param = params[next(reversed(params))] # type: ignore
|
||||
return (
|
||||
last_param.kind == inspect.Parameter.VAR_KEYWORD
|
||||
and last_param.name != kw_name
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_allowed_kwarg_only_overrides(
|
||||
callable: Callable[..., object],
|
||||
overrides: Mapping[str, object] | None,
|
||||
*,
|
||||
requires_kw_only: bool = True,
|
||||
allow_var_kwargs: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Given a callable which has one or more keyword only params and a dict
|
||||
mapping param names to values, drop values that can be not be kwarg
|
||||
expanded to overwrite one or more keyword-only args. This is used in a
|
||||
few places to handle custom processor overrides for multimodal models,
|
||||
e.g., for profiling when processor options provided by the user
|
||||
may affect the number of mm tokens per instance.
|
||||
|
||||
Args:
|
||||
callable: Callable which takes 0 or more keyword only arguments.
|
||||
If None is provided, all overrides names are allowed.
|
||||
overrides: Potential overrides to be used when invoking the callable.
|
||||
allow_var_kwargs: Allows overrides that are expandable for var kwargs.
|
||||
|
||||
Returns:
|
||||
Dictionary containing the kwargs to be leveraged which may be used
|
||||
to overwrite one or more keyword only arguments when invoking the
|
||||
callable.
|
||||
"""
|
||||
if not overrides:
|
||||
return {}
|
||||
|
||||
# Drop any mm_processor_kwargs provided by the user that
|
||||
# are not kwargs, unless it can fit it var_kwargs param
|
||||
filtered_overrides = {
|
||||
kwarg_name: val
|
||||
for kwarg_name, val in overrides.items()
|
||||
if supports_kw(
|
||||
callable,
|
||||
kwarg_name,
|
||||
requires_kw_only=requires_kw_only,
|
||||
allow_var_kwargs=allow_var_kwargs,
|
||||
)
|
||||
}
|
||||
|
||||
# If anything is dropped, log a warning
|
||||
dropped_keys = overrides.keys() - filtered_overrides.keys()
|
||||
if dropped_keys:
|
||||
if requires_kw_only:
|
||||
logger.warning(
|
||||
"The following intended overrides are not keyword-only args "
|
||||
"and will be dropped: %s",
|
||||
dropped_keys,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"The following intended overrides are not keyword args "
|
||||
"and will be dropped: %s",
|
||||
dropped_keys,
|
||||
)
|
||||
|
||||
return filtered_overrides
|
||||
|
||||
|
||||
# Using dynamo with vLLM doesn't really work well with PyTorch versions < 2.4.0.
|
||||
# In particular, the FakeScalarType is not supported for earlier versions of
|
||||
# PyTorch which breaks dynamo for any ops registered using ScalarType.
|
||||
|
||||
258
vllm/utils/func.py
Normal file
258
vllm/utils/func.py
Normal file
@ -0,0 +1,258 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Contains helpers that are applied to functions.
|
||||
|
||||
This is similar in concept to the `functools` module.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import inspect
|
||||
import threading
|
||||
import warnings
|
||||
from collections.abc import Awaitable, Callable, Mapping
|
||||
from functools import lru_cache, partial, wraps
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def identity(value: T, **kwargs) -> T:
|
||||
"""Returns the first provided value."""
|
||||
return value
|
||||
|
||||
|
||||
def make_async(
|
||||
func: Callable[P, T],
|
||||
executor: concurrent.futures.Executor | None = None,
|
||||
) -> Callable[P, Awaitable[T]]:
|
||||
"""
|
||||
Take a blocking function, and run it on in an executor thread.
|
||||
|
||||
This function prevents the blocking function from blocking the
|
||||
asyncio event loop.
|
||||
The code in this function needs to be thread safe.
|
||||
"""
|
||||
|
||||
def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future[T]:
|
||||
loop = asyncio.get_event_loop()
|
||||
p_func = partial(func, *args, **kwargs)
|
||||
return loop.run_in_executor(executor=executor, func=p_func)
|
||||
|
||||
return _async_wrapper
|
||||
|
||||
|
||||
def run_once(f: Callable[P, None]) -> Callable[P, None]:
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
|
||||
if wrapper.has_run: # type: ignore[attr-defined]
|
||||
return
|
||||
|
||||
with wrapper.lock: # type: ignore[attr-defined]
|
||||
if not wrapper.has_run: # type: ignore[attr-defined]
|
||||
wrapper.has_run = True # type: ignore[attr-defined]
|
||||
return f(*args, **kwargs)
|
||||
|
||||
wrapper.has_run = False # type: ignore[attr-defined]
|
||||
wrapper.lock = threading.Lock() # type: ignore[attr-defined]
|
||||
return wrapper
|
||||
|
||||
|
||||
def deprecate_args(
|
||||
start_index: int,
|
||||
is_deprecated: bool | Callable[[], bool] = True,
|
||||
additional_message: str | None = None,
|
||||
) -> Callable[[F], F]:
|
||||
if not callable(is_deprecated):
|
||||
is_deprecated = partial(identity, is_deprecated)
|
||||
|
||||
def wrapper(fn: F) -> F:
|
||||
params = inspect.signature(fn).parameters
|
||||
pos_types = (
|
||||
inspect.Parameter.POSITIONAL_ONLY,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
)
|
||||
pos_kws = [kw for kw, param in params.items() if param.kind in pos_types]
|
||||
|
||||
@wraps(fn)
|
||||
def inner(*args, **kwargs):
|
||||
if is_deprecated():
|
||||
deprecated_args = pos_kws[start_index : len(args)]
|
||||
if deprecated_args:
|
||||
msg = (
|
||||
f"The positional arguments {deprecated_args} are "
|
||||
"deprecated and will be removed in a future update."
|
||||
)
|
||||
if additional_message is not None:
|
||||
msg += f" {additional_message}"
|
||||
|
||||
warnings.warn(
|
||||
DeprecationWarning(msg),
|
||||
stacklevel=3, # The inner function takes up one level
|
||||
)
|
||||
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return inner # type: ignore
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def deprecate_kwargs(
|
||||
*kws: str,
|
||||
is_deprecated: bool | Callable[[], bool] = True,
|
||||
additional_message: str | None = None,
|
||||
) -> Callable[[F], F]:
|
||||
deprecated_kws = set(kws)
|
||||
|
||||
if not callable(is_deprecated):
|
||||
is_deprecated = partial(identity, is_deprecated)
|
||||
|
||||
def wrapper(fn: F) -> F:
|
||||
@wraps(fn)
|
||||
def inner(*args, **kwargs):
|
||||
if is_deprecated():
|
||||
deprecated_kwargs = kwargs.keys() & deprecated_kws
|
||||
if deprecated_kwargs:
|
||||
msg = (
|
||||
f"The keyword arguments {deprecated_kwargs} are "
|
||||
"deprecated and will be removed in a future update."
|
||||
)
|
||||
if additional_message is not None:
|
||||
msg += f" {additional_message}"
|
||||
|
||||
warnings.warn(
|
||||
DeprecationWarning(msg),
|
||||
stacklevel=3, # The inner function takes up one level
|
||||
)
|
||||
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return inner # type: ignore
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@lru_cache
|
||||
def supports_kw(
|
||||
callable: Callable[..., object],
|
||||
kw_name: str,
|
||||
*,
|
||||
requires_kw_only: bool = False,
|
||||
allow_var_kwargs: bool = True,
|
||||
) -> bool:
|
||||
"""Check if a keyword is a valid kwarg for a callable; if requires_kw_only
|
||||
disallows kwargs names that can also be positional arguments.
|
||||
"""
|
||||
params = inspect.signature(callable).parameters
|
||||
if not params:
|
||||
return False
|
||||
|
||||
param_val = params.get(kw_name)
|
||||
|
||||
# Types where the it may be valid, i.e., explicitly defined & nonvariadic
|
||||
passable_kw_types = set(
|
||||
(
|
||||
inspect.Parameter.POSITIONAL_ONLY,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
inspect.Parameter.KEYWORD_ONLY,
|
||||
)
|
||||
)
|
||||
|
||||
if param_val:
|
||||
is_sig_param = param_val.kind in passable_kw_types
|
||||
# We want kwargs only, but this is passable as a positional arg
|
||||
if (
|
||||
requires_kw_only
|
||||
and is_sig_param
|
||||
and param_val.kind != inspect.Parameter.KEYWORD_ONLY
|
||||
):
|
||||
return False
|
||||
if (requires_kw_only and param_val.kind == inspect.Parameter.KEYWORD_ONLY) or (
|
||||
not requires_kw_only and is_sig_param
|
||||
):
|
||||
return True
|
||||
|
||||
# If we're okay with var-kwargs, it's supported as long as
|
||||
# the kw_name isn't something like *args, **kwargs
|
||||
if allow_var_kwargs:
|
||||
# Get the last param; type is ignored here because params is a proxy
|
||||
# mapping, but it wraps an ordered dict, and they appear in order.
|
||||
# Ref: https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters
|
||||
last_param = params[next(reversed(params))] # type: ignore
|
||||
return (
|
||||
last_param.kind == inspect.Parameter.VAR_KEYWORD
|
||||
and last_param.name != kw_name
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_allowed_kwarg_only_overrides(
|
||||
callable: Callable[..., object],
|
||||
overrides: Mapping[str, object] | None,
|
||||
*,
|
||||
requires_kw_only: bool = True,
|
||||
allow_var_kwargs: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Given a callable which has one or more keyword only params and a dict
|
||||
mapping param names to values, drop values that can be not be kwarg
|
||||
expanded to overwrite one or more keyword-only args. This is used in a
|
||||
few places to handle custom processor overrides for multimodal models,
|
||||
e.g., for profiling when processor options provided by the user
|
||||
may affect the number of mm tokens per instance.
|
||||
|
||||
Args:
|
||||
callable: Callable which takes 0 or more keyword only arguments.
|
||||
If None is provided, all overrides names are allowed.
|
||||
overrides: Potential overrides to be used when invoking the callable.
|
||||
allow_var_kwargs: Allows overrides that are expandable for var kwargs.
|
||||
|
||||
Returns:
|
||||
Dictionary containing the kwargs to be leveraged which may be used
|
||||
to overwrite one or more keyword only arguments when invoking the
|
||||
callable.
|
||||
"""
|
||||
if not overrides:
|
||||
return {}
|
||||
|
||||
# Drop any mm_processor_kwargs provided by the user that
|
||||
# are not kwargs, unless it can fit it var_kwargs param
|
||||
filtered_overrides = {
|
||||
kwarg_name: val
|
||||
for kwarg_name, val in overrides.items()
|
||||
if supports_kw(
|
||||
callable,
|
||||
kwarg_name,
|
||||
requires_kw_only=requires_kw_only,
|
||||
allow_var_kwargs=allow_var_kwargs,
|
||||
)
|
||||
}
|
||||
|
||||
# If anything is dropped, log a warning
|
||||
dropped_keys = overrides.keys() - filtered_overrides.keys()
|
||||
if dropped_keys:
|
||||
if requires_kw_only:
|
||||
logger.warning(
|
||||
"The following intended overrides are not keyword-only args "
|
||||
"and will be dropped: %s",
|
||||
dropped_keys,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"The following intended overrides are not keyword args "
|
||||
"and will be dropped: %s",
|
||||
dropped_keys,
|
||||
)
|
||||
|
||||
return filtered_overrides
|
||||
@ -29,7 +29,8 @@ from vllm.tracing import init_tracer
|
||||
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Device, as_list, cancel_task_threadsafe, cdiv, deprecate_kwargs
|
||||
from vllm.utils import Device, as_list, cancel_task_threadsafe, cdiv
|
||||
from vllm.utils.func import deprecate_kwargs
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.core_client import EngineCoreClient
|
||||
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user