create is_in_the_same_node on cpu (#26832)

Co-authored-by: Lunwen He <lunwenh@meta.com>
This commit is contained in:
Lunwen He 2025-10-20 19:04:14 -07:00 committed by GitHub
parent 163965d183
commit 0eb8f2b880
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 26 additions and 8 deletions

View File

@ -1081,6 +1081,7 @@ steps:
- pytest -v -s ./compile/test_basic_correctness.py
- pytest -v -s ./compile/test_wrapper.py
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
- VLLM_TEST_SAME_HOST=1 VLLM_TEST_WITH_DEFAULT_DEVICE_SET=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
- pytest -v -s distributed/test_sequence_parallel.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown
- pytest -v -s v1/worker/test_worker_memory_snapshot.py

View File

@ -977,6 +977,7 @@ steps:
- pytest -v -s ./compile/test_basic_correctness.py
- pytest -v -s ./compile/test_wrapper.py
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
- VLLM_TEST_SAME_HOST=1 VLLM_TEST_WITH_DEFAULT_DEVICE_SET=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
- pytest -v -s distributed/test_sequence_parallel.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown
- pytest -v -s v1/worker/test_worker_memory_snapshot.py

View File

@ -3,12 +3,25 @@
import os
import torch
import torch.distributed as dist
from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.distributed.utils import StatelessProcessGroup
from vllm.utils.network_utils import get_ip, get_open_port
def _run_test(pg):
test_result = all(in_the_same_node_as(pg, source_rank=0))
expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1"
assert test_result == expected, f"Expected {expected}, got {test_result}"
if pg == dist.group.WORLD:
print("Same node test passed! when using torch distributed!")
else:
print("Same node test passed! when using StatelessProcessGroup!")
if __name__ == "__main__":
dist.init_process_group(backend="gloo")
@ -25,11 +38,12 @@ if __name__ == "__main__":
stateless_pg = StatelessProcessGroup.create(ip, port, rank, dist.get_world_size())
for pg in [dist.group.WORLD, stateless_pg]:
test_result = all(in_the_same_node_as(pg, source_rank=0))
expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1"
assert test_result == expected, f"Expected {expected}, got {test_result}"
if pg == dist.group.WORLD:
print("Same node test passed! when using torch distributed!")
if os.environ.get("VLLM_TEST_WITH_DEFAULT_DEVICE_SET", "0") == "1":
default_devices = ["cpu"]
if torch.cuda.is_available():
default_devices.append("cuda")
for device in default_devices:
torch.set_default_device(device)
_run_test(pg)
else:
print("Same node test passed! when using StatelessProcessGroup!")
_run_test(pg)

View File

@ -1526,7 +1526,9 @@ def in_the_same_node_as(
ranks = list(range(world_size))
# local tensor in each process to store the result
is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)
is_in_the_same_node = torch.tensor(
[0] * world_size, dtype=torch.int32, device="cpu"
)
magic_message = b"magic_message"
shm = None