[BugFix] Fix multi-node offline data parallel (#19937)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-06-24 12:45:20 -07:00 committed by GitHub
parent c635c5f744
commit 8619e7158c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 31 additions and 4 deletions

View File

@ -615,13 +615,16 @@ steps:
- vllm/executor/
- vllm/model_executor/models/
- tests/distributed/
- tests/examples/offline_inference/data_parallel.py
commands:
- # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up)
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
- python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=0 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py
- # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
- python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=1 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code
- label: Distributed Tests (2 GPUs) # 40min
mirror_hardwares: [amdexperimental]

View File

@ -1568,6 +1568,8 @@ class LLM:
pbar.update(n)
else:
pbar.update(1)
if pbar.n == num_requests:
pbar.refresh()
if use_tqdm:
pbar.close()

View File

@ -877,12 +877,16 @@ class DPEngineCoreProc(EngineCoreProc):
local_unfinished_reqs)
if not self.engines_running:
if self.dp_rank == 0:
if self.dp_rank == 0 or not self.has_coordinator:
# Notify client that we are pausing the loop.
logger.debug("Wave %d finished, pausing engine loop.",
self.current_wave)
# In the coordinator case, dp rank 0 sends updates to the
# coordinator. Otherwise (offline spmd case), each rank
# sends the update to its colocated front-end process.
client_index = -1 if self.has_coordinator else 0
self.output_queue.put_nowait(
(-1,
(client_index,
EngineCoreOutputs(wave_complete=self.current_wave)))
self.current_wave += 1

View File

@ -155,6 +155,11 @@ class EngineCoreClient(ABC):
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
raise NotImplementedError
def dp_engines_running(self) -> bool:
"""Returns True id data parallel engines are collectively in a
running state."""
raise NotImplementedError
async def get_output_async(self) -> EngineCoreOutputs:
raise NotImplementedError
@ -282,6 +287,9 @@ class InprocClient(EngineCoreClient):
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
def dp_engines_running(self) -> bool:
return False
@dataclass
class BackgroundResources:
@ -384,6 +392,9 @@ class MPClient(EngineCoreClient):
dp_size = parallel_config.data_parallel_size
dp_rank = parallel_config.data_parallel_rank
# State used for data parallel.
self.engines_running = False
# SPMD mode is where there is an LLM instance per DP rank and
# one core engine per LLM, see
# examples/offline_inference/data_parallel.py.
@ -539,6 +550,9 @@ class MPClient(EngineCoreClient):
while self.pending_messages and self.pending_messages[-1][0].done:
self.pending_messages.pop()
def dp_engines_running(self) -> bool:
return self.engines_running
def _process_utility_output(output: UtilityOutput,
utility_results: dict[int, AnyFuture]):
@ -562,6 +576,7 @@ class SyncMPClient(MPClient):
log_stats=log_stats,
)
self.is_dp = self.vllm_config.parallel_config.data_parallel_size > 1
self.outputs_queue = queue.Queue[Union[EngineCoreOutputs, Exception]]()
# Ensure that the outputs socket processing thread does not have
@ -623,6 +638,8 @@ class SyncMPClient(MPClient):
outputs = self.outputs_queue.get()
if isinstance(outputs, Exception):
raise self._format_exception(outputs) from None
if outputs.wave_complete is not None:
self.engines_running = False
return outputs
def _send_input(self, request_type: EngineCoreRequestType, request: Any):
@ -650,6 +667,8 @@ class SyncMPClient(MPClient):
return future.result()
def add_request(self, request: EngineCoreRequest) -> None:
if self.is_dp:
self.engines_running = True
self._send_input(EngineCoreRequestType.ADD, request)
def abort_requests(self, request_ids: list[str]) -> None:
@ -911,7 +930,6 @@ class DPAsyncMPClient(AsyncMPClient):
client_addresses: Optional[dict[str, str]] = None,
client_index: int = 0):
self.current_wave = 0
self.engines_running = False
# To route aborts to the correct engine.
self.reqs_in_flight: dict[str, CoreEngine] = {}

View File

@ -160,7 +160,7 @@ class LLMEngine:
def has_unfinished_requests(self) -> bool:
has_unfinished = self.output_processor.has_unfinished_requests()
if self.dp_group is None:
return has_unfinished
return has_unfinished or self.engine_core.dp_engines_running()
return self.has_unfinished_requests_dp(has_unfinished)
def has_unfinished_requests_dp(self, has_unfinished: bool) -> bool: