mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:15:01 +08:00
custom allreduce + torch.compile (#10121)
Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
519e8e4182
commit
9a88f89799
@ -86,7 +86,6 @@ If GPU/CPU communication cannot be established, you can use the following Python
|
||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
|
||||
pynccl = PyNcclCommunicator(group=gloo_group, device=local_rank)
|
||||
pynccl.disabled = False
|
||||
|
||||
s = torch.cuda.Stream()
|
||||
with torch.cuda.stream(s):
|
||||
|
||||
@ -60,7 +60,7 @@ def worker_fn():
|
||||
tensor = torch.ones(16, 1024, 1024,
|
||||
dtype=torch.float32).cuda(pynccl_comm.rank)
|
||||
with pynccl_comm.change_state(enable=True):
|
||||
pynccl_comm.all_reduce(tensor)
|
||||
tensor = pynccl_comm.all_reduce(tensor)
|
||||
result = tensor.mean().cpu().item()
|
||||
assert result == pynccl_comm.world_size
|
||||
|
||||
@ -84,12 +84,12 @@ def multiple_allreduce_worker_fn():
|
||||
with pynccl_comm.change_state(enable=True):
|
||||
# two groups can communicate independently
|
||||
if torch.distributed.get_rank() in [0, 1]:
|
||||
pynccl_comm.all_reduce(tensor)
|
||||
pynccl_comm.all_reduce(tensor)
|
||||
tensor = pynccl_comm.all_reduce(tensor)
|
||||
tensor = pynccl_comm.all_reduce(tensor)
|
||||
result = tensor.mean().cpu().item()
|
||||
assert result == 4
|
||||
else:
|
||||
pynccl_comm.all_reduce(tensor)
|
||||
tensor = pynccl_comm.all_reduce(tensor)
|
||||
result = tensor.mean().cpu().item()
|
||||
assert result == 2
|
||||
|
||||
@ -140,14 +140,11 @@ def worker_fn_with_cudagraph():
|
||||
with torch.cuda.graph(
|
||||
graph, stream=pynccl_comm.stream), pynccl_comm.change_state(
|
||||
enable=True):
|
||||
# operation during the graph capture is recorded but not executed
|
||||
# see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture # noqa
|
||||
pynccl_comm.all_reduce(a)
|
||||
a_out = pynccl_comm.all_reduce(a)
|
||||
pynccl_comm.stream.synchronize()
|
||||
assert a.mean().cpu().item() == pynccl_comm.world_size**0
|
||||
graph.replay()
|
||||
pynccl_comm.stream.synchronize()
|
||||
assert a.mean().cpu().item() == pynccl_comm.world_size**1
|
||||
assert a_out.mean().cpu().item() == pynccl_comm.world_size**1
|
||||
|
||||
|
||||
@worker_fn_wrapper
|
||||
|
||||
@ -70,14 +70,12 @@ def gpu_worker(rank, WORLD_SIZE, port1, port2):
|
||||
rank=rank,
|
||||
world_size=WORLD_SIZE)
|
||||
pynccl1 = PyNcclCommunicator(pg1, device=rank)
|
||||
pynccl1.disabled = False
|
||||
if rank <= 2:
|
||||
pg2 = StatelessProcessGroup.create(host="127.0.0.1",
|
||||
port=port2,
|
||||
rank=rank,
|
||||
world_size=3)
|
||||
pynccl2 = PyNcclCommunicator(pg2, device=rank)
|
||||
pynccl2.disabled = False
|
||||
data = torch.tensor([rank]).cuda()
|
||||
pynccl1.all_reduce(data)
|
||||
pg1.barrier()
|
||||
|
||||
@ -106,30 +106,30 @@ class PyNcclCommunicator:
|
||||
self.stream.synchronize()
|
||||
del data
|
||||
|
||||
# by default it is disabled, e.g. in profiling models and prefill phase.
|
||||
# to use it, use under `with obj.change_state(enable=True)`, usually
|
||||
# when we are using CUDA graph.
|
||||
self.disabled = True
|
||||
|
||||
def all_reduce(self,
|
||||
tensor: torch.Tensor,
|
||||
in_tensor: torch.Tensor,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
stream=None):
|
||||
stream=None) -> torch.Tensor:
|
||||
if self.disabled:
|
||||
return
|
||||
return None
|
||||
# nccl communicator created on a specific device
|
||||
# will only work on tensors on the same device
|
||||
# otherwise it will cause "illegal memory access"
|
||||
assert tensor.device == self.device, (
|
||||
assert in_tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}")
|
||||
f"but the input tensor is on {in_tensor.device}")
|
||||
|
||||
out_tensor = torch.empty_like(in_tensor)
|
||||
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
self.nccl.ncclAllReduce(buffer_type(tensor.data_ptr()),
|
||||
buffer_type(tensor.data_ptr()), tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||
self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()),
|
||||
buffer_type(out_tensor.data_ptr()),
|
||||
in_tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(in_tensor.dtype),
|
||||
ncclRedOpTypeEnum.from_torch(op), self.comm,
|
||||
cudaStream_t(stream.cuda_stream))
|
||||
return out_tensor
|
||||
|
||||
def all_gather(self,
|
||||
output_tensor: torch.Tensor,
|
||||
|
||||
@ -96,42 +96,24 @@ def _register_group(group: "GroupCoordinator") -> None:
|
||||
_groups[group.unique_name] = weakref.ref(group)
|
||||
|
||||
|
||||
def all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
|
||||
assert group_name in _groups, f"Group {group_name} is not found."
|
||||
group = _groups[group_name]()
|
||||
if group is None:
|
||||
raise ValueError(f"Group {group_name} is destroyed.")
|
||||
return group._all_reduce_out_place(tensor)
|
||||
|
||||
|
||||
def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
|
||||
return torch.empty_like(tensor)
|
||||
|
||||
|
||||
if supports_custom_op():
|
||||
|
||||
def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
|
||||
assert group_name in _groups, f"Group {group_name} is not found."
|
||||
group = _groups[group_name]()
|
||||
if group is None:
|
||||
raise ValueError(f"Group {group_name} is destroyed.")
|
||||
group._all_reduce_in_place(tensor)
|
||||
|
||||
def inplace_all_reduce_fake(tensor: torch.Tensor, group_name: str) -> None:
|
||||
return
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="inplace_all_reduce",
|
||||
op_func=inplace_all_reduce,
|
||||
mutates_args=["tensor"],
|
||||
fake_impl=inplace_all_reduce_fake,
|
||||
)
|
||||
|
||||
def outplace_all_reduce(tensor: torch.Tensor,
|
||||
group_name: str) -> torch.Tensor:
|
||||
assert group_name in _groups, f"Group {group_name} is not found."
|
||||
group = _groups[group_name]()
|
||||
if group is None:
|
||||
raise ValueError(f"Group {group_name} is destroyed.")
|
||||
return group._all_reduce_out_place(tensor)
|
||||
|
||||
def outplace_all_reduce_fake(tensor: torch.Tensor,
|
||||
group_name: str) -> torch.Tensor:
|
||||
return torch.empty_like(tensor)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="outplace_all_reduce",
|
||||
op_func=outplace_all_reduce,
|
||||
op_name="all_reduce",
|
||||
op_func=all_reduce,
|
||||
mutates_args=[],
|
||||
fake_impl=outplace_all_reduce_fake,
|
||||
fake_impl=all_reduce_fake,
|
||||
)
|
||||
|
||||
|
||||
@ -317,30 +299,13 @@ class GroupCoordinator:
|
||||
stream.wait_stream(curr_stream)
|
||||
|
||||
with torch.cuda.stream(stream), maybe_ca_context:
|
||||
# In graph mode, we have to be very careful about the collective
|
||||
# operations. The current status is:
|
||||
# allreduce \ Mode | Eager | Graph |
|
||||
# --------------------------------------------
|
||||
# custom allreduce | enabled | enabled |
|
||||
# PyNccl | disabled| enabled |
|
||||
# torch.distributed | enabled | disabled|
|
||||
#
|
||||
# Note that custom allreduce will have a runtime check, if the
|
||||
# tensor size is too large, it will fallback to the next
|
||||
# available option.
|
||||
# In summary: When using CUDA graph, we use
|
||||
# either custom all-reduce kernel or pynccl. When not using
|
||||
# CUDA graph, we use either custom all-reduce kernel or
|
||||
# PyTorch NCCL. We always prioritize using custom all-reduce
|
||||
# kernel but fall back to PyTorch or pynccl if it is
|
||||
# disabled or not supported.
|
||||
pynccl_comm = self.pynccl_comm
|
||||
maybe_pynccl_context: Any
|
||||
if not pynccl_comm:
|
||||
maybe_pynccl_context = nullcontext()
|
||||
else:
|
||||
maybe_pynccl_context = pynccl_comm.change_state(
|
||||
enable=True, stream=torch.cuda.current_stream())
|
||||
stream=torch.cuda.current_stream())
|
||||
with maybe_pynccl_context:
|
||||
yield graph_capture_context
|
||||
|
||||
@ -356,8 +321,8 @@ class GroupCoordinator:
|
||||
coordinator.
|
||||
|
||||
In addition, PyTorch custom ops do not support mutation or returning
|
||||
a new tensor in the same op. So we need to figure out if the op is
|
||||
in-place or out-of-place ahead of time.
|
||||
a new tensor in the same op. So we always make the all-reduce operation
|
||||
out-of-place.
|
||||
"""
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if self.world_size == 1:
|
||||
@ -368,10 +333,6 @@ class GroupCoordinator:
|
||||
ipex.distributed.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
|
||||
if not supports_custom_op():
|
||||
self._all_reduce_in_place(input_)
|
||||
return input_
|
||||
|
||||
if self.tpu_communicator is not None and \
|
||||
not self.tpu_communicator.disabled:
|
||||
# TPU handles Dynamo with its own logic.
|
||||
@ -385,30 +346,31 @@ class GroupCoordinator:
|
||||
not self.xpu_communicator.disabled:
|
||||
return self.xpu_communicator.all_reduce(input_)
|
||||
|
||||
if self.ca_comm is not None and \
|
||||
not self.ca_comm.disabled and \
|
||||
self.ca_comm.should_custom_ar(input_):
|
||||
return torch.ops.vllm.outplace_all_reduce(
|
||||
input_, group_name=self.unique_name)
|
||||
else:
|
||||
torch.ops.vllm.inplace_all_reduce(input_,
|
||||
group_name=self.unique_name)
|
||||
return input_
|
||||
return torch.ops.vllm.all_reduce(input_, group_name=self.unique_name)
|
||||
|
||||
def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
# always try custom allreduce first,
|
||||
# and then pynccl.
|
||||
ca_comm = self.ca_comm
|
||||
assert ca_comm is not None
|
||||
assert not ca_comm.disabled
|
||||
out = ca_comm.custom_all_reduce(input_)
|
||||
assert out is not None
|
||||
return out
|
||||
|
||||
def _all_reduce_in_place(self, input_: torch.Tensor) -> None:
|
||||
if ca_comm is not None and not ca_comm.disabled and \
|
||||
ca_comm.should_custom_ar(input_):
|
||||
out = ca_comm.custom_all_reduce(input_)
|
||||
assert out is not None
|
||||
return out
|
||||
pynccl_comm = self.pynccl_comm
|
||||
if (pynccl_comm is not None and not pynccl_comm.disabled):
|
||||
pynccl_comm.all_reduce(input_)
|
||||
else:
|
||||
torch.distributed.all_reduce(input_, group=self.device_group)
|
||||
assert pynccl_comm is not None
|
||||
# TODO: pynccl should not use `stream=`
|
||||
# it can just always use the current stream.
|
||||
out = pynccl_comm.all_reduce(input_,
|
||||
stream=torch.cuda.current_stream())
|
||||
if out is None:
|
||||
# fall back to the default all-reduce using PyTorch.
|
||||
# this usually happens during testing.
|
||||
# when we run the model, allreduce only happens for the TP
|
||||
# group, where we always have either custom allreduce or pynccl.
|
||||
out = input_.clone()
|
||||
torch.distributed.all_reduce(out, group=self.device_group)
|
||||
return out
|
||||
|
||||
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
world_size = self.world_size
|
||||
|
||||
@ -10,6 +10,7 @@ import torch.nn as nn
|
||||
|
||||
from vllm.compilation.compile_context import set_compile_context
|
||||
from vllm.config import CompilationLevel, VllmConfig
|
||||
from vllm.distributed.parallel_state import graph_capture
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
||||
from vllm.logger import init_logger
|
||||
@ -570,8 +571,9 @@ class GPUModelRunner:
|
||||
# Trigger CUDA graph capture for specific shapes.
|
||||
# Capture the large shapes first so that the smaller shapes
|
||||
# can reuse the memory pool allocated for the large shapes.
|
||||
for num_tokens in reversed(self.cudagraph_batch_sizes):
|
||||
self._dummy_run(self.model, num_tokens, self.kv_caches)
|
||||
with graph_capture():
|
||||
for num_tokens in reversed(self.cudagraph_batch_sizes):
|
||||
self._dummy_run(self.model, num_tokens, self.kv_caches)
|
||||
|
||||
end_time = time.perf_counter()
|
||||
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user