mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-02 06:07:02 +08:00
[Bugfix] cuda error running llama 3.2 (#11047)
This commit is contained in:
parent
bfd610430c
commit
82c73fd510
@ -4,7 +4,8 @@ pynvml. However, it should not initialize cuda context.
|
||||
|
||||
import os
|
||||
from functools import lru_cache, wraps
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional, TypeVar
|
||||
from typing import (TYPE_CHECKING, Callable, List, Optional, Tuple, TypeVar,
|
||||
Union)
|
||||
|
||||
import pynvml
|
||||
import torch
|
||||
@ -78,7 +79,9 @@ class CudaPlatformBase(Platform):
|
||||
dispatch_key: str = "CUDA"
|
||||
|
||||
@classmethod
|
||||
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
|
||||
def get_device_capability(cls,
|
||||
device_id: int = 0
|
||||
) -> Optional[DeviceCapability]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@ -144,11 +147,29 @@ class NvmlCudaPlatform(CudaPlatformBase):
|
||||
@classmethod
|
||||
@lru_cache(maxsize=8)
|
||||
@with_nvml_context
|
||||
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
|
||||
physical_device_id = device_id_to_physical_device_id(device_id)
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
|
||||
major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
|
||||
return DeviceCapability(major=major, minor=minor)
|
||||
def get_device_capability(cls,
|
||||
device_id: int = 0
|
||||
) -> Optional[DeviceCapability]:
|
||||
try:
|
||||
physical_device_id = device_id_to_physical_device_id(device_id)
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
|
||||
major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
|
||||
return DeviceCapability(major=major, minor=minor)
|
||||
except RuntimeError:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@lru_cache(maxsize=8)
|
||||
@with_nvml_context
|
||||
def has_device_capability(
|
||||
cls,
|
||||
capability: Union[Tuple[int, int], int],
|
||||
device_id: int = 0,
|
||||
) -> bool:
|
||||
try:
|
||||
return super().has_device_capability(capability, device_id)
|
||||
except RuntimeError:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
@lru_cache(maxsize=8)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user