mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-05 05:24:37 +08:00
[Misc] Consolidate LRUCache implementations (#15481)
Signed-off-by: Bella kira <2374035698@qq.com>
This commit is contained in:
parent
e1e0fd7543
commit
f4c98b4d4c
@ -12,7 +12,6 @@ from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
|
|||||||
TypeVar, Union, cast)
|
TypeVar, Union, cast)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from cachetools import LRUCache
|
|
||||||
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
|
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
|
||||||
from typing_extensions import assert_never
|
from typing_extensions import assert_never
|
||||||
|
|
||||||
@ -21,7 +20,7 @@ from vllm.jsontree import json_map_leaves, json_reduce_leaves
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
|
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
|
||||||
encode_tokens)
|
encode_tokens)
|
||||||
from vllm.utils import GiB_bytes, flatten_2d_lists, full_groupby
|
from vllm.utils import GiB_bytes, LRUCache, flatten_2d_lists, full_groupby
|
||||||
|
|
||||||
from .hasher import MultiModalHasher
|
from .hasher import MultiModalHasher
|
||||||
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
|
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
|
||||||
|
|||||||
159
vllm/utils.py
159
vllm/utils.py
@ -33,15 +33,17 @@ import uuid
|
|||||||
import warnings
|
import warnings
|
||||||
import weakref
|
import weakref
|
||||||
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
|
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
|
||||||
from collections import OrderedDict, UserDict, defaultdict
|
from collections import UserDict, defaultdict
|
||||||
from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable,
|
from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable,
|
||||||
Iterable, Iterator, Mapping)
|
Iterable, Iterator, KeysView, Mapping)
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from functools import cache, lru_cache, partial, wraps
|
from functools import cache, lru_cache, partial, wraps
|
||||||
|
from types import MappingProxyType
|
||||||
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
|
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
|
||||||
Optional, Type, TypeVar, Union)
|
Optional, Type, TypeVar, Union, cast, overload)
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import cachetools
|
||||||
import cloudpickle
|
import cloudpickle
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
@ -173,6 +175,7 @@ U = TypeVar("U")
|
|||||||
|
|
||||||
_K = TypeVar("_K", bound=Hashable)
|
_K = TypeVar("_K", bound=Hashable)
|
||||||
_V = TypeVar("_V")
|
_V = TypeVar("_V")
|
||||||
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
|
|
||||||
class _Sentinel:
|
class _Sentinel:
|
||||||
@ -206,6 +209,19 @@ class Counter:
|
|||||||
self.counter = 0
|
self.counter = 0
|
||||||
|
|
||||||
|
|
||||||
|
class _MappingOrderCacheView(UserDict[_K, _V]):
|
||||||
|
|
||||||
|
def __init__(self, data: Mapping[_K, _V], ordered_keys: Mapping[_K, None]):
|
||||||
|
super().__init__(data)
|
||||||
|
self.ordered_keys = ordered_keys
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[_K]:
|
||||||
|
return iter(self.ordered_keys)
|
||||||
|
|
||||||
|
def keys(self) -> KeysView[_K]:
|
||||||
|
return KeysView(self.ordered_keys)
|
||||||
|
|
||||||
|
|
||||||
class CacheInfo(NamedTuple):
|
class CacheInfo(NamedTuple):
|
||||||
hits: int
|
hits: int
|
||||||
total: int
|
total: int
|
||||||
@ -218,45 +234,62 @@ class CacheInfo(NamedTuple):
|
|||||||
return self.hits / self.total
|
return self.hits / self.total
|
||||||
|
|
||||||
|
|
||||||
class LRUCache(Generic[_K, _V]):
|
class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
|
||||||
"""Note: This class is not thread safe!"""
|
|
||||||
|
|
||||||
def __init__(self, capacity: int) -> None:
|
def __init__(self,
|
||||||
self.cache = OrderedDict[_K, _V]()
|
capacity: float,
|
||||||
|
getsizeof: Optional[Callable[[_V], float]] = None):
|
||||||
|
super().__init__(capacity, getsizeof)
|
||||||
self.pinned_items = set[_K]()
|
self.pinned_items = set[_K]()
|
||||||
self.capacity = capacity
|
self.capacity = capacity
|
||||||
|
|
||||||
self._hits = 0
|
self._hits = 0
|
||||||
self._total = 0
|
self._total = 0
|
||||||
|
|
||||||
def __contains__(self, key: _K) -> bool:
|
|
||||||
return key in self.cache
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return len(self.cache)
|
|
||||||
|
|
||||||
def __getitem__(self, key: _K) -> _V:
|
|
||||||
value = self.cache[key] # Raise KeyError if not exists
|
|
||||||
self.cache.move_to_end(key)
|
|
||||||
return value
|
|
||||||
|
|
||||||
def __setitem__(self, key: _K, value: _V) -> None:
|
|
||||||
self.put(key, value)
|
|
||||||
|
|
||||||
def __delitem__(self, key: _K) -> None:
|
def __delitem__(self, key: _K) -> None:
|
||||||
self.pop(key)
|
run_on_remove = key in self
|
||||||
|
value = self.__getitem__(key)
|
||||||
|
super().__delitem__(key)
|
||||||
|
if key in self.pinned_items:
|
||||||
|
# Todo: add warning to inform that del pinned item
|
||||||
|
self._unpin(key)
|
||||||
|
if run_on_remove:
|
||||||
|
self._on_remove(key, value)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cache(self) -> Mapping[_K, _V]:
|
||||||
|
"""Return the internal cache dictionary in order (read-only)."""
|
||||||
|
return _MappingOrderCacheView(
|
||||||
|
self._Cache__data, # type: ignore
|
||||||
|
self.order)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def order(self) -> Mapping[_K, None]:
|
||||||
|
"""Return the internal order dictionary (read-only)."""
|
||||||
|
return MappingProxyType(self._LRUCache__order) # type: ignore
|
||||||
|
|
||||||
def stat(self) -> CacheInfo:
|
def stat(self) -> CacheInfo:
|
||||||
return CacheInfo(hits=self._hits, total=self._total)
|
return CacheInfo(hits=self._hits, total=self._total)
|
||||||
|
|
||||||
def touch(self, key: _K) -> None:
|
def touch(self, key: _K) -> None:
|
||||||
self.cache.move_to_end(key)
|
self._LRUCache__update(key) # type: ignore
|
||||||
|
|
||||||
def get(self, key: _K, default: Optional[_V] = None) -> Optional[_V]:
|
@overload
|
||||||
value: Optional[_V]
|
def get(self, key: _K, /) -> Optional[_V]:
|
||||||
if key in self.cache:
|
...
|
||||||
value = self.cache[key]
|
|
||||||
self.cache.move_to_end(key)
|
@overload
|
||||||
|
def get(self, key: _K, /, default: Union[_V, _T]) -> Union[_V, _T]:
|
||||||
|
...
|
||||||
|
|
||||||
|
def get(self,
|
||||||
|
key: _K,
|
||||||
|
/,
|
||||||
|
default: Optional[Union[_V,
|
||||||
|
_T]] = None) -> Optional[Union[_V, _T]]:
|
||||||
|
value: Optional[Union[_V, _T]]
|
||||||
|
if key in self:
|
||||||
|
value = self.__getitem__(key)
|
||||||
|
|
||||||
self._hits += 1
|
self._hits += 1
|
||||||
else:
|
else:
|
||||||
@ -265,60 +298,76 @@ class LRUCache(Generic[_K, _V]):
|
|||||||
self._total += 1
|
self._total += 1
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def pop(self, key: _K) -> _V:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def pop(self, key: _K, default: Union[_V, _T]) -> Union[_V, _T]:
|
||||||
|
...
|
||||||
|
|
||||||
|
def pop(self,
|
||||||
|
key: _K,
|
||||||
|
default: Optional[Union[_V,
|
||||||
|
_T]] = None) -> Optional[Union[_V, _T]]:
|
||||||
|
value: Optional[Union[_V, _T]]
|
||||||
|
if key not in self:
|
||||||
|
return default
|
||||||
|
|
||||||
|
value = self[key]
|
||||||
|
del self[key]
|
||||||
|
return value
|
||||||
|
|
||||||
def put(self, key: _K, value: _V) -> None:
|
def put(self, key: _K, value: _V) -> None:
|
||||||
self.cache[key] = value
|
self.__setitem__(key, value)
|
||||||
self.cache.move_to_end(key)
|
|
||||||
self._remove_old_if_needed()
|
|
||||||
|
|
||||||
def pin(self, key: _K) -> None:
|
def pin(self, key: _K) -> None:
|
||||||
"""
|
"""
|
||||||
Pins a key in the cache preventing it from being
|
Pins a key in the cache preventing it from being
|
||||||
evicted in the LRU order.
|
evicted in the LRU order.
|
||||||
"""
|
"""
|
||||||
if key not in self.cache:
|
if key not in self:
|
||||||
raise ValueError(f"Cannot pin key: {key} not in cache.")
|
raise ValueError(f"Cannot pin key: {key} not in cache.")
|
||||||
self.pinned_items.add(key)
|
self.pinned_items.add(key)
|
||||||
|
|
||||||
def _unpin(self, key: _K) -> None:
|
def _unpin(self, key: _K) -> None:
|
||||||
|
"""
|
||||||
|
Unpins a key in the cache allowing it to be
|
||||||
|
evicted in the LRU order.
|
||||||
|
"""
|
||||||
self.pinned_items.remove(key)
|
self.pinned_items.remove(key)
|
||||||
|
|
||||||
def _on_remove(self, key: _K, value: Optional[_V]) -> None:
|
def _on_remove(self, key: _K, value: Optional[_V]) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def remove_oldest(self, *, remove_pinned: bool = False) -> None:
|
def remove_oldest(self, *, remove_pinned: bool = False) -> None:
|
||||||
if not self.cache:
|
if len(self) == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
self.popitem(remove_pinned=remove_pinned)
|
||||||
|
|
||||||
|
def _remove_old_if_needed(self) -> None:
|
||||||
|
while self.currsize > self.capacity:
|
||||||
|
self.remove_oldest()
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
while len(self) > 0:
|
||||||
|
self.remove_oldest(remove_pinned=True)
|
||||||
|
|
||||||
|
def popitem(self, remove_pinned: bool = False):
|
||||||
|
"""Remove and return the `(key, value)` pair least recently used."""
|
||||||
if not remove_pinned:
|
if not remove_pinned:
|
||||||
# pop the oldest item in the cache that is not pinned
|
# pop the oldest item in the cache that is not pinned
|
||||||
lru_key = next(
|
lru_key = next(
|
||||||
(key for key in self.cache if key not in self.pinned_items),
|
(key for key in self.order if key not in self.pinned_items),
|
||||||
ALL_PINNED_SENTINEL)
|
ALL_PINNED_SENTINEL)
|
||||||
if lru_key is ALL_PINNED_SENTINEL:
|
if lru_key is ALL_PINNED_SENTINEL:
|
||||||
raise RuntimeError("All items are pinned, "
|
raise RuntimeError("All items are pinned, "
|
||||||
"cannot remove oldest from the cache.")
|
"cannot remove oldest from the cache.")
|
||||||
else:
|
else:
|
||||||
lru_key = next(iter(self.cache))
|
lru_key = next(iter(self.order))
|
||||||
self.pop(lru_key) # type: ignore
|
value = self.pop(cast(_K, lru_key))
|
||||||
|
return (lru_key, value)
|
||||||
def _remove_old_if_needed(self) -> None:
|
|
||||||
while len(self.cache) > self.capacity:
|
|
||||||
self.remove_oldest()
|
|
||||||
|
|
||||||
def pop(self, key: _K, default: Optional[_V] = None) -> Optional[_V]:
|
|
||||||
run_on_remove = key in self.cache
|
|
||||||
value = self.cache.pop(key, default)
|
|
||||||
# remove from pinned items
|
|
||||||
if key in self.pinned_items:
|
|
||||||
self._unpin(key)
|
|
||||||
if run_on_remove:
|
|
||||||
self._on_remove(key, value)
|
|
||||||
return value
|
|
||||||
|
|
||||||
def clear(self) -> None:
|
|
||||||
while len(self.cache) > 0:
|
|
||||||
self.remove_oldest(remove_pinned=True)
|
|
||||||
self.cache.clear()
|
|
||||||
|
|
||||||
|
|
||||||
class PyObjectCache:
|
class PyObjectCache:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user