Add is_tpu

This commit is contained in:
Woosuk Kwon 2024-04-01 03:18:36 +00:00
parent d148c2ef00
commit 3b8f43024f

View File

@ -1,6 +1,7 @@
import asyncio
import enum
import gc
import importlib
import os
import socket
import subprocess
@ -126,6 +127,11 @@ def is_neuron() -> bool:
return transformers_neuronx is not None
@lru_cache(maxsize=None)
def is_tpu() -> bool:
return importlib.util.find_spec("libtpu") is not None
@lru_cache(maxsize=None)
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
"""Returns the maximum shared memory per thread block in bytes."""