From 28579b55faf47e267ca7ff8b49e2f2272b88ec4f Mon Sep 17 00:00:00 2001 From: Amir Samani Date: Fri, 21 Nov 2025 14:20:08 -0800 Subject: [PATCH 1/4] add stream to cuda graph catpure Signed-off-by: Amir Samani --- benchmarks/kernels/benchmark_device_communicators.py | 2 +- vllm/compilation/cuda_graph.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/benchmarks/kernels/benchmark_device_communicators.py b/benchmarks/kernels/benchmark_device_communicators.py index b414efa6e330b..7b453fe7b6809 100644 --- a/benchmarks/kernels/benchmark_device_communicators.py +++ b/benchmarks/kernels/benchmark_device_communicators.py @@ -293,7 +293,7 @@ class CommunicatorBenchmark: graph = torch.cuda.CUDAGraph() graph_pool = torch.cuda.graph_pool_handle() set_graph_pool_id(graph_pool) - with torch.cuda.graph(graph, pool=graph_pool): + with torch.cuda.graph(graph, pool=graph_pool, stream=stream): for _ in range(CUDA_GRAPH_CAPTURE_CYCLES): allreduce_fn(graph_input) diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py index a2e0abfebc2c9..f936084939d36 100644 --- a/vllm/compilation/cuda_graph.py +++ b/vllm/compilation/cuda_graph.py @@ -169,7 +169,11 @@ class CUDAGraphWrapper: else: set_graph_pool_id(current_platform.graph_pool_handle()) # mind-exploding: carefully manage the reference and memory. - with torch.cuda.graph(cudagraph, pool=self.graph_pool): + with torch.cuda.graph( + cudagraph, + pool=self.graph_pool, + stream=torch.cuda.current_stream(), + ): # `output` is managed by pytorch's cudagraph pool output = self.runnable(*args, **kwargs) if self.cudagraph_options.weak_ref_output: From 695d78c4713c04b4535e4c5e083c3aa2f3ea2b10 Mon Sep 17 00:00:00 2001 From: Amir Samani Date: Mon, 24 Nov 2025 11:23:23 -0800 Subject: [PATCH 2/4] wip Signed-off-by: Amir Samani --- tests/v1/cudagraph/test_cudagraph_dispatch.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/tests/v1/cudagraph/test_cudagraph_dispatch.py b/tests/v1/cudagraph/test_cudagraph_dispatch.py index bb953e5c70c8c..2dba8c89e3217 100644 --- a/tests/v1/cudagraph/test_cudagraph_dispatch.py +++ b/tests/v1/cudagraph/test_cudagraph_dispatch.py @@ -17,6 +17,7 @@ from vllm.config import ( SchedulerConfig, VllmConfig, ) +from vllm.distributed.parallel_state import graph_capture from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.platforms import current_platform from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher @@ -167,11 +168,14 @@ class TestCUDAGraphWrapper: batch_descriptor = BatchDescriptor(num_tokens=10) # 0. global warmup - with set_forward_context( - attn_metadata=None, - vllm_config=self.vllm_config, - cudagraph_runtime_mode=CUDAGraphMode.NONE, - batch_descriptor=None, + with ( + set_forward_context( + attn_metadata=None, + vllm_config=self.vllm_config, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + batch_descriptor=None, + ), + graph_capture(device=torch.device("cuda")) as graph_ctx, ): wrapper(self.input_tensor) @@ -184,6 +188,7 @@ class TestCUDAGraphWrapper: batch_descriptor=batch_descriptor, ), patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph, + graph_capture(device=torch.device("cuda")) as graph_ctx, ): output1 = wrapper(self.input_tensor) # capturing phase should generate a zero output From dfb5ce44dc043420d6d2358f3ec51816fb4cebf8 Mon Sep 17 00:00:00 2001 From: Amir Samani Date: Mon, 24 Nov 2025 13:48:31 -0800 Subject: [PATCH 3/4] wip Signed-off-by: Amir Samani --- tests/v1/cudagraph/test_cudagraph_dispatch.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/v1/cudagraph/test_cudagraph_dispatch.py b/tests/v1/cudagraph/test_cudagraph_dispatch.py index 2dba8c89e3217..2d7da4d177959 100644 --- a/tests/v1/cudagraph/test_cudagraph_dispatch.py +++ b/tests/v1/cudagraph/test_cudagraph_dispatch.py @@ -166,7 +166,7 @@ class TestCUDAGraphWrapper: self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL ) batch_descriptor = BatchDescriptor(num_tokens=10) - + stream = torch.cuda.Stream() # 0. global warmup with ( set_forward_context( @@ -175,7 +175,7 @@ class TestCUDAGraphWrapper: cudagraph_runtime_mode=CUDAGraphMode.NONE, batch_descriptor=None, ), - graph_capture(device=torch.device("cuda")) as graph_ctx, + torch.cuda.stream(stream), ): wrapper(self.input_tensor) @@ -188,7 +188,7 @@ class TestCUDAGraphWrapper: batch_descriptor=batch_descriptor, ), patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph, - graph_capture(device=torch.device("cuda")) as graph_ctx, + torch.cuda.stream(stream), ): output1 = wrapper(self.input_tensor) # capturing phase should generate a zero output From d1f2b3b9953e450ab191c5f0232fc0813484c41e Mon Sep 17 00:00:00 2001 From: Amir Samani Date: Mon, 24 Nov 2025 14:12:12 -0800 Subject: [PATCH 4/4] wip Signed-off-by: Amir Samani --- tests/v1/cudagraph/test_cudagraph_dispatch.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/v1/cudagraph/test_cudagraph_dispatch.py b/tests/v1/cudagraph/test_cudagraph_dispatch.py index 2d7da4d177959..0ee58415787c8 100644 --- a/tests/v1/cudagraph/test_cudagraph_dispatch.py +++ b/tests/v1/cudagraph/test_cudagraph_dispatch.py @@ -17,7 +17,6 @@ from vllm.config import ( SchedulerConfig, VllmConfig, ) -from vllm.distributed.parallel_state import graph_capture from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.platforms import current_platform from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher