From 0b217da646fd4cc08cd0dd20d0ea69f81d64ab35 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Tue, 13 May 2025 16:32:51 +0100 Subject: [PATCH] Update deprecated type hinting in `vllm/adapter_commons` (#18073) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- pyproject.toml | 1 - vllm/adapter_commons/layers.py | 5 ++--- vllm/adapter_commons/models.py | 8 ++++---- vllm/adapter_commons/utils.py | 18 +++++++++--------- vllm/adapter_commons/worker_manager.py | 6 +++--- 5 files changed, 18 insertions(+), 20 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6f5c560e800f0..ac8a3612907de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,7 +72,6 @@ exclude = [ "vllm/version.py" = ["F401"] "vllm/_version.py" = ["ALL"] # Python 3.8 typing. TODO: Remove these excludes after v1.0.0 -"vllm/adapter_commons/**/*.py" = ["UP006", "UP035"] "vllm/attention/**/*.py" = ["UP006", "UP035"] "vllm/core/**/*.py" = ["UP006", "UP035"] "vllm/device_allocator/**/*.py" = ["UP006", "UP035"] diff --git a/vllm/adapter_commons/layers.py b/vllm/adapter_commons/layers.py index 18e0c5227d45c..9cc2b181fc7cc 100644 --- a/vllm/adapter_commons/layers.py +++ b/vllm/adapter_commons/layers.py @@ -1,15 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass -from typing import Tuple @dataclass class AdapterMapping: # Per every token in input_ids: - index_mapping: Tuple[int, ...] + index_mapping: tuple[int, ...] # Per sampled token: - prompt_mapping: Tuple[int, ...] + prompt_mapping: tuple[int, ...] def __post_init__(self): self.index_mapping = tuple(self.index_mapping) diff --git a/vllm/adapter_commons/models.py b/vllm/adapter_commons/models.py index f9a5d2fffad5e..a84fbea2e444a 100644 --- a/vllm/adapter_commons/models.py +++ b/vllm/adapter_commons/models.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Optional, TypeVar +from typing import Any, Callable, Optional, TypeVar from torch import nn @@ -49,9 +49,9 @@ class AdapterModelManager(ABC): model: the model to be adapted. """ self.model: nn.Module = model - self._registered_adapters: Dict[int, Any] = {} + self._registered_adapters: dict[int, Any] = {} # Dict instead of a Set for compatibility with LRUCache. - self._active_adapters: Dict[int, None] = {} + self._active_adapters: dict[int, None] = {} self.adapter_type = 'Adapter' self._last_mapping = None @@ -97,7 +97,7 @@ class AdapterModelManager(ABC): raise NotImplementedError @abstractmethod - def list_adapters(self) -> Dict[int, Any]: + def list_adapters(self) -> dict[int, Any]: raise NotImplementedError @abstractmethod diff --git a/vllm/adapter_commons/utils.py b/vllm/adapter_commons/utils.py index c2dc5433cc656..46e9629e1f55f 100644 --- a/vllm/adapter_commons/utils.py +++ b/vllm/adapter_commons/utils.py @@ -1,10 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Callable, Dict, Optional, Set +from typing import Any, Callable, Optional ## model functions -def deactivate_adapter(adapter_id: int, active_adapters: Dict[int, None], +def deactivate_adapter(adapter_id: int, active_adapters: dict[int, None], deactivate_func: Callable) -> bool: if adapter_id in active_adapters: deactivate_func(adapter_id) @@ -13,7 +13,7 @@ def deactivate_adapter(adapter_id: int, active_adapters: Dict[int, None], return False -def add_adapter(adapter: Any, registered_adapters: Dict[int, Any], +def add_adapter(adapter: Any, registered_adapters: dict[int, Any], capacity: int, add_func: Callable) -> bool: if adapter.id not in registered_adapters: if len(registered_adapters) >= capacity: @@ -32,23 +32,23 @@ def set_adapter_mapping(mapping: Any, last_mapping: Any, return last_mapping -def remove_adapter(adapter_id: int, registered_adapters: Dict[int, Any], +def remove_adapter(adapter_id: int, registered_adapters: dict[int, Any], deactivate_func: Callable) -> bool: deactivate_func(adapter_id) return bool(registered_adapters.pop(adapter_id, None)) -def list_adapters(registered_adapters: Dict[int, Any]) -> Dict[int, Any]: +def list_adapters(registered_adapters: dict[int, Any]) -> dict[int, Any]: return dict(registered_adapters) def get_adapter(adapter_id: int, - registered_adapters: Dict[int, Any]) -> Optional[Any]: + registered_adapters: dict[int, Any]) -> Optional[Any]: return registered_adapters.get(adapter_id) ## worker functions -def set_active_adapters_worker(requests: Set[Any], mapping: Optional[Any], +def set_active_adapters_worker(requests: set[Any], mapping: Optional[Any], apply_adapters_func, set_adapter_mapping_func) -> None: apply_adapters_func(requests) @@ -66,7 +66,7 @@ def add_adapter_worker(adapter_request: Any, list_adapters_func, return loaded -def apply_adapters_worker(adapter_requests: Set[Any], list_adapters_func, +def apply_adapters_worker(adapter_requests: set[Any], list_adapters_func, adapter_slots: int, remove_adapter_func, add_adapter_func) -> None: models_that_exist = list_adapters_func() @@ -88,5 +88,5 @@ def apply_adapters_worker(adapter_requests: Set[Any], list_adapters_func, add_adapter_func(models_map[adapter_id]) -def list_adapters_worker(adapter_manager_list_adapters_func) -> Set[int]: +def list_adapters_worker(adapter_manager_list_adapters_func) -> set[int]: return set(adapter_manager_list_adapters_func()) diff --git a/vllm/adapter_commons/worker_manager.py b/vllm/adapter_commons/worker_manager.py index ce24e08a5b56e..3c1d26404c990 100644 --- a/vllm/adapter_commons/worker_manager.py +++ b/vllm/adapter_commons/worker_manager.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod -from typing import Any, Optional, Set +from typing import Any, Optional import torch @@ -17,7 +17,7 @@ class AbstractWorkerManager(ABC): raise NotImplementedError @abstractmethod - def set_active_adapters(self, requests: Set[Any], + def set_active_adapters(self, requests: set[Any], mapping: Optional[Any]) -> None: raise NotImplementedError @@ -34,5 +34,5 @@ class AbstractWorkerManager(ABC): raise NotImplementedError @abstractmethod - def list_adapters(self) -> Set[int]: + def list_adapters(self) -> set[int]: raise NotImplementedError