mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 07:54:57 +08:00
[distributed][misc] be consistent with pytorch for libcudart.so (#6346)
[distributed][misc] keep consistent with how pytorch finds libcudart.so (#6346)
This commit is contained in:
parent
d6ab528997
commit
2b0fb53481
@ -4,6 +4,9 @@ convenient for use when we just need to call a few functions.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import ctypes
|
import ctypes
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
@ -33,6 +36,26 @@ class Function:
|
|||||||
argtypes: List[Any]
|
argtypes: List[Any]
|
||||||
|
|
||||||
|
|
||||||
|
def get_pytorch_default_cudart_library_path() -> str:
|
||||||
|
# code borrowed from https://github.com/pytorch/pytorch/blob/1cae60a87e5bdda8bcf55724a862eeed98a9747e/torch/__init__.py#L284 # noqa
|
||||||
|
lib_folder = "cuda_runtime"
|
||||||
|
lib_name = "libcudart.so.*[0-9]"
|
||||||
|
lib_path = None
|
||||||
|
for path in sys.path:
|
||||||
|
nvidia_path = os.path.join(path, "nvidia")
|
||||||
|
if not os.path.exists(nvidia_path):
|
||||||
|
continue
|
||||||
|
candidate_lib_paths = glob.glob(
|
||||||
|
os.path.join(nvidia_path, lib_folder, "lib", lib_name))
|
||||||
|
if candidate_lib_paths and not lib_path:
|
||||||
|
lib_path = candidate_lib_paths[0]
|
||||||
|
if lib_path:
|
||||||
|
break
|
||||||
|
if not lib_path:
|
||||||
|
raise ValueError(f"{lib_name} not found in the system path {sys.path}")
|
||||||
|
return lib_path
|
||||||
|
|
||||||
|
|
||||||
class CudaRTLibrary:
|
class CudaRTLibrary:
|
||||||
exported_functions = [
|
exported_functions = [
|
||||||
# cudaError_t cudaSetDevice ( int device )
|
# cudaError_t cudaSetDevice ( int device )
|
||||||
@ -77,9 +100,7 @@ class CudaRTLibrary:
|
|||||||
|
|
||||||
def __init__(self, so_file: Optional[str] = None):
|
def __init__(self, so_file: Optional[str] = None):
|
||||||
if so_file is None:
|
if so_file is None:
|
||||||
assert torch.version.cuda is not None
|
so_file = get_pytorch_default_cudart_library_path()
|
||||||
major_version = torch.version.cuda.split(".")[0]
|
|
||||||
so_file = f"libcudart.so.{major_version}"
|
|
||||||
if so_file not in CudaRTLibrary.path_to_library_cache:
|
if so_file not in CudaRTLibrary.path_to_library_cache:
|
||||||
lib = ctypes.CDLL(so_file)
|
lib = ctypes.CDLL(so_file)
|
||||||
CudaRTLibrary.path_to_library_cache[so_file] = lib
|
CudaRTLibrary.path_to_library_cache[so_file] = lib
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user