mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-26 04:12:34 +08:00
Add TPU to setup
This commit is contained in:
parent
27c592b97b
commit
4cdb732cef
20
setup.py
20
setup.py
@ -185,7 +185,8 @@ class cmake_build_ext(build_ext):
|
||||
|
||||
|
||||
def _is_cuda() -> bool:
|
||||
return torch.version.cuda is not None and not _is_neuron()
|
||||
has_cuda = torch.version.cuda is not None
|
||||
return has_cuda and not (_is_neuron() or _is_tpu())
|
||||
|
||||
|
||||
def _is_hip() -> bool:
|
||||
@ -201,6 +202,14 @@ def _is_neuron() -> bool:
|
||||
return torch_neuronx_installed
|
||||
|
||||
|
||||
def _is_tpu() -> bool:
|
||||
return True # FIXME
|
||||
|
||||
|
||||
def _build_custom_ops() -> bool:
|
||||
return _is_cuda() or _is_hip()
|
||||
|
||||
|
||||
def _install_punica() -> bool:
|
||||
return bool(int(os.getenv("VLLM_INSTALL_PUNICA_KERNELS", "0")))
|
||||
|
||||
@ -296,6 +305,8 @@ def get_vllm_version() -> str:
|
||||
if neuron_version != MAIN_CUDA_VERSION:
|
||||
neuron_version_str = neuron_version.replace(".", "")[:3]
|
||||
version += f"+neuron{neuron_version_str}"
|
||||
elif _is_tpu():
|
||||
version += "+tpu"
|
||||
else:
|
||||
raise RuntimeError("Unknown runtime environment")
|
||||
|
||||
@ -322,6 +333,9 @@ def get_requirements() -> List[str]:
|
||||
elif _is_neuron():
|
||||
with open(get_path("requirements-neuron.txt")) as f:
|
||||
requirements = f.read().strip().split("\n")
|
||||
elif _is_tpu():
|
||||
with open(get_path("requirements-tpu.txt")) as f:
|
||||
requirements = f.read().strip().split("\n")
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported platform, please use CUDA, ROCM or Neuron.")
|
||||
@ -337,7 +351,7 @@ if _is_cuda():
|
||||
if _install_punica():
|
||||
ext_modules.append(CMakeExtension(name="vllm._punica_C"))
|
||||
|
||||
if not _is_neuron():
|
||||
if _build_custom_ops():
|
||||
ext_modules.append(CMakeExtension(name="vllm._C"))
|
||||
|
||||
package_data = {
|
||||
@ -373,6 +387,6 @@ setup(
|
||||
python_requires=">=3.8",
|
||||
install_requires=get_requirements(),
|
||||
ext_modules=ext_modules,
|
||||
cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {},
|
||||
cmdclass={"build_ext": cmake_build_ext} if _build_custom_ops() else {},
|
||||
package_data=package_data,
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user