mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:05:01 +08:00
[torch.compile] improve allreduce registration (#9061)
This commit is contained in:
parent
cc90419e89
commit
663874e048
@ -265,24 +265,21 @@ class CustomAllreduce:
|
|||||||
|
|
||||||
def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
|
def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
|
||||||
# when custom allreduce is disabled, this will be None
|
# when custom allreduce is disabled, this will be None
|
||||||
if self.disabled:
|
if self.disabled or not self.should_custom_ar(input):
|
||||||
return None
|
return None
|
||||||
if self._IS_CAPTURING:
|
if self._IS_CAPTURING:
|
||||||
if torch.cuda.is_current_stream_capturing():
|
if torch.cuda.is_current_stream_capturing():
|
||||||
if self.should_custom_ar(input):
|
return self.all_reduce_reg(input)
|
||||||
return self.all_reduce_reg(input)
|
|
||||||
else:
|
else:
|
||||||
if self.should_custom_ar(input):
|
# if warm up, mimic the allocation pattern
|
||||||
# if warm up, mimic the allocation pattern
|
# since custom allreduce is out-of-place
|
||||||
# since custom allreduce is out-of-place
|
return torch.empty_like(input)
|
||||||
return torch.empty_like(input)
|
|
||||||
else:
|
else:
|
||||||
# note: outside of cuda graph context,
|
# note: outside of cuda graph context,
|
||||||
# custom allreduce incurs a cost of cudaMemcpy, which should
|
# custom allreduce incurs a cost of cudaMemcpy, which should
|
||||||
# be small(<=1% of overall latency) compared to the performance
|
# be small(<=1% of overall latency) compared to the performance
|
||||||
# gains of using custom kernels
|
# gains of using custom kernels
|
||||||
if self.should_custom_ar(input):
|
return self.all_reduce_unreg(input)
|
||||||
return self.all_reduce_unreg(input)
|
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@ -105,7 +105,7 @@ if supports_custom_op():
|
|||||||
group = _groups[group_name]()
|
group = _groups[group_name]()
|
||||||
if group is None:
|
if group is None:
|
||||||
raise ValueError(f"Group {group_name} is destroyed.")
|
raise ValueError(f"Group {group_name} is destroyed.")
|
||||||
group._all_reduce(tensor)
|
group._all_reduce_in_place(tensor)
|
||||||
|
|
||||||
@inplace_all_reduce.register_fake
|
@inplace_all_reduce.register_fake
|
||||||
def _(tensor: torch.Tensor, group_name: str) -> None:
|
def _(tensor: torch.Tensor, group_name: str) -> None:
|
||||||
@ -118,7 +118,7 @@ if supports_custom_op():
|
|||||||
group = _groups[group_name]()
|
group = _groups[group_name]()
|
||||||
if group is None:
|
if group is None:
|
||||||
raise ValueError(f"Group {group_name} is destroyed.")
|
raise ValueError(f"Group {group_name} is destroyed.")
|
||||||
return group._all_reduce(tensor)
|
return group._all_reduce_out_place(tensor)
|
||||||
|
|
||||||
@outplace_all_reduce.register_fake
|
@outplace_all_reduce.register_fake
|
||||||
def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
|
def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
|
||||||
@ -338,14 +338,17 @@ class GroupCoordinator:
|
|||||||
return input_
|
return input_
|
||||||
|
|
||||||
if not supports_custom_op():
|
if not supports_custom_op():
|
||||||
return self._all_reduce(input_)
|
self._all_reduce_in_place(input_)
|
||||||
|
return input_
|
||||||
|
|
||||||
if self.tpu_communicator is not None and \
|
if self.tpu_communicator is not None and \
|
||||||
not self.tpu_communicator.disabled:
|
not self.tpu_communicator.disabled:
|
||||||
# TPU handles Dynamo with its own logic.
|
# TPU handles Dynamo with its own logic.
|
||||||
return self._all_reduce(input_)
|
return self.tpu_communicator.all_reduce(input_)
|
||||||
|
|
||||||
if self.ca_comm is not None and self.ca_comm.should_custom_ar(input_):
|
if self.ca_comm is not None and \
|
||||||
|
not self.ca_comm.disabled and \
|
||||||
|
self.ca_comm.should_custom_ar(input_):
|
||||||
return torch.ops.vllm.outplace_all_reduce(
|
return torch.ops.vllm.outplace_all_reduce(
|
||||||
input_, group_name=self.unique_name)
|
input_, group_name=self.unique_name)
|
||||||
else:
|
else:
|
||||||
@ -353,25 +356,15 @@ class GroupCoordinator:
|
|||||||
group_name=self.unique_name)
|
group_name=self.unique_name)
|
||||||
return input_
|
return input_
|
||||||
|
|
||||||
def _all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
|
||||||
The actual all-reduce implementation.
|
|
||||||
|
|
||||||
NOTE: This operation will be applied in-place or out-of-place.
|
|
||||||
Always assume this function modifies its input, but use the return
|
|
||||||
value as the output.
|
|
||||||
"""
|
|
||||||
ca_comm = self.ca_comm
|
ca_comm = self.ca_comm
|
||||||
|
assert ca_comm is not None
|
||||||
|
assert not ca_comm.disabled
|
||||||
|
out = ca_comm.custom_all_reduce(input_)
|
||||||
|
assert out is not None
|
||||||
|
return out
|
||||||
|
|
||||||
# For TPUs, use TPU communicator.
|
def _all_reduce_in_place(self, input_: torch.Tensor) -> None:
|
||||||
tpu_comm = self.tpu_communicator
|
|
||||||
if tpu_comm is not None and not tpu_comm.disabled:
|
|
||||||
return tpu_comm.all_reduce(input_)
|
|
||||||
|
|
||||||
if ca_comm is not None:
|
|
||||||
out = ca_comm.custom_all_reduce(input_)
|
|
||||||
if out is not None:
|
|
||||||
return out
|
|
||||||
pynccl_comm = self.pynccl_comm
|
pynccl_comm = self.pynccl_comm
|
||||||
if (pynccl_comm is not None and not pynccl_comm.disabled):
|
if (pynccl_comm is not None and not pynccl_comm.disabled):
|
||||||
pynccl_comm.all_reduce(input_)
|
pynccl_comm.all_reduce(input_)
|
||||||
@ -380,7 +373,6 @@ class GroupCoordinator:
|
|||||||
ipex.distributed.all_reduce(input_, group=self.device_group)
|
ipex.distributed.all_reduce(input_, group=self.device_group)
|
||||||
else:
|
else:
|
||||||
torch.distributed.all_reduce(input_, group=self.device_group)
|
torch.distributed.all_reduce(input_, group=self.device_group)
|
||||||
return input_
|
|
||||||
|
|
||||||
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||||
world_size = self.world_size
|
world_size = self.world_size
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user