[doc] fix doc build error caused by msgspec (#7659)

This commit is contained in:
youkaichao 2024-08-19 17:50:59 -07:00 committed by GitHub
parent 67e02fa8a4
commit e54ebc2f8f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 41 additions and 9 deletions

View File

@ -3,6 +3,7 @@ sphinx-book-theme==1.0.1
sphinx-copybutton==0.5.2
myst-parser==2.0.0
sphinx-argparse==0.4.0
msgspec
# packages to install to build the documentation
pydantic

View File

@ -1,23 +1,54 @@
import torch
from .interface import Platform, PlatformEnum, UnspecifiedPlatform
current_platform: Platform
try:
import libtpu
except ImportError:
libtpu = None
# NOTE: we don't use `torch.version.cuda` / `torch.version.hip` because
# they only indicate the build configuration, not the runtime environment.
# For example, people can install a cuda build of pytorch but run on tpu.
if libtpu is not None:
is_tpu = False
try:
import torch_xla.core.xla_model as xm
xm.xla_device(devkind="TPU")
is_tpu = True
except Exception:
pass
is_cuda = False
try:
import pynvml
pynvml.nvmlInit()
try:
if pynvml.nvmlDeviceGetCount() > 0:
is_cuda = True
finally:
pynvml.nvmlShutdown()
except Exception:
pass
is_rocm = False
try:
import amdsmi
amdsmi.amdsmi_init()
try:
if len(amdsmi.amdsmi_get_processor_handles()) > 0:
is_rocm = True
finally:
amdsmi.amdsmi_shut_down()
except Exception:
pass
if is_tpu:
# people might install pytorch built with cuda but run on tpu
# so we need to check tpu first
from .tpu import TpuPlatform
current_platform = TpuPlatform()
elif torch.version.cuda is not None:
elif is_cuda:
from .cuda import CudaPlatform
current_platform = CudaPlatform()
elif torch.version.hip is not None:
elif is_rocm:
from .rocm import RocmPlatform
current_platform = RocmPlatform()
else: