[Core][MM] Cleanup MultiModalCache (#25006)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
Lukas Geiger 2025-09-18 05:08:41 +01:00 committed by GitHub
parent 32baf1d036
commit b98219670f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import operator
import sys import sys
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
@ -91,26 +92,15 @@ _V = TypeVar("_V", bound=MultiModalCacheValue)
class MultiModalCache: class MultiModalCache:
@classmethod @classmethod
def get_leaf_size( def get_leaf_size(cls, leaf: object) -> int:
cls,
leaf: object,
*,
debug: bool = False,
) -> int:
if isinstance(leaf, MultiModalProcessorCacheItem): if isinstance(leaf, MultiModalProcessorCacheItem):
return cls.get_leaf_size(leaf.item) return cls.get_leaf_size(leaf.item)
if isinstance(leaf, MultiModalProcessorCacheItemMetadata): if isinstance(leaf, MultiModalProcessorCacheItemMetadata):
return leaf.item_size return leaf.item_size
# These are not subclasses of dict # These are not subclasses of dict
if isinstance(leaf, MultiModalKwargsItems): if isinstance(leaf, (MultiModalKwargs, MultiModalKwargsItems,
return cls.get_item_size(leaf.data) # type: ignore MultiModalKwargsItem, MultiModalFieldElem)):
if isinstance(leaf, MultiModalKwargsItem):
return cls.get_item_size(leaf.data) # type: ignore
if isinstance(leaf, MultiModalKwargs):
return cls.get_item_size(leaf.data) # type: ignore
if isinstance(leaf, MultiModalFieldElem):
return cls.get_item_size(leaf.data) # type: ignore return cls.get_item_size(leaf.data) # type: ignore
# sys.getsizeof doesn't work for tensors # sys.getsizeof doesn't work for tensors
@ -126,11 +116,8 @@ class MultiModalCache:
*, *,
debug: bool = False, debug: bool = False,
) -> int: ) -> int:
size = json_reduce_leaves( size = json_reduce_leaves(operator.add,
lambda a, b: a + b, json_map_leaves(cls.get_leaf_size, value))
json_map_leaves(lambda x: cls.get_leaf_size(x, debug=debug),
value),
)
if debug: if debug:
leaf_count = json_count_leaves(value) leaf_count = json_count_leaves(value)