diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index 16ec84b43cacc..81a141e86206a 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -1,4 +1,3 @@ -import ray import torch import torch.distributed as dist from torch.distributed import ProcessGroup @@ -6,6 +5,7 @@ from torch.distributed import ProcessGroup from vllm.platforms import current_platform if current_platform.is_tpu(): + import ray import torch_xla.core.xla_model as xm import torch_xla.runtime as xr from torch_xla._internal import pjrt