From 48ac2bed5b6271f82187de61245a85d987197c6f Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Sat, 17 May 2025 00:23:12 -0700 Subject: [PATCH] [Hardware][TPU] Optionally import for TPU backend (#18269) Signed-off-by: Siyuan Liu Signed-off-by: Jade Zheng Co-authored-by: Carol Zheng Co-authored-by: Jade Zheng Co-authored-by: Hongmin Fan --- .../distributed/device_communicators/tpu_communicator.py | 9 +++++++++ vllm/platforms/tpu.py | 8 ++++++++ vllm/v1/worker/tpu_worker.py | 8 ++++++++ 3 files changed, 25 insertions(+) diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index de66ceaeef6f1..a1775279661d1 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -91,3 +91,12 @@ class TpuCommunicator(DeviceCommunicatorBase): def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: assert dim == -1, "TPUs only support dim=-1 for all-gather." return xm.all_gather(input_, dim=dim) + + +try: + from tpu_commons.distributed.device_communicators import ( + TpuCommunicator as TpuCommonsCommunicator) + TpuCommunicator = TpuCommonsCommunicator # type: ignore +except ImportError: + logger.info("tpu_commons not found, using vLLM's TpuCommunicator") + pass diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 41ed94fb619e4..6c573c1b3635e 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -194,3 +194,11 @@ class TpuPlatform(Platform): if params.sampling_type == SamplingType.RANDOM_SEED: raise ValueError( "Torch XLA does not support per-request seed.") + + +try: + from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform + TpuPlatform = TpuCommonsPlatform # type: ignore +except ImportError: + logger.info("tpu_commons not found, using vLLM's TpuPlatform") + pass diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 25715407ceeee..ae3735ab0255f 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -267,3 +267,11 @@ def init_tpu_worker_distributed_environment( ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size, parallel_config.enable_expert_parallel) + + +try: + from tpu_commons.worker import TPUWorker as TPUCommonsWorker + TPUWorker = TPUCommonsWorker # type: ignore +except ImportError: + logger.info("tpu_commons not found, using vLLM's TPUWorker.") + pass