Update deprecated type hinting in platform, plugins, triton_utils, vllm_flash_attn (#18129)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-05-14 13:28:16 +01:00 committed by GitHub
parent dc372b9c8a
commit c8ea982d9b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 18 additions and 24 deletions

View File

@ -78,13 +78,8 @@ exclude = [
"vllm/executor/**/*.py" = ["UP006", "UP035"]
"vllm/model_executor/model_loader/**/*.py" = ["UP006", "UP035"]
"vllm/model_executor/models/**/*.py" = ["UP006", "UP035"]
"vllm/platforms/**/*.py" = ["UP006", "UP035"]
"vllm/plugins/**/*.py" = ["UP006", "UP035"]
"vllm/prompt_adapter/**/*.py" = ["UP006", "UP035"]
"vllm/spec_decode/**/*.py" = ["UP006", "UP035"]
"vllm/transformers_utils/**/*.py" = ["UP006", "UP035"]
"vllm/triton_utils/**/*.py" = ["UP006", "UP035"]
"vllm/vllm_flash_attn/**/*.py" = ["UP006", "UP035"]
"vllm/worker/**/*.py" = ["UP006", "UP035"]
"vllm/utils.py" = ["UP006", "UP035"]

View File

@ -5,8 +5,7 @@ pynvml. However, it should not initialize cuda context.
import os
from functools import wraps
from typing import (TYPE_CHECKING, Callable, List, Optional, Tuple, TypeVar,
Union)
from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union
import torch
from typing_extensions import ParamSpec
@ -56,7 +55,7 @@ class CudaPlatformBase(Platform):
device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
@property
def supported_dtypes(self) -> List[torch.dtype]:
def supported_dtypes(self) -> list[torch.dtype]:
if self.has_device_capability(80):
# Ampere and Hopper or later NVIDIA GPUs.
return [torch.bfloat16, torch.float16, torch.float32]
@ -93,7 +92,7 @@ class CudaPlatformBase(Platform):
return True
@classmethod
def is_fully_connected(cls, device_ids: List[int]) -> bool:
def is_fully_connected(cls, device_ids: list[int]) -> bool:
raise NotImplementedError
@classmethod
@ -335,7 +334,7 @@ class NvmlCudaPlatform(CudaPlatformBase):
@with_nvml_context
def has_device_capability(
cls,
capability: Union[Tuple[int, int], int],
capability: Union[tuple[int, int], int],
device_id: int = 0,
) -> bool:
try:
@ -365,7 +364,7 @@ class NvmlCudaPlatform(CudaPlatformBase):
@classmethod
@with_nvml_context
def is_fully_connected(cls, physical_device_ids: List[int]) -> bool:
def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
"""
query if the set of gpus are fully connected by nvlink (1 hop)
"""
@ -430,7 +429,7 @@ class NonNvmlCudaPlatform(CudaPlatformBase):
return device_props.total_memory
@classmethod
def is_fully_connected(cls, physical_device_ids: List[int]) -> bool:
def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
logger.exception(
"NVLink detection not possible, as context support was"
" not found. Assuming no NVLink available.")

View File

@ -4,7 +4,7 @@ import os
import platform
import random
from platform import uname
from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union
from typing import TYPE_CHECKING, NamedTuple, Optional, Union
import numpy as np
import torch
@ -200,7 +200,7 @@ class Platform:
@classmethod
def has_device_capability(
cls,
capability: Union[Tuple[int, int], int],
capability: Union[tuple[int, int], int],
device_id: int = 0,
) -> bool:
"""
@ -362,7 +362,7 @@ class Platform:
raise NotImplementedError
@classmethod
def get_infinity_values(cls, dtype: torch.dtype) -> Tuple[float, float]:
def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]:
"""
Return the platform specific values for (-inf, inf)
"""

View File

@ -2,7 +2,7 @@
import os
from functools import cache, lru_cache, wraps
from typing import TYPE_CHECKING, Dict, List, Optional
from typing import TYPE_CHECKING, Optional
import torch
@ -35,7 +35,7 @@ except ImportError as e:
logger.warning("Failed to import from vllm._rocm_C with %r", e)
# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS: List[str] = []
_ROCM_UNSUPPORTED_MODELS: list[str] = []
# Models partially supported by ROCm.
# Architecture -> Reason.
@ -43,7 +43,7 @@ _ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
"Triton flash attention. For half-precision SWA support, "
"please use CK flash attention by setting "
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
_ROCM_PARTIALLY_SUPPORTED_MODELS: dict[str, str] = {
"Qwen2ForCausalLM":
_ROCM_SWA_REASON,
"MistralForCausalLM":
@ -58,7 +58,7 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
"excessive use of shared memory. If this happens, disable Triton FA "
"by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
}
_ROCM_DEVICE_ID_NAME_MAP: Dict[str, str] = {
_ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = {
"0x74a0": "AMD_Instinct_MI300A",
"0x74a1": "AMD_Instinct_MI300X",
"0x74b5": "AMD_Instinct_MI300X", # MI300X VF
@ -203,7 +203,7 @@ class RocmPlatform(Platform):
@staticmethod
@with_amdsmi_context
def is_fully_connected(physical_device_ids: List[int]) -> bool:
def is_fully_connected(physical_device_ids: list[int]) -> bool:
"""
Query if the set of gpus are fully connected by xgmi (1 hop)
"""

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
from typing import TYPE_CHECKING, Optional, Union, cast
import torch
from tpu_info import device
@ -73,7 +73,7 @@ class TpuPlatform(Platform):
return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU"
@classmethod
def get_infinity_values(cls, dtype: torch.dtype) -> Tuple[float, float]:
def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]:
return torch.finfo(dtype).min, torch.finfo(dtype).max
@classmethod

View File

@ -2,7 +2,7 @@
import logging
import os
from typing import Callable, Dict
from typing import Callable
import torch
@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
plugins_loaded = False
def load_plugins_by_group(group: str) -> Dict[str, Callable]:
def load_plugins_by_group(group: str) -> dict[str, Callable]:
import sys
if sys.version_info < (3, 10):
from importlib_metadata import entry_points