[distributed][misc] be consistent with pytorch for libcudart.so (#6346)

[distributed][misc] keep consistent with how pytorch finds libcudart.so (#6346)
This commit is contained in:
youkaichao 2024-07-11 19:35:17 -07:00 committed by GitHub
parent d6ab528997
commit 2b0fb53481
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,6 +4,9 @@ convenient for use when we just need to call a few functions.
""" """
import ctypes import ctypes
import glob
import os
import sys
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
@ -33,6 +36,26 @@ class Function:
argtypes: List[Any] argtypes: List[Any]
def get_pytorch_default_cudart_library_path() -> str:
# code borrowed from https://github.com/pytorch/pytorch/blob/1cae60a87e5bdda8bcf55724a862eeed98a9747e/torch/__init__.py#L284 # noqa
lib_folder = "cuda_runtime"
lib_name = "libcudart.so.*[0-9]"
lib_path = None
for path in sys.path:
nvidia_path = os.path.join(path, "nvidia")
if not os.path.exists(nvidia_path):
continue
candidate_lib_paths = glob.glob(
os.path.join(nvidia_path, lib_folder, "lib", lib_name))
if candidate_lib_paths and not lib_path:
lib_path = candidate_lib_paths[0]
if lib_path:
break
if not lib_path:
raise ValueError(f"{lib_name} not found in the system path {sys.path}")
return lib_path
class CudaRTLibrary: class CudaRTLibrary:
exported_functions = [ exported_functions = [
# cudaError_t cudaSetDevice ( int device ) # cudaError_t cudaSetDevice ( int device )
@ -77,9 +100,7 @@ class CudaRTLibrary:
def __init__(self, so_file: Optional[str] = None): def __init__(self, so_file: Optional[str] = None):
if so_file is None: if so_file is None:
assert torch.version.cuda is not None so_file = get_pytorch_default_cudart_library_path()
major_version = torch.version.cuda.split(".")[0]
so_file = f"libcudart.so.{major_version}"
if so_file not in CudaRTLibrary.path_to_library_cache: if so_file not in CudaRTLibrary.path_to_library_cache:
lib = ctypes.CDLL(so_file) lib = ctypes.CDLL(so_file)
CudaRTLibrary.path_to_library_cache[so_file] = lib CudaRTLibrary.path_to_library_cache[so_file] = lib