diff --git a/vllm/v1/worker/gpu/cudagraph_utils.py b/vllm/v1/worker/gpu/cudagraph_utils.py index 763bd61834625..654bd60e558b1 100644 --- a/vllm/v1/worker/gpu/cudagraph_utils.py +++ b/vllm/v1/worker/gpu/cudagraph_utils.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import gc -from contextlib import contextmanager +from unittest.mock import patch import numpy as np import torch @@ -140,6 +139,7 @@ class CudaGraphManager: attn_metadata, self.vllm_config, num_tokens=batch_size, + cudagraph_runtime_mode=CUDAGraphMode.NONE, num_tokens_across_dp=num_tokens_across_dp, ): hidden_states = model( @@ -148,15 +148,16 @@ class CudaGraphManager: ) if self.hidden_states is None: self.hidden_states = torch.empty_like(hidden_states) - torch.cuda.synchronize() # Capture the graph. graph = torch.cuda.CUDAGraph() with ( + patch("torch.cuda.empty_cache", lambda: None), set_forward_context( attn_metadata, self.vllm_config, num_tokens=batch_size, + cudagraph_runtime_mode=CUDAGraphMode.NONE, num_tokens_across_dp=num_tokens_across_dp, ), torch.cuda.graph(graph, self.pool), @@ -183,7 +184,7 @@ class CudaGraphManager: if is_global_first_rank(): sizes_to_capture = tqdm(sizes_to_capture, desc="Capturing CUDA graphs") - with freeze_gc(), graph_capture(device=self.device): + with graph_capture(device=self.device): for batch_size in sizes_to_capture: self.capture_graph( batch_size, @@ -199,13 +200,3 @@ class CudaGraphManager: self.graphs[batch_size].replay() assert self.hidden_states is not None return self.hidden_states[:batch_size] - - -@contextmanager -def freeze_gc(): - gc.collect() - gc.freeze() - try: - yield - finally: - gc.unfreeze() diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 9ca37ff282d82..9d6e2cf92a8cc 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -298,6 +298,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return 0 start_time = time.perf_counter() + torch.cuda.empty_cache() start_free_gpu_memory = torch.cuda.mem_get_info()[0] with self.maybe_setup_dummy_loras(self.lora_config):