From ea3890a5f0314e49d69afca45fe706504cb14029 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 12 Jun 2024 17:27:08 -0700 Subject: [PATCH] [Core][Distributed] code deduplication in tp&pp with coordinator(#5293) [Core][Distributed] add coordinator to reduce code duplication in tp and pp (#5293) --- tests/conftest.py | 4 +- tests/distributed/test_custom_all_reduce.py | 6 +- tests/distributed/test_pynccl.py | 12 +- tests/lora/conftest.py | 23 +- tests/worker/test_model_runner.py | 4 +- vllm/attention/backends/pallas.py | 2 +- vllm/distributed/communication_op.py | 311 +------ .../device_communicators/custom_all_reduce.py | 13 +- .../custom_all_reduce_utils.py | 7 +- .../device_communicators/pynccl.py | 11 +- vllm/distributed/parallel_state.py | 815 ++++++++++++------ vllm/worker/model_runner.py | 2 +- 12 files changed, 625 insertions(+), 585 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index e0680467d78b9..29a4f126ff920 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,7 +15,8 @@ from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq, from vllm import LLM, SamplingParams from vllm.config import TokenizerPoolConfig, VisionLanguageConfig -from vllm.distributed import destroy_model_parallel +from vllm.distributed import (destroy_distributed_environment, + destroy_model_parallel) from vllm.inputs import TextPrompt from vllm.logger import init_logger from vllm.multimodal import MultiModalData @@ -54,6 +55,7 @@ def _read_prompts(filename: str) -> List[str]: def cleanup(): destroy_model_parallel() + destroy_distributed_environment() with contextlib.suppress(AssertionError): torch.distributed.destroy_process_group() gc.collect() diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py index 186f9faa6bfb6..3776c1f91a3f2 100644 --- a/tests/distributed/test_custom_all_reduce.py +++ b/tests/distributed/test_custom_all_reduce.py @@ -7,9 +7,9 @@ import torch import torch.distributed as dist from vllm.distributed.communication_op import ( # noqa - graph_capture, tensor_model_parallel_all_reduce) + tensor_model_parallel_all_reduce) from vllm.distributed.parallel_state import (get_tensor_model_parallel_group, - get_tp_ca_communicator) + get_tp_group, graph_capture) from ..utils import (init_test_distributed_environment, multi_process_tensor_parallel) @@ -91,7 +91,7 @@ def eager_allreduce(tp_size, pp_size, rank, distributed_init_port): # communicate independently num_communication = rank // tp_size + 1 sz = 1024 - fa = get_tp_ca_communicator() + fa = get_tp_group().ca_comm inp = torch.ones(sz, dtype=torch.float32, device=device) out = inp for _ in range(num_communication): diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index 0218295a3e3f9..b788e253ab9ef 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -6,10 +6,11 @@ import torch import torch.distributed from vllm.distributed.communication_op import ( # noqa - graph_capture, tensor_model_parallel_all_reduce) + tensor_model_parallel_all_reduce) from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, + get_world_group, graph_capture, init_distributed_environment) from vllm.utils import update_environment_variables @@ -53,7 +54,8 @@ def worker_fn_wrapper(fn): @worker_fn_wrapper def worker_fn(): - pynccl_comm = PyNcclCommunicator() + pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, + device=get_world_group().device) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank) with pynccl_comm.change_state(enable=True): @@ -129,7 +131,8 @@ def test_pynccl_multiple_allreduce_with_vllm(): def worker_fn_with_cudagraph(): with torch.no_grad(): graph = torch.cuda.CUDAGraph() - pynccl_comm = PyNcclCommunicator() + pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, + device=get_world_group().device) # run something in the default stream to initialize torch engine a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}') torch.cuda.synchronize() @@ -154,7 +157,8 @@ def test_pynccl_with_cudagraph(): @worker_fn_wrapper def send_recv_worker_fn(): - pynccl_comm = PyNcclCommunicator() + pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, + device=get_world_group().device) if pynccl_comm.rank == 0: tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank) diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 400333066b9fa..522c635b82d9c 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -12,7 +12,10 @@ from huggingface_hub import snapshot_download import vllm from vllm.config import LoRAConfig -from vllm.distributed import destroy_model_parallel, initialize_model_parallel +from vllm.distributed import (destroy_distributed_environment, + destroy_model_parallel, + init_distributed_environment, + initialize_model_parallel) from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, RowParallelLinear) @@ -35,6 +38,7 @@ LONG_LORA_INFOS = [{ def cleanup(): destroy_model_parallel() + destroy_distributed_environment() with contextlib.suppress(AssertionError): torch.distributed.destroy_process_group() gc.collect() @@ -64,15 +68,14 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool): @pytest.fixture def dist_init(): - if not torch.distributed.is_initialized(): - temp_file = tempfile.mkstemp()[1] - torch.distributed.init_process_group( - backend="nccl", - world_size=1, - rank=0, - init_method=f"file://{temp_file}", - ) - torch.distributed.all_reduce(torch.zeros(1).cuda()) + temp_file = tempfile.mkstemp()[1] + init_distributed_environment( + world_size=1, + rank=0, + distributed_init_method=f"file://{temp_file}", + local_rank=0, + backend="nccl", + ) initialize_model_parallel(1, 1) yield cleanup() diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 92de545acd53d..514a57e17ebf4 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -1,7 +1,8 @@ import pytest import torch -from vllm.distributed.parallel_state import init_distributed_environment +from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, + init_distributed_environment) from vllm.engine.arg_utils import EngineArgs from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata @@ -292,6 +293,7 @@ def distributed_init(): rank=0, distributed_init_method=f"tcp://127.0.0.1:{get_open_port()}", local_rank=0) + ensure_model_parallel_initialized(1, 1) @pytest.mark.parametrize("batch_size", list(range(2, 128))) diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index b203c5ec54c92..75f2465264ad3 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -110,7 +110,7 @@ class PallasAttentionBackendImpl(AttentionImpl): raise NotImplementedError("TPU version must be 4 or higher.") self.megacore_mode = None - tpu_type = torch_xla.tpu.get_tpu_env()["TYPE"].lower() + tpu_type = torch_xla.tpu.get_tp_groupu_env()["TYPE"].lower() if not tpu_type.endswith("lite"): if self.num_kv_heads % 2 == 0: self.megacore_mode = "kv_head" diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 2b38ec472de66..32394a07b00b9 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -1,317 +1,32 @@ -from collections import namedtuple -from contextlib import contextmanager, nullcontext -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Optional, Union import torch -from torch.distributed import ProcessGroup +import torch.distributed -from .parallel_state import (get_cpu_world_group, get_pp_pynccl_communicator, - get_tensor_model_parallel_group, - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - get_tp_ca_communicator, - get_tp_pynccl_communicator) - - -@dataclass -class GraphCaptureContext: - stream: torch.cuda.Stream - - -@contextmanager -def graph_capture(): - """ - `graph_capture` is a context manager which should surround the code that - is capturing the CUDA graph. Its main purpose is to ensure that the - some operations will be run after the graph is captured, before the graph - is replayed. It returns a `GraphCaptureContext` object which contains the - necessary data for the graph capture. Currently, it only contains the - stream that the graph capture is running on. This stream is set to the - current CUDA stream when the context manager is entered and reset to the - default stream when the context manager is exited. This is to ensure that - the graph capture is running on a separate stream from the default stream, - in order to explicitly distinguish the kernels to capture - from other kernels possibly launched on background in the default stream. - """ - stream = torch.cuda.Stream() - graph_capture_context = GraphCaptureContext(stream) - ca_comm = get_tp_ca_communicator() - maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture() - with torch.cuda.stream(stream), maybe_ca_context: - # In graph mode, we have to be very careful about the collective - # operations. The current status is: - # allreduce \ Mode | Eager | Graph | - # -------------------------------------------- - # custom allreduce | enabled | enabled | - # PyNccl | disabled| enabled | - # torch.distributed | enabled | disabled| - # - # Note that custom allreduce will have a runtime check, if the tensor - # size is too large, it will fallback to the next available option. - # In summary: When using CUDA graph, we use - # either custom all-reduce kernel or pynccl. When not using CUDA - # graph, we use either custom all-reduce kernel or PyTorch NCCL. - # We always prioritize using custom all-reduce kernel but fall back - # to PyTorch or pynccl if it is disabled or not supported. - tp_pynccl_comm = get_tp_pynccl_communicator() - pp_pynccl_comm = get_pp_pynccl_communicator() - if not tp_pynccl_comm: - maybe_tp_pynccl_context = nullcontext() - else: - maybe_tp_pynccl_context = tp_pynccl_comm.change_state( - enable=True, stream=torch.cuda.current_stream()) - if not pp_pynccl_comm: - maybe_pp_pynccl_context = nullcontext() - else: - maybe_pp_pynccl_context = pp_pynccl_comm.change_state( - enable=True, stream=torch.cuda.current_stream()) - with maybe_tp_pynccl_context, maybe_pp_pynccl_context: - yield graph_capture_context +from .parallel_state import get_tp_group def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: - """All-reduce the input tensor across model parallel group. - - NOTE: This operation will be applied in-place on the input tensor if - disable_custom_all_reduce is set to True. Otherwise, this operation may or - may not be applied in place depending on whether custom all reduce is - invoked for a particular tensor, which further depends on the tensor size - and GPU topology. - - TLDR: always assume this function modifies its input, but use the return - value as the output. - """ - ca_comm = get_tp_ca_communicator() - - # Bypass the function if we are using only 1 GPU. - if get_tensor_model_parallel_world_size() == 1: - return input_ - if ca_comm is not None: - out = ca_comm.custom_all_reduce(input_) - if out is not None: - return out - pynccl_comm = get_tp_pynccl_communicator() - if (pynccl_comm is not None and not pynccl_comm.disabled): - pynccl_comm.all_reduce(input_) - else: - torch.distributed.all_reduce(input_, - group=get_tensor_model_parallel_group()) - return input_ + """All-reduce the input tensor across model parallel group.""" + return get_tp_group().all_reduce(input_) def tensor_model_parallel_all_gather(input_: torch.Tensor, dim: int = -1) -> torch.Tensor: """All-gather the input tensor across model parallel group.""" - world_size = get_tensor_model_parallel_world_size() - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - assert -input_.dim() <= dim < input_.dim(), ( - f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") - if dim < 0: - # Convert negative dim to positive. - dim += input_.dim() - input_size = input_.size() - # Allocate output tensor. - output_tensor = torch.empty((world_size, ) + input_size, - dtype=input_.dtype, - device=input_.device) - # All-gather. - torch.distributed.all_gather_into_tensor( - output_tensor, input_, group=get_tensor_model_parallel_group()) - # Reshape - output_tensor = output_tensor.movedim(0, dim) - output_tensor = output_tensor.reshape(input_size[:dim] + - (world_size * input_size[dim], ) + - input_size[dim + 1:]) - return output_tensor + return get_tp_group().all_gather(input_, dim) def tensor_model_parallel_gather(input_: torch.Tensor, dst: int = 0, dim: int = -1) -> torch.Tensor: - """Gather the input tensor across model parallel group. - - NOTE: We assume that the input tensor is on the same device across - all the ranks. - """ - world_size = get_tensor_model_parallel_world_size() - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - assert -input_.dim() <= dim < input_.dim(), ( - f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") - if dim < 0: - # Convert negative dim to positive. - dim += input_.dim() - # Allocate output tensor. - if get_tensor_model_parallel_rank() == dst: - gather_list = [torch.empty_like(input_) for _ in range(world_size)] - else: - gather_list = None - # Gather. - torch.distributed.gather(input_, - gather_list, - dst=dst, - group=get_tensor_model_parallel_group()) - if get_tensor_model_parallel_rank() == dst: - output_tensor = torch.cat(gather_list, dim=dim) - else: - output_tensor = None - return output_tensor + """Gather the input tensor across model parallel group.""" + return get_tp_group().gather(input_, dst, dim) -def broadcast(input_: torch.Tensor, - src: int = 0, - group: Optional[ProcessGroup] = None): - """Broadcast the input tensor.""" - group = group or torch.distributed.group.WORLD - ranks = torch.distributed.get_process_group_ranks(group) - assert src in ranks, f"Invalid src rank ({src})" - - # Bypass the function if we are using only 1 GPU. - world_size = torch.distributed.get_world_size(group=group) - if world_size == 1: - return input_ - # Broadcast. - torch.distributed.broadcast(input_, src=src, group=group) - return input_ - - -def broadcast_object_list(obj_list: List[Any], - src: int = 0, - group: Optional[ProcessGroup] = None): - """Broadcast the input object list.""" - group = group or torch.distributed.group.WORLD - ranks = torch.distributed.get_process_group_ranks(group) - assert src in ranks, f"Invalid src rank ({src})" - - # Bypass the function if we are using only 1 GPU. - world_size = torch.distributed.get_world_size(group=group) - if world_size == 1: - return obj_list - # Broadcast. - torch.distributed.broadcast_object_list(obj_list, src=src, group=group) - return obj_list - - -TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) - - -def _split_tensor_dict( - tensor_dict: Dict[Any, Union[torch.Tensor, Any]] -) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: - """Split the tensor dictionary into two parts: - 1. A list of (key, value) pairs. If the value is a tensor, it is replaced - by its metadata. - 2. A list of tensors. - """ - metadata_list = [] - tensor_list = [] - for key, value in tensor_dict.items(): - if isinstance(value, torch.Tensor): - # Note: we cannot use `value.device` here, - # because it contains not only the device type but also the device - # index (e.g. "cuda:0"). We only need the device type. - # receiving side will set the device index. - device = "cpu" if value.is_cpu else "cuda" - metadata_list.append( - (key, TensorMetadata(device, value.dtype, value.size()))) - tensor_list.append(value) - else: - metadata_list.append((key, value)) - return metadata_list, tensor_list - - -def broadcast_tensor_dict( - tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, - src: int = 0, - group: Optional[ProcessGroup] = None, - metadata_group: Optional[ProcessGroup] = None -) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]: - """Broadcast the input tensor dictionary. - `group` is used to broadcast the tensors, while `metadata_group` is used - to broadcast the metadata of the dict (e.g. dict structure, tensor sizes, - dtypes). - """ - # Bypass the function if we are using only 1 GPU. - if (not torch.distributed.is_initialized() - or torch.distributed.get_world_size(group=group) == 1): +def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor, + Any]]] = None, + src: int = 0): + if not torch.distributed.is_initialized(): return tensor_dict - - group = group or torch.distributed.group.WORLD - metadata_group = metadata_group or get_cpu_world_group() - ranks = torch.distributed.get_process_group_ranks(group) - assert src in ranks, f"Invalid src rank ({src})" - - rank = torch.distributed.get_rank() - if rank == src: - metadata_list: List[Tuple[Any, Any]] = [] - assert isinstance( - tensor_dict, - dict), (f"Expecting a dictionary, got {type(tensor_dict)}") - metadata_list, tensor_list = _split_tensor_dict(tensor_dict) - # `metadata_list` lives in CPU memory. - # `broadcast_object_list` involves serialization and deserialization, - # all happening on CPU. Therefore, we can use the CPU group. - torch.distributed.broadcast_object_list([metadata_list], - src=src, - group=metadata_group) - async_handles = [] - for tensor in tensor_list: - if tensor.numel() == 0: - # Skip broadcasting empty tensors. - continue - if tensor.is_cpu: - # use metadata_group for CPU tensors - handle = torch.distributed.broadcast(tensor, - src=src, - group=metadata_group, - async_op=True) - else: - # use group for GPU tensors - handle = torch.distributed.broadcast(tensor, - src=src, - group=group, - async_op=True) - async_handles.append(handle) - for async_handle in async_handles: - async_handle.wait() - - else: - recv_metadata_list = [None] - torch.distributed.broadcast_object_list(recv_metadata_list, - src=src, - group=metadata_group) - assert recv_metadata_list[0] is not None - tensor_dict = {} - async_handles = [] - for key, value in recv_metadata_list[0]: - if isinstance(value, TensorMetadata): - tensor = torch.empty(value.size, - dtype=value.dtype, - device=value.device) - if tensor.numel() == 0: - # Skip broadcasting empty tensors. - tensor_dict[key] = tensor - continue - if tensor.is_cpu: - # use metadata_group for CPU tensors - handle = torch.distributed.broadcast(tensor, - src=src, - group=metadata_group, - async_op=True) - else: - # use group for GPU tensors - handle = torch.distributed.broadcast(tensor, - src=src, - group=group, - async_op=True) - async_handles.append(handle) - tensor_dict[key] = tensor - else: - tensor_dict[key] = value - for async_handle in async_handles: - async_handle.wait() - return tensor_dict + return get_tp_group().broadcast_tensor_dict(tensor_dict, src) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index bbc2284f8a364..9a2b47594916f 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -9,8 +9,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops from vllm.distributed.device_communicators.custom_all_reduce_utils import ( gpu_p2p_access_check) -from vllm.distributed.parallel_state import ( - get_local_rank, get_tensor_model_parallel_cpu_group, is_in_the_same_node) +from vllm.distributed.parallel_state import is_in_the_same_node from vllm.logger import init_logger try: @@ -86,8 +85,8 @@ class CustomAllreduce: # max_size: max supported allreduce size def __init__(self, - group: Optional[ProcessGroup] = None, - device: Optional[Union[int, str, torch.device]] = None, + group: ProcessGroup, + device: Union[int, str, torch.device], max_size=8192 * 1024) -> None: """ Args: @@ -107,7 +106,6 @@ class CustomAllreduce: # e.g. in a non-cuda environment return - group = group or get_tensor_model_parallel_cpu_group() self.group = group assert dist.get_backend(group) != dist.Backend.NCCL, ( @@ -134,10 +132,7 @@ class CustomAllreduce: world_size, str(CustomAllreduce._SUPPORTED_WORLD_SIZES)) return - if device is None: - local_rank = get_local_rank() - device = torch.device(f"cuda:{local_rank}") - elif isinstance(device, int): + if isinstance(device, int): device = torch.device(f"cuda:{device}") elif isinstance(device, str): device = torch.device(device) diff --git a/vllm/distributed/device_communicators/custom_all_reduce_utils.py b/vllm/distributed/device_communicators/custom_all_reduce_utils.py index 4b89a23dfc463..1fd0058f617f8 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce_utils.py +++ b/vllm/distributed/device_communicators/custom_all_reduce_utils.py @@ -11,7 +11,6 @@ import torch.distributed as dist import torch.multiprocessing as mp import vllm.envs as envs -from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank from vllm.logger import init_logger logger = init_logger(__name__) @@ -162,7 +161,8 @@ def gpu_p2p_access_check(i: int, j: int) -> bool: f"{VLLM_CONFIG_ROOT}/vllm/gpu_p2p_access_cache_for_{cuda_visible_devices}.json" ) os.makedirs(os.path.dirname(path), exist_ok=True) - if ((not is_distributed or get_local_rank() == 0) + from vllm.distributed.parallel_state import get_world_group + if ((not is_distributed or get_world_group().local_rank == 0) and (not os.path.exists(path))): # only the local master process (with local_rank == 0) can # enter this block to calculate the cache @@ -174,8 +174,7 @@ def gpu_p2p_access_check(i: int, j: int) -> bool: with open(path, "w") as f: json.dump(cache, f, indent=4) if is_distributed: - cpu_world_group = get_cpu_world_group() - dist.barrier(cpu_world_group) + get_world_group().barrier() logger.info("reading GPU P2P access cache from %s", path) with open(path, "r") as f: cache = json.load(f) diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index f5f1de0c71615..83eec264b6f81 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -9,7 +9,6 @@ from torch.distributed import ProcessGroup, ReduceOp from vllm.distributed.device_communicators.pynccl_wrapper import ( NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum, ncclRedOpTypeEnum, ncclUniqueId) -from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank from vllm.logger import init_logger logger = init_logger(__name__) @@ -19,8 +18,8 @@ class PyNcclCommunicator: def __init__( self, - group: Optional[ProcessGroup] = None, - device: Optional[Union[int, str, torch.device]] = None, + group: ProcessGroup, + device: Union[int, str, torch.device], library_path: Optional[str] = None, ): """ @@ -35,7 +34,6 @@ class PyNcclCommunicator: is bind to a unique device. """ assert dist.is_initialized() - group = get_cpu_world_group() if group is None else group assert dist.get_backend(group) != dist.Backend.NCCL, ( "PyNcclCommunicator should be attached to a non-NCCL group.") self.group = group @@ -77,10 +75,7 @@ class PyNcclCommunicator: byte_list = tensor.tolist() for i, byte in enumerate(byte_list): self.unique_id.internal[i] = byte - if device is None: - local_rank = get_local_rank() - device = torch.device(f"cuda:{local_rank}") - elif isinstance(device, int): + if isinstance(device, int): device = torch.device(f"cuda:{device}") elif isinstance(device, str): device = torch.device(device) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index b6d1eeff09786..f6a2fc9b05a84 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -2,83 +2,520 @@ # Adapted from # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -"""Tensor and pipeline parallel groups.""" +"""vLLM distributed state. +It takes over the control of the distributed environment from PyTorch. +The typical workflow is: + +- call `init_distributed_environment` to initialize the distributed environment. +- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to + initialize the model parallel groups. + +- any code dealing with the distributed stuff + +- call `destroy_model_parallel` to destroy the model parallel groups. +- call `destroy_distributed_environment` to destroy the distributed environment. + +If you only need to use the distributed environment without model/pipeline + parallelism, you can skip the model parallel initialization and destruction + steps. +""" import contextlib +from collections import namedtuple +from contextlib import contextmanager, nullcontext +from dataclasses import dataclass from multiprocessing import resource_tracker, shared_memory -from typing import List, Optional +from typing import Any, Dict, List, Optional, Tuple, Union import torch -from torch.distributed import ProcessGroup +from torch.distributed import Backend, ProcessGroup import vllm.envs as envs from vllm.logger import init_logger + +@dataclass +class GraphCaptureContext: + stream: torch.cuda.Stream + + +TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) + + +def _split_tensor_dict( + tensor_dict: Dict[Any, Union[torch.Tensor, Any]] +) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: + """Split the tensor dictionary into two parts: + 1. A list of (key, value) pairs. If the value is a tensor, it is replaced + by its metadata. + 2. A list of tensors. + """ + metadata_list = [] + tensor_list = [] + for key, value in tensor_dict.items(): + if isinstance(value, torch.Tensor): + # Note: we cannot use `value.device` here, + # because it contains not only the device type but also the device + # index (e.g. "cuda:0"). We only need the device type. + # receiving side will set the device index. + device = "cpu" if value.is_cpu else "cuda" + metadata_list.append( + (key, TensorMetadata(device, value.dtype, value.size()))) + tensor_list.append(value) + else: + metadata_list.append((key, value)) + return metadata_list, tensor_list + + +class GroupCoordinator: + """ + PyTorch ProcessGroup wrapper for a group of processes. + PyTorch ProcessGroup is bound to one specific communication backend, + e.g. NCCL, Gloo, MPI, etc. + GroupCoordinator takes charge of all the communication operations among + the processes in the group. It can route the communication to + a specific implementation (e.g. switch allreduce implementation + based on the tensor size and cuda graph mode). + """ + + # available attributes: + rank: int # global rank + ranks: List[int] # global ranks in the group + world_size: int # size of the group + # difference between `local_rank` and `rank_in_group`: + # if we have a group of size 4 across two nodes: + # Process | Node | Rank | Local Rank | Rank in Group + # 0 | 0 | 0 | 0 | 0 + # 1 | 0 | 1 | 1 | 1 + # 2 | 1 | 2 | 0 | 2 + # 3 | 1 | 3 | 1 | 3 + local_rank: int # local rank used to assign devices + rank_in_group: int # rank inside the group + cpu_group: ProcessGroup # group for CPU communication + device_group: ProcessGroup # group for device communication + use_pynccl: bool # a hint of whether to use PyNccl + use_custom_allreduce: bool # a hint of whether to use CustomAllreduce + # communicators are only created for world size > 1 + pynccl_comm: Optional[Any] # PyNccl communicator + ca_comm: Optional[Any] # Custom allreduce communicator + + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + use_pynccl: bool, + use_custom_allreduce: bool, + ): + + self.rank = torch.distributed.get_rank() + self.local_rank = local_rank + self.device_group = None + self.cpu_group = None + + for ranks in group_ranks: + device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend) + # a group with `gloo` backend, to allow direct coordination between + # processes through the CPU. + cpu_group = torch.distributed.new_group(ranks, backend="gloo") + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self.device_group = device_group + self.cpu_group = cpu_group + + assert self.cpu_group is not None + assert self.device_group is not None + + if torch.cuda.is_available(): + self.device = torch.device(f"cuda:{local_rank}") + else: + self.device = torch.device("cpu") + + self.use_pynccl = use_pynccl + self.use_custom_allreduce = use_custom_allreduce + + # lazy import to avoid documentation build error + from vllm.distributed.device_communicators.custom_all_reduce import ( + CustomAllreduce) + from vllm.distributed.device_communicators.pynccl import ( + PyNcclCommunicator) + + self.pynccl_comm: Optional[PyNcclCommunicator] + if use_pynccl and self.world_size > 1: + self.pynccl_comm = PyNcclCommunicator( + group=self.cpu_group, + device=self.device, + ) + else: + self.pynccl_comm = None + + self.ca_comm: Optional[CustomAllreduce] + if use_custom_allreduce and self.world_size > 1: + # Initialize a custom fast all-reduce implementation. + self.ca_comm = CustomAllreduce( + group=self.cpu_group, + device=self.device, + ) + else: + self.ca_comm = None + + @property + def first_rank(self): + """Return the global rank of the first process in the group""" + return self.ranks[0] + + @property + def last_rank(self): + """Return the global rank of the last process in the group""" + return self.ranks[-1] + + @property + def next_rank(self): + """Return the global rank of the process that follows the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group + 1) % world_size] + + @property + def prev_rank(self): + """Return the global rank of the process that precedes the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group - 1) % world_size] + + @contextmanager + def graph_capture( + self, graph_capture_context: Optional[GraphCaptureContext] = None): + if graph_capture_context is None: + stream = torch.cuda.Stream() + graph_capture_context = GraphCaptureContext(stream) + else: + stream = graph_capture_context.stream + + ca_comm = self.ca_comm + maybe_ca_context = nullcontext( + ) if ca_comm is None else ca_comm.capture() + with torch.cuda.stream(stream), maybe_ca_context: + # In graph mode, we have to be very careful about the collective + # operations. The current status is: + # allreduce \ Mode | Eager | Graph | + # -------------------------------------------- + # custom allreduce | enabled | enabled | + # PyNccl | disabled| enabled | + # torch.distributed | enabled | disabled| + # + # Note that custom allreduce will have a runtime check, if the + # tensor size is too large, it will fallback to the next + # available option. + # In summary: When using CUDA graph, we use + # either custom all-reduce kernel or pynccl. When not using + # CUDA graph, we use either custom all-reduce kernel or + # PyTorch NCCL. We always prioritize using custom all-reduce + # kernel but fall back to PyTorch or pynccl if it is + # disabled or not supported. + pynccl_comm = self.pynccl_comm + maybe_pynccl_context: Any + if not pynccl_comm: + maybe_pynccl_context = nullcontext() + else: + maybe_pynccl_context = pynccl_comm.change_state( + enable=True, stream=torch.cuda.current_stream()) + with maybe_pynccl_context: + yield graph_capture_context + + def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: + """ + NOTE: This operation will be applied in-place or out-of-place. + Always assume this function modifies its input, but use the return + value as the output. + """ + ca_comm = self.ca_comm + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + if ca_comm is not None: + out = ca_comm.custom_all_reduce(input_) + if out is not None: + return out + pynccl_comm = self.pynccl_comm + if (pynccl_comm is not None and not pynccl_comm.disabled): + pynccl_comm.all_reduce(input_) + else: + torch.distributed.all_reduce(input_, group=self.device_group) + return input_ + + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + input_size = input_.size() + # Allocate output tensor. + output_tensor = torch.empty((world_size, ) + input_size, + dtype=input_.dtype, + device=input_.device) + # All-gather. + torch.distributed.all_gather_into_tensor(output_tensor, + input_, + group=self.device_group) + # Reshape + output_tensor = output_tensor.movedim(0, dim) + output_tensor = output_tensor.reshape(input_size[:dim] + + (world_size * + input_size[dim], ) + + input_size[dim + 1:]) + return output_tensor + + def gather(self, + input_: torch.Tensor, + dst: int = 0, + dim: int = -1) -> torch.Tensor: + """ + NOTE: We assume that the input tensor is on the same device across + all the ranks. + NOTE: `dst` is the local rank of the destination rank. + """ + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + # Allocate output tensor. + if self.rank_in_group == dst: + gather_list = [torch.empty_like(input_) for _ in range(world_size)] + else: + gather_list = None + # Gather. + torch.distributed.gather(input_, + gather_list, + dst=self.ranks[dst], + group=self.device_group) + if self.rank_in_group == dst: + output_tensor = torch.cat(gather_list, dim=dim) + else: + output_tensor = None + return output_tensor + + def broadcast(self, input_: torch.Tensor, src: int = 0): + """Broadcast the input tensor. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + # Broadcast. + torch.distributed.broadcast(input_, + src=self.ranks[src], + group=self.device_group) + return input_ + + def broadcast_object_list(self, + obj_list: List[Any], + src: int = 0, + group: Optional[ProcessGroup] = None): + """Broadcast the input object list. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj_list + # Broadcast. + torch.distributed.broadcast_object_list(obj_list, + src=self.ranks[src], + group=self.device_group) + return obj_list + + def broadcast_tensor_dict( + self, + tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, + src: int = 0, + group: Optional[ProcessGroup] = None, + metadata_group: Optional[ProcessGroup] = None + ) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]: + """Broadcast the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if (not torch.distributed.is_initialized() or self.world_size == 1): + return tensor_dict + + group = self.device_group + metadata_group = self.cpu_group + assert src < self.world_size, f"Invalid src rank ({src})" + src = self.ranks[src] + + rank = self.rank + if rank == src: + metadata_list: List[Tuple[Any, Any]] = [] + assert isinstance( + tensor_dict, + dict), (f"Expecting a dictionary, got {type(tensor_dict)}") + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `broadcast_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + torch.distributed.broadcast_object_list([metadata_list], + src=src, + group=metadata_group) + async_handles = [] + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast(tensor, + src=src, + group=metadata_group, + async_op=True) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast(tensor, + src=src, + group=group, + async_op=True) + async_handles.append(handle) + for async_handle in async_handles: + async_handle.wait() + + else: + recv_metadata_list = [None] + torch.distributed.broadcast_object_list(recv_metadata_list, + src=src, + group=metadata_group) + assert recv_metadata_list[0] is not None + tensor_dict = {} + async_handles = [] + for key, value in recv_metadata_list[0]: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, + dtype=value.dtype, + device=value.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + tensor_dict[key] = tensor + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast( + tensor, + src=src, + group=metadata_group, + async_op=True) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast(tensor, + src=src, + group=group, + async_op=True) + async_handles.append(handle) + tensor_dict[key] = tensor + else: + tensor_dict[key] = value + for async_handle in async_handles: + async_handle.wait() + return tensor_dict + + def barrier(self): + """Barrier synchronization among the group. + NOTE: don't use `device_group` here! `barrier` in NCCL is + terrible because it is internally a broadcast operation with + secretly created GPU tensors. It is easy to mess up the current + device. Use the CPU group instead. + """ + torch.distributed.barrier(group=self.cpu_group) + + def destroy(self): + if self.device_group is not None: + torch.distributed.destroy_process_group(self.device_group) + self.device_group = None + if self.cpu_group is not None: + torch.distributed.destroy_process_group(self.cpu_group) + self.cpu_group = None + if self.pynccl_comm is not None: + self.pynccl_comm = None + if self.ca_comm is not None: + self.ca_comm = None + + +_WORLD: Optional[GroupCoordinator] = None + + +def get_world_group() -> GroupCoordinator: + assert _WORLD is not None, ("world group is not initialized") + return _WORLD + + +_TP: Optional[GroupCoordinator] = None + + +def get_tp_group() -> GroupCoordinator: + assert _TP is not None, ("tensor model parallel group is not initialized") + return _TP + + +# kept for backward compatibility +get_tensor_model_parallel_group = get_tp_group + +_PP: Optional[GroupCoordinator] = None + + +def get_pp_group() -> GroupCoordinator: + assert _PP is not None, ( + "pipeline model parallel group is not initialized") + return _PP + + +# kept for backward compatibility +get_pipeline_model_parallel_group = get_pp_group + + +@contextmanager +def graph_capture(): + """ + `graph_capture` is a context manager which should surround the code that + is capturing the CUDA graph. Its main purpose is to ensure that the + some operations will be run after the graph is captured, before the graph + is replayed. It returns a `GraphCaptureContext` object which contains the + necessary data for the graph capture. Currently, it only contains the + stream that the graph capture is running on. This stream is set to the + current CUDA stream when the context manager is entered and reset to the + default stream when the context manager is exited. This is to ensure that + the graph capture is running on a separate stream from the default stream, + in order to explicitly distinguish the kernels to capture + from other kernels possibly launched on background in the default stream. + """ + with get_tp_group().graph_capture() as context, get_pp_group( + ).graph_capture(context): + yield context + + logger = init_logger(__name__) _ENABLE_CUSTOM_ALL_REDUCE = True -# Tensor model parallel group that the current rank belongs to. -_TP_DEVICE_GROUP: Optional[ProcessGroup] = None -_TP_CPU_GROUP: Optional[ProcessGroup] = None -_TP_PYNCCL_COMMUNICATOR = None -_TP_CA_COMMUNICATOR = None -# Pipeline model parallel group that the current rank belongs to. -_PP_DEVICE_GROUP: Optional[ProcessGroup] = None -_PP_CPU_GROUP: Optional[ProcessGroup] = None -_PP_PYNCCL_COMMUNICATOR = None - -# when people blindly call `torch.distributed.all_reduce` etc, -# it will use this group. It is initialized with the `backend` -# parameter of `init_distributed_environment` below. -# Essentially, this is `torch.distributed.group.WORLD`. -# We leave a line here to note that this is device-specific. -# Note that this variable is not safe to use, because when users -# call `init_distributed_environment` first, and then destroy -# the process group themselves, this variable will keep a reference to the -# destroyed process group, which is not useful. -_DEVICE_WORLD_GROUP = None - -# duing `init_distributed_environment`, we will also initialize a -# group with `gloo` backend, to allow direct coordination between -# processes through the CPU. -_CPU_WORLD_GROUP = None - -# In summary, after calling `init_distributed_environment`, we will -# always have two groups: one for device-specific (and is the default) -# and one for CPU. All processes will be part of both groups. - -# A list of global ranks for each pipeline group to ease calculation of the -# source rank when broadcasting from the first or last pipeline stage. -_PP_GLOBAL_RANKS: Optional[List[int]] = None - -_LOCAL_RANK = -1 - def set_custom_all_reduce(enable: bool): global _ENABLE_CUSTOM_ALL_REDUCE _ENABLE_CUSTOM_ALL_REDUCE = enable -def get_pp_pynccl_communicator(): - global _PP_PYNCCL_COMMUNICATOR - return _PP_PYNCCL_COMMUNICATOR - - -def get_tp_pynccl_communicator(): - global _TP_PYNCCL_COMMUNICATOR - return _TP_PYNCCL_COMMUNICATOR - - -def get_tp_ca_communicator(): - global _TP_CA_COMMUNICATOR - return _TP_CA_COMMUNICATOR - - -def get_local_rank(): - global _LOCAL_RANK - return _LOCAL_RANK - - def init_distributed_environment( world_size: int = -1, rank: int = -1, @@ -100,31 +537,29 @@ def init_distributed_environment( init_method=distributed_init_method, world_size=world_size, rank=rank) - global _DEVICE_WORLD_GROUP, _CPU_WORLD_GROUP - _DEVICE_WORLD_GROUP = torch.distributed.group.WORLD + # set the local rank + # local_rank is not available in torch ProcessGroup, + # see https://github.com/pytorch/pytorch/issues/122816 + if local_rank == -1: + # local rank not set, this usually happens in single-node + # setting, where we can use rank as local rank + if distributed_init_method == "env://": + local_rank = envs.LOCAL_RANK + else: + local_rank = rank + global _WORLD + if _WORLD is None: ranks = list(range(torch.distributed.get_world_size())) - _CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks, - backend="gloo") - # set the local rank - # local_rank is not available in torch ProcessGroup, - # see https://github.com/pytorch/pytorch/issues/122816 - if local_rank == -1: - # local rank not set, this usually happens in single-node - # setting, where we can use rank as local rank - if distributed_init_method == "env://": - local_rank = envs.LOCAL_RANK - else: - local_rank = rank - global _LOCAL_RANK - _LOCAL_RANK = local_rank - # A small all_reduce for warmup. - data = torch.zeros(1) - if torch.cuda.is_available(): - data = data.to(device=f"cuda:{local_rank}") - torch.distributed.all_reduce(data) - if torch.cuda.is_available(): - torch.cuda.synchronize() - del data + _WORLD = GroupCoordinator( + group_ranks=[ranks], + local_rank=local_rank, + torch_distributed_backend=backend, + use_pynccl=False, + use_custom_allreduce=False, + ) + else: + assert _WORLD.world_size == torch.distributed.get_world_size(), ( + "world group already initialized with a different world size") def initialize_model_parallel( @@ -157,8 +592,8 @@ def initialize_model_parallel( # Get world size and rank. Ensure some consistencies. assert torch.distributed.is_initialized() world_size: int = torch.distributed.get_world_size() - # get the backend of _DEVICE_WORLD_GROUP - backend = backend or torch.distributed.get_backend() + backend = backend or torch.distributed.get_backend( + get_world_group().device_group) if (world_size != tensor_model_parallel_size * pipeline_model_parallel_size): @@ -167,63 +602,42 @@ def initialize_model_parallel( f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " f"pipeline_model_parallel_size ({pipeline_model_parallel_size})") + # Build the tensor model-parallel groups. num_tensor_model_parallel_groups: int = (world_size // tensor_model_parallel_size) - num_pipeline_model_parallel_groups: int = (world_size // - pipeline_model_parallel_size) - rank = torch.distributed.get_rank() - - # Build the tensor model-parallel groups. - global _TP_DEVICE_GROUP, _TP_CPU_GROUP - global _TP_PYNCCL_COMMUNICATOR, _TP_CA_COMMUNICATOR - assert _TP_DEVICE_GROUP is None, ( - "tensor model parallel group is already initialized") + global _TP + assert _TP is None, ("tensor model parallel group is already initialized") + group_ranks = [] for i in range(num_tensor_model_parallel_groups): ranks = list( range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)) - group = torch.distributed.new_group(ranks, backend=backend) - cpu_group = torch.distributed.new_group(ranks, backend="gloo") - if rank in ranks: - _TP_DEVICE_GROUP = group - _TP_CPU_GROUP = cpu_group - - from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator - if tensor_model_parallel_size > 1: - _TP_PYNCCL_COMMUNICATOR = PyNcclCommunicator( - group=_TP_CPU_GROUP, - device=_LOCAL_RANK, - ) - - # Initialize a custom fast all-reduce implementation. - if _ENABLE_CUSTOM_ALL_REDUCE: - from vllm.distributed.device_communicators.custom_all_reduce import ( - CustomAllreduce) - _TP_CA_COMMUNICATOR = CustomAllreduce( - group=_TP_CPU_GROUP, - device=_LOCAL_RANK, - ) + group_ranks.append(ranks) + _TP = GroupCoordinator( + group_ranks=group_ranks, + local_rank=get_world_group().local_rank, + torch_distributed_backend=backend, + use_pynccl=True, + use_custom_allreduce=_ENABLE_CUSTOM_ALL_REDUCE, + ) # Build the pipeline model-parallel groups. - global _PP_DEVICE_GROUP, _PP_CPU_GROUP - global _PP_PYNCCL_COMMUNICATOR - global _PP_GLOBAL_RANKS - assert _PP_DEVICE_GROUP is None, ( + num_pipeline_model_parallel_groups: int = (world_size // + pipeline_model_parallel_size) + global _PP + assert _PP is None, ( "pipeline model parallel group is already initialized") + group_ranks = [] for i in range(num_pipeline_model_parallel_groups): ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) - group = torch.distributed.new_group(ranks, backend=backend) - cpu_group = torch.distributed.new_group(ranks, backend="gloo") - if rank in ranks: - _PP_DEVICE_GROUP = group - _PP_CPU_GROUP = cpu_group - _PP_GLOBAL_RANKS = ranks - - if pipeline_model_parallel_size > 1: - _PP_PYNCCL_COMMUNICATOR = PyNcclCommunicator( - group=_PP_CPU_GROUP, - device=_LOCAL_RANK, - ) + group_ranks.append(ranks) + _PP = GroupCoordinator( + group_ranks=group_ranks, + local_rank=get_world_group().local_rank, + torch_distributed_backend=backend, + use_pynccl=True, + use_custom_allreduce=_ENABLE_CUSTOM_ALL_REDUCE, + ) def ensure_model_parallel_initialized( @@ -235,8 +649,8 @@ def ensure_model_parallel_initialized( or ensure tensor-parallel and pipeline-parallel sizes are equal to expected values if the model parallel groups are initialized. """ - # get the backend of _DEVICE_WORLD_GROUP - backend = backend or torch.distributed.get_backend() + backend = backend or torch.distributed.get_backend( + get_world_group().device_group) if not model_parallel_is_initialized(): initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend) @@ -247,137 +661,48 @@ def ensure_model_parallel_initialized( ), ("tensor parallel group already initialized, but of unexpected size: " f"{get_tensor_model_parallel_world_size()=} vs. " f"{tensor_model_parallel_size=}") - assert (get_pipeline_model_parallel_world_size( - ) == pipeline_model_parallel_size), ( + pp_world_size = get_pp_group().world_size + assert (pp_world_size == pipeline_model_parallel_size), ( "pipeline parallel group already initialized, but of unexpected size: " - f"{get_pipeline_model_parallel_world_size()=} vs. " + f"{pp_world_size=} vs. " f"{pipeline_model_parallel_size=}") def model_parallel_is_initialized(): """Check if tensor and pipeline parallel groups are initialized.""" - return (_TP_DEVICE_GROUP is not None and _PP_DEVICE_GROUP is not None) - - -def get_cpu_world_group(): - """Get the CPU world group.""" - assert _CPU_WORLD_GROUP is not None, ("CPU world group is not initialized") - return _CPU_WORLD_GROUP - - -def get_tensor_model_parallel_group(): - """Get the tensor model parallel group the caller rank belongs to.""" - assert _TP_DEVICE_GROUP is not None, ( - "tensor model parallel group is not initialized") - return _TP_DEVICE_GROUP - - -def get_tensor_model_parallel_cpu_group(): - """Get the tensor model parallel cpu group the caller rank belongs to.""" - assert _TP_CPU_GROUP is not None, ( - "tensor model parallel cpu group is not initialized") - return _TP_CPU_GROUP - - -def get_pipeline_model_parallel_group(): - """Get the pipeline model parallel group the caller rank belongs to.""" - assert _PP_DEVICE_GROUP is not None, ( - "pipeline model parallel group is not initialized") - return _PP_DEVICE_GROUP - - -def get_pipeline_model_parallel_cpu_group(): - """Get the pipeline model parallel cpu group the caller rank belongs to.""" - assert _PP_CPU_GROUP is not None, ( - "pipeline model parallel cpu group is not initialized") - return _PP_CPU_GROUP + return (_TP is not None and _PP is not None) def get_tensor_model_parallel_world_size(): """Return world size for the tensor model parallel group.""" - return torch.distributed.get_world_size( - group=get_tensor_model_parallel_group()) - - -def get_pipeline_model_parallel_world_size(): - """Return world size for the pipeline model parallel group.""" - return torch.distributed.get_world_size( - group=get_pipeline_model_parallel_group()) + return get_tp_group().world_size def get_tensor_model_parallel_rank(): """Return my rank for the tensor model parallel group.""" - return torch.distributed.get_rank(group=get_tensor_model_parallel_group()) - - -def get_pipeline_model_parallel_rank(): - """Return my rank for the pipeline model parallel group.""" - return torch.distributed.get_rank( - group=get_pipeline_model_parallel_group()) - - -def get_tensor_model_parallel_src_rank(): - """Calculate the global rank corresponding to the first local rank - in the tensor model parallel group.""" - global_rank = torch.distributed.get_rank() - local_world_size = get_tensor_model_parallel_world_size() - return (global_rank // local_world_size) * local_world_size - - -def get_pipeline_model_parallel_first_rank(): - """Return the global rank of the first process in the pipeline for the - current tensor parallel group""" - assert _PP_GLOBAL_RANKS is not None, ( - "Pipeline parallel group is not initialized") - return _PP_GLOBAL_RANKS[0] - - -def get_pipeline_model_parallel_last_rank(): - """Return the global rank of the last process in the pipeline for the - current tensor parallel group""" - assert _PP_GLOBAL_RANKS is not None, ( - "Pipeline parallel group is not initialized") - last_rank_local = get_pipeline_model_parallel_world_size() - 1 - return _PP_GLOBAL_RANKS[last_rank_local] - - -def get_pipeline_model_parallel_next_rank(): - """Return the global rank that follows the caller in the pipeline""" - assert _PP_GLOBAL_RANKS is not None, ( - "Pipeline parallel group is not initialized") - rank_in_pipeline = get_pipeline_model_parallel_rank() - world_size = get_pipeline_model_parallel_world_size() - return _PP_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size] - - -def get_pipeline_model_parallel_prev_rank(): - """Return the global rank that precedes the caller in the pipeline""" - assert _PP_GLOBAL_RANKS is not None, ( - "Pipeline parallel group is not initialized") - rank_in_pipeline = get_pipeline_model_parallel_rank() - world_size = get_pipeline_model_parallel_world_size() - return _PP_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size] + return get_tp_group().rank_in_group def destroy_model_parallel(): """Set the groups to none and destroy them.""" - global _TP_DEVICE_GROUP - if _TP_DEVICE_GROUP: - torch.distributed.destroy_process_group(_TP_DEVICE_GROUP) - _TP_DEVICE_GROUP = None - global _TP_CPU_GROUP - if _TP_CPU_GROUP: - torch.distributed.destroy_process_group(_TP_CPU_GROUP) - _TP_CPU_GROUP = None - global _TP_PYNCCL_COMMUNICATOR - _TP_PYNCCL_COMMUNICATOR = None + global _TP + if _TP: + _TP.destroy() + _TP = None - global _PP_DEVICE_GROUP - if _PP_DEVICE_GROUP: - torch.distributed.destroy_process_group(_PP_DEVICE_GROUP) - _PP_DEVICE_GROUP = None - global _PP_GLOBAL_RANKS - _PP_GLOBAL_RANKS = None + global _PP + if _PP: + _PP.destroy() + _PP = None + + +def destroy_distributed_environment(): + global _WORLD + if _WORLD: + _WORLD.destroy() + _WORLD = None + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() def is_in_the_same_node(pg: ProcessGroup): diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index de616ef1ded96..476e9ba3bb463 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -13,7 +13,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict -from vllm.distributed.communication_op import graph_capture +from vllm.distributed.parallel_state import graph_capture from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest