mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-06 13:15:45 +08:00
[BugFix] Fix multi-node offline data-parallel (#18981)
Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
parent
8bf507d766
commit
9a1b9b99d7
@ -97,10 +97,14 @@ def main(
|
|||||||
# with DP, each rank should process different prompts.
|
# with DP, each rank should process different prompts.
|
||||||
# usually all the DP ranks process a full dataset,
|
# usually all the DP ranks process a full dataset,
|
||||||
# and each rank processes a different part of the dataset.
|
# and each rank processes a different part of the dataset.
|
||||||
promts_per_rank = len(prompts) // dp_size
|
floor = len(prompts) // dp_size
|
||||||
start = global_dp_rank * promts_per_rank
|
remainder = len(prompts) % dp_size
|
||||||
end = start + promts_per_rank
|
|
||||||
prompts = prompts[start:end]
|
# Distribute prompts into even groups.
|
||||||
|
def start(rank):
|
||||||
|
return rank * floor + min(rank, remainder)
|
||||||
|
|
||||||
|
prompts = prompts[start(global_dp_rank) : start(global_dp_rank + 1)]
|
||||||
if len(prompts) == 0:
|
if len(prompts) == 0:
|
||||||
# if any rank has no prompts to process,
|
# if any rank has no prompts to process,
|
||||||
# we need to set a placeholder prompt
|
# we need to set a placeholder prompt
|
||||||
|
|||||||
@ -363,6 +363,7 @@ class MPClient(EngineCoreClient):
|
|||||||
local_engine_count = parallel_config.data_parallel_size_local
|
local_engine_count = parallel_config.data_parallel_size_local
|
||||||
local_start_index = parallel_config.data_parallel_rank_local
|
local_start_index = parallel_config.data_parallel_rank_local
|
||||||
dp_size = parallel_config.data_parallel_size
|
dp_size = parallel_config.data_parallel_size
|
||||||
|
dp_rank = parallel_config.data_parallel_rank
|
||||||
|
|
||||||
# SPMD mode is where there is an LLM instance per DP rank and
|
# SPMD mode is where there is an LLM instance per DP rank and
|
||||||
# one core engine per LLM, see
|
# one core engine per LLM, see
|
||||||
@ -370,11 +371,9 @@ class MPClient(EngineCoreClient):
|
|||||||
spmd_mode = local_start_index is not None
|
spmd_mode = local_start_index is not None
|
||||||
if spmd_mode:
|
if spmd_mode:
|
||||||
assert local_engine_count == 1
|
assert local_engine_count == 1
|
||||||
self.core_engines = [
|
self.core_engines = [CoreEngine(index=dp_rank, local=True)]
|
||||||
CoreEngine(index=local_start_index, local=True)
|
|
||||||
]
|
|
||||||
else:
|
else:
|
||||||
assert parallel_config.data_parallel_rank == 0
|
assert dp_rank == 0
|
||||||
local_start_index = 0
|
local_start_index = 0
|
||||||
self.core_engines = [
|
self.core_engines = [
|
||||||
CoreEngine(index=i, local=(i < local_engine_count))
|
CoreEngine(index=i, local=(i < local_engine_count))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user