[Core] Freeze gc during cuda graph capture to speed up init (#21146)

Signed-off-by: Codex <codex@openai.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-07-23 20:20:14 -04:00 committed by GitHub
parent 82ec66f514
commit f3137cdd81
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 23 additions and 1 deletions

View File

@ -140,6 +140,7 @@ if TYPE_CHECKING:
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None
VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 120
VLLM_USE_CUDNN_PREFILL: bool = False
VLLM_ENABLE_CUDAGRAPH_GC: bool = False
VLLM_LOOPBACK_IP: str = ""
@ -968,6 +969,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_TRTLLM_DECODE_ATTENTION":
lambda: os.getenv("VLLM_USE_TRTLLM_DECODE_ATTENTION", None),
# Controls garbage collection during CUDA graph capture.
# If set to 0 (default), enables GC freezing to speed up capture time.
# If set to 1, allows GC to run during capture.
"VLLM_ENABLE_CUDAGRAPH_GC":
lambda: bool(int(os.getenv("VLLM_ENABLE_CUDAGRAPH_GC", "0"))),
# Used to force set up loopback IP
"VLLM_LOOPBACK_IP":
lambda: os.getenv("VLLM_LOOPBACK_IP", ""),

View File

@ -2439,10 +2439,25 @@ class GPUModelRunner(LoRAModelRunnerMixin):
start_time = time.perf_counter()
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
@contextmanager
def freeze_gc():
# Optimize garbage collection during CUDA graph capture.
# Clean up, then freeze all remaining objects from being included
# in future collections.
gc.collect()
should_freeze = not envs.VLLM_ENABLE_CUDAGRAPH_GC
if should_freeze:
gc.freeze()
try:
yield
finally:
if should_freeze:
gc.unfreeze()
# Trigger CUDA graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
with graph_capture(device=self.device):
with freeze_gc(), graph_capture(device=self.device):
full_cg = self.full_cuda_graph
# Only rank 0 should print progress bar during capture
compilation_cases = reversed(self.cudagraph_batch_sizes)