[cuda] manually import the correct pynvml module (#12679)

fixes problems like https://github.com/vllm-project/vllm/pull/12635 and
https://github.com/vllm-project/vllm/pull/12636 and
https://github.com/vllm-project/vllm/pull/12565

---------

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2025-02-03 15:58:21 +08:00 committed by GitHub
parent b9986454fe
commit ad4a9dc817
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 56 additions and 9 deletions

View File

@ -33,7 +33,8 @@ def cuda_platform_plugin() -> Optional[str]:
is_cuda = False
try:
import pynvml
from vllm.utils import import_pynvml
pynvml = import_pynvml()
pynvml.nvmlInit()
try:
if pynvml.nvmlDeviceGetCount() > 0:

View File

@ -8,7 +8,6 @@ from functools import lru_cache, wraps
from typing import (TYPE_CHECKING, Callable, List, Optional, Tuple, TypeVar,
Union)
import pynvml
import torch
from typing_extensions import ParamSpec
@ -16,6 +15,7 @@ from typing_extensions import ParamSpec
import vllm._C # noqa
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils import import_pynvml
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
@ -29,13 +29,7 @@ logger = init_logger(__name__)
_P = ParamSpec("_P")
_R = TypeVar("_R")
if pynvml.__file__.endswith("__init__.py"):
logger.warning(
"You are using a deprecated `pynvml` package. Please install"
" `nvidia-ml-py` instead, and make sure to uninstall `pynvml`."
" When both of them are installed, `pynvml` will take precedence"
" and cause errors. See https://pypi.org/project/pynvml "
"for more information.")
pynvml = import_pynvml()
# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
# see https://github.com/huggingface/diffusers/issues/9704 for details

View File

@ -2208,3 +2208,55 @@ def run_method(obj: Any, method: Union[str, bytes, Callable], args: Tuple[Any],
else:
func = partial(method, obj) # type: ignore
return func(*args, **kwargs)
def import_pynvml():
"""
Historical comments:
libnvml.so is the library behind nvidia-smi, and
pynvml is a Python wrapper around it. We use it to get GPU
status without initializing CUDA context in the current process.
Historically, there are two packages that provide pynvml:
- `nvidia-ml-py` (https://pypi.org/project/nvidia-ml-py/): The official
wrapper. It is a dependency of vLLM, and is installed when users
install vLLM. It provides a Python module named `pynvml`.
- `pynvml` (https://pypi.org/project/pynvml/): An unofficial wrapper.
Prior to version 12.0, it also provides a Python module `pynvml`,
and therefore conflicts with the official one. What's worse,
the module is a Python package, and has higher priority than
the official one which is a standalone Python file.
This causes errors when both of them are installed.
Starting from version 12.0, it migrates to a new module
named `pynvml_utils` to avoid the conflict.
TL;DR: if users have pynvml<12.0 installed, it will cause problems.
Otherwise, `import pynvml` will import the correct module.
We take the safest approach here, to manually import the correct
`pynvml.py` module from the `nvidia-ml-py` package.
"""
if TYPE_CHECKING:
import pynvml
return pynvml
if "pynvml" in sys.modules:
import pynvml
if pynvml.__file__.endswith("__init__.py"):
# this is pynvml < 12.0
raise RuntimeError(
"You are using a deprecated `pynvml` package. "
"Please uninstall `pynvml` or upgrade to at least"
" version 12.0. See https://pypi.org/project/pynvml "
"for more information.")
return sys.modules["pynvml"]
import importlib.util
import os
import site
for site_dir in site.getsitepackages():
pynvml_path = os.path.join(site_dir, "pynvml.py")
if os.path.exists(pynvml_path):
spec = importlib.util.spec_from_file_location(
"pynvml", pynvml_path)
pynvml = importlib.util.module_from_spec(spec)
sys.modules["pynvml"] = pynvml
spec.loader.exec_module(pynvml)
return pynvml