diff --git a/benchmarks/kernels/benchmark_device_communicators.py b/benchmarks/kernels/benchmark_device_communicators.py index b414efa6e330b..7b453fe7b6809 100644 --- a/benchmarks/kernels/benchmark_device_communicators.py +++ b/benchmarks/kernels/benchmark_device_communicators.py @@ -293,7 +293,7 @@ class CommunicatorBenchmark: graph = torch.cuda.CUDAGraph() graph_pool = torch.cuda.graph_pool_handle() set_graph_pool_id(graph_pool) - with torch.cuda.graph(graph, pool=graph_pool): + with torch.cuda.graph(graph, pool=graph_pool, stream=stream): for _ in range(CUDA_GRAPH_CAPTURE_CYCLES): allreduce_fn(graph_input) diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py index a2e0abfebc2c9..f936084939d36 100644 --- a/vllm/compilation/cuda_graph.py +++ b/vllm/compilation/cuda_graph.py @@ -169,7 +169,11 @@ class CUDAGraphWrapper: else: set_graph_pool_id(current_platform.graph_pool_handle()) # mind-exploding: carefully manage the reference and memory. - with torch.cuda.graph(cudagraph, pool=self.graph_pool): + with torch.cuda.graph( + cudagraph, + pool=self.graph_pool, + stream=torch.cuda.current_stream(), + ): # `output` is managed by pytorch's cudagraph pool output = self.runnable(*args, **kwargs) if self.cudagraph_options.weak_ref_output: