mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 00:16:00 +08:00
[Misc] Add parallel state node_count function (#20045)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
4734704b30
commit
c40692bf9a
@ -619,11 +619,13 @@ steps:
|
|||||||
commands:
|
commands:
|
||||||
- # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up)
|
- # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up)
|
||||||
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
|
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
|
||||||
|
- NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed'
|
||||||
- python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=0 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code
|
- python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=0 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code
|
||||||
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py
|
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py
|
||||||
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py
|
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py
|
||||||
- # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
|
- # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
|
||||||
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
|
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
|
||||||
|
- NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed'
|
||||||
- python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=1 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code
|
- python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=1 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code
|
||||||
|
|
||||||
- label: Distributed Tests (2 GPUs) # 40min
|
- label: Distributed Tests (2 GPUs) # 40min
|
||||||
|
|||||||
43
tests/distributed/test_node_count.py
Normal file
43
tests/distributed/test_node_count.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from vllm.distributed.parallel_state import _node_count
|
||||||
|
from vllm.distributed.utils import StatelessProcessGroup
|
||||||
|
from vllm.utils import get_ip, get_open_port
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
dist.init_process_group(backend="gloo")
|
||||||
|
|
||||||
|
rank = dist.get_rank()
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
port = get_open_port()
|
||||||
|
ip = get_ip()
|
||||||
|
dist.broadcast_object_list([ip, port], src=0)
|
||||||
|
else:
|
||||||
|
recv = [None, None]
|
||||||
|
dist.broadcast_object_list(recv, src=0)
|
||||||
|
ip, port = recv
|
||||||
|
|
||||||
|
stateless_pg = StatelessProcessGroup.create(ip, port, rank, world_size)
|
||||||
|
|
||||||
|
for pg in [dist.group.WORLD, stateless_pg]:
|
||||||
|
test_result = _node_count(pg)
|
||||||
|
|
||||||
|
# Expected node count based on environment variable)
|
||||||
|
expected = int(os.environ.get("NUM_NODES", "1"))
|
||||||
|
|
||||||
|
assert test_result == expected, \
|
||||||
|
f"Expected {expected} nodes, got {test_result}"
|
||||||
|
|
||||||
|
if pg == dist.group.WORLD:
|
||||||
|
print(f"Node count test passed! Got {test_result} nodes "
|
||||||
|
f"when using torch distributed!")
|
||||||
|
else:
|
||||||
|
print(f"Node count test passed! Got {test_result} nodes "
|
||||||
|
f"when using StatelessProcessGroup!")
|
||||||
@ -802,6 +802,7 @@ class GroupCoordinator:
|
|||||||
|
|
||||||
|
|
||||||
_WORLD: Optional[GroupCoordinator] = None
|
_WORLD: Optional[GroupCoordinator] = None
|
||||||
|
_NODE_COUNT: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
def get_world_group() -> GroupCoordinator:
|
def get_world_group() -> GroupCoordinator:
|
||||||
@ -961,10 +962,13 @@ def init_distributed_environment(
|
|||||||
local_rank = envs.LOCAL_RANK
|
local_rank = envs.LOCAL_RANK
|
||||||
else:
|
else:
|
||||||
local_rank = rank
|
local_rank = rank
|
||||||
global _WORLD
|
global _WORLD, _NODE_COUNT
|
||||||
if _WORLD is None:
|
if _WORLD is None:
|
||||||
ranks = list(range(torch.distributed.get_world_size()))
|
ranks = list(range(torch.distributed.get_world_size()))
|
||||||
_WORLD = init_world_group(ranks, local_rank, backend)
|
_WORLD = init_world_group(ranks, local_rank, backend)
|
||||||
|
_NODE_COUNT = _node_count(_WORLD.cpu_group)
|
||||||
|
logger.debug("Detected %d nodes in the distributed environment",
|
||||||
|
_NODE_COUNT)
|
||||||
else:
|
else:
|
||||||
assert _WORLD.world_size == torch.distributed.get_world_size(), (
|
assert _WORLD.world_size == torch.distributed.get_world_size(), (
|
||||||
"world group already initialized with a different world size")
|
"world group already initialized with a different world size")
|
||||||
@ -1164,6 +1168,13 @@ def get_tensor_model_parallel_rank():
|
|||||||
return get_tp_group().rank_in_group
|
return get_tp_group().rank_in_group
|
||||||
|
|
||||||
|
|
||||||
|
def get_node_count() -> int:
|
||||||
|
"""Return the total number of nodes in the distributed environment. """
|
||||||
|
assert _NODE_COUNT is not None, (
|
||||||
|
"distributed environment is not initialized")
|
||||||
|
return _NODE_COUNT
|
||||||
|
|
||||||
|
|
||||||
def destroy_model_parallel():
|
def destroy_model_parallel():
|
||||||
"""Set the groups to none and destroy them."""
|
"""Set the groups to none and destroy them."""
|
||||||
global _TP
|
global _TP
|
||||||
@ -1189,10 +1200,11 @@ def destroy_model_parallel():
|
|||||||
|
|
||||||
|
|
||||||
def destroy_distributed_environment():
|
def destroy_distributed_environment():
|
||||||
global _WORLD
|
global _WORLD, _NODE_COUNT
|
||||||
if _WORLD:
|
if _WORLD:
|
||||||
_WORLD.destroy()
|
_WORLD.destroy()
|
||||||
_WORLD = None
|
_WORLD = None
|
||||||
|
_NODE_COUNT = None
|
||||||
if torch.distributed.is_initialized():
|
if torch.distributed.is_initialized():
|
||||||
torch.distributed.destroy_process_group()
|
torch.distributed.destroy_process_group()
|
||||||
|
|
||||||
@ -1301,3 +1313,42 @@ def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
|
|||||||
aggregated_data += rank_data
|
aggregated_data += rank_data
|
||||||
|
|
||||||
return [x == 1 for x in aggregated_data.tolist()]
|
return [x == 1 for x in aggregated_data.tolist()]
|
||||||
|
|
||||||
|
|
||||||
|
def _node_count(pg: Union[ProcessGroup, StatelessProcessGroup]) -> int:
|
||||||
|
"""
|
||||||
|
Returns the total number of nodes in the process group.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pg: The process group to analyze
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The total number of nodes
|
||||||
|
"""
|
||||||
|
if isinstance(pg, ProcessGroup):
|
||||||
|
world_size = torch.distributed.get_world_size(group=pg)
|
||||||
|
else:
|
||||||
|
world_size = pg.world_size
|
||||||
|
|
||||||
|
if world_size == 1:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
# Build node assignment map
|
||||||
|
node_assignment = [0] * world_size # rank -> node_id
|
||||||
|
next_node_id = 0
|
||||||
|
|
||||||
|
for current_rank in range(world_size):
|
||||||
|
if node_assignment[current_rank] != 0:
|
||||||
|
continue # Already assigned to a node
|
||||||
|
|
||||||
|
# Assign current rank to a new node
|
||||||
|
next_node_id += 1
|
||||||
|
node_assignment[current_rank] = next_node_id
|
||||||
|
|
||||||
|
# Find all ranks on the same node as current_rank
|
||||||
|
same_node_flags = in_the_same_node_as(pg, current_rank)
|
||||||
|
for other_rank, is_same_node in enumerate(same_node_flags):
|
||||||
|
if is_same_node and node_assignment[other_rank] == 0:
|
||||||
|
node_assignment[other_rank] = next_node_id
|
||||||
|
|
||||||
|
return next_node_id
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user