[Misc] Consolidate LRUCache implementations (#15481)

Signed-off-by: Bella kira <2374035698@qq.com>
This commit is contained in:
Bella kira 2025-03-27 14:43:43 +08:00 committed by GitHub
parent e1e0fd7543
commit f4c98b4d4c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 105 additions and 57 deletions

View File

@ -12,7 +12,6 @@ from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
TypeVar, Union, cast) TypeVar, Union, cast)
import torch import torch
from cachetools import LRUCache
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
from typing_extensions import assert_never from typing_extensions import assert_never
@ -21,7 +20,7 @@ from vllm.jsontree import json_map_leaves, json_reduce_leaves
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens, from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
encode_tokens) encode_tokens)
from vllm.utils import GiB_bytes, flatten_2d_lists, full_groupby from vllm.utils import GiB_bytes, LRUCache, flatten_2d_lists, full_groupby
from .hasher import MultiModalHasher from .hasher import MultiModalHasher
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,

View File

@ -33,15 +33,17 @@ import uuid
import warnings import warnings
import weakref import weakref
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
from collections import OrderedDict, UserDict, defaultdict from collections import UserDict, defaultdict
from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable, from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable,
Iterable, Iterator, Mapping) Iterable, Iterator, KeysView, Mapping)
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import cache, lru_cache, partial, wraps from functools import cache, lru_cache, partial, wraps
from types import MappingProxyType
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple, from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
Optional, Type, TypeVar, Union) Optional, Type, TypeVar, Union, cast, overload)
from uuid import uuid4 from uuid import uuid4
import cachetools
import cloudpickle import cloudpickle
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
@ -173,6 +175,7 @@ U = TypeVar("U")
_K = TypeVar("_K", bound=Hashable) _K = TypeVar("_K", bound=Hashable)
_V = TypeVar("_V") _V = TypeVar("_V")
_T = TypeVar("_T")
class _Sentinel: class _Sentinel:
@ -206,6 +209,19 @@ class Counter:
self.counter = 0 self.counter = 0
class _MappingOrderCacheView(UserDict[_K, _V]):
def __init__(self, data: Mapping[_K, _V], ordered_keys: Mapping[_K, None]):
super().__init__(data)
self.ordered_keys = ordered_keys
def __iter__(self) -> Iterator[_K]:
return iter(self.ordered_keys)
def keys(self) -> KeysView[_K]:
return KeysView(self.ordered_keys)
class CacheInfo(NamedTuple): class CacheInfo(NamedTuple):
hits: int hits: int
total: int total: int
@ -218,45 +234,62 @@ class CacheInfo(NamedTuple):
return self.hits / self.total return self.hits / self.total
class LRUCache(Generic[_K, _V]): class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
"""Note: This class is not thread safe!"""
def __init__(self, capacity: int) -> None: def __init__(self,
self.cache = OrderedDict[_K, _V]() capacity: float,
getsizeof: Optional[Callable[[_V], float]] = None):
super().__init__(capacity, getsizeof)
self.pinned_items = set[_K]() self.pinned_items = set[_K]()
self.capacity = capacity self.capacity = capacity
self._hits = 0 self._hits = 0
self._total = 0 self._total = 0
def __contains__(self, key: _K) -> bool:
return key in self.cache
def __len__(self) -> int:
return len(self.cache)
def __getitem__(self, key: _K) -> _V:
value = self.cache[key] # Raise KeyError if not exists
self.cache.move_to_end(key)
return value
def __setitem__(self, key: _K, value: _V) -> None:
self.put(key, value)
def __delitem__(self, key: _K) -> None: def __delitem__(self, key: _K) -> None:
self.pop(key) run_on_remove = key in self
value = self.__getitem__(key)
super().__delitem__(key)
if key in self.pinned_items:
# Todo: add warning to inform that del pinned item
self._unpin(key)
if run_on_remove:
self._on_remove(key, value)
@property
def cache(self) -> Mapping[_K, _V]:
"""Return the internal cache dictionary in order (read-only)."""
return _MappingOrderCacheView(
self._Cache__data, # type: ignore
self.order)
@property
def order(self) -> Mapping[_K, None]:
"""Return the internal order dictionary (read-only)."""
return MappingProxyType(self._LRUCache__order) # type: ignore
def stat(self) -> CacheInfo: def stat(self) -> CacheInfo:
return CacheInfo(hits=self._hits, total=self._total) return CacheInfo(hits=self._hits, total=self._total)
def touch(self, key: _K) -> None: def touch(self, key: _K) -> None:
self.cache.move_to_end(key) self._LRUCache__update(key) # type: ignore
def get(self, key: _K, default: Optional[_V] = None) -> Optional[_V]: @overload
value: Optional[_V] def get(self, key: _K, /) -> Optional[_V]:
if key in self.cache: ...
value = self.cache[key]
self.cache.move_to_end(key) @overload
def get(self, key: _K, /, default: Union[_V, _T]) -> Union[_V, _T]:
...
def get(self,
key: _K,
/,
default: Optional[Union[_V,
_T]] = None) -> Optional[Union[_V, _T]]:
value: Optional[Union[_V, _T]]
if key in self:
value = self.__getitem__(key)
self._hits += 1 self._hits += 1
else: else:
@ -265,60 +298,76 @@ class LRUCache(Generic[_K, _V]):
self._total += 1 self._total += 1
return value return value
@overload
def pop(self, key: _K) -> _V:
...
@overload
def pop(self, key: _K, default: Union[_V, _T]) -> Union[_V, _T]:
...
def pop(self,
key: _K,
default: Optional[Union[_V,
_T]] = None) -> Optional[Union[_V, _T]]:
value: Optional[Union[_V, _T]]
if key not in self:
return default
value = self[key]
del self[key]
return value
def put(self, key: _K, value: _V) -> None: def put(self, key: _K, value: _V) -> None:
self.cache[key] = value self.__setitem__(key, value)
self.cache.move_to_end(key)
self._remove_old_if_needed()
def pin(self, key: _K) -> None: def pin(self, key: _K) -> None:
""" """
Pins a key in the cache preventing it from being Pins a key in the cache preventing it from being
evicted in the LRU order. evicted in the LRU order.
""" """
if key not in self.cache: if key not in self:
raise ValueError(f"Cannot pin key: {key} not in cache.") raise ValueError(f"Cannot pin key: {key} not in cache.")
self.pinned_items.add(key) self.pinned_items.add(key)
def _unpin(self, key: _K) -> None: def _unpin(self, key: _K) -> None:
"""
Unpins a key in the cache allowing it to be
evicted in the LRU order.
"""
self.pinned_items.remove(key) self.pinned_items.remove(key)
def _on_remove(self, key: _K, value: Optional[_V]) -> None: def _on_remove(self, key: _K, value: Optional[_V]) -> None:
pass pass
def remove_oldest(self, *, remove_pinned: bool = False) -> None: def remove_oldest(self, *, remove_pinned: bool = False) -> None:
if not self.cache: if len(self) == 0:
return return
self.popitem(remove_pinned=remove_pinned)
def _remove_old_if_needed(self) -> None:
while self.currsize > self.capacity:
self.remove_oldest()
def clear(self) -> None:
while len(self) > 0:
self.remove_oldest(remove_pinned=True)
def popitem(self, remove_pinned: bool = False):
"""Remove and return the `(key, value)` pair least recently used."""
if not remove_pinned: if not remove_pinned:
# pop the oldest item in the cache that is not pinned # pop the oldest item in the cache that is not pinned
lru_key = next( lru_key = next(
(key for key in self.cache if key not in self.pinned_items), (key for key in self.order if key not in self.pinned_items),
ALL_PINNED_SENTINEL) ALL_PINNED_SENTINEL)
if lru_key is ALL_PINNED_SENTINEL: if lru_key is ALL_PINNED_SENTINEL:
raise RuntimeError("All items are pinned, " raise RuntimeError("All items are pinned, "
"cannot remove oldest from the cache.") "cannot remove oldest from the cache.")
else: else:
lru_key = next(iter(self.cache)) lru_key = next(iter(self.order))
self.pop(lru_key) # type: ignore value = self.pop(cast(_K, lru_key))
return (lru_key, value)
def _remove_old_if_needed(self) -> None:
while len(self.cache) > self.capacity:
self.remove_oldest()
def pop(self, key: _K, default: Optional[_V] = None) -> Optional[_V]:
run_on_remove = key in self.cache
value = self.cache.pop(key, default)
# remove from pinned items
if key in self.pinned_items:
self._unpin(key)
if run_on_remove:
self._on_remove(key, value)
return value
def clear(self) -> None:
while len(self.cache) > 0:
self.remove_oldest(remove_pinned=True)
self.cache.clear()
class PyObjectCache: class PyObjectCache: