diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index de66ceaeef6f..a1775279661d 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -91,3 +91,12 @@ class TpuCommunicator(DeviceCommunicatorBase): def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: assert dim == -1, "TPUs only support dim=-1 for all-gather." return xm.all_gather(input_, dim=dim) + + +try: + from tpu_commons.distributed.device_communicators import ( + TpuCommunicator as TpuCommonsCommunicator) + TpuCommunicator = TpuCommonsCommunicator # type: ignore +except ImportError: + logger.info("tpu_commons not found, using vLLM's TpuCommunicator") + pass diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 41ed94fb619e..6c573c1b3635 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -194,3 +194,11 @@ class TpuPlatform(Platform): if params.sampling_type == SamplingType.RANDOM_SEED: raise ValueError( "Torch XLA does not support per-request seed.") + + +try: + from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform + TpuPlatform = TpuCommonsPlatform # type: ignore +except ImportError: + logger.info("tpu_commons not found, using vLLM's TpuPlatform") + pass diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 25715407ceee..ae3735ab0255 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -267,3 +267,11 @@ def init_tpu_worker_distributed_environment( ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size, parallel_config.enable_expert_parallel) + + +try: + from tpu_commons.worker import TPUWorker as TPUCommonsWorker + TPUWorker = TPUCommonsWorker # type: ignore +except ImportError: + logger.info("tpu_commons not found, using vLLM's TPUWorker.") + pass