[Bugfix] Fix size calculation of processing cache (#15114)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-03-19 20:53:19 +08:00 committed by GitHub
parent 1fe0fd12d3
commit 3d446433ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 92 additions and 16 deletions

View File

@ -7,15 +7,20 @@ from unittest.mock import MagicMock
import numpy as np import numpy as np
import pytest import pytest
import torch
from transformers import ProcessorMixin from transformers import ProcessorMixin
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargs,
MultiModalKwargsItem,
MultiModalSharedField)
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
from vllm.multimodal.processing import (PlaceholderFeaturesInfo, from vllm.multimodal.processing import (PlaceholderFeaturesInfo,
PromptIndexTargets, PromptInsertion, ProcessingCache, PromptIndexTargets,
PromptReplacement, apply_text_matches, PromptInsertion, PromptReplacement,
apply_text_matches,
apply_token_matches, apply_token_matches,
find_mm_placeholders, find_mm_placeholders,
find_text_matches, find_token_matches, find_text_matches, find_token_matches,
@ -890,6 +895,45 @@ def test_find_mm_placeholders(
assert result == expected assert result == expected
def _dummy_elem(modality: str, key: str, size: int):
return MultiModalFieldElem(
modality=modality,
key=key,
data=torch.empty((size, ), dtype=torch.int8),
field=MultiModalSharedField(1),
)
def _dummy_item(modality: str, size_by_key: dict[str, int]):
return MultiModalKwargsItem.from_elems([
_dummy_elem(modality, key, size) for key, size in size_by_key.items()
])
def _dummy_kw(size_by_key_modality: dict[str, dict[str, int]]):
return MultiModalKwargs.from_items([
_dummy_item(modality, size_by_key)
for modality, size_by_key in size_by_key_modality.items()
])
# yapf: disable
@pytest.mark.parametrize(
("item", "expected_size"),
[
(_dummy_item("a", {"a1": 100}), 100),
(_dummy_item("a", {"a1": 100, "a2": 110}), 210),
(_dummy_kw({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460), # noqa: E501
],
)
# yapf: enable
def test_cache_item_size(item, expected_size):
cache = ProcessingCache.get_lru_cache(2048, type(item))
cache[""] = item
assert cache.currsize == expected_size
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) @pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
("limit", "num_supported", "is_valid"), ("limit", "num_supported", "is_valid"),

View File

@ -26,7 +26,7 @@ from vllm.utils import GiB_bytes, flatten_2d_lists, full_groupby
from .hasher import MultiModalHasher from .hasher import MultiModalHasher
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs, MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs,
MultiModalKwargsItem, PlaceholderRange) MultiModalKwargsItem, NestedTensors, PlaceholderRange)
from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems, from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems,
MultiModalDataParser) MultiModalDataParser)
@ -853,33 +853,62 @@ class ProcessingCache:
@staticmethod @staticmethod
def get_lru_cache( def get_lru_cache(
capacity_gb: int, capacity_gb: float,
value_type: type[_V], value_type: type[_V],
*,
debug: bool = False,
) -> LRUCache[str, _V]: ) -> LRUCache[str, _V]:
def get_size(leaf: object) -> int: def get_leaf_size(leaf: object) -> int:
# MultiModalKwargs is not a subclass of dict
if isinstance(leaf, MultiModalKwargs):
return get_item_size(leaf.data)
# MultiModalKwargsItem is not a subclass of dict
if isinstance(leaf, MultiModalKwargsItem):
leaf_data = {k: v.data for k, v in leaf.items()}
return get_item_size(leaf_data)
# sys.getsizeof doesn't work for tensors
if isinstance(leaf, torch.Tensor): if isinstance(leaf, torch.Tensor):
return leaf.nbytes # sys.getsizeof doesn't work for tensors return leaf.nbytes
return sys.getsizeof(leaf) return sys.getsizeof(leaf)
return LRUCache[str, _V]( def get_item_size(
GiB_bytes * capacity_gb, value: Union[MultiModalKwargs, MultiModalKwargsItem,
getsizeof=lambda x: json_reduce_leaves( Mapping[str, NestedTensors]]
) -> int:
size = json_reduce_leaves(
lambda a, b: a + b, lambda a, b: a + b,
json_map_leaves(get_size, x), json_map_leaves(get_leaf_size, value),
),
) )
def __init__(self, capacity_gb: int) -> None: if debug:
logger.debug("Calculated size of %s to be %.2f GiB",
type(value), size / GiB_bytes)
return size
return LRUCache(GiB_bytes * capacity_gb, getsizeof=get_item_size)
def __init__(
self,
capacity_gb: float,
*,
debug_cache_hit_ratio_steps: Optional[int] = None,
) -> None:
super().__init__() super().__init__()
# DEBUG: Set to None to disable self.debug_cache_hit_ratio_steps = debug_cache_hit_ratio_steps
self.debug_cache_hit_ratio_steps: Optional[int] = None
self.debug_cache_hits = 0 self.debug_cache_hits = 0
self.debug_cache_total = 0 self.debug_cache_total = 0
self._cache = self.get_lru_cache(capacity_gb, MultiModalKwargsItem) self._cache = self.get_lru_cache(
capacity_gb,
MultiModalKwargsItem,
debug=bool(debug_cache_hit_ratio_steps),
)
def _maybe_log_cache_stats(self) -> None: def _maybe_log_cache_stats(self) -> None:
steps = self.debug_cache_hit_ratio_steps steps = self.debug_cache_hit_ratio_steps
@ -890,6 +919,9 @@ class ProcessingCache:
if total > 0 and total % steps == 0: if total > 0 and total % steps == 0:
logger.debug("ProcessingCache: hit_ratio = %.2f", logger.debug("ProcessingCache: hit_ratio = %.2f",
self.debug_cache_hits / total) self.debug_cache_hits / total)
logger.debug("ProcessingCache: size = %.2f / %.2f GiB",
self._cache.currsize / GiB_bytes,
self._cache.maxsize / GiB_bytes)
def get( def get(
self, self,