mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 13:05:47 +08:00
Add cache to cuda get_device_capability (#19436)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
a2142f0196
commit
7484e1fce2
@ -6,7 +6,7 @@ pynvml. However, it should not initialize cuda context.
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from functools import wraps
|
from functools import cache, wraps
|
||||||
from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union
|
from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -389,6 +389,7 @@ class CudaPlatformBase(Platform):
|
|||||||
class NvmlCudaPlatform(CudaPlatformBase):
|
class NvmlCudaPlatform(CudaPlatformBase):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@cache
|
||||||
@with_nvml_context
|
@with_nvml_context
|
||||||
def get_device_capability(cls,
|
def get_device_capability(cls,
|
||||||
device_id: int = 0
|
device_id: int = 0
|
||||||
@ -486,6 +487,7 @@ class NvmlCudaPlatform(CudaPlatformBase):
|
|||||||
class NonNvmlCudaPlatform(CudaPlatformBase):
|
class NonNvmlCudaPlatform(CudaPlatformBase):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@cache
|
||||||
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
|
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
|
||||||
major, minor = torch.cuda.get_device_capability(device_id)
|
major, minor = torch.cuda.get_device_capability(device_id)
|
||||||
return DeviceCapability(major=major, minor=minor)
|
return DeviceCapability(major=major, minor=minor)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user