From 4cdb732cef5db3d5055ba06254cd8159c621b983 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 1 Apr 2024 07:07:38 +0000 Subject: [PATCH] Add TPU to setup --- setup.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 225fda0a0b412..813f865ffb63e 100644 --- a/setup.py +++ b/setup.py @@ -185,7 +185,8 @@ class cmake_build_ext(build_ext): def _is_cuda() -> bool: - return torch.version.cuda is not None and not _is_neuron() + has_cuda = torch.version.cuda is not None + return has_cuda and not (_is_neuron() or _is_tpu()) def _is_hip() -> bool: @@ -201,6 +202,14 @@ def _is_neuron() -> bool: return torch_neuronx_installed +def _is_tpu() -> bool: + return True # FIXME + + +def _build_custom_ops() -> bool: + return _is_cuda() or _is_hip() + + def _install_punica() -> bool: return bool(int(os.getenv("VLLM_INSTALL_PUNICA_KERNELS", "0"))) @@ -296,6 +305,8 @@ def get_vllm_version() -> str: if neuron_version != MAIN_CUDA_VERSION: neuron_version_str = neuron_version.replace(".", "")[:3] version += f"+neuron{neuron_version_str}" + elif _is_tpu(): + version += "+tpu" else: raise RuntimeError("Unknown runtime environment") @@ -322,6 +333,9 @@ def get_requirements() -> List[str]: elif _is_neuron(): with open(get_path("requirements-neuron.txt")) as f: requirements = f.read().strip().split("\n") + elif _is_tpu(): + with open(get_path("requirements-tpu.txt")) as f: + requirements = f.read().strip().split("\n") else: raise ValueError( "Unsupported platform, please use CUDA, ROCM or Neuron.") @@ -337,7 +351,7 @@ if _is_cuda(): if _install_punica(): ext_modules.append(CMakeExtension(name="vllm._punica_C")) -if not _is_neuron(): +if _build_custom_ops(): ext_modules.append(CMakeExtension(name="vllm._C")) package_data = { @@ -373,6 +387,6 @@ setup( python_requires=">=3.8", install_requires=get_requirements(), ext_modules=ext_modules, - cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {}, + cmdclass={"build_ext": cmake_build_ext} if _build_custom_ops() else {}, package_data=package_data, )