[Platform] allow platform to init dp group (#22243)

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan 2025-10-15 17:32:17 +08:00 committed by GitHub
parent 7f83b4ee8e
commit db1764e4e0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 15 additions and 83 deletions

View File

@ -334,7 +334,7 @@ class ParallelConfig:
self.get_next_dp_init_port(),
self.data_parallel_rank,
self.data_parallel_size,
backend="gloo",
backend=current_platform.dist_backend,
)
except DistNetworkError as e:
# We only want to retry when the root cause is EADDRINUSE.

View File

@ -415,7 +415,6 @@ class StatelessProcessGroup:
def init_gloo_process_group(
backend: Backend,
prefix_store: PrefixStore,
group_rank: int,
group_size: int,
@ -432,7 +431,7 @@ def init_gloo_process_group(
group_size,
)
else:
options = ProcessGroup.Options(backend=backend)
options = ProcessGroup.Options(backend="gloo")
pg = ProcessGroup(
prefix_store,
group_rank,
@ -504,24 +503,25 @@ def stateless_init_torch_distributed_process_group(
# Use a PrefixStore to avoid accidental overrides of keys used by
# different systems (e.g. RPC) in case the store is multi-tenant.
prefix_store = PrefixStore(init_method, store)
try:
from vllm.platforms import current_platform
if backend == "gloo":
return init_gloo_process_group(
return current_platform.stateless_init_device_torch_dist_pg(
backend=backend,
prefix_store=prefix_store,
group_rank=group_rank,
group_size=group_size,
timeout=timeout,
)
from vllm.platforms import current_platform
return current_platform.stateless_init_device_torch_dist_pg(
backend=backend,
prefix_store=prefix_store,
group_rank=group_rank,
group_size=group_size,
timeout=timeout,
)
except NotImplementedError:
# If platform doesn't implement stateless_init_device_torch_dist_pg, it
# will raise a NotImplementedError. In this case, we fall back to gloo.
return init_gloo_process_group(
prefix_store=prefix_store,
group_rank=group_rank,
group_size=group_size,
timeout=timeout,
)
def stateless_destroy_torch_distributed_process_group(pg: ProcessGroup) -> None:

View File

@ -6,13 +6,10 @@ pynvml. However, it should not initialize cuda context.
import os
from collections.abc import Callable
from datetime import timedelta
from functools import cache, wraps
from typing import TYPE_CHECKING, TypeVar
import torch
from torch.distributed import PrefixStore, ProcessGroup
from torch.distributed.distributed_c10d import is_nccl_available
from typing_extensions import ParamSpec
# import custom ops, trigger op registration
@ -455,37 +452,6 @@ class CudaPlatformBase(Platform):
def get_static_graph_wrapper_cls(cls) -> str:
return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
@classmethod
def stateless_init_device_torch_dist_pg(
cls,
backend: str,
prefix_store: PrefixStore,
group_rank: int,
group_size: int,
timeout: timedelta,
) -> ProcessGroup:
assert is_nccl_available()
pg: ProcessGroup = ProcessGroup(
prefix_store,
group_rank,
group_size,
)
from torch.distributed.distributed_c10d import ProcessGroupNCCL
backend_options = ProcessGroupNCCL.Options()
backend_options._timeout = timeout
backend_class = ProcessGroupNCCL(
prefix_store, group_rank, group_size, backend_options
)
backend_type = ProcessGroup.BackendType.NCCL
device = torch.device("cuda")
pg._set_default_backend(backend_type)
backend_class._set_sequence_number_for_group()
pg._register_backend(device, backend_type, backend_class)
return pg
@classmethod
def device_count(cls) -> int:
return cuda_device_count_stateless()

View File

@ -551,7 +551,7 @@ class Platform:
"""
Init platform-specific torch distributed process group.
"""
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")
raise NotImplementedError
@classmethod
def is_kv_cache_dtype_supported(

View File

@ -2,13 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from datetime import timedelta
from functools import cache, lru_cache, wraps
from typing import TYPE_CHECKING
import torch
from torch.distributed import PrefixStore, ProcessGroup
from torch.distributed.distributed_c10d import is_nccl_available
import vllm.envs as envs
from vllm.logger import init_logger
@ -476,37 +473,6 @@ class RocmPlatform(Platform):
def get_static_graph_wrapper_cls(cls) -> str:
return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
@classmethod
def stateless_init_device_torch_dist_pg(
cls,
backend: str,
prefix_store: PrefixStore,
group_rank: int,
group_size: int,
timeout: timedelta,
) -> ProcessGroup:
assert is_nccl_available()
pg: ProcessGroup = ProcessGroup(
prefix_store,
group_rank,
group_size,
)
from torch.distributed.distributed_c10d import ProcessGroupNCCL
backend_options = ProcessGroupNCCL.Options()
backend_options._timeout = timeout
backend_class = ProcessGroupNCCL(
prefix_store, group_rank, group_size, backend_options
)
backend_type = ProcessGroup.BackendType.NCCL
device = torch.device("cuda")
pg._set_default_backend(backend_type)
backend_class._set_sequence_number_for_group()
pg._register_backend(device, backend_type, backend_class)
return pg
@classmethod
def device_count(cls) -> int:
return cuda_device_count_stateless()