Add TPU to DeviceConfig

This commit is contained in:
Woosuk Kwon 2024-04-01 03:19:17 +00:00
parent 3b8f43024f
commit 824521c987

View File

@ -10,7 +10,8 @@ from transformers import PretrainedConfig
from vllm.logger import init_logger
from vllm.transformers_utils.config import get_config, get_hf_text_config
from vllm.utils import get_cpu_memory, get_nvcc_cuda_version, is_hip, is_neuron
from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_hip,
is_neuron, is_tpu)
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
@ -582,6 +583,8 @@ class DeviceConfig:
# Automated device type detection
if is_neuron():
self.device_type = "neuron"
elif is_tpu():
self.device_type = "tpu"
else:
# We don't call torch.cuda.is_available() here to
# avoid initializing CUDA before workers are forked