From b82662d9523d9aa1386d8d1de410426781a1fa3b Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Sat, 15 Mar 2025 20:26:19 -0700 Subject: [PATCH] [BugFix] Fix torch distributed stateless PG backend init (#14870) Signed-off-by: Nick Hill --- examples/offline_inference/data_parallel.py | 5 +++++ vllm/distributed/utils.py | 6 +++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py index b00519314d8bd..b73770ce382cf 100644 --- a/examples/offline_inference/data_parallel.py +++ b/examples/offline_inference/data_parallel.py @@ -76,5 +76,10 @@ if __name__ == "__main__": GPUs_per_dp_rank)) proc.start() procs.append(proc) + exit_code = 0 for proc in procs: proc.join() + if proc.exitcode: + exit_code = proc.exitcode + + exit(exit_code) diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index 25202062e9757..84899358a6d66 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -299,13 +299,10 @@ def stateless_init_torch_distributed_process_group( # different systems (e.g. RPC) in case the store is multi-tenant. prefix_store = PrefixStore(init_method, store) - pg_options = ProcessGroup.Options(backend=backend, timeout=timeout) - pg: ProcessGroup = ProcessGroup( prefix_store, group_rank, group_size, - pg_options, ) if backend == "gloo": @@ -327,7 +324,10 @@ def stateless_init_torch_distributed_process_group( backend_options) backend_type = ProcessGroup.BackendType.NCCL device = torch.device("cuda") + else: + raise RuntimeError(f"Unsupported torch distributed backend: {backend}") + pg._set_default_backend(backend_type) backend_class._set_sequence_number_for_group() pg._register_backend(device, backend_type, backend_class)