From dfb5ce44dc043420d6d2358f3ec51816fb4cebf8 Mon Sep 17 00:00:00 2001 From: Amir Samani Date: Mon, 24 Nov 2025 13:48:31 -0800 Subject: [PATCH] wip Signed-off-by: Amir Samani --- tests/v1/cudagraph/test_cudagraph_dispatch.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/v1/cudagraph/test_cudagraph_dispatch.py b/tests/v1/cudagraph/test_cudagraph_dispatch.py index 2dba8c89e3217..2d7da4d177959 100644 --- a/tests/v1/cudagraph/test_cudagraph_dispatch.py +++ b/tests/v1/cudagraph/test_cudagraph_dispatch.py @@ -166,7 +166,7 @@ class TestCUDAGraphWrapper: self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL ) batch_descriptor = BatchDescriptor(num_tokens=10) - + stream = torch.cuda.Stream() # 0. global warmup with ( set_forward_context( @@ -175,7 +175,7 @@ class TestCUDAGraphWrapper: cudagraph_runtime_mode=CUDAGraphMode.NONE, batch_descriptor=None, ), - graph_capture(device=torch.device("cuda")) as graph_ctx, + torch.cuda.stream(stream), ): wrapper(self.input_tensor) @@ -188,7 +188,7 @@ class TestCUDAGraphWrapper: batch_descriptor=batch_descriptor, ), patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph, - graph_capture(device=torch.device("cuda")) as graph_ctx, + torch.cuda.stream(stream), ): output1 = wrapper(self.input_tensor) # capturing phase should generate a zero output