From c8ea982d9b86e145a16092017528d068a7f94630 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Wed, 14 May 2025 13:28:16 +0100 Subject: [PATCH] Update deprecated type hinting in `platform`, `plugins`, `triton_utils`, `vllm_flash_attn` (#18129) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- pyproject.toml | 5 ----- vllm/platforms/cuda.py | 13 ++++++------- vllm/platforms/interface.py | 6 +++--- vllm/platforms/rocm.py | 10 +++++----- vllm/platforms/tpu.py | 4 ++-- vllm/plugins/__init__.py | 4 ++-- 6 files changed, 18 insertions(+), 24 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b3ca68f9f8ac..46cf7a801fd6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,13 +78,8 @@ exclude = [ "vllm/executor/**/*.py" = ["UP006", "UP035"] "vllm/model_executor/model_loader/**/*.py" = ["UP006", "UP035"] "vllm/model_executor/models/**/*.py" = ["UP006", "UP035"] -"vllm/platforms/**/*.py" = ["UP006", "UP035"] -"vllm/plugins/**/*.py" = ["UP006", "UP035"] "vllm/prompt_adapter/**/*.py" = ["UP006", "UP035"] "vllm/spec_decode/**/*.py" = ["UP006", "UP035"] -"vllm/transformers_utils/**/*.py" = ["UP006", "UP035"] -"vllm/triton_utils/**/*.py" = ["UP006", "UP035"] -"vllm/vllm_flash_attn/**/*.py" = ["UP006", "UP035"] "vllm/worker/**/*.py" = ["UP006", "UP035"] "vllm/utils.py" = ["UP006", "UP035"] diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 2343e6d82868..9163b97c51a0 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -5,8 +5,7 @@ pynvml. However, it should not initialize cuda context. import os from functools import wraps -from typing import (TYPE_CHECKING, Callable, List, Optional, Tuple, TypeVar, - Union) +from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union import torch from typing_extensions import ParamSpec @@ -56,7 +55,7 @@ class CudaPlatformBase(Platform): device_control_env_var: str = "CUDA_VISIBLE_DEVICES" @property - def supported_dtypes(self) -> List[torch.dtype]: + def supported_dtypes(self) -> list[torch.dtype]: if self.has_device_capability(80): # Ampere and Hopper or later NVIDIA GPUs. return [torch.bfloat16, torch.float16, torch.float32] @@ -93,7 +92,7 @@ class CudaPlatformBase(Platform): return True @classmethod - def is_fully_connected(cls, device_ids: List[int]) -> bool: + def is_fully_connected(cls, device_ids: list[int]) -> bool: raise NotImplementedError @classmethod @@ -335,7 +334,7 @@ class NvmlCudaPlatform(CudaPlatformBase): @with_nvml_context def has_device_capability( cls, - capability: Union[Tuple[int, int], int], + capability: Union[tuple[int, int], int], device_id: int = 0, ) -> bool: try: @@ -365,7 +364,7 @@ class NvmlCudaPlatform(CudaPlatformBase): @classmethod @with_nvml_context - def is_fully_connected(cls, physical_device_ids: List[int]) -> bool: + def is_fully_connected(cls, physical_device_ids: list[int]) -> bool: """ query if the set of gpus are fully connected by nvlink (1 hop) """ @@ -430,7 +429,7 @@ class NonNvmlCudaPlatform(CudaPlatformBase): return device_props.total_memory @classmethod - def is_fully_connected(cls, physical_device_ids: List[int]) -> bool: + def is_fully_connected(cls, physical_device_ids: list[int]) -> bool: logger.exception( "NVLink detection not possible, as context support was" " not found. Assuming no NVLink available.") diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index cf30f7529563..b09e31e9ed46 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -4,7 +4,7 @@ import os import platform import random from platform import uname -from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union +from typing import TYPE_CHECKING, NamedTuple, Optional, Union import numpy as np import torch @@ -200,7 +200,7 @@ class Platform: @classmethod def has_device_capability( cls, - capability: Union[Tuple[int, int], int], + capability: Union[tuple[int, int], int], device_id: int = 0, ) -> bool: """ @@ -362,7 +362,7 @@ class Platform: raise NotImplementedError @classmethod - def get_infinity_values(cls, dtype: torch.dtype) -> Tuple[float, float]: + def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]: """ Return the platform specific values for (-inf, inf) """ diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index f3d64f01b0f7..c8b86087578d 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -2,7 +2,7 @@ import os from functools import cache, lru_cache, wraps -from typing import TYPE_CHECKING, Dict, List, Optional +from typing import TYPE_CHECKING, Optional import torch @@ -35,7 +35,7 @@ except ImportError as e: logger.warning("Failed to import from vllm._rocm_C with %r", e) # Models not supported by ROCm. -_ROCM_UNSUPPORTED_MODELS: List[str] = [] +_ROCM_UNSUPPORTED_MODELS: list[str] = [] # Models partially supported by ROCm. # Architecture -> Reason. @@ -43,7 +43,7 @@ _ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in " "Triton flash attention. For half-precision SWA support, " "please use CK flash attention by setting " "`VLLM_USE_TRITON_FLASH_ATTN=0`") -_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = { +_ROCM_PARTIALLY_SUPPORTED_MODELS: dict[str, str] = { "Qwen2ForCausalLM": _ROCM_SWA_REASON, "MistralForCausalLM": @@ -58,7 +58,7 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = { "excessive use of shared memory. If this happens, disable Triton FA " "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`") } -_ROCM_DEVICE_ID_NAME_MAP: Dict[str, str] = { +_ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = { "0x74a0": "AMD_Instinct_MI300A", "0x74a1": "AMD_Instinct_MI300X", "0x74b5": "AMD_Instinct_MI300X", # MI300X VF @@ -203,7 +203,7 @@ class RocmPlatform(Platform): @staticmethod @with_amdsmi_context - def is_fully_connected(physical_device_ids: List[int]) -> bool: + def is_fully_connected(physical_device_ids: list[int]) -> bool: """ Query if the set of gpus are fully connected by xgmi (1 hop) """ diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index d0a5af3587c4..41ed94fb619e 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import TYPE_CHECKING, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Optional, Union, cast import torch from tpu_info import device @@ -73,7 +73,7 @@ class TpuPlatform(Platform): return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU" @classmethod - def get_infinity_values(cls, dtype: torch.dtype) -> Tuple[float, float]: + def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]: return torch.finfo(dtype).min, torch.finfo(dtype).max @classmethod diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index 389cb8728103..d72ab2bd088c 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -2,7 +2,7 @@ import logging import os -from typing import Callable, Dict +from typing import Callable import torch @@ -14,7 +14,7 @@ logger = logging.getLogger(__name__) plugins_loaded = False -def load_plugins_by_group(group: str) -> Dict[str, Callable]: +def load_plugins_by_group(group: str) -> dict[str, Callable]: import sys if sys.version_info < (3, 10): from importlib_metadata import entry_points