diff --git a/vllm/config.py b/vllm/config.py index 62f1d70079648..a901729194c9b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -10,7 +10,8 @@ from transformers import PretrainedConfig from vllm.logger import init_logger from vllm.transformers_utils.config import get_config, get_hf_text_config -from vllm.utils import get_cpu_memory, get_nvcc_cuda_version, is_hip, is_neuron +from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_hip, + is_neuron, is_tpu) if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup @@ -582,6 +583,8 @@ class DeviceConfig: # Automated device type detection if is_neuron(): self.device_type = "neuron" + elif is_tpu(): + self.device_type = "tpu" else: # We don't call torch.cuda.is_available() here to # avoid initializing CUDA before workers are forked