From ba1fcd84a7f1dc907c17bf4ba4fab6762a9f33a1 Mon Sep 17 00:00:00 2001 From: Johnny Yang <24908445+jcyang43@users.noreply.github.com> Date: Wed, 26 Nov 2025 14:46:36 -0800 Subject: [PATCH] [TPU] add tpu_inference (#27277) Signed-off-by: Johnny Yang --- requirements/tpu.txt | 4 +--- vllm/distributed/device_communicators/tpu_communicator.py | 8 -------- vllm/platforms/tpu.py | 4 +++- vllm/v1/worker/tpu_worker.py | 2 +- 4 files changed, 5 insertions(+), 13 deletions(-) diff --git a/requirements/tpu.txt b/requirements/tpu.txt index 4241cbb2b033..e6fff58f7b79 100644 --- a/requirements/tpu.txt +++ b/requirements/tpu.txt @@ -12,6 +12,4 @@ ray[data] setuptools==78.1.0 nixl==0.3.0 tpu_info==0.4.0 - -# Install torch_xla -torch_xla[tpu, pallas]==2.8.0 \ No newline at end of file +tpu-inference==0.11.1 diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index a7724a86cc6a..fa99078e9ff0 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -97,11 +97,3 @@ class TpuCommunicator(DeviceCommunicatorBase): def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: assert dim == -1, "TPUs only support dim=-1 for all-gather." return xm.all_gather(input_, dim=dim) - - -if USE_TPU_INFERENCE: - from tpu_inference.distributed.device_communicators import ( - TpuCommunicator as TpuInferenceCommunicator, - ) - - TpuCommunicator = TpuInferenceCommunicator # type: ignore diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 944344a22957..aa5ddbe43659 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -267,7 +267,9 @@ class TpuPlatform(Platform): try: - from tpu_inference.platforms import TpuPlatform as TpuInferencePlatform + from tpu_inference.platforms.tpu_platforms import ( + TpuPlatform as TpuInferencePlatform, + ) TpuPlatform = TpuInferencePlatform # type: ignore USE_TPU_INFERENCE = True diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index e1a109eca0a8..ce18ca6c3716 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -346,6 +346,6 @@ class TPUWorker: if USE_TPU_INFERENCE: - from tpu_inference.worker import TPUWorker as TpuInferenceWorker + from tpu_inference.worker.tpu_worker import TPUWorker as TpuInferenceWorker TPUWorker = TpuInferenceWorker # type: ignore