diff --git a/tests/v1/cudagraph/test_cudagraph_dispatch.py b/tests/v1/cudagraph/test_cudagraph_dispatch.py index 2dba8c89e3217..2d7da4d177959 100644 --- a/tests/v1/cudagraph/test_cudagraph_dispatch.py +++ b/tests/v1/cudagraph/test_cudagraph_dispatch.py @@ -166,7 +166,7 @@ class TestCUDAGraphWrapper: self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL ) batch_descriptor = BatchDescriptor(num_tokens=10) - + stream = torch.cuda.Stream() # 0. global warmup with ( set_forward_context( @@ -175,7 +175,7 @@ class TestCUDAGraphWrapper: cudagraph_runtime_mode=CUDAGraphMode.NONE, batch_descriptor=None, ), - graph_capture(device=torch.device("cuda")) as graph_ctx, + torch.cuda.stream(stream), ): wrapper(self.input_tensor) @@ -188,7 +188,7 @@ class TestCUDAGraphWrapper: batch_descriptor=batch_descriptor, ), patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph, - graph_capture(device=torch.device("cuda")) as graph_ctx, + torch.cuda.stream(stream), ): output1 = wrapper(self.input_tensor) # capturing phase should generate a zero output