mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-27 11:15:15 +08:00
[hardware][misc] introduce platform abstraction (#6080)
This commit is contained in:
parent
9d6a8daa87
commit
482045ee77
@ -8,13 +8,13 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.utils import get_device_capability_stateless
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
|
||||
capability = get_device_capability_stateless()
|
||||
capability = current_platform.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
from vllm.utils import get_device_capability_stateless
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def is_quant_method_supported(quant_method: str) -> bool:
|
||||
@ -9,7 +9,7 @@ def is_quant_method_supported(quant_method: str) -> bool:
|
||||
if not torch.cuda.is_available():
|
||||
return False
|
||||
|
||||
capability = get_device_capability_stateless()
|
||||
capability = current_platform.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
return (capability >=
|
||||
QUANTIZATION_METHODS[quant_method].get_min_capability())
|
||||
|
||||
@ -2,13 +2,14 @@ import math
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.utils import get_device_capability_stateless, is_cpu, is_hip
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_cpu, is_hip
|
||||
|
||||
from .utils import (dense_to_crow_col, get_head_sliding_step,
|
||||
get_sparse_attn_mask)
|
||||
|
||||
IS_COMPUTE_8_OR_ABOVE = (torch.cuda.is_available()
|
||||
and get_device_capability_stateless()[0] >= 8)
|
||||
and current_platform.get_device_capability()[0] >= 8)
|
||||
|
||||
if IS_COMPUTE_8_OR_ABOVE:
|
||||
from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd
|
||||
|
||||
@ -5,7 +5,7 @@ import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.utils import get_device_capability_stateless
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if triton.__version__ >= "2.1.0":
|
||||
|
||||
@ -685,7 +685,7 @@ if triton.__version__ >= "2.1.0":
|
||||
alibi_slopes=None,
|
||||
sliding_window=None):
|
||||
|
||||
cap = get_device_capability_stateless()
|
||||
cap = current_platform.get_device_capability()
|
||||
BLOCK = 128 if cap[0] >= 8 else 64
|
||||
# shape constraints
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
|
||||
@ -5,14 +5,14 @@ from typing import Optional
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.utils import get_device_capability_stateless
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def _check_punica_support():
|
||||
if ops.is_custom_op_supported("_punica_C::dispatch_bgmv"):
|
||||
return
|
||||
|
||||
if get_device_capability_stateless() < (8, 0):
|
||||
if current_platform.get_device_capability() < (8, 0):
|
||||
raise ImportError(
|
||||
"punica LoRA kernels require compute capability >= 8.0")
|
||||
else:
|
||||
|
||||
@ -14,7 +14,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
CompressionFormat, QuantizationArgs, QuantizationStrategy,
|
||||
find_first_name_or_class_match)
|
||||
from vllm.utils import get_device_capability_stateless
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
class CompressedTensorsConfig(QuantizationConfig):
|
||||
@ -85,7 +85,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
return []
|
||||
|
||||
def _check_gptq_and_marlin_can_run(self):
|
||||
capability = get_device_capability_stateless()
|
||||
capability = current_platform.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
if capability < 80:
|
||||
raise RuntimeError("The quantization config is not supported for ",
|
||||
|
||||
@ -12,7 +12,8 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.utils import get_device_capability_stateless, print_warning_once
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
|
||||
@ -20,7 +21,7 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
def cutlass_fp8_supported() -> bool:
|
||||
capability = get_device_capability_stateless()
|
||||
capability = current_platform.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
|
||||
return ops.cutlass_scaled_mm_supports_fp8(capability)
|
||||
|
||||
@ -12,7 +12,7 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.utils import get_device_capability_stateless
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -173,7 +173,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
return False
|
||||
|
||||
# If the capability of the device is too low, cannot convert.
|
||||
major, minor = get_device_capability_stateless()
|
||||
major, minor = current_platform.get_device_capability()
|
||||
device_capability = major * 10 + minor
|
||||
if device_capability < cls.get_min_capability():
|
||||
return False
|
||||
|
||||
@ -12,9 +12,9 @@ from vllm.model_executor.layers.quantization.utils.marlin_perms import (
|
||||
marlin_perm, marlin_scale_perm, marlin_scale_perm_single)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
get_pack_factor, quantize_weights, sort_weights)
|
||||
from vllm.utils import get_device_capability_stateless
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
__cuda_arch = get_device_capability_stateless()
|
||||
__cuda_arch = current_platform.get_device_capability()
|
||||
|
||||
MARLIN_TILE = 16
|
||||
|
||||
|
||||
@ -35,7 +35,8 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
from vllm.model_executor.models.interfaces import (supports_lora,
|
||||
supports_vision)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.utils import get_device_capability_stateless, is_tpu
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_tpu
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -46,7 +47,7 @@ def _get_quantization_config(
|
||||
"""Get the quantization config."""
|
||||
if model_config.quantization is not None:
|
||||
quant_config = get_quant_config(model_config, load_config)
|
||||
capability = get_device_capability_stateless()
|
||||
capability = current_platform.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
if capability < quant_config.get_min_capability():
|
||||
raise ValueError(
|
||||
|
||||
18
vllm/platforms/__init__.py
Normal file
18
vllm/platforms/__init__.py
Normal file
@ -0,0 +1,18 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from .interface import Platform, PlatformEnum
|
||||
|
||||
current_platform: Optional[Platform]
|
||||
|
||||
if torch.version.cuda is not None:
|
||||
from .cuda import CudaPlatform
|
||||
current_platform = CudaPlatform()
|
||||
elif torch.version.hip is not None:
|
||||
from .rocm import RocmPlatform
|
||||
current_platform = RocmPlatform()
|
||||
else:
|
||||
current_platform = None
|
||||
|
||||
__all__ = ['Platform', 'PlatformEnum', 'current_platform']
|
||||
34
vllm/platforms/cuda.py
Normal file
34
vllm/platforms/cuda.py
Normal file
@ -0,0 +1,34 @@
|
||||
"""Code inside this file can safely assume cuda platform, e.g. importing
|
||||
pynvml. However, it should not initialize cuda context.
|
||||
"""
|
||||
|
||||
from functools import lru_cache, wraps
|
||||
from typing import Tuple
|
||||
|
||||
import pynvml
|
||||
|
||||
from .interface import Platform, PlatformEnum
|
||||
|
||||
|
||||
def with_nvml_context(fn):
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
pynvml.nvmlInit()
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
finally:
|
||||
pynvml.nvmlShutdown()
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class CudaPlatform(Platform):
|
||||
_enum = PlatformEnum.CUDA
|
||||
|
||||
@staticmethod
|
||||
@lru_cache(maxsize=8)
|
||||
@with_nvml_context
|
||||
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
|
||||
return pynvml.nvmlDeviceGetCudaComputeCapability(handle)
|
||||
21
vllm/platforms/interface.py
Normal file
21
vllm/platforms/interface.py
Normal file
@ -0,0 +1,21 @@
|
||||
import enum
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
class PlatformEnum(enum.Enum):
|
||||
CUDA = enum.auto()
|
||||
ROCM = enum.auto()
|
||||
|
||||
|
||||
class Platform:
|
||||
_enum: PlatformEnum
|
||||
|
||||
def is_cuda(self) -> bool:
|
||||
return self._enum == PlatformEnum.CUDA
|
||||
|
||||
def is_rocm(self) -> bool:
|
||||
return self._enum == PlatformEnum.ROCM
|
||||
|
||||
@staticmethod
|
||||
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
|
||||
raise NotImplementedError
|
||||
15
vllm/platforms/rocm.py
Normal file
15
vllm/platforms/rocm.py
Normal file
@ -0,0 +1,15 @@
|
||||
from functools import lru_cache
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from .interface import Platform, PlatformEnum
|
||||
|
||||
|
||||
class RocmPlatform(Platform):
|
||||
_enum = PlatformEnum.ROCM
|
||||
|
||||
@staticmethod
|
||||
@lru_cache(maxsize=8)
|
||||
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
|
||||
return torch.cuda.get_device_capability(device_id)
|
||||
@ -866,13 +866,6 @@ def is_full_nvlink(device_ids: List[int]) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
@with_nvml_context
|
||||
def get_device_capability_stateless(device_id: int = 0) -> Tuple[int, int]:
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
|
||||
return pynvml.nvmlDeviceGetCudaComputeCapability(handle)
|
||||
|
||||
|
||||
#From: https://stackoverflow.com/a/4104188/2749989
|
||||
def run_once(f):
|
||||
|
||||
|
||||
@ -15,8 +15,8 @@ from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import get_device_capability_stateless
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
|
||||
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
|
||||
@ -333,7 +333,7 @@ def init_worker_distributed_environment(
|
||||
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
||||
# Check if the GPU supports the dtype.
|
||||
if torch_dtype == torch.bfloat16:
|
||||
compute_capability = get_device_capability_stateless()
|
||||
compute_capability = current_platform.get_device_capability()
|
||||
if compute_capability[0] < 8:
|
||||
gpu_name = torch.cuda.get_device_name()
|
||||
raise ValueError(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user