mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-10 17:30:12 +08:00
[Chore] Separate out vllm.utils.func (#26904)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
f57438338d
commit
136a17fe6e
@ -17,7 +17,7 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import identity
|
from vllm.utils.func import identity
|
||||||
|
|
||||||
from ....conftest import (
|
from ....conftest import (
|
||||||
IMAGE_ASSETS,
|
IMAGE_ASSETS,
|
||||||
|
|||||||
97
tests/utils_/test_func_utils.py
Normal file
97
tests/utils_/test_func_utils.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
# ruff: noqa
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.utils.func import deprecate_kwargs, supports_kw
|
||||||
|
|
||||||
|
from ..utils import error_on_warning
|
||||||
|
|
||||||
|
|
||||||
|
def test_deprecate_kwargs_always():
|
||||||
|
@deprecate_kwargs("old_arg", is_deprecated=True)
|
||||||
|
def dummy(*, old_arg: object = None, new_arg: object = None):
|
||||||
|
pass
|
||||||
|
|
||||||
|
with pytest.warns(DeprecationWarning, match="'old_arg'"):
|
||||||
|
dummy(old_arg=1)
|
||||||
|
|
||||||
|
with error_on_warning(DeprecationWarning):
|
||||||
|
dummy(new_arg=1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_deprecate_kwargs_never():
|
||||||
|
@deprecate_kwargs("old_arg", is_deprecated=False)
|
||||||
|
def dummy(*, old_arg: object = None, new_arg: object = None):
|
||||||
|
pass
|
||||||
|
|
||||||
|
with error_on_warning(DeprecationWarning):
|
||||||
|
dummy(old_arg=1)
|
||||||
|
|
||||||
|
with error_on_warning(DeprecationWarning):
|
||||||
|
dummy(new_arg=1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_deprecate_kwargs_dynamic():
|
||||||
|
is_deprecated = True
|
||||||
|
|
||||||
|
@deprecate_kwargs("old_arg", is_deprecated=lambda: is_deprecated)
|
||||||
|
def dummy(*, old_arg: object = None, new_arg: object = None):
|
||||||
|
pass
|
||||||
|
|
||||||
|
with pytest.warns(DeprecationWarning, match="'old_arg'"):
|
||||||
|
dummy(old_arg=1)
|
||||||
|
|
||||||
|
with error_on_warning(DeprecationWarning):
|
||||||
|
dummy(new_arg=1)
|
||||||
|
|
||||||
|
is_deprecated = False
|
||||||
|
|
||||||
|
with error_on_warning(DeprecationWarning):
|
||||||
|
dummy(old_arg=1)
|
||||||
|
|
||||||
|
with error_on_warning(DeprecationWarning):
|
||||||
|
dummy(new_arg=1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_deprecate_kwargs_additional_message():
|
||||||
|
@deprecate_kwargs("old_arg", is_deprecated=True, additional_message="abcd")
|
||||||
|
def dummy(*, old_arg: object = None, new_arg: object = None):
|
||||||
|
pass
|
||||||
|
|
||||||
|
with pytest.warns(DeprecationWarning, match="abcd"):
|
||||||
|
dummy(old_arg=1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("callable", "kw_name", "requires_kw_only", "allow_var_kwargs", "is_supported"),
|
||||||
|
[
|
||||||
|
# Tests for positional argument support
|
||||||
|
(lambda foo: None, "foo", True, True, False),
|
||||||
|
(lambda foo: None, "foo", False, True, True),
|
||||||
|
# Tests for positional or keyword / keyword only
|
||||||
|
(lambda foo=100: None, "foo", True, True, False),
|
||||||
|
(lambda *, foo: None, "foo", False, True, True),
|
||||||
|
# Tests to make sure the names of variadic params are NOT supported
|
||||||
|
(lambda *args: None, "args", False, True, False),
|
||||||
|
(lambda **kwargs: None, "kwargs", False, True, False),
|
||||||
|
# Tests for if we allow var kwargs to add support
|
||||||
|
(lambda foo: None, "something_else", False, True, False),
|
||||||
|
(lambda foo, **kwargs: None, "something_else", False, True, True),
|
||||||
|
(lambda foo, **kwargs: None, "kwargs", True, True, False),
|
||||||
|
(lambda foo, **kwargs: None, "foo", True, True, False),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_supports_kw(
|
||||||
|
callable, kw_name, requires_kw_only, allow_var_kwargs, is_supported
|
||||||
|
):
|
||||||
|
assert (
|
||||||
|
supports_kw(
|
||||||
|
callable=callable,
|
||||||
|
kw_name=kw_name,
|
||||||
|
requires_kw_only=requires_kw_only,
|
||||||
|
allow_var_kwargs=allow_var_kwargs,
|
||||||
|
)
|
||||||
|
== is_supported
|
||||||
|
)
|
||||||
32
tests/utils_/test_jsontree.py
Normal file
32
tests/utils_/test_jsontree.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from vllm.utils.jsontree import json_count_leaves
|
||||||
|
|
||||||
|
|
||||||
|
def test_json_count_leaves():
|
||||||
|
"""Test json_count_leaves function from jsontree utility."""
|
||||||
|
|
||||||
|
# Single leaf values
|
||||||
|
assert json_count_leaves(42) == 1
|
||||||
|
assert json_count_leaves("hello") == 1
|
||||||
|
assert json_count_leaves(None) == 1
|
||||||
|
|
||||||
|
# Empty containers
|
||||||
|
assert json_count_leaves([]) == 0
|
||||||
|
assert json_count_leaves({}) == 0
|
||||||
|
assert json_count_leaves(()) == 0
|
||||||
|
|
||||||
|
# Flat structures
|
||||||
|
assert json_count_leaves([1, 2, 3]) == 3
|
||||||
|
assert json_count_leaves({"a": 1, "b": 2}) == 2
|
||||||
|
assert json_count_leaves((1, 2, 3)) == 3
|
||||||
|
|
||||||
|
# Nested structures
|
||||||
|
nested_dict = {"a": 1, "b": {"c": 2, "d": 3}}
|
||||||
|
assert json_count_leaves(nested_dict) == 3
|
||||||
|
|
||||||
|
nested_list = [1, [2, 3], 4]
|
||||||
|
assert json_count_leaves(nested_list) == 4
|
||||||
|
|
||||||
|
mixed_nested = {"list": [1, 2], "dict": {"x": 3}, "value": 4}
|
||||||
|
assert json_count_leaves(mixed_nested) == 4
|
||||||
@ -30,7 +30,6 @@ from vllm.utils import (
|
|||||||
bind_kv_cache,
|
bind_kv_cache,
|
||||||
common_broadcastable_dtype,
|
common_broadcastable_dtype,
|
||||||
current_stream,
|
current_stream,
|
||||||
deprecate_kwargs,
|
|
||||||
get_open_port,
|
get_open_port,
|
||||||
get_tcp_uri,
|
get_tcp_uri,
|
||||||
is_lossless_cast,
|
is_lossless_cast,
|
||||||
@ -42,12 +41,11 @@ from vllm.utils import (
|
|||||||
sha256,
|
sha256,
|
||||||
split_host_port,
|
split_host_port,
|
||||||
split_zmq_path,
|
split_zmq_path,
|
||||||
supports_kw,
|
|
||||||
swap_dict_values,
|
swap_dict_values,
|
||||||
unique_filepath,
|
unique_filepath,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..utils import create_new_process_for_each_test, error_on_warning
|
from ..utils import create_new_process_for_each_test
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -83,61 +81,6 @@ async def test_merge_async_iterators():
|
|||||||
raise AssertionError() from e
|
raise AssertionError() from e
|
||||||
|
|
||||||
|
|
||||||
def test_deprecate_kwargs_always():
|
|
||||||
@deprecate_kwargs("old_arg", is_deprecated=True)
|
|
||||||
def dummy(*, old_arg: object = None, new_arg: object = None):
|
|
||||||
pass
|
|
||||||
|
|
||||||
with pytest.warns(DeprecationWarning, match="'old_arg'"):
|
|
||||||
dummy(old_arg=1)
|
|
||||||
|
|
||||||
with error_on_warning(DeprecationWarning):
|
|
||||||
dummy(new_arg=1)
|
|
||||||
|
|
||||||
|
|
||||||
def test_deprecate_kwargs_never():
|
|
||||||
@deprecate_kwargs("old_arg", is_deprecated=False)
|
|
||||||
def dummy(*, old_arg: object = None, new_arg: object = None):
|
|
||||||
pass
|
|
||||||
|
|
||||||
with error_on_warning(DeprecationWarning):
|
|
||||||
dummy(old_arg=1)
|
|
||||||
|
|
||||||
with error_on_warning(DeprecationWarning):
|
|
||||||
dummy(new_arg=1)
|
|
||||||
|
|
||||||
|
|
||||||
def test_deprecate_kwargs_dynamic():
|
|
||||||
is_deprecated = True
|
|
||||||
|
|
||||||
@deprecate_kwargs("old_arg", is_deprecated=lambda: is_deprecated)
|
|
||||||
def dummy(*, old_arg: object = None, new_arg: object = None):
|
|
||||||
pass
|
|
||||||
|
|
||||||
with pytest.warns(DeprecationWarning, match="'old_arg'"):
|
|
||||||
dummy(old_arg=1)
|
|
||||||
|
|
||||||
with error_on_warning(DeprecationWarning):
|
|
||||||
dummy(new_arg=1)
|
|
||||||
|
|
||||||
is_deprecated = False
|
|
||||||
|
|
||||||
with error_on_warning(DeprecationWarning):
|
|
||||||
dummy(old_arg=1)
|
|
||||||
|
|
||||||
with error_on_warning(DeprecationWarning):
|
|
||||||
dummy(new_arg=1)
|
|
||||||
|
|
||||||
|
|
||||||
def test_deprecate_kwargs_additional_message():
|
|
||||||
@deprecate_kwargs("old_arg", is_deprecated=True, additional_message="abcd")
|
|
||||||
def dummy(*, old_arg: object = None, new_arg: object = None):
|
|
||||||
pass
|
|
||||||
|
|
||||||
with pytest.warns(DeprecationWarning, match="abcd"):
|
|
||||||
dummy(old_arg=1)
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_open_port(monkeypatch: pytest.MonkeyPatch):
|
def test_get_open_port(monkeypatch: pytest.MonkeyPatch):
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
m.setenv("VLLM_PORT", "5678")
|
m.setenv("VLLM_PORT", "5678")
|
||||||
@ -383,39 +326,6 @@ def test_duplicate_dict_args(caplog_vllm, parser):
|
|||||||
assert "-O.mode" in caplog_vllm.text
|
assert "-O.mode" in caplog_vllm.text
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"callable,kw_name,requires_kw_only,allow_var_kwargs,is_supported",
|
|
||||||
[
|
|
||||||
# Tests for positional argument support
|
|
||||||
(lambda foo: None, "foo", True, True, False),
|
|
||||||
(lambda foo: None, "foo", False, True, True),
|
|
||||||
# Tests for positional or keyword / keyword only
|
|
||||||
(lambda foo=100: None, "foo", True, True, False),
|
|
||||||
(lambda *, foo: None, "foo", False, True, True),
|
|
||||||
# Tests to make sure the names of variadic params are NOT supported
|
|
||||||
(lambda *args: None, "args", False, True, False),
|
|
||||||
(lambda **kwargs: None, "kwargs", False, True, False),
|
|
||||||
# Tests for if we allow var kwargs to add support
|
|
||||||
(lambda foo: None, "something_else", False, True, False),
|
|
||||||
(lambda foo, **kwargs: None, "something_else", False, True, True),
|
|
||||||
(lambda foo, **kwargs: None, "kwargs", True, True, False),
|
|
||||||
(lambda foo, **kwargs: None, "foo", True, True, False),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_supports_kw(
|
|
||||||
callable, kw_name, requires_kw_only, allow_var_kwargs, is_supported
|
|
||||||
):
|
|
||||||
assert (
|
|
||||||
supports_kw(
|
|
||||||
callable=callable,
|
|
||||||
kw_name=kw_name,
|
|
||||||
requires_kw_only=requires_kw_only,
|
|
||||||
allow_var_kwargs=allow_var_kwargs,
|
|
||||||
)
|
|
||||||
== is_supported
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@create_new_process_for_each_test()
|
@create_new_process_for_each_test()
|
||||||
def test_memory_profiling():
|
def test_memory_profiling():
|
||||||
# Fake out some model loading + inference memory usage to test profiling
|
# Fake out some model loading + inference memory usage to test profiling
|
||||||
@ -863,36 +773,6 @@ def test_join_host_port():
|
|||||||
assert join_host_port("::1", 5555) == "[::1]:5555"
|
assert join_host_port("::1", 5555) == "[::1]:5555"
|
||||||
|
|
||||||
|
|
||||||
def test_json_count_leaves():
|
|
||||||
"""Test json_count_leaves function from jsontree utility."""
|
|
||||||
from vllm.utils.jsontree import json_count_leaves
|
|
||||||
|
|
||||||
# Single leaf values
|
|
||||||
assert json_count_leaves(42) == 1
|
|
||||||
assert json_count_leaves("hello") == 1
|
|
||||||
assert json_count_leaves(None) == 1
|
|
||||||
|
|
||||||
# Empty containers
|
|
||||||
assert json_count_leaves([]) == 0
|
|
||||||
assert json_count_leaves({}) == 0
|
|
||||||
assert json_count_leaves(()) == 0
|
|
||||||
|
|
||||||
# Flat structures
|
|
||||||
assert json_count_leaves([1, 2, 3]) == 3
|
|
||||||
assert json_count_leaves({"a": 1, "b": 2}) == 2
|
|
||||||
assert json_count_leaves((1, 2, 3)) == 3
|
|
||||||
|
|
||||||
# Nested structures
|
|
||||||
nested_dict = {"a": 1, "b": {"c": 2, "d": 3}}
|
|
||||||
assert json_count_leaves(nested_dict) == 3
|
|
||||||
|
|
||||||
nested_list = [1, [2, 3], 4]
|
|
||||||
assert json_count_leaves(nested_list) == 4
|
|
||||||
|
|
||||||
mixed_nested = {"list": [1, 2], "dict": {"x": 3}, "value": 4}
|
|
||||||
assert json_count_leaves(mixed_nested) == 4
|
|
||||||
|
|
||||||
|
|
||||||
def test_convert_ids_list_to_tokens():
|
def test_convert_ids_list_to_tokens():
|
||||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
|
||||||
token_ids = tokenizer.encode("Hello, world!")
|
token_ids = tokenizer.encode("Hello, world!")
|
||||||
|
|||||||
@ -50,7 +50,8 @@ from vllm.multimodal.utils import MediaConnector
|
|||||||
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
|
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
|
||||||
from vllm.transformers_utils.processor import cached_get_processor
|
from vllm.transformers_utils.processor import cached_get_processor
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||||
from vllm.utils import random_uuid, supports_kw
|
from vllm.utils import random_uuid
|
||||||
|
from vllm.utils.func import supports_kw
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@ -94,10 +94,10 @@ from vllm.utils import (
|
|||||||
AsyncMicrobatchTokenizer,
|
AsyncMicrobatchTokenizer,
|
||||||
collect_from_async_generator,
|
collect_from_async_generator,
|
||||||
is_list_of,
|
is_list_of,
|
||||||
make_async,
|
|
||||||
merge_async_iterators,
|
merge_async_iterators,
|
||||||
random_uuid,
|
random_uuid,
|
||||||
)
|
)
|
||||||
|
from vllm.utils.func import make_async
|
||||||
from vllm.v1.engine import EngineCoreRequest
|
from vllm.v1.engine import EngineCoreRequest
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|||||||
@ -37,7 +37,8 @@ from vllm.logger import init_logger
|
|||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
|
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||||
from vllm.utils import make_async, merge_async_iterators
|
from vllm.utils import merge_async_iterators
|
||||||
|
from vllm.utils.func import make_async
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@ -17,7 +17,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.sequence import ExecuteModelRequest
|
from vllm.sequence import ExecuteModelRequest
|
||||||
from vllm.tasks import SupportedTask
|
from vllm.tasks import SupportedTask
|
||||||
from vllm.utils import make_async
|
from vllm.utils.func import make_async
|
||||||
from vllm.v1.outputs import SamplerOutput
|
from vllm.v1.outputs import SamplerOutput
|
||||||
from vllm.v1.worker.worker_base import WorkerBase
|
from vllm.v1.worker.worker_base import WorkerBase
|
||||||
|
|
||||||
|
|||||||
@ -24,8 +24,8 @@ from vllm.utils import (
|
|||||||
get_distributed_init_method,
|
get_distributed_init_method,
|
||||||
get_ip,
|
get_ip,
|
||||||
get_open_port,
|
get_open_port,
|
||||||
make_async,
|
|
||||||
)
|
)
|
||||||
|
from vllm.utils.func import make_async
|
||||||
from vllm.v1.outputs import SamplerOutput
|
from vllm.v1.outputs import SamplerOutput
|
||||||
|
|
||||||
if ray is not None:
|
if ray is not None:
|
||||||
|
|||||||
@ -27,8 +27,9 @@ from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
|||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
per_token_group_quant_fp8,
|
per_token_group_quant_fp8,
|
||||||
)
|
)
|
||||||
from vllm.utils import has_deep_gemm, run_once
|
from vllm.utils import has_deep_gemm
|
||||||
from vllm.utils.deep_gemm import m_grouped_fp8_gemm_nt_contiguous
|
from vllm.utils.deep_gemm import m_grouped_fp8_gemm_nt_contiguous
|
||||||
|
from vllm.utils.func import run_once
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@ -24,7 +24,7 @@ from vllm.inputs import TokensPrompt
|
|||||||
from vllm.inputs.data import PromptType
|
from vllm.inputs.data import PromptType
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.utils import supports_kw
|
from vllm.utils.func import supports_kw
|
||||||
|
|
||||||
from .interfaces_base import VllmModel, is_pooling_model
|
from .interfaces_base import VllmModel, is_pooling_model
|
||||||
|
|
||||||
|
|||||||
@ -15,7 +15,7 @@ import torch.nn as nn
|
|||||||
from typing_extensions import TypeIs, TypeVar
|
from typing_extensions import TypeIs, TypeVar
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import supports_kw
|
from vllm.utils.func import supports_kw
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
|||||||
@ -25,7 +25,8 @@ from typing_extensions import TypeVar, assert_never
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.transformers_utils.processor import cached_processor_from_config
|
from vllm.transformers_utils.processor import cached_processor_from_config
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, decode_tokens, encode_tokens
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, decode_tokens, encode_tokens
|
||||||
from vllm.utils import flatten_2d_lists, full_groupby, get_allowed_kwarg_only_overrides
|
from vllm.utils import flatten_2d_lists, full_groupby
|
||||||
|
from vllm.utils.func import get_allowed_kwarg_only_overrides
|
||||||
from vllm.utils.jsontree import JSONTree, json_map_leaves
|
from vllm.utils.jsontree import JSONTree, json_map_leaves
|
||||||
|
|
||||||
from .hasher import MultiModalHasher
|
from .hasher import MultiModalHasher
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import os
|
|||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import run_once
|
from vllm.utils.func import run_once
|
||||||
|
|
||||||
TRACE_HEADERS = ["traceparent", "tracestate"]
|
TRACE_HEADERS = ["traceparent", "tracestate"]
|
||||||
|
|
||||||
|
|||||||
@ -16,7 +16,7 @@ from transformers.processing_utils import ProcessorMixin
|
|||||||
from transformers.video_processing_utils import BaseVideoProcessor
|
from transformers.video_processing_utils import BaseVideoProcessor
|
||||||
from typing_extensions import TypeVar
|
from typing_extensions import TypeVar
|
||||||
|
|
||||||
from vllm.utils import get_allowed_kwarg_only_overrides
|
from vllm.utils.func import get_allowed_kwarg_only_overrides
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
|
|||||||
@ -2,7 +2,6 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import concurrent
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import datetime
|
import datetime
|
||||||
import enum
|
import enum
|
||||||
@ -43,7 +42,6 @@ from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
|
|||||||
from collections import UserDict, defaultdict
|
from collections import UserDict, defaultdict
|
||||||
from collections.abc import (
|
from collections.abc import (
|
||||||
AsyncGenerator,
|
AsyncGenerator,
|
||||||
Awaitable,
|
|
||||||
Callable,
|
Callable,
|
||||||
Collection,
|
Collection,
|
||||||
Generator,
|
Generator,
|
||||||
@ -85,7 +83,7 @@ from packaging import version
|
|||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
from torch.library import Library
|
from torch.library import Library
|
||||||
from transformers.tokenization_utils_base import BatchEncoding
|
from transformers.tokenization_utils_base import BatchEncoding
|
||||||
from typing_extensions import Never, ParamSpec, TypeIs, assert_never
|
from typing_extensions import Never, TypeIs, assert_never
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.logger import enable_trace_function_call, init_logger
|
from vllm.logger import enable_trace_function_call, init_logger
|
||||||
@ -174,7 +172,6 @@ def set_default_torch_num_threads(num_threads: int):
|
|||||||
torch.set_num_threads(old_num_threads)
|
torch.set_num_threads(old_num_threads)
|
||||||
|
|
||||||
|
|
||||||
P = ParamSpec("P")
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
U = TypeVar("U")
|
U = TypeVar("U")
|
||||||
|
|
||||||
@ -452,24 +449,6 @@ def in_loop(event_loop: AbstractEventLoop) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def make_async(
|
|
||||||
func: Callable[P, T], executor: concurrent.futures.Executor | None = None
|
|
||||||
) -> Callable[P, Awaitable[T]]:
|
|
||||||
"""Take a blocking function, and run it on in an executor thread.
|
|
||||||
|
|
||||||
This function prevents the blocking function from blocking the
|
|
||||||
asyncio event loop.
|
|
||||||
The code in this function needs to be thread safe.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future:
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
p_func = partial(func, *args, **kwargs)
|
|
||||||
return loop.run_in_executor(executor=executor, func=p_func)
|
|
||||||
|
|
||||||
return _async_wrapper
|
|
||||||
|
|
||||||
|
|
||||||
async def merge_async_iterators(
|
async def merge_async_iterators(
|
||||||
*iterators: AsyncGenerator[T, None],
|
*iterators: AsyncGenerator[T, None],
|
||||||
) -> AsyncGenerator[tuple[int, T], None]:
|
) -> AsyncGenerator[tuple[int, T], None]:
|
||||||
@ -1254,90 +1233,6 @@ def enable_trace_function_call_for_thread(vllm_config: VllmConfig) -> None:
|
|||||||
enable_trace_function_call(log_path)
|
enable_trace_function_call(log_path)
|
||||||
|
|
||||||
|
|
||||||
# `functools` helpers
|
|
||||||
def identity(value: T, **kwargs) -> T:
|
|
||||||
"""Returns the first provided value."""
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
F = TypeVar("F", bound=Callable[..., Any])
|
|
||||||
|
|
||||||
|
|
||||||
def deprecate_args(
|
|
||||||
start_index: int,
|
|
||||||
is_deprecated: bool | Callable[[], bool] = True,
|
|
||||||
additional_message: str | None = None,
|
|
||||||
) -> Callable[[F], F]:
|
|
||||||
if not callable(is_deprecated):
|
|
||||||
is_deprecated = partial(identity, is_deprecated)
|
|
||||||
|
|
||||||
def wrapper(fn: F) -> F:
|
|
||||||
params = inspect.signature(fn).parameters
|
|
||||||
pos_types = (
|
|
||||||
inspect.Parameter.POSITIONAL_ONLY,
|
|
||||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
||||||
)
|
|
||||||
pos_kws = [kw for kw, param in params.items() if param.kind in pos_types]
|
|
||||||
|
|
||||||
@wraps(fn)
|
|
||||||
def inner(*args, **kwargs):
|
|
||||||
if is_deprecated():
|
|
||||||
deprecated_args = pos_kws[start_index : len(args)]
|
|
||||||
if deprecated_args:
|
|
||||||
msg = (
|
|
||||||
f"The positional arguments {deprecated_args} are "
|
|
||||||
"deprecated and will be removed in a future update."
|
|
||||||
)
|
|
||||||
if additional_message is not None:
|
|
||||||
msg += f" {additional_message}"
|
|
||||||
|
|
||||||
warnings.warn(
|
|
||||||
DeprecationWarning(msg),
|
|
||||||
stacklevel=3, # The inner function takes up one level
|
|
||||||
)
|
|
||||||
|
|
||||||
return fn(*args, **kwargs)
|
|
||||||
|
|
||||||
return inner # type: ignore
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
def deprecate_kwargs(
|
|
||||||
*kws: str,
|
|
||||||
is_deprecated: bool | Callable[[], bool] = True,
|
|
||||||
additional_message: str | None = None,
|
|
||||||
) -> Callable[[F], F]:
|
|
||||||
deprecated_kws = set(kws)
|
|
||||||
|
|
||||||
if not callable(is_deprecated):
|
|
||||||
is_deprecated = partial(identity, is_deprecated)
|
|
||||||
|
|
||||||
def wrapper(fn: F) -> F:
|
|
||||||
@wraps(fn)
|
|
||||||
def inner(*args, **kwargs):
|
|
||||||
if is_deprecated():
|
|
||||||
deprecated_kwargs = kwargs.keys() & deprecated_kws
|
|
||||||
if deprecated_kwargs:
|
|
||||||
msg = (
|
|
||||||
f"The keyword arguments {deprecated_kwargs} are "
|
|
||||||
"deprecated and will be removed in a future update."
|
|
||||||
)
|
|
||||||
if additional_message is not None:
|
|
||||||
msg += f" {additional_message}"
|
|
||||||
|
|
||||||
warnings.warn(
|
|
||||||
DeprecationWarning(msg),
|
|
||||||
stacklevel=3, # The inner function takes up one level
|
|
||||||
)
|
|
||||||
|
|
||||||
return fn(*args, **kwargs)
|
|
||||||
|
|
||||||
return inner # type: ignore
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=8)
|
@lru_cache(maxsize=8)
|
||||||
def _cuda_device_count_stateless(cuda_visible_devices: str | None = None) -> int:
|
def _cuda_device_count_stateless(cuda_visible_devices: str | None = None) -> int:
|
||||||
# Note: cuda_visible_devices is not used, but we keep it as an argument for
|
# Note: cuda_visible_devices is not used, but we keep it as an argument for
|
||||||
@ -1426,21 +1321,6 @@ def weak_bind(
|
|||||||
return weak_bound
|
return weak_bound
|
||||||
|
|
||||||
|
|
||||||
def run_once(f: Callable[P, None]) -> Callable[P, None]:
|
|
||||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
|
|
||||||
if wrapper.has_run: # type: ignore[attr-defined]
|
|
||||||
return
|
|
||||||
|
|
||||||
with wrapper.lock: # type: ignore[attr-defined]
|
|
||||||
if not wrapper.has_run: # type: ignore[attr-defined]
|
|
||||||
wrapper.has_run = True # type: ignore[attr-defined]
|
|
||||||
return f(*args, **kwargs)
|
|
||||||
|
|
||||||
wrapper.has_run = False # type: ignore[attr-defined]
|
|
||||||
wrapper.lock = threading.Lock() # type: ignore[attr-defined]
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
class StoreBoolean(Action):
|
class StoreBoolean(Action):
|
||||||
def __call__(self, parser, namespace, values, option_string=None):
|
def __call__(self, parser, namespace, values, option_string=None):
|
||||||
if values.lower() == "true":
|
if values.lower() == "true":
|
||||||
@ -1929,122 +1809,6 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, **kwarg
|
|||||||
return await task(*args, **kwargs)
|
return await task(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
|
||||||
def supports_kw(
|
|
||||||
callable: Callable[..., object],
|
|
||||||
kw_name: str,
|
|
||||||
*,
|
|
||||||
requires_kw_only: bool = False,
|
|
||||||
allow_var_kwargs: bool = True,
|
|
||||||
) -> bool:
|
|
||||||
"""Check if a keyword is a valid kwarg for a callable; if requires_kw_only
|
|
||||||
disallows kwargs names that can also be positional arguments.
|
|
||||||
"""
|
|
||||||
params = inspect.signature(callable).parameters
|
|
||||||
if not params:
|
|
||||||
return False
|
|
||||||
|
|
||||||
param_val = params.get(kw_name)
|
|
||||||
|
|
||||||
# Types where the it may be valid, i.e., explicitly defined & nonvariadic
|
|
||||||
passable_kw_types = set(
|
|
||||||
(
|
|
||||||
inspect.Parameter.POSITIONAL_ONLY,
|
|
||||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
||||||
inspect.Parameter.KEYWORD_ONLY,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if param_val:
|
|
||||||
is_sig_param = param_val.kind in passable_kw_types
|
|
||||||
# We want kwargs only, but this is passable as a positional arg
|
|
||||||
if (
|
|
||||||
requires_kw_only
|
|
||||||
and is_sig_param
|
|
||||||
and param_val.kind != inspect.Parameter.KEYWORD_ONLY
|
|
||||||
):
|
|
||||||
return False
|
|
||||||
if (requires_kw_only and param_val.kind == inspect.Parameter.KEYWORD_ONLY) or (
|
|
||||||
not requires_kw_only and is_sig_param
|
|
||||||
):
|
|
||||||
return True
|
|
||||||
|
|
||||||
# If we're okay with var-kwargs, it's supported as long as
|
|
||||||
# the kw_name isn't something like *args, **kwargs
|
|
||||||
if allow_var_kwargs:
|
|
||||||
# Get the last param; type is ignored here because params is a proxy
|
|
||||||
# mapping, but it wraps an ordered dict, and they appear in order.
|
|
||||||
# Ref: https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters
|
|
||||||
last_param = params[next(reversed(params))] # type: ignore
|
|
||||||
return (
|
|
||||||
last_param.kind == inspect.Parameter.VAR_KEYWORD
|
|
||||||
and last_param.name != kw_name
|
|
||||||
)
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def get_allowed_kwarg_only_overrides(
|
|
||||||
callable: Callable[..., object],
|
|
||||||
overrides: Mapping[str, object] | None,
|
|
||||||
*,
|
|
||||||
requires_kw_only: bool = True,
|
|
||||||
allow_var_kwargs: bool = False,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Given a callable which has one or more keyword only params and a dict
|
|
||||||
mapping param names to values, drop values that can be not be kwarg
|
|
||||||
expanded to overwrite one or more keyword-only args. This is used in a
|
|
||||||
few places to handle custom processor overrides for multimodal models,
|
|
||||||
e.g., for profiling when processor options provided by the user
|
|
||||||
may affect the number of mm tokens per instance.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
callable: Callable which takes 0 or more keyword only arguments.
|
|
||||||
If None is provided, all overrides names are allowed.
|
|
||||||
overrides: Potential overrides to be used when invoking the callable.
|
|
||||||
allow_var_kwargs: Allows overrides that are expandable for var kwargs.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary containing the kwargs to be leveraged which may be used
|
|
||||||
to overwrite one or more keyword only arguments when invoking the
|
|
||||||
callable.
|
|
||||||
"""
|
|
||||||
if not overrides:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
# Drop any mm_processor_kwargs provided by the user that
|
|
||||||
# are not kwargs, unless it can fit it var_kwargs param
|
|
||||||
filtered_overrides = {
|
|
||||||
kwarg_name: val
|
|
||||||
for kwarg_name, val in overrides.items()
|
|
||||||
if supports_kw(
|
|
||||||
callable,
|
|
||||||
kwarg_name,
|
|
||||||
requires_kw_only=requires_kw_only,
|
|
||||||
allow_var_kwargs=allow_var_kwargs,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
# If anything is dropped, log a warning
|
|
||||||
dropped_keys = overrides.keys() - filtered_overrides.keys()
|
|
||||||
if dropped_keys:
|
|
||||||
if requires_kw_only:
|
|
||||||
logger.warning(
|
|
||||||
"The following intended overrides are not keyword-only args "
|
|
||||||
"and will be dropped: %s",
|
|
||||||
dropped_keys,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"The following intended overrides are not keyword args "
|
|
||||||
"and will be dropped: %s",
|
|
||||||
dropped_keys,
|
|
||||||
)
|
|
||||||
|
|
||||||
return filtered_overrides
|
|
||||||
|
|
||||||
|
|
||||||
# Using dynamo with vLLM doesn't really work well with PyTorch versions < 2.4.0.
|
# Using dynamo with vLLM doesn't really work well with PyTorch versions < 2.4.0.
|
||||||
# In particular, the FakeScalarType is not supported for earlier versions of
|
# In particular, the FakeScalarType is not supported for earlier versions of
|
||||||
# PyTorch which breaks dynamo for any ops registered using ScalarType.
|
# PyTorch which breaks dynamo for any ops registered using ScalarType.
|
||||||
|
|||||||
258
vllm/utils/func.py
Normal file
258
vllm/utils/func.py
Normal file
@ -0,0 +1,258 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
Contains helpers that are applied to functions.
|
||||||
|
|
||||||
|
This is similar in concept to the `functools` module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import concurrent.futures
|
||||||
|
import inspect
|
||||||
|
import threading
|
||||||
|
import warnings
|
||||||
|
from collections.abc import Awaitable, Callable, Mapping
|
||||||
|
from functools import lru_cache, partial, wraps
|
||||||
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
T = TypeVar("T")
|
||||||
|
F = TypeVar("F", bound=Callable[..., Any])
|
||||||
|
|
||||||
|
|
||||||
|
def identity(value: T, **kwargs) -> T:
|
||||||
|
"""Returns the first provided value."""
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def make_async(
|
||||||
|
func: Callable[P, T],
|
||||||
|
executor: concurrent.futures.Executor | None = None,
|
||||||
|
) -> Callable[P, Awaitable[T]]:
|
||||||
|
"""
|
||||||
|
Take a blocking function, and run it on in an executor thread.
|
||||||
|
|
||||||
|
This function prevents the blocking function from blocking the
|
||||||
|
asyncio event loop.
|
||||||
|
The code in this function needs to be thread safe.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future[T]:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
p_func = partial(func, *args, **kwargs)
|
||||||
|
return loop.run_in_executor(executor=executor, func=p_func)
|
||||||
|
|
||||||
|
return _async_wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def run_once(f: Callable[P, None]) -> Callable[P, None]:
|
||||||
|
def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
|
||||||
|
if wrapper.has_run: # type: ignore[attr-defined]
|
||||||
|
return
|
||||||
|
|
||||||
|
with wrapper.lock: # type: ignore[attr-defined]
|
||||||
|
if not wrapper.has_run: # type: ignore[attr-defined]
|
||||||
|
wrapper.has_run = True # type: ignore[attr-defined]
|
||||||
|
return f(*args, **kwargs)
|
||||||
|
|
||||||
|
wrapper.has_run = False # type: ignore[attr-defined]
|
||||||
|
wrapper.lock = threading.Lock() # type: ignore[attr-defined]
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def deprecate_args(
|
||||||
|
start_index: int,
|
||||||
|
is_deprecated: bool | Callable[[], bool] = True,
|
||||||
|
additional_message: str | None = None,
|
||||||
|
) -> Callable[[F], F]:
|
||||||
|
if not callable(is_deprecated):
|
||||||
|
is_deprecated = partial(identity, is_deprecated)
|
||||||
|
|
||||||
|
def wrapper(fn: F) -> F:
|
||||||
|
params = inspect.signature(fn).parameters
|
||||||
|
pos_types = (
|
||||||
|
inspect.Parameter.POSITIONAL_ONLY,
|
||||||
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||||
|
)
|
||||||
|
pos_kws = [kw for kw, param in params.items() if param.kind in pos_types]
|
||||||
|
|
||||||
|
@wraps(fn)
|
||||||
|
def inner(*args, **kwargs):
|
||||||
|
if is_deprecated():
|
||||||
|
deprecated_args = pos_kws[start_index : len(args)]
|
||||||
|
if deprecated_args:
|
||||||
|
msg = (
|
||||||
|
f"The positional arguments {deprecated_args} are "
|
||||||
|
"deprecated and will be removed in a future update."
|
||||||
|
)
|
||||||
|
if additional_message is not None:
|
||||||
|
msg += f" {additional_message}"
|
||||||
|
|
||||||
|
warnings.warn(
|
||||||
|
DeprecationWarning(msg),
|
||||||
|
stacklevel=3, # The inner function takes up one level
|
||||||
|
)
|
||||||
|
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
return inner # type: ignore
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def deprecate_kwargs(
|
||||||
|
*kws: str,
|
||||||
|
is_deprecated: bool | Callable[[], bool] = True,
|
||||||
|
additional_message: str | None = None,
|
||||||
|
) -> Callable[[F], F]:
|
||||||
|
deprecated_kws = set(kws)
|
||||||
|
|
||||||
|
if not callable(is_deprecated):
|
||||||
|
is_deprecated = partial(identity, is_deprecated)
|
||||||
|
|
||||||
|
def wrapper(fn: F) -> F:
|
||||||
|
@wraps(fn)
|
||||||
|
def inner(*args, **kwargs):
|
||||||
|
if is_deprecated():
|
||||||
|
deprecated_kwargs = kwargs.keys() & deprecated_kws
|
||||||
|
if deprecated_kwargs:
|
||||||
|
msg = (
|
||||||
|
f"The keyword arguments {deprecated_kwargs} are "
|
||||||
|
"deprecated and will be removed in a future update."
|
||||||
|
)
|
||||||
|
if additional_message is not None:
|
||||||
|
msg += f" {additional_message}"
|
||||||
|
|
||||||
|
warnings.warn(
|
||||||
|
DeprecationWarning(msg),
|
||||||
|
stacklevel=3, # The inner function takes up one level
|
||||||
|
)
|
||||||
|
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
return inner # type: ignore
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def supports_kw(
|
||||||
|
callable: Callable[..., object],
|
||||||
|
kw_name: str,
|
||||||
|
*,
|
||||||
|
requires_kw_only: bool = False,
|
||||||
|
allow_var_kwargs: bool = True,
|
||||||
|
) -> bool:
|
||||||
|
"""Check if a keyword is a valid kwarg for a callable; if requires_kw_only
|
||||||
|
disallows kwargs names that can also be positional arguments.
|
||||||
|
"""
|
||||||
|
params = inspect.signature(callable).parameters
|
||||||
|
if not params:
|
||||||
|
return False
|
||||||
|
|
||||||
|
param_val = params.get(kw_name)
|
||||||
|
|
||||||
|
# Types where the it may be valid, i.e., explicitly defined & nonvariadic
|
||||||
|
passable_kw_types = set(
|
||||||
|
(
|
||||||
|
inspect.Parameter.POSITIONAL_ONLY,
|
||||||
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||||
|
inspect.Parameter.KEYWORD_ONLY,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if param_val:
|
||||||
|
is_sig_param = param_val.kind in passable_kw_types
|
||||||
|
# We want kwargs only, but this is passable as a positional arg
|
||||||
|
if (
|
||||||
|
requires_kw_only
|
||||||
|
and is_sig_param
|
||||||
|
and param_val.kind != inspect.Parameter.KEYWORD_ONLY
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
if (requires_kw_only and param_val.kind == inspect.Parameter.KEYWORD_ONLY) or (
|
||||||
|
not requires_kw_only and is_sig_param
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# If we're okay with var-kwargs, it's supported as long as
|
||||||
|
# the kw_name isn't something like *args, **kwargs
|
||||||
|
if allow_var_kwargs:
|
||||||
|
# Get the last param; type is ignored here because params is a proxy
|
||||||
|
# mapping, but it wraps an ordered dict, and they appear in order.
|
||||||
|
# Ref: https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters
|
||||||
|
last_param = params[next(reversed(params))] # type: ignore
|
||||||
|
return (
|
||||||
|
last_param.kind == inspect.Parameter.VAR_KEYWORD
|
||||||
|
and last_param.name != kw_name
|
||||||
|
)
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_allowed_kwarg_only_overrides(
|
||||||
|
callable: Callable[..., object],
|
||||||
|
overrides: Mapping[str, object] | None,
|
||||||
|
*,
|
||||||
|
requires_kw_only: bool = True,
|
||||||
|
allow_var_kwargs: bool = False,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Given a callable which has one or more keyword only params and a dict
|
||||||
|
mapping param names to values, drop values that can be not be kwarg
|
||||||
|
expanded to overwrite one or more keyword-only args. This is used in a
|
||||||
|
few places to handle custom processor overrides for multimodal models,
|
||||||
|
e.g., for profiling when processor options provided by the user
|
||||||
|
may affect the number of mm tokens per instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
callable: Callable which takes 0 or more keyword only arguments.
|
||||||
|
If None is provided, all overrides names are allowed.
|
||||||
|
overrides: Potential overrides to be used when invoking the callable.
|
||||||
|
allow_var_kwargs: Allows overrides that are expandable for var kwargs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing the kwargs to be leveraged which may be used
|
||||||
|
to overwrite one or more keyword only arguments when invoking the
|
||||||
|
callable.
|
||||||
|
"""
|
||||||
|
if not overrides:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Drop any mm_processor_kwargs provided by the user that
|
||||||
|
# are not kwargs, unless it can fit it var_kwargs param
|
||||||
|
filtered_overrides = {
|
||||||
|
kwarg_name: val
|
||||||
|
for kwarg_name, val in overrides.items()
|
||||||
|
if supports_kw(
|
||||||
|
callable,
|
||||||
|
kwarg_name,
|
||||||
|
requires_kw_only=requires_kw_only,
|
||||||
|
allow_var_kwargs=allow_var_kwargs,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
# If anything is dropped, log a warning
|
||||||
|
dropped_keys = overrides.keys() - filtered_overrides.keys()
|
||||||
|
if dropped_keys:
|
||||||
|
if requires_kw_only:
|
||||||
|
logger.warning(
|
||||||
|
"The following intended overrides are not keyword-only args "
|
||||||
|
"and will be dropped: %s",
|
||||||
|
dropped_keys,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"The following intended overrides are not keyword args "
|
||||||
|
"and will be dropped: %s",
|
||||||
|
dropped_keys,
|
||||||
|
)
|
||||||
|
|
||||||
|
return filtered_overrides
|
||||||
@ -29,7 +29,8 @@ from vllm.tracing import init_tracer
|
|||||||
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
|
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.utils import Device, as_list, cancel_task_threadsafe, cdiv, deprecate_kwargs
|
from vllm.utils import Device, as_list, cancel_task_threadsafe, cdiv
|
||||||
|
from vllm.utils.func import deprecate_kwargs
|
||||||
from vllm.v1.engine import EngineCoreRequest
|
from vllm.v1.engine import EngineCoreRequest
|
||||||
from vllm.v1.engine.core_client import EngineCoreClient
|
from vllm.v1.engine.core_client import EngineCoreClient
|
||||||
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
|
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user