From 695d78c4713c04b4535e4c5e083c3aa2f3ea2b10 Mon Sep 17 00:00:00 2001 From: Amir Samani Date: Mon, 24 Nov 2025 11:23:23 -0800 Subject: [PATCH] wip Signed-off-by: Amir Samani --- tests/v1/cudagraph/test_cudagraph_dispatch.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/tests/v1/cudagraph/test_cudagraph_dispatch.py b/tests/v1/cudagraph/test_cudagraph_dispatch.py index bb953e5c70c8c..2dba8c89e3217 100644 --- a/tests/v1/cudagraph/test_cudagraph_dispatch.py +++ b/tests/v1/cudagraph/test_cudagraph_dispatch.py @@ -17,6 +17,7 @@ from vllm.config import ( SchedulerConfig, VllmConfig, ) +from vllm.distributed.parallel_state import graph_capture from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.platforms import current_platform from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher @@ -167,11 +168,14 @@ class TestCUDAGraphWrapper: batch_descriptor = BatchDescriptor(num_tokens=10) # 0. global warmup - with set_forward_context( - attn_metadata=None, - vllm_config=self.vllm_config, - cudagraph_runtime_mode=CUDAGraphMode.NONE, - batch_descriptor=None, + with ( + set_forward_context( + attn_metadata=None, + vllm_config=self.vllm_config, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + batch_descriptor=None, + ), + graph_capture(device=torch.device("cuda")) as graph_ctx, ): wrapper(self.input_tensor) @@ -184,6 +188,7 @@ class TestCUDAGraphWrapper: batch_descriptor=batch_descriptor, ), patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph, + graph_capture(device=torch.device("cuda")) as graph_ctx, ): output1 = wrapper(self.input_tensor) # capturing phase should generate a zero output