diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 48d1aacba1858..7ab5146fd743c 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -6,7 +6,7 @@ pynvml. However, it should not initialize cuda context. import os from datetime import timedelta -from functools import wraps +from functools import cache, wraps from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union import torch @@ -389,6 +389,7 @@ class CudaPlatformBase(Platform): class NvmlCudaPlatform(CudaPlatformBase): @classmethod + @cache @with_nvml_context def get_device_capability(cls, device_id: int = 0 @@ -486,6 +487,7 @@ class NvmlCudaPlatform(CudaPlatformBase): class NonNvmlCudaPlatform(CudaPlatformBase): @classmethod + @cache def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: major, minor = torch.cuda.get_device_capability(device_id) return DeviceCapability(major=major, minor=minor)