From db1764e4e05b06c93073b9f26df7b1f3b684e638 Mon Sep 17 00:00:00 2001 From: wangxiyuan Date: Wed, 15 Oct 2025 17:32:17 +0800 Subject: [PATCH] [Platform] allow platform to init dp group (#22243) Signed-off-by: wangxiyuan --- vllm/config/parallel.py | 2 +- vllm/distributed/utils.py | 26 +++++++++++++------------- vllm/platforms/cuda.py | 34 ---------------------------------- vllm/platforms/interface.py | 2 +- vllm/platforms/rocm.py | 34 ---------------------------------- 5 files changed, 15 insertions(+), 83 deletions(-) diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index b7ef0fef68330..944a1e8666f4b 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -334,7 +334,7 @@ class ParallelConfig: self.get_next_dp_init_port(), self.data_parallel_rank, self.data_parallel_size, - backend="gloo", + backend=current_platform.dist_backend, ) except DistNetworkError as e: # We only want to retry when the root cause is EADDRINUSE. diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index 0a1e04ec10f99..a3d9dbe83a124 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -415,7 +415,6 @@ class StatelessProcessGroup: def init_gloo_process_group( - backend: Backend, prefix_store: PrefixStore, group_rank: int, group_size: int, @@ -432,7 +431,7 @@ def init_gloo_process_group( group_size, ) else: - options = ProcessGroup.Options(backend=backend) + options = ProcessGroup.Options(backend="gloo") pg = ProcessGroup( prefix_store, group_rank, @@ -504,24 +503,25 @@ def stateless_init_torch_distributed_process_group( # Use a PrefixStore to avoid accidental overrides of keys used by # different systems (e.g. RPC) in case the store is multi-tenant. prefix_store = PrefixStore(init_method, store) + try: + from vllm.platforms import current_platform - if backend == "gloo": - return init_gloo_process_group( + return current_platform.stateless_init_device_torch_dist_pg( backend=backend, prefix_store=prefix_store, group_rank=group_rank, group_size=group_size, timeout=timeout, ) - from vllm.platforms import current_platform - - return current_platform.stateless_init_device_torch_dist_pg( - backend=backend, - prefix_store=prefix_store, - group_rank=group_rank, - group_size=group_size, - timeout=timeout, - ) + except NotImplementedError: + # If platform doesn't implement stateless_init_device_torch_dist_pg, it + # will raise a NotImplementedError. In this case, we fall back to gloo. + return init_gloo_process_group( + prefix_store=prefix_store, + group_rank=group_rank, + group_size=group_size, + timeout=timeout, + ) def stateless_destroy_torch_distributed_process_group(pg: ProcessGroup) -> None: diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 0252c3acb08c1..04c2bbb43805b 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -6,13 +6,10 @@ pynvml. However, it should not initialize cuda context. import os from collections.abc import Callable -from datetime import timedelta from functools import cache, wraps from typing import TYPE_CHECKING, TypeVar import torch -from torch.distributed import PrefixStore, ProcessGroup -from torch.distributed.distributed_c10d import is_nccl_available from typing_extensions import ParamSpec # import custom ops, trigger op registration @@ -455,37 +452,6 @@ class CudaPlatformBase(Platform): def get_static_graph_wrapper_cls(cls) -> str: return "vllm.compilation.cuda_graph.CUDAGraphWrapper" - @classmethod - def stateless_init_device_torch_dist_pg( - cls, - backend: str, - prefix_store: PrefixStore, - group_rank: int, - group_size: int, - timeout: timedelta, - ) -> ProcessGroup: - assert is_nccl_available() - pg: ProcessGroup = ProcessGroup( - prefix_store, - group_rank, - group_size, - ) - from torch.distributed.distributed_c10d import ProcessGroupNCCL - - backend_options = ProcessGroupNCCL.Options() - backend_options._timeout = timeout - - backend_class = ProcessGroupNCCL( - prefix_store, group_rank, group_size, backend_options - ) - backend_type = ProcessGroup.BackendType.NCCL - device = torch.device("cuda") - pg._set_default_backend(backend_type) - backend_class._set_sequence_number_for_group() - - pg._register_backend(device, backend_type, backend_class) - return pg - @classmethod def device_count(cls) -> int: return cuda_device_count_stateless() diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 9b8d75ac22fe0..f08e62a4aa9c2 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -551,7 +551,7 @@ class Platform: """ Init platform-specific torch distributed process group. """ - raise RuntimeError(f"Unsupported torch distributed backend: {backend}") + raise NotImplementedError @classmethod def is_kv_cache_dtype_supported( diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 81745257d0ae2..8fa07b10d34aa 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -2,13 +2,10 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os -from datetime import timedelta from functools import cache, lru_cache, wraps from typing import TYPE_CHECKING import torch -from torch.distributed import PrefixStore, ProcessGroup -from torch.distributed.distributed_c10d import is_nccl_available import vllm.envs as envs from vllm.logger import init_logger @@ -476,37 +473,6 @@ class RocmPlatform(Platform): def get_static_graph_wrapper_cls(cls) -> str: return "vllm.compilation.cuda_graph.CUDAGraphWrapper" - @classmethod - def stateless_init_device_torch_dist_pg( - cls, - backend: str, - prefix_store: PrefixStore, - group_rank: int, - group_size: int, - timeout: timedelta, - ) -> ProcessGroup: - assert is_nccl_available() - pg: ProcessGroup = ProcessGroup( - prefix_store, - group_rank, - group_size, - ) - from torch.distributed.distributed_c10d import ProcessGroupNCCL - - backend_options = ProcessGroupNCCL.Options() - backend_options._timeout = timeout - - backend_class = ProcessGroupNCCL( - prefix_store, group_rank, group_size, backend_options - ) - backend_type = ProcessGroup.BackendType.NCCL - device = torch.device("cuda") - pg._set_default_backend(backend_type) - backend_class._set_sequence_number_for_group() - - pg._register_backend(device, backend_type, backend_class) - return pg - @classmethod def device_count(cls) -> int: return cuda_device_count_stateless()