From 9b5b39b650ef37d9086985eabfb9ed2f1c327075 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Wed, 14 May 2025 11:57:59 +0100 Subject: [PATCH] Update deprecated type hinting in `vllm/lora` (#18128) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- pyproject.toml | 1 - vllm/lora/fully_sharded_layers.py | 22 ++--- vllm/lora/layers.py | 62 +++++++------- vllm/lora/lora.py | 14 ++-- vllm/lora/models.py | 56 ++++++------- vllm/lora/ops/triton_ops/lora_expand_op.py | 8 +- .../ops/triton_ops/lora_kernel_metadata.py | 4 +- vllm/lora/ops/triton_ops/lora_shrink_op.py | 8 +- vllm/lora/ops/triton_ops/utils.py | 10 +-- vllm/lora/peft_helper.py | 4 +- vllm/lora/punica_wrapper/punica_base.py | 82 +++++++++---------- vllm/lora/punica_wrapper/punica_cpu.py | 46 +++++------ vllm/lora/punica_wrapper/punica_gpu.py | 36 ++++---- vllm/lora/punica_wrapper/punica_hpu.py | 30 +++---- vllm/lora/punica_wrapper/punica_tpu.py | 50 +++++------ vllm/lora/punica_wrapper/utils.py | 16 ++-- vllm/lora/resolver.py | 7 +- vllm/lora/utils.py | 18 ++-- vllm/lora/worker_manager.py | 22 ++--- 19 files changed, 245 insertions(+), 251 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ac8a3612907de..62196a842e060 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,7 +78,6 @@ exclude = [ "vllm/distributed/**/*.py" = ["UP006", "UP035"] "vllm/engine/**/*.py" = ["UP006", "UP035"] "vllm/executor/**/*.py" = ["UP006", "UP035"] -"vllm/lora/**/*.py" = ["UP006", "UP035"] "vllm/model_executor/model_loader/**/*.py" = ["UP006", "UP035"] "vllm/model_executor/models/**/*.py" = ["UP006", "UP035"] "vllm/platforms/**/*.py" = ["UP006", "UP035"] diff --git a/vllm/lora/fully_sharded_layers.py b/vllm/lora/fully_sharded_layers.py index e195f8cf5e8e9..b6b138a44051f 100644 --- a/vllm/lora/fully_sharded_layers.py +++ b/vllm/lora/fully_sharded_layers.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # pylint: disable=unused-argument -from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Optional, Union, cast import torch import torch.nn as nn @@ -118,7 +118,7 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA): cls, source_layer: nn.Module, lora_config: LoRAConfig, - packed_modules_list: List, + packed_modules_list: list, model_config: Optional[PretrainedConfig], ) -> bool: # specifying kwargs so they can be easily accessed in decorator @@ -141,8 +141,8 @@ class MergedColumnParallelLinearWithShardedLoRA( """ def slice_lora_a( - self, lora_a: List[Union[torch.Tensor, None]] - ) -> List[Union[torch.Tensor, None]]: + self, lora_a: list[Union[torch.Tensor, None]] + ) -> list[Union[torch.Tensor, None]]: #NOTE: lora_a contains 2 subloras, and each sublora could be None. output_shard_size = self.lora_a_stacked[0].shape[2] output_start_idx = self.tp_rank * output_shard_size @@ -165,7 +165,7 @@ class MergedColumnParallelLinearWithShardedLoRA( cls, source_layer: nn.Module, lora_config: LoRAConfig, - packed_modules_list: List, + packed_modules_list: list, model_config: Optional[PretrainedConfig], ) -> bool: # specifying kwargs so they can be easily accessed in decorator @@ -201,7 +201,7 @@ class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA): @classmethod @_fully_sharded_can_replace def can_replace_layer(cls, source_layer: nn.Module, - lora_config: LoRAConfig, packed_modules_list: List, + lora_config: LoRAConfig, packed_modules_list: list, model_config: Optional[PretrainedConfig]) -> bool: # specifying kwargs so they can be easily accessed in decorator return super().can_replace_layer( @@ -222,8 +222,8 @@ class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA): """ def slice_lora_a( - self, lora_a: List[Union[torch.Tensor, None]] - ) -> List[Union[torch.Tensor, None]]: + self, lora_a: list[Union[torch.Tensor, None]] + ) -> list[Union[torch.Tensor, None]]: # NOTE: lora_a contains 3 subloras, and each sublora could be None. shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)] start_idx = [self.tp_rank * shard_size[i] for i in range(3)] @@ -248,7 +248,7 @@ class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA): cls, source_layer: nn.Module, lora_config: LoRAConfig, - packed_modules_list: List, + packed_modules_list: list, model_config: Optional[PretrainedConfig], ) -> bool: # specifying kwargs so they can be easily accessed in decorator @@ -281,7 +281,7 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: if bias is None: return bias - self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], + self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], self.lora_bias_stacked) shard_size = self.lora_bias_stacked[0].shape[2] start_idx = self.tp_rank * shard_size @@ -341,7 +341,7 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): cls, source_layer: nn.Module, lora_config: LoRAConfig, - packed_modules_list: List, + packed_modules_list: list, model_config: Optional[PretrainedConfig], ) -> bool: # specifying kwargs so they can be easily accessed in decorator diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 6749ec16a0973..023c8e9c9a864 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -3,7 +3,7 @@ # pylint: disable=unused-argument import math from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Optional, Union, cast import torch import torch.nn as nn @@ -82,14 +82,14 @@ class LoRAMapping(AdapterMapping): class BaseLayerWithLoRA(nn.Module): def slice_lora_a( - self, lora_a: Union[torch.Tensor, List[Union[torch.Tensor, None]]] - ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]: + self, lora_a: Union[torch.Tensor, list[Union[torch.Tensor, None]]] + ) -> Union[torch.Tensor, list[Union[torch.Tensor, None]]]: """Slice lora a if splitting for tensor parallelism.""" ... def slice_lora_b( - self, lora_b: Union[torch.Tensor, List[Union[torch.Tensor, None]]] - ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]: + self, lora_b: Union[torch.Tensor, list[Union[torch.Tensor, None]]] + ) -> Union[torch.Tensor, list[Union[torch.Tensor, None]]]: """Slice lora b if splitting with tensor parallelism.""" ... @@ -128,7 +128,7 @@ class BaseLayerWithLoRA(nn.Module): cls, source_layer: nn.Module, lora_config: LoRAConfig, - packed_modules_list: List, + packed_modules_list: list, model_config: Optional[PretrainedConfig], ) -> bool: """Returns True if the layer can be replaced by this LoRA layer.""" @@ -140,7 +140,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): def __init__(self, base_layer: VocabParallelEmbedding) -> None: super().__init__() self.base_layer = base_layer - self.embeddings_slice: Optional[Tuple[int, int]] + self.embeddings_slice: Optional[tuple[int, int]] self.embeddings_weights: Optional[torch.Tensor] def create_lora_weights( @@ -279,7 +279,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): cls, source_layer: nn.Module, lora_config: LoRAConfig, - packed_modules_list: List, + packed_modules_list: list, model_config: Optional[PretrainedConfig], ) -> bool: return type(source_layer) is VocabParallelEmbedding @@ -296,9 +296,9 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): self.base_layer = base_layer self.input_size = self.base_layer.input_size self.device = _get_lora_device(self.base_layer) - self.lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]] = None + self.lora_bias_stacked: Optional[tuple[torch.Tensor, ...]] = None - self.output_slices: Tuple[int, ...] + self.output_slices: tuple[int, ...] self.tp_size: int self.output_size: int self.n_slices: int @@ -365,7 +365,7 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): self.lora_b_stacked[s_index][index] = 0 if self.lora_config.bias_enabled: # Make mypy happy - self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], + self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], self.lora_bias_stacked) self.lora_bias_stacked[s_index][index] = 0 @@ -399,7 +399,7 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): lora_b.T, non_blocking=True) if lora_bias is not None: - self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], + self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], self.lora_bias_stacked) assert len(self.lora_bias_stacked) self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_( @@ -497,7 +497,7 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA): cls, source_layer: nn.Module, lora_config: LoRAConfig, - packed_modules_list: List, + packed_modules_list: list, model_config: Optional[PretrainedConfig], ) -> bool: return type(source_layer) is ReplicatedLinear @@ -597,7 +597,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): cls, source_layer: nn.Module, lora_config: LoRAConfig, - packed_modules_list: List, + packed_modules_list: list, model_config: Optional[PretrainedConfig], ) -> bool: return type(source_layer) is ColumnParallelLinear or ( @@ -674,13 +674,13 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ) for output_size in self.output_slices) def slice_lora_a( - self, lora_a: List[Union[torch.Tensor, None]] - ) -> List[Union[torch.Tensor, None]]: + self, lora_a: list[Union[torch.Tensor, None]] + ) -> list[Union[torch.Tensor, None]]: return lora_a def slice_lora_b( - self, lora_b: List[Union[torch.Tensor, None]] - ) -> List[Union[torch.Tensor, None]]: + self, lora_b: list[Union[torch.Tensor, None]] + ) -> list[Union[torch.Tensor, None]]: for i, (shard_id, shard_size) in enumerate( zip(self.output_ids, self.output_slices)): if (lora_b_i := lora_b[i]) is not None: @@ -689,8 +689,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): return lora_b def slice_bias( - self, bias: List[Union[torch.Tensor, - None]]) -> List[Union[torch.Tensor, None]]: + self, bias: list[Union[torch.Tensor, + None]]) -> list[Union[torch.Tensor, None]]: for i, (shard_id, shard_size) in enumerate( zip(self.output_ids, self.output_slices)): if (bias_i := bias[i]) is not None: @@ -725,7 +725,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): lora_b_i.T, non_blocking=True) if lora_bias is not None: - self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], + self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], self.lora_bias_stacked) for i in range(self.n_slices): if (lora_bias_i := lora_bias[i]) is not None: @@ -740,7 +740,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): cls, source_layer: nn.Module, lora_config: LoRAConfig, - packed_modules_list: List, + packed_modules_list: list, model_config: Optional[PretrainedConfig], ) -> bool: return (type(source_layer) is MergedColumnParallelLinear @@ -809,7 +809,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): @classmethod @_not_fully_sharded_can_replace def can_replace_layer(cls, source_layer: nn.Module, - lora_config: LoRAConfig, packed_modules_list: List, + lora_config: LoRAConfig, packed_modules_list: list, model_config: Optional[PretrainedConfig]) -> bool: return type(source_layer) is QKVParallelLinear and len( packed_modules_list) == 1 @@ -869,7 +869,7 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA): cls, source_layer: nn.Module, lora_config: LoRAConfig, - packed_modules_list: List, + packed_modules_list: list, model_config: Optional[PretrainedConfig], ) -> bool: return (type(source_layer) is QKVParallelLinear @@ -923,7 +923,7 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): - output - bias """ - # Set up backprop all-reduce. + # set up backprop all-reduce. if self.base_layer.input_is_parallel: input_parallel = input_ else: @@ -958,7 +958,7 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): cls, source_layer: nn.Module, lora_config: LoRAConfig, - packed_modules_list: List, + packed_modules_list: list, model_config: Optional[PretrainedConfig], ) -> bool: return type(source_layer) is RowParallelLinear @@ -981,7 +981,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): def __init__(self, base_layer: LogitsProcessor, hidden_size: int, dtype: torch.dtype, device: torch.device, - sharded_to_full_mapping: Optional[List[int]]) -> None: + sharded_to_full_mapping: Optional[list[int]]) -> None: super().__init__() self.base_layer = base_layer self.hidden_size = hidden_size @@ -1189,7 +1189,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): cls, source_layer: nn.Module, lora_config: LoRAConfig, - packed_modules_list: List, + packed_modules_list: list, model_config: Optional[PretrainedConfig], ) -> bool: # Special handling for the LogitsProcessor. @@ -1256,7 +1256,7 @@ class LinearScalingRotaryEmbeddingWithLoRA(BaseLayerWithLoRA): positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: return self.base_layer( positions, query, @@ -1265,7 +1265,7 @@ class LinearScalingRotaryEmbeddingWithLoRA(BaseLayerWithLoRA): ) @property - def scaling_factor_to_offset(self) -> Dict[float, int]: + def scaling_factor_to_offset(self) -> dict[float, int]: return self.base_layer.scaling_factor_to_offset @classmethod @@ -1273,7 +1273,7 @@ class LinearScalingRotaryEmbeddingWithLoRA(BaseLayerWithLoRA): cls, source_layer: nn.Module, lora_config: LoRAConfig, - packed_modules_list: List, + packed_modules_list: list, model_config: Optional[PretrainedConfig], ) -> bool: """Returns True if the layer can be replaced by this LoRA layer.""" diff --git a/vllm/lora/lora.py b/vllm/lora/lora.py index 00299bf6c2a81..294b49e0a8997 100644 --- a/vllm/lora/lora.py +++ b/vllm/lora/lora.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional -from typing import Sequence as GenericSequence +from collections.abc import Sequence as GenericSequence +from typing import Optional import torch import torch.types @@ -125,11 +125,11 @@ class PackedLoRALayerWeights(LoRALayerWeights): self, module_name: str, rank: int, - lora_alphas: List[Optional[int]], - lora_a: List[Optional[torch.Tensor]], - lora_b: List[Optional[torch.Tensor]], - bias: Optional[List[Optional[torch.Tensor]]] = None, - scaling: Optional[List[float]] = None, + lora_alphas: list[Optional[int]], + lora_a: list[Optional[torch.Tensor]], + lora_b: list[Optional[torch.Tensor]], + bias: Optional[list[Optional[torch.Tensor]]] = None, + scaling: Optional[list[float]] = None, ) -> None: super().__init__( module_name=module_name, diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 9f9d808679d74..959fe4a672a6d 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -4,9 +4,9 @@ import copy import math import os import re +from collections.abc import Sequence from dataclasses import dataclass, field -from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Type, - Union) +from typing import Any, Callable, Optional, Union import safetensors.torch import torch @@ -44,12 +44,12 @@ _GLOBAL_LORA_ID = 0 class LongContextLoRAContext: """Context for lora adapters that support long context.""" # The scaling factors to support long context lora fine tuned models. - scaling_factors: List[float] + scaling_factors: list[float] # dimension to apply rotary embedding. rot_dim: int # offsets to the sin_cos_cache for each lora_id loaded. # This value is dynamically modified. - offsets_by_lora_id: Dict[int, int] = field(default_factory=dict) + offsets_by_lora_id: dict[int, int] = field(default_factory=dict) def get_lora_id(): @@ -65,7 +65,7 @@ class LoRAModel(AdapterModel): self, lora_model_id: int, rank: int, - loras: Dict[str, LoRALayerWeights], + loras: dict[str, LoRALayerWeights], scaling_factor: Optional[float] = None, ) -> None: """ @@ -84,7 +84,7 @@ class LoRAModel(AdapterModel): lora_model_id > 0), f"a valid lora id should be greater than 0, got {self.id}" self.rank = rank - self.loras: Dict[str, LoRALayerWeights] = loras + self.loras: dict[str, LoRALayerWeights] = loras def clone(self, lora_model_id: int) -> "LoRAModel": """Return a copy of the object with different ids. @@ -113,19 +113,19 @@ class LoRAModel(AdapterModel): def from_lora_tensors( cls, lora_model_id: int, - tensors: Dict[str, torch.Tensor], + tensors: dict[str, torch.Tensor], peft_helper: PEFTHelper, device: str = "cuda", dtype: Optional[torch.dtype] = None, - embeddings: Optional[Dict[str, torch.Tensor]] = None, + embeddings: Optional[dict[str, torch.Tensor]] = None, target_embedding_padding: Optional[int] = None, - embedding_modules: Optional[Dict[str, str]] = None, - embedding_padding_modules: Optional[List[str]] = None, + embedding_modules: Optional[dict[str, str]] = None, + embedding_padding_modules: Optional[list[str]] = None, weights_mapper: Optional[WeightsMapper] = None, ) -> "LoRAModel": """Create a LoRAModel from a dictionary of tensors.""" pin_memory = str(device) == "cpu" and is_pin_memory_available() - loras: Dict[str, LoRALayerWeights] = {} + loras: dict[str, LoRALayerWeights] = {} for tensor_name, tensor in tensors.items(): module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name( tensor_name, weights_mapper) @@ -187,15 +187,15 @@ class LoRAModel(AdapterModel): def from_local_checkpoint( cls, lora_dir: str, - expected_lora_modules: List[str], + expected_lora_modules: list[str], peft_helper: PEFTHelper, *, lora_model_id: Optional[int] = None, device: str = "cuda", dtype: Optional[torch.dtype] = None, target_embedding_padding: Optional[int] = None, - embedding_modules: Optional[Dict[str, str]] = None, - embedding_padding_modules: Optional[List[str]] = None, + embedding_modules: Optional[dict[str, str]] = None, + embedding_padding_modules: Optional[list[str]] = None, weights_mapper: Optional[WeightsMapper] = None, ) -> "LoRAModel": """Create a LoRAModel from a local checkpoint. @@ -220,9 +220,9 @@ class LoRAModel(AdapterModel): new_embeddings_bin_file_path = os.path.join(lora_dir, "new_embeddings.bin") - unexpected_modules: List[Union[list[str], str]] + unexpected_modules: list[Union[list[str], str]] if os.path.isfile(lora_tensor_path): - tensors: Dict[str, torch.Tensor] = {} + tensors: dict[str, torch.Tensor] = {} # Find unexpected modules. # Use safetensor key as a source of truth to find expected modules. # in peft if you have target_modules A, B, C and C does not exist @@ -329,7 +329,7 @@ class LoRAModelManager(AdapterModelManager): self.max_num_seqs = max_num_seqs assert self.capacity >= self.lora_slots self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8 - self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots + self.lora_index_to_id: list[Optional[int]] = [None] * self.lora_slots self.vocab_size = vocab_size self.long_lora_context: Optional[LongContextLoRAContext] = None self.punica_wrapper = get_punica_wrapper( @@ -339,7 +339,7 @@ class LoRAModelManager(AdapterModelManager): max_loras=self.lora_config.max_loras) # Scaling factor -> offset to the sin_cos_cache to it. # Used for long context lora. - self.scaling_factor_to_offset: Dict[float, int] = {} + self.scaling_factor_to_offset: dict[float, int] = {} super().__init__(model) self.supported_lora_modules = get_supported_lora_modules(self.model) @@ -358,9 +358,9 @@ class LoRAModelManager(AdapterModelManager): # text modules (e.g. ChatGLM) and hasattr(self.model, "get_mm_mapping")) self.is_pooling_model = is_pooling_model(self.model) - self.packed_modules: Dict[str, List[str]] = {} - self.modules: Dict[str, BaseLayerWithLoRA] = {} - # Dict instead of a Set for compatibility with LRUCache. + self.packed_modules: dict[str, list[str]] = {} + self.modules: dict[str, BaseLayerWithLoRA] = {} + # Dict instead of a set for compatibility with LRUCache. self._last_mapping: Optional[LoRAMapping] = None self._create_lora_modules() self.model.lora_manager = self @@ -530,7 +530,7 @@ class LoRAModelManager(AdapterModelManager): lora_id: int, rank: int, scaling_factor: Optional[float], - embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel: + embedding_modules: Optional[dict[str, str]] = None) -> LoRAModel: """Create zero-initialized LoRAModel for warmup.""" model = LoRAModel(lora_id, rank, {}, scaling_factor) for module_name, module in self.model.named_modules(): @@ -578,7 +578,7 @@ class LoRAModelManager(AdapterModelManager): else: parts = module_name.split(".") replacements = self.packed_modules_mapping[parts[-1]] - subloras: List[Optional[LoRALayerWeights]] = [] + subloras: list[Optional[LoRALayerWeights]] = [] for i, r in enumerate(replacements): lora = LoRALayerWeights.create_dummy_lora_weights( module_name + "." + r, @@ -630,8 +630,8 @@ class LoRAModelManager(AdapterModelManager): def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None: for module_name, new_module_names in self.packed_modules.items(): - replacement_loras: List[Optional[LoRALayerWeights]] = [] - replaced_module: Set[str] = set() + replacement_loras: list[Optional[LoRALayerWeights]] = [] + replaced_module: set[str] = set() has_replacement = False for r in new_module_names: lora = self._get_lora_layer_weights(lora_model, r) @@ -694,7 +694,7 @@ class LoRAModelManager(AdapterModelManager): return remove_adapter(adapter_id, self._registered_adapters, self.deactivate_adapter) - def list_adapters(self) -> Dict[int, Any]: + def list_adapters(self) -> dict[int, Any]: return list_adapters(self._registered_adapters) def get_adapter(self, adapter_id: int) -> Optional[Any]: @@ -721,7 +721,7 @@ class LRUCacheLoRAModelManager(LoRAModelManager): self._active_adapters: LoRALRUCache = LoRALRUCache( self.lora_slots, self._deactivate_adapter) - def list_adapters(self) -> Dict[int, LoRAModel]: + def list_adapters(self) -> dict[int, LoRAModel]: """List all registered LoRAModels.""" return dict(self._registered_adapters.cache) @@ -786,7 +786,7 @@ def create_lora_manager( vocab_size: int, lora_config: LoRAConfig, device: torch.device, - lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager, + lora_manager_cls: type[LoRAModelManager] = LoRAModelManager, **kwargs) -> LoRAModelManager: """Create a LoRA adapter for a given model.""" if not hasattr(model, "packed_modules_mapping"): diff --git a/vllm/lora/ops/triton_ops/lora_expand_op.py b/vllm/lora/ops/triton_ops/lora_expand_op.py index e41ae1d9594a7..13ddaaf961f7b 100644 --- a/vllm/lora/ops/triton_ops/lora_expand_op.py +++ b/vllm/lora/ops/triton_ops/lora_expand_op.py @@ -6,8 +6,6 @@ Punica: Multi-Tenant LoRA Serving. https://arxiv.org/abs/2310.18547 """ -from typing import List - import torch import triton import triton.language as tl @@ -127,7 +125,7 @@ def _lora_expand_kernel( @torch.inference_mode() def _lora_expand( inputs: torch.Tensor, # shape [num_slices, num_tokens, lora_rank] - lora_b_weights: List[ + lora_b_weights: list[ torch.Tensor], # shape [num_lora, hidden_size, lora_rank] output_tensor: torch. Tensor, # shape [num_tokens, hidden_size * num_slices] @@ -143,7 +141,7 @@ def _lora_expand( """ Args: inputs (torch.Tensor): input tensor - lora_b_weights (List[torch.Tensor]): lora'b weight + lora_b_weights (list[torch.Tensor]): lora'b weight output_tensor (torch.Tensor): output tensor token_lora_mapping (torch.Tensor): A tensor mapping each input token to the lora-id related to that token. A value of -1 indicates that @@ -264,7 +262,7 @@ def _lora_expand( def _lora_expand_fake( inputs: torch.Tensor, - lora_b_weights: List[torch.Tensor], + lora_b_weights: list[torch.Tensor], output_tensor: torch.Tensor, token_lora_mapping: torch.Tensor, token_indices_sorted_by_lora_ids: torch.Tensor, diff --git a/vllm/lora/ops/triton_ops/lora_kernel_metadata.py b/vllm/lora/ops/triton_ops/lora_kernel_metadata.py index 055e78f406f3e..ac459a83220c7 100644 --- a/vllm/lora/ops/triton_ops/lora_kernel_metadata.py +++ b/vllm/lora/ops/triton_ops/lora_kernel_metadata.py @@ -4,7 +4,7 @@ LoRA kernels metadata preparation utilities. """ from dataclasses import dataclass -from typing import Tuple, Union +from typing import Union import torch @@ -125,7 +125,7 @@ class LoRAKernelMeta: def meta_args( self, token_nums: int - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ This function returns the kernel metadata required for the current diff --git a/vllm/lora/ops/triton_ops/lora_shrink_op.py b/vllm/lora/ops/triton_ops/lora_shrink_op.py index fb0422cf0b0ee..c3871bd58ffa1 100644 --- a/vllm/lora/ops/triton_ops/lora_shrink_op.py +++ b/vllm/lora/ops/triton_ops/lora_shrink_op.py @@ -6,8 +6,6 @@ Punica: Multi-Tenant LoRA Serving. https://arxiv.org/abs/2310.18547 """ -from typing import List - import torch import triton import triton.language as tl @@ -98,7 +96,7 @@ def _lora_shrink_kernel(input_ptr, lora_ptr, out_ptr, M, N, K, @torch.inference_mode() def _lora_shrink( inputs: torch.Tensor, # shape [num_tokens, hidden_size] - lora_a_weights: List[ + lora_a_weights: list[ torch.Tensor], # shape [num_loras, lora_rank, hidden_size] output_tensor: torch.Tensor, # shape [num_slices, num_tokens, lora_rank] token_lora_mapping: torch.Tensor, # shape [num_tokens] @@ -112,7 +110,7 @@ def _lora_shrink( """ Args: inputs (torch.Tensor): Input tensor - lora_a_weights (List[torch.Tensor]): LoRA weights + lora_a_weights (list[torch.Tensor]): LoRA weights output_tensor (torch.Tensor): output tensor token_lora_mapping (torch.Tensor): A tensor mapping each input token to the lora-id related to that token. A value of -1 indicates that @@ -219,7 +217,7 @@ def _lora_shrink( def _lora_shrink_fake( inputs: torch.Tensor, - lora_a_weights: List[torch.Tensor], + lora_a_weights: list[torch.Tensor], output_tensor: torch.Tensor, token_lora_mapping: torch.Tensor, token_indices_sorted_by_lora_ids: torch.Tensor, diff --git a/vllm/lora/ops/triton_ops/utils.py b/vllm/lora/ops/triton_ops/utils.py index f779bbccd31ad..6225635c2955f 100644 --- a/vllm/lora/ops/triton_ops/utils.py +++ b/vllm/lora/ops/triton_ops/utils.py @@ -1,14 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, List, Tuple - import torch -_LORA_A_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {} -_LORA_B_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {} +_LORA_A_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {} +_LORA_B_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {} -def _get_lora_a_ptr(lora_a_weights: List[torch.Tensor], device: torch.device): +def _get_lora_a_ptr(lora_a_weights: list[torch.Tensor], device: torch.device): """ `_LORA_A_PTR_DICT` collects the required information during `profile_run`, After this, it remains constant and subsequent usage is through LUT. @@ -53,7 +51,7 @@ def _get_lora_a_ptr(lora_a_weights: List[torch.Tensor], device: torch.device): return _LORA_A_PTR_DICT.get(key) -def _get_lora_b_ptr(lora_weights: List[torch.Tensor], offset_start: int, +def _get_lora_b_ptr(lora_weights: list[torch.Tensor], offset_start: int, device: torch.device): """ `_LORA_B_PTR_DICT` collects the required information during `profile_run`, diff --git a/vllm/lora/peft_helper.py b/vllm/lora/peft_helper.py index f6944368b36ee..d5de63f5baade 100644 --- a/vllm/lora/peft_helper.py +++ b/vllm/lora/peft_helper.py @@ -6,7 +6,7 @@ import json import math import os from dataclasses import MISSING, dataclass, field, fields -from typing import List, Literal, Optional, Union +from typing import Literal, Optional, Union from vllm.config import LoRAConfig from vllm.logger import init_logger @@ -40,7 +40,7 @@ class PEFTHelper: vllm_max_position_embeddings: Optional[int] = field(default=False) vllm_long_context_scaling_factor: Optional[float] = field(default=None) - def _validate_features(self) -> List[str]: + def _validate_features(self) -> list[str]: """ Check if there are any unsupported LoRA features. """ diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py index 78866c51895bb..e03f7329021b3 100644 --- a/vllm/lora/punica_wrapper/punica_base.py +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -7,7 +7,7 @@ https://arxiv.org/abs/2310.18547 """ from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Optional, Union import torch @@ -28,7 +28,7 @@ class PunicaWrapperABC(ABC): def update_metadata( self, mapping: "LoRAMapping", - lora_index_to_id: List[Optional[int]], + lora_index_to_id: list[Optional[int]], max_loras: int, vocab_size: int, extra_vocab_size: int, @@ -43,9 +43,9 @@ class PunicaWrapperABC(ABC): @abstractmethod def add_shrink( self, - y: Union[Tuple[torch.Tensor, ...], torch.Tensor], + y: Union[tuple[torch.Tensor, ...], torch.Tensor], x: torch.Tensor, - lora_a_stacked: Tuple[torch.Tensor, ...], + lora_a_stacked: tuple[torch.Tensor, ...], scale: float, **kwargs, ) -> Optional[torch.Tensor]: @@ -59,10 +59,10 @@ class PunicaWrapperABC(ABC): def add_expand( self, y: torch.Tensor, - x: Union[Tuple[torch.Tensor, ...], torch.Tensor], - lora_b_stacked: Tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], - output_slices: Tuple[int, ...], + x: Union[tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], + output_slices: tuple[int, ...], offset_start: int = 0, add_inputs=True, **kwargs, @@ -91,13 +91,13 @@ class PunicaWrapperABC(ABC): def add_lora_linear(self, y: torch.Tensor, x: torch.Tensor, - lora_a_stacked: Tuple[torch.Tensor, ...], - lora_b_stacked: Tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + lora_a_stacked: tuple[torch.Tensor, ...], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], scale: float, - output_slices: Tuple[int, ...], + output_slices: tuple[int, ...], *, - buffer: Optional[Tuple[torch.Tensor, ...]] = None, + buffer: Optional[tuple[torch.Tensor, ...]] = None, **kwargs) -> Optional[torch.Tensor]: """ Applicable to linear-related lora. @@ -150,7 +150,7 @@ class PunicaWrapperBase(PunicaWrapperABC): # 5 is the number of indices tensors. # base_indices, sampler_indices, sampler_indices_padded, # embeddings_indices,long_lora_indices - self.indices_len: List[Optional[int]] = [None] * 5 + self.indices_len: list[Optional[int]] = [None] * 5 # these attributes are the information required for sgmv kernel self._seq_start_locs = torch.empty(max_batches, dtype=torch.long, @@ -171,7 +171,7 @@ class PunicaWrapperBase(PunicaWrapperABC): def _update_base_metadata( self, mapping: "LoRAMapping", - lora_index_to_id: List[Optional[int]], + lora_index_to_id: list[Optional[int]], max_loras: int, vocab_size: int, extra_vocab_size: int, @@ -228,8 +228,8 @@ class PunicaWrapperBase(PunicaWrapperABC): self, indices: torch.Tensor, output: torch.Tensor, - output_slices: Tuple[int, ...], - lora_bias_stacked: Tuple[Optional[torch.Tensor], ...], + output_slices: tuple[int, ...], + lora_bias_stacked: tuple[Optional[torch.Tensor], ...], ): """Applies bias to output @@ -259,7 +259,7 @@ class PunicaWrapperBase(PunicaWrapperABC): @property def prefill_metadata( self - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]: """ This property provides a convenient way to access the necessary metadata for prefill-related kernel computations. @@ -323,7 +323,7 @@ class PunicaWrapperBase(PunicaWrapperABC): def update_metadata( self, mapping: "LoRAMapping", - lora_index_to_id: List[Optional[int]], + lora_index_to_id: list[Optional[int]], max_loras: int, vocab_size: int, extra_vocab_size: int, @@ -341,8 +341,8 @@ class PunicaWrapperBase(PunicaWrapperABC): self.is_prefill = False @abstractmethod - def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], - x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], + def add_shrink(self, y: Union[tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...], scale: float, **kwargs) -> Optional[torch.Tensor]: """ Performs GEMM for multiple slices of lora_a. @@ -352,9 +352,9 @@ class PunicaWrapperBase(PunicaWrapperABC): y[i] += (x @ lora_a_stacked[i]) * scale Args: - y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors + y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors x (torch.Tensor): Input tensor - lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights + lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights scale (float): Scaling factor for the operation """ @@ -364,10 +364,10 @@ class PunicaWrapperBase(PunicaWrapperABC): @abstractmethod def add_expand(self, y: torch.Tensor, - x: Union[Tuple[torch.Tensor, ...], torch.Tensor], - lora_b_stacked: Tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], - output_slices: Tuple[int, ...], + x: Union[tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], + output_slices: tuple[int, ...], offset_start: int = 0, add_inputs=True, **kwargs) -> Optional[torch.Tensor]: @@ -384,11 +384,11 @@ class PunicaWrapperBase(PunicaWrapperABC): Args: y (torch.Tensor): Output tensor. - x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors - lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight - lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): + x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors + lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight + lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): bias's weight - output_slices (Tuple[int, ...]): Every slice's size + output_slices (tuple[int, ...]): Every slice's size offset_start (int): The starting position of y, defaults to 0 add_inputs (bool): Defaults to True. @@ -422,13 +422,13 @@ class PunicaWrapperBase(PunicaWrapperABC): def add_lora_linear(self, y: torch.Tensor, x: torch.Tensor, - lora_a_stacked: Tuple[torch.Tensor, ...], - lora_b_stacked: Tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + lora_a_stacked: tuple[torch.Tensor, ...], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], scale: float, - output_slices: Tuple[int, ...], + output_slices: tuple[int, ...], *, - buffer: Optional[Tuple[torch.Tensor, ...]] = None, + buffer: Optional[tuple[torch.Tensor, ...]] = None, **kwargs) -> Optional[torch.Tensor]: """ Applicable to linear-related lora. @@ -445,12 +445,12 @@ class PunicaWrapperBase(PunicaWrapperABC): Args: y (torch.Tensor): Output tensor. Will be changed in-place. x (torch.Tensor): Input tensor - lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. - lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. - lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. + lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight. + lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight. + lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias. scale (float): Scaling factor. - output_slices (Tuple[int, ...]): Every slice's size. - buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None. + output_slices (tuple[int, ...]): Every slice's size. + buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None. """ # TODO: implement it based on torch ops raise NotImplementedError diff --git a/vllm/lora/punica_wrapper/punica_cpu.py b/vllm/lora/punica_wrapper/punica_cpu.py index 29428f4cfff31..8118a72d696a2 100644 --- a/vllm/lora/punica_wrapper/punica_cpu.py +++ b/vllm/lora/punica_wrapper/punica_cpu.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, Optional, Tuple, Union +from typing import Callable, Optional, Union import torch @@ -150,8 +150,8 @@ class PunicaWrapperCPU(PunicaWrapperBase): shrink_fun(y, x, w_t_all, scale) y = y.view_as(y_org) - def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], - x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], + def add_shrink(self, y: Union[tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...], scale: float, **kwargs): """ Performs GEMM for multiple slices of lora_a. @@ -165,9 +165,9 @@ class PunicaWrapperCPU(PunicaWrapperBase): y[i] += (x @ lora_a_stacked[i]) * scale Args: - y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors + y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors x (torch.Tensor): Input tensor - lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights + lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights scale (float): Scaling factor for the operation """ @@ -179,10 +179,10 @@ class PunicaWrapperCPU(PunicaWrapperBase): def add_expand(self, y: torch.Tensor, - x: Union[Tuple[torch.Tensor, ...], torch.Tensor], - lora_b_stacked: Tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], - output_slices: Tuple[int, ...], + x: Union[tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], + output_slices: tuple[int, ...], offset_start: int = 0, add_inputs=True, **kwargs) -> None: @@ -198,11 +198,11 @@ class PunicaWrapperCPU(PunicaWrapperBase): Args: y (torch.Tensor): Output tensor. - x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors - lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight - lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): + x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors + lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight + lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): bias's weight - output_slices (Tuple[int, ...]): Every slice's size + output_slices (tuple[int, ...]): Every slice's size add_inputs (bool): Defaults to True. """ y_org = y @@ -250,13 +250,13 @@ class PunicaWrapperCPU(PunicaWrapperBase): def add_lora_linear(self, y: torch.Tensor, x: torch.Tensor, - lora_a_stacked: Tuple[torch.Tensor, ...], - lora_b_stacked: Tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + lora_a_stacked: tuple[torch.Tensor, ...], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], scale: float, - output_slices: Tuple[int, ...], + output_slices: tuple[int, ...], *, - buffer: Optional[Tuple[torch.Tensor, ...]] = None, + buffer: Optional[tuple[torch.Tensor, ...]] = None, **kwargs) -> None: """ Applicable to linear-related lora. @@ -273,12 +273,12 @@ class PunicaWrapperCPU(PunicaWrapperBase): Args: y (torch.Tensor): Output tensor. Will be changed in-place. x (torch.Tensor): Input tensor - lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. - lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. - lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. + lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight. + lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight. + lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias. scale (float): Scaling factor. - output_slices (Tuple[int, ...]): Every slice's size. - buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None. + output_slices (tuple[int, ...]): Every slice's size. + buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None. """ assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index bb6d2808e46a1..224640ec71925 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -6,7 +6,7 @@ Punica: Multi-Tenant LoRA Serving. https://arxiv.org/abs/2310.18547 """ -from typing import TYPE_CHECKING, List, Optional, Tuple, Union, final +from typing import TYPE_CHECKING, Optional, Union, final import torch @@ -57,7 +57,7 @@ class PunicaWrapperGPU(PunicaWrapperBase): def update_metadata( self, mapping: LoRAMapping, - lora_index_to_id: List[Optional[int]], + lora_index_to_id: list[Optional[int]], max_loras: int, vocab_size: int, extra_vocab_size: int, @@ -74,7 +74,7 @@ class PunicaWrapperGPU(PunicaWrapperBase): self.prompt_mapping_meta.prepare_tensors(self.sampler_indices) def add_shrink(self, y: torch.Tensor, x: torch.Tensor, - lora_a_stacked: Tuple[torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, ...], scale: float, **kwargs): """ Performs GEMM for multiple slices of lora_a. @@ -86,7 +86,7 @@ class PunicaWrapperGPU(PunicaWrapperBase): Args: y (torch.Tensor): Output tensors x (torch.Tensor): Input tensor - lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights + lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights scale (float): Scaling factor for the operation """ @@ -102,9 +102,9 @@ class PunicaWrapperGPU(PunicaWrapperBase): def add_expand(self, y: torch.Tensor, x: torch.Tensor, - lora_b_stacked: Tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], - output_slices: Tuple[int, ...], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], + output_slices: tuple[int, ...], offset_start: int = 0, add_inputs=True, **kwargs) -> None: @@ -121,10 +121,10 @@ class PunicaWrapperGPU(PunicaWrapperBase): Args: y (torch.Tensor): Output tensor. x (torch.Tensor): Input tensors - lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight - lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): + lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight + lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): bias's weight - output_slices (Tuple[int, ...]): Every slice's size + output_slices (tuple[int, ...]): Every slice's size add_inputs (bool): Defaults to True. """ y_org = y @@ -181,11 +181,11 @@ class PunicaWrapperGPU(PunicaWrapperBase): def add_lora_linear(self, y: torch.Tensor, x: torch.Tensor, - lora_a_stacked: Tuple[torch.Tensor, ...], - lora_b_stacked: Tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + lora_a_stacked: tuple[torch.Tensor, ...], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], scale: float, - output_slices: Tuple[int, ...], + output_slices: tuple[int, ...], *, buffer: Optional[torch.Tensor] = None, **kwargs) -> None: @@ -204,11 +204,11 @@ class PunicaWrapperGPU(PunicaWrapperBase): Args: y (torch.Tensor): Output tensor. Will be changed in-place. x (torch.Tensor): Input tensor - lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. - lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. - lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. + lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight. + lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight. + lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias. scale (float): Scaling factor. - output_slices (Tuple[int, ...]): Every slice's size. + output_slices (tuple[int, ...]): Every slice's size. buffer (Optional[torch.Tensor]): Defaults to None. """ diff --git a/vllm/lora/punica_wrapper/punica_hpu.py b/vllm/lora/punica_wrapper/punica_hpu.py index 3661a7214648a..416c23e73bf85 100644 --- a/vllm/lora/punica_wrapper/punica_hpu.py +++ b/vllm/lora/punica_wrapper/punica_hpu.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import TYPE_CHECKING, List, Optional, Tuple, Union, final +from typing import TYPE_CHECKING, Optional, Union, final import torch from vllm_hpu_extension.ops import (dispatch_bgmv_embedding, @@ -28,7 +28,7 @@ class PunicaWrapperHPU(PunicaWrapperBase): def _update_base_metadata( self, mapping: "LoRAMapping", - lora_index_to_id: List[Optional[int]], + lora_index_to_id: list[Optional[int]], max_loras: int, vocab_size: int, extra_vocab_size: int, @@ -48,9 +48,9 @@ class PunicaWrapperHPU(PunicaWrapperBase): # graph accumulation. Hence HPU appends `lora_offset` to a list and # converts it to a tensor only after it is ready. if long_lora_context: - index_mapping_indices: List[int] = list( + index_mapping_indices: list[int] = list( mapping.index_mapping).copy() - long_lora_offsets: List[int] = [] + long_lora_offsets: list[int] = [] for i in range(len(index_mapping_indices)): lora_offset: int = long_lora_context.offsets_by_lora_id.get( index_mapping_indices[i], 0) @@ -85,13 +85,13 @@ class PunicaWrapperHPU(PunicaWrapperBase): def add_lora_linear(self, y: torch.Tensor, x: torch.Tensor, - lora_a_stacked: Tuple[torch.Tensor, ...], - lora_b_stacked: Tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + lora_a_stacked: tuple[torch.Tensor, ...], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], scale: float, - output_slices: Tuple[int, ...], + output_slices: tuple[int, ...], *, - buffer: Optional[Tuple[torch.Tensor, ...]] = None, + buffer: Optional[tuple[torch.Tensor, ...]] = None, **kwargs) -> None: y_org = y x = x.view(-1, x.shape[-1]) @@ -122,9 +122,9 @@ class PunicaWrapperHPU(PunicaWrapperBase): def add_shrink( self, - y: Union[Tuple[torch.Tensor, ...], torch.Tensor], + y: Union[tuple[torch.Tensor, ...], torch.Tensor], x: torch.Tensor, - lora_a_stacked: Tuple[torch.Tensor, ...], + lora_a_stacked: tuple[torch.Tensor, ...], scale: float, **kwargs, ) -> None: @@ -133,10 +133,10 @@ class PunicaWrapperHPU(PunicaWrapperBase): def add_expand( self, y: torch.Tensor, - x: Union[Tuple[torch.Tensor, ...], torch.Tensor], - lora_b_stacked: Tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], - output_slices: Tuple[int, ...], + x: Union[tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], + output_slices: tuple[int, ...], offset_start: int = 0, add_inputs=True, **kwargs, diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 37544c755d909..f3153c6dab03c 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Optional, Tuple, Union +from typing import Optional, Union import torch import torch.nn.functional as F @@ -77,8 +77,8 @@ class PunicaWrapperTPU(PunicaWrapperBase): self._get_token_lora_indices(x), y_offset, y_slice_size, add_inputs) - def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], - x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], + def add_shrink(self, y: Union[tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...], scale: float, **kwargs) -> Optional[torch.Tensor]: """ Performs GEMM for multiple slices of lora_a. @@ -88,9 +88,9 @@ class PunicaWrapperTPU(PunicaWrapperBase): y[i] += (x @ lora_a_stacked[i]) * scale Args: - y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors + y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors x (torch.Tensor): Input tensor - lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights + lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights scale (float): Scaling factor for the operation """ @@ -106,10 +106,10 @@ class PunicaWrapperTPU(PunicaWrapperBase): def add_expand(self, y: torch.Tensor, - x: Union[Tuple[torch.Tensor, ...], torch.Tensor], - lora_b_stacked: Tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], - output_slices: Tuple[int, ...], + x: Union[tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], + output_slices: tuple[int, ...], offset_start: int = 0, add_inputs=True, **kwargs) -> torch.Tensor: @@ -125,11 +125,11 @@ class PunicaWrapperTPU(PunicaWrapperBase): Args: y (torch.Tensor): Output tensor. - x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors - lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight - lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): + x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors + lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight + lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): bias's weight - output_slices (Tuple[int, ...]): Every slice's size + output_slices (tuple[int, ...]): Every slice's size add_inputs (bool): Defaults to True. """ y_org = y @@ -177,13 +177,13 @@ class PunicaWrapperTPU(PunicaWrapperBase): def add_lora_linear(self, y: torch.Tensor, x: torch.Tensor, - lora_a_stacked: Tuple[torch.Tensor, ...], - lora_b_stacked: Tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + lora_a_stacked: tuple[torch.Tensor, ...], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], scale: float, - output_slices: Tuple[int, ...], + output_slices: tuple[int, ...], *, - buffer: Optional[Tuple[torch.Tensor, ...]] = None, + buffer: Optional[tuple[torch.Tensor, ...]] = None, **kwargs) -> torch.Tensor: """ Applicable to linear-related lora. @@ -200,12 +200,12 @@ class PunicaWrapperTPU(PunicaWrapperBase): Args: y (torch.Tensor): Output tensor. Will not be changed in-place. x (torch.Tensor): Input tensor (T, E) - lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. - lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. - lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. + lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight. + lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight. + lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias. scale (float): Scaling factor. - output_slices (Tuple[int, ...]): Every slice's size. - buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None. + output_slices (tuple[int, ...]): Every slice's size. + buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None. """ assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) @@ -284,8 +284,8 @@ class PunicaWrapperTPU(PunicaWrapperBase): self, indices: torch.Tensor, output: torch.Tensor, - output_slices: Tuple[int, ...], - lora_bias_stacked: Tuple[Optional[torch.Tensor], ...], + output_slices: tuple[int, ...], + lora_bias_stacked: tuple[Optional[torch.Tensor], ...], ): """Applies bias to output diff --git a/vllm/lora/punica_wrapper/utils.py b/vllm/lora/punica_wrapper/utils.py index f4e5542b177d4..1adb40b4c284b 100644 --- a/vllm/lora/punica_wrapper/utils.py +++ b/vllm/lora/punica_wrapper/utils.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Optional, Union import torch @@ -12,7 +12,7 @@ if TYPE_CHECKING: def compute_meta( token_lora_tensor: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]: """ Get the information required for the sgmv kernel. With the features: 1. If consecutive requests in the batch use the same LoRA, this function @@ -43,14 +43,14 @@ def compute_meta( # TODO see if this can be vectorized def convert_mapping( mapping: "LoRAMapping", - lora_index_to_id: List[Optional[int]], + lora_index_to_id: list[Optional[int]], max_loras: int, vocab_size: int, extra_vocab_size: int, device: torch.device, long_lora_context: Optional["LongContextLoRAContext"] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, - Optional[torch.Tensor], List[int]]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, + Optional[torch.Tensor], list[int]]: """Converts LoRAMapping to index tensors. Args: @@ -84,7 +84,7 @@ def convert_mapping( (base_indices, sampler_indices, sampler_indices_padded, embeddings_indices, long_lora_indices). """ - index_mapping_indices: List[int] = list(mapping.index_mapping).copy() + index_mapping_indices: list[int] = list(mapping.index_mapping).copy() embedding_indices = index_mapping_indices.copy() lora_indices = index_mapping_indices.copy() long_lora_offsets: Optional[torch.Tensor] = None @@ -92,7 +92,7 @@ def convert_mapping( long_lora_offsets = torch.zeros(len(index_mapping_indices), device=device, dtype=torch.long) - prompt_mapping: List[int] = [ + prompt_mapping: list[int] = [ lora_index_to_id.index(x) if x > 0 else -1 for x in mapping.prompt_mapping ] @@ -109,7 +109,7 @@ def convert_mapping( index_mapping_indices[i], 0) long_lora_offsets[i] = lora_offset - indices_list: List[Union[List[int], torch.Tensor]] = [ + indices_list: list[Union[list[int], torch.Tensor]] = [ index_mapping_indices, lora_indices, embedding_indices, diff --git a/vllm/lora/resolver.py b/vllm/lora/resolver.py index 6726ca9a903ff..33f35322fe85f 100644 --- a/vllm/lora/resolver.py +++ b/vllm/lora/resolver.py @@ -1,8 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod +from collections.abc import Set from dataclasses import dataclass, field -from typing import AbstractSet, Dict, Optional +from typing import Optional from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -40,9 +41,9 @@ class LoRAResolver(ABC): @dataclass class _LoRAResolverRegistry: - resolvers: Dict[str, LoRAResolver] = field(default_factory=dict) + resolvers: dict[str, LoRAResolver] = field(default_factory=dict) - def get_supported_resolvers(self) -> AbstractSet[str]: + def get_supported_resolvers(self) -> Set[str]: """Get all registered resolver names.""" return self.resolvers.keys() diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 01064e5d007ec..b66850d4304f1 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -2,7 +2,7 @@ import os import re -from typing import List, Optional, Set, Tuple, Type, Union +from typing import Optional, Union import huggingface_hub from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError, @@ -37,7 +37,7 @@ from vllm.model_executor.models.utils import WeightsMapper logger = init_logger(__name__) -_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = { +_all_lora_classes: set[type[BaseLayerWithLoRA]] = { VocabParallelEmbeddingWithLoRA, ColumnParallelLinearWithLoRA, MergedColumnParallelLinearWithLoRA, @@ -58,7 +58,7 @@ _all_lora_classes: Set[Type[BaseLayerWithLoRA]] = { def from_layer(layer: nn.Module, max_loras: int, lora_config: LoRAConfig, - packed_modules_list: List, + packed_modules_list: list, model_config: Optional[PretrainedConfig] = None) -> nn.Module: for lora_cls in _all_lora_classes: # specifying kwargs so they can be easily accessed in decorator @@ -99,7 +99,7 @@ def replace_submodule(model: nn.Module, module_name: str, def parse_fine_tuned_lora_name( name: str, weights_mapper: Optional[WeightsMapper] = None -) -> Tuple[str, bool, bool]: +) -> tuple[str, bool, bool]: """Parse the name of lora weights. args: @@ -108,7 +108,7 @@ def parse_fine_tuned_lora_name( weights_mapper: maps the name of weight, e.g. `model.` -> `language_model.model.`, return: - Tuple(module_name, is_lora_a): + tuple(module_name, is_lora_a): module_name: the name of the module, e.g. model.dense1, is_lora_a whether the tensor is lora_a or lora_b. is_bias whether the tensor is lora bias. @@ -147,8 +147,8 @@ def parse_fine_tuned_lora_name( raise ValueError(f"{name} is unsupported LoRA weight") -def is_regex_target_modules(load_modules: Union[str, List[str]], - expected_lora_modules: List[str]) -> bool: +def is_regex_target_modules(load_modules: Union[str, list[str]], + expected_lora_modules: list[str]) -> bool: """ PEFT supports passing `target_modules` in the form of regular expressions, such as `model.*(q_proj|k_proj|v_proj)$`. This function is mainly used to @@ -179,11 +179,11 @@ def is_regex_target_modules(load_modules: Union[str, List[str]], return False -def get_supported_lora_modules(model: nn.Module) -> List[str]: +def get_supported_lora_modules(model: nn.Module) -> list[str]: """ In vLLM, all linear layers support LoRA. """ - supported_lora_modules: Set[str] = set() + supported_lora_modules: set[str] = set() # step1: traverse the model to get all the linear subfixes. for name, module in model.named_modules(): if isinstance(module, (LinearBase, )): diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 108beb34b244a..8e5bc61066593 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from contextlib import contextmanager -from typing import Any, Dict, List, Literal, Optional, Set, Type, Union +from typing import Any, Literal, Optional, Union import torch @@ -27,7 +27,7 @@ class WorkerLoRAManager(AbstractWorkerManager): Every request, the requested LoRAs will be loaded (unless they are already loaded), and every other LoRA will be unloaded.""" - _manager_cls: Type[LoRAModelManager] = LoRAModelManager + _manager_cls: type[LoRAModelManager] = LoRAModelManager def __init__( self, @@ -36,9 +36,9 @@ class WorkerLoRAManager(AbstractWorkerManager): vocab_size: int, lora_config: LoRAConfig, device: torch.device, - embedding_modules: Dict[str, str], - embedding_padding_modules: List[str], - lora_model_cls: Type[LoRAModel] = LoRAModel, + embedding_modules: dict[str, str], + embedding_padding_modules: list[str], + lora_model_cls: type[LoRAModel] = LoRAModel, max_position_embeddings: Optional[int] = None, ): self._lora_model_cls = lora_model_cls @@ -88,7 +88,7 @@ class WorkerLoRAManager(AbstractWorkerManager): self._adapter_manager.supported_lora_modules) packed_modules_mapping = ( self._adapter_manager.packed_modules_mapping) - expected_lora_modules: List[str] = [] + expected_lora_modules: list[str] = [] for module in supported_lora_modules: if module in packed_modules_mapping: expected_lora_modules.extend( @@ -162,12 +162,12 @@ class WorkerLoRAManager(AbstractWorkerManager): def pin_adapter(self, adapter_id: int) -> bool: return self._adapter_manager.pin_adapter(adapter_id) - def set_active_adapters(self, requests: Set[Any], + def set_active_adapters(self, requests: set[Any], mapping: Optional[Any]) -> None: set_active_adapters_worker(requests, mapping, self._apply_adapters, self._adapter_manager.set_adapter_mapping) - def _apply_adapters(self, adapter_requests: Set[Any]) -> None: + def _apply_adapters(self, adapter_requests: set[Any]) -> None: apply_adapters_worker(adapter_requests, self.list_adapters, self._adapter_manager.adapter_slots, self.remove_adapter, self.add_adapter) @@ -184,7 +184,7 @@ class WorkerLoRAManager(AbstractWorkerManager): def remove_all_adapters(self): self._adapter_manager.remove_all_adapters() - def list_adapters(self) -> Set[int]: + def list_adapters(self) -> set[int]: return list_adapters_worker(self._adapter_manager.list_adapters) @@ -195,7 +195,7 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager): (unless they are already loaded) and least recently used LoRAs will be unloaded if the cache is above capacity.""" - _manager_cls: Type[LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager + _manager_cls: type[LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager def create_lora_manager( self, @@ -213,7 +213,7 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager): self._adapter_manager = lora_manager return lora_manager.model - def _apply_adapters(self, lora_requests: Set[LoRARequest]) -> None: + def _apply_adapters(self, lora_requests: set[LoRARequest]) -> None: loras_map = { lora_request.lora_int_id: lora_request for lora_request in lora_requests if lora_request