[BugFix] Fix multi-node offline data-parallel (#18981)

Signed-off-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
Nick Hill 2025-05-31 08:34:52 -07:00 committed by GitHub
parent 8bf507d766
commit 9a1b9b99d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 11 additions and 8 deletions

View File

@ -97,10 +97,14 @@ def main(
# with DP, each rank should process different prompts.
# usually all the DP ranks process a full dataset,
# and each rank processes a different part of the dataset.
promts_per_rank = len(prompts) // dp_size
start = global_dp_rank * promts_per_rank
end = start + promts_per_rank
prompts = prompts[start:end]
floor = len(prompts) // dp_size
remainder = len(prompts) % dp_size
# Distribute prompts into even groups.
def start(rank):
return rank * floor + min(rank, remainder)
prompts = prompts[start(global_dp_rank) : start(global_dp_rank + 1)]
if len(prompts) == 0:
# if any rank has no prompts to process,
# we need to set a placeholder prompt

View File

@ -363,6 +363,7 @@ class MPClient(EngineCoreClient):
local_engine_count = parallel_config.data_parallel_size_local
local_start_index = parallel_config.data_parallel_rank_local
dp_size = parallel_config.data_parallel_size
dp_rank = parallel_config.data_parallel_rank
# SPMD mode is where there is an LLM instance per DP rank and
# one core engine per LLM, see
@ -370,11 +371,9 @@ class MPClient(EngineCoreClient):
spmd_mode = local_start_index is not None
if spmd_mode:
assert local_engine_count == 1
self.core_engines = [
CoreEngine(index=local_start_index, local=True)
]
self.core_engines = [CoreEngine(index=dp_rank, local=True)]
else:
assert parallel_config.data_parallel_rank == 0
assert dp_rank == 0
local_start_index = 0
self.core_engines = [
CoreEngine(index=i, local=(i < local_engine_count))