diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index 93a069d36c4b..96d08dc1a3c1 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -5,7 +5,6 @@ # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. import dataclasses -import datetime import os import pickle import socket @@ -14,14 +13,14 @@ import time import uuid from collections import deque from collections.abc import Sequence +from datetime import timedelta from typing import Any, Optional import torch from torch.distributed import ProcessGroup, TCPStore from torch.distributed.distributed_c10d import (Backend, PrefixStore, _get_default_timeout, - _unregister_process_group, - is_nccl_available) + _unregister_process_group) from torch.distributed.rendezvous import rendezvous import vllm.envs as envs @@ -406,7 +405,7 @@ class StatelessProcessGroup: port=port, world_size=world_size, is_master=launch_server, - timeout=datetime.timedelta(seconds=store_timeout), + timeout=timedelta(seconds=store_timeout), use_libuv=False, # for now: github.com/pytorch/pytorch/pull/150215 master_listen_fd=listen_fd, ) @@ -419,6 +418,43 @@ class StatelessProcessGroup: data_expiration_seconds=data_expiration_seconds) +def init_gloo_process_group(backend: Backend, prefix_store: PrefixStore, + group_rank: int, group_size: int, + timeout: timedelta) -> ProcessGroup: + """ + Stateless init ProcessGroup with gloo backend compatible with + different torch versions. + """ + if is_torch_equal_or_newer("2.6"): + pg = ProcessGroup( + prefix_store, + group_rank, + group_size, + ) + else: + options = ProcessGroup.Options(backend=backend) + pg = ProcessGroup( + prefix_store, + group_rank, + group_size, + options, + ) + from torch.distributed.distributed_c10d import ProcessGroupGloo + backend_class = ProcessGroupGloo(prefix_store, + group_rank, + group_size, + timeout=timeout) + backend_type = ProcessGroup.BackendType.GLOO + device = torch.device("cpu") + if is_torch_equal_or_newer("2.6"): + # _set_default_backend is supported in torch >= 2.6 + pg._set_default_backend(backend_type) + backend_class._set_sequence_number_for_group() + + pg._register_backend(device, backend_type, backend_class) + return pg + + def stateless_init_torch_distributed_process_group( host: str, port: int, rank: int, world_size: int, backend: str) -> ProcessGroup: @@ -468,40 +504,19 @@ def stateless_init_torch_distributed_process_group( # different systems (e.g. RPC) in case the store is multi-tenant. prefix_store = PrefixStore(init_method, store) - pg: ProcessGroup = ProcessGroup( - prefix_store, - group_rank, - group_size, - ) - if backend == "gloo": - from torch.distributed.distributed_c10d import ProcessGroupGloo - backend_class = ProcessGroupGloo(prefix_store, - group_rank, - group_size, - timeout=timeout) - backend_type = ProcessGroup.BackendType.GLOO - device = torch.device("cpu") - elif backend == "nccl": - assert is_nccl_available() - from torch.distributed.distributed_c10d import ProcessGroupNCCL - - backend_options = ProcessGroupNCCL.Options() - backend_options._timeout = timeout - - backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size, - backend_options) - backend_type = ProcessGroup.BackendType.NCCL - device = torch.device("cuda") - else: - raise RuntimeError(f"Unsupported torch distributed backend: {backend}") - - pg._set_default_backend(backend_type) - backend_class._set_sequence_number_for_group() - - pg._register_backend(device, backend_type, backend_class) - - return pg + return init_gloo_process_group(backend=backend, + prefix_store=prefix_store, + group_rank=group_rank, + group_size=group_size, + timeout=timeout) + from vllm.platforms import current_platform + return current_platform.stateless_init_device_torch_dist_pg( + backend=backend, + prefix_store=prefix_store, + group_rank=group_rank, + group_size=group_size, + timeout=timeout) def stateless_destroy_torch_distributed_process_group( diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 8bb3dfe7457a..0bed44f73277 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -4,10 +4,13 @@ pynvml. However, it should not initialize cuda context. """ import os +from datetime import timedelta from functools import wraps from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union import torch +from torch.distributed import PrefixStore, ProcessGroup +from torch.distributed.distributed_c10d import is_nccl_available from typing_extensions import ParamSpec # import custom ops, trigger op registration @@ -316,6 +319,36 @@ class CudaPlatformBase(Platform): def get_piecewise_backend_cls(cls) -> str: return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa + @classmethod + def stateless_init_device_torch_dist_pg( + cls, + backend: str, + prefix_store: PrefixStore, + group_rank: int, + group_size: int, + timeout: timedelta, + ) -> ProcessGroup: + assert is_nccl_available() + pg: ProcessGroup = ProcessGroup( + prefix_store, + group_rank, + group_size, + ) + from torch.distributed.distributed_c10d import ProcessGroupNCCL + + backend_options = ProcessGroupNCCL.Options() + backend_options._timeout = timeout + + backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size, + backend_options) + backend_type = ProcessGroup.BackendType.NCCL + device = torch.device("cuda") + pg._set_default_backend(backend_type) + backend_class._set_sequence_number_for_group() + + pg._register_backend(device, backend_type, backend_class) + return pg + # NVML utils # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 504c3b42a75d..5c4f7a2f7dc7 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -3,11 +3,13 @@ import enum import os import platform import random +from datetime import timedelta from platform import uname from typing import TYPE_CHECKING, NamedTuple, Optional, Union import numpy as np import torch +from torch.distributed import PrefixStore, ProcessGroup from vllm.inputs import ProcessorInputs, PromptType from vllm.logger import init_logger @@ -486,6 +488,20 @@ class Platform: """ return "vllm.compilation.base_piecewise_backend.AbstractPiecewiseBackend" # noqa + @classmethod + def stateless_init_device_torch_dist_pg( + cls, + backend: str, + prefix_store: PrefixStore, + group_rank: int, + group_size: int, + timeout: timedelta, + ) -> ProcessGroup: + """ + Init platform-specific torch distributed process group. + """ + raise RuntimeError(f"Unsupported torch distributed backend: {backend}") + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index b5e742c65c9f..d544b4ab4b02 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -1,10 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 import os +from datetime import timedelta from functools import cache, lru_cache, wraps from typing import TYPE_CHECKING, Optional import torch +from torch.distributed import PrefixStore, ProcessGroup +from torch.distributed.distributed_c10d import is_nccl_available import vllm.envs as envs from vllm.logger import init_logger @@ -387,3 +390,33 @@ class RocmPlatform(Platform): @classmethod def get_piecewise_backend_cls(cls) -> str: return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa + + @classmethod + def stateless_init_device_torch_dist_pg( + cls, + backend: str, + prefix_store: PrefixStore, + group_rank: int, + group_size: int, + timeout: timedelta, + ) -> ProcessGroup: + assert is_nccl_available() + pg: ProcessGroup = ProcessGroup( + prefix_store, + group_rank, + group_size, + ) + from torch.distributed.distributed_c10d import ProcessGroupNCCL + + backend_options = ProcessGroupNCCL.Options() + backend_options._timeout = timeout + + backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size, + backend_options) + backend_type = ProcessGroup.BackendType.NCCL + device = torch.device("cuda") + pg._set_default_backend(backend_type) + backend_class._set_sequence_number_for_group() + + pg._register_backend(device, backend_type, backend_class) + return pg