diff --git a/requirements/tpu.txt b/requirements/tpu.txt index 4241cbb2b033..e6fff58f7b79 100644 --- a/requirements/tpu.txt +++ b/requirements/tpu.txt @@ -12,6 +12,4 @@ ray[data] setuptools==78.1.0 nixl==0.3.0 tpu_info==0.4.0 - -# Install torch_xla -torch_xla[tpu, pallas]==2.8.0 \ No newline at end of file +tpu-inference==0.11.1 diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index a7724a86cc6a..fa99078e9ff0 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -97,11 +97,3 @@ class TpuCommunicator(DeviceCommunicatorBase): def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: assert dim == -1, "TPUs only support dim=-1 for all-gather." return xm.all_gather(input_, dim=dim) - - -if USE_TPU_INFERENCE: - from tpu_inference.distributed.device_communicators import ( - TpuCommunicator as TpuInferenceCommunicator, - ) - - TpuCommunicator = TpuInferenceCommunicator # type: ignore diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 944344a22957..aa5ddbe43659 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -267,7 +267,9 @@ class TpuPlatform(Platform): try: - from tpu_inference.platforms import TpuPlatform as TpuInferencePlatform + from tpu_inference.platforms.tpu_platforms import ( + TpuPlatform as TpuInferencePlatform, + ) TpuPlatform = TpuInferencePlatform # type: ignore USE_TPU_INFERENCE = True diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index e1a109eca0a8..ce18ca6c3716 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -346,6 +346,6 @@ class TPUWorker: if USE_TPU_INFERENCE: - from tpu_inference.worker import TPUWorker as TpuInferenceWorker + from tpu_inference.worker.tpu_worker import TPUWorker as TpuInferenceWorker TPUWorker = TpuInferenceWorker # type: ignore