[Misc] Clean up and consolidate LRUCache (#11339)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2024-12-20 00:59:32 +08:00 committed by GitHub
parent e24113a8fe
commit cdf22afdda
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 34 additions and 67 deletions

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Hashable, Optional, TypeVar from typing import Any, Callable, Dict, Optional, TypeVar
from torch import nn from torch import nn
@ -24,14 +24,13 @@ class AdapterModel(ABC):
T = TypeVar('T') T = TypeVar('T')
class AdapterLRUCache(LRUCache[T]): class AdapterLRUCache(LRUCache[int, T]):
def __init__(self, capacity: int, deactivate_fn: Callable[[Hashable], def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]):
None]):
super().__init__(capacity) super().__init__(capacity)
self.deactivate_fn = deactivate_fn self.deactivate_fn = deactivate_fn
def _on_remove(self, key: Hashable, value: Optional[T]): def _on_remove(self, key: int, value: Optional[T]):
logger.debug("Removing adapter int id: %d", key) logger.debug("Removing adapter int id: %d", key)
self.deactivate_fn(key) self.deactivate_fn(key)
return super()._on_remove(key, value) return super()._on_remove(key, value)

View File

@ -22,7 +22,7 @@ class TokenizerGroup(BaseTokenizerGroup):
self.max_input_length = max_input_length self.max_input_length = max_input_length
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config) self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
max_loras = tokenizer_config.get("max_loras", 0) max_loras = tokenizer_config.get("max_loras", 0)
self.lora_tokenizers = LRUCache[AnyTokenizer]( self.lora_tokenizers = LRUCache[int, AnyTokenizer](
capacity=max(max_loras, max_num_seqs) if enable_lora else 0) capacity=max(max_loras, max_num_seqs) if enable_lora else 0)
@classmethod @classmethod

View File

@ -21,14 +21,13 @@ import uuid
import warnings import warnings
import weakref import weakref
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
from collections import UserDict, defaultdict from collections import OrderedDict, UserDict, defaultdict
from collections.abc import Iterable, Mapping from collections.abc import Iterable, Mapping
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import lru_cache, partial, wraps from functools import lru_cache, partial, wraps
from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
Dict, Generator, Generic, Hashable, List, Literal, Dict, Generator, Generic, Hashable, List, Literal,
Optional, OrderedDict, Set, Tuple, Type, TypeVar, Union, Optional, Tuple, Type, TypeVar, Union, overload)
overload)
from uuid import uuid4 from uuid import uuid4
import numpy as np import numpy as np
@ -154,10 +153,12 @@ TORCH_DTYPE_TO_NUMPY_DTYPE = {
} }
P = ParamSpec('P') P = ParamSpec('P')
K = TypeVar("K")
T = TypeVar("T") T = TypeVar("T")
U = TypeVar("U") U = TypeVar("U")
_K = TypeVar("_K", bound=Hashable)
_V = TypeVar("_V")
class _Sentinel: class _Sentinel:
... ...
@ -190,50 +191,48 @@ class Counter:
self.counter = 0 self.counter = 0
class LRUCache(Generic[T]): class LRUCache(Generic[_K, _V]):
def __init__(self, capacity: int): def __init__(self, capacity: int) -> None:
self.cache: OrderedDict[Hashable, T] = OrderedDict() self.cache = OrderedDict[_K, _V]()
self.pinned_items: Set[Hashable] = set() self.pinned_items = set[_K]()
self.capacity = capacity self.capacity = capacity
def __contains__(self, key: Hashable) -> bool: def __contains__(self, key: _K) -> bool:
return key in self.cache return key in self.cache
def __len__(self) -> int: def __len__(self) -> int:
return len(self.cache) return len(self.cache)
def __getitem__(self, key: Hashable) -> T: def __getitem__(self, key: _K) -> _V:
value = self.cache[key] # Raise KeyError if not exists value = self.cache[key] # Raise KeyError if not exists
self.cache.move_to_end(key) self.cache.move_to_end(key)
return value return value
def __setitem__(self, key: Hashable, value: T) -> None: def __setitem__(self, key: _K, value: _V) -> None:
self.put(key, value) self.put(key, value)
def __delitem__(self, key: Hashable) -> None: def __delitem__(self, key: _K) -> None:
self.pop(key) self.pop(key)
def touch(self, key: Hashable) -> None: def touch(self, key: _K) -> None:
self.cache.move_to_end(key) self.cache.move_to_end(key)
def get(self, def get(self, key: _K, default: Optional[_V] = None) -> Optional[_V]:
key: Hashable, value: Optional[_V]
default_value: Optional[T] = None) -> Optional[T]:
value: Optional[T]
if key in self.cache: if key in self.cache:
value = self.cache[key] value = self.cache[key]
self.cache.move_to_end(key) self.cache.move_to_end(key)
else: else:
value = default_value value = default
return value return value
def put(self, key: Hashable, value: T) -> None: def put(self, key: _K, value: _V) -> None:
self.cache[key] = value self.cache[key] = value
self.cache.move_to_end(key) self.cache.move_to_end(key)
self._remove_old_if_needed() self._remove_old_if_needed()
def pin(self, key: Hashable) -> None: def pin(self, key: _K) -> None:
""" """
Pins a key in the cache preventing it from being Pins a key in the cache preventing it from being
evicted in the LRU order. evicted in the LRU order.
@ -242,13 +241,13 @@ class LRUCache(Generic[T]):
raise ValueError(f"Cannot pin key: {key} not in cache.") raise ValueError(f"Cannot pin key: {key} not in cache.")
self.pinned_items.add(key) self.pinned_items.add(key)
def _unpin(self, key: Hashable) -> None: def _unpin(self, key: _K) -> None:
self.pinned_items.remove(key) self.pinned_items.remove(key)
def _on_remove(self, key: Hashable, value: Optional[T]): def _on_remove(self, key: _K, value: Optional[_V]) -> None:
pass pass
def remove_oldest(self, remove_pinned=False): def remove_oldest(self, *, remove_pinned: bool = False) -> None:
if not self.cache: if not self.cache:
return return
@ -262,17 +261,15 @@ class LRUCache(Generic[T]):
"cannot remove oldest from the cache.") "cannot remove oldest from the cache.")
else: else:
lru_key = next(iter(self.cache)) lru_key = next(iter(self.cache))
self.pop(lru_key) self.pop(lru_key) # type: ignore
def _remove_old_if_needed(self) -> None: def _remove_old_if_needed(self) -> None:
while len(self.cache) > self.capacity: while len(self.cache) > self.capacity:
self.remove_oldest() self.remove_oldest()
def pop(self, def pop(self, key: _K, default: Optional[_V] = None) -> Optional[_V]:
key: Hashable,
default_value: Optional[T] = None) -> Optional[T]:
run_on_remove = key in self.cache run_on_remove = key in self.cache
value: Optional[T] = self.cache.pop(key, default_value) value = self.cache.pop(key, default)
# remove from pinned items # remove from pinned items
if key in self.pinned_items: if key in self.pinned_items:
self._unpin(key) self._unpin(key)
@ -280,7 +277,7 @@ class LRUCache(Generic[T]):
self._on_remove(key, value) self._on_remove(key, value)
return value return value
def clear(self): def clear(self) -> None:
while len(self.cache) > 0: while len(self.cache) > 0:
self.remove_oldest(remove_pinned=True) self.remove_oldest(remove_pinned=True)
self.cache.clear() self.cache.clear()
@ -843,10 +840,6 @@ def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
return [item for sublist in lists for item in sublist] return [item for sublist in lists for item in sublist]
_K = TypeVar("_K", bound=Hashable)
_V = TypeVar("_V")
def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]): def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]):
""" """
Unlike :class:`itertools.groupby`, groups are not broken by Unlike :class:`itertools.groupby`, groups are not broken by

View File

@ -8,7 +8,7 @@ from vllm.inputs import PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict, from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
MultiModalKwargs, MultiModalRegistry) MultiModalKwargs, MultiModalRegistry)
from vllm.v1.utils import LRUDictCache from vllm.utils import LRUCache
logger = init_logger(__name__) logger = init_logger(__name__)
@ -44,7 +44,7 @@ class MMInputMapperClient:
# Init cache # Init cache
self.use_cache = not model_config.disable_mm_preprocessor_cache self.use_cache = not model_config.disable_mm_preprocessor_cache
self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE) self.mm_cache = LRUCache[str, MultiModalKwargs](MM_CACHE_SIZE)
# DEBUG: Set to None to disable # DEBUG: Set to None to disable
self.mm_debug_cache_hit_ratio_steps = None self.mm_debug_cache_hit_ratio_steps = None
@ -120,7 +120,7 @@ class MMInputMapperServer:
def __init__(self, model_config): def __init__(self, model_config):
self.use_cache = not model_config.disable_mm_preprocessor_cache self.use_cache = not model_config.disable_mm_preprocessor_cache
self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE) self.mm_cache = LRUCache[str, MultiModalKwargs](MM_CACHE_SIZE)
def process_inputs( def process_inputs(
self, self,

View File

@ -1,4 +1,3 @@
from collections import OrderedDict
from collections.abc import Sequence from collections.abc import Sequence
from contextlib import contextmanager from contextlib import contextmanager
from typing import (Any, Generic, Iterator, List, Optional, TypeVar, Union, from typing import (Any, Generic, Iterator, List, Optional, TypeVar, Union,
@ -102,27 +101,3 @@ def make_zmq_socket(
finally: finally:
ctx.destroy(linger=0) ctx.destroy(linger=0)
K = TypeVar('K')
V = TypeVar('V')
class LRUDictCache(Generic[K, V]):
def __init__(self, size: int):
self.cache: OrderedDict[K, V] = OrderedDict()
self.size = size
def get(self, key: K, default=None) -> V:
if key not in self.cache:
return default
self.cache.move_to_end(key)
return self.cache[key]
def put(self, key: K, value: V):
self.cache[key] = value
self.cache.move_to_end(key)
if len(self.cache) > self.size:
self.cache.popitem(last=False)