add stream to cuda graph catpure

Signed-off-by: Amir Samani <asamani@nvidia.com>
This commit is contained in:
Amir Samani 2025-11-21 14:20:08 -08:00
parent 57430fc95c
commit 28579b55fa
2 changed files with 6 additions and 2 deletions

View File

@ -293,7 +293,7 @@ class CommunicatorBenchmark:
graph = torch.cuda.CUDAGraph()
graph_pool = torch.cuda.graph_pool_handle()
set_graph_pool_id(graph_pool)
with torch.cuda.graph(graph, pool=graph_pool):
with torch.cuda.graph(graph, pool=graph_pool, stream=stream):
for _ in range(CUDA_GRAPH_CAPTURE_CYCLES):
allreduce_fn(graph_input)

View File

@ -169,7 +169,11 @@ class CUDAGraphWrapper:
else:
set_graph_pool_id(current_platform.graph_pool_handle())
# mind-exploding: carefully manage the reference and memory.
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
with torch.cuda.graph(
cudagraph,
pool=self.graph_pool,
stream=torch.cuda.current_stream(),
):
# `output` is managed by pytorch's cudagraph pool
output = self.runnable(*args, **kwargs)
if self.cudagraph_options.weak_ref_output: