[TPU] Support multi-host inference (#7457)

This commit is contained in:
Woosuk Kwon 2024-08-13 16:31:20 -07:00 committed by GitHub
parent 16422ea76f
commit a08df8322e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 11 additions and 4 deletions

View File

@ -8,7 +8,7 @@ vLLM supports Google Cloud TPUs using PyTorch XLA.
Requirements
------------
* Google Cloud TPU VM (single host)
* Google Cloud TPU VM (single & multi host)
* TPU versions: v5e, v5p, v4
* Python: 3.10

View File

@ -1,3 +1,4 @@
import ray
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
@ -18,9 +19,15 @@ class TpuCommunicator:
return
self.disabled = False
local_rank = dist.get_rank(group)
world_size = dist.get_world_size(group)
pjrt.initialize_multiprocess(local_rank, world_size)
# NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node
# must be used together. Therefore, the local rank and world size can
# be simply calculated as follows.
global_rank = dist.get_rank(group)
global_world_size = dist.get_world_size(group)
num_nodes = len(ray.nodes())
local_world_size = global_world_size // num_nodes
local_rank = global_rank % local_world_size
pjrt.initialize_multiprocess(local_rank, local_world_size)
xr._init_world_size_ordinal()
def all_reduce(self, x: torch.Tensor) -> torch.Tensor: