mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 20:44:27 +08:00
custom allreduce + torch.compile (#10121)
Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
519e8e4182
commit
9a88f89799
@ -86,7 +86,6 @@ If GPU/CPU communication cannot be established, you can use the following Python
|
|||||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||||
|
|
||||||
pynccl = PyNcclCommunicator(group=gloo_group, device=local_rank)
|
pynccl = PyNcclCommunicator(group=gloo_group, device=local_rank)
|
||||||
pynccl.disabled = False
|
|
||||||
|
|
||||||
s = torch.cuda.Stream()
|
s = torch.cuda.Stream()
|
||||||
with torch.cuda.stream(s):
|
with torch.cuda.stream(s):
|
||||||
|
|||||||
@ -60,7 +60,7 @@ def worker_fn():
|
|||||||
tensor = torch.ones(16, 1024, 1024,
|
tensor = torch.ones(16, 1024, 1024,
|
||||||
dtype=torch.float32).cuda(pynccl_comm.rank)
|
dtype=torch.float32).cuda(pynccl_comm.rank)
|
||||||
with pynccl_comm.change_state(enable=True):
|
with pynccl_comm.change_state(enable=True):
|
||||||
pynccl_comm.all_reduce(tensor)
|
tensor = pynccl_comm.all_reduce(tensor)
|
||||||
result = tensor.mean().cpu().item()
|
result = tensor.mean().cpu().item()
|
||||||
assert result == pynccl_comm.world_size
|
assert result == pynccl_comm.world_size
|
||||||
|
|
||||||
@ -84,12 +84,12 @@ def multiple_allreduce_worker_fn():
|
|||||||
with pynccl_comm.change_state(enable=True):
|
with pynccl_comm.change_state(enable=True):
|
||||||
# two groups can communicate independently
|
# two groups can communicate independently
|
||||||
if torch.distributed.get_rank() in [0, 1]:
|
if torch.distributed.get_rank() in [0, 1]:
|
||||||
pynccl_comm.all_reduce(tensor)
|
tensor = pynccl_comm.all_reduce(tensor)
|
||||||
pynccl_comm.all_reduce(tensor)
|
tensor = pynccl_comm.all_reduce(tensor)
|
||||||
result = tensor.mean().cpu().item()
|
result = tensor.mean().cpu().item()
|
||||||
assert result == 4
|
assert result == 4
|
||||||
else:
|
else:
|
||||||
pynccl_comm.all_reduce(tensor)
|
tensor = pynccl_comm.all_reduce(tensor)
|
||||||
result = tensor.mean().cpu().item()
|
result = tensor.mean().cpu().item()
|
||||||
assert result == 2
|
assert result == 2
|
||||||
|
|
||||||
@ -140,14 +140,11 @@ def worker_fn_with_cudagraph():
|
|||||||
with torch.cuda.graph(
|
with torch.cuda.graph(
|
||||||
graph, stream=pynccl_comm.stream), pynccl_comm.change_state(
|
graph, stream=pynccl_comm.stream), pynccl_comm.change_state(
|
||||||
enable=True):
|
enable=True):
|
||||||
# operation during the graph capture is recorded but not executed
|
a_out = pynccl_comm.all_reduce(a)
|
||||||
# see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture # noqa
|
|
||||||
pynccl_comm.all_reduce(a)
|
|
||||||
pynccl_comm.stream.synchronize()
|
pynccl_comm.stream.synchronize()
|
||||||
assert a.mean().cpu().item() == pynccl_comm.world_size**0
|
|
||||||
graph.replay()
|
graph.replay()
|
||||||
pynccl_comm.stream.synchronize()
|
pynccl_comm.stream.synchronize()
|
||||||
assert a.mean().cpu().item() == pynccl_comm.world_size**1
|
assert a_out.mean().cpu().item() == pynccl_comm.world_size**1
|
||||||
|
|
||||||
|
|
||||||
@worker_fn_wrapper
|
@worker_fn_wrapper
|
||||||
|
|||||||
@ -70,14 +70,12 @@ def gpu_worker(rank, WORLD_SIZE, port1, port2):
|
|||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=WORLD_SIZE)
|
world_size=WORLD_SIZE)
|
||||||
pynccl1 = PyNcclCommunicator(pg1, device=rank)
|
pynccl1 = PyNcclCommunicator(pg1, device=rank)
|
||||||
pynccl1.disabled = False
|
|
||||||
if rank <= 2:
|
if rank <= 2:
|
||||||
pg2 = StatelessProcessGroup.create(host="127.0.0.1",
|
pg2 = StatelessProcessGroup.create(host="127.0.0.1",
|
||||||
port=port2,
|
port=port2,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=3)
|
world_size=3)
|
||||||
pynccl2 = PyNcclCommunicator(pg2, device=rank)
|
pynccl2 = PyNcclCommunicator(pg2, device=rank)
|
||||||
pynccl2.disabled = False
|
|
||||||
data = torch.tensor([rank]).cuda()
|
data = torch.tensor([rank]).cuda()
|
||||||
pynccl1.all_reduce(data)
|
pynccl1.all_reduce(data)
|
||||||
pg1.barrier()
|
pg1.barrier()
|
||||||
|
|||||||
@ -106,30 +106,30 @@ class PyNcclCommunicator:
|
|||||||
self.stream.synchronize()
|
self.stream.synchronize()
|
||||||
del data
|
del data
|
||||||
|
|
||||||
# by default it is disabled, e.g. in profiling models and prefill phase.
|
|
||||||
# to use it, use under `with obj.change_state(enable=True)`, usually
|
|
||||||
# when we are using CUDA graph.
|
|
||||||
self.disabled = True
|
|
||||||
|
|
||||||
def all_reduce(self,
|
def all_reduce(self,
|
||||||
tensor: torch.Tensor,
|
in_tensor: torch.Tensor,
|
||||||
op: ReduceOp = ReduceOp.SUM,
|
op: ReduceOp = ReduceOp.SUM,
|
||||||
stream=None):
|
stream=None) -> torch.Tensor:
|
||||||
if self.disabled:
|
if self.disabled:
|
||||||
return
|
return None
|
||||||
# nccl communicator created on a specific device
|
# nccl communicator created on a specific device
|
||||||
# will only work on tensors on the same device
|
# will only work on tensors on the same device
|
||||||
# otherwise it will cause "illegal memory access"
|
# otherwise it will cause "illegal memory access"
|
||||||
assert tensor.device == self.device, (
|
assert in_tensor.device == self.device, (
|
||||||
f"this nccl communicator is created to work on {self.device}, "
|
f"this nccl communicator is created to work on {self.device}, "
|
||||||
f"but the input tensor is on {tensor.device}")
|
f"but the input tensor is on {in_tensor.device}")
|
||||||
|
|
||||||
|
out_tensor = torch.empty_like(in_tensor)
|
||||||
|
|
||||||
if stream is None:
|
if stream is None:
|
||||||
stream = self.stream
|
stream = self.stream
|
||||||
self.nccl.ncclAllReduce(buffer_type(tensor.data_ptr()),
|
self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()),
|
||||||
buffer_type(tensor.data_ptr()), tensor.numel(),
|
buffer_type(out_tensor.data_ptr()),
|
||||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
in_tensor.numel(),
|
||||||
|
ncclDataTypeEnum.from_torch(in_tensor.dtype),
|
||||||
ncclRedOpTypeEnum.from_torch(op), self.comm,
|
ncclRedOpTypeEnum.from_torch(op), self.comm,
|
||||||
cudaStream_t(stream.cuda_stream))
|
cudaStream_t(stream.cuda_stream))
|
||||||
|
return out_tensor
|
||||||
|
|
||||||
def all_gather(self,
|
def all_gather(self,
|
||||||
output_tensor: torch.Tensor,
|
output_tensor: torch.Tensor,
|
||||||
|
|||||||
@ -96,42 +96,24 @@ def _register_group(group: "GroupCoordinator") -> None:
|
|||||||
_groups[group.unique_name] = weakref.ref(group)
|
_groups[group.unique_name] = weakref.ref(group)
|
||||||
|
|
||||||
|
|
||||||
|
def all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
|
||||||
|
assert group_name in _groups, f"Group {group_name} is not found."
|
||||||
|
group = _groups[group_name]()
|
||||||
|
if group is None:
|
||||||
|
raise ValueError(f"Group {group_name} is destroyed.")
|
||||||
|
return group._all_reduce_out_place(tensor)
|
||||||
|
|
||||||
|
|
||||||
|
def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
|
||||||
|
return torch.empty_like(tensor)
|
||||||
|
|
||||||
|
|
||||||
if supports_custom_op():
|
if supports_custom_op():
|
||||||
|
|
||||||
def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
|
|
||||||
assert group_name in _groups, f"Group {group_name} is not found."
|
|
||||||
group = _groups[group_name]()
|
|
||||||
if group is None:
|
|
||||||
raise ValueError(f"Group {group_name} is destroyed.")
|
|
||||||
group._all_reduce_in_place(tensor)
|
|
||||||
|
|
||||||
def inplace_all_reduce_fake(tensor: torch.Tensor, group_name: str) -> None:
|
|
||||||
return
|
|
||||||
|
|
||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="inplace_all_reduce",
|
op_name="all_reduce",
|
||||||
op_func=inplace_all_reduce,
|
op_func=all_reduce,
|
||||||
mutates_args=["tensor"],
|
|
||||||
fake_impl=inplace_all_reduce_fake,
|
|
||||||
)
|
|
||||||
|
|
||||||
def outplace_all_reduce(tensor: torch.Tensor,
|
|
||||||
group_name: str) -> torch.Tensor:
|
|
||||||
assert group_name in _groups, f"Group {group_name} is not found."
|
|
||||||
group = _groups[group_name]()
|
|
||||||
if group is None:
|
|
||||||
raise ValueError(f"Group {group_name} is destroyed.")
|
|
||||||
return group._all_reduce_out_place(tensor)
|
|
||||||
|
|
||||||
def outplace_all_reduce_fake(tensor: torch.Tensor,
|
|
||||||
group_name: str) -> torch.Tensor:
|
|
||||||
return torch.empty_like(tensor)
|
|
||||||
|
|
||||||
direct_register_custom_op(
|
|
||||||
op_name="outplace_all_reduce",
|
|
||||||
op_func=outplace_all_reduce,
|
|
||||||
mutates_args=[],
|
mutates_args=[],
|
||||||
fake_impl=outplace_all_reduce_fake,
|
fake_impl=all_reduce_fake,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -317,30 +299,13 @@ class GroupCoordinator:
|
|||||||
stream.wait_stream(curr_stream)
|
stream.wait_stream(curr_stream)
|
||||||
|
|
||||||
with torch.cuda.stream(stream), maybe_ca_context:
|
with torch.cuda.stream(stream), maybe_ca_context:
|
||||||
# In graph mode, we have to be very careful about the collective
|
|
||||||
# operations. The current status is:
|
|
||||||
# allreduce \ Mode | Eager | Graph |
|
|
||||||
# --------------------------------------------
|
|
||||||
# custom allreduce | enabled | enabled |
|
|
||||||
# PyNccl | disabled| enabled |
|
|
||||||
# torch.distributed | enabled | disabled|
|
|
||||||
#
|
|
||||||
# Note that custom allreduce will have a runtime check, if the
|
|
||||||
# tensor size is too large, it will fallback to the next
|
|
||||||
# available option.
|
|
||||||
# In summary: When using CUDA graph, we use
|
|
||||||
# either custom all-reduce kernel or pynccl. When not using
|
|
||||||
# CUDA graph, we use either custom all-reduce kernel or
|
|
||||||
# PyTorch NCCL. We always prioritize using custom all-reduce
|
|
||||||
# kernel but fall back to PyTorch or pynccl if it is
|
|
||||||
# disabled or not supported.
|
|
||||||
pynccl_comm = self.pynccl_comm
|
pynccl_comm = self.pynccl_comm
|
||||||
maybe_pynccl_context: Any
|
maybe_pynccl_context: Any
|
||||||
if not pynccl_comm:
|
if not pynccl_comm:
|
||||||
maybe_pynccl_context = nullcontext()
|
maybe_pynccl_context = nullcontext()
|
||||||
else:
|
else:
|
||||||
maybe_pynccl_context = pynccl_comm.change_state(
|
maybe_pynccl_context = pynccl_comm.change_state(
|
||||||
enable=True, stream=torch.cuda.current_stream())
|
stream=torch.cuda.current_stream())
|
||||||
with maybe_pynccl_context:
|
with maybe_pynccl_context:
|
||||||
yield graph_capture_context
|
yield graph_capture_context
|
||||||
|
|
||||||
@ -356,8 +321,8 @@ class GroupCoordinator:
|
|||||||
coordinator.
|
coordinator.
|
||||||
|
|
||||||
In addition, PyTorch custom ops do not support mutation or returning
|
In addition, PyTorch custom ops do not support mutation or returning
|
||||||
a new tensor in the same op. So we need to figure out if the op is
|
a new tensor in the same op. So we always make the all-reduce operation
|
||||||
in-place or out-of-place ahead of time.
|
out-of-place.
|
||||||
"""
|
"""
|
||||||
# Bypass the function if we are using only 1 GPU.
|
# Bypass the function if we are using only 1 GPU.
|
||||||
if self.world_size == 1:
|
if self.world_size == 1:
|
||||||
@ -368,10 +333,6 @@ class GroupCoordinator:
|
|||||||
ipex.distributed.all_reduce(input_, group=self.device_group)
|
ipex.distributed.all_reduce(input_, group=self.device_group)
|
||||||
return input_
|
return input_
|
||||||
|
|
||||||
if not supports_custom_op():
|
|
||||||
self._all_reduce_in_place(input_)
|
|
||||||
return input_
|
|
||||||
|
|
||||||
if self.tpu_communicator is not None and \
|
if self.tpu_communicator is not None and \
|
||||||
not self.tpu_communicator.disabled:
|
not self.tpu_communicator.disabled:
|
||||||
# TPU handles Dynamo with its own logic.
|
# TPU handles Dynamo with its own logic.
|
||||||
@ -385,30 +346,31 @@ class GroupCoordinator:
|
|||||||
not self.xpu_communicator.disabled:
|
not self.xpu_communicator.disabled:
|
||||||
return self.xpu_communicator.all_reduce(input_)
|
return self.xpu_communicator.all_reduce(input_)
|
||||||
|
|
||||||
if self.ca_comm is not None and \
|
return torch.ops.vllm.all_reduce(input_, group_name=self.unique_name)
|
||||||
not self.ca_comm.disabled and \
|
|
||||||
self.ca_comm.should_custom_ar(input_):
|
|
||||||
return torch.ops.vllm.outplace_all_reduce(
|
|
||||||
input_, group_name=self.unique_name)
|
|
||||||
else:
|
|
||||||
torch.ops.vllm.inplace_all_reduce(input_,
|
|
||||||
group_name=self.unique_name)
|
|
||||||
return input_
|
|
||||||
|
|
||||||
def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
|
def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
|
||||||
|
# always try custom allreduce first,
|
||||||
|
# and then pynccl.
|
||||||
ca_comm = self.ca_comm
|
ca_comm = self.ca_comm
|
||||||
assert ca_comm is not None
|
if ca_comm is not None and not ca_comm.disabled and \
|
||||||
assert not ca_comm.disabled
|
ca_comm.should_custom_ar(input_):
|
||||||
out = ca_comm.custom_all_reduce(input_)
|
out = ca_comm.custom_all_reduce(input_)
|
||||||
assert out is not None
|
assert out is not None
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def _all_reduce_in_place(self, input_: torch.Tensor) -> None:
|
|
||||||
pynccl_comm = self.pynccl_comm
|
pynccl_comm = self.pynccl_comm
|
||||||
if (pynccl_comm is not None and not pynccl_comm.disabled):
|
assert pynccl_comm is not None
|
||||||
pynccl_comm.all_reduce(input_)
|
# TODO: pynccl should not use `stream=`
|
||||||
else:
|
# it can just always use the current stream.
|
||||||
torch.distributed.all_reduce(input_, group=self.device_group)
|
out = pynccl_comm.all_reduce(input_,
|
||||||
|
stream=torch.cuda.current_stream())
|
||||||
|
if out is None:
|
||||||
|
# fall back to the default all-reduce using PyTorch.
|
||||||
|
# this usually happens during testing.
|
||||||
|
# when we run the model, allreduce only happens for the TP
|
||||||
|
# group, where we always have either custom allreduce or pynccl.
|
||||||
|
out = input_.clone()
|
||||||
|
torch.distributed.all_reduce(out, group=self.device_group)
|
||||||
|
return out
|
||||||
|
|
||||||
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||||
world_size = self.world_size
|
world_size = self.world_size
|
||||||
|
|||||||
@ -10,6 +10,7 @@ import torch.nn as nn
|
|||||||
|
|
||||||
from vllm.compilation.compile_context import set_compile_context
|
from vllm.compilation.compile_context import set_compile_context
|
||||||
from vllm.config import CompilationLevel, VllmConfig
|
from vllm.config import CompilationLevel, VllmConfig
|
||||||
|
from vllm.distributed.parallel_state import graph_capture
|
||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -570,8 +571,9 @@ class GPUModelRunner:
|
|||||||
# Trigger CUDA graph capture for specific shapes.
|
# Trigger CUDA graph capture for specific shapes.
|
||||||
# Capture the large shapes first so that the smaller shapes
|
# Capture the large shapes first so that the smaller shapes
|
||||||
# can reuse the memory pool allocated for the large shapes.
|
# can reuse the memory pool allocated for the large shapes.
|
||||||
for num_tokens in reversed(self.cudagraph_batch_sizes):
|
with graph_capture():
|
||||||
self._dummy_run(self.model, num_tokens, self.kv_caches)
|
for num_tokens in reversed(self.cudagraph_batch_sizes):
|
||||||
|
self._dummy_run(self.model, num_tokens, self.kv_caches)
|
||||||
|
|
||||||
end_time = time.perf_counter()
|
end_time = time.perf_counter()
|
||||||
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user