vllm/vllm/multimodal/cache.py
Cyrus Leung 27e8d1ea3e
[Refactor] Define MultiModalKwargsItems separate from MultiModalKwargs (#23053)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2025-08-18 09:52:00 +00:00

101 lines
2.6 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import sys
from collections.abc import Mapping
from dataclasses import dataclass
from typing import TypeVar, Union
import torch
from vllm.logger import init_logger
from vllm.utils import GiB_bytes, LRUCache
from vllm.utils.jsontree import json_map_leaves, json_reduce_leaves
from .inputs import (MultiModalFieldElem, MultiModalKwargs,
MultiModalKwargsItem, MultiModalKwargsItems,
NestedTensors)
logger = init_logger(__name__)
@dataclass
class MultiModalCacheItemMetadata:
size: int
@classmethod
def wraps(cls, value: "MultiModalCacheValue"):
return cls(size=MultiModalCache.get_item_size(value))
MultiModalCacheValue = Union[
MultiModalKwargsItems,
MultiModalKwargsItem,
MultiModalKwargs,
Mapping[str, NestedTensors],
MultiModalCacheItemMetadata,
]
_V = TypeVar("_V", bound=MultiModalCacheValue)
class MultiModalCache:
@classmethod
def get_leaf_size(
cls,
leaf: object,
*,
debug: bool = False,
) -> int:
if isinstance(leaf, MultiModalFieldElem):
return cls.get_item_size(leaf.data) # type: ignore
# These are not subclasses of dict
if isinstance(leaf, MultiModalKwargsItems):
return cls.get_item_size(leaf.data) # type: ignore
if isinstance(leaf, MultiModalKwargsItem):
return cls.get_item_size(leaf.data) # type: ignore
if isinstance(leaf, MultiModalKwargs):
return cls.get_item_size(leaf.data) # type: ignore
# sys.getsizeof doesn't work for tensors
if isinstance(leaf, torch.Tensor):
return leaf.nbytes
if isinstance(leaf, MultiModalCacheItemMetadata):
return leaf.size
return sys.getsizeof(leaf)
@classmethod
def get_item_size(
cls,
value: MultiModalCacheValue,
*,
debug: bool = False,
) -> int:
size = json_reduce_leaves(
lambda a, b: a + b,
json_map_leaves(lambda x: cls.get_leaf_size(x, debug=debug),
value),
)
if debug:
logger.debug("Calculated size of %s to be %.2f GiB", type(value),
size / GiB_bytes)
return size
@classmethod
def get_lru_cache(
cls,
capacity_gb: float,
value_type: type[_V],
*,
debug: bool = False,
) -> LRUCache[str, _V]:
return LRUCache(
GiB_bytes * capacity_gb,
getsizeof=lambda x: cls.get_item_size(x, debug=debug),
)