mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 10:45:01 +08:00
[platform] add device_control env var (#12009)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
e8c23ff989
commit
458e63a2c6
@ -78,6 +78,7 @@ class CudaPlatformBase(Platform):
|
||||
device_type: str = "cuda"
|
||||
dispatch_key: str = "CUDA"
|
||||
ray_device_key: str = "GPU"
|
||||
device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
|
||||
|
||||
@classmethod
|
||||
def get_device_capability(cls,
|
||||
|
||||
@ -20,6 +20,7 @@ class HpuPlatform(Platform):
|
||||
device_type: str = "hpu"
|
||||
dispatch_key: str = "HPU"
|
||||
ray_device_key: str = "HPU"
|
||||
device_control_env_var: str = "HABANA_VISIBLE_MODULES"
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||
|
||||
@ -78,20 +78,30 @@ class Platform:
|
||||
_enum: PlatformEnum
|
||||
device_name: str
|
||||
device_type: str
|
||||
|
||||
# available dispatch keys:
|
||||
# check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa
|
||||
# use "CPU" as a fallback for platforms not registered in PyTorch
|
||||
dispatch_key: str = "CPU"
|
||||
|
||||
# available ray device keys:
|
||||
# https://github.com/ray-project/ray/blob/10ba5adadcc49c60af2c358a33bb943fb491a171/python/ray/_private/ray_constants.py#L438 # noqa
|
||||
# empty string means the device does not support ray
|
||||
ray_device_key: str = ""
|
||||
|
||||
# platform-agnostic way to specify the device control environment variable,
|
||||
# .e.g. CUDA_VISIBLE_DEVICES for CUDA.
|
||||
# hint: search for "get_visible_accelerator_ids_env_var" in
|
||||
# https://github.com/ray-project/ray/tree/master/python/ray/_private/accelerators # noqa
|
||||
device_control_env_var: str = "VLLM_DEVICE_CONTROL_ENV_VAR_PLACEHOLDER"
|
||||
|
||||
# The torch.compile backend for compiling simple and
|
||||
# standalone functions. The default value is "inductor" to keep
|
||||
# the same behavior as PyTorch.
|
||||
# NOTE: for the forward part of the model, vLLM has another separate
|
||||
# compilation strategy.
|
||||
simple_compile_backend: str = "inductor"
|
||||
|
||||
supported_quantization: list[str] = []
|
||||
|
||||
def is_cuda(self) -> bool:
|
||||
|
||||
@ -18,6 +18,7 @@ class NeuronPlatform(Platform):
|
||||
device_type: str = "neuron"
|
||||
ray_device_key: str = "neuron_cores"
|
||||
supported_quantization: list[str] = ["neuron_quant"]
|
||||
device_control_env_var: str = "NEURON_RT_VISIBLE_CORES"
|
||||
|
||||
@classmethod
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
|
||||
@ -65,6 +65,8 @@ class RocmPlatform(Platform):
|
||||
device_type: str = "cuda"
|
||||
dispatch_key: str = "CUDA"
|
||||
ray_device_key: str = "GPU"
|
||||
# rocm shares the same device control env var as CUDA
|
||||
device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
|
||||
|
||||
supported_quantization: list[str] = [
|
||||
"awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
|
||||
|
||||
@ -20,6 +20,7 @@ class TpuPlatform(Platform):
|
||||
device_type: str = "tpu"
|
||||
dispatch_key: str = "XLA"
|
||||
ray_device_key: str = "TPU"
|
||||
device_control_env_var: str = "TPU_VISIBLE_CHIPS"
|
||||
|
||||
supported_quantization: list[str] = [
|
||||
"tpu_int8", "compressed-tensors", "compressed_tensors"
|
||||
|
||||
@ -22,6 +22,7 @@ class XPUPlatform(Platform):
|
||||
# Intel XPU's device key is "GPU" for Ray.
|
||||
# see https://github.com/ray-project/ray/blob/6a5eb5865eeb9ccf058a79b44f107e327e360673/python/ray/_private/accelerators/intel_gpu.py#L20 # noqa: E501
|
||||
ray_device_key: str = "GPU"
|
||||
device_control_env_var: str = "ONEAPI_DEVICE_SELECTOR"
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user