mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-26 02:49:37 +08:00
[hardware][cuda] use device id under CUDA_VISIBLE_DEVICES for get_device_capability (#6216)
This commit is contained in:
parent
4f0e0ea131
commit
a3c9435d93
@ -2,6 +2,7 @@
|
||||
pynvml. However, it should not initialize cuda context.
|
||||
"""
|
||||
|
||||
import os
|
||||
from functools import lru_cache, wraps
|
||||
from typing import Tuple
|
||||
|
||||
@ -23,12 +24,27 @@ def with_nvml_context(fn):
|
||||
return wrapper
|
||||
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
@with_nvml_context
|
||||
def get_physical_device_capability(device_id: int = 0) -> Tuple[int, int]:
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
|
||||
return pynvml.nvmlDeviceGetCudaComputeCapability(handle)
|
||||
|
||||
|
||||
def device_id_to_physical_device_id(device_id: int) -> int:
|
||||
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
|
||||
device_ids = [int(device_id) for device_id in device_ids]
|
||||
physical_device_id = device_ids[device_id]
|
||||
else:
|
||||
physical_device_id = device_id
|
||||
return physical_device_id
|
||||
|
||||
|
||||
class CudaPlatform(Platform):
|
||||
_enum = PlatformEnum.CUDA
|
||||
|
||||
@staticmethod
|
||||
@lru_cache(maxsize=8)
|
||||
@with_nvml_context
|
||||
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
|
||||
return pynvml.nvmlDeviceGetCudaComputeCapability(handle)
|
||||
physical_device_id = device_id_to_physical_device_id(device_id)
|
||||
return get_physical_device_capability(physical_device_id)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user