mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-26 16:57:05 +08:00
101 lines
2.6 KiB
Python
101 lines
2.6 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import sys
|
|
from collections.abc import Mapping
|
|
from dataclasses import dataclass
|
|
from typing import TypeVar, Union
|
|
|
|
import torch
|
|
|
|
from vllm.logger import init_logger
|
|
from vllm.utils import GiB_bytes, LRUCache
|
|
from vllm.utils.jsontree import json_map_leaves, json_reduce_leaves
|
|
|
|
from .inputs import (MultiModalFieldElem, MultiModalKwargs,
|
|
MultiModalKwargsItem, MultiModalKwargsItems,
|
|
NestedTensors)
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class MultiModalCacheItemMetadata:
|
|
size: int
|
|
|
|
@classmethod
|
|
def wraps(cls, value: "MultiModalCacheValue"):
|
|
return cls(size=MultiModalCache.get_item_size(value))
|
|
|
|
|
|
MultiModalCacheValue = Union[
|
|
MultiModalKwargsItems,
|
|
MultiModalKwargsItem,
|
|
MultiModalKwargs,
|
|
Mapping[str, NestedTensors],
|
|
MultiModalCacheItemMetadata,
|
|
]
|
|
|
|
_V = TypeVar("_V", bound=MultiModalCacheValue)
|
|
|
|
|
|
class MultiModalCache:
|
|
|
|
@classmethod
|
|
def get_leaf_size(
|
|
cls,
|
|
leaf: object,
|
|
*,
|
|
debug: bool = False,
|
|
) -> int:
|
|
if isinstance(leaf, MultiModalFieldElem):
|
|
return cls.get_item_size(leaf.data) # type: ignore
|
|
|
|
# These are not subclasses of dict
|
|
if isinstance(leaf, MultiModalKwargsItems):
|
|
return cls.get_item_size(leaf.data) # type: ignore
|
|
if isinstance(leaf, MultiModalKwargsItem):
|
|
return cls.get_item_size(leaf.data) # type: ignore
|
|
if isinstance(leaf, MultiModalKwargs):
|
|
return cls.get_item_size(leaf.data) # type: ignore
|
|
|
|
# sys.getsizeof doesn't work for tensors
|
|
if isinstance(leaf, torch.Tensor):
|
|
return leaf.nbytes
|
|
|
|
if isinstance(leaf, MultiModalCacheItemMetadata):
|
|
return leaf.size
|
|
|
|
return sys.getsizeof(leaf)
|
|
|
|
@classmethod
|
|
def get_item_size(
|
|
cls,
|
|
value: MultiModalCacheValue,
|
|
*,
|
|
debug: bool = False,
|
|
) -> int:
|
|
size = json_reduce_leaves(
|
|
lambda a, b: a + b,
|
|
json_map_leaves(lambda x: cls.get_leaf_size(x, debug=debug),
|
|
value),
|
|
)
|
|
|
|
if debug:
|
|
logger.debug("Calculated size of %s to be %.2f GiB", type(value),
|
|
size / GiB_bytes)
|
|
|
|
return size
|
|
|
|
@classmethod
|
|
def get_lru_cache(
|
|
cls,
|
|
capacity_gb: float,
|
|
value_type: type[_V],
|
|
*,
|
|
debug: bool = False,
|
|
) -> LRUCache[str, _V]:
|
|
return LRUCache(
|
|
GiB_bytes * capacity_gb,
|
|
getsizeof=lambda x: cls.get_item_size(x, debug=debug),
|
|
)
|