mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:54:56 +08:00
[Misc] Add parallel state node_count function (#20045)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
4734704b30
commit
c40692bf9a
@ -619,11 +619,13 @@ steps:
|
||||
commands:
|
||||
- # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up)
|
||||
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
|
||||
- NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed'
|
||||
- python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=0 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code
|
||||
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py
|
||||
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py
|
||||
- # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
|
||||
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
|
||||
- NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed'
|
||||
- python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=1 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code
|
||||
|
||||
- label: Distributed Tests (2 GPUs) # 40min
|
||||
|
||||
43
tests/distributed/test_node_count.py
Normal file
43
tests/distributed/test_node_count.py
Normal file
@ -0,0 +1,43 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm.distributed.parallel_state import _node_count
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.utils import get_ip, get_open_port
|
||||
|
||||
if __name__ == "__main__":
|
||||
dist.init_process_group(backend="gloo")
|
||||
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
if rank == 0:
|
||||
port = get_open_port()
|
||||
ip = get_ip()
|
||||
dist.broadcast_object_list([ip, port], src=0)
|
||||
else:
|
||||
recv = [None, None]
|
||||
dist.broadcast_object_list(recv, src=0)
|
||||
ip, port = recv
|
||||
|
||||
stateless_pg = StatelessProcessGroup.create(ip, port, rank, world_size)
|
||||
|
||||
for pg in [dist.group.WORLD, stateless_pg]:
|
||||
test_result = _node_count(pg)
|
||||
|
||||
# Expected node count based on environment variable)
|
||||
expected = int(os.environ.get("NUM_NODES", "1"))
|
||||
|
||||
assert test_result == expected, \
|
||||
f"Expected {expected} nodes, got {test_result}"
|
||||
|
||||
if pg == dist.group.WORLD:
|
||||
print(f"Node count test passed! Got {test_result} nodes "
|
||||
f"when using torch distributed!")
|
||||
else:
|
||||
print(f"Node count test passed! Got {test_result} nodes "
|
||||
f"when using StatelessProcessGroup!")
|
||||
@ -802,6 +802,7 @@ class GroupCoordinator:
|
||||
|
||||
|
||||
_WORLD: Optional[GroupCoordinator] = None
|
||||
_NODE_COUNT: Optional[int] = None
|
||||
|
||||
|
||||
def get_world_group() -> GroupCoordinator:
|
||||
@ -961,10 +962,13 @@ def init_distributed_environment(
|
||||
local_rank = envs.LOCAL_RANK
|
||||
else:
|
||||
local_rank = rank
|
||||
global _WORLD
|
||||
global _WORLD, _NODE_COUNT
|
||||
if _WORLD is None:
|
||||
ranks = list(range(torch.distributed.get_world_size()))
|
||||
_WORLD = init_world_group(ranks, local_rank, backend)
|
||||
_NODE_COUNT = _node_count(_WORLD.cpu_group)
|
||||
logger.debug("Detected %d nodes in the distributed environment",
|
||||
_NODE_COUNT)
|
||||
else:
|
||||
assert _WORLD.world_size == torch.distributed.get_world_size(), (
|
||||
"world group already initialized with a different world size")
|
||||
@ -1164,6 +1168,13 @@ def get_tensor_model_parallel_rank():
|
||||
return get_tp_group().rank_in_group
|
||||
|
||||
|
||||
def get_node_count() -> int:
|
||||
"""Return the total number of nodes in the distributed environment. """
|
||||
assert _NODE_COUNT is not None, (
|
||||
"distributed environment is not initialized")
|
||||
return _NODE_COUNT
|
||||
|
||||
|
||||
def destroy_model_parallel():
|
||||
"""Set the groups to none and destroy them."""
|
||||
global _TP
|
||||
@ -1189,10 +1200,11 @@ def destroy_model_parallel():
|
||||
|
||||
|
||||
def destroy_distributed_environment():
|
||||
global _WORLD
|
||||
global _WORLD, _NODE_COUNT
|
||||
if _WORLD:
|
||||
_WORLD.destroy()
|
||||
_WORLD = None
|
||||
_NODE_COUNT = None
|
||||
if torch.distributed.is_initialized():
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
@ -1301,3 +1313,42 @@ def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
|
||||
aggregated_data += rank_data
|
||||
|
||||
return [x == 1 for x in aggregated_data.tolist()]
|
||||
|
||||
|
||||
def _node_count(pg: Union[ProcessGroup, StatelessProcessGroup]) -> int:
|
||||
"""
|
||||
Returns the total number of nodes in the process group.
|
||||
|
||||
Args:
|
||||
pg: The process group to analyze
|
||||
|
||||
Returns:
|
||||
int: The total number of nodes
|
||||
"""
|
||||
if isinstance(pg, ProcessGroup):
|
||||
world_size = torch.distributed.get_world_size(group=pg)
|
||||
else:
|
||||
world_size = pg.world_size
|
||||
|
||||
if world_size == 1:
|
||||
return 1
|
||||
|
||||
# Build node assignment map
|
||||
node_assignment = [0] * world_size # rank -> node_id
|
||||
next_node_id = 0
|
||||
|
||||
for current_rank in range(world_size):
|
||||
if node_assignment[current_rank] != 0:
|
||||
continue # Already assigned to a node
|
||||
|
||||
# Assign current rank to a new node
|
||||
next_node_id += 1
|
||||
node_assignment[current_rank] = next_node_id
|
||||
|
||||
# Find all ranks on the same node as current_rank
|
||||
same_node_flags = in_the_same_node_as(pg, current_rank)
|
||||
for other_rank, is_same_node in enumerate(same_node_flags):
|
||||
if is_same_node and node_assignment[other_rank] == 0:
|
||||
node_assignment[other_rank] = next_node_id
|
||||
|
||||
return next_node_id
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user