Update deprecated type hinting in platform, plugins, triton_utils, vllm_flash_attn (#18129)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-05-14 13:28:16 +01:00 committed by GitHub
parent dc372b9c8a
commit c8ea982d9b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 18 additions and 24 deletions

View File

@ -78,13 +78,8 @@ exclude = [
"vllm/executor/**/*.py" = ["UP006", "UP035"] "vllm/executor/**/*.py" = ["UP006", "UP035"]
"vllm/model_executor/model_loader/**/*.py" = ["UP006", "UP035"] "vllm/model_executor/model_loader/**/*.py" = ["UP006", "UP035"]
"vllm/model_executor/models/**/*.py" = ["UP006", "UP035"] "vllm/model_executor/models/**/*.py" = ["UP006", "UP035"]
"vllm/platforms/**/*.py" = ["UP006", "UP035"]
"vllm/plugins/**/*.py" = ["UP006", "UP035"]
"vllm/prompt_adapter/**/*.py" = ["UP006", "UP035"] "vllm/prompt_adapter/**/*.py" = ["UP006", "UP035"]
"vllm/spec_decode/**/*.py" = ["UP006", "UP035"] "vllm/spec_decode/**/*.py" = ["UP006", "UP035"]
"vllm/transformers_utils/**/*.py" = ["UP006", "UP035"]
"vllm/triton_utils/**/*.py" = ["UP006", "UP035"]
"vllm/vllm_flash_attn/**/*.py" = ["UP006", "UP035"]
"vllm/worker/**/*.py" = ["UP006", "UP035"] "vllm/worker/**/*.py" = ["UP006", "UP035"]
"vllm/utils.py" = ["UP006", "UP035"] "vllm/utils.py" = ["UP006", "UP035"]

View File

@ -5,8 +5,7 @@ pynvml. However, it should not initialize cuda context.
import os import os
from functools import wraps from functools import wraps
from typing import (TYPE_CHECKING, Callable, List, Optional, Tuple, TypeVar, from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union
Union)
import torch import torch
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
@ -56,7 +55,7 @@ class CudaPlatformBase(Platform):
device_control_env_var: str = "CUDA_VISIBLE_DEVICES" device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
@property @property
def supported_dtypes(self) -> List[torch.dtype]: def supported_dtypes(self) -> list[torch.dtype]:
if self.has_device_capability(80): if self.has_device_capability(80):
# Ampere and Hopper or later NVIDIA GPUs. # Ampere and Hopper or later NVIDIA GPUs.
return [torch.bfloat16, torch.float16, torch.float32] return [torch.bfloat16, torch.float16, torch.float32]
@ -93,7 +92,7 @@ class CudaPlatformBase(Platform):
return True return True
@classmethod @classmethod
def is_fully_connected(cls, device_ids: List[int]) -> bool: def is_fully_connected(cls, device_ids: list[int]) -> bool:
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
@ -335,7 +334,7 @@ class NvmlCudaPlatform(CudaPlatformBase):
@with_nvml_context @with_nvml_context
def has_device_capability( def has_device_capability(
cls, cls,
capability: Union[Tuple[int, int], int], capability: Union[tuple[int, int], int],
device_id: int = 0, device_id: int = 0,
) -> bool: ) -> bool:
try: try:
@ -365,7 +364,7 @@ class NvmlCudaPlatform(CudaPlatformBase):
@classmethod @classmethod
@with_nvml_context @with_nvml_context
def is_fully_connected(cls, physical_device_ids: List[int]) -> bool: def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
""" """
query if the set of gpus are fully connected by nvlink (1 hop) query if the set of gpus are fully connected by nvlink (1 hop)
""" """
@ -430,7 +429,7 @@ class NonNvmlCudaPlatform(CudaPlatformBase):
return device_props.total_memory return device_props.total_memory
@classmethod @classmethod
def is_fully_connected(cls, physical_device_ids: List[int]) -> bool: def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
logger.exception( logger.exception(
"NVLink detection not possible, as context support was" "NVLink detection not possible, as context support was"
" not found. Assuming no NVLink available.") " not found. Assuming no NVLink available.")

View File

@ -4,7 +4,7 @@ import os
import platform import platform
import random import random
from platform import uname from platform import uname
from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union from typing import TYPE_CHECKING, NamedTuple, Optional, Union
import numpy as np import numpy as np
import torch import torch
@ -200,7 +200,7 @@ class Platform:
@classmethod @classmethod
def has_device_capability( def has_device_capability(
cls, cls,
capability: Union[Tuple[int, int], int], capability: Union[tuple[int, int], int],
device_id: int = 0, device_id: int = 0,
) -> bool: ) -> bool:
""" """
@ -362,7 +362,7 @@ class Platform:
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
def get_infinity_values(cls, dtype: torch.dtype) -> Tuple[float, float]: def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]:
""" """
Return the platform specific values for (-inf, inf) Return the platform specific values for (-inf, inf)
""" """

View File

@ -2,7 +2,7 @@
import os import os
from functools import cache, lru_cache, wraps from functools import cache, lru_cache, wraps
from typing import TYPE_CHECKING, Dict, List, Optional from typing import TYPE_CHECKING, Optional
import torch import torch
@ -35,7 +35,7 @@ except ImportError as e:
logger.warning("Failed to import from vllm._rocm_C with %r", e) logger.warning("Failed to import from vllm._rocm_C with %r", e)
# Models not supported by ROCm. # Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS: List[str] = [] _ROCM_UNSUPPORTED_MODELS: list[str] = []
# Models partially supported by ROCm. # Models partially supported by ROCm.
# Architecture -> Reason. # Architecture -> Reason.
@ -43,7 +43,7 @@ _ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
"Triton flash attention. For half-precision SWA support, " "Triton flash attention. For half-precision SWA support, "
"please use CK flash attention by setting " "please use CK flash attention by setting "
"`VLLM_USE_TRITON_FLASH_ATTN=0`") "`VLLM_USE_TRITON_FLASH_ATTN=0`")
_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = { _ROCM_PARTIALLY_SUPPORTED_MODELS: dict[str, str] = {
"Qwen2ForCausalLM": "Qwen2ForCausalLM":
_ROCM_SWA_REASON, _ROCM_SWA_REASON,
"MistralForCausalLM": "MistralForCausalLM":
@ -58,7 +58,7 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
"excessive use of shared memory. If this happens, disable Triton FA " "excessive use of shared memory. If this happens, disable Triton FA "
"by setting `VLLM_USE_TRITON_FLASH_ATTN=0`") "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
} }
_ROCM_DEVICE_ID_NAME_MAP: Dict[str, str] = { _ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = {
"0x74a0": "AMD_Instinct_MI300A", "0x74a0": "AMD_Instinct_MI300A",
"0x74a1": "AMD_Instinct_MI300X", "0x74a1": "AMD_Instinct_MI300X",
"0x74b5": "AMD_Instinct_MI300X", # MI300X VF "0x74b5": "AMD_Instinct_MI300X", # MI300X VF
@ -203,7 +203,7 @@ class RocmPlatform(Platform):
@staticmethod @staticmethod
@with_amdsmi_context @with_amdsmi_context
def is_fully_connected(physical_device_ids: List[int]) -> bool: def is_fully_connected(physical_device_ids: list[int]) -> bool:
""" """
Query if the set of gpus are fully connected by xgmi (1 hop) Query if the set of gpus are fully connected by xgmi (1 hop)
""" """

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING, Optional, Tuple, Union, cast from typing import TYPE_CHECKING, Optional, Union, cast
import torch import torch
from tpu_info import device from tpu_info import device
@ -73,7 +73,7 @@ class TpuPlatform(Platform):
return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU" return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU"
@classmethod @classmethod
def get_infinity_values(cls, dtype: torch.dtype) -> Tuple[float, float]: def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]:
return torch.finfo(dtype).min, torch.finfo(dtype).max return torch.finfo(dtype).min, torch.finfo(dtype).max
@classmethod @classmethod

View File

@ -2,7 +2,7 @@
import logging import logging
import os import os
from typing import Callable, Dict from typing import Callable
import torch import torch
@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
plugins_loaded = False plugins_loaded = False
def load_plugins_by_group(group: str) -> Dict[str, Callable]: def load_plugins_by_group(group: str) -> dict[str, Callable]:
import sys import sys
if sys.version_info < (3, 10): if sys.version_info < (3, 10):
from importlib_metadata import entry_points from importlib_metadata import entry_points