vllm/vllm/platforms/rocm.py

32 lines
885 B
Python

import os
from functools import lru_cache
from typing import Tuple
import torch
from vllm.logger import init_logger
from .interface import Platform, PlatformEnum
logger = init_logger(__name__)
if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]:
logger.warning("`fork` method is not supported by ROCm. "
"VLLM_WORKER_MULTIPROC_METHOD is overridden to"
" `spawn` instead.")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
class RocmPlatform(Platform):
_enum = PlatformEnum.ROCM
@staticmethod
@lru_cache(maxsize=8)
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
return torch.cuda.get_device_capability(device_id)
@staticmethod
@lru_cache(maxsize=8)
def get_device_name(device_id: int = 0) -> str:
return torch.cuda.get_device_name(device_id)