mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-25 10:14:27 +08:00
[Platform] allow platform to init dp group (#22243)
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
parent
7f83b4ee8e
commit
db1764e4e0
@ -334,7 +334,7 @@ class ParallelConfig:
|
||||
self.get_next_dp_init_port(),
|
||||
self.data_parallel_rank,
|
||||
self.data_parallel_size,
|
||||
backend="gloo",
|
||||
backend=current_platform.dist_backend,
|
||||
)
|
||||
except DistNetworkError as e:
|
||||
# We only want to retry when the root cause is EADDRINUSE.
|
||||
|
||||
@ -415,7 +415,6 @@ class StatelessProcessGroup:
|
||||
|
||||
|
||||
def init_gloo_process_group(
|
||||
backend: Backend,
|
||||
prefix_store: PrefixStore,
|
||||
group_rank: int,
|
||||
group_size: int,
|
||||
@ -432,7 +431,7 @@ def init_gloo_process_group(
|
||||
group_size,
|
||||
)
|
||||
else:
|
||||
options = ProcessGroup.Options(backend=backend)
|
||||
options = ProcessGroup.Options(backend="gloo")
|
||||
pg = ProcessGroup(
|
||||
prefix_store,
|
||||
group_rank,
|
||||
@ -504,24 +503,25 @@ def stateless_init_torch_distributed_process_group(
|
||||
# Use a PrefixStore to avoid accidental overrides of keys used by
|
||||
# different systems (e.g. RPC) in case the store is multi-tenant.
|
||||
prefix_store = PrefixStore(init_method, store)
|
||||
try:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if backend == "gloo":
|
||||
return init_gloo_process_group(
|
||||
return current_platform.stateless_init_device_torch_dist_pg(
|
||||
backend=backend,
|
||||
prefix_store=prefix_store,
|
||||
group_rank=group_rank,
|
||||
group_size=group_size,
|
||||
timeout=timeout,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
return current_platform.stateless_init_device_torch_dist_pg(
|
||||
backend=backend,
|
||||
prefix_store=prefix_store,
|
||||
group_rank=group_rank,
|
||||
group_size=group_size,
|
||||
timeout=timeout,
|
||||
)
|
||||
except NotImplementedError:
|
||||
# If platform doesn't implement stateless_init_device_torch_dist_pg, it
|
||||
# will raise a NotImplementedError. In this case, we fall back to gloo.
|
||||
return init_gloo_process_group(
|
||||
prefix_store=prefix_store,
|
||||
group_rank=group_rank,
|
||||
group_size=group_size,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
|
||||
def stateless_destroy_torch_distributed_process_group(pg: ProcessGroup) -> None:
|
||||
|
||||
@ -6,13 +6,10 @@ pynvml. However, it should not initialize cuda context.
|
||||
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from functools import cache, wraps
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
import torch
|
||||
from torch.distributed import PrefixStore, ProcessGroup
|
||||
from torch.distributed.distributed_c10d import is_nccl_available
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
# import custom ops, trigger op registration
|
||||
@ -455,37 +452,6 @@ class CudaPlatformBase(Platform):
|
||||
def get_static_graph_wrapper_cls(cls) -> str:
|
||||
return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
|
||||
|
||||
@classmethod
|
||||
def stateless_init_device_torch_dist_pg(
|
||||
cls,
|
||||
backend: str,
|
||||
prefix_store: PrefixStore,
|
||||
group_rank: int,
|
||||
group_size: int,
|
||||
timeout: timedelta,
|
||||
) -> ProcessGroup:
|
||||
assert is_nccl_available()
|
||||
pg: ProcessGroup = ProcessGroup(
|
||||
prefix_store,
|
||||
group_rank,
|
||||
group_size,
|
||||
)
|
||||
from torch.distributed.distributed_c10d import ProcessGroupNCCL
|
||||
|
||||
backend_options = ProcessGroupNCCL.Options()
|
||||
backend_options._timeout = timeout
|
||||
|
||||
backend_class = ProcessGroupNCCL(
|
||||
prefix_store, group_rank, group_size, backend_options
|
||||
)
|
||||
backend_type = ProcessGroup.BackendType.NCCL
|
||||
device = torch.device("cuda")
|
||||
pg._set_default_backend(backend_type)
|
||||
backend_class._set_sequence_number_for_group()
|
||||
|
||||
pg._register_backend(device, backend_type, backend_class)
|
||||
return pg
|
||||
|
||||
@classmethod
|
||||
def device_count(cls) -> int:
|
||||
return cuda_device_count_stateless()
|
||||
|
||||
@ -551,7 +551,7 @@ class Platform:
|
||||
"""
|
||||
Init platform-specific torch distributed process group.
|
||||
"""
|
||||
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def is_kv_cache_dtype_supported(
|
||||
|
||||
@ -2,13 +2,10 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from datetime import timedelta
|
||||
from functools import cache, lru_cache, wraps
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch.distributed import PrefixStore, ProcessGroup
|
||||
from torch.distributed.distributed_c10d import is_nccl_available
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
@ -476,37 +473,6 @@ class RocmPlatform(Platform):
|
||||
def get_static_graph_wrapper_cls(cls) -> str:
|
||||
return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
|
||||
|
||||
@classmethod
|
||||
def stateless_init_device_torch_dist_pg(
|
||||
cls,
|
||||
backend: str,
|
||||
prefix_store: PrefixStore,
|
||||
group_rank: int,
|
||||
group_size: int,
|
||||
timeout: timedelta,
|
||||
) -> ProcessGroup:
|
||||
assert is_nccl_available()
|
||||
pg: ProcessGroup = ProcessGroup(
|
||||
prefix_store,
|
||||
group_rank,
|
||||
group_size,
|
||||
)
|
||||
from torch.distributed.distributed_c10d import ProcessGroupNCCL
|
||||
|
||||
backend_options = ProcessGroupNCCL.Options()
|
||||
backend_options._timeout = timeout
|
||||
|
||||
backend_class = ProcessGroupNCCL(
|
||||
prefix_store, group_rank, group_size, backend_options
|
||||
)
|
||||
backend_type = ProcessGroup.BackendType.NCCL
|
||||
device = torch.device("cuda")
|
||||
pg._set_default_backend(backend_type)
|
||||
backend_class._set_sequence_number_for_group()
|
||||
|
||||
pg._register_backend(device, backend_type, backend_class)
|
||||
return pg
|
||||
|
||||
@classmethod
|
||||
def device_count(cls) -> int:
|
||||
return cuda_device_count_stateless()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user