diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9941cacae8ab..efb4a8c0054f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3517,7 +3517,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): compilation_counter.num_gpu_runner_capture_triggers += 1 start_time = time.perf_counter() - start_free_gpu_memory = torch.cuda.mem_get_info()[0] @contextmanager def freeze_gc(): @@ -3540,6 +3539,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # can reuse the memory pool allocated for the large shapes. set_cudagraph_capturing_enabled(True) with freeze_gc(), graph_capture(device=self.device): + start_free_gpu_memory = torch.cuda.mem_get_info()[0] cudagraph_mode = self.compilation_config.cudagraph_mode assert cudagraph_mode is not None if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: @@ -3568,6 +3568,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): cudagraph_runtime_mode=CUDAGraphMode.FULL, uniform_decode=True) + torch.cuda.synchronize() + end_free_gpu_memory = torch.cuda.mem_get_info()[0] + # Disable cudagraph capturing globally, so any unexpected cudagraph # capturing will be detected and raise an error after here. # Note: We don't put it into graph_capture context manager because @@ -3576,7 +3579,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): set_cudagraph_capturing_enabled(False) end_time = time.perf_counter() - end_free_gpu_memory = torch.cuda.mem_get_info()[0] elapsed_time = end_time - start_time cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory # This usually takes 5~20 seconds.