mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 07:24:57 +08:00
[Misc] Clean up and consolidate LRUCache (#11339)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
e24113a8fe
commit
cdf22afdda
@ -1,5 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Callable, Dict, Hashable, Optional, TypeVar
|
from typing import Any, Callable, Dict, Optional, TypeVar
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
@ -24,14 +24,13 @@ class AdapterModel(ABC):
|
|||||||
T = TypeVar('T')
|
T = TypeVar('T')
|
||||||
|
|
||||||
|
|
||||||
class AdapterLRUCache(LRUCache[T]):
|
class AdapterLRUCache(LRUCache[int, T]):
|
||||||
|
|
||||||
def __init__(self, capacity: int, deactivate_fn: Callable[[Hashable],
|
def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]):
|
||||||
None]):
|
|
||||||
super().__init__(capacity)
|
super().__init__(capacity)
|
||||||
self.deactivate_fn = deactivate_fn
|
self.deactivate_fn = deactivate_fn
|
||||||
|
|
||||||
def _on_remove(self, key: Hashable, value: Optional[T]):
|
def _on_remove(self, key: int, value: Optional[T]):
|
||||||
logger.debug("Removing adapter int id: %d", key)
|
logger.debug("Removing adapter int id: %d", key)
|
||||||
self.deactivate_fn(key)
|
self.deactivate_fn(key)
|
||||||
return super()._on_remove(key, value)
|
return super()._on_remove(key, value)
|
||||||
|
|||||||
@ -22,7 +22,7 @@ class TokenizerGroup(BaseTokenizerGroup):
|
|||||||
self.max_input_length = max_input_length
|
self.max_input_length = max_input_length
|
||||||
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
|
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
|
||||||
max_loras = tokenizer_config.get("max_loras", 0)
|
max_loras = tokenizer_config.get("max_loras", 0)
|
||||||
self.lora_tokenizers = LRUCache[AnyTokenizer](
|
self.lora_tokenizers = LRUCache[int, AnyTokenizer](
|
||||||
capacity=max(max_loras, max_num_seqs) if enable_lora else 0)
|
capacity=max(max_loras, max_num_seqs) if enable_lora else 0)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -21,14 +21,13 @@ import uuid
|
|||||||
import warnings
|
import warnings
|
||||||
import weakref
|
import weakref
|
||||||
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
|
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
|
||||||
from collections import UserDict, defaultdict
|
from collections import OrderedDict, UserDict, defaultdict
|
||||||
from collections.abc import Iterable, Mapping
|
from collections.abc import Iterable, Mapping
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from functools import lru_cache, partial, wraps
|
from functools import lru_cache, partial, wraps
|
||||||
from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
|
from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
|
||||||
Dict, Generator, Generic, Hashable, List, Literal,
|
Dict, Generator, Generic, Hashable, List, Literal,
|
||||||
Optional, OrderedDict, Set, Tuple, Type, TypeVar, Union,
|
Optional, Tuple, Type, TypeVar, Union, overload)
|
||||||
overload)
|
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -154,10 +153,12 @@ TORCH_DTYPE_TO_NUMPY_DTYPE = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
P = ParamSpec('P')
|
P = ParamSpec('P')
|
||||||
K = TypeVar("K")
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
U = TypeVar("U")
|
U = TypeVar("U")
|
||||||
|
|
||||||
|
_K = TypeVar("_K", bound=Hashable)
|
||||||
|
_V = TypeVar("_V")
|
||||||
|
|
||||||
|
|
||||||
class _Sentinel:
|
class _Sentinel:
|
||||||
...
|
...
|
||||||
@ -190,50 +191,48 @@ class Counter:
|
|||||||
self.counter = 0
|
self.counter = 0
|
||||||
|
|
||||||
|
|
||||||
class LRUCache(Generic[T]):
|
class LRUCache(Generic[_K, _V]):
|
||||||
|
|
||||||
def __init__(self, capacity: int):
|
def __init__(self, capacity: int) -> None:
|
||||||
self.cache: OrderedDict[Hashable, T] = OrderedDict()
|
self.cache = OrderedDict[_K, _V]()
|
||||||
self.pinned_items: Set[Hashable] = set()
|
self.pinned_items = set[_K]()
|
||||||
self.capacity = capacity
|
self.capacity = capacity
|
||||||
|
|
||||||
def __contains__(self, key: Hashable) -> bool:
|
def __contains__(self, key: _K) -> bool:
|
||||||
return key in self.cache
|
return key in self.cache
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self.cache)
|
return len(self.cache)
|
||||||
|
|
||||||
def __getitem__(self, key: Hashable) -> T:
|
def __getitem__(self, key: _K) -> _V:
|
||||||
value = self.cache[key] # Raise KeyError if not exists
|
value = self.cache[key] # Raise KeyError if not exists
|
||||||
self.cache.move_to_end(key)
|
self.cache.move_to_end(key)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def __setitem__(self, key: Hashable, value: T) -> None:
|
def __setitem__(self, key: _K, value: _V) -> None:
|
||||||
self.put(key, value)
|
self.put(key, value)
|
||||||
|
|
||||||
def __delitem__(self, key: Hashable) -> None:
|
def __delitem__(self, key: _K) -> None:
|
||||||
self.pop(key)
|
self.pop(key)
|
||||||
|
|
||||||
def touch(self, key: Hashable) -> None:
|
def touch(self, key: _K) -> None:
|
||||||
self.cache.move_to_end(key)
|
self.cache.move_to_end(key)
|
||||||
|
|
||||||
def get(self,
|
def get(self, key: _K, default: Optional[_V] = None) -> Optional[_V]:
|
||||||
key: Hashable,
|
value: Optional[_V]
|
||||||
default_value: Optional[T] = None) -> Optional[T]:
|
|
||||||
value: Optional[T]
|
|
||||||
if key in self.cache:
|
if key in self.cache:
|
||||||
value = self.cache[key]
|
value = self.cache[key]
|
||||||
self.cache.move_to_end(key)
|
self.cache.move_to_end(key)
|
||||||
else:
|
else:
|
||||||
value = default_value
|
value = default
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def put(self, key: Hashable, value: T) -> None:
|
def put(self, key: _K, value: _V) -> None:
|
||||||
self.cache[key] = value
|
self.cache[key] = value
|
||||||
self.cache.move_to_end(key)
|
self.cache.move_to_end(key)
|
||||||
self._remove_old_if_needed()
|
self._remove_old_if_needed()
|
||||||
|
|
||||||
def pin(self, key: Hashable) -> None:
|
def pin(self, key: _K) -> None:
|
||||||
"""
|
"""
|
||||||
Pins a key in the cache preventing it from being
|
Pins a key in the cache preventing it from being
|
||||||
evicted in the LRU order.
|
evicted in the LRU order.
|
||||||
@ -242,13 +241,13 @@ class LRUCache(Generic[T]):
|
|||||||
raise ValueError(f"Cannot pin key: {key} not in cache.")
|
raise ValueError(f"Cannot pin key: {key} not in cache.")
|
||||||
self.pinned_items.add(key)
|
self.pinned_items.add(key)
|
||||||
|
|
||||||
def _unpin(self, key: Hashable) -> None:
|
def _unpin(self, key: _K) -> None:
|
||||||
self.pinned_items.remove(key)
|
self.pinned_items.remove(key)
|
||||||
|
|
||||||
def _on_remove(self, key: Hashable, value: Optional[T]):
|
def _on_remove(self, key: _K, value: Optional[_V]) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def remove_oldest(self, remove_pinned=False):
|
def remove_oldest(self, *, remove_pinned: bool = False) -> None:
|
||||||
if not self.cache:
|
if not self.cache:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -262,17 +261,15 @@ class LRUCache(Generic[T]):
|
|||||||
"cannot remove oldest from the cache.")
|
"cannot remove oldest from the cache.")
|
||||||
else:
|
else:
|
||||||
lru_key = next(iter(self.cache))
|
lru_key = next(iter(self.cache))
|
||||||
self.pop(lru_key)
|
self.pop(lru_key) # type: ignore
|
||||||
|
|
||||||
def _remove_old_if_needed(self) -> None:
|
def _remove_old_if_needed(self) -> None:
|
||||||
while len(self.cache) > self.capacity:
|
while len(self.cache) > self.capacity:
|
||||||
self.remove_oldest()
|
self.remove_oldest()
|
||||||
|
|
||||||
def pop(self,
|
def pop(self, key: _K, default: Optional[_V] = None) -> Optional[_V]:
|
||||||
key: Hashable,
|
|
||||||
default_value: Optional[T] = None) -> Optional[T]:
|
|
||||||
run_on_remove = key in self.cache
|
run_on_remove = key in self.cache
|
||||||
value: Optional[T] = self.cache.pop(key, default_value)
|
value = self.cache.pop(key, default)
|
||||||
# remove from pinned items
|
# remove from pinned items
|
||||||
if key in self.pinned_items:
|
if key in self.pinned_items:
|
||||||
self._unpin(key)
|
self._unpin(key)
|
||||||
@ -280,7 +277,7 @@ class LRUCache(Generic[T]):
|
|||||||
self._on_remove(key, value)
|
self._on_remove(key, value)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def clear(self):
|
def clear(self) -> None:
|
||||||
while len(self.cache) > 0:
|
while len(self.cache) > 0:
|
||||||
self.remove_oldest(remove_pinned=True)
|
self.remove_oldest(remove_pinned=True)
|
||||||
self.cache.clear()
|
self.cache.clear()
|
||||||
@ -843,10 +840,6 @@ def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
|
|||||||
return [item for sublist in lists for item in sublist]
|
return [item for sublist in lists for item in sublist]
|
||||||
|
|
||||||
|
|
||||||
_K = TypeVar("_K", bound=Hashable)
|
|
||||||
_V = TypeVar("_V")
|
|
||||||
|
|
||||||
|
|
||||||
def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]):
|
def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]):
|
||||||
"""
|
"""
|
||||||
Unlike :class:`itertools.groupby`, groups are not broken by
|
Unlike :class:`itertools.groupby`, groups are not broken by
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from vllm.inputs import PromptType
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
|
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
|
||||||
MultiModalKwargs, MultiModalRegistry)
|
MultiModalKwargs, MultiModalRegistry)
|
||||||
from vllm.v1.utils import LRUDictCache
|
from vllm.utils import LRUCache
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -44,7 +44,7 @@ class MMInputMapperClient:
|
|||||||
|
|
||||||
# Init cache
|
# Init cache
|
||||||
self.use_cache = not model_config.disable_mm_preprocessor_cache
|
self.use_cache = not model_config.disable_mm_preprocessor_cache
|
||||||
self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE)
|
self.mm_cache = LRUCache[str, MultiModalKwargs](MM_CACHE_SIZE)
|
||||||
|
|
||||||
# DEBUG: Set to None to disable
|
# DEBUG: Set to None to disable
|
||||||
self.mm_debug_cache_hit_ratio_steps = None
|
self.mm_debug_cache_hit_ratio_steps = None
|
||||||
@ -120,7 +120,7 @@ class MMInputMapperServer:
|
|||||||
|
|
||||||
def __init__(self, model_config):
|
def __init__(self, model_config):
|
||||||
self.use_cache = not model_config.disable_mm_preprocessor_cache
|
self.use_cache = not model_config.disable_mm_preprocessor_cache
|
||||||
self.mm_cache = LRUDictCache[str, MultiModalKwargs](MM_CACHE_SIZE)
|
self.mm_cache = LRUCache[str, MultiModalKwargs](MM_CACHE_SIZE)
|
||||||
|
|
||||||
def process_inputs(
|
def process_inputs(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
from collections import OrderedDict
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import (Any, Generic, Iterator, List, Optional, TypeVar, Union,
|
from typing import (Any, Generic, Iterator, List, Optional, TypeVar, Union,
|
||||||
@ -102,27 +101,3 @@ def make_zmq_socket(
|
|||||||
|
|
||||||
finally:
|
finally:
|
||||||
ctx.destroy(linger=0)
|
ctx.destroy(linger=0)
|
||||||
|
|
||||||
|
|
||||||
K = TypeVar('K')
|
|
||||||
V = TypeVar('V')
|
|
||||||
|
|
||||||
|
|
||||||
class LRUDictCache(Generic[K, V]):
|
|
||||||
|
|
||||||
def __init__(self, size: int):
|
|
||||||
self.cache: OrderedDict[K, V] = OrderedDict()
|
|
||||||
self.size = size
|
|
||||||
|
|
||||||
def get(self, key: K, default=None) -> V:
|
|
||||||
if key not in self.cache:
|
|
||||||
return default
|
|
||||||
|
|
||||||
self.cache.move_to_end(key)
|
|
||||||
return self.cache[key]
|
|
||||||
|
|
||||||
def put(self, key: K, value: V):
|
|
||||||
self.cache[key] = value
|
|
||||||
self.cache.move_to_end(key)
|
|
||||||
if len(self.cache) > self.size:
|
|
||||||
self.cache.popitem(last=False)
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user