diff --git a/tests/utils_/test_cache.py b/tests/utils_/test_cache.py new file mode 100644 index 0000000000000..e361006fd8e66 --- /dev/null +++ b/tests/utils_/test_cache.py @@ -0,0 +1,125 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.utils.cache import CacheInfo, LRUCache + + +class TestLRUCache(LRUCache): + def _on_remove(self, key, value): + if not hasattr(self, "_remove_counter"): + self._remove_counter = 0 + self._remove_counter += 1 + + +def test_lru_cache(): + cache = TestLRUCache(3) + assert cache.stat() == CacheInfo(hits=0, total=0) + assert cache.stat(delta=True) == CacheInfo(hits=0, total=0) + + cache.put(1, 1) + assert len(cache) == 1 + + cache.put(1, 1) + assert len(cache) == 1 + + cache.put(2, 2) + assert len(cache) == 2 + + cache.put(3, 3) + assert len(cache) == 3 + assert set(cache.cache) == {1, 2, 3} + + cache.put(4, 4) + assert len(cache) == 3 + assert set(cache.cache) == {2, 3, 4} + assert cache._remove_counter == 1 + + assert cache.get(2) == 2 + assert cache.stat() == CacheInfo(hits=1, total=1) + assert cache.stat(delta=True) == CacheInfo(hits=1, total=1) + + assert cache[2] == 2 + assert cache.stat() == CacheInfo(hits=2, total=2) + assert cache.stat(delta=True) == CacheInfo(hits=1, total=1) + + cache.put(5, 5) + assert set(cache.cache) == {2, 4, 5} + assert cache._remove_counter == 2 + + assert cache.pop(5) == 5 + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + assert cache.get(-1) is None + assert cache.stat() == CacheInfo(hits=2, total=3) + assert cache.stat(delta=True) == CacheInfo(hits=0, total=1) + + cache.pop(10) + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache.get(10) + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache.put(6, 6) + assert len(cache) == 3 + assert set(cache.cache) == {2, 4, 6} + assert 2 in cache + assert 4 in cache + assert 6 in cache + + cache.remove_oldest() + assert len(cache) == 2 + assert set(cache.cache) == {2, 6} + assert cache._remove_counter == 4 + + cache.clear() + assert len(cache) == 0 + assert cache._remove_counter == 6 + assert cache.stat() == CacheInfo(hits=0, total=0) + assert cache.stat(delta=True) == CacheInfo(hits=0, total=0) + + cache._remove_counter = 0 + + cache[1] = 1 + assert len(cache) == 1 + + cache[1] = 1 + assert len(cache) == 1 + + cache[2] = 2 + assert len(cache) == 2 + + cache[3] = 3 + assert len(cache) == 3 + assert set(cache.cache) == {1, 2, 3} + + cache[4] = 4 + assert len(cache) == 3 + assert set(cache.cache) == {2, 3, 4} + assert cache._remove_counter == 1 + assert cache[2] == 2 + + cache[5] = 5 + assert set(cache.cache) == {2, 4, 5} + assert cache._remove_counter == 2 + + del cache[5] + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache.pop(10) + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache[6] = 6 + assert len(cache) == 3 + assert set(cache.cache) == {2, 4, 6} + assert 2 in cache + assert 4 in cache + assert 6 in cache diff --git a/tests/utils_/test_utils.py b/tests/utils_/test_utils.py index 71c82feac36bc..cd5fa550498be 100644 --- a/tests/utils_/test_utils.py +++ b/tests/utils_/test_utils.py @@ -23,11 +23,8 @@ from vllm_test_utils.monitor import monitor from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.transformers_utils.detokenizer_utils import convert_ids_list_to_tokens -# isort: off from vllm.utils import ( - CacheInfo, FlexibleArgumentParser, - LRUCache, MemorySnapshot, PlaceholderModule, bind_kv_cache, @@ -50,7 +47,6 @@ from vllm.utils import ( unique_filepath, ) -# isort: on from ..utils import create_new_process_for_each_test, error_on_warning @@ -557,128 +553,6 @@ def test_bind_kv_cache_pp(): assert ctx["layers.0.self_attn"].kv_cache[1] is kv_cache[1][0] -class TestLRUCache(LRUCache): - def _on_remove(self, key, value): - if not hasattr(self, "_remove_counter"): - self._remove_counter = 0 - self._remove_counter += 1 - - -def test_lru_cache(): - cache = TestLRUCache(3) - assert cache.stat() == CacheInfo(hits=0, total=0) - assert cache.stat(delta=True) == CacheInfo(hits=0, total=0) - - cache.put(1, 1) - assert len(cache) == 1 - - cache.put(1, 1) - assert len(cache) == 1 - - cache.put(2, 2) - assert len(cache) == 2 - - cache.put(3, 3) - assert len(cache) == 3 - assert set(cache.cache) == {1, 2, 3} - - cache.put(4, 4) - assert len(cache) == 3 - assert set(cache.cache) == {2, 3, 4} - assert cache._remove_counter == 1 - - assert cache.get(2) == 2 - assert cache.stat() == CacheInfo(hits=1, total=1) - assert cache.stat(delta=True) == CacheInfo(hits=1, total=1) - - assert cache[2] == 2 - assert cache.stat() == CacheInfo(hits=2, total=2) - assert cache.stat(delta=True) == CacheInfo(hits=1, total=1) - - cache.put(5, 5) - assert set(cache.cache) == {2, 4, 5} - assert cache._remove_counter == 2 - - assert cache.pop(5) == 5 - assert len(cache) == 2 - assert set(cache.cache) == {2, 4} - assert cache._remove_counter == 3 - - assert cache.get(-1) is None - assert cache.stat() == CacheInfo(hits=2, total=3) - assert cache.stat(delta=True) == CacheInfo(hits=0, total=1) - - cache.pop(10) - assert len(cache) == 2 - assert set(cache.cache) == {2, 4} - assert cache._remove_counter == 3 - - cache.get(10) - assert len(cache) == 2 - assert set(cache.cache) == {2, 4} - assert cache._remove_counter == 3 - - cache.put(6, 6) - assert len(cache) == 3 - assert set(cache.cache) == {2, 4, 6} - assert 2 in cache - assert 4 in cache - assert 6 in cache - - cache.remove_oldest() - assert len(cache) == 2 - assert set(cache.cache) == {2, 6} - assert cache._remove_counter == 4 - - cache.clear() - assert len(cache) == 0 - assert cache._remove_counter == 6 - assert cache.stat() == CacheInfo(hits=0, total=0) - assert cache.stat(delta=True) == CacheInfo(hits=0, total=0) - - cache._remove_counter = 0 - - cache[1] = 1 - assert len(cache) == 1 - - cache[1] = 1 - assert len(cache) == 1 - - cache[2] = 2 - assert len(cache) == 2 - - cache[3] = 3 - assert len(cache) == 3 - assert set(cache.cache) == {1, 2, 3} - - cache[4] = 4 - assert len(cache) == 3 - assert set(cache.cache) == {2, 3, 4} - assert cache._remove_counter == 1 - assert cache[2] == 2 - - cache[5] = 5 - assert set(cache.cache) == {2, 4, 5} - assert cache._remove_counter == 2 - - del cache[5] - assert len(cache) == 2 - assert set(cache.cache) == {2, 4} - assert cache._remove_counter == 3 - - cache.pop(10) - assert len(cache) == 2 - assert set(cache.cache) == {2, 4} - assert cache._remove_counter == 3 - - cache[6] = 6 - assert len(cache) == 3 - assert set(cache.cache) == {2, 4, 6} - assert 2 in cache - assert 4 in cache - assert 6 in cache - - @pytest.mark.parametrize( ("src_dtype", "tgt_dtype", "expected_result"), [ diff --git a/vllm/lora/models.py b/vllm/lora/models.py index edf34b483e9ab..771c8608f4a8a 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -32,7 +32,8 @@ from vllm.model_executor.models.interfaces import is_pooling_model from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper from vllm.model_executor.utils import get_packed_modules_mapping -from vllm.utils import LRUCache, is_pin_memory_available +from vllm.utils import is_pin_memory_available +from vllm.utils.cache import LRUCache logger = init_logger(__name__) diff --git a/vllm/multimodal/cache.py b/vllm/multimodal/cache.py index 15aa91a040921..7febc393157fd 100644 --- a/vllm/multimodal/cache.py +++ b/vllm/multimodal/cache.py @@ -17,7 +17,8 @@ from vllm.distributed.device_communicators.shm_object_storage import ( ) from vllm.envs import VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME from vllm.logger import init_logger -from vllm.utils import GiB_bytes, LRUCache, MiB_bytes +from vllm.utils import GiB_bytes, MiB_bytes +from vllm.utils.cache import LRUCache from vllm.utils.jsontree import json_count_leaves, json_map_leaves, json_reduce_leaves from .inputs import ( diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index c06bbbbb23aba..4a6a79ad067ba 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -51,7 +51,6 @@ from collections.abc import ( Hashable, Iterable, Iterator, - KeysView, Mapping, Sequence, ) @@ -60,24 +59,19 @@ from concurrent.futures.process import ProcessPoolExecutor from dataclasses import dataclass, field from functools import cache, lru_cache, partial, wraps from pathlib import Path -from types import MappingProxyType from typing import ( TYPE_CHECKING, Any, Callable, Generic, Literal, - NamedTuple, TextIO, TypeVar, Union, - cast, - overload, ) from urllib.parse import urlparse from uuid import uuid4 -import cachetools import cbor2 import cloudpickle import numpy as np @@ -183,13 +177,6 @@ U = TypeVar("U") _K = TypeVar("_K", bound=Hashable) _V = TypeVar("_V") -_T = TypeVar("_T") - - -class _Sentinel: ... - - -ALL_PINNED_SENTINEL = _Sentinel() class Device(enum.Enum): @@ -215,243 +202,6 @@ class Counter: self.counter = 0 -class _MappingOrderCacheView(UserDict[_K, _V]): - def __init__(self, data: Mapping[_K, _V], ordered_keys: Mapping[_K, None]): - super().__init__(data) - self.ordered_keys = ordered_keys - - def __iter__(self) -> Iterator[_K]: - return iter(self.ordered_keys) - - def keys(self) -> KeysView[_K]: - return KeysView(self.ordered_keys) - - -class CacheInfo(NamedTuple): - hits: int - total: int - - @property - def hit_ratio(self) -> float: - if self.total == 0: - return 0 - - return self.hits / self.total - - def __sub__(self, other: CacheInfo): - return CacheInfo( - hits=self.hits - other.hits, - total=self.total - other.total, - ) - - -class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]): - def __init__(self, capacity: float, getsizeof: Callable[[_V], float] | None = None): - super().__init__(capacity, getsizeof) - - self.pinned_items = set[_K]() - - self._hits = 0 - self._total = 0 - self._last_info = CacheInfo(hits=0, total=0) - - def __getitem__(self, key: _K, *, update_info: bool = True) -> _V: - value = super().__getitem__(key) - - if update_info: - self._hits += 1 - self._total += 1 - - return value - - def __delitem__(self, key: _K) -> None: - run_on_remove = key in self - value = self.__getitem__(key, update_info=False) # type: ignore[call-arg] - super().__delitem__(key) - if key in self.pinned_items: - # Todo: add warning to inform that del pinned item - self._unpin(key) - if run_on_remove: - self._on_remove(key, value) - - @property - def cache(self) -> Mapping[_K, _V]: - """Return the internal cache dictionary in order (read-only).""" - return _MappingOrderCacheView( - self._Cache__data, # type: ignore - self.order, - ) - - @property - def order(self) -> Mapping[_K, None]: - """Return the internal order dictionary (read-only).""" - return MappingProxyType(self._LRUCache__order) # type: ignore - - @property - def capacity(self) -> float: - return self.maxsize - - @property - def usage(self) -> float: - if self.maxsize == 0: - return 0 - - return self.currsize / self.maxsize - - def stat(self, *, delta: bool = False) -> CacheInfo: - """ - Gets the cumulative number of hits and queries against this cache. - - If `delta=True`, instead gets these statistics - since the last call that also passed `delta=True`. - """ - info = CacheInfo(hits=self._hits, total=self._total) - - if delta: - info_delta = info - self._last_info - self._last_info = info - info = info_delta - - return info - - def touch(self, key: _K) -> None: - try: - self._LRUCache__order.move_to_end(key) # type: ignore - except KeyError: - self._LRUCache__order[key] = None # type: ignore - - @overload - def get(self, key: _K, /) -> _V | None: ... - - @overload - def get(self, key: _K, /, default: Union[_V, _T]) -> Union[_V, _T]: ... - - def get( - self, key: _K, /, default: Union[_V, _T] | None = None - ) -> Union[_V, _T] | None: - value: Union[_V, _T] | None - if key in self: - value = self.__getitem__(key, update_info=False) # type: ignore[call-arg] - - self._hits += 1 - else: - value = default - - self._total += 1 - return value - - @overload - def pop(self, key: _K) -> _V: ... - - @overload - def pop(self, key: _K, default: Union[_V, _T]) -> Union[_V, _T]: ... - - def pop( - self, key: _K, default: Union[_V, _T] | None = None - ) -> Union[_V, _T] | None: - value: Union[_V, _T] | None - if key not in self: - return default - - value = self.__getitem__(key, update_info=False) # type: ignore[call-arg] - self.__delitem__(key) - return value - - def put(self, key: _K, value: _V) -> None: - self.__setitem__(key, value) - - def pin(self, key: _K) -> None: - """ - Pins a key in the cache preventing it from being - evicted in the LRU order. - """ - if key not in self: - raise ValueError(f"Cannot pin key: {key} not in cache.") - self.pinned_items.add(key) - - def _unpin(self, key: _K) -> None: - """ - Unpins a key in the cache allowing it to be - evicted in the LRU order. - """ - self.pinned_items.remove(key) - - def _on_remove(self, key: _K, value: _V | None) -> None: - pass - - def remove_oldest(self, *, remove_pinned: bool = False) -> None: - if len(self) == 0: - return - - self.popitem(remove_pinned=remove_pinned) - - def _remove_old_if_needed(self) -> None: - while self.currsize > self.capacity: - self.remove_oldest() - - def popitem(self, remove_pinned: bool = False): - """Remove and return the `(key, value)` pair least recently used.""" - if not remove_pinned: - # pop the oldest item in the cache that is not pinned - lru_key = next( - (key for key in self.order if key not in self.pinned_items), - ALL_PINNED_SENTINEL, - ) - if lru_key is ALL_PINNED_SENTINEL: - raise RuntimeError( - "All items are pinned, cannot remove oldest from the cache." - ) - else: - lru_key = next(iter(self.order)) - value = self.pop(cast(_K, lru_key)) - return (lru_key, value) - - def clear(self) -> None: - while len(self) > 0: - self.remove_oldest(remove_pinned=True) - - self._hits = 0 - self._total = 0 - self._last_info = CacheInfo(hits=0, total=0) - - -class PyObjectCache: - """Used to cache python objects to avoid object allocations - across scheduler iterations. - """ - - def __init__(self, obj_builder): - self._obj_builder = obj_builder - self._index = 0 - - self._obj_cache = [] - for _ in range(128): - self._obj_cache.append(self._obj_builder()) - - def _grow_cache(self): - # Double the size of the cache - num_objs = len(self._obj_cache) - for _ in range(num_objs): - self._obj_cache.append(self._obj_builder()) - - def get_object(self): - """Returns a pre-allocated cached object. If there is not enough - objects, then the cache size will double. - """ - if self._index >= len(self._obj_cache): - self._grow_cache() - assert self._index < len(self._obj_cache) - - obj = self._obj_cache[self._index] - self._index += 1 - - return obj - - def reset(self): - """Makes all cached-objects available for the next scheduler iteration.""" - self._index = 0 - - @cache def get_max_shared_memory_bytes(gpu: int = 0) -> int: """Returns the maximum shared memory per thread block in bytes.""" diff --git a/vllm/utils/cache.py b/vllm/utils/cache.py new file mode 100644 index 0000000000000..a57ef9b70ccc8 --- /dev/null +++ b/vllm/utils/cache.py @@ -0,0 +1,220 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations + +from collections import UserDict +from collections.abc import Hashable, Iterator, KeysView, Mapping +from types import MappingProxyType +from typing import Callable, Generic, NamedTuple, TypeVar, Union, cast, overload + +import cachetools + +_K = TypeVar("_K", bound=Hashable) +_V = TypeVar("_V") +_T = TypeVar("_T") + + +class _Sentinel: ... + + +ALL_PINNED_SENTINEL = _Sentinel() + + +class _MappingOrderCacheView(UserDict[_K, _V]): + def __init__(self, data: Mapping[_K, _V], ordered_keys: Mapping[_K, None]): + super().__init__(data) + self.ordered_keys = ordered_keys + + def __iter__(self) -> Iterator[_K]: + return iter(self.ordered_keys) + + def keys(self) -> KeysView[_K]: + return KeysView(self.ordered_keys) + + +class CacheInfo(NamedTuple): + hits: int + total: int + + @property + def hit_ratio(self) -> float: + if self.total == 0: + return 0 + + return self.hits / self.total + + def __sub__(self, other: CacheInfo): + return CacheInfo( + hits=self.hits - other.hits, + total=self.total - other.total, + ) + + +class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]): + def __init__(self, capacity: float, getsizeof: Callable[[_V], float] | None = None): + super().__init__(capacity, getsizeof) + + self.pinned_items = set[_K]() + + self._hits = 0 + self._total = 0 + self._last_info = CacheInfo(hits=0, total=0) + + def __getitem__(self, key: _K, *, update_info: bool = True) -> _V: + value = super().__getitem__(key) + + if update_info: + self._hits += 1 + self._total += 1 + + return value + + def __delitem__(self, key: _K) -> None: + run_on_remove = key in self + value = self.__getitem__(key, update_info=False) # type: ignore[call-arg] + super().__delitem__(key) + if key in self.pinned_items: + # Todo: add warning to inform that del pinned item + self._unpin(key) + if run_on_remove: + self._on_remove(key, value) + + @property + def cache(self) -> Mapping[_K, _V]: + """Return the internal cache dictionary in order (read-only).""" + return _MappingOrderCacheView( + self._Cache__data, # type: ignore + self.order, + ) + + @property + def order(self) -> Mapping[_K, None]: + """Return the internal order dictionary (read-only).""" + return MappingProxyType(self._LRUCache__order) # type: ignore + + @property + def capacity(self) -> float: + return self.maxsize + + @property + def usage(self) -> float: + if self.maxsize == 0: + return 0 + + return self.currsize / self.maxsize + + def stat(self, *, delta: bool = False) -> CacheInfo: + """ + Gets the cumulative number of hits and queries against this cache. + + If `delta=True`, instead gets these statistics + since the last call that also passed `delta=True`. + """ + info = CacheInfo(hits=self._hits, total=self._total) + + if delta: + info_delta = info - self._last_info + self._last_info = info + info = info_delta + + return info + + def touch(self, key: _K) -> None: + try: + self._LRUCache__order.move_to_end(key) # type: ignore + except KeyError: + self._LRUCache__order[key] = None # type: ignore + + @overload + def get(self, key: _K, /) -> _V | None: ... + + @overload + def get(self, key: _K, /, default: Union[_V, _T]) -> Union[_V, _T]: ... + + def get( + self, key: _K, /, default: Union[_V, _T] | None = None + ) -> Union[_V, _T] | None: + value: Union[_V, _T] | None + if key in self: + value = self.__getitem__(key, update_info=False) # type: ignore[call-arg] + + self._hits += 1 + else: + value = default + + self._total += 1 + return value + + @overload + def pop(self, key: _K) -> _V: ... + + @overload + def pop(self, key: _K, default: Union[_V, _T]) -> Union[_V, _T]: ... + + def pop( + self, key: _K, default: Union[_V, _T] | None = None + ) -> Union[_V, _T] | None: + value: Union[_V, _T] | None + if key not in self: + return default + + value = self.__getitem__(key, update_info=False) # type: ignore[call-arg] + self.__delitem__(key) + return value + + def put(self, key: _K, value: _V) -> None: + self.__setitem__(key, value) + + def pin(self, key: _K) -> None: + """ + Pins a key in the cache preventing it from being + evicted in the LRU order. + """ + if key not in self: + raise ValueError(f"Cannot pin key: {key} not in cache.") + self.pinned_items.add(key) + + def _unpin(self, key: _K) -> None: + """ + Unpins a key in the cache allowing it to be + evicted in the LRU order. + """ + self.pinned_items.remove(key) + + def _on_remove(self, key: _K, value: _V | None) -> None: + pass + + def remove_oldest(self, *, remove_pinned: bool = False) -> None: + if len(self) == 0: + return + + self.popitem(remove_pinned=remove_pinned) + + def _remove_old_if_needed(self) -> None: + while self.currsize > self.capacity: + self.remove_oldest() + + def popitem(self, remove_pinned: bool = False): + """Remove and return the `(key, value)` pair least recently used.""" + if not remove_pinned: + # pop the oldest item in the cache that is not pinned + lru_key = next( + (key for key in self.order if key not in self.pinned_items), + ALL_PINNED_SENTINEL, + ) + if lru_key is ALL_PINNED_SENTINEL: + raise RuntimeError( + "All items are pinned, cannot remove oldest from the cache." + ) + else: + lru_key = next(iter(self.order)) + value = self.pop(cast(_K, lru_key)) + return (lru_key, value) + + def clear(self) -> None: + while len(self) > 0: + self.remove_oldest(remove_pinned=True) + + self._hits = 0 + self._total = 0 + self._last_info = CacheInfo(hits=0, total=0)