[Platform][Dist] Make torch distributed process group extendable (#18763)

Signed-off-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao 2025-05-28 18:52:34 +08:00 committed by GitHub
parent ce75efeecb
commit d781930f90
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 134 additions and 37 deletions

View File

@ -5,7 +5,6 @@
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import dataclasses import dataclasses
import datetime
import os import os
import pickle import pickle
import socket import socket
@ -14,14 +13,14 @@ import time
import uuid import uuid
from collections import deque from collections import deque
from collections.abc import Sequence from collections.abc import Sequence
from datetime import timedelta
from typing import Any, Optional from typing import Any, Optional
import torch import torch
from torch.distributed import ProcessGroup, TCPStore from torch.distributed import ProcessGroup, TCPStore
from torch.distributed.distributed_c10d import (Backend, PrefixStore, from torch.distributed.distributed_c10d import (Backend, PrefixStore,
_get_default_timeout, _get_default_timeout,
_unregister_process_group, _unregister_process_group)
is_nccl_available)
from torch.distributed.rendezvous import rendezvous from torch.distributed.rendezvous import rendezvous
import vllm.envs as envs import vllm.envs as envs
@ -406,7 +405,7 @@ class StatelessProcessGroup:
port=port, port=port,
world_size=world_size, world_size=world_size,
is_master=launch_server, is_master=launch_server,
timeout=datetime.timedelta(seconds=store_timeout), timeout=timedelta(seconds=store_timeout),
use_libuv=False, # for now: github.com/pytorch/pytorch/pull/150215 use_libuv=False, # for now: github.com/pytorch/pytorch/pull/150215
master_listen_fd=listen_fd, master_listen_fd=listen_fd,
) )
@ -419,6 +418,43 @@ class StatelessProcessGroup:
data_expiration_seconds=data_expiration_seconds) data_expiration_seconds=data_expiration_seconds)
def init_gloo_process_group(backend: Backend, prefix_store: PrefixStore,
group_rank: int, group_size: int,
timeout: timedelta) -> ProcessGroup:
"""
Stateless init ProcessGroup with gloo backend compatible with
different torch versions.
"""
if is_torch_equal_or_newer("2.6"):
pg = ProcessGroup(
prefix_store,
group_rank,
group_size,
)
else:
options = ProcessGroup.Options(backend=backend)
pg = ProcessGroup(
prefix_store,
group_rank,
group_size,
options,
)
from torch.distributed.distributed_c10d import ProcessGroupGloo
backend_class = ProcessGroupGloo(prefix_store,
group_rank,
group_size,
timeout=timeout)
backend_type = ProcessGroup.BackendType.GLOO
device = torch.device("cpu")
if is_torch_equal_or_newer("2.6"):
# _set_default_backend is supported in torch >= 2.6
pg._set_default_backend(backend_type)
backend_class._set_sequence_number_for_group()
pg._register_backend(device, backend_type, backend_class)
return pg
def stateless_init_torch_distributed_process_group( def stateless_init_torch_distributed_process_group(
host: str, port: int, rank: int, world_size: int, host: str, port: int, rank: int, world_size: int,
backend: str) -> ProcessGroup: backend: str) -> ProcessGroup:
@ -468,40 +504,19 @@ def stateless_init_torch_distributed_process_group(
# different systems (e.g. RPC) in case the store is multi-tenant. # different systems (e.g. RPC) in case the store is multi-tenant.
prefix_store = PrefixStore(init_method, store) prefix_store = PrefixStore(init_method, store)
pg: ProcessGroup = ProcessGroup(
prefix_store,
group_rank,
group_size,
)
if backend == "gloo": if backend == "gloo":
from torch.distributed.distributed_c10d import ProcessGroupGloo return init_gloo_process_group(backend=backend,
backend_class = ProcessGroupGloo(prefix_store, prefix_store=prefix_store,
group_rank, group_rank=group_rank,
group_size, group_size=group_size,
timeout=timeout) timeout=timeout)
backend_type = ProcessGroup.BackendType.GLOO from vllm.platforms import current_platform
device = torch.device("cpu") return current_platform.stateless_init_device_torch_dist_pg(
elif backend == "nccl": backend=backend,
assert is_nccl_available() prefix_store=prefix_store,
from torch.distributed.distributed_c10d import ProcessGroupNCCL group_rank=group_rank,
group_size=group_size,
backend_options = ProcessGroupNCCL.Options() timeout=timeout)
backend_options._timeout = timeout
backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
backend_options)
backend_type = ProcessGroup.BackendType.NCCL
device = torch.device("cuda")
else:
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")
pg._set_default_backend(backend_type)
backend_class._set_sequence_number_for_group()
pg._register_backend(device, backend_type, backend_class)
return pg
def stateless_destroy_torch_distributed_process_group( def stateless_destroy_torch_distributed_process_group(

View File

@ -4,10 +4,13 @@ pynvml. However, it should not initialize cuda context.
""" """
import os import os
from datetime import timedelta
from functools import wraps from functools import wraps
from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union
import torch import torch
from torch.distributed import PrefixStore, ProcessGroup
from torch.distributed.distributed_c10d import is_nccl_available
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
# import custom ops, trigger op registration # import custom ops, trigger op registration
@ -316,6 +319,36 @@ class CudaPlatformBase(Platform):
def get_piecewise_backend_cls(cls) -> str: def get_piecewise_backend_cls(cls) -> str:
return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa
@classmethod
def stateless_init_device_torch_dist_pg(
cls,
backend: str,
prefix_store: PrefixStore,
group_rank: int,
group_size: int,
timeout: timedelta,
) -> ProcessGroup:
assert is_nccl_available()
pg: ProcessGroup = ProcessGroup(
prefix_store,
group_rank,
group_size,
)
from torch.distributed.distributed_c10d import ProcessGroupNCCL
backend_options = ProcessGroupNCCL.Options()
backend_options._timeout = timeout
backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
backend_options)
backend_type = ProcessGroup.BackendType.NCCL
device = torch.device("cuda")
pg._set_default_backend(backend_type)
backend_class._set_sequence_number_for_group()
pg._register_backend(device, backend_type, backend_class)
return pg
# NVML utils # NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,

View File

@ -3,11 +3,13 @@ import enum
import os import os
import platform import platform
import random import random
from datetime import timedelta
from platform import uname from platform import uname
from typing import TYPE_CHECKING, NamedTuple, Optional, Union from typing import TYPE_CHECKING, NamedTuple, Optional, Union
import numpy as np import numpy as np
import torch import torch
from torch.distributed import PrefixStore, ProcessGroup
from vllm.inputs import ProcessorInputs, PromptType from vllm.inputs import ProcessorInputs, PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
@ -486,6 +488,20 @@ class Platform:
""" """
return "vllm.compilation.base_piecewise_backend.AbstractPiecewiseBackend" # noqa return "vllm.compilation.base_piecewise_backend.AbstractPiecewiseBackend" # noqa
@classmethod
def stateless_init_device_torch_dist_pg(
cls,
backend: str,
prefix_store: PrefixStore,
group_rank: int,
group_size: int,
timeout: timedelta,
) -> ProcessGroup:
"""
Init platform-specific torch distributed process group.
"""
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")
class UnspecifiedPlatform(Platform): class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED _enum = PlatformEnum.UNSPECIFIED

View File

@ -1,10 +1,13 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os import os
from datetime import timedelta
from functools import cache, lru_cache, wraps from functools import cache, lru_cache, wraps
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
import torch import torch
from torch.distributed import PrefixStore, ProcessGroup
from torch.distributed.distributed_c10d import is_nccl_available
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
@ -387,3 +390,33 @@ class RocmPlatform(Platform):
@classmethod @classmethod
def get_piecewise_backend_cls(cls) -> str: def get_piecewise_backend_cls(cls) -> str:
return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa
@classmethod
def stateless_init_device_torch_dist_pg(
cls,
backend: str,
prefix_store: PrefixStore,
group_rank: int,
group_size: int,
timeout: timedelta,
) -> ProcessGroup:
assert is_nccl_available()
pg: ProcessGroup = ProcessGroup(
prefix_store,
group_rank,
group_size,
)
from torch.distributed.distributed_c10d import ProcessGroupNCCL
backend_options = ProcessGroupNCCL.Options()
backend_options._timeout = timeout
backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
backend_options)
backend_type = ProcessGroup.BackendType.NCCL
device = torch.device("cuda")
pg._set_default_backend(backend_type)
backend_class._set_sequence_number_for_group()
pg._register_backend(device, backend_type, backend_class)
return pg