[distributed][misc] be consistent with pytorch for libcudart.so (#6346)

[distributed][misc] keep consistent with how pytorch finds libcudart.so (#6346)
This commit is contained in:
youkaichao 2024-07-11 19:35:17 -07:00 committed by GitHub
parent d6ab528997
commit 2b0fb53481
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,6 +4,9 @@ convenient for use when we just need to call a few functions.
"""
import ctypes
import glob
import os
import sys
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
@ -33,6 +36,26 @@ class Function:
argtypes: List[Any]
def get_pytorch_default_cudart_library_path() -> str:
# code borrowed from https://github.com/pytorch/pytorch/blob/1cae60a87e5bdda8bcf55724a862eeed98a9747e/torch/__init__.py#L284 # noqa
lib_folder = "cuda_runtime"
lib_name = "libcudart.so.*[0-9]"
lib_path = None
for path in sys.path:
nvidia_path = os.path.join(path, "nvidia")
if not os.path.exists(nvidia_path):
continue
candidate_lib_paths = glob.glob(
os.path.join(nvidia_path, lib_folder, "lib", lib_name))
if candidate_lib_paths and not lib_path:
lib_path = candidate_lib_paths[0]
if lib_path:
break
if not lib_path:
raise ValueError(f"{lib_name} not found in the system path {sys.path}")
return lib_path
class CudaRTLibrary:
exported_functions = [
# cudaError_t cudaSetDevice ( int device )
@ -77,9 +100,7 @@ class CudaRTLibrary:
def __init__(self, so_file: Optional[str] = None):
if so_file is None:
assert torch.version.cuda is not None
major_version = torch.version.cuda.split(".")[0]
so_file = f"libcudart.so.{major_version}"
so_file = get_pytorch_default_cudart_library_path()
if so_file not in CudaRTLibrary.path_to_library_cache:
lib = ctypes.CDLL(so_file)
CudaRTLibrary.path_to_library_cache[so_file] = lib