mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-02 12:17:54 +08:00
[distributed] fix dp group (#15355)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
cbcdf2c609
commit
9606d572ed
@ -897,29 +897,22 @@ def initialize_model_parallel(
|
|||||||
get_world_group().device_group)
|
get_world_group().device_group)
|
||||||
|
|
||||||
data_parallel_size = 1
|
data_parallel_size = 1
|
||||||
has_external_dp = False
|
|
||||||
from vllm.config import get_current_vllm_config
|
from vllm.config import get_current_vllm_config
|
||||||
config = get_current_vllm_config()
|
config = get_current_vllm_config()
|
||||||
if config is not None:
|
if config is not None:
|
||||||
if config.parallel_config.world_size != world_size:
|
data_parallel_size = config.parallel_config.data_parallel_size
|
||||||
# detect external data parallelism.
|
|
||||||
# dp in vllm means all dp instances need to run together.
|
|
||||||
# if the world size does not match, it means this dp is external,
|
|
||||||
# and the dp instances can run independently, e.g. in rlhf workflow
|
|
||||||
# from https://github.com/volcengine/verl .
|
|
||||||
# in that case, we treat the rest dimensions as if they are
|
|
||||||
# data parallel, and create a dummy dp group that is not used.
|
|
||||||
data_parallel_size = world_size // (pipeline_model_parallel_size *
|
|
||||||
tensor_model_parallel_size)
|
|
||||||
has_external_dp = True
|
|
||||||
else:
|
|
||||||
data_parallel_size = config.parallel_config.data_parallel_size
|
|
||||||
|
|
||||||
# the layout order is: DP x PP x TP
|
# the layout order is: ExternalDP x DP x PP x TP
|
||||||
|
# ExternalDP is the data parallel group that is not part of the model,
|
||||||
|
# every dp rank can generate independently (in verl integration).
|
||||||
|
# DP is the data parallel group that is part of the model,
|
||||||
|
# all the ranks in the same DP group should generate simultaneously,
|
||||||
|
# i.e. the `generate` call in the same DP group should be called together,
|
||||||
|
# otherwise it will cause deadlock.
|
||||||
# to get group_ranks for each dimension, transpose that dimension to the
|
# to get group_ranks for each dimension, transpose that dimension to the
|
||||||
# last dimension, then reshape to 2D, then unbind the last dimension
|
# last dimension, then reshape to 2D, then unbind the last dimension
|
||||||
all_ranks = torch.arange(world_size).reshape(
|
all_ranks = torch.arange(world_size).reshape(
|
||||||
data_parallel_size, pipeline_model_parallel_size,
|
-1, data_parallel_size, pipeline_model_parallel_size,
|
||||||
tensor_model_parallel_size) # noqa
|
tensor_model_parallel_size) # noqa
|
||||||
|
|
||||||
# Build the tensor model-parallel groups.
|
# Build the tensor model-parallel groups.
|
||||||
@ -939,7 +932,7 @@ def initialize_model_parallel(
|
|||||||
global _PP
|
global _PP
|
||||||
assert _PP is None, (
|
assert _PP is None, (
|
||||||
"pipeline model parallel group is already initialized")
|
"pipeline model parallel group is already initialized")
|
||||||
group_ranks = all_ranks.transpose(1, 2).reshape(
|
group_ranks = all_ranks.transpose(2, 3).reshape(
|
||||||
-1, pipeline_model_parallel_size).unbind(0)
|
-1, pipeline_model_parallel_size).unbind(0)
|
||||||
group_ranks = [x.tolist() for x in group_ranks]
|
group_ranks = [x.tolist() for x in group_ranks]
|
||||||
_PP = init_model_parallel_group(group_ranks,
|
_PP = init_model_parallel_group(group_ranks,
|
||||||
@ -949,16 +942,10 @@ def initialize_model_parallel(
|
|||||||
|
|
||||||
global _DP
|
global _DP
|
||||||
assert _DP is None, ("data parallel group is already initialized")
|
assert _DP is None, ("data parallel group is already initialized")
|
||||||
group_ranks = all_ranks.transpose(0,
|
group_ranks = all_ranks.transpose(1,
|
||||||
2).reshape(-1,
|
3).reshape(-1,
|
||||||
data_parallel_size).unbind(0)
|
data_parallel_size).unbind(0)
|
||||||
group_ranks = [x.tolist() for x in group_ranks]
|
group_ranks = [x.tolist() for x in group_ranks]
|
||||||
if has_external_dp:
|
|
||||||
# create a dummy dp group that is not used actually,
|
|
||||||
# since this dp is external.
|
|
||||||
# a dummy dp group means every rank is a group itself.
|
|
||||||
# this way, no communication is needed, no memory is wasted.
|
|
||||||
group_ranks = [[x] for x in range(world_size)]
|
|
||||||
_DP = init_model_parallel_group(group_ranks,
|
_DP = init_model_parallel_group(group_ranks,
|
||||||
get_world_group().local_rank,
|
get_world_group().local_rank,
|
||||||
backend,
|
backend,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user