mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-02 08:04:33 +08:00
Update deprecated type hinting in vllm/adapter_commons (#18073)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
19324d660c
commit
0b217da646
@ -72,7 +72,6 @@ exclude = [
|
|||||||
"vllm/version.py" = ["F401"]
|
"vllm/version.py" = ["F401"]
|
||||||
"vllm/_version.py" = ["ALL"]
|
"vllm/_version.py" = ["ALL"]
|
||||||
# Python 3.8 typing. TODO: Remove these excludes after v1.0.0
|
# Python 3.8 typing. TODO: Remove these excludes after v1.0.0
|
||||||
"vllm/adapter_commons/**/*.py" = ["UP006", "UP035"]
|
|
||||||
"vllm/attention/**/*.py" = ["UP006", "UP035"]
|
"vllm/attention/**/*.py" = ["UP006", "UP035"]
|
||||||
"vllm/core/**/*.py" = ["UP006", "UP035"]
|
"vllm/core/**/*.py" = ["UP006", "UP035"]
|
||||||
"vllm/device_allocator/**/*.py" = ["UP006", "UP035"]
|
"vllm/device_allocator/**/*.py" = ["UP006", "UP035"]
|
||||||
|
|||||||
@ -1,15 +1,14 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AdapterMapping:
|
class AdapterMapping:
|
||||||
# Per every token in input_ids:
|
# Per every token in input_ids:
|
||||||
index_mapping: Tuple[int, ...]
|
index_mapping: tuple[int, ...]
|
||||||
# Per sampled token:
|
# Per sampled token:
|
||||||
prompt_mapping: Tuple[int, ...]
|
prompt_mapping: tuple[int, ...]
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.index_mapping = tuple(self.index_mapping)
|
self.index_mapping = tuple(self.index_mapping)
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Callable, Dict, Optional, TypeVar
|
from typing import Any, Callable, Optional, TypeVar
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
@ -49,9 +49,9 @@ class AdapterModelManager(ABC):
|
|||||||
model: the model to be adapted.
|
model: the model to be adapted.
|
||||||
"""
|
"""
|
||||||
self.model: nn.Module = model
|
self.model: nn.Module = model
|
||||||
self._registered_adapters: Dict[int, Any] = {}
|
self._registered_adapters: dict[int, Any] = {}
|
||||||
# Dict instead of a Set for compatibility with LRUCache.
|
# Dict instead of a Set for compatibility with LRUCache.
|
||||||
self._active_adapters: Dict[int, None] = {}
|
self._active_adapters: dict[int, None] = {}
|
||||||
self.adapter_type = 'Adapter'
|
self.adapter_type = 'Adapter'
|
||||||
self._last_mapping = None
|
self._last_mapping = None
|
||||||
|
|
||||||
@ -97,7 +97,7 @@ class AdapterModelManager(ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def list_adapters(self) -> Dict[int, Any]:
|
def list_adapters(self) -> dict[int, Any]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import Any, Callable, Dict, Optional, Set
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
|
||||||
## model functions
|
## model functions
|
||||||
def deactivate_adapter(adapter_id: int, active_adapters: Dict[int, None],
|
def deactivate_adapter(adapter_id: int, active_adapters: dict[int, None],
|
||||||
deactivate_func: Callable) -> bool:
|
deactivate_func: Callable) -> bool:
|
||||||
if adapter_id in active_adapters:
|
if adapter_id in active_adapters:
|
||||||
deactivate_func(adapter_id)
|
deactivate_func(adapter_id)
|
||||||
@ -13,7 +13,7 @@ def deactivate_adapter(adapter_id: int, active_adapters: Dict[int, None],
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def add_adapter(adapter: Any, registered_adapters: Dict[int, Any],
|
def add_adapter(adapter: Any, registered_adapters: dict[int, Any],
|
||||||
capacity: int, add_func: Callable) -> bool:
|
capacity: int, add_func: Callable) -> bool:
|
||||||
if adapter.id not in registered_adapters:
|
if adapter.id not in registered_adapters:
|
||||||
if len(registered_adapters) >= capacity:
|
if len(registered_adapters) >= capacity:
|
||||||
@ -32,23 +32,23 @@ def set_adapter_mapping(mapping: Any, last_mapping: Any,
|
|||||||
return last_mapping
|
return last_mapping
|
||||||
|
|
||||||
|
|
||||||
def remove_adapter(adapter_id: int, registered_adapters: Dict[int, Any],
|
def remove_adapter(adapter_id: int, registered_adapters: dict[int, Any],
|
||||||
deactivate_func: Callable) -> bool:
|
deactivate_func: Callable) -> bool:
|
||||||
deactivate_func(adapter_id)
|
deactivate_func(adapter_id)
|
||||||
return bool(registered_adapters.pop(adapter_id, None))
|
return bool(registered_adapters.pop(adapter_id, None))
|
||||||
|
|
||||||
|
|
||||||
def list_adapters(registered_adapters: Dict[int, Any]) -> Dict[int, Any]:
|
def list_adapters(registered_adapters: dict[int, Any]) -> dict[int, Any]:
|
||||||
return dict(registered_adapters)
|
return dict(registered_adapters)
|
||||||
|
|
||||||
|
|
||||||
def get_adapter(adapter_id: int,
|
def get_adapter(adapter_id: int,
|
||||||
registered_adapters: Dict[int, Any]) -> Optional[Any]:
|
registered_adapters: dict[int, Any]) -> Optional[Any]:
|
||||||
return registered_adapters.get(adapter_id)
|
return registered_adapters.get(adapter_id)
|
||||||
|
|
||||||
|
|
||||||
## worker functions
|
## worker functions
|
||||||
def set_active_adapters_worker(requests: Set[Any], mapping: Optional[Any],
|
def set_active_adapters_worker(requests: set[Any], mapping: Optional[Any],
|
||||||
apply_adapters_func,
|
apply_adapters_func,
|
||||||
set_adapter_mapping_func) -> None:
|
set_adapter_mapping_func) -> None:
|
||||||
apply_adapters_func(requests)
|
apply_adapters_func(requests)
|
||||||
@ -66,7 +66,7 @@ def add_adapter_worker(adapter_request: Any, list_adapters_func,
|
|||||||
return loaded
|
return loaded
|
||||||
|
|
||||||
|
|
||||||
def apply_adapters_worker(adapter_requests: Set[Any], list_adapters_func,
|
def apply_adapters_worker(adapter_requests: set[Any], list_adapters_func,
|
||||||
adapter_slots: int, remove_adapter_func,
|
adapter_slots: int, remove_adapter_func,
|
||||||
add_adapter_func) -> None:
|
add_adapter_func) -> None:
|
||||||
models_that_exist = list_adapters_func()
|
models_that_exist = list_adapters_func()
|
||||||
@ -88,5 +88,5 @@ def apply_adapters_worker(adapter_requests: Set[Any], list_adapters_func,
|
|||||||
add_adapter_func(models_map[adapter_id])
|
add_adapter_func(models_map[adapter_id])
|
||||||
|
|
||||||
|
|
||||||
def list_adapters_worker(adapter_manager_list_adapters_func) -> Set[int]:
|
def list_adapters_worker(adapter_manager_list_adapters_func) -> set[int]:
|
||||||
return set(adapter_manager_list_adapters_func())
|
return set(adapter_manager_list_adapters_func())
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Optional, Set
|
from typing import Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -17,7 +17,7 @@ class AbstractWorkerManager(ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def set_active_adapters(self, requests: Set[Any],
|
def set_active_adapters(self, requests: set[Any],
|
||||||
mapping: Optional[Any]) -> None:
|
mapping: Optional[Any]) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -34,5 +34,5 @@ class AbstractWorkerManager(ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def list_adapters(self) -> Set[int]:
|
def list_adapters(self) -> set[int]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user