diff --git a/benchmarks/kernels/benchmark_device_communicators.py b/benchmarks/kernels/benchmark_device_communicators.py index b414efa6e330b..7b453fe7b6809 100644 --- a/benchmarks/kernels/benchmark_device_communicators.py +++ b/benchmarks/kernels/benchmark_device_communicators.py @@ -293,7 +293,7 @@ class CommunicatorBenchmark: graph = torch.cuda.CUDAGraph() graph_pool = torch.cuda.graph_pool_handle() set_graph_pool_id(graph_pool) - with torch.cuda.graph(graph, pool=graph_pool): + with torch.cuda.graph(graph, pool=graph_pool, stream=stream): for _ in range(CUDA_GRAPH_CAPTURE_CYCLES): allreduce_fn(graph_input) diff --git a/tests/v1/cudagraph/test_cudagraph_dispatch.py b/tests/v1/cudagraph/test_cudagraph_dispatch.py index f9d3e8d0532b5..3c97511b3450b 100644 --- a/tests/v1/cudagraph/test_cudagraph_dispatch.py +++ b/tests/v1/cudagraph/test_cudagraph_dispatch.py @@ -189,13 +189,16 @@ class TestCUDAGraphWrapper: self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL ) batch_descriptor = BatchDescriptor(num_tokens=10) - + stream = torch.cuda.Stream() # 0. global warmup - with set_forward_context( - attn_metadata=None, - vllm_config=self.vllm_config, - cudagraph_runtime_mode=CUDAGraphMode.NONE, - batch_descriptor=None, + with ( + set_forward_context( + attn_metadata=None, + vllm_config=self.vllm_config, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + batch_descriptor=None, + ), + torch.cuda.stream(stream), ): wrapper(self.input_tensor) @@ -208,6 +211,7 @@ class TestCUDAGraphWrapper: batch_descriptor=batch_descriptor, ), patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph, + torch.cuda.stream(stream), ): output1 = wrapper(self.input_tensor) # capturing phase should generate a zero output diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py index 0748643a5299f..84209b1a18138 100644 --- a/vllm/compilation/cuda_graph.py +++ b/vllm/compilation/cuda_graph.py @@ -263,7 +263,11 @@ class CUDAGraphWrapper: else: set_graph_pool_id(current_platform.graph_pool_handle()) # mind-exploding: carefully manage the reference and memory. - with torch.cuda.graph(cudagraph, pool=self.graph_pool): + with torch.cuda.graph( + cudagraph, + pool=self.graph_pool, + stream=torch.cuda.current_stream(), + ): # `output` is managed by pytorch's cudagraph pool output = self.runnable(*args, **kwargs) if self.cudagraph_options.weak_ref_output: