Add cache to cuda get_device_capability (#19436)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-06-11 05:37:05 -04:00 committed by GitHub
parent a2142f0196
commit 7484e1fce2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -6,7 +6,7 @@ pynvml. However, it should not initialize cuda context.
import os
from datetime import timedelta
from functools import wraps
from functools import cache, wraps
from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union
import torch
@ -389,6 +389,7 @@ class CudaPlatformBase(Platform):
class NvmlCudaPlatform(CudaPlatformBase):
@classmethod
@cache
@with_nvml_context
def get_device_capability(cls,
device_id: int = 0
@ -486,6 +487,7 @@ class NvmlCudaPlatform(CudaPlatformBase):
class NonNvmlCudaPlatform(CudaPlatformBase):
@classmethod
@cache
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
major, minor = torch.cuda.get_device_capability(device_id)
return DeviceCapability(major=major, minor=minor)