[misc] fix debugging code (#13487)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2025-02-19 01:37:11 +08:00 committed by GitHub
parent 4fb8142a0e
commit 7b203b7694
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -94,20 +94,20 @@ pynccl.disabled = False
s = torch.cuda.Stream()
with torch.cuda.stream(s):
data.fill_(1)
pynccl.all_reduce(data, stream=s)
value = data.mean().item()
out = pynccl.all_reduce(data, stream=s)
value = out.mean().item()
assert value == world_size, f"Expected {world_size}, got {value}"
print("vLLM NCCL is successful!")
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(cuda_graph=g, stream=s):
pynccl.all_reduce(data, stream=torch.cuda.current_stream())
out = pynccl.all_reduce(data, stream=torch.cuda.current_stream())
data.fill_(1)
g.replay()
torch.cuda.current_stream().synchronize()
value = data.mean().item()
value = out.mean().item()
assert value == world_size, f"Expected {world_size}, got {value}"
print("vLLM NCCL with cuda graph is successful!")