mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-31 20:27:08 +08:00
[Core][Distributed] code deduplication in tp&pp with coordinator(#5293)
[Core][Distributed] add coordinator to reduce code duplication in tp and pp (#5293)
This commit is contained in:
parent
2135cacb45
commit
ea3890a5f0
@ -15,7 +15,8 @@ from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq,
|
|||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
|
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
|
||||||
from vllm.distributed import destroy_model_parallel
|
from vllm.distributed import (destroy_distributed_environment,
|
||||||
|
destroy_model_parallel)
|
||||||
from vllm.inputs import TextPrompt
|
from vllm.inputs import TextPrompt
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.multimodal import MultiModalData
|
from vllm.multimodal import MultiModalData
|
||||||
@ -54,6 +55,7 @@ def _read_prompts(filename: str) -> List[str]:
|
|||||||
|
|
||||||
def cleanup():
|
def cleanup():
|
||||||
destroy_model_parallel()
|
destroy_model_parallel()
|
||||||
|
destroy_distributed_environment()
|
||||||
with contextlib.suppress(AssertionError):
|
with contextlib.suppress(AssertionError):
|
||||||
torch.distributed.destroy_process_group()
|
torch.distributed.destroy_process_group()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|||||||
@ -7,9 +7,9 @@ import torch
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from vllm.distributed.communication_op import ( # noqa
|
from vllm.distributed.communication_op import ( # noqa
|
||||||
graph_capture, tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
|
from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
|
||||||
get_tp_ca_communicator)
|
get_tp_group, graph_capture)
|
||||||
|
|
||||||
from ..utils import (init_test_distributed_environment,
|
from ..utils import (init_test_distributed_environment,
|
||||||
multi_process_tensor_parallel)
|
multi_process_tensor_parallel)
|
||||||
@ -91,7 +91,7 @@ def eager_allreduce(tp_size, pp_size, rank, distributed_init_port):
|
|||||||
# communicate independently
|
# communicate independently
|
||||||
num_communication = rank // tp_size + 1
|
num_communication = rank // tp_size + 1
|
||||||
sz = 1024
|
sz = 1024
|
||||||
fa = get_tp_ca_communicator()
|
fa = get_tp_group().ca_comm
|
||||||
inp = torch.ones(sz, dtype=torch.float32, device=device)
|
inp = torch.ones(sz, dtype=torch.float32, device=device)
|
||||||
out = inp
|
out = inp
|
||||||
for _ in range(num_communication):
|
for _ in range(num_communication):
|
||||||
|
|||||||
@ -6,10 +6,11 @@ import torch
|
|||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
from vllm.distributed.communication_op import ( # noqa
|
from vllm.distributed.communication_op import ( # noqa
|
||||||
graph_capture, tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||||
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
|
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
|
||||||
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
|
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
|
||||||
|
get_world_group, graph_capture,
|
||||||
init_distributed_environment)
|
init_distributed_environment)
|
||||||
from vllm.utils import update_environment_variables
|
from vllm.utils import update_environment_variables
|
||||||
|
|
||||||
@ -53,7 +54,8 @@ def worker_fn_wrapper(fn):
|
|||||||
|
|
||||||
@worker_fn_wrapper
|
@worker_fn_wrapper
|
||||||
def worker_fn():
|
def worker_fn():
|
||||||
pynccl_comm = PyNcclCommunicator()
|
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
||||||
|
device=get_world_group().device)
|
||||||
tensor = torch.ones(16, 1024, 1024,
|
tensor = torch.ones(16, 1024, 1024,
|
||||||
dtype=torch.float32).cuda(pynccl_comm.rank)
|
dtype=torch.float32).cuda(pynccl_comm.rank)
|
||||||
with pynccl_comm.change_state(enable=True):
|
with pynccl_comm.change_state(enable=True):
|
||||||
@ -129,7 +131,8 @@ def test_pynccl_multiple_allreduce_with_vllm():
|
|||||||
def worker_fn_with_cudagraph():
|
def worker_fn_with_cudagraph():
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
graph = torch.cuda.CUDAGraph()
|
graph = torch.cuda.CUDAGraph()
|
||||||
pynccl_comm = PyNcclCommunicator()
|
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
||||||
|
device=get_world_group().device)
|
||||||
# run something in the default stream to initialize torch engine
|
# run something in the default stream to initialize torch engine
|
||||||
a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}')
|
a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}')
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
@ -154,7 +157,8 @@ def test_pynccl_with_cudagraph():
|
|||||||
|
|
||||||
@worker_fn_wrapper
|
@worker_fn_wrapper
|
||||||
def send_recv_worker_fn():
|
def send_recv_worker_fn():
|
||||||
pynccl_comm = PyNcclCommunicator()
|
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
||||||
|
device=get_world_group().device)
|
||||||
if pynccl_comm.rank == 0:
|
if pynccl_comm.rank == 0:
|
||||||
tensor = torch.ones(16, 1024, 1024,
|
tensor = torch.ones(16, 1024, 1024,
|
||||||
dtype=torch.float32).cuda(pynccl_comm.rank)
|
dtype=torch.float32).cuda(pynccl_comm.rank)
|
||||||
|
|||||||
@ -12,7 +12,10 @@ from huggingface_hub import snapshot_download
|
|||||||
|
|
||||||
import vllm
|
import vllm
|
||||||
from vllm.config import LoRAConfig
|
from vllm.config import LoRAConfig
|
||||||
from vllm.distributed import destroy_model_parallel, initialize_model_parallel
|
from vllm.distributed import (destroy_distributed_environment,
|
||||||
|
destroy_model_parallel,
|
||||||
|
init_distributed_environment,
|
||||||
|
initialize_model_parallel)
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
@ -35,6 +38,7 @@ LONG_LORA_INFOS = [{
|
|||||||
|
|
||||||
def cleanup():
|
def cleanup():
|
||||||
destroy_model_parallel()
|
destroy_model_parallel()
|
||||||
|
destroy_distributed_environment()
|
||||||
with contextlib.suppress(AssertionError):
|
with contextlib.suppress(AssertionError):
|
||||||
torch.distributed.destroy_process_group()
|
torch.distributed.destroy_process_group()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
@ -64,15 +68,14 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool):
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def dist_init():
|
def dist_init():
|
||||||
if not torch.distributed.is_initialized():
|
temp_file = tempfile.mkstemp()[1]
|
||||||
temp_file = tempfile.mkstemp()[1]
|
init_distributed_environment(
|
||||||
torch.distributed.init_process_group(
|
world_size=1,
|
||||||
backend="nccl",
|
rank=0,
|
||||||
world_size=1,
|
distributed_init_method=f"file://{temp_file}",
|
||||||
rank=0,
|
local_rank=0,
|
||||||
init_method=f"file://{temp_file}",
|
backend="nccl",
|
||||||
)
|
)
|
||||||
torch.distributed.all_reduce(torch.zeros(1).cuda())
|
|
||||||
initialize_model_parallel(1, 1)
|
initialize_model_parallel(1, 1)
|
||||||
yield
|
yield
|
||||||
cleanup()
|
cleanup()
|
||||||
|
|||||||
@ -1,7 +1,8 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.distributed.parallel_state import init_distributed_environment
|
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
|
||||||
|
init_distributed_environment)
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||||
@ -292,6 +293,7 @@ def distributed_init():
|
|||||||
rank=0,
|
rank=0,
|
||||||
distributed_init_method=f"tcp://127.0.0.1:{get_open_port()}",
|
distributed_init_method=f"tcp://127.0.0.1:{get_open_port()}",
|
||||||
local_rank=0)
|
local_rank=0)
|
||||||
|
ensure_model_parallel_initialized(1, 1)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("batch_size", list(range(2, 128)))
|
@pytest.mark.parametrize("batch_size", list(range(2, 128)))
|
||||||
|
|||||||
@ -110,7 +110,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
raise NotImplementedError("TPU version must be 4 or higher.")
|
raise NotImplementedError("TPU version must be 4 or higher.")
|
||||||
|
|
||||||
self.megacore_mode = None
|
self.megacore_mode = None
|
||||||
tpu_type = torch_xla.tpu.get_tpu_env()["TYPE"].lower()
|
tpu_type = torch_xla.tpu.get_tp_groupu_env()["TYPE"].lower()
|
||||||
if not tpu_type.endswith("lite"):
|
if not tpu_type.endswith("lite"):
|
||||||
if self.num_kv_heads % 2 == 0:
|
if self.num_kv_heads % 2 == 0:
|
||||||
self.megacore_mode = "kv_head"
|
self.megacore_mode = "kv_head"
|
||||||
|
|||||||
@ -1,317 +1,32 @@
|
|||||||
from collections import namedtuple
|
from typing import Any, Dict, Optional, Union
|
||||||
from contextlib import contextmanager, nullcontext
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed import ProcessGroup
|
import torch.distributed
|
||||||
|
|
||||||
from .parallel_state import (get_cpu_world_group, get_pp_pynccl_communicator,
|
from .parallel_state import get_tp_group
|
||||||
get_tensor_model_parallel_group,
|
|
||||||
get_tensor_model_parallel_rank,
|
|
||||||
get_tensor_model_parallel_world_size,
|
|
||||||
get_tp_ca_communicator,
|
|
||||||
get_tp_pynccl_communicator)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class GraphCaptureContext:
|
|
||||||
stream: torch.cuda.Stream
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def graph_capture():
|
|
||||||
"""
|
|
||||||
`graph_capture` is a context manager which should surround the code that
|
|
||||||
is capturing the CUDA graph. Its main purpose is to ensure that the
|
|
||||||
some operations will be run after the graph is captured, before the graph
|
|
||||||
is replayed. It returns a `GraphCaptureContext` object which contains the
|
|
||||||
necessary data for the graph capture. Currently, it only contains the
|
|
||||||
stream that the graph capture is running on. This stream is set to the
|
|
||||||
current CUDA stream when the context manager is entered and reset to the
|
|
||||||
default stream when the context manager is exited. This is to ensure that
|
|
||||||
the graph capture is running on a separate stream from the default stream,
|
|
||||||
in order to explicitly distinguish the kernels to capture
|
|
||||||
from other kernels possibly launched on background in the default stream.
|
|
||||||
"""
|
|
||||||
stream = torch.cuda.Stream()
|
|
||||||
graph_capture_context = GraphCaptureContext(stream)
|
|
||||||
ca_comm = get_tp_ca_communicator()
|
|
||||||
maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture()
|
|
||||||
with torch.cuda.stream(stream), maybe_ca_context:
|
|
||||||
# In graph mode, we have to be very careful about the collective
|
|
||||||
# operations. The current status is:
|
|
||||||
# allreduce \ Mode | Eager | Graph |
|
|
||||||
# --------------------------------------------
|
|
||||||
# custom allreduce | enabled | enabled |
|
|
||||||
# PyNccl | disabled| enabled |
|
|
||||||
# torch.distributed | enabled | disabled|
|
|
||||||
#
|
|
||||||
# Note that custom allreduce will have a runtime check, if the tensor
|
|
||||||
# size is too large, it will fallback to the next available option.
|
|
||||||
# In summary: When using CUDA graph, we use
|
|
||||||
# either custom all-reduce kernel or pynccl. When not using CUDA
|
|
||||||
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
|
|
||||||
# We always prioritize using custom all-reduce kernel but fall back
|
|
||||||
# to PyTorch or pynccl if it is disabled or not supported.
|
|
||||||
tp_pynccl_comm = get_tp_pynccl_communicator()
|
|
||||||
pp_pynccl_comm = get_pp_pynccl_communicator()
|
|
||||||
if not tp_pynccl_comm:
|
|
||||||
maybe_tp_pynccl_context = nullcontext()
|
|
||||||
else:
|
|
||||||
maybe_tp_pynccl_context = tp_pynccl_comm.change_state(
|
|
||||||
enable=True, stream=torch.cuda.current_stream())
|
|
||||||
if not pp_pynccl_comm:
|
|
||||||
maybe_pp_pynccl_context = nullcontext()
|
|
||||||
else:
|
|
||||||
maybe_pp_pynccl_context = pp_pynccl_comm.change_state(
|
|
||||||
enable=True, stream=torch.cuda.current_stream())
|
|
||||||
with maybe_tp_pynccl_context, maybe_pp_pynccl_context:
|
|
||||||
yield graph_capture_context
|
|
||||||
|
|
||||||
|
|
||||||
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
|
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
|
||||||
"""All-reduce the input tensor across model parallel group.
|
"""All-reduce the input tensor across model parallel group."""
|
||||||
|
return get_tp_group().all_reduce(input_)
|
||||||
NOTE: This operation will be applied in-place on the input tensor if
|
|
||||||
disable_custom_all_reduce is set to True. Otherwise, this operation may or
|
|
||||||
may not be applied in place depending on whether custom all reduce is
|
|
||||||
invoked for a particular tensor, which further depends on the tensor size
|
|
||||||
and GPU topology.
|
|
||||||
|
|
||||||
TLDR: always assume this function modifies its input, but use the return
|
|
||||||
value as the output.
|
|
||||||
"""
|
|
||||||
ca_comm = get_tp_ca_communicator()
|
|
||||||
|
|
||||||
# Bypass the function if we are using only 1 GPU.
|
|
||||||
if get_tensor_model_parallel_world_size() == 1:
|
|
||||||
return input_
|
|
||||||
if ca_comm is not None:
|
|
||||||
out = ca_comm.custom_all_reduce(input_)
|
|
||||||
if out is not None:
|
|
||||||
return out
|
|
||||||
pynccl_comm = get_tp_pynccl_communicator()
|
|
||||||
if (pynccl_comm is not None and not pynccl_comm.disabled):
|
|
||||||
pynccl_comm.all_reduce(input_)
|
|
||||||
else:
|
|
||||||
torch.distributed.all_reduce(input_,
|
|
||||||
group=get_tensor_model_parallel_group())
|
|
||||||
return input_
|
|
||||||
|
|
||||||
|
|
||||||
def tensor_model_parallel_all_gather(input_: torch.Tensor,
|
def tensor_model_parallel_all_gather(input_: torch.Tensor,
|
||||||
dim: int = -1) -> torch.Tensor:
|
dim: int = -1) -> torch.Tensor:
|
||||||
"""All-gather the input tensor across model parallel group."""
|
"""All-gather the input tensor across model parallel group."""
|
||||||
world_size = get_tensor_model_parallel_world_size()
|
return get_tp_group().all_gather(input_, dim)
|
||||||
# Bypass the function if we are using only 1 GPU.
|
|
||||||
if world_size == 1:
|
|
||||||
return input_
|
|
||||||
assert -input_.dim() <= dim < input_.dim(), (
|
|
||||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
|
||||||
if dim < 0:
|
|
||||||
# Convert negative dim to positive.
|
|
||||||
dim += input_.dim()
|
|
||||||
input_size = input_.size()
|
|
||||||
# Allocate output tensor.
|
|
||||||
output_tensor = torch.empty((world_size, ) + input_size,
|
|
||||||
dtype=input_.dtype,
|
|
||||||
device=input_.device)
|
|
||||||
# All-gather.
|
|
||||||
torch.distributed.all_gather_into_tensor(
|
|
||||||
output_tensor, input_, group=get_tensor_model_parallel_group())
|
|
||||||
# Reshape
|
|
||||||
output_tensor = output_tensor.movedim(0, dim)
|
|
||||||
output_tensor = output_tensor.reshape(input_size[:dim] +
|
|
||||||
(world_size * input_size[dim], ) +
|
|
||||||
input_size[dim + 1:])
|
|
||||||
return output_tensor
|
|
||||||
|
|
||||||
|
|
||||||
def tensor_model_parallel_gather(input_: torch.Tensor,
|
def tensor_model_parallel_gather(input_: torch.Tensor,
|
||||||
dst: int = 0,
|
dst: int = 0,
|
||||||
dim: int = -1) -> torch.Tensor:
|
dim: int = -1) -> torch.Tensor:
|
||||||
"""Gather the input tensor across model parallel group.
|
"""Gather the input tensor across model parallel group."""
|
||||||
|
return get_tp_group().gather(input_, dst, dim)
|
||||||
NOTE: We assume that the input tensor is on the same device across
|
|
||||||
all the ranks.
|
|
||||||
"""
|
|
||||||
world_size = get_tensor_model_parallel_world_size()
|
|
||||||
# Bypass the function if we are using only 1 GPU.
|
|
||||||
if world_size == 1:
|
|
||||||
return input_
|
|
||||||
assert -input_.dim() <= dim < input_.dim(), (
|
|
||||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
|
||||||
if dim < 0:
|
|
||||||
# Convert negative dim to positive.
|
|
||||||
dim += input_.dim()
|
|
||||||
# Allocate output tensor.
|
|
||||||
if get_tensor_model_parallel_rank() == dst:
|
|
||||||
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
|
|
||||||
else:
|
|
||||||
gather_list = None
|
|
||||||
# Gather.
|
|
||||||
torch.distributed.gather(input_,
|
|
||||||
gather_list,
|
|
||||||
dst=dst,
|
|
||||||
group=get_tensor_model_parallel_group())
|
|
||||||
if get_tensor_model_parallel_rank() == dst:
|
|
||||||
output_tensor = torch.cat(gather_list, dim=dim)
|
|
||||||
else:
|
|
||||||
output_tensor = None
|
|
||||||
return output_tensor
|
|
||||||
|
|
||||||
|
|
||||||
def broadcast(input_: torch.Tensor,
|
def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor,
|
||||||
src: int = 0,
|
Any]]] = None,
|
||||||
group: Optional[ProcessGroup] = None):
|
src: int = 0):
|
||||||
"""Broadcast the input tensor."""
|
if not torch.distributed.is_initialized():
|
||||||
group = group or torch.distributed.group.WORLD
|
|
||||||
ranks = torch.distributed.get_process_group_ranks(group)
|
|
||||||
assert src in ranks, f"Invalid src rank ({src})"
|
|
||||||
|
|
||||||
# Bypass the function if we are using only 1 GPU.
|
|
||||||
world_size = torch.distributed.get_world_size(group=group)
|
|
||||||
if world_size == 1:
|
|
||||||
return input_
|
|
||||||
# Broadcast.
|
|
||||||
torch.distributed.broadcast(input_, src=src, group=group)
|
|
||||||
return input_
|
|
||||||
|
|
||||||
|
|
||||||
def broadcast_object_list(obj_list: List[Any],
|
|
||||||
src: int = 0,
|
|
||||||
group: Optional[ProcessGroup] = None):
|
|
||||||
"""Broadcast the input object list."""
|
|
||||||
group = group or torch.distributed.group.WORLD
|
|
||||||
ranks = torch.distributed.get_process_group_ranks(group)
|
|
||||||
assert src in ranks, f"Invalid src rank ({src})"
|
|
||||||
|
|
||||||
# Bypass the function if we are using only 1 GPU.
|
|
||||||
world_size = torch.distributed.get_world_size(group=group)
|
|
||||||
if world_size == 1:
|
|
||||||
return obj_list
|
|
||||||
# Broadcast.
|
|
||||||
torch.distributed.broadcast_object_list(obj_list, src=src, group=group)
|
|
||||||
return obj_list
|
|
||||||
|
|
||||||
|
|
||||||
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
|
|
||||||
|
|
||||||
|
|
||||||
def _split_tensor_dict(
|
|
||||||
tensor_dict: Dict[Any, Union[torch.Tensor, Any]]
|
|
||||||
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
|
|
||||||
"""Split the tensor dictionary into two parts:
|
|
||||||
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
|
|
||||||
by its metadata.
|
|
||||||
2. A list of tensors.
|
|
||||||
"""
|
|
||||||
metadata_list = []
|
|
||||||
tensor_list = []
|
|
||||||
for key, value in tensor_dict.items():
|
|
||||||
if isinstance(value, torch.Tensor):
|
|
||||||
# Note: we cannot use `value.device` here,
|
|
||||||
# because it contains not only the device type but also the device
|
|
||||||
# index (e.g. "cuda:0"). We only need the device type.
|
|
||||||
# receiving side will set the device index.
|
|
||||||
device = "cpu" if value.is_cpu else "cuda"
|
|
||||||
metadata_list.append(
|
|
||||||
(key, TensorMetadata(device, value.dtype, value.size())))
|
|
||||||
tensor_list.append(value)
|
|
||||||
else:
|
|
||||||
metadata_list.append((key, value))
|
|
||||||
return metadata_list, tensor_list
|
|
||||||
|
|
||||||
|
|
||||||
def broadcast_tensor_dict(
|
|
||||||
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
|
|
||||||
src: int = 0,
|
|
||||||
group: Optional[ProcessGroup] = None,
|
|
||||||
metadata_group: Optional[ProcessGroup] = None
|
|
||||||
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
|
|
||||||
"""Broadcast the input tensor dictionary.
|
|
||||||
`group` is used to broadcast the tensors, while `metadata_group` is used
|
|
||||||
to broadcast the metadata of the dict (e.g. dict structure, tensor sizes,
|
|
||||||
dtypes).
|
|
||||||
"""
|
|
||||||
# Bypass the function if we are using only 1 GPU.
|
|
||||||
if (not torch.distributed.is_initialized()
|
|
||||||
or torch.distributed.get_world_size(group=group) == 1):
|
|
||||||
return tensor_dict
|
return tensor_dict
|
||||||
|
return get_tp_group().broadcast_tensor_dict(tensor_dict, src)
|
||||||
group = group or torch.distributed.group.WORLD
|
|
||||||
metadata_group = metadata_group or get_cpu_world_group()
|
|
||||||
ranks = torch.distributed.get_process_group_ranks(group)
|
|
||||||
assert src in ranks, f"Invalid src rank ({src})"
|
|
||||||
|
|
||||||
rank = torch.distributed.get_rank()
|
|
||||||
if rank == src:
|
|
||||||
metadata_list: List[Tuple[Any, Any]] = []
|
|
||||||
assert isinstance(
|
|
||||||
tensor_dict,
|
|
||||||
dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
|
|
||||||
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
|
|
||||||
# `metadata_list` lives in CPU memory.
|
|
||||||
# `broadcast_object_list` involves serialization and deserialization,
|
|
||||||
# all happening on CPU. Therefore, we can use the CPU group.
|
|
||||||
torch.distributed.broadcast_object_list([metadata_list],
|
|
||||||
src=src,
|
|
||||||
group=metadata_group)
|
|
||||||
async_handles = []
|
|
||||||
for tensor in tensor_list:
|
|
||||||
if tensor.numel() == 0:
|
|
||||||
# Skip broadcasting empty tensors.
|
|
||||||
continue
|
|
||||||
if tensor.is_cpu:
|
|
||||||
# use metadata_group for CPU tensors
|
|
||||||
handle = torch.distributed.broadcast(tensor,
|
|
||||||
src=src,
|
|
||||||
group=metadata_group,
|
|
||||||
async_op=True)
|
|
||||||
else:
|
|
||||||
# use group for GPU tensors
|
|
||||||
handle = torch.distributed.broadcast(tensor,
|
|
||||||
src=src,
|
|
||||||
group=group,
|
|
||||||
async_op=True)
|
|
||||||
async_handles.append(handle)
|
|
||||||
for async_handle in async_handles:
|
|
||||||
async_handle.wait()
|
|
||||||
|
|
||||||
else:
|
|
||||||
recv_metadata_list = [None]
|
|
||||||
torch.distributed.broadcast_object_list(recv_metadata_list,
|
|
||||||
src=src,
|
|
||||||
group=metadata_group)
|
|
||||||
assert recv_metadata_list[0] is not None
|
|
||||||
tensor_dict = {}
|
|
||||||
async_handles = []
|
|
||||||
for key, value in recv_metadata_list[0]:
|
|
||||||
if isinstance(value, TensorMetadata):
|
|
||||||
tensor = torch.empty(value.size,
|
|
||||||
dtype=value.dtype,
|
|
||||||
device=value.device)
|
|
||||||
if tensor.numel() == 0:
|
|
||||||
# Skip broadcasting empty tensors.
|
|
||||||
tensor_dict[key] = tensor
|
|
||||||
continue
|
|
||||||
if tensor.is_cpu:
|
|
||||||
# use metadata_group for CPU tensors
|
|
||||||
handle = torch.distributed.broadcast(tensor,
|
|
||||||
src=src,
|
|
||||||
group=metadata_group,
|
|
||||||
async_op=True)
|
|
||||||
else:
|
|
||||||
# use group for GPU tensors
|
|
||||||
handle = torch.distributed.broadcast(tensor,
|
|
||||||
src=src,
|
|
||||||
group=group,
|
|
||||||
async_op=True)
|
|
||||||
async_handles.append(handle)
|
|
||||||
tensor_dict[key] = tensor
|
|
||||||
else:
|
|
||||||
tensor_dict[key] = value
|
|
||||||
for async_handle in async_handles:
|
|
||||||
async_handle.wait()
|
|
||||||
return tensor_dict
|
|
||||||
|
|||||||
@ -9,8 +9,7 @@ import vllm.envs as envs
|
|||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.distributed.device_communicators.custom_all_reduce_utils import (
|
from vllm.distributed.device_communicators.custom_all_reduce_utils import (
|
||||||
gpu_p2p_access_check)
|
gpu_p2p_access_check)
|
||||||
from vllm.distributed.parallel_state import (
|
from vllm.distributed.parallel_state import is_in_the_same_node
|
||||||
get_local_rank, get_tensor_model_parallel_cpu_group, is_in_the_same_node)
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -86,8 +85,8 @@ class CustomAllreduce:
|
|||||||
|
|
||||||
# max_size: max supported allreduce size
|
# max_size: max supported allreduce size
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
group: Optional[ProcessGroup] = None,
|
group: ProcessGroup,
|
||||||
device: Optional[Union[int, str, torch.device]] = None,
|
device: Union[int, str, torch.device],
|
||||||
max_size=8192 * 1024) -> None:
|
max_size=8192 * 1024) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -107,7 +106,6 @@ class CustomAllreduce:
|
|||||||
# e.g. in a non-cuda environment
|
# e.g. in a non-cuda environment
|
||||||
return
|
return
|
||||||
|
|
||||||
group = group or get_tensor_model_parallel_cpu_group()
|
|
||||||
self.group = group
|
self.group = group
|
||||||
|
|
||||||
assert dist.get_backend(group) != dist.Backend.NCCL, (
|
assert dist.get_backend(group) != dist.Backend.NCCL, (
|
||||||
@ -134,10 +132,7 @@ class CustomAllreduce:
|
|||||||
world_size, str(CustomAllreduce._SUPPORTED_WORLD_SIZES))
|
world_size, str(CustomAllreduce._SUPPORTED_WORLD_SIZES))
|
||||||
return
|
return
|
||||||
|
|
||||||
if device is None:
|
if isinstance(device, int):
|
||||||
local_rank = get_local_rank()
|
|
||||||
device = torch.device(f"cuda:{local_rank}")
|
|
||||||
elif isinstance(device, int):
|
|
||||||
device = torch.device(f"cuda:{device}")
|
device = torch.device(f"cuda:{device}")
|
||||||
elif isinstance(device, str):
|
elif isinstance(device, str):
|
||||||
device = torch.device(device)
|
device = torch.device(device)
|
||||||
|
|||||||
@ -11,7 +11,6 @@ import torch.distributed as dist
|
|||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -162,7 +161,8 @@ def gpu_p2p_access_check(i: int, j: int) -> bool:
|
|||||||
f"{VLLM_CONFIG_ROOT}/vllm/gpu_p2p_access_cache_for_{cuda_visible_devices}.json"
|
f"{VLLM_CONFIG_ROOT}/vllm/gpu_p2p_access_cache_for_{cuda_visible_devices}.json"
|
||||||
)
|
)
|
||||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||||
if ((not is_distributed or get_local_rank() == 0)
|
from vllm.distributed.parallel_state import get_world_group
|
||||||
|
if ((not is_distributed or get_world_group().local_rank == 0)
|
||||||
and (not os.path.exists(path))):
|
and (not os.path.exists(path))):
|
||||||
# only the local master process (with local_rank == 0) can
|
# only the local master process (with local_rank == 0) can
|
||||||
# enter this block to calculate the cache
|
# enter this block to calculate the cache
|
||||||
@ -174,8 +174,7 @@ def gpu_p2p_access_check(i: int, j: int) -> bool:
|
|||||||
with open(path, "w") as f:
|
with open(path, "w") as f:
|
||||||
json.dump(cache, f, indent=4)
|
json.dump(cache, f, indent=4)
|
||||||
if is_distributed:
|
if is_distributed:
|
||||||
cpu_world_group = get_cpu_world_group()
|
get_world_group().barrier()
|
||||||
dist.barrier(cpu_world_group)
|
|
||||||
logger.info("reading GPU P2P access cache from %s", path)
|
logger.info("reading GPU P2P access cache from %s", path)
|
||||||
with open(path, "r") as f:
|
with open(path, "r") as f:
|
||||||
cache = json.load(f)
|
cache = json.load(f)
|
||||||
|
|||||||
@ -9,7 +9,6 @@ from torch.distributed import ProcessGroup, ReduceOp
|
|||||||
from vllm.distributed.device_communicators.pynccl_wrapper import (
|
from vllm.distributed.device_communicators.pynccl_wrapper import (
|
||||||
NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum,
|
NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum,
|
||||||
ncclRedOpTypeEnum, ncclUniqueId)
|
ncclRedOpTypeEnum, ncclUniqueId)
|
||||||
from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -19,8 +18,8 @@ class PyNcclCommunicator:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
group: Optional[ProcessGroup] = None,
|
group: ProcessGroup,
|
||||||
device: Optional[Union[int, str, torch.device]] = None,
|
device: Union[int, str, torch.device],
|
||||||
library_path: Optional[str] = None,
|
library_path: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -35,7 +34,6 @@ class PyNcclCommunicator:
|
|||||||
is bind to a unique device.
|
is bind to a unique device.
|
||||||
"""
|
"""
|
||||||
assert dist.is_initialized()
|
assert dist.is_initialized()
|
||||||
group = get_cpu_world_group() if group is None else group
|
|
||||||
assert dist.get_backend(group) != dist.Backend.NCCL, (
|
assert dist.get_backend(group) != dist.Backend.NCCL, (
|
||||||
"PyNcclCommunicator should be attached to a non-NCCL group.")
|
"PyNcclCommunicator should be attached to a non-NCCL group.")
|
||||||
self.group = group
|
self.group = group
|
||||||
@ -77,10 +75,7 @@ class PyNcclCommunicator:
|
|||||||
byte_list = tensor.tolist()
|
byte_list = tensor.tolist()
|
||||||
for i, byte in enumerate(byte_list):
|
for i, byte in enumerate(byte_list):
|
||||||
self.unique_id.internal[i] = byte
|
self.unique_id.internal[i] = byte
|
||||||
if device is None:
|
if isinstance(device, int):
|
||||||
local_rank = get_local_rank()
|
|
||||||
device = torch.device(f"cuda:{local_rank}")
|
|
||||||
elif isinstance(device, int):
|
|
||||||
device = torch.device(f"cuda:{device}")
|
device = torch.device(f"cuda:{device}")
|
||||||
elif isinstance(device, str):
|
elif isinstance(device, str):
|
||||||
device = torch.device(device)
|
device = torch.device(device)
|
||||||
|
|||||||
@ -2,83 +2,520 @@
|
|||||||
# Adapted from
|
# Adapted from
|
||||||
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
|
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
|
||||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||||
"""Tensor and pipeline parallel groups."""
|
"""vLLM distributed state.
|
||||||
|
It takes over the control of the distributed environment from PyTorch.
|
||||||
|
The typical workflow is:
|
||||||
|
|
||||||
|
- call `init_distributed_environment` to initialize the distributed environment.
|
||||||
|
- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
|
||||||
|
initialize the model parallel groups.
|
||||||
|
|
||||||
|
- any code dealing with the distributed stuff
|
||||||
|
|
||||||
|
- call `destroy_model_parallel` to destroy the model parallel groups.
|
||||||
|
- call `destroy_distributed_environment` to destroy the distributed environment.
|
||||||
|
|
||||||
|
If you only need to use the distributed environment without model/pipeline
|
||||||
|
parallelism, you can skip the model parallel initialization and destruction
|
||||||
|
steps.
|
||||||
|
"""
|
||||||
import contextlib
|
import contextlib
|
||||||
|
from collections import namedtuple
|
||||||
|
from contextlib import contextmanager, nullcontext
|
||||||
|
from dataclasses import dataclass
|
||||||
from multiprocessing import resource_tracker, shared_memory
|
from multiprocessing import resource_tracker, shared_memory
|
||||||
from typing import List, Optional
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import Backend, ProcessGroup
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GraphCaptureContext:
|
||||||
|
stream: torch.cuda.Stream
|
||||||
|
|
||||||
|
|
||||||
|
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
|
||||||
|
|
||||||
|
|
||||||
|
def _split_tensor_dict(
|
||||||
|
tensor_dict: Dict[Any, Union[torch.Tensor, Any]]
|
||||||
|
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
|
||||||
|
"""Split the tensor dictionary into two parts:
|
||||||
|
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
|
||||||
|
by its metadata.
|
||||||
|
2. A list of tensors.
|
||||||
|
"""
|
||||||
|
metadata_list = []
|
||||||
|
tensor_list = []
|
||||||
|
for key, value in tensor_dict.items():
|
||||||
|
if isinstance(value, torch.Tensor):
|
||||||
|
# Note: we cannot use `value.device` here,
|
||||||
|
# because it contains not only the device type but also the device
|
||||||
|
# index (e.g. "cuda:0"). We only need the device type.
|
||||||
|
# receiving side will set the device index.
|
||||||
|
device = "cpu" if value.is_cpu else "cuda"
|
||||||
|
metadata_list.append(
|
||||||
|
(key, TensorMetadata(device, value.dtype, value.size())))
|
||||||
|
tensor_list.append(value)
|
||||||
|
else:
|
||||||
|
metadata_list.append((key, value))
|
||||||
|
return metadata_list, tensor_list
|
||||||
|
|
||||||
|
|
||||||
|
class GroupCoordinator:
|
||||||
|
"""
|
||||||
|
PyTorch ProcessGroup wrapper for a group of processes.
|
||||||
|
PyTorch ProcessGroup is bound to one specific communication backend,
|
||||||
|
e.g. NCCL, Gloo, MPI, etc.
|
||||||
|
GroupCoordinator takes charge of all the communication operations among
|
||||||
|
the processes in the group. It can route the communication to
|
||||||
|
a specific implementation (e.g. switch allreduce implementation
|
||||||
|
based on the tensor size and cuda graph mode).
|
||||||
|
"""
|
||||||
|
|
||||||
|
# available attributes:
|
||||||
|
rank: int # global rank
|
||||||
|
ranks: List[int] # global ranks in the group
|
||||||
|
world_size: int # size of the group
|
||||||
|
# difference between `local_rank` and `rank_in_group`:
|
||||||
|
# if we have a group of size 4 across two nodes:
|
||||||
|
# Process | Node | Rank | Local Rank | Rank in Group
|
||||||
|
# 0 | 0 | 0 | 0 | 0
|
||||||
|
# 1 | 0 | 1 | 1 | 1
|
||||||
|
# 2 | 1 | 2 | 0 | 2
|
||||||
|
# 3 | 1 | 3 | 1 | 3
|
||||||
|
local_rank: int # local rank used to assign devices
|
||||||
|
rank_in_group: int # rank inside the group
|
||||||
|
cpu_group: ProcessGroup # group for CPU communication
|
||||||
|
device_group: ProcessGroup # group for device communication
|
||||||
|
use_pynccl: bool # a hint of whether to use PyNccl
|
||||||
|
use_custom_allreduce: bool # a hint of whether to use CustomAllreduce
|
||||||
|
# communicators are only created for world size > 1
|
||||||
|
pynccl_comm: Optional[Any] # PyNccl communicator
|
||||||
|
ca_comm: Optional[Any] # Custom allreduce communicator
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
group_ranks: List[List[int]],
|
||||||
|
local_rank: int,
|
||||||
|
torch_distributed_backend: Union[str, Backend],
|
||||||
|
use_pynccl: bool,
|
||||||
|
use_custom_allreduce: bool,
|
||||||
|
):
|
||||||
|
|
||||||
|
self.rank = torch.distributed.get_rank()
|
||||||
|
self.local_rank = local_rank
|
||||||
|
self.device_group = None
|
||||||
|
self.cpu_group = None
|
||||||
|
|
||||||
|
for ranks in group_ranks:
|
||||||
|
device_group = torch.distributed.new_group(
|
||||||
|
ranks, backend=torch_distributed_backend)
|
||||||
|
# a group with `gloo` backend, to allow direct coordination between
|
||||||
|
# processes through the CPU.
|
||||||
|
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
|
||||||
|
if self.rank in ranks:
|
||||||
|
self.ranks = ranks
|
||||||
|
self.world_size = len(ranks)
|
||||||
|
self.rank_in_group = ranks.index(self.rank)
|
||||||
|
self.device_group = device_group
|
||||||
|
self.cpu_group = cpu_group
|
||||||
|
|
||||||
|
assert self.cpu_group is not None
|
||||||
|
assert self.device_group is not None
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
self.device = torch.device(f"cuda:{local_rank}")
|
||||||
|
else:
|
||||||
|
self.device = torch.device("cpu")
|
||||||
|
|
||||||
|
self.use_pynccl = use_pynccl
|
||||||
|
self.use_custom_allreduce = use_custom_allreduce
|
||||||
|
|
||||||
|
# lazy import to avoid documentation build error
|
||||||
|
from vllm.distributed.device_communicators.custom_all_reduce import (
|
||||||
|
CustomAllreduce)
|
||||||
|
from vllm.distributed.device_communicators.pynccl import (
|
||||||
|
PyNcclCommunicator)
|
||||||
|
|
||||||
|
self.pynccl_comm: Optional[PyNcclCommunicator]
|
||||||
|
if use_pynccl and self.world_size > 1:
|
||||||
|
self.pynccl_comm = PyNcclCommunicator(
|
||||||
|
group=self.cpu_group,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.pynccl_comm = None
|
||||||
|
|
||||||
|
self.ca_comm: Optional[CustomAllreduce]
|
||||||
|
if use_custom_allreduce and self.world_size > 1:
|
||||||
|
# Initialize a custom fast all-reduce implementation.
|
||||||
|
self.ca_comm = CustomAllreduce(
|
||||||
|
group=self.cpu_group,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.ca_comm = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def first_rank(self):
|
||||||
|
"""Return the global rank of the first process in the group"""
|
||||||
|
return self.ranks[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def last_rank(self):
|
||||||
|
"""Return the global rank of the last process in the group"""
|
||||||
|
return self.ranks[-1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def next_rank(self):
|
||||||
|
"""Return the global rank of the process that follows the caller"""
|
||||||
|
rank_in_group = self.rank_in_group
|
||||||
|
world_size = self.world_size
|
||||||
|
return self.ranks[(rank_in_group + 1) % world_size]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prev_rank(self):
|
||||||
|
"""Return the global rank of the process that precedes the caller"""
|
||||||
|
rank_in_group = self.rank_in_group
|
||||||
|
world_size = self.world_size
|
||||||
|
return self.ranks[(rank_in_group - 1) % world_size]
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def graph_capture(
|
||||||
|
self, graph_capture_context: Optional[GraphCaptureContext] = None):
|
||||||
|
if graph_capture_context is None:
|
||||||
|
stream = torch.cuda.Stream()
|
||||||
|
graph_capture_context = GraphCaptureContext(stream)
|
||||||
|
else:
|
||||||
|
stream = graph_capture_context.stream
|
||||||
|
|
||||||
|
ca_comm = self.ca_comm
|
||||||
|
maybe_ca_context = nullcontext(
|
||||||
|
) if ca_comm is None else ca_comm.capture()
|
||||||
|
with torch.cuda.stream(stream), maybe_ca_context:
|
||||||
|
# In graph mode, we have to be very careful about the collective
|
||||||
|
# operations. The current status is:
|
||||||
|
# allreduce \ Mode | Eager | Graph |
|
||||||
|
# --------------------------------------------
|
||||||
|
# custom allreduce | enabled | enabled |
|
||||||
|
# PyNccl | disabled| enabled |
|
||||||
|
# torch.distributed | enabled | disabled|
|
||||||
|
#
|
||||||
|
# Note that custom allreduce will have a runtime check, if the
|
||||||
|
# tensor size is too large, it will fallback to the next
|
||||||
|
# available option.
|
||||||
|
# In summary: When using CUDA graph, we use
|
||||||
|
# either custom all-reduce kernel or pynccl. When not using
|
||||||
|
# CUDA graph, we use either custom all-reduce kernel or
|
||||||
|
# PyTorch NCCL. We always prioritize using custom all-reduce
|
||||||
|
# kernel but fall back to PyTorch or pynccl if it is
|
||||||
|
# disabled or not supported.
|
||||||
|
pynccl_comm = self.pynccl_comm
|
||||||
|
maybe_pynccl_context: Any
|
||||||
|
if not pynccl_comm:
|
||||||
|
maybe_pynccl_context = nullcontext()
|
||||||
|
else:
|
||||||
|
maybe_pynccl_context = pynccl_comm.change_state(
|
||||||
|
enable=True, stream=torch.cuda.current_stream())
|
||||||
|
with maybe_pynccl_context:
|
||||||
|
yield graph_capture_context
|
||||||
|
|
||||||
|
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
NOTE: This operation will be applied in-place or out-of-place.
|
||||||
|
Always assume this function modifies its input, but use the return
|
||||||
|
value as the output.
|
||||||
|
"""
|
||||||
|
ca_comm = self.ca_comm
|
||||||
|
|
||||||
|
# Bypass the function if we are using only 1 GPU.
|
||||||
|
if self.world_size == 1:
|
||||||
|
return input_
|
||||||
|
if ca_comm is not None:
|
||||||
|
out = ca_comm.custom_all_reduce(input_)
|
||||||
|
if out is not None:
|
||||||
|
return out
|
||||||
|
pynccl_comm = self.pynccl_comm
|
||||||
|
if (pynccl_comm is not None and not pynccl_comm.disabled):
|
||||||
|
pynccl_comm.all_reduce(input_)
|
||||||
|
else:
|
||||||
|
torch.distributed.all_reduce(input_, group=self.device_group)
|
||||||
|
return input_
|
||||||
|
|
||||||
|
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||||
|
world_size = self.world_size
|
||||||
|
# Bypass the function if we are using only 1 GPU.
|
||||||
|
if world_size == 1:
|
||||||
|
return input_
|
||||||
|
assert -input_.dim() <= dim < input_.dim(), (
|
||||||
|
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
||||||
|
if dim < 0:
|
||||||
|
# Convert negative dim to positive.
|
||||||
|
dim += input_.dim()
|
||||||
|
input_size = input_.size()
|
||||||
|
# Allocate output tensor.
|
||||||
|
output_tensor = torch.empty((world_size, ) + input_size,
|
||||||
|
dtype=input_.dtype,
|
||||||
|
device=input_.device)
|
||||||
|
# All-gather.
|
||||||
|
torch.distributed.all_gather_into_tensor(output_tensor,
|
||||||
|
input_,
|
||||||
|
group=self.device_group)
|
||||||
|
# Reshape
|
||||||
|
output_tensor = output_tensor.movedim(0, dim)
|
||||||
|
output_tensor = output_tensor.reshape(input_size[:dim] +
|
||||||
|
(world_size *
|
||||||
|
input_size[dim], ) +
|
||||||
|
input_size[dim + 1:])
|
||||||
|
return output_tensor
|
||||||
|
|
||||||
|
def gather(self,
|
||||||
|
input_: torch.Tensor,
|
||||||
|
dst: int = 0,
|
||||||
|
dim: int = -1) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
NOTE: We assume that the input tensor is on the same device across
|
||||||
|
all the ranks.
|
||||||
|
NOTE: `dst` is the local rank of the destination rank.
|
||||||
|
"""
|
||||||
|
world_size = self.world_size
|
||||||
|
# Bypass the function if we are using only 1 GPU.
|
||||||
|
if world_size == 1:
|
||||||
|
return input_
|
||||||
|
assert -input_.dim() <= dim < input_.dim(), (
|
||||||
|
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
||||||
|
if dim < 0:
|
||||||
|
# Convert negative dim to positive.
|
||||||
|
dim += input_.dim()
|
||||||
|
# Allocate output tensor.
|
||||||
|
if self.rank_in_group == dst:
|
||||||
|
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||||
|
else:
|
||||||
|
gather_list = None
|
||||||
|
# Gather.
|
||||||
|
torch.distributed.gather(input_,
|
||||||
|
gather_list,
|
||||||
|
dst=self.ranks[dst],
|
||||||
|
group=self.device_group)
|
||||||
|
if self.rank_in_group == dst:
|
||||||
|
output_tensor = torch.cat(gather_list, dim=dim)
|
||||||
|
else:
|
||||||
|
output_tensor = None
|
||||||
|
return output_tensor
|
||||||
|
|
||||||
|
def broadcast(self, input_: torch.Tensor, src: int = 0):
|
||||||
|
"""Broadcast the input tensor.
|
||||||
|
NOTE: `src` is the local rank of the source rank.
|
||||||
|
"""
|
||||||
|
assert src < self.world_size, f"Invalid src rank ({src})"
|
||||||
|
|
||||||
|
# Bypass the function if we are using only 1 GPU.
|
||||||
|
if self.world_size == 1:
|
||||||
|
return input_
|
||||||
|
# Broadcast.
|
||||||
|
torch.distributed.broadcast(input_,
|
||||||
|
src=self.ranks[src],
|
||||||
|
group=self.device_group)
|
||||||
|
return input_
|
||||||
|
|
||||||
|
def broadcast_object_list(self,
|
||||||
|
obj_list: List[Any],
|
||||||
|
src: int = 0,
|
||||||
|
group: Optional[ProcessGroup] = None):
|
||||||
|
"""Broadcast the input object list.
|
||||||
|
NOTE: `src` is the local rank of the source rank.
|
||||||
|
"""
|
||||||
|
assert src < self.world_size, f"Invalid src rank ({src})"
|
||||||
|
|
||||||
|
# Bypass the function if we are using only 1 GPU.
|
||||||
|
if self.world_size == 1:
|
||||||
|
return obj_list
|
||||||
|
# Broadcast.
|
||||||
|
torch.distributed.broadcast_object_list(obj_list,
|
||||||
|
src=self.ranks[src],
|
||||||
|
group=self.device_group)
|
||||||
|
return obj_list
|
||||||
|
|
||||||
|
def broadcast_tensor_dict(
|
||||||
|
self,
|
||||||
|
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
|
||||||
|
src: int = 0,
|
||||||
|
group: Optional[ProcessGroup] = None,
|
||||||
|
metadata_group: Optional[ProcessGroup] = None
|
||||||
|
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
|
||||||
|
"""Broadcast the input tensor dictionary.
|
||||||
|
NOTE: `src` is the local rank of the source rank.
|
||||||
|
"""
|
||||||
|
# Bypass the function if we are using only 1 GPU.
|
||||||
|
if (not torch.distributed.is_initialized() or self.world_size == 1):
|
||||||
|
return tensor_dict
|
||||||
|
|
||||||
|
group = self.device_group
|
||||||
|
metadata_group = self.cpu_group
|
||||||
|
assert src < self.world_size, f"Invalid src rank ({src})"
|
||||||
|
src = self.ranks[src]
|
||||||
|
|
||||||
|
rank = self.rank
|
||||||
|
if rank == src:
|
||||||
|
metadata_list: List[Tuple[Any, Any]] = []
|
||||||
|
assert isinstance(
|
||||||
|
tensor_dict,
|
||||||
|
dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
|
||||||
|
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
|
||||||
|
# `metadata_list` lives in CPU memory.
|
||||||
|
# `broadcast_object_list` has serialization & deserialization,
|
||||||
|
# all happening on CPU. Therefore, we can use the CPU group.
|
||||||
|
torch.distributed.broadcast_object_list([metadata_list],
|
||||||
|
src=src,
|
||||||
|
group=metadata_group)
|
||||||
|
async_handles = []
|
||||||
|
for tensor in tensor_list:
|
||||||
|
if tensor.numel() == 0:
|
||||||
|
# Skip broadcasting empty tensors.
|
||||||
|
continue
|
||||||
|
if tensor.is_cpu:
|
||||||
|
# use metadata_group for CPU tensors
|
||||||
|
handle = torch.distributed.broadcast(tensor,
|
||||||
|
src=src,
|
||||||
|
group=metadata_group,
|
||||||
|
async_op=True)
|
||||||
|
else:
|
||||||
|
# use group for GPU tensors
|
||||||
|
handle = torch.distributed.broadcast(tensor,
|
||||||
|
src=src,
|
||||||
|
group=group,
|
||||||
|
async_op=True)
|
||||||
|
async_handles.append(handle)
|
||||||
|
for async_handle in async_handles:
|
||||||
|
async_handle.wait()
|
||||||
|
|
||||||
|
else:
|
||||||
|
recv_metadata_list = [None]
|
||||||
|
torch.distributed.broadcast_object_list(recv_metadata_list,
|
||||||
|
src=src,
|
||||||
|
group=metadata_group)
|
||||||
|
assert recv_metadata_list[0] is not None
|
||||||
|
tensor_dict = {}
|
||||||
|
async_handles = []
|
||||||
|
for key, value in recv_metadata_list[0]:
|
||||||
|
if isinstance(value, TensorMetadata):
|
||||||
|
tensor = torch.empty(value.size,
|
||||||
|
dtype=value.dtype,
|
||||||
|
device=value.device)
|
||||||
|
if tensor.numel() == 0:
|
||||||
|
# Skip broadcasting empty tensors.
|
||||||
|
tensor_dict[key] = tensor
|
||||||
|
continue
|
||||||
|
if tensor.is_cpu:
|
||||||
|
# use metadata_group for CPU tensors
|
||||||
|
handle = torch.distributed.broadcast(
|
||||||
|
tensor,
|
||||||
|
src=src,
|
||||||
|
group=metadata_group,
|
||||||
|
async_op=True)
|
||||||
|
else:
|
||||||
|
# use group for GPU tensors
|
||||||
|
handle = torch.distributed.broadcast(tensor,
|
||||||
|
src=src,
|
||||||
|
group=group,
|
||||||
|
async_op=True)
|
||||||
|
async_handles.append(handle)
|
||||||
|
tensor_dict[key] = tensor
|
||||||
|
else:
|
||||||
|
tensor_dict[key] = value
|
||||||
|
for async_handle in async_handles:
|
||||||
|
async_handle.wait()
|
||||||
|
return tensor_dict
|
||||||
|
|
||||||
|
def barrier(self):
|
||||||
|
"""Barrier synchronization among the group.
|
||||||
|
NOTE: don't use `device_group` here! `barrier` in NCCL is
|
||||||
|
terrible because it is internally a broadcast operation with
|
||||||
|
secretly created GPU tensors. It is easy to mess up the current
|
||||||
|
device. Use the CPU group instead.
|
||||||
|
"""
|
||||||
|
torch.distributed.barrier(group=self.cpu_group)
|
||||||
|
|
||||||
|
def destroy(self):
|
||||||
|
if self.device_group is not None:
|
||||||
|
torch.distributed.destroy_process_group(self.device_group)
|
||||||
|
self.device_group = None
|
||||||
|
if self.cpu_group is not None:
|
||||||
|
torch.distributed.destroy_process_group(self.cpu_group)
|
||||||
|
self.cpu_group = None
|
||||||
|
if self.pynccl_comm is not None:
|
||||||
|
self.pynccl_comm = None
|
||||||
|
if self.ca_comm is not None:
|
||||||
|
self.ca_comm = None
|
||||||
|
|
||||||
|
|
||||||
|
_WORLD: Optional[GroupCoordinator] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_world_group() -> GroupCoordinator:
|
||||||
|
assert _WORLD is not None, ("world group is not initialized")
|
||||||
|
return _WORLD
|
||||||
|
|
||||||
|
|
||||||
|
_TP: Optional[GroupCoordinator] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_tp_group() -> GroupCoordinator:
|
||||||
|
assert _TP is not None, ("tensor model parallel group is not initialized")
|
||||||
|
return _TP
|
||||||
|
|
||||||
|
|
||||||
|
# kept for backward compatibility
|
||||||
|
get_tensor_model_parallel_group = get_tp_group
|
||||||
|
|
||||||
|
_PP: Optional[GroupCoordinator] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_pp_group() -> GroupCoordinator:
|
||||||
|
assert _PP is not None, (
|
||||||
|
"pipeline model parallel group is not initialized")
|
||||||
|
return _PP
|
||||||
|
|
||||||
|
|
||||||
|
# kept for backward compatibility
|
||||||
|
get_pipeline_model_parallel_group = get_pp_group
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def graph_capture():
|
||||||
|
"""
|
||||||
|
`graph_capture` is a context manager which should surround the code that
|
||||||
|
is capturing the CUDA graph. Its main purpose is to ensure that the
|
||||||
|
some operations will be run after the graph is captured, before the graph
|
||||||
|
is replayed. It returns a `GraphCaptureContext` object which contains the
|
||||||
|
necessary data for the graph capture. Currently, it only contains the
|
||||||
|
stream that the graph capture is running on. This stream is set to the
|
||||||
|
current CUDA stream when the context manager is entered and reset to the
|
||||||
|
default stream when the context manager is exited. This is to ensure that
|
||||||
|
the graph capture is running on a separate stream from the default stream,
|
||||||
|
in order to explicitly distinguish the kernels to capture
|
||||||
|
from other kernels possibly launched on background in the default stream.
|
||||||
|
"""
|
||||||
|
with get_tp_group().graph_capture() as context, get_pp_group(
|
||||||
|
).graph_capture(context):
|
||||||
|
yield context
|
||||||
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
_ENABLE_CUSTOM_ALL_REDUCE = True
|
_ENABLE_CUSTOM_ALL_REDUCE = True
|
||||||
|
|
||||||
# Tensor model parallel group that the current rank belongs to.
|
|
||||||
_TP_DEVICE_GROUP: Optional[ProcessGroup] = None
|
|
||||||
_TP_CPU_GROUP: Optional[ProcessGroup] = None
|
|
||||||
_TP_PYNCCL_COMMUNICATOR = None
|
|
||||||
_TP_CA_COMMUNICATOR = None
|
|
||||||
# Pipeline model parallel group that the current rank belongs to.
|
|
||||||
_PP_DEVICE_GROUP: Optional[ProcessGroup] = None
|
|
||||||
_PP_CPU_GROUP: Optional[ProcessGroup] = None
|
|
||||||
_PP_PYNCCL_COMMUNICATOR = None
|
|
||||||
|
|
||||||
# when people blindly call `torch.distributed.all_reduce` etc,
|
|
||||||
# it will use this group. It is initialized with the `backend`
|
|
||||||
# parameter of `init_distributed_environment` below.
|
|
||||||
# Essentially, this is `torch.distributed.group.WORLD`.
|
|
||||||
# We leave a line here to note that this is device-specific.
|
|
||||||
# Note that this variable is not safe to use, because when users
|
|
||||||
# call `init_distributed_environment` first, and then destroy
|
|
||||||
# the process group themselves, this variable will keep a reference to the
|
|
||||||
# destroyed process group, which is not useful.
|
|
||||||
_DEVICE_WORLD_GROUP = None
|
|
||||||
|
|
||||||
# duing `init_distributed_environment`, we will also initialize a
|
|
||||||
# group with `gloo` backend, to allow direct coordination between
|
|
||||||
# processes through the CPU.
|
|
||||||
_CPU_WORLD_GROUP = None
|
|
||||||
|
|
||||||
# In summary, after calling `init_distributed_environment`, we will
|
|
||||||
# always have two groups: one for device-specific (and is the default)
|
|
||||||
# and one for CPU. All processes will be part of both groups.
|
|
||||||
|
|
||||||
# A list of global ranks for each pipeline group to ease calculation of the
|
|
||||||
# source rank when broadcasting from the first or last pipeline stage.
|
|
||||||
_PP_GLOBAL_RANKS: Optional[List[int]] = None
|
|
||||||
|
|
||||||
_LOCAL_RANK = -1
|
|
||||||
|
|
||||||
|
|
||||||
def set_custom_all_reduce(enable: bool):
|
def set_custom_all_reduce(enable: bool):
|
||||||
global _ENABLE_CUSTOM_ALL_REDUCE
|
global _ENABLE_CUSTOM_ALL_REDUCE
|
||||||
_ENABLE_CUSTOM_ALL_REDUCE = enable
|
_ENABLE_CUSTOM_ALL_REDUCE = enable
|
||||||
|
|
||||||
|
|
||||||
def get_pp_pynccl_communicator():
|
|
||||||
global _PP_PYNCCL_COMMUNICATOR
|
|
||||||
return _PP_PYNCCL_COMMUNICATOR
|
|
||||||
|
|
||||||
|
|
||||||
def get_tp_pynccl_communicator():
|
|
||||||
global _TP_PYNCCL_COMMUNICATOR
|
|
||||||
return _TP_PYNCCL_COMMUNICATOR
|
|
||||||
|
|
||||||
|
|
||||||
def get_tp_ca_communicator():
|
|
||||||
global _TP_CA_COMMUNICATOR
|
|
||||||
return _TP_CA_COMMUNICATOR
|
|
||||||
|
|
||||||
|
|
||||||
def get_local_rank():
|
|
||||||
global _LOCAL_RANK
|
|
||||||
return _LOCAL_RANK
|
|
||||||
|
|
||||||
|
|
||||||
def init_distributed_environment(
|
def init_distributed_environment(
|
||||||
world_size: int = -1,
|
world_size: int = -1,
|
||||||
rank: int = -1,
|
rank: int = -1,
|
||||||
@ -100,31 +537,29 @@ def init_distributed_environment(
|
|||||||
init_method=distributed_init_method,
|
init_method=distributed_init_method,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
rank=rank)
|
rank=rank)
|
||||||
global _DEVICE_WORLD_GROUP, _CPU_WORLD_GROUP
|
# set the local rank
|
||||||
_DEVICE_WORLD_GROUP = torch.distributed.group.WORLD
|
# local_rank is not available in torch ProcessGroup,
|
||||||
|
# see https://github.com/pytorch/pytorch/issues/122816
|
||||||
|
if local_rank == -1:
|
||||||
|
# local rank not set, this usually happens in single-node
|
||||||
|
# setting, where we can use rank as local rank
|
||||||
|
if distributed_init_method == "env://":
|
||||||
|
local_rank = envs.LOCAL_RANK
|
||||||
|
else:
|
||||||
|
local_rank = rank
|
||||||
|
global _WORLD
|
||||||
|
if _WORLD is None:
|
||||||
ranks = list(range(torch.distributed.get_world_size()))
|
ranks = list(range(torch.distributed.get_world_size()))
|
||||||
_CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks,
|
_WORLD = GroupCoordinator(
|
||||||
backend="gloo")
|
group_ranks=[ranks],
|
||||||
# set the local rank
|
local_rank=local_rank,
|
||||||
# local_rank is not available in torch ProcessGroup,
|
torch_distributed_backend=backend,
|
||||||
# see https://github.com/pytorch/pytorch/issues/122816
|
use_pynccl=False,
|
||||||
if local_rank == -1:
|
use_custom_allreduce=False,
|
||||||
# local rank not set, this usually happens in single-node
|
)
|
||||||
# setting, where we can use rank as local rank
|
else:
|
||||||
if distributed_init_method == "env://":
|
assert _WORLD.world_size == torch.distributed.get_world_size(), (
|
||||||
local_rank = envs.LOCAL_RANK
|
"world group already initialized with a different world size")
|
||||||
else:
|
|
||||||
local_rank = rank
|
|
||||||
global _LOCAL_RANK
|
|
||||||
_LOCAL_RANK = local_rank
|
|
||||||
# A small all_reduce for warmup.
|
|
||||||
data = torch.zeros(1)
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
data = data.to(device=f"cuda:{local_rank}")
|
|
||||||
torch.distributed.all_reduce(data)
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
del data
|
|
||||||
|
|
||||||
|
|
||||||
def initialize_model_parallel(
|
def initialize_model_parallel(
|
||||||
@ -157,8 +592,8 @@ def initialize_model_parallel(
|
|||||||
# Get world size and rank. Ensure some consistencies.
|
# Get world size and rank. Ensure some consistencies.
|
||||||
assert torch.distributed.is_initialized()
|
assert torch.distributed.is_initialized()
|
||||||
world_size: int = torch.distributed.get_world_size()
|
world_size: int = torch.distributed.get_world_size()
|
||||||
# get the backend of _DEVICE_WORLD_GROUP
|
backend = backend or torch.distributed.get_backend(
|
||||||
backend = backend or torch.distributed.get_backend()
|
get_world_group().device_group)
|
||||||
|
|
||||||
if (world_size !=
|
if (world_size !=
|
||||||
tensor_model_parallel_size * pipeline_model_parallel_size):
|
tensor_model_parallel_size * pipeline_model_parallel_size):
|
||||||
@ -167,63 +602,42 @@ def initialize_model_parallel(
|
|||||||
f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
|
f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
|
||||||
f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")
|
f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")
|
||||||
|
|
||||||
|
# Build the tensor model-parallel groups.
|
||||||
num_tensor_model_parallel_groups: int = (world_size //
|
num_tensor_model_parallel_groups: int = (world_size //
|
||||||
tensor_model_parallel_size)
|
tensor_model_parallel_size)
|
||||||
num_pipeline_model_parallel_groups: int = (world_size //
|
global _TP
|
||||||
pipeline_model_parallel_size)
|
assert _TP is None, ("tensor model parallel group is already initialized")
|
||||||
rank = torch.distributed.get_rank()
|
group_ranks = []
|
||||||
|
|
||||||
# Build the tensor model-parallel groups.
|
|
||||||
global _TP_DEVICE_GROUP, _TP_CPU_GROUP
|
|
||||||
global _TP_PYNCCL_COMMUNICATOR, _TP_CA_COMMUNICATOR
|
|
||||||
assert _TP_DEVICE_GROUP is None, (
|
|
||||||
"tensor model parallel group is already initialized")
|
|
||||||
for i in range(num_tensor_model_parallel_groups):
|
for i in range(num_tensor_model_parallel_groups):
|
||||||
ranks = list(
|
ranks = list(
|
||||||
range(i * tensor_model_parallel_size,
|
range(i * tensor_model_parallel_size,
|
||||||
(i + 1) * tensor_model_parallel_size))
|
(i + 1) * tensor_model_parallel_size))
|
||||||
group = torch.distributed.new_group(ranks, backend=backend)
|
group_ranks.append(ranks)
|
||||||
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
|
_TP = GroupCoordinator(
|
||||||
if rank in ranks:
|
group_ranks=group_ranks,
|
||||||
_TP_DEVICE_GROUP = group
|
local_rank=get_world_group().local_rank,
|
||||||
_TP_CPU_GROUP = cpu_group
|
torch_distributed_backend=backend,
|
||||||
|
use_pynccl=True,
|
||||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
use_custom_allreduce=_ENABLE_CUSTOM_ALL_REDUCE,
|
||||||
if tensor_model_parallel_size > 1:
|
)
|
||||||
_TP_PYNCCL_COMMUNICATOR = PyNcclCommunicator(
|
|
||||||
group=_TP_CPU_GROUP,
|
|
||||||
device=_LOCAL_RANK,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize a custom fast all-reduce implementation.
|
|
||||||
if _ENABLE_CUSTOM_ALL_REDUCE:
|
|
||||||
from vllm.distributed.device_communicators.custom_all_reduce import (
|
|
||||||
CustomAllreduce)
|
|
||||||
_TP_CA_COMMUNICATOR = CustomAllreduce(
|
|
||||||
group=_TP_CPU_GROUP,
|
|
||||||
device=_LOCAL_RANK,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build the pipeline model-parallel groups.
|
# Build the pipeline model-parallel groups.
|
||||||
global _PP_DEVICE_GROUP, _PP_CPU_GROUP
|
num_pipeline_model_parallel_groups: int = (world_size //
|
||||||
global _PP_PYNCCL_COMMUNICATOR
|
pipeline_model_parallel_size)
|
||||||
global _PP_GLOBAL_RANKS
|
global _PP
|
||||||
assert _PP_DEVICE_GROUP is None, (
|
assert _PP is None, (
|
||||||
"pipeline model parallel group is already initialized")
|
"pipeline model parallel group is already initialized")
|
||||||
|
group_ranks = []
|
||||||
for i in range(num_pipeline_model_parallel_groups):
|
for i in range(num_pipeline_model_parallel_groups):
|
||||||
ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
|
ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
|
||||||
group = torch.distributed.new_group(ranks, backend=backend)
|
group_ranks.append(ranks)
|
||||||
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
|
_PP = GroupCoordinator(
|
||||||
if rank in ranks:
|
group_ranks=group_ranks,
|
||||||
_PP_DEVICE_GROUP = group
|
local_rank=get_world_group().local_rank,
|
||||||
_PP_CPU_GROUP = cpu_group
|
torch_distributed_backend=backend,
|
||||||
_PP_GLOBAL_RANKS = ranks
|
use_pynccl=True,
|
||||||
|
use_custom_allreduce=_ENABLE_CUSTOM_ALL_REDUCE,
|
||||||
if pipeline_model_parallel_size > 1:
|
)
|
||||||
_PP_PYNCCL_COMMUNICATOR = PyNcclCommunicator(
|
|
||||||
group=_PP_CPU_GROUP,
|
|
||||||
device=_LOCAL_RANK,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def ensure_model_parallel_initialized(
|
def ensure_model_parallel_initialized(
|
||||||
@ -235,8 +649,8 @@ def ensure_model_parallel_initialized(
|
|||||||
or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
|
or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
|
||||||
values if the model parallel groups are initialized.
|
values if the model parallel groups are initialized.
|
||||||
"""
|
"""
|
||||||
# get the backend of _DEVICE_WORLD_GROUP
|
backend = backend or torch.distributed.get_backend(
|
||||||
backend = backend or torch.distributed.get_backend()
|
get_world_group().device_group)
|
||||||
if not model_parallel_is_initialized():
|
if not model_parallel_is_initialized():
|
||||||
initialize_model_parallel(tensor_model_parallel_size,
|
initialize_model_parallel(tensor_model_parallel_size,
|
||||||
pipeline_model_parallel_size, backend)
|
pipeline_model_parallel_size, backend)
|
||||||
@ -247,137 +661,48 @@ def ensure_model_parallel_initialized(
|
|||||||
), ("tensor parallel group already initialized, but of unexpected size: "
|
), ("tensor parallel group already initialized, but of unexpected size: "
|
||||||
f"{get_tensor_model_parallel_world_size()=} vs. "
|
f"{get_tensor_model_parallel_world_size()=} vs. "
|
||||||
f"{tensor_model_parallel_size=}")
|
f"{tensor_model_parallel_size=}")
|
||||||
assert (get_pipeline_model_parallel_world_size(
|
pp_world_size = get_pp_group().world_size
|
||||||
) == pipeline_model_parallel_size), (
|
assert (pp_world_size == pipeline_model_parallel_size), (
|
||||||
"pipeline parallel group already initialized, but of unexpected size: "
|
"pipeline parallel group already initialized, but of unexpected size: "
|
||||||
f"{get_pipeline_model_parallel_world_size()=} vs. "
|
f"{pp_world_size=} vs. "
|
||||||
f"{pipeline_model_parallel_size=}")
|
f"{pipeline_model_parallel_size=}")
|
||||||
|
|
||||||
|
|
||||||
def model_parallel_is_initialized():
|
def model_parallel_is_initialized():
|
||||||
"""Check if tensor and pipeline parallel groups are initialized."""
|
"""Check if tensor and pipeline parallel groups are initialized."""
|
||||||
return (_TP_DEVICE_GROUP is not None and _PP_DEVICE_GROUP is not None)
|
return (_TP is not None and _PP is not None)
|
||||||
|
|
||||||
|
|
||||||
def get_cpu_world_group():
|
|
||||||
"""Get the CPU world group."""
|
|
||||||
assert _CPU_WORLD_GROUP is not None, ("CPU world group is not initialized")
|
|
||||||
return _CPU_WORLD_GROUP
|
|
||||||
|
|
||||||
|
|
||||||
def get_tensor_model_parallel_group():
|
|
||||||
"""Get the tensor model parallel group the caller rank belongs to."""
|
|
||||||
assert _TP_DEVICE_GROUP is not None, (
|
|
||||||
"tensor model parallel group is not initialized")
|
|
||||||
return _TP_DEVICE_GROUP
|
|
||||||
|
|
||||||
|
|
||||||
def get_tensor_model_parallel_cpu_group():
|
|
||||||
"""Get the tensor model parallel cpu group the caller rank belongs to."""
|
|
||||||
assert _TP_CPU_GROUP is not None, (
|
|
||||||
"tensor model parallel cpu group is not initialized")
|
|
||||||
return _TP_CPU_GROUP
|
|
||||||
|
|
||||||
|
|
||||||
def get_pipeline_model_parallel_group():
|
|
||||||
"""Get the pipeline model parallel group the caller rank belongs to."""
|
|
||||||
assert _PP_DEVICE_GROUP is not None, (
|
|
||||||
"pipeline model parallel group is not initialized")
|
|
||||||
return _PP_DEVICE_GROUP
|
|
||||||
|
|
||||||
|
|
||||||
def get_pipeline_model_parallel_cpu_group():
|
|
||||||
"""Get the pipeline model parallel cpu group the caller rank belongs to."""
|
|
||||||
assert _PP_CPU_GROUP is not None, (
|
|
||||||
"pipeline model parallel cpu group is not initialized")
|
|
||||||
return _PP_CPU_GROUP
|
|
||||||
|
|
||||||
|
|
||||||
def get_tensor_model_parallel_world_size():
|
def get_tensor_model_parallel_world_size():
|
||||||
"""Return world size for the tensor model parallel group."""
|
"""Return world size for the tensor model parallel group."""
|
||||||
return torch.distributed.get_world_size(
|
return get_tp_group().world_size
|
||||||
group=get_tensor_model_parallel_group())
|
|
||||||
|
|
||||||
|
|
||||||
def get_pipeline_model_parallel_world_size():
|
|
||||||
"""Return world size for the pipeline model parallel group."""
|
|
||||||
return torch.distributed.get_world_size(
|
|
||||||
group=get_pipeline_model_parallel_group())
|
|
||||||
|
|
||||||
|
|
||||||
def get_tensor_model_parallel_rank():
|
def get_tensor_model_parallel_rank():
|
||||||
"""Return my rank for the tensor model parallel group."""
|
"""Return my rank for the tensor model parallel group."""
|
||||||
return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
|
return get_tp_group().rank_in_group
|
||||||
|
|
||||||
|
|
||||||
def get_pipeline_model_parallel_rank():
|
|
||||||
"""Return my rank for the pipeline model parallel group."""
|
|
||||||
return torch.distributed.get_rank(
|
|
||||||
group=get_pipeline_model_parallel_group())
|
|
||||||
|
|
||||||
|
|
||||||
def get_tensor_model_parallel_src_rank():
|
|
||||||
"""Calculate the global rank corresponding to the first local rank
|
|
||||||
in the tensor model parallel group."""
|
|
||||||
global_rank = torch.distributed.get_rank()
|
|
||||||
local_world_size = get_tensor_model_parallel_world_size()
|
|
||||||
return (global_rank // local_world_size) * local_world_size
|
|
||||||
|
|
||||||
|
|
||||||
def get_pipeline_model_parallel_first_rank():
|
|
||||||
"""Return the global rank of the first process in the pipeline for the
|
|
||||||
current tensor parallel group"""
|
|
||||||
assert _PP_GLOBAL_RANKS is not None, (
|
|
||||||
"Pipeline parallel group is not initialized")
|
|
||||||
return _PP_GLOBAL_RANKS[0]
|
|
||||||
|
|
||||||
|
|
||||||
def get_pipeline_model_parallel_last_rank():
|
|
||||||
"""Return the global rank of the last process in the pipeline for the
|
|
||||||
current tensor parallel group"""
|
|
||||||
assert _PP_GLOBAL_RANKS is not None, (
|
|
||||||
"Pipeline parallel group is not initialized")
|
|
||||||
last_rank_local = get_pipeline_model_parallel_world_size() - 1
|
|
||||||
return _PP_GLOBAL_RANKS[last_rank_local]
|
|
||||||
|
|
||||||
|
|
||||||
def get_pipeline_model_parallel_next_rank():
|
|
||||||
"""Return the global rank that follows the caller in the pipeline"""
|
|
||||||
assert _PP_GLOBAL_RANKS is not None, (
|
|
||||||
"Pipeline parallel group is not initialized")
|
|
||||||
rank_in_pipeline = get_pipeline_model_parallel_rank()
|
|
||||||
world_size = get_pipeline_model_parallel_world_size()
|
|
||||||
return _PP_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
|
|
||||||
|
|
||||||
|
|
||||||
def get_pipeline_model_parallel_prev_rank():
|
|
||||||
"""Return the global rank that precedes the caller in the pipeline"""
|
|
||||||
assert _PP_GLOBAL_RANKS is not None, (
|
|
||||||
"Pipeline parallel group is not initialized")
|
|
||||||
rank_in_pipeline = get_pipeline_model_parallel_rank()
|
|
||||||
world_size = get_pipeline_model_parallel_world_size()
|
|
||||||
return _PP_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
|
|
||||||
|
|
||||||
|
|
||||||
def destroy_model_parallel():
|
def destroy_model_parallel():
|
||||||
"""Set the groups to none and destroy them."""
|
"""Set the groups to none and destroy them."""
|
||||||
global _TP_DEVICE_GROUP
|
global _TP
|
||||||
if _TP_DEVICE_GROUP:
|
if _TP:
|
||||||
torch.distributed.destroy_process_group(_TP_DEVICE_GROUP)
|
_TP.destroy()
|
||||||
_TP_DEVICE_GROUP = None
|
_TP = None
|
||||||
global _TP_CPU_GROUP
|
|
||||||
if _TP_CPU_GROUP:
|
|
||||||
torch.distributed.destroy_process_group(_TP_CPU_GROUP)
|
|
||||||
_TP_CPU_GROUP = None
|
|
||||||
global _TP_PYNCCL_COMMUNICATOR
|
|
||||||
_TP_PYNCCL_COMMUNICATOR = None
|
|
||||||
|
|
||||||
global _PP_DEVICE_GROUP
|
global _PP
|
||||||
if _PP_DEVICE_GROUP:
|
if _PP:
|
||||||
torch.distributed.destroy_process_group(_PP_DEVICE_GROUP)
|
_PP.destroy()
|
||||||
_PP_DEVICE_GROUP = None
|
_PP = None
|
||||||
global _PP_GLOBAL_RANKS
|
|
||||||
_PP_GLOBAL_RANKS = None
|
|
||||||
|
def destroy_distributed_environment():
|
||||||
|
global _WORLD
|
||||||
|
if _WORLD:
|
||||||
|
_WORLD.destroy()
|
||||||
|
_WORLD = None
|
||||||
|
if torch.distributed.is_initialized():
|
||||||
|
torch.distributed.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
def is_in_the_same_node(pg: ProcessGroup):
|
def is_in_the_same_node(pg: ProcessGroup):
|
||||||
|
|||||||
@ -13,7 +13,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
|||||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||||
VisionLanguageConfig)
|
VisionLanguageConfig)
|
||||||
from vllm.distributed import broadcast_tensor_dict
|
from vllm.distributed import broadcast_tensor_dict
|
||||||
from vllm.distributed.communication_op import graph_capture
|
from vllm.distributed.parallel_state import graph_capture
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.layers import LoRAMapping
|
from vllm.lora.layers import LoRAMapping
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user