mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 02:34:56 +08:00
Update deprecated type hinting in platform, plugins, triton_utils, vllm_flash_attn (#18129)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
dc372b9c8a
commit
c8ea982d9b
@ -78,13 +78,8 @@ exclude = [
|
||||
"vllm/executor/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/model_executor/model_loader/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/model_executor/models/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/platforms/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/plugins/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/prompt_adapter/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/spec_decode/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/transformers_utils/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/triton_utils/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/vllm_flash_attn/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/worker/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/utils.py" = ["UP006", "UP035"]
|
||||
|
||||
|
||||
@ -5,8 +5,7 @@ pynvml. However, it should not initialize cuda context.
|
||||
|
||||
import os
|
||||
from functools import wraps
|
||||
from typing import (TYPE_CHECKING, Callable, List, Optional, Tuple, TypeVar,
|
||||
Union)
|
||||
from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union
|
||||
|
||||
import torch
|
||||
from typing_extensions import ParamSpec
|
||||
@ -56,7 +55,7 @@ class CudaPlatformBase(Platform):
|
||||
device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
|
||||
|
||||
@property
|
||||
def supported_dtypes(self) -> List[torch.dtype]:
|
||||
def supported_dtypes(self) -> list[torch.dtype]:
|
||||
if self.has_device_capability(80):
|
||||
# Ampere and Hopper or later NVIDIA GPUs.
|
||||
return [torch.bfloat16, torch.float16, torch.float32]
|
||||
@ -93,7 +92,7 @@ class CudaPlatformBase(Platform):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def is_fully_connected(cls, device_ids: List[int]) -> bool:
|
||||
def is_fully_connected(cls, device_ids: list[int]) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@ -335,7 +334,7 @@ class NvmlCudaPlatform(CudaPlatformBase):
|
||||
@with_nvml_context
|
||||
def has_device_capability(
|
||||
cls,
|
||||
capability: Union[Tuple[int, int], int],
|
||||
capability: Union[tuple[int, int], int],
|
||||
device_id: int = 0,
|
||||
) -> bool:
|
||||
try:
|
||||
@ -365,7 +364,7 @@ class NvmlCudaPlatform(CudaPlatformBase):
|
||||
|
||||
@classmethod
|
||||
@with_nvml_context
|
||||
def is_fully_connected(cls, physical_device_ids: List[int]) -> bool:
|
||||
def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
|
||||
"""
|
||||
query if the set of gpus are fully connected by nvlink (1 hop)
|
||||
"""
|
||||
@ -430,7 +429,7 @@ class NonNvmlCudaPlatform(CudaPlatformBase):
|
||||
return device_props.total_memory
|
||||
|
||||
@classmethod
|
||||
def is_fully_connected(cls, physical_device_ids: List[int]) -> bool:
|
||||
def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
|
||||
logger.exception(
|
||||
"NVLink detection not possible, as context support was"
|
||||
" not found. Assuming no NVLink available.")
|
||||
|
||||
@ -4,7 +4,7 @@ import os
|
||||
import platform
|
||||
import random
|
||||
from platform import uname
|
||||
from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, NamedTuple, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -200,7 +200,7 @@ class Platform:
|
||||
@classmethod
|
||||
def has_device_capability(
|
||||
cls,
|
||||
capability: Union[Tuple[int, int], int],
|
||||
capability: Union[tuple[int, int], int],
|
||||
device_id: int = 0,
|
||||
) -> bool:
|
||||
"""
|
||||
@ -362,7 +362,7 @@ class Platform:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_infinity_values(cls, dtype: torch.dtype) -> Tuple[float, float]:
|
||||
def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]:
|
||||
"""
|
||||
Return the platform specific values for (-inf, inf)
|
||||
"""
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
import os
|
||||
from functools import cache, lru_cache, wraps
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -35,7 +35,7 @@ except ImportError as e:
|
||||
logger.warning("Failed to import from vllm._rocm_C with %r", e)
|
||||
|
||||
# Models not supported by ROCm.
|
||||
_ROCM_UNSUPPORTED_MODELS: List[str] = []
|
||||
_ROCM_UNSUPPORTED_MODELS: list[str] = []
|
||||
|
||||
# Models partially supported by ROCm.
|
||||
# Architecture -> Reason.
|
||||
@ -43,7 +43,7 @@ _ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
|
||||
"Triton flash attention. For half-precision SWA support, "
|
||||
"please use CK flash attention by setting "
|
||||
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
|
||||
_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
|
||||
_ROCM_PARTIALLY_SUPPORTED_MODELS: dict[str, str] = {
|
||||
"Qwen2ForCausalLM":
|
||||
_ROCM_SWA_REASON,
|
||||
"MistralForCausalLM":
|
||||
@ -58,7 +58,7 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
|
||||
"excessive use of shared memory. If this happens, disable Triton FA "
|
||||
"by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
|
||||
}
|
||||
_ROCM_DEVICE_ID_NAME_MAP: Dict[str, str] = {
|
||||
_ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = {
|
||||
"0x74a0": "AMD_Instinct_MI300A",
|
||||
"0x74a1": "AMD_Instinct_MI300X",
|
||||
"0x74b5": "AMD_Instinct_MI300X", # MI300X VF
|
||||
@ -203,7 +203,7 @@ class RocmPlatform(Platform):
|
||||
|
||||
@staticmethod
|
||||
@with_amdsmi_context
|
||||
def is_fully_connected(physical_device_ids: List[int]) -> bool:
|
||||
def is_fully_connected(physical_device_ids: list[int]) -> bool:
|
||||
"""
|
||||
Query if the set of gpus are fully connected by xgmi (1 hop)
|
||||
"""
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
|
||||
from typing import TYPE_CHECKING, Optional, Union, cast
|
||||
|
||||
import torch
|
||||
from tpu_info import device
|
||||
@ -73,7 +73,7 @@ class TpuPlatform(Platform):
|
||||
return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU"
|
||||
|
||||
@classmethod
|
||||
def get_infinity_values(cls, dtype: torch.dtype) -> Tuple[float, float]:
|
||||
def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]:
|
||||
return torch.finfo(dtype).min, torch.finfo(dtype).max
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Callable, Dict
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
|
||||
@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
|
||||
plugins_loaded = False
|
||||
|
||||
|
||||
def load_plugins_by_group(group: str) -> Dict[str, Callable]:
|
||||
def load_plugins_by_group(group: str) -> dict[str, Callable]:
|
||||
import sys
|
||||
if sys.version_info < (3, 10):
|
||||
from importlib_metadata import entry_points
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user