From 2b0fb534813e9835077403723a484b7c03d47259 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 11 Jul 2024 19:35:17 -0700 Subject: [PATCH] [distributed][misc] be consistent with pytorch for libcudart.so (#6346) [distributed][misc] keep consistent with how pytorch finds libcudart.so (#6346) --- .../device_communicators/cuda_wrapper.py | 27 ++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/device_communicators/cuda_wrapper.py b/vllm/distributed/device_communicators/cuda_wrapper.py index 24308235c4a48..5cac3c1d57bca 100644 --- a/vllm/distributed/device_communicators/cuda_wrapper.py +++ b/vllm/distributed/device_communicators/cuda_wrapper.py @@ -4,6 +4,9 @@ convenient for use when we just need to call a few functions. """ import ctypes +import glob +import os +import sys from dataclasses import dataclass from typing import Any, Dict, List, Optional @@ -33,6 +36,26 @@ class Function: argtypes: List[Any] +def get_pytorch_default_cudart_library_path() -> str: + # code borrowed from https://github.com/pytorch/pytorch/blob/1cae60a87e5bdda8bcf55724a862eeed98a9747e/torch/__init__.py#L284 # noqa + lib_folder = "cuda_runtime" + lib_name = "libcudart.so.*[0-9]" + lib_path = None + for path in sys.path: + nvidia_path = os.path.join(path, "nvidia") + if not os.path.exists(nvidia_path): + continue + candidate_lib_paths = glob.glob( + os.path.join(nvidia_path, lib_folder, "lib", lib_name)) + if candidate_lib_paths and not lib_path: + lib_path = candidate_lib_paths[0] + if lib_path: + break + if not lib_path: + raise ValueError(f"{lib_name} not found in the system path {sys.path}") + return lib_path + + class CudaRTLibrary: exported_functions = [ # ​cudaError_t cudaSetDevice ( int device ) @@ -77,9 +100,7 @@ class CudaRTLibrary: def __init__(self, so_file: Optional[str] = None): if so_file is None: - assert torch.version.cuda is not None - major_version = torch.version.cuda.split(".")[0] - so_file = f"libcudart.so.{major_version}" + so_file = get_pytorch_default_cudart_library_path() if so_file not in CudaRTLibrary.path_to_library_cache: lib = ctypes.CDLL(so_file) CudaRTLibrary.path_to_library_cache[so_file] = lib