Adds json_count_leaves utility function (#23899)

Signed-off-by: aditchawdhary <aditxy@hotmail.com>
This commit is contained in:
Adit Chawdhary 2025-08-29 17:58:13 +05:30 committed by GitHub
parent 67c14906aa
commit 4f7cde7272
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 72 additions and 10 deletions

View File

@ -379,9 +379,9 @@ def test_duplicate_dict_args(caplog_vllm, parser):
def test_supports_kw(callable,kw_name,requires_kw_only, def test_supports_kw(callable,kw_name,requires_kw_only,
allow_var_kwargs,is_supported): allow_var_kwargs,is_supported):
assert supports_kw( assert supports_kw(
callable=callable, callable=callable,
kw_name=kw_name, kw_name=kw_name,
requires_kw_only=requires_kw_only, requires_kw_only=requires_kw_only,
allow_var_kwargs=allow_var_kwargs allow_var_kwargs=allow_var_kwargs
) == is_supported ) == is_supported
@ -948,6 +948,36 @@ def test_join_host_port():
assert join_host_port("::1", 5555) == "[::1]:5555" assert join_host_port("::1", 5555) == "[::1]:5555"
def test_json_count_leaves():
"""Test json_count_leaves function from jsontree utility."""
from vllm.utils.jsontree import json_count_leaves
# Single leaf values
assert json_count_leaves(42) == 1
assert json_count_leaves("hello") == 1
assert json_count_leaves(None) == 1
# Empty containers
assert json_count_leaves([]) == 0
assert json_count_leaves({}) == 0
assert json_count_leaves(()) == 0
# Flat structures
assert json_count_leaves([1, 2, 3]) == 3
assert json_count_leaves({"a": 1, "b": 2}) == 2
assert json_count_leaves((1, 2, 3)) == 3
# Nested structures
nested_dict = {"a": 1, "b": {"c": 2, "d": 3}}
assert json_count_leaves(nested_dict) == 3
nested_list = [1, [2, 3], 4]
assert json_count_leaves(nested_list) == 4
mixed_nested = {"list": [1, 2], "dict": {"x": 3}, "value": 4}
assert json_count_leaves(mixed_nested) == 4
def test_convert_ids_list_to_tokens(): def test_convert_ids_list_to_tokens():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct") tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
token_ids = tokenizer.encode("Hello, world!") token_ids = tokenizer.encode("Hello, world!")

View File

@ -10,7 +10,8 @@ from typing_extensions import TypeAlias, override
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import GiB_bytes, LRUCache from vllm.utils import GiB_bytes, LRUCache
from vllm.utils.jsontree import json_map_leaves, json_reduce_leaves from vllm.utils.jsontree import (json_count_leaves, json_map_leaves,
json_reduce_leaves)
from .inputs import (MultiModalFeatureSpec, MultiModalFieldElem, from .inputs import (MultiModalFeatureSpec, MultiModalFieldElem,
MultiModalKwargs, MultiModalKwargsItem, MultiModalKwargs, MultiModalKwargsItem,
@ -127,11 +128,32 @@ class MultiModalCache:
) )
if debug: if debug:
logger.debug("Calculated size of %s to be %.2f GiB", type(value), leaf_count = json_count_leaves(value)
size / GiB_bytes) logger.debug(
"Calculated size of %s to be %.2f GiB (%d leaves)",
type(value),
size / GiB_bytes,
leaf_count,
)
return size return size
@classmethod
def get_item_complexity(cls, value: MultiModalCacheValue) -> int:
"""
Get the number of leaf elements in a multi-modal cache value.
This provides a measure of structural complexity that can be useful
for debugging cache performance and understanding data patterns.
Args:
value: The multi-modal cache value to analyze.
Returns:
The number of leaf elements in the nested structure.
"""
return json_count_leaves(value)
@classmethod @classmethod
def get_lru_cache( def get_lru_cache(
cls, cls,
@ -184,7 +206,7 @@ class BaseMultiModalCache(ABC, Generic[_I, _O]):
""" """
Possibly update a multi-modal item based on whether it is Possibly update a multi-modal item based on whether it is
in the underlying cache. in the underlying cache.
This update is done out-of-place and updates the cache eviction order. This update is done out-of-place and updates the cache eviction order.
Args: Args:
@ -262,7 +284,7 @@ class BaseMultiModalProcessorCache(
in the underlying cache. in the underlying cache.
This **DOES NOT** update the cache eviction order. This **DOES NOT** update the cache eviction order.
Args: Args:
mm_hashes: The hash of each item to check. mm_hashes: The hash of each item to check.

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Helper functions to work with nested JSON structures.""" """Helper functions to work with nested JSON structures."""
from collections.abc import Iterable from collections.abc import Iterable
from functools import reduce from functools import reduce
from typing import Callable, TypeVar, Union, overload from typing import Callable, TypeVar, Union, overload
@ -8,8 +9,12 @@ from typing import Callable, TypeVar, Union, overload
_T = TypeVar("_T") _T = TypeVar("_T")
_U = TypeVar("_U") _U = TypeVar("_U")
JSONTree = Union[dict[str, "JSONTree[_T]"], list["JSONTree[_T]"], JSONTree = Union[
tuple["JSONTree[_T]", ...], _T] dict[str, "JSONTree[_T]"],
list["JSONTree[_T]"],
tuple["JSONTree[_T]", ...],
_T,
]
"""A nested JSON structure where the leaves need not be JSON-serializable.""" """A nested JSON structure where the leaves need not be JSON-serializable."""
@ -78,3 +83,8 @@ def json_reduce_leaves(
json_iter_leaves(value), json_iter_leaves(value),
initial, initial,
) )
def json_count_leaves(value: JSONTree[_T]) -> int:
"""Count the number of leaves in a nested JSON structure."""
return sum(1 for _ in json_iter_leaves(value))