From f4c98b4d4cbc1ae7c51ec2e29d07ae6fb01e6094 Mon Sep 17 00:00:00 2001 From: Bella kira <89331823+Avabowler@users.noreply.github.com> Date: Thu, 27 Mar 2025 14:43:43 +0800 Subject: [PATCH] [Misc] Consolidate LRUCache implementations (#15481) Signed-off-by: Bella kira <2374035698@qq.com> --- vllm/multimodal/processing.py | 3 +- vllm/utils.py | 159 ++++++++++++++++++++++------------ 2 files changed, 105 insertions(+), 57 deletions(-) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index fec77acc1d197..c8864c33fe372 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -12,7 +12,6 @@ from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol, TypeVar, Union, cast) import torch -from cachetools import LRUCache from transformers import BatchFeature, PretrainedConfig, ProcessorMixin from typing_extensions import assert_never @@ -21,7 +20,7 @@ from vllm.jsontree import json_map_leaves, json_reduce_leaves from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens, encode_tokens) -from vllm.utils import GiB_bytes, flatten_2d_lists, full_groupby +from vllm.utils import GiB_bytes, LRUCache, flatten_2d_lists, full_groupby from .hasher import MultiModalHasher from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, diff --git a/vllm/utils.py b/vllm/utils.py index 73de826266daa..516b33dca1dc8 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -33,15 +33,17 @@ import uuid import warnings import weakref from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task -from collections import OrderedDict, UserDict, defaultdict +from collections import UserDict, defaultdict from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable, - Iterable, Iterator, Mapping) + Iterable, Iterator, KeysView, Mapping) from dataclasses import dataclass, field from functools import cache, lru_cache, partial, wraps +from types import MappingProxyType from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple, - Optional, Type, TypeVar, Union) + Optional, Type, TypeVar, Union, cast, overload) from uuid import uuid4 +import cachetools import cloudpickle import numpy as np import numpy.typing as npt @@ -173,6 +175,7 @@ U = TypeVar("U") _K = TypeVar("_K", bound=Hashable) _V = TypeVar("_V") +_T = TypeVar("_T") class _Sentinel: @@ -206,6 +209,19 @@ class Counter: self.counter = 0 +class _MappingOrderCacheView(UserDict[_K, _V]): + + def __init__(self, data: Mapping[_K, _V], ordered_keys: Mapping[_K, None]): + super().__init__(data) + self.ordered_keys = ordered_keys + + def __iter__(self) -> Iterator[_K]: + return iter(self.ordered_keys) + + def keys(self) -> KeysView[_K]: + return KeysView(self.ordered_keys) + + class CacheInfo(NamedTuple): hits: int total: int @@ -218,45 +234,62 @@ class CacheInfo(NamedTuple): return self.hits / self.total -class LRUCache(Generic[_K, _V]): - """Note: This class is not thread safe!""" +class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]): - def __init__(self, capacity: int) -> None: - self.cache = OrderedDict[_K, _V]() + def __init__(self, + capacity: float, + getsizeof: Optional[Callable[[_V], float]] = None): + super().__init__(capacity, getsizeof) self.pinned_items = set[_K]() self.capacity = capacity self._hits = 0 self._total = 0 - def __contains__(self, key: _K) -> bool: - return key in self.cache - - def __len__(self) -> int: - return len(self.cache) - - def __getitem__(self, key: _K) -> _V: - value = self.cache[key] # Raise KeyError if not exists - self.cache.move_to_end(key) - return value - - def __setitem__(self, key: _K, value: _V) -> None: - self.put(key, value) - def __delitem__(self, key: _K) -> None: - self.pop(key) + run_on_remove = key in self + value = self.__getitem__(key) + super().__delitem__(key) + if key in self.pinned_items: + # Todo: add warning to inform that del pinned item + self._unpin(key) + if run_on_remove: + self._on_remove(key, value) + + @property + def cache(self) -> Mapping[_K, _V]: + """Return the internal cache dictionary in order (read-only).""" + return _MappingOrderCacheView( + self._Cache__data, # type: ignore + self.order) + + @property + def order(self) -> Mapping[_K, None]: + """Return the internal order dictionary (read-only).""" + return MappingProxyType(self._LRUCache__order) # type: ignore def stat(self) -> CacheInfo: return CacheInfo(hits=self._hits, total=self._total) def touch(self, key: _K) -> None: - self.cache.move_to_end(key) + self._LRUCache__update(key) # type: ignore - def get(self, key: _K, default: Optional[_V] = None) -> Optional[_V]: - value: Optional[_V] - if key in self.cache: - value = self.cache[key] - self.cache.move_to_end(key) + @overload + def get(self, key: _K, /) -> Optional[_V]: + ... + + @overload + def get(self, key: _K, /, default: Union[_V, _T]) -> Union[_V, _T]: + ... + + def get(self, + key: _K, + /, + default: Optional[Union[_V, + _T]] = None) -> Optional[Union[_V, _T]]: + value: Optional[Union[_V, _T]] + if key in self: + value = self.__getitem__(key) self._hits += 1 else: @@ -265,60 +298,76 @@ class LRUCache(Generic[_K, _V]): self._total += 1 return value + @overload + def pop(self, key: _K) -> _V: + ... + + @overload + def pop(self, key: _K, default: Union[_V, _T]) -> Union[_V, _T]: + ... + + def pop(self, + key: _K, + default: Optional[Union[_V, + _T]] = None) -> Optional[Union[_V, _T]]: + value: Optional[Union[_V, _T]] + if key not in self: + return default + + value = self[key] + del self[key] + return value + def put(self, key: _K, value: _V) -> None: - self.cache[key] = value - self.cache.move_to_end(key) - self._remove_old_if_needed() + self.__setitem__(key, value) def pin(self, key: _K) -> None: """ Pins a key in the cache preventing it from being evicted in the LRU order. """ - if key not in self.cache: + if key not in self: raise ValueError(f"Cannot pin key: {key} not in cache.") self.pinned_items.add(key) def _unpin(self, key: _K) -> None: + """ + Unpins a key in the cache allowing it to be + evicted in the LRU order. + """ self.pinned_items.remove(key) def _on_remove(self, key: _K, value: Optional[_V]) -> None: pass def remove_oldest(self, *, remove_pinned: bool = False) -> None: - if not self.cache: + if len(self) == 0: return + self.popitem(remove_pinned=remove_pinned) + + def _remove_old_if_needed(self) -> None: + while self.currsize > self.capacity: + self.remove_oldest() + + def clear(self) -> None: + while len(self) > 0: + self.remove_oldest(remove_pinned=True) + + def popitem(self, remove_pinned: bool = False): + """Remove and return the `(key, value)` pair least recently used.""" if not remove_pinned: # pop the oldest item in the cache that is not pinned lru_key = next( - (key for key in self.cache if key not in self.pinned_items), + (key for key in self.order if key not in self.pinned_items), ALL_PINNED_SENTINEL) if lru_key is ALL_PINNED_SENTINEL: raise RuntimeError("All items are pinned, " "cannot remove oldest from the cache.") else: - lru_key = next(iter(self.cache)) - self.pop(lru_key) # type: ignore - - def _remove_old_if_needed(self) -> None: - while len(self.cache) > self.capacity: - self.remove_oldest() - - def pop(self, key: _K, default: Optional[_V] = None) -> Optional[_V]: - run_on_remove = key in self.cache - value = self.cache.pop(key, default) - # remove from pinned items - if key in self.pinned_items: - self._unpin(key) - if run_on_remove: - self._on_remove(key, value) - return value - - def clear(self) -> None: - while len(self.cache) > 0: - self.remove_oldest(remove_pinned=True) - self.cache.clear() + lru_key = next(iter(self.order)) + value = self.pop(cast(_K, lru_key)) + return (lru_key, value) class PyObjectCache: