diff --git a/vllm/utils.py b/vllm/utils.py index 93fff4ffc9361..780269f7e8ff5 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1,6 +1,7 @@ import asyncio import enum import gc +import importlib import os import socket import subprocess @@ -126,6 +127,11 @@ def is_neuron() -> bool: return transformers_neuronx is not None +@lru_cache(maxsize=None) +def is_tpu() -> bool: + return importlib.util.find_spec("libtpu") is not None + + @lru_cache(maxsize=None) def get_max_shared_memory_bytes(gpu: int = 0) -> int: """Returns the maximum shared memory per thread block in bytes."""