mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 06:34:58 +08:00
[TPU] optimize the all-reduce performance (#15903)
Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
parent
1b84eff03a
commit
01b6113659
@ -22,6 +22,8 @@ if current_platform.is_tpu():
|
|||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
import torch_xla.runtime as xr
|
import torch_xla.runtime as xr
|
||||||
from torch_xla._internal import pjrt
|
from torch_xla._internal import pjrt
|
||||||
|
from torch_xla.distributed.xla_multiprocessing import (
|
||||||
|
create_optimized_replica_groups)
|
||||||
|
|
||||||
if USE_RAY:
|
if USE_RAY:
|
||||||
from vllm.executor import ray_utils
|
from vllm.executor import ray_utils
|
||||||
@ -79,9 +81,12 @@ class TpuCommunicator(DeviceCommunicatorBase):
|
|||||||
|
|
||||||
pjrt.initialize_multiprocess(local_rank, local_world_size)
|
pjrt.initialize_multiprocess(local_rank, local_world_size)
|
||||||
xr._init_world_size_ordinal()
|
xr._init_world_size_ordinal()
|
||||||
|
self.groups = create_optimized_replica_groups()
|
||||||
|
|
||||||
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
||||||
return xm.all_reduce(xm.REDUCE_SUM, input_)
|
# TODO: Remove the groups specification after XLA compiler can support
|
||||||
|
# auto-reordering the ring order for all-reduce.
|
||||||
|
return xm.all_reduce(xm.REDUCE_SUM, input_, groups=self.groups)
|
||||||
|
|
||||||
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||||
assert dim == -1, "TPUs only support dim=-1 for all-gather."
|
assert dim == -1, "TPUs only support dim=-1 for all-gather."
|
||||||
|
|||||||
@ -119,11 +119,13 @@ def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
|
|||||||
|
|
||||||
|
|
||||||
if supports_custom_op():
|
if supports_custom_op():
|
||||||
|
from vllm.platforms import current_platform
|
||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="all_reduce",
|
op_name="all_reduce",
|
||||||
op_func=all_reduce,
|
op_func=all_reduce,
|
||||||
mutates_args=[],
|
mutates_args=[],
|
||||||
fake_impl=all_reduce_fake,
|
fake_impl=all_reduce_fake,
|
||||||
|
dispatch_key=current_platform.dispatch_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -219,7 +221,8 @@ class GroupCoordinator:
|
|||||||
self.cpu_group, 1 << 22, 6)
|
self.cpu_group, 1 << 22, 6)
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
self.use_custom_op_call = current_platform.is_cuda_alike()
|
self.use_custom_op_call = (current_platform.is_cuda_alike()
|
||||||
|
or current_platform.is_tpu())
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def first_rank(self):
|
def first_rank(self):
|
||||||
|
|||||||
@ -84,6 +84,12 @@ class TPUWorker:
|
|||||||
|
|
||||||
def init_device(self):
|
def init_device(self):
|
||||||
os.environ["PJRT_DEVICE"] = "TPU"
|
os.environ["PJRT_DEVICE"] = "TPU"
|
||||||
|
# Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D
|
||||||
|
# ring, the xla tpu compiler flag
|
||||||
|
# `xla_tpu_force_1d_allreduce_at_chunk_count` is a temporary solution to
|
||||||
|
# fix this. It will be removed after the bug in XLA compiler is fixed.
|
||||||
|
os.environ["LIBTPU_INIT_ARGS"] = (
|
||||||
|
"--xla_tpu_force_1d_allreduce_at_chunk_count=1")
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
torch.set_default_dtype(self.model_config.dtype)
|
torch.set_default_dtype(self.model_config.dtype)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user