mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-26 20:35:23 +08:00
[cuda] manually import the correct pynvml module (#12679)
fixes problems like https://github.com/vllm-project/vllm/pull/12635 and https://github.com/vllm-project/vllm/pull/12636 and https://github.com/vllm-project/vllm/pull/12565 --------- Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
b9986454fe
commit
ad4a9dc817
@ -33,7 +33,8 @@ def cuda_platform_plugin() -> Optional[str]:
|
||||
is_cuda = False
|
||||
|
||||
try:
|
||||
import pynvml
|
||||
from vllm.utils import import_pynvml
|
||||
pynvml = import_pynvml()
|
||||
pynvml.nvmlInit()
|
||||
try:
|
||||
if pynvml.nvmlDeviceGetCount() > 0:
|
||||
|
||||
@ -8,7 +8,6 @@ from functools import lru_cache, wraps
|
||||
from typing import (TYPE_CHECKING, Callable, List, Optional, Tuple, TypeVar,
|
||||
Union)
|
||||
|
||||
import pynvml
|
||||
import torch
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
@ -16,6 +15,7 @@ from typing_extensions import ParamSpec
|
||||
import vllm._C # noqa
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import import_pynvml
|
||||
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
|
||||
|
||||
@ -29,13 +29,7 @@ logger = init_logger(__name__)
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
|
||||
if pynvml.__file__.endswith("__init__.py"):
|
||||
logger.warning(
|
||||
"You are using a deprecated `pynvml` package. Please install"
|
||||
" `nvidia-ml-py` instead, and make sure to uninstall `pynvml`."
|
||||
" When both of them are installed, `pynvml` will take precedence"
|
||||
" and cause errors. See https://pypi.org/project/pynvml "
|
||||
"for more information.")
|
||||
pynvml = import_pynvml()
|
||||
|
||||
# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
|
||||
# see https://github.com/huggingface/diffusers/issues/9704 for details
|
||||
|
||||
@ -2208,3 +2208,55 @@ def run_method(obj: Any, method: Union[str, bytes, Callable], args: Tuple[Any],
|
||||
else:
|
||||
func = partial(method, obj) # type: ignore
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
def import_pynvml():
|
||||
"""
|
||||
Historical comments:
|
||||
|
||||
libnvml.so is the library behind nvidia-smi, and
|
||||
pynvml is a Python wrapper around it. We use it to get GPU
|
||||
status without initializing CUDA context in the current process.
|
||||
Historically, there are two packages that provide pynvml:
|
||||
- `nvidia-ml-py` (https://pypi.org/project/nvidia-ml-py/): The official
|
||||
wrapper. It is a dependency of vLLM, and is installed when users
|
||||
install vLLM. It provides a Python module named `pynvml`.
|
||||
- `pynvml` (https://pypi.org/project/pynvml/): An unofficial wrapper.
|
||||
Prior to version 12.0, it also provides a Python module `pynvml`,
|
||||
and therefore conflicts with the official one. What's worse,
|
||||
the module is a Python package, and has higher priority than
|
||||
the official one which is a standalone Python file.
|
||||
This causes errors when both of them are installed.
|
||||
Starting from version 12.0, it migrates to a new module
|
||||
named `pynvml_utils` to avoid the conflict.
|
||||
|
||||
TL;DR: if users have pynvml<12.0 installed, it will cause problems.
|
||||
Otherwise, `import pynvml` will import the correct module.
|
||||
We take the safest approach here, to manually import the correct
|
||||
`pynvml.py` module from the `nvidia-ml-py` package.
|
||||
"""
|
||||
if TYPE_CHECKING:
|
||||
import pynvml
|
||||
return pynvml
|
||||
if "pynvml" in sys.modules:
|
||||
import pynvml
|
||||
if pynvml.__file__.endswith("__init__.py"):
|
||||
# this is pynvml < 12.0
|
||||
raise RuntimeError(
|
||||
"You are using a deprecated `pynvml` package. "
|
||||
"Please uninstall `pynvml` or upgrade to at least"
|
||||
" version 12.0. See https://pypi.org/project/pynvml "
|
||||
"for more information.")
|
||||
return sys.modules["pynvml"]
|
||||
import importlib.util
|
||||
import os
|
||||
import site
|
||||
for site_dir in site.getsitepackages():
|
||||
pynvml_path = os.path.join(site_dir, "pynvml.py")
|
||||
if os.path.exists(pynvml_path):
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"pynvml", pynvml_path)
|
||||
pynvml = importlib.util.module_from_spec(spec)
|
||||
sys.modules["pynvml"] = pynvml
|
||||
spec.loader.exec_module(pynvml)
|
||||
return pynvml
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user