From 82c73fd5104e010c2c98820f3e761e1e4f36c135 Mon Sep 17 00:00:00 2001 From: Gene Der Su Date: Mon, 9 Dec 2024 23:41:11 -0800 Subject: [PATCH] [Bugfix] cuda error running llama 3.2 (#11047) --- vllm/platforms/cuda.py | 35 ++++++++++++++++++++++++++++------- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 10f83fd304281..ae1fd6d5ce068 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -4,7 +4,8 @@ pynvml. However, it should not initialize cuda context. import os from functools import lru_cache, wraps -from typing import TYPE_CHECKING, Callable, List, Optional, TypeVar +from typing import (TYPE_CHECKING, Callable, List, Optional, Tuple, TypeVar, + Union) import pynvml import torch @@ -78,7 +79,9 @@ class CudaPlatformBase(Platform): dispatch_key: str = "CUDA" @classmethod - def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: + def get_device_capability(cls, + device_id: int = 0 + ) -> Optional[DeviceCapability]: raise NotImplementedError @classmethod @@ -144,11 +147,29 @@ class NvmlCudaPlatform(CudaPlatformBase): @classmethod @lru_cache(maxsize=8) @with_nvml_context - def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: - physical_device_id = device_id_to_physical_device_id(device_id) - handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id) - major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle) - return DeviceCapability(major=major, minor=minor) + def get_device_capability(cls, + device_id: int = 0 + ) -> Optional[DeviceCapability]: + try: + physical_device_id = device_id_to_physical_device_id(device_id) + handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id) + major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle) + return DeviceCapability(major=major, minor=minor) + except RuntimeError: + return None + + @classmethod + @lru_cache(maxsize=8) + @with_nvml_context + def has_device_capability( + cls, + capability: Union[Tuple[int, int], int], + device_id: int = 0, + ) -> bool: + try: + return super().has_device_capability(capability, device_id) + except RuntimeError: + return False @classmethod @lru_cache(maxsize=8)