[Misc] Add parallel state node_count function (#20045)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-06-25 13:38:53 -07:00 committed by GitHub
parent 4734704b30
commit c40692bf9a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 98 additions and 2 deletions

View File

@ -619,11 +619,13 @@ steps:
commands:
- # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up)
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
- NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed'
- python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=0 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py
- # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
- NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed'
- python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=1 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code
- label: Distributed Tests (2 GPUs) # 40min

View File

@ -0,0 +1,43 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import torch.distributed as dist
from vllm.distributed.parallel_state import _node_count
from vllm.distributed.utils import StatelessProcessGroup
from vllm.utils import get_ip, get_open_port
if __name__ == "__main__":
dist.init_process_group(backend="gloo")
rank = dist.get_rank()
world_size = dist.get_world_size()
if rank == 0:
port = get_open_port()
ip = get_ip()
dist.broadcast_object_list([ip, port], src=0)
else:
recv = [None, None]
dist.broadcast_object_list(recv, src=0)
ip, port = recv
stateless_pg = StatelessProcessGroup.create(ip, port, rank, world_size)
for pg in [dist.group.WORLD, stateless_pg]:
test_result = _node_count(pg)
# Expected node count based on environment variable)
expected = int(os.environ.get("NUM_NODES", "1"))
assert test_result == expected, \
f"Expected {expected} nodes, got {test_result}"
if pg == dist.group.WORLD:
print(f"Node count test passed! Got {test_result} nodes "
f"when using torch distributed!")
else:
print(f"Node count test passed! Got {test_result} nodes "
f"when using StatelessProcessGroup!")

View File

@ -802,6 +802,7 @@ class GroupCoordinator:
_WORLD: Optional[GroupCoordinator] = None
_NODE_COUNT: Optional[int] = None
def get_world_group() -> GroupCoordinator:
@ -961,10 +962,13 @@ def init_distributed_environment(
local_rank = envs.LOCAL_RANK
else:
local_rank = rank
global _WORLD
global _WORLD, _NODE_COUNT
if _WORLD is None:
ranks = list(range(torch.distributed.get_world_size()))
_WORLD = init_world_group(ranks, local_rank, backend)
_NODE_COUNT = _node_count(_WORLD.cpu_group)
logger.debug("Detected %d nodes in the distributed environment",
_NODE_COUNT)
else:
assert _WORLD.world_size == torch.distributed.get_world_size(), (
"world group already initialized with a different world size")
@ -1164,6 +1168,13 @@ def get_tensor_model_parallel_rank():
return get_tp_group().rank_in_group
def get_node_count() -> int:
"""Return the total number of nodes in the distributed environment. """
assert _NODE_COUNT is not None, (
"distributed environment is not initialized")
return _NODE_COUNT
def destroy_model_parallel():
"""Set the groups to none and destroy them."""
global _TP
@ -1189,10 +1200,11 @@ def destroy_model_parallel():
def destroy_distributed_environment():
global _WORLD
global _WORLD, _NODE_COUNT
if _WORLD:
_WORLD.destroy()
_WORLD = None
_NODE_COUNT = None
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
@ -1301,3 +1313,42 @@ def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
aggregated_data += rank_data
return [x == 1 for x in aggregated_data.tolist()]
def _node_count(pg: Union[ProcessGroup, StatelessProcessGroup]) -> int:
"""
Returns the total number of nodes in the process group.
Args:
pg: The process group to analyze
Returns:
int: The total number of nodes
"""
if isinstance(pg, ProcessGroup):
world_size = torch.distributed.get_world_size(group=pg)
else:
world_size = pg.world_size
if world_size == 1:
return 1
# Build node assignment map
node_assignment = [0] * world_size # rank -> node_id
next_node_id = 0
for current_rank in range(world_size):
if node_assignment[current_rank] != 0:
continue # Already assigned to a node
# Assign current rank to a new node
next_node_id += 1
node_assignment[current_rank] = next_node_id
# Find all ranks on the same node as current_rank
same_node_flags = in_the_same_node_as(pg, current_rank)
for other_rank, is_same_node in enumerate(same_node_flags):
if is_same_node and node_assignment[other_rank] == 0:
node_assignment[other_rank] = next_node_id
return next_node_id