mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-23 03:15:01 +08:00
[Bugfix] Fix size calculation of processing cache (#15114)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
1fe0fd12d3
commit
3d446433ec
@ -7,15 +7,20 @@ from unittest.mock import MagicMock
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
from transformers import ProcessorMixin
|
from transformers import ProcessorMixin
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargs,
|
||||||
|
MultiModalKwargsItem,
|
||||||
|
MultiModalSharedField)
|
||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
from vllm.multimodal.processing import (PlaceholderFeaturesInfo,
|
from vllm.multimodal.processing import (PlaceholderFeaturesInfo,
|
||||||
PromptIndexTargets, PromptInsertion,
|
ProcessingCache, PromptIndexTargets,
|
||||||
PromptReplacement, apply_text_matches,
|
PromptInsertion, PromptReplacement,
|
||||||
|
apply_text_matches,
|
||||||
apply_token_matches,
|
apply_token_matches,
|
||||||
find_mm_placeholders,
|
find_mm_placeholders,
|
||||||
find_text_matches, find_token_matches,
|
find_text_matches, find_token_matches,
|
||||||
@ -890,6 +895,45 @@ def test_find_mm_placeholders(
|
|||||||
assert result == expected
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
def _dummy_elem(modality: str, key: str, size: int):
|
||||||
|
return MultiModalFieldElem(
|
||||||
|
modality=modality,
|
||||||
|
key=key,
|
||||||
|
data=torch.empty((size, ), dtype=torch.int8),
|
||||||
|
field=MultiModalSharedField(1),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _dummy_item(modality: str, size_by_key: dict[str, int]):
|
||||||
|
return MultiModalKwargsItem.from_elems([
|
||||||
|
_dummy_elem(modality, key, size) for key, size in size_by_key.items()
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
def _dummy_kw(size_by_key_modality: dict[str, dict[str, int]]):
|
||||||
|
return MultiModalKwargs.from_items([
|
||||||
|
_dummy_item(modality, size_by_key)
|
||||||
|
for modality, size_by_key in size_by_key_modality.items()
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
# yapf: disable
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("item", "expected_size"),
|
||||||
|
[
|
||||||
|
(_dummy_item("a", {"a1": 100}), 100),
|
||||||
|
(_dummy_item("a", {"a1": 100, "a2": 110}), 210),
|
||||||
|
(_dummy_kw({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460), # noqa: E501
|
||||||
|
],
|
||||||
|
)
|
||||||
|
# yapf: enable
|
||||||
|
def test_cache_item_size(item, expected_size):
|
||||||
|
cache = ProcessingCache.get_lru_cache(2048, type(item))
|
||||||
|
cache[""] = item
|
||||||
|
|
||||||
|
assert cache.currsize == expected_size
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
|
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("limit", "num_supported", "is_valid"),
|
("limit", "num_supported", "is_valid"),
|
||||||
|
|||||||
@ -26,7 +26,7 @@ from vllm.utils import GiB_bytes, flatten_2d_lists, full_groupby
|
|||||||
from .hasher import MultiModalHasher
|
from .hasher import MultiModalHasher
|
||||||
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
|
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
|
||||||
MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs,
|
MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs,
|
||||||
MultiModalKwargsItem, PlaceholderRange)
|
MultiModalKwargsItem, NestedTensors, PlaceholderRange)
|
||||||
from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems,
|
from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems,
|
||||||
MultiModalDataParser)
|
MultiModalDataParser)
|
||||||
|
|
||||||
@ -853,33 +853,62 @@ class ProcessingCache:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_lru_cache(
|
def get_lru_cache(
|
||||||
capacity_gb: int,
|
capacity_gb: float,
|
||||||
value_type: type[_V],
|
value_type: type[_V],
|
||||||
|
*,
|
||||||
|
debug: bool = False,
|
||||||
) -> LRUCache[str, _V]:
|
) -> LRUCache[str, _V]:
|
||||||
|
|
||||||
def get_size(leaf: object) -> int:
|
def get_leaf_size(leaf: object) -> int:
|
||||||
|
# MultiModalKwargs is not a subclass of dict
|
||||||
|
if isinstance(leaf, MultiModalKwargs):
|
||||||
|
return get_item_size(leaf.data)
|
||||||
|
|
||||||
|
# MultiModalKwargsItem is not a subclass of dict
|
||||||
|
if isinstance(leaf, MultiModalKwargsItem):
|
||||||
|
leaf_data = {k: v.data for k, v in leaf.items()}
|
||||||
|
return get_item_size(leaf_data)
|
||||||
|
|
||||||
|
# sys.getsizeof doesn't work for tensors
|
||||||
if isinstance(leaf, torch.Tensor):
|
if isinstance(leaf, torch.Tensor):
|
||||||
return leaf.nbytes # sys.getsizeof doesn't work for tensors
|
return leaf.nbytes
|
||||||
|
|
||||||
return sys.getsizeof(leaf)
|
return sys.getsizeof(leaf)
|
||||||
|
|
||||||
return LRUCache[str, _V](
|
def get_item_size(
|
||||||
GiB_bytes * capacity_gb,
|
value: Union[MultiModalKwargs, MultiModalKwargsItem,
|
||||||
getsizeof=lambda x: json_reduce_leaves(
|
Mapping[str, NestedTensors]]
|
||||||
|
) -> int:
|
||||||
|
size = json_reduce_leaves(
|
||||||
lambda a, b: a + b,
|
lambda a, b: a + b,
|
||||||
json_map_leaves(get_size, x),
|
json_map_leaves(get_leaf_size, value),
|
||||||
),
|
)
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self, capacity_gb: int) -> None:
|
if debug:
|
||||||
|
logger.debug("Calculated size of %s to be %.2f GiB",
|
||||||
|
type(value), size / GiB_bytes)
|
||||||
|
|
||||||
|
return size
|
||||||
|
|
||||||
|
return LRUCache(GiB_bytes * capacity_gb, getsizeof=get_item_size)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
capacity_gb: float,
|
||||||
|
*,
|
||||||
|
debug_cache_hit_ratio_steps: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# DEBUG: Set to None to disable
|
self.debug_cache_hit_ratio_steps = debug_cache_hit_ratio_steps
|
||||||
self.debug_cache_hit_ratio_steps: Optional[int] = None
|
|
||||||
self.debug_cache_hits = 0
|
self.debug_cache_hits = 0
|
||||||
self.debug_cache_total = 0
|
self.debug_cache_total = 0
|
||||||
|
|
||||||
self._cache = self.get_lru_cache(capacity_gb, MultiModalKwargsItem)
|
self._cache = self.get_lru_cache(
|
||||||
|
capacity_gb,
|
||||||
|
MultiModalKwargsItem,
|
||||||
|
debug=bool(debug_cache_hit_ratio_steps),
|
||||||
|
)
|
||||||
|
|
||||||
def _maybe_log_cache_stats(self) -> None:
|
def _maybe_log_cache_stats(self) -> None:
|
||||||
steps = self.debug_cache_hit_ratio_steps
|
steps = self.debug_cache_hit_ratio_steps
|
||||||
@ -890,6 +919,9 @@ class ProcessingCache:
|
|||||||
if total > 0 and total % steps == 0:
|
if total > 0 and total % steps == 0:
|
||||||
logger.debug("ProcessingCache: hit_ratio = %.2f",
|
logger.debug("ProcessingCache: hit_ratio = %.2f",
|
||||||
self.debug_cache_hits / total)
|
self.debug_cache_hits / total)
|
||||||
|
logger.debug("ProcessingCache: size = %.2f / %.2f GiB",
|
||||||
|
self._cache.currsize / GiB_bytes,
|
||||||
|
self._cache.maxsize / GiB_bytes)
|
||||||
|
|
||||||
def get(
|
def get(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user