mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 19:04:55 +08:00
[Core] Use platform-agnostic device control for DP engine core (#17245)
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
parent
b9fd0d7a69
commit
289199feb6
@ -34,24 +34,6 @@ pynvml = import_pynvml()
|
|||||||
torch.backends.cuda.enable_cudnn_sdp(False)
|
torch.backends.cuda.enable_cudnn_sdp(False)
|
||||||
|
|
||||||
|
|
||||||
def device_id_to_physical_device_id(device_id: int) -> int:
|
|
||||||
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
|
||||||
device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
|
|
||||||
if device_ids == [""]:
|
|
||||||
msg = (
|
|
||||||
"CUDA_VISIBLE_DEVICES is set to empty string, which means"
|
|
||||||
" GPU support is disabled. If you are using ray, please unset"
|
|
||||||
" the environment variable `CUDA_VISIBLE_DEVICES` inside the"
|
|
||||||
" worker/actor. "
|
|
||||||
"Check https://github.com/vllm-project/vllm/issues/8402 for"
|
|
||||||
" more information.")
|
|
||||||
raise RuntimeError(msg)
|
|
||||||
physical_device_id = device_ids[device_id]
|
|
||||||
return int(physical_device_id)
|
|
||||||
else:
|
|
||||||
return device_id
|
|
||||||
|
|
||||||
|
|
||||||
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||||
|
|
||||||
@wraps(fn)
|
@wraps(fn)
|
||||||
@ -338,7 +320,7 @@ class NvmlCudaPlatform(CudaPlatformBase):
|
|||||||
device_id: int = 0
|
device_id: int = 0
|
||||||
) -> Optional[DeviceCapability]:
|
) -> Optional[DeviceCapability]:
|
||||||
try:
|
try:
|
||||||
physical_device_id = device_id_to_physical_device_id(device_id)
|
physical_device_id = cls.device_id_to_physical_device_id(device_id)
|
||||||
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
|
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
|
||||||
major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
|
major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
|
||||||
return DeviceCapability(major=major, minor=minor)
|
return DeviceCapability(major=major, minor=minor)
|
||||||
@ -360,20 +342,20 @@ class NvmlCudaPlatform(CudaPlatformBase):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@with_nvml_context
|
@with_nvml_context
|
||||||
def get_device_name(cls, device_id: int = 0) -> str:
|
def get_device_name(cls, device_id: int = 0) -> str:
|
||||||
physical_device_id = device_id_to_physical_device_id(device_id)
|
physical_device_id = cls.device_id_to_physical_device_id(device_id)
|
||||||
return cls._get_physical_device_name(physical_device_id)
|
return cls._get_physical_device_name(physical_device_id)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@with_nvml_context
|
@with_nvml_context
|
||||||
def get_device_uuid(cls, device_id: int = 0) -> str:
|
def get_device_uuid(cls, device_id: int = 0) -> str:
|
||||||
physical_device_id = device_id_to_physical_device_id(device_id)
|
physical_device_id = cls.device_id_to_physical_device_id(device_id)
|
||||||
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
|
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
|
||||||
return pynvml.nvmlDeviceGetUUID(handle)
|
return pynvml.nvmlDeviceGetUUID(handle)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@with_nvml_context
|
@with_nvml_context
|
||||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||||
physical_device_id = device_id_to_physical_device_id(device_id)
|
physical_device_id = cls.device_id_to_physical_device_id(device_id)
|
||||||
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
|
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
|
||||||
return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
|
return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
import enum
|
import enum
|
||||||
|
import os
|
||||||
import platform
|
import platform
|
||||||
import random
|
import random
|
||||||
from platform import uname
|
from platform import uname
|
||||||
@ -161,6 +162,24 @@ class Platform:
|
|||||||
def is_sleep_mode_available(self) -> bool:
|
def is_sleep_mode_available(self) -> bool:
|
||||||
return self._enum == PlatformEnum.CUDA
|
return self._enum == PlatformEnum.CUDA
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def device_id_to_physical_device_id(cls, device_id: int):
|
||||||
|
if cls.device_control_env_var in os.environ:
|
||||||
|
device_ids = os.environ[cls.device_control_env_var].split(",")
|
||||||
|
if device_ids == [""]:
|
||||||
|
msg = (f"{cls.device_control_env_var} is set to empty string, "
|
||||||
|
"which means current platform support is disabled. If "
|
||||||
|
"you are using ray, please unset the environment "
|
||||||
|
f"variable `{cls.device_control_env_var}` inside the "
|
||||||
|
"worker/actor. Check "
|
||||||
|
"https://github.com/vllm-project/vllm/issues/8402 for "
|
||||||
|
"more information.")
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
physical_device_id = device_ids[device_id]
|
||||||
|
return int(physical_device_id)
|
||||||
|
else:
|
||||||
|
return device_id
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||||
|
|||||||
@ -95,15 +95,6 @@ def with_amdsmi_context(fn):
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def device_id_to_physical_device_id(device_id: int) -> int:
|
|
||||||
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
|
||||||
device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
|
|
||||||
physical_device_id = device_ids[device_id]
|
|
||||||
return int(physical_device_id)
|
|
||||||
else:
|
|
||||||
return device_id
|
|
||||||
|
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def on_mi250_mi300() -> bool:
|
def on_mi250_mi300() -> bool:
|
||||||
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||||
@ -238,7 +229,7 @@ class RocmPlatform(Platform):
|
|||||||
@with_amdsmi_context
|
@with_amdsmi_context
|
||||||
@lru_cache(maxsize=8)
|
@lru_cache(maxsize=8)
|
||||||
def get_device_name(cls, device_id: int = 0) -> str:
|
def get_device_name(cls, device_id: int = 0) -> str:
|
||||||
physical_device_id = device_id_to_physical_device_id(device_id)
|
physical_device_id = cls.device_id_to_physical_device_id(device_id)
|
||||||
handle = amdsmi_get_processor_handles()[physical_device_id]
|
handle = amdsmi_get_processor_handles()[physical_device_id]
|
||||||
asic_info = amdsmi_get_gpu_asic_info(handle)
|
asic_info = amdsmi_get_gpu_asic_info(handle)
|
||||||
device_name: str = asic_info["device_id"]
|
device_name: str = asic_info["device_id"]
|
||||||
|
|||||||
@ -622,13 +622,12 @@ class DPEngineCoreProc(EngineCoreProc):
|
|||||||
assert 0 <= local_dp_rank <= dp_rank < dp_size
|
assert 0 <= local_dp_rank <= dp_rank < dp_size
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
if current_platform.is_cuda_alike():
|
device_control_env_var = current_platform.device_control_env_var
|
||||||
from vllm.platforms.cuda import device_id_to_physical_device_id
|
tp_size = vllm_config.parallel_config.tensor_parallel_size
|
||||||
tp_size = vllm_config.parallel_config.tensor_parallel_size
|
os.environ[device_control_env_var] = ",".join(
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
|
str(current_platform.device_id_to_physical_device_id(i))
|
||||||
str(device_id_to_physical_device_id(i))
|
for i in range(local_dp_rank * tp_size, (local_dp_rank + 1) *
|
||||||
for i in range(local_dp_rank * tp_size, (local_dp_rank + 1) *
|
tp_size))
|
||||||
tp_size))
|
|
||||||
|
|
||||||
self.local_dp_rank = local_dp_rank
|
self.local_dp_rank = local_dp_rank
|
||||||
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
|
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user