[Misc] Clean up and consolidate LRUCache (#11339)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2024-12-20 00:59:32 +08:00 committed by GitHub
parent e24113a8fe
commit cdf22afdda
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 34 additions and 67 deletions

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Hashable, Optional, TypeVar
from typing import Any, Callable, Dict, Optional, TypeVar
from torch import nn
@ -24,14 +24,13 @@ class AdapterModel(ABC):
T = TypeVar('T')
class AdapterLRUCache(LRUCache[T]):
class AdapterLRUCache(LRUCache[int, T]):
def __init__(self, capacity: int, deactivate_fn: Callable[[Hashable],
None]):
def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]):
super().__init__(capacity)
self.deactivate_fn = deactivate_fn
def _on_remove(self, key: Hashable, value: Optional[T]):
def _on_remove(self, key: int, value: Optional[T]):
logger.debug("Removing adapter int id: %d", key)
self.deactivate_fn(key)
return super()._on_remove(key, value)

View File

@ -22,7 +22,7 @@ class TokenizerGroup(BaseTokenizerGroup):
self.max_input_length = max_input_length
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
max_loras = tokenizer_config.get("max_loras", 0)
self.lora_tokenizers = LRUCache[AnyTokenizer](
self.lora_tokenizers = LRUCache[int, AnyTokenizer](
capacity=max(max_loras, max_num_seqs) if enable_lora else 0)
@classmethod

View File

@ -21,14 +21,13 @@ import uuid
import warnings
import weakref
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
from collections import UserDict, defaultdict
from collections import OrderedDict, UserDict, defaultdict
from collections.abc import Iterable, Mapping
from dataclasses import dataclass, field
from functools import lru_cache, partial, wraps
from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
Dict, Generator, Generic, Hashable, List, Literal,
Optional, OrderedDict, Set, Tuple, Type, TypeVar, Union,
overload)
Optional, Tuple, Type, TypeVar, Union, overload)
from uuid import uuid4
import numpy as np
@ -154,10 +153,12 @@ TORCH_DTYPE_TO_NUMPY_DTYPE = {
}
P = ParamSpec('P')
K = TypeVar("K")
T = TypeVar("T")
U = TypeVar("U")
_K = TypeVar("_K", bound=Hashable)
_V = TypeVar("_V")
class _Sentinel:
...
@ -190,50 +191,48 @@ class Counter:
self.counter = 0
class LRUCache(Generic[T]):
class LRUCache(Generic[_K, _V]):
def __init__(self, capacity: int):
self.cache: OrderedDict[Hashable, T] = OrderedDict()
self.pinned_items: Set[Hashable] = set()
def __init__(self, capacity: int) -> None:
self.cache = OrderedDict[_K, _V]()
self.pinned_items = set[_K]()
self.capacity = capacity
def __contains__(self, key: Hashable) -> bool:
def __contains__(self, key: _K) -> bool:
return key in self.cache
def __len__(self) -> int:
return len(self.cache)
def __getitem__(self, key: Hashable) -> T:
def __getitem__(self, key: _K) -> _V:
value = self.cache[key] # Raise KeyError if not exists
self.cache.move_to_end(key)
return value
def __setitem__(self, key: Hashable, value: T) -> None:
def __setitem__(self, key: _K, value: _V) -> None:
self.put(key, value)
def __delitem__(self, key: Hashable) -> None:
def __delitem__(self, key: _K) -> None:
self.pop(key)
def touch(self, key: Hashable) -> None:
def touch(self, key: _K) -> None:
self.cache.move_to_end(key)
def get(self,
key: Hashable,
default_value: Optional[T] = None) -> Optional[T]:
value: Optional[T]
def get(self, key: _K, default: Optional[_V] = None) -> Optional[_V]:
value: Optional[_V]
if key in self.cache:
value = self.cache[key]
self.cache.move_to_end(key)
else:
value = default_value
value = default
return value
def put(self, key: Hashable, value: T) -> None:
def put(self, key: _K, value: _V) -> None:
self.cache[key] = value
self.cache.move_to_end(key)
self._remove_old_if_needed()
def pin(self, key: Hashable) -> None:
def pin(self, key: _K) -> None:
"""
Pins a key in the cache preventing it from being
evicted in the LRU order.
@ -242,13 +241,13 @@ class LRUCache(Generic[T]):
raise ValueError(f"Cannot pin key: {key} not in cache.")
self.pinned_items.add(key)
def _unpin(self, key: Hashable) -> None:
def _unpin(self, key: _K) -> None:
self.pinned_items.remove(key)
def _on_remove(self, key: Hashable, value: Optional[T]):
def _on_remove(self, key: _K, value: Optional[_V]) -> None:
pass
def remove_oldest(self, remove_pinned=False):
def remove_oldest(self, *, remove_pinned: bool = False) -> None:
if not self.cache:
return
@ -262,17 +261,15 @@ class LRUCache(Generic[T]):
"cannot remove oldest from the cache.")
else:
lru_key = next(iter(self.cache))
self.pop(lru_key)
self.pop(lru_key) # type: ignore
def _remove_old_if_needed(self) -> None:
while len(self.cache) > self.capacity:
self.remove_oldest()
def pop(self,
key: Hashable,
default_value: Optional[T] = None) -> Optional[T]:
def pop(self, key: _K, default: Optional[_V] = None) -> Optional[_V]:
run_on_remove = key in self.cache
value: Optional[T] = self.cache.pop(key, default_value)
value = self.cache.pop(key, default)
# remove from pinned items
if key in self.pinned_items:
self._unpin(key)
@ -280,7 +277,7 @@ class LRUCache(Generic[T]):
self._on_remove(key, value)
return value
def clear(self):
def clear(self) -> None:
while len(self.cache) > 0:
self.remove_oldest(remove_pinned=True)
self.cache.clear()
@ -843,10 +840,6 @@ def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
return [item for sublist in lists for item in sublist]
_K = TypeVar("_K", bound=Hashable)
_V = TypeVar("_V")
def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]):
"""
Unlike :class:`itertools.groupby`, groups are not broken by

View File

@ -8,7 +8,7 @@ from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
MultiModalKwargs, MultiModalRegistry)
from vllm.v1.utils import LRUDictCache
from vllm.utils import LRUCache
logger = init_logger(__name__)
@ -44,7 +44,7 @@ class MMInputMapperClient:
# Init cache
self.use_cache = not model_config.disable_mm_preprocessor_cache
self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE)
self.mm_cache = LRUCache[str, MultiModalKwargs](MM_CACHE_SIZE)
# DEBUG: Set to None to disable
self.mm_debug_cache_hit_ratio_steps = None
@ -120,7 +120,7 @@ class MMInputMapperServer:
def __init__(self, model_config):
self.use_cache = not model_config.disable_mm_preprocessor_cache
self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE)
self.mm_cache = LRUCache[str, MultiModalKwargs](MM_CACHE_SIZE)
def process_inputs(
self,

View File

@ -1,4 +1,3 @@
from collections import OrderedDict
from collections.abc import Sequence
from contextlib import contextmanager
from typing import (Any, Generic, Iterator, List, Optional, TypeVar, Union,
@ -102,27 +101,3 @@ def make_zmq_socket(
finally:
ctx.destroy(linger=0)
K = TypeVar('K')
V = TypeVar('V')
class LRUDictCache(Generic[K, V]):
def __init__(self, size: int):
self.cache: OrderedDict[K, V] = OrderedDict()
self.size = size
def get(self, key: K, default=None) -> V:
if key not in self.cache:
return default
self.cache.move_to_end(key)
return self.cache[key]
def put(self, key: K, value: V):
self.cache[key] = value
self.cache.move_to_end(key)
if len(self.cache) > self.size:
self.cache.popitem(last=False)