mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:26:12 +08:00
[Misc] Clean up and consolidate LRUCache (#11339)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
e24113a8fe
commit
cdf22afdda
@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, Hashable, Optional, TypeVar
|
||||
from typing import Any, Callable, Dict, Optional, TypeVar
|
||||
|
||||
from torch import nn
|
||||
|
||||
@ -24,14 +24,13 @@ class AdapterModel(ABC):
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class AdapterLRUCache(LRUCache[T]):
|
||||
class AdapterLRUCache(LRUCache[int, T]):
|
||||
|
||||
def __init__(self, capacity: int, deactivate_fn: Callable[[Hashable],
|
||||
None]):
|
||||
def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]):
|
||||
super().__init__(capacity)
|
||||
self.deactivate_fn = deactivate_fn
|
||||
|
||||
def _on_remove(self, key: Hashable, value: Optional[T]):
|
||||
def _on_remove(self, key: int, value: Optional[T]):
|
||||
logger.debug("Removing adapter int id: %d", key)
|
||||
self.deactivate_fn(key)
|
||||
return super()._on_remove(key, value)
|
||||
|
||||
@ -22,7 +22,7 @@ class TokenizerGroup(BaseTokenizerGroup):
|
||||
self.max_input_length = max_input_length
|
||||
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
|
||||
max_loras = tokenizer_config.get("max_loras", 0)
|
||||
self.lora_tokenizers = LRUCache[AnyTokenizer](
|
||||
self.lora_tokenizers = LRUCache[int, AnyTokenizer](
|
||||
capacity=max(max_loras, max_num_seqs) if enable_lora else 0)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -21,14 +21,13 @@ import uuid
|
||||
import warnings
|
||||
import weakref
|
||||
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
|
||||
from collections import UserDict, defaultdict
|
||||
from collections import OrderedDict, UserDict, defaultdict
|
||||
from collections.abc import Iterable, Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from functools import lru_cache, partial, wraps
|
||||
from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
|
||||
Dict, Generator, Generic, Hashable, List, Literal,
|
||||
Optional, OrderedDict, Set, Tuple, Type, TypeVar, Union,
|
||||
overload)
|
||||
Optional, Tuple, Type, TypeVar, Union, overload)
|
||||
from uuid import uuid4
|
||||
|
||||
import numpy as np
|
||||
@ -154,10 +153,12 @@ TORCH_DTYPE_TO_NUMPY_DTYPE = {
|
||||
}
|
||||
|
||||
P = ParamSpec('P')
|
||||
K = TypeVar("K")
|
||||
T = TypeVar("T")
|
||||
U = TypeVar("U")
|
||||
|
||||
_K = TypeVar("_K", bound=Hashable)
|
||||
_V = TypeVar("_V")
|
||||
|
||||
|
||||
class _Sentinel:
|
||||
...
|
||||
@ -190,50 +191,48 @@ class Counter:
|
||||
self.counter = 0
|
||||
|
||||
|
||||
class LRUCache(Generic[T]):
|
||||
class LRUCache(Generic[_K, _V]):
|
||||
|
||||
def __init__(self, capacity: int):
|
||||
self.cache: OrderedDict[Hashable, T] = OrderedDict()
|
||||
self.pinned_items: Set[Hashable] = set()
|
||||
def __init__(self, capacity: int) -> None:
|
||||
self.cache = OrderedDict[_K, _V]()
|
||||
self.pinned_items = set[_K]()
|
||||
self.capacity = capacity
|
||||
|
||||
def __contains__(self, key: Hashable) -> bool:
|
||||
def __contains__(self, key: _K) -> bool:
|
||||
return key in self.cache
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.cache)
|
||||
|
||||
def __getitem__(self, key: Hashable) -> T:
|
||||
def __getitem__(self, key: _K) -> _V:
|
||||
value = self.cache[key] # Raise KeyError if not exists
|
||||
self.cache.move_to_end(key)
|
||||
return value
|
||||
|
||||
def __setitem__(self, key: Hashable, value: T) -> None:
|
||||
def __setitem__(self, key: _K, value: _V) -> None:
|
||||
self.put(key, value)
|
||||
|
||||
def __delitem__(self, key: Hashable) -> None:
|
||||
def __delitem__(self, key: _K) -> None:
|
||||
self.pop(key)
|
||||
|
||||
def touch(self, key: Hashable) -> None:
|
||||
def touch(self, key: _K) -> None:
|
||||
self.cache.move_to_end(key)
|
||||
|
||||
def get(self,
|
||||
key: Hashable,
|
||||
default_value: Optional[T] = None) -> Optional[T]:
|
||||
value: Optional[T]
|
||||
def get(self, key: _K, default: Optional[_V] = None) -> Optional[_V]:
|
||||
value: Optional[_V]
|
||||
if key in self.cache:
|
||||
value = self.cache[key]
|
||||
self.cache.move_to_end(key)
|
||||
else:
|
||||
value = default_value
|
||||
value = default
|
||||
return value
|
||||
|
||||
def put(self, key: Hashable, value: T) -> None:
|
||||
def put(self, key: _K, value: _V) -> None:
|
||||
self.cache[key] = value
|
||||
self.cache.move_to_end(key)
|
||||
self._remove_old_if_needed()
|
||||
|
||||
def pin(self, key: Hashable) -> None:
|
||||
def pin(self, key: _K) -> None:
|
||||
"""
|
||||
Pins a key in the cache preventing it from being
|
||||
evicted in the LRU order.
|
||||
@ -242,13 +241,13 @@ class LRUCache(Generic[T]):
|
||||
raise ValueError(f"Cannot pin key: {key} not in cache.")
|
||||
self.pinned_items.add(key)
|
||||
|
||||
def _unpin(self, key: Hashable) -> None:
|
||||
def _unpin(self, key: _K) -> None:
|
||||
self.pinned_items.remove(key)
|
||||
|
||||
def _on_remove(self, key: Hashable, value: Optional[T]):
|
||||
def _on_remove(self, key: _K, value: Optional[_V]) -> None:
|
||||
pass
|
||||
|
||||
def remove_oldest(self, remove_pinned=False):
|
||||
def remove_oldest(self, *, remove_pinned: bool = False) -> None:
|
||||
if not self.cache:
|
||||
return
|
||||
|
||||
@ -262,17 +261,15 @@ class LRUCache(Generic[T]):
|
||||
"cannot remove oldest from the cache.")
|
||||
else:
|
||||
lru_key = next(iter(self.cache))
|
||||
self.pop(lru_key)
|
||||
self.pop(lru_key) # type: ignore
|
||||
|
||||
def _remove_old_if_needed(self) -> None:
|
||||
while len(self.cache) > self.capacity:
|
||||
self.remove_oldest()
|
||||
|
||||
def pop(self,
|
||||
key: Hashable,
|
||||
default_value: Optional[T] = None) -> Optional[T]:
|
||||
def pop(self, key: _K, default: Optional[_V] = None) -> Optional[_V]:
|
||||
run_on_remove = key in self.cache
|
||||
value: Optional[T] = self.cache.pop(key, default_value)
|
||||
value = self.cache.pop(key, default)
|
||||
# remove from pinned items
|
||||
if key in self.pinned_items:
|
||||
self._unpin(key)
|
||||
@ -280,7 +277,7 @@ class LRUCache(Generic[T]):
|
||||
self._on_remove(key, value)
|
||||
return value
|
||||
|
||||
def clear(self):
|
||||
def clear(self) -> None:
|
||||
while len(self.cache) > 0:
|
||||
self.remove_oldest(remove_pinned=True)
|
||||
self.cache.clear()
|
||||
@ -843,10 +840,6 @@ def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
|
||||
return [item for sublist in lists for item in sublist]
|
||||
|
||||
|
||||
_K = TypeVar("_K", bound=Hashable)
|
||||
_V = TypeVar("_V")
|
||||
|
||||
|
||||
def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]):
|
||||
"""
|
||||
Unlike :class:`itertools.groupby`, groups are not broken by
|
||||
|
||||
@ -8,7 +8,7 @@ from vllm.inputs import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
|
||||
MultiModalKwargs, MultiModalRegistry)
|
||||
from vllm.v1.utils import LRUDictCache
|
||||
from vllm.utils import LRUCache
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -44,7 +44,7 @@ class MMInputMapperClient:
|
||||
|
||||
# Init cache
|
||||
self.use_cache = not model_config.disable_mm_preprocessor_cache
|
||||
self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE)
|
||||
self.mm_cache = LRUCache[str, MultiModalKwargs](MM_CACHE_SIZE)
|
||||
|
||||
# DEBUG: Set to None to disable
|
||||
self.mm_debug_cache_hit_ratio_steps = None
|
||||
@ -120,7 +120,7 @@ class MMInputMapperServer:
|
||||
|
||||
def __init__(self, model_config):
|
||||
self.use_cache = not model_config.disable_mm_preprocessor_cache
|
||||
self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE)
|
||||
self.mm_cache = LRUCache[str, MultiModalKwargs](MM_CACHE_SIZE)
|
||||
|
||||
def process_inputs(
|
||||
self,
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Sequence
|
||||
from contextlib import contextmanager
|
||||
from typing import (Any, Generic, Iterator, List, Optional, TypeVar, Union,
|
||||
@ -102,27 +101,3 @@ def make_zmq_socket(
|
||||
|
||||
finally:
|
||||
ctx.destroy(linger=0)
|
||||
|
||||
|
||||
K = TypeVar('K')
|
||||
V = TypeVar('V')
|
||||
|
||||
|
||||
class LRUDictCache(Generic[K, V]):
|
||||
|
||||
def __init__(self, size: int):
|
||||
self.cache: OrderedDict[K, V] = OrderedDict()
|
||||
self.size = size
|
||||
|
||||
def get(self, key: K, default=None) -> V:
|
||||
if key not in self.cache:
|
||||
return default
|
||||
|
||||
self.cache.move_to_end(key)
|
||||
return self.cache[key]
|
||||
|
||||
def put(self, key: K, value: V):
|
||||
self.cache[key] = value
|
||||
self.cache.move_to_end(key)
|
||||
if len(self.cache) > self.size:
|
||||
self.cache.popitem(last=False)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user