mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 18:38:38 +08:00
[doc] update test script to include cudagraph (#7501)
This commit is contained in:
parent
dd164d72f3
commit
199adbb7cf
@ -30,24 +30,59 @@ Here are some common issues that can cause hangs:
|
|||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
# Test PyTorch NCCL
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
dist.init_process_group(backend="nccl")
|
dist.init_process_group(backend="nccl")
|
||||||
local_rank = dist.get_rank() % torch.cuda.device_count()
|
local_rank = dist.get_rank() % torch.cuda.device_count()
|
||||||
data = torch.FloatTensor([1,] * 128).to(f"cuda:{local_rank}")
|
torch.cuda.set_device(local_rank)
|
||||||
|
data = torch.FloatTensor([1,] * 128).to("cuda")
|
||||||
dist.all_reduce(data, op=dist.ReduceOp.SUM)
|
dist.all_reduce(data, op=dist.ReduceOp.SUM)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
value = data.mean().item()
|
value = data.mean().item()
|
||||||
world_size = dist.get_world_size()
|
world_size = dist.get_world_size()
|
||||||
assert value == world_size, f"Expected {world_size}, got {value}"
|
assert value == world_size, f"Expected {world_size}, got {value}"
|
||||||
|
|
||||||
|
print("PyTorch NCCL is successful!")
|
||||||
|
|
||||||
|
# Test PyTorch GLOO
|
||||||
gloo_group = dist.new_group(ranks=list(range(world_size)), backend="gloo")
|
gloo_group = dist.new_group(ranks=list(range(world_size)), backend="gloo")
|
||||||
cpu_data = torch.FloatTensor([1,] * 128)
|
cpu_data = torch.FloatTensor([1,] * 128)
|
||||||
dist.all_reduce(cpu_data, op=dist.ReduceOp.SUM, group=gloo_group)
|
dist.all_reduce(cpu_data, op=dist.ReduceOp.SUM, group=gloo_group)
|
||||||
value = cpu_data.mean().item()
|
value = cpu_data.mean().item()
|
||||||
assert value == world_size, f"Expected {world_size}, got {value}"
|
assert value == world_size, f"Expected {world_size}, got {value}"
|
||||||
|
|
||||||
print("sanity check is successful!")
|
print("PyTorch GLOO is successful!")
|
||||||
|
|
||||||
|
# Test vLLM NCCL, with cuda graph
|
||||||
|
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||||
|
|
||||||
|
pynccl = PyNcclCommunicator(group=gloo_group, device=local_rank)
|
||||||
|
pynccl.disabled = False
|
||||||
|
|
||||||
|
s = torch.cuda.Stream()
|
||||||
|
with torch.cuda.stream(s):
|
||||||
|
data.fill_(1)
|
||||||
|
pynccl.all_reduce(data, stream=s)
|
||||||
|
value = data.mean().item()
|
||||||
|
assert value == world_size, f"Expected {world_size}, got {value}"
|
||||||
|
|
||||||
|
print("vLLM NCCL is successful!")
|
||||||
|
|
||||||
|
g = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(cuda_graph=g, stream=s):
|
||||||
|
pynccl.all_reduce(data, stream=torch.cuda.current_stream())
|
||||||
|
|
||||||
|
data.fill_(1)
|
||||||
|
g.replay()
|
||||||
|
torch.cuda.current_stream().synchronize()
|
||||||
|
value = data.mean().item()
|
||||||
|
assert value == world_size, f"Expected {world_size}, got {value}"
|
||||||
|
|
||||||
|
print("vLLM NCCL with cuda graph is successful!")
|
||||||
|
|
||||||
|
dist.destroy_process_group(gloo_group)
|
||||||
|
dist.destroy_process_group()
|
||||||
|
|
||||||
.. tip::
|
.. tip::
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user