mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-21 11:04:29 +08:00
[Core][Distributed] refactor custom allreduce to support multiple tp groups (#4754)
This commit is contained in:
parent
a7be4d0072
commit
702bee461f
@ -16,7 +16,7 @@ from vllm.test_utils import (init_test_distributed_environment,
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
def all_reduce_test_worker(tensor_parallel_size: int, rank: int,
|
||||
def all_reduce_test_worker(tp_size: int, pp_size: int, rank: int,
|
||||
distributed_init_port: str):
|
||||
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
|
||||
# so that each worker can see all the GPUs
|
||||
@ -24,12 +24,12 @@ def all_reduce_test_worker(tensor_parallel_size: int, rank: int,
|
||||
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(1, tensor_parallel_size, rank,
|
||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
||||
distributed_init_port)
|
||||
num_elements = 8
|
||||
all_tensors = [
|
||||
torch.arange(num_elements, dtype=torch.float32, device="cuda") *
|
||||
(r + 1) for r in range(tensor_parallel_size)
|
||||
(r + 1) for r in range(tp_size)
|
||||
]
|
||||
expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
|
||||
t = all_tensors[rank]
|
||||
@ -38,7 +38,7 @@ def all_reduce_test_worker(tensor_parallel_size: int, rank: int,
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
def all_gather_test_worker(tensor_parallel_size: int, rank: int,
|
||||
def all_gather_test_worker(tp_size: int, pp_size: int, rank: int,
|
||||
distributed_init_port: str):
|
||||
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
|
||||
# so that each worker can see all the GPUs
|
||||
@ -46,7 +46,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int,
|
||||
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(1, tensor_parallel_size, rank,
|
||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
||||
distributed_init_port)
|
||||
num_dimensions = 3
|
||||
tensor_size = list(range(2, num_dimensions + 2))
|
||||
@ -57,7 +57,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int,
|
||||
all_tensors = [
|
||||
torch.arange(total_size, dtype=torch.float32,
|
||||
device="cuda").reshape(tensor_size) * (r + 1)
|
||||
for r in range(tensor_parallel_size)
|
||||
for r in range(tp_size)
|
||||
]
|
||||
expected = torch.cat(all_tensors, dim=all_gather_dimension)
|
||||
t = all_tensors[rank]
|
||||
@ -66,7 +66,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int,
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int,
|
||||
def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
|
||||
distributed_init_port: str):
|
||||
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
|
||||
# so that each worker can see all the GPUs
|
||||
@ -74,7 +74,7 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int,
|
||||
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(1, tensor_parallel_size, rank,
|
||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
||||
distributed_init_port)
|
||||
test_dict = {
|
||||
# device tensor
|
||||
@ -106,10 +106,10 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int,
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [2])
|
||||
@pytest.mark.parametrize("tp_size", [2])
|
||||
@pytest.mark.parametrize("test_target", [
|
||||
all_reduce_test_worker, all_gather_test_worker,
|
||||
broadcast_tensor_dict_test_worker
|
||||
])
|
||||
def test_multi_process_tensor_parallel(tensor_parallel_size, test_target):
|
||||
multi_process_tensor_parallel(tensor_parallel_size, test_target)
|
||||
def test_multi_process_tensor_parallel(tp_size, test_target):
|
||||
multi_process_tensor_parallel(tp_size, 1, test_target)
|
||||
|
||||
@ -6,8 +6,10 @@ import ray
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.device_communicators import custom_all_reduce
|
||||
from vllm.distributed.communication_op import ( # noqa
|
||||
graph_capture, tensor_model_parallel_all_reduce)
|
||||
from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
|
||||
get_tp_ca_communicator)
|
||||
from vllm.test_utils import (init_test_distributed_environment,
|
||||
multi_process_tensor_parallel)
|
||||
|
||||
@ -18,17 +20,36 @@ for i, v in enumerate(test_sizes):
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
def graph_allreduce(world_size, rank, distributed_init_port):
|
||||
def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
|
||||
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(1, world_size, rank,
|
||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
||||
distributed_init_port)
|
||||
|
||||
custom_all_reduce.init_custom_ar()
|
||||
group = get_tensor_model_parallel_group()
|
||||
|
||||
# A small all_reduce for warmup.
|
||||
# this is needed because device communicators might be created lazily
|
||||
# (e.g. NCCL). This will ensure that the communicator is initialized
|
||||
# before any communication happens, so that this group can be used for
|
||||
# graph capture immediately.
|
||||
data = torch.zeros(1)
|
||||
data = data.to(device=device)
|
||||
torch.distributed.all_reduce(data, group=group)
|
||||
torch.cuda.synchronize()
|
||||
del data
|
||||
|
||||
# we use the first group to communicate once
|
||||
# and the second group to communicate twice
|
||||
# and so on
|
||||
# this is used to demonstrate that each group can
|
||||
# communicate independently
|
||||
num_communication = rank // tp_size + 1
|
||||
|
||||
for sz in test_sizes:
|
||||
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
with custom_all_reduce.capture():
|
||||
with graph_capture():
|
||||
# use integers so result matches NCCL exactly
|
||||
inp1 = torch.randint(1,
|
||||
16, (sz, ),
|
||||
@ -41,44 +62,52 @@ def graph_allreduce(world_size, rank, distributed_init_port):
|
||||
torch.cuda.synchronize()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
out1 = tensor_model_parallel_all_reduce(inp1)
|
||||
# the input buffer is immediately modified to test
|
||||
# synchronization
|
||||
dist.all_reduce(inp1)
|
||||
out2 = tensor_model_parallel_all_reduce(inp2)
|
||||
dist.all_reduce(inp2)
|
||||
for i in range(num_communication):
|
||||
out1 = tensor_model_parallel_all_reduce(inp1)
|
||||
# the input buffer is immediately modified to test
|
||||
# synchronization
|
||||
dist.all_reduce(inp1, group=group)
|
||||
out2 = tensor_model_parallel_all_reduce(inp2)
|
||||
dist.all_reduce(inp2, group=group)
|
||||
graph.replay()
|
||||
assert torch.allclose(out1, inp1)
|
||||
assert torch.allclose(out2, inp2)
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
def eager_allreduce(world_size, rank, distributed_init_port):
|
||||
def eager_allreduce(tp_size, pp_size, rank, distributed_init_port):
|
||||
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(1, world_size, rank,
|
||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
||||
distributed_init_port)
|
||||
|
||||
# we use the first group to communicate once
|
||||
# and the second group to communicate twice
|
||||
# and so on
|
||||
# this is used to demonstrate that each group can
|
||||
# communicate independently
|
||||
num_communication = rank // tp_size + 1
|
||||
sz = 1024
|
||||
custom_all_reduce.init_custom_ar()
|
||||
fa = custom_all_reduce.get_handle()
|
||||
fa = get_tp_ca_communicator()
|
||||
inp = torch.ones(sz, dtype=torch.float32, device=device)
|
||||
out = fa.all_reduce_unreg(inp)
|
||||
assert torch.allclose(out, inp * world_size)
|
||||
out = inp
|
||||
for _ in range(num_communication):
|
||||
out = fa.all_reduce_unreg(out)
|
||||
assert torch.allclose(out, inp * (tp_size**num_communication))
|
||||
|
||||
inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device)
|
||||
out = fa.all_reduce_unreg(inp)
|
||||
assert torch.allclose(out, inp * world_size)
|
||||
out = inp
|
||||
for _ in range(num_communication):
|
||||
out = fa.all_reduce_unreg(out)
|
||||
assert torch.allclose(out, inp * (tp_size**num_communication))
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [2])
|
||||
@pytest.mark.parametrize("tp_size", [2])
|
||||
@pytest.mark.parametrize("pipeline_parallel_size", [1, 2])
|
||||
@pytest.mark.parametrize("test_target", [eager_allreduce, graph_allreduce])
|
||||
def test_multi_process_tensor_parallel(tensor_parallel_size, test_target):
|
||||
multi_process_tensor_parallel(tensor_parallel_size, test_target)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
multi_process_tensor_parallel(2, graph_allreduce)
|
||||
def test_custom_allreduce(tp_size, pipeline_parallel_size, test_target):
|
||||
world_size = tp_size * pipeline_parallel_size
|
||||
if world_size > torch.cuda.device_count():
|
||||
pytest.skip("Not enough GPUs to run the test.")
|
||||
multi_process_tensor_parallel(tp_size, pipeline_parallel_size, test_target)
|
||||
|
||||
@ -5,7 +5,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.distributed.communication_op import ( # noqa
|
||||
graph_capture_mode, tensor_model_parallel_all_reduce)
|
||||
graph_mode, tensor_model_parallel_all_reduce)
|
||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
|
||||
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
|
||||
@ -103,7 +103,7 @@ def multiple_tp_with_vllm_worker_fn():
|
||||
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
|
||||
ensure_model_parallel_initialized(2, 2)
|
||||
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
|
||||
with graph_capture_mode():
|
||||
with graph_mode():
|
||||
# two tp groups can communicate independently
|
||||
if torch.distributed.get_rank() in [0, 1]:
|
||||
tensor = tensor_model_parallel_all_reduce(tensor)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections import namedtuple
|
||||
from contextlib import contextmanager
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@ -9,12 +9,13 @@ from .parallel_state import (get_cpu_world_group,
|
||||
get_tensor_model_parallel_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_ca_communicator,
|
||||
get_tp_pynccl_communicator)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def graph_capture_mode():
|
||||
# In graph capture, we have to be very careful about the collective
|
||||
def graph_mode():
|
||||
# In graph mode, we have to be very careful about the collective
|
||||
# operations. The current status is:
|
||||
# allreduce \ Mode | Eager | Graph |
|
||||
# --------------------------------------------
|
||||
@ -24,10 +25,32 @@ def graph_capture_mode():
|
||||
#
|
||||
# Note that custom allreduce will have a runtime check, if the tensor size
|
||||
# is too large, it will fallback to the next available option.
|
||||
# In summary: When using CUDA graph, we use
|
||||
# either custom all-reduce kernel or pynccl. When not using CUDA
|
||||
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
|
||||
# We always prioritize using custom all-reduce kernel but fall back
|
||||
# to PyTorch or pynccl if it is disabled or not supported.
|
||||
pynccl_comm = get_tp_pynccl_communicator()
|
||||
assert pynccl_comm is not None
|
||||
with pynccl_comm.change_state(enable=True,
|
||||
stream=torch.cuda.current_stream()):
|
||||
if pynccl_comm is None:
|
||||
context = nullcontext()
|
||||
else:
|
||||
context = pynccl_comm.change_state(enable=True,
|
||||
stream=torch.cuda.current_stream())
|
||||
with context:
|
||||
yield
|
||||
|
||||
|
||||
@contextmanager
|
||||
def graph_capture():
|
||||
"""
|
||||
`graph_capture` is a context manager which should include the code that
|
||||
is capturing the CUDA graph. Its main purpose is to ensure that the
|
||||
some operations will be run after the graph is captured, before the graph
|
||||
is replayed.
|
||||
"""
|
||||
ca_comm = get_tp_ca_communicator()
|
||||
context = nullcontext() if ca_comm is None else ca_comm.capture()
|
||||
with context:
|
||||
yield
|
||||
|
||||
|
||||
@ -43,15 +66,15 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
|
||||
TLDR: always assume this function modifies its input, but use the return
|
||||
value as the output.
|
||||
"""
|
||||
from vllm.distributed.device_communicators.custom_all_reduce import (
|
||||
custom_all_reduce)
|
||||
ca_comm = get_tp_ca_communicator()
|
||||
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if get_tensor_model_parallel_world_size() == 1:
|
||||
return input_
|
||||
out = custom_all_reduce(input_)
|
||||
if out is not None:
|
||||
return out
|
||||
if ca_comm is not None:
|
||||
out = ca_comm.custom_all_reduce(input_)
|
||||
if out is not None:
|
||||
return out
|
||||
pynccl_comm = get_tp_pynccl_communicator()
|
||||
if (pynccl_comm is not None and not pynccl_comm.disabled):
|
||||
pynccl_comm.all_reduce(input_)
|
||||
|
||||
@ -1,155 +1,43 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_local_rank, get_tensor_model_parallel_cpu_group)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
try:
|
||||
import pynvml
|
||||
|
||||
from vllm._C import custom_ar
|
||||
|
||||
@contextmanager
|
||||
def _nvml():
|
||||
try:
|
||||
pynvml.nvmlInit()
|
||||
yield
|
||||
finally:
|
||||
pynvml.nvmlShutdown()
|
||||
|
||||
except ImportError:
|
||||
# For AMD GPUs
|
||||
custom_ar = None
|
||||
pynvml = None
|
||||
|
||||
@contextmanager
|
||||
def _nvml():
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
pass
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_CA_HANDLE: Optional["CustomAllreduce"] = None
|
||||
_IS_CAPTURING = False
|
||||
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
|
||||
|
||||
|
||||
def init_custom_ar() -> None:
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
|
||||
global _CA_HANDLE
|
||||
if _CA_HANDLE is not None:
|
||||
return
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
if world_size == 1:
|
||||
# No need to initialize custom allreduce for single GPU case.
|
||||
return
|
||||
|
||||
if world_size not in _SUPPORTED_WORLD_SIZES:
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled due to an unsupported world size: "
|
||||
"%d. Supported world sizes: %s. To silence this warning, specify"
|
||||
" disable_custom_all_reduce=True explicitly.", world_size,
|
||||
str(_SUPPORTED_WORLD_SIZES))
|
||||
return
|
||||
num_dev = torch.cuda.device_count()
|
||||
# note: num dev can be larger than world_size if we're only using
|
||||
# first few GPUs
|
||||
if num_dev < world_size:
|
||||
logger.warning(
|
||||
"Cannot test GPU P2P because not all GPUs are visible to the "
|
||||
"current process. This might be the case if 'CUDA_VISIBLE_DEVICES'"
|
||||
" is set.")
|
||||
return
|
||||
|
||||
# we only use a subset of GPUs here
|
||||
# so we only need to check the nvlink connectivity of these GPUs
|
||||
num_dev = world_size
|
||||
# test nvlink first, this will filter out most of the cases
|
||||
# where custom allreduce is not supported
|
||||
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
|
||||
if cuda_visible_devices:
|
||||
device_ids = list(map(int, cuda_visible_devices.split(",")))
|
||||
else:
|
||||
device_ids = list(range(num_dev))
|
||||
# this checks hardware and driver support for NVLink
|
||||
full_nvlink = _is_full_nvlink(device_ids)
|
||||
if world_size > 2 and not full_nvlink:
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled because it's not supported on more"
|
||||
" than two PCIe-only GPUs. To silence this warning, specify"
|
||||
" disable_custom_all_reduce=True explicitly.")
|
||||
return
|
||||
# test P2P capability, this checks software/cudaruntime support
|
||||
# this is expensive to compute at the first time
|
||||
# then we cache the result
|
||||
if not _can_p2p(rank, world_size):
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled because your platform lacks GPU P2P"
|
||||
" capability or P2P test failed. To silence this warning, specify"
|
||||
" disable_custom_all_reduce=True explicitly.")
|
||||
return
|
||||
_CA_HANDLE = CustomAllreduce(rank, world_size, full_nvlink)
|
||||
|
||||
|
||||
def begin_capture() -> None:
|
||||
global _IS_CAPTURING
|
||||
_IS_CAPTURING = True
|
||||
|
||||
|
||||
def end_capture() -> None:
|
||||
global _IS_CAPTURING
|
||||
_IS_CAPTURING = False
|
||||
|
||||
|
||||
def is_capturing() -> bool:
|
||||
return _IS_CAPTURING and _CA_HANDLE is not None
|
||||
|
||||
|
||||
def get_handle() -> Optional["CustomAllreduce"]:
|
||||
return _CA_HANDLE
|
||||
|
||||
|
||||
def is_initialized() -> bool:
|
||||
return _CA_HANDLE is not None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def capture():
|
||||
try:
|
||||
begin_capture()
|
||||
yield
|
||||
finally:
|
||||
end_capture()
|
||||
handle = get_handle()
|
||||
if handle is not None:
|
||||
handle.register_graph_buffers()
|
||||
|
||||
|
||||
def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
ca_handle = get_handle()
|
||||
# when custom allreduce is disabled, this will be None
|
||||
if ca_handle is None:
|
||||
return None
|
||||
if is_capturing():
|
||||
if torch.cuda.is_current_stream_capturing():
|
||||
if ca_handle.should_custom_ar(input):
|
||||
return ca_handle.all_reduce_reg(input)
|
||||
else:
|
||||
if ca_handle.should_custom_ar(input):
|
||||
# if warm up, mimic the allocation pattern
|
||||
# since custom allreduce is out-of-place
|
||||
return torch.empty_like(input)
|
||||
else:
|
||||
# note: outside of cuda graph context,
|
||||
# custom allreduce incurs a cost of cudaMemcpy, which should
|
||||
# be small(<=1% of overall latency) compared to the performance
|
||||
# gains of using custom kernels
|
||||
if ca_handle.should_custom_ar(input):
|
||||
return ca_handle.all_reduce_unreg(input)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _nvml():
|
||||
try:
|
||||
pynvml.nvmlInit()
|
||||
yield
|
||||
finally:
|
||||
pynvml.nvmlShutdown()
|
||||
|
||||
|
||||
@_nvml()
|
||||
def _is_full_nvlink(device_ids: List[int]) -> bool:
|
||||
@ -188,22 +76,112 @@ def _can_p2p(rank: int, world_size: int) -> bool:
|
||||
|
||||
class CustomAllreduce:
|
||||
|
||||
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
|
||||
|
||||
# max_size: max supported allreduce size
|
||||
def __init__(self,
|
||||
rank,
|
||||
world_size,
|
||||
full_nvlink,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
device: Optional[Union[int, str, torch.device]] = None,
|
||||
max_size=8192 * 1024) -> None:
|
||||
"""
|
||||
Args:
|
||||
group: the process group to work on. If None, it will use the
|
||||
default process group.
|
||||
device: the device to bind the CustomAllreduce to. If None,
|
||||
it will be bind to f"cuda:{local_rank}".
|
||||
It is the caller's responsibility to make sure each communicator
|
||||
is bind to a unique device, and all communicators in this group
|
||||
are in the same node.
|
||||
"""
|
||||
self._IS_CAPTURING = False
|
||||
self.disabled = True
|
||||
|
||||
if custom_ar is None:
|
||||
# disable because of missing custom allreduce library
|
||||
# e.g. in a non-cuda environment
|
||||
return
|
||||
|
||||
group = group or get_tensor_model_parallel_cpu_group()
|
||||
self.group = group
|
||||
|
||||
assert dist.get_backend(group) != dist.Backend.NCCL, (
|
||||
"CustomAllreduce should be attached to a non-NCCL group.")
|
||||
|
||||
rank = dist.get_rank(group=self.group)
|
||||
world_size = dist.get_world_size(group=self.group)
|
||||
if world_size == 1:
|
||||
# No need to initialize custom allreduce for single GPU case.
|
||||
return
|
||||
|
||||
if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES:
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled due to an unsupported world"
|
||||
" size: %d. Supported world sizes: %s. To silence this "
|
||||
"warning, specify disable_custom_all_reduce=True explicitly.",
|
||||
world_size, str(CustomAllreduce._SUPPORTED_WORLD_SIZES))
|
||||
return
|
||||
|
||||
if device is None:
|
||||
local_rank = get_local_rank()
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
elif isinstance(device, int):
|
||||
device = torch.device(f"cuda:{device}")
|
||||
elif isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
# now `device` is a `torch.device` object
|
||||
assert isinstance(device, torch.device)
|
||||
self.device = device
|
||||
|
||||
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
|
||||
if cuda_visible_devices:
|
||||
device_ids = list(map(int, cuda_visible_devices.split(",")))
|
||||
else:
|
||||
device_ids = list(range(torch.cuda.device_count()))
|
||||
|
||||
physical_device_id = device_ids[device.index]
|
||||
tensor = torch.tensor([physical_device_id],
|
||||
dtype=torch.int,
|
||||
device="cpu")
|
||||
gather_list = [
|
||||
torch.tensor([0], dtype=torch.int, device="cpu")
|
||||
for _ in range(world_size)
|
||||
]
|
||||
dist.all_gather(gather_list, tensor, group=self.group)
|
||||
physical_device_ids = [t.item() for t in gather_list]
|
||||
|
||||
# test nvlink first, this will filter out most of the cases
|
||||
# where custom allreduce is not supported
|
||||
# this checks hardware and driver support for NVLink
|
||||
full_nvlink = _is_full_nvlink(physical_device_ids)
|
||||
if world_size > 2 and not full_nvlink:
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled because it's not supported on"
|
||||
" more than two PCIe-only GPUs. To silence this warning, "
|
||||
"specify disable_custom_all_reduce=True explicitly.")
|
||||
return
|
||||
# test P2P capability, this checks software/cudaruntime support
|
||||
# this is expensive to compute at the first time
|
||||
# then we cache the result
|
||||
if not _can_p2p(rank, world_size):
|
||||
logger.warning(
|
||||
"Custom allreduce is disabled because your platform lacks "
|
||||
"GPU P2P capability or P2P test failed. To silence this "
|
||||
"warning, specify disable_custom_all_reduce=True explicitly.")
|
||||
return
|
||||
|
||||
self.disabled = False
|
||||
# buffers memory are owned by this Python class and passed to C++
|
||||
# meta data composes of two parts: meta data for synchronization
|
||||
# (256 bytes) and a temporary buffer for storing intermediate
|
||||
# allreduce results.
|
||||
self.meta = torch.zeros(custom_ar.meta_size() + max_size,
|
||||
dtype=torch.uint8,
|
||||
device="cuda")
|
||||
device=self.device)
|
||||
# This is a pre-registered IPC buffer. In eager mode, input tensors
|
||||
# are first copied into this buffer before allreduce is performed
|
||||
self.buffer = torch.empty(max_size, dtype=torch.uint8, device="cuda")
|
||||
self.buffer = torch.empty(max_size,
|
||||
dtype=torch.uint8,
|
||||
device=self.device)
|
||||
# This is a buffer for storing the tuples of pointers pointing to
|
||||
# IPC buffers from all ranks. Each registered tuple has size of
|
||||
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
|
||||
@ -211,8 +189,9 @@ class CustomAllreduce:
|
||||
# needs less than 10000 of registered tuples.
|
||||
self.rank_data = torch.empty(8 * 1024 * 1024,
|
||||
dtype=torch.uint8,
|
||||
device="cuda")
|
||||
device=self.device)
|
||||
self.max_size = max_size
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
handles, offsets = self._get_ipc_meta(self.meta)
|
||||
self.full_nvlink = full_nvlink
|
||||
@ -221,6 +200,21 @@ class CustomAllreduce:
|
||||
self.full_nvlink)
|
||||
self.register_buffer(self.buffer)
|
||||
|
||||
@contextmanager
|
||||
def capture(self):
|
||||
"""
|
||||
The main responsibility of this context manager is the
|
||||
`register_graph_buffers` call at the end of the context.
|
||||
It records all the buffer addresses used in the CUDA graph.
|
||||
"""
|
||||
try:
|
||||
self._IS_CAPTURING = True
|
||||
yield
|
||||
finally:
|
||||
self._IS_CAPTURING = False
|
||||
if not self.disabled:
|
||||
self.register_graph_buffers()
|
||||
|
||||
def _get_ipc_meta(self, inp: torch.Tensor):
|
||||
data = inp.untyped_storage()._share_cuda_()
|
||||
shard_data = (
|
||||
@ -230,14 +224,29 @@ class CustomAllreduce:
|
||||
return self._gather_ipc_meta(shard_data)
|
||||
|
||||
def _gather_ipc_meta(self, shard_data):
|
||||
all_data: List[Optional[Any]] = [None] * self.world_size
|
||||
dist.all_gather_object(all_data, shard_data)
|
||||
# Note: don't use `[[None]] * self.world_size` here
|
||||
# because it will create a list of the same reference
|
||||
all_data: List[Optional[Any]] = [[None]
|
||||
for i in range(self.world_size)]
|
||||
all_data[self.rank][0] = shard_data
|
||||
|
||||
ranks = dist.get_process_group_ranks(group=self.group)
|
||||
ranks.sort()
|
||||
for i, rank in enumerate(ranks):
|
||||
dist.broadcast_object_list(all_data[i],
|
||||
src=rank,
|
||||
group=self.group,
|
||||
device="cpu")
|
||||
|
||||
# we cannot directly use `dist.all_gather_object` here
|
||||
# because it is incompatible with `gloo` backend under inference mode.
|
||||
# see https://github.com/pytorch/pytorch/issues/126032 for details.
|
||||
|
||||
handles = []
|
||||
offsets = []
|
||||
for i in range(len(all_data)):
|
||||
handles.append(all_data[i][0]) # type: ignore
|
||||
offsets.append(all_data[i][1]) # type: ignore
|
||||
handles.append(all_data[i][0][0]) # type: ignore
|
||||
offsets.append(all_data[i][0][1]) # type: ignore
|
||||
return handles, offsets
|
||||
|
||||
def register_buffer(self, inp: torch.Tensor):
|
||||
@ -269,8 +278,31 @@ class CustomAllreduce:
|
||||
custom_ar.all_reduce_unreg(self._ptr, inp, self.buffer, out)
|
||||
return out
|
||||
|
||||
def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
# when custom allreduce is disabled, this will be None
|
||||
if self.disabled:
|
||||
return None
|
||||
if self._IS_CAPTURING:
|
||||
if torch.cuda.is_current_stream_capturing():
|
||||
if self.should_custom_ar(input):
|
||||
return self.all_reduce_reg(input)
|
||||
else:
|
||||
if self.should_custom_ar(input):
|
||||
# if warm up, mimic the allocation pattern
|
||||
# since custom allreduce is out-of-place
|
||||
return torch.empty_like(input)
|
||||
else:
|
||||
# note: outside of cuda graph context,
|
||||
# custom allreduce incurs a cost of cudaMemcpy, which should
|
||||
# be small(<=1% of overall latency) compared to the performance
|
||||
# gains of using custom kernels
|
||||
if self.should_custom_ar(input):
|
||||
return self.all_reduce_unreg(input)
|
||||
|
||||
return None
|
||||
|
||||
def close(self):
|
||||
if self._ptr:
|
||||
if not self.disabled and self._ptr:
|
||||
custom_ar.dispose(self._ptr)
|
||||
self._ptr = 0
|
||||
|
||||
|
||||
@ -96,8 +96,10 @@ class PyNcclCommunicator:
|
||||
self.stream = torch.cuda.Stream()
|
||||
|
||||
# A small all_reduce for warmup.
|
||||
self.all_reduce(torch.zeros(1, device=device))
|
||||
data = torch.zeros(1, device=device)
|
||||
self.all_reduce(data)
|
||||
self.stream.synchronize()
|
||||
del data
|
||||
|
||||
# by default it is disabled, e.g. in profiling models and prefill phase.
|
||||
# to use it, use under `with obj.change_state(enable=True)`, usually
|
||||
|
||||
@ -13,10 +13,13 @@ from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_ENABLE_CUSTOM_ALL_REDUCE = True
|
||||
|
||||
# Tensor model parallel group that the current rank belongs to.
|
||||
_TP_DEVICE_GROUP: Optional[ProcessGroup] = None
|
||||
_TP_CPU_GROUP: Optional[ProcessGroup] = None
|
||||
_TP_PYNCCL_COMMUNICATOR = None
|
||||
_TP_CA_COMMUNICATOR = None
|
||||
# Pipeline model parallel group that the current rank belongs to.
|
||||
_PP_DEVICE_GROUP: Optional[ProcessGroup] = None
|
||||
|
||||
@ -47,11 +50,21 @@ _PP_GLOBAL_RANKS: Optional[List[int]] = None
|
||||
_LOCAL_RANK = -1
|
||||
|
||||
|
||||
def set_custom_all_reduce(enable: bool):
|
||||
global _ENABLE_CUSTOM_ALL_REDUCE
|
||||
_ENABLE_CUSTOM_ALL_REDUCE = enable
|
||||
|
||||
|
||||
def get_tp_pynccl_communicator():
|
||||
global _TP_PYNCCL_COMMUNICATOR
|
||||
return _TP_PYNCCL_COMMUNICATOR
|
||||
|
||||
|
||||
def get_tp_ca_communicator():
|
||||
global _TP_CA_COMMUNICATOR
|
||||
return _TP_CA_COMMUNICATOR
|
||||
|
||||
|
||||
def get_local_rank():
|
||||
global _LOCAL_RANK
|
||||
return _LOCAL_RANK
|
||||
@ -100,6 +113,9 @@ def init_distributed_environment(
|
||||
if torch.cuda.is_available():
|
||||
data = data.to(device=f"cuda:{local_rank}")
|
||||
torch.distributed.all_reduce(data)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
del data
|
||||
|
||||
|
||||
def initialize_model_parallel(
|
||||
@ -149,7 +165,8 @@ def initialize_model_parallel(
|
||||
rank = torch.distributed.get_rank()
|
||||
|
||||
# Build the tensor model-parallel groups.
|
||||
global _TP_DEVICE_GROUP, _TP_CPU_GROUP, _TP_PYNCCL_COMMUNICATOR
|
||||
global _TP_DEVICE_GROUP, _TP_CPU_GROUP
|
||||
global _TP_PYNCCL_COMMUNICATOR, _TP_CA_COMMUNICATOR
|
||||
assert _TP_DEVICE_GROUP is None, (
|
||||
"tensor model parallel group is already initialized")
|
||||
for i in range(num_tensor_model_parallel_groups):
|
||||
@ -168,6 +185,15 @@ def initialize_model_parallel(
|
||||
device=_LOCAL_RANK,
|
||||
)
|
||||
|
||||
# Initialize a custom fast all-reduce implementation.
|
||||
if _ENABLE_CUSTOM_ALL_REDUCE:
|
||||
from vllm.distributed.device_communicators.custom_all_reduce import (
|
||||
CustomAllreduce)
|
||||
_TP_CA_COMMUNICATOR = CustomAllreduce(
|
||||
group=_TP_CPU_GROUP,
|
||||
device=_LOCAL_RANK,
|
||||
)
|
||||
|
||||
# Build the pipeline model-parallel groups.
|
||||
global _PP_DEVICE_GROUP
|
||||
global _PP_GLOBAL_RANKS
|
||||
|
||||
@ -6,24 +6,24 @@ from vllm.utils import get_open_port
|
||||
|
||||
|
||||
def init_test_distributed_environment(
|
||||
pipeline_parallel_size: int,
|
||||
tensor_parallel_size: int,
|
||||
tp_size: int,
|
||||
pp_size: int,
|
||||
rank: int,
|
||||
distributed_init_port: str,
|
||||
local_rank: int = -1,
|
||||
) -> None:
|
||||
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
|
||||
init_distributed_environment(
|
||||
world_size=pipeline_parallel_size * tensor_parallel_size,
|
||||
world_size=pp_size * tp_size,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
local_rank=local_rank)
|
||||
ensure_model_parallel_initialized(tensor_parallel_size,
|
||||
pipeline_parallel_size)
|
||||
ensure_model_parallel_initialized(tp_size, pp_size)
|
||||
|
||||
|
||||
def multi_process_tensor_parallel(
|
||||
tensor_parallel_size: int,
|
||||
tp_size: int,
|
||||
pp_size: int,
|
||||
test_target,
|
||||
) -> None:
|
||||
# Using ray helps debugging the error when it failed
|
||||
@ -32,10 +32,9 @@ def multi_process_tensor_parallel(
|
||||
|
||||
distributed_init_port = get_open_port()
|
||||
refs = []
|
||||
for rank in range(tensor_parallel_size):
|
||||
for rank in range(tp_size * pp_size):
|
||||
refs.append(
|
||||
test_target.remote(tensor_parallel_size, rank,
|
||||
distributed_init_port))
|
||||
test_target.remote(tp_size, pp_size, rank, distributed_init_port))
|
||||
ray.get(refs)
|
||||
|
||||
ray.shutdown()
|
||||
|
||||
@ -12,8 +12,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
VisionLanguageConfig)
|
||||
from vllm.distributed import broadcast_tensor_dict
|
||||
from vllm.distributed.communication_op import graph_capture_mode
|
||||
from vllm.distributed.device_communicators import custom_all_reduce
|
||||
from vllm.distributed.communication_op import graph_capture, graph_mode
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -942,13 +941,7 @@ class ModelRunner:
|
||||
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
|
||||
]
|
||||
|
||||
# NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce
|
||||
# kernel, pynccl, and PyTorch NCCL. When using CUDA graph, we use
|
||||
# either custom all-reduce kernel or pynccl. When not using CUDA
|
||||
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
|
||||
# We always prioritize using custom all-reduce kernel but fall back
|
||||
# to PyTorch or pynccl if it is disabled or not supported.
|
||||
with custom_all_reduce.capture():
|
||||
with graph_capture():
|
||||
# NOTE: Capturing the largest batch size first may help reduce the
|
||||
# memory usage of CUDA graph.
|
||||
for batch_size in reversed(batch_size_capture_list):
|
||||
@ -1040,7 +1033,7 @@ class CUDAGraphRunner:
|
||||
# Run the model once without capturing the graph.
|
||||
# This is to make sure that the captured graph does not include the
|
||||
# kernel launches for initial benchmarking (e.g., Triton autotune).
|
||||
with graph_capture_mode():
|
||||
with graph_mode():
|
||||
self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
@ -1055,7 +1048,7 @@ class CUDAGraphRunner:
|
||||
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
|
||||
self._graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self._graph, pool=memory_pool): # noqa: SIM117
|
||||
with graph_capture_mode():
|
||||
with graph_mode():
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
|
||||
@ -11,9 +11,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
VisionLanguageConfig)
|
||||
from vllm.distributed import (broadcast_tensor_dict,
|
||||
ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.distributed.device_communicators.custom_all_reduce import (
|
||||
init_custom_ar)
|
||||
init_distributed_environment,
|
||||
set_custom_all_reduce)
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
|
||||
@ -302,16 +301,14 @@ def init_worker_distributed_environment(
|
||||
local_rank: int = -1,
|
||||
) -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
|
||||
|
||||
init_distributed_environment(parallel_config.world_size, rank,
|
||||
distributed_init_method, local_rank)
|
||||
|
||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
|
||||
# Initialize a custom fast all-reduce implementation.
|
||||
if not parallel_config.disable_custom_all_reduce:
|
||||
init_custom_ar()
|
||||
|
||||
|
||||
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
||||
# Check if the GPU supports the dtype.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user