[Hardware][TPU] Optionally import for TPU backend (#18269)

Signed-off-by: Siyuan Liu <lsiyuan@google.com>
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
Co-authored-by: Carol Zheng <cazheng@google.com>
Co-authored-by: Jade Zheng <zheng.shoujian@outlook.com>
Co-authored-by: Hongmin Fan <fanhongmin@google.com>
This commit is contained in:
Siyuan Liu 2025-05-17 00:23:12 -07:00 committed by GitHub
parent 3e0d435027
commit 48ac2bed5b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 25 additions and 0 deletions

View File

@ -91,3 +91,12 @@ class TpuCommunicator(DeviceCommunicatorBase):
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
assert dim == -1, "TPUs only support dim=-1 for all-gather."
return xm.all_gather(input_, dim=dim)
try:
from tpu_commons.distributed.device_communicators import (
TpuCommunicator as TpuCommonsCommunicator)
TpuCommunicator = TpuCommonsCommunicator # type: ignore
except ImportError:
logger.info("tpu_commons not found, using vLLM's TpuCommunicator")
pass

View File

@ -194,3 +194,11 @@ class TpuPlatform(Platform):
if params.sampling_type == SamplingType.RANDOM_SEED:
raise ValueError(
"Torch XLA does not support per-request seed.")
try:
from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform
TpuPlatform = TpuCommonsPlatform # type: ignore
except ImportError:
logger.info("tpu_commons not found, using vLLM's TpuPlatform")
pass

View File

@ -267,3 +267,11 @@ def init_tpu_worker_distributed_environment(
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size,
parallel_config.enable_expert_parallel)
try:
from tpu_commons.worker import TPUWorker as TPUCommonsWorker
TPUWorker = TPUCommonsWorker # type: ignore
except ImportError:
logger.info("tpu_commons not found, using vLLM's TPUWorker.")
pass