mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-07-06 07:17:09 +08:00
[Platform] allow platform to init dp group (#22243)
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
parent
7f83b4ee8e
commit
db1764e4e0
@ -334,7 +334,7 @@ class ParallelConfig:
|
|||||||
self.get_next_dp_init_port(),
|
self.get_next_dp_init_port(),
|
||||||
self.data_parallel_rank,
|
self.data_parallel_rank,
|
||||||
self.data_parallel_size,
|
self.data_parallel_size,
|
||||||
backend="gloo",
|
backend=current_platform.dist_backend,
|
||||||
)
|
)
|
||||||
except DistNetworkError as e:
|
except DistNetworkError as e:
|
||||||
# We only want to retry when the root cause is EADDRINUSE.
|
# We only want to retry when the root cause is EADDRINUSE.
|
||||||
|
|||||||
@ -415,7 +415,6 @@ class StatelessProcessGroup:
|
|||||||
|
|
||||||
|
|
||||||
def init_gloo_process_group(
|
def init_gloo_process_group(
|
||||||
backend: Backend,
|
|
||||||
prefix_store: PrefixStore,
|
prefix_store: PrefixStore,
|
||||||
group_rank: int,
|
group_rank: int,
|
||||||
group_size: int,
|
group_size: int,
|
||||||
@ -432,7 +431,7 @@ def init_gloo_process_group(
|
|||||||
group_size,
|
group_size,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
options = ProcessGroup.Options(backend=backend)
|
options = ProcessGroup.Options(backend="gloo")
|
||||||
pg = ProcessGroup(
|
pg = ProcessGroup(
|
||||||
prefix_store,
|
prefix_store,
|
||||||
group_rank,
|
group_rank,
|
||||||
@ -504,24 +503,25 @@ def stateless_init_torch_distributed_process_group(
|
|||||||
# Use a PrefixStore to avoid accidental overrides of keys used by
|
# Use a PrefixStore to avoid accidental overrides of keys used by
|
||||||
# different systems (e.g. RPC) in case the store is multi-tenant.
|
# different systems (e.g. RPC) in case the store is multi-tenant.
|
||||||
prefix_store = PrefixStore(init_method, store)
|
prefix_store = PrefixStore(init_method, store)
|
||||||
|
try:
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
if backend == "gloo":
|
return current_platform.stateless_init_device_torch_dist_pg(
|
||||||
return init_gloo_process_group(
|
|
||||||
backend=backend,
|
backend=backend,
|
||||||
prefix_store=prefix_store,
|
prefix_store=prefix_store,
|
||||||
group_rank=group_rank,
|
group_rank=group_rank,
|
||||||
group_size=group_size,
|
group_size=group_size,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
from vllm.platforms import current_platform
|
except NotImplementedError:
|
||||||
|
# If platform doesn't implement stateless_init_device_torch_dist_pg, it
|
||||||
return current_platform.stateless_init_device_torch_dist_pg(
|
# will raise a NotImplementedError. In this case, we fall back to gloo.
|
||||||
backend=backend,
|
return init_gloo_process_group(
|
||||||
prefix_store=prefix_store,
|
prefix_store=prefix_store,
|
||||||
group_rank=group_rank,
|
group_rank=group_rank,
|
||||||
group_size=group_size,
|
group_size=group_size,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def stateless_destroy_torch_distributed_process_group(pg: ProcessGroup) -> None:
|
def stateless_destroy_torch_distributed_process_group(pg: ProcessGroup) -> None:
|
||||||
|
|||||||
@ -6,13 +6,10 @@ pynvml. However, it should not initialize cuda context.
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from datetime import timedelta
|
|
||||||
from functools import cache, wraps
|
from functools import cache, wraps
|
||||||
from typing import TYPE_CHECKING, TypeVar
|
from typing import TYPE_CHECKING, TypeVar
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed import PrefixStore, ProcessGroup
|
|
||||||
from torch.distributed.distributed_c10d import is_nccl_available
|
|
||||||
from typing_extensions import ParamSpec
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
# import custom ops, trigger op registration
|
# import custom ops, trigger op registration
|
||||||
@ -455,37 +452,6 @@ class CudaPlatformBase(Platform):
|
|||||||
def get_static_graph_wrapper_cls(cls) -> str:
|
def get_static_graph_wrapper_cls(cls) -> str:
|
||||||
return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
|
return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def stateless_init_device_torch_dist_pg(
|
|
||||||
cls,
|
|
||||||
backend: str,
|
|
||||||
prefix_store: PrefixStore,
|
|
||||||
group_rank: int,
|
|
||||||
group_size: int,
|
|
||||||
timeout: timedelta,
|
|
||||||
) -> ProcessGroup:
|
|
||||||
assert is_nccl_available()
|
|
||||||
pg: ProcessGroup = ProcessGroup(
|
|
||||||
prefix_store,
|
|
||||||
group_rank,
|
|
||||||
group_size,
|
|
||||||
)
|
|
||||||
from torch.distributed.distributed_c10d import ProcessGroupNCCL
|
|
||||||
|
|
||||||
backend_options = ProcessGroupNCCL.Options()
|
|
||||||
backend_options._timeout = timeout
|
|
||||||
|
|
||||||
backend_class = ProcessGroupNCCL(
|
|
||||||
prefix_store, group_rank, group_size, backend_options
|
|
||||||
)
|
|
||||||
backend_type = ProcessGroup.BackendType.NCCL
|
|
||||||
device = torch.device("cuda")
|
|
||||||
pg._set_default_backend(backend_type)
|
|
||||||
backend_class._set_sequence_number_for_group()
|
|
||||||
|
|
||||||
pg._register_backend(device, backend_type, backend_class)
|
|
||||||
return pg
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def device_count(cls) -> int:
|
def device_count(cls) -> int:
|
||||||
return cuda_device_count_stateless()
|
return cuda_device_count_stateless()
|
||||||
|
|||||||
@ -551,7 +551,7 @@ class Platform:
|
|||||||
"""
|
"""
|
||||||
Init platform-specific torch distributed process group.
|
Init platform-specific torch distributed process group.
|
||||||
"""
|
"""
|
||||||
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_kv_cache_dtype_supported(
|
def is_kv_cache_dtype_supported(
|
||||||
|
|||||||
@ -2,13 +2,10 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from datetime import timedelta
|
|
||||||
from functools import cache, lru_cache, wraps
|
from functools import cache, lru_cache, wraps
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed import PrefixStore, ProcessGroup
|
|
||||||
from torch.distributed.distributed_c10d import is_nccl_available
|
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -476,37 +473,6 @@ class RocmPlatform(Platform):
|
|||||||
def get_static_graph_wrapper_cls(cls) -> str:
|
def get_static_graph_wrapper_cls(cls) -> str:
|
||||||
return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
|
return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def stateless_init_device_torch_dist_pg(
|
|
||||||
cls,
|
|
||||||
backend: str,
|
|
||||||
prefix_store: PrefixStore,
|
|
||||||
group_rank: int,
|
|
||||||
group_size: int,
|
|
||||||
timeout: timedelta,
|
|
||||||
) -> ProcessGroup:
|
|
||||||
assert is_nccl_available()
|
|
||||||
pg: ProcessGroup = ProcessGroup(
|
|
||||||
prefix_store,
|
|
||||||
group_rank,
|
|
||||||
group_size,
|
|
||||||
)
|
|
||||||
from torch.distributed.distributed_c10d import ProcessGroupNCCL
|
|
||||||
|
|
||||||
backend_options = ProcessGroupNCCL.Options()
|
|
||||||
backend_options._timeout = timeout
|
|
||||||
|
|
||||||
backend_class = ProcessGroupNCCL(
|
|
||||||
prefix_store, group_rank, group_size, backend_options
|
|
||||||
)
|
|
||||||
backend_type = ProcessGroup.BackendType.NCCL
|
|
||||||
device = torch.device("cuda")
|
|
||||||
pg._set_default_backend(backend_type)
|
|
||||||
backend_class._set_sequence_number_for_group()
|
|
||||||
|
|
||||||
pg._register_backend(device, backend_type, backend_class)
|
|
||||||
return pg
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def device_count(cls) -> int:
|
def device_count(cls) -> int:
|
||||||
return cuda_device_count_stateless()
|
return cuda_device_count_stateless()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user