[Model Runner V2] Minor fix for cudagraph_utils (#29256)

This commit is contained in:
Woosuk Kwon 2025-11-22 20:12:50 -08:00 committed by GitHub
parent 389aa1b2eb
commit 20ee418adc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 14 deletions

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
from contextlib import contextmanager
from unittest.mock import patch
import numpy as np
import torch
@ -140,6 +139,7 @@ class CudaGraphManager:
attn_metadata,
self.vllm_config,
num_tokens=batch_size,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp,
):
hidden_states = model(
@ -148,15 +148,16 @@ class CudaGraphManager:
)
if self.hidden_states is None:
self.hidden_states = torch.empty_like(hidden_states)
torch.cuda.synchronize()
# Capture the graph.
graph = torch.cuda.CUDAGraph()
with (
patch("torch.cuda.empty_cache", lambda: None),
set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=batch_size,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp,
),
torch.cuda.graph(graph, self.pool),
@ -183,7 +184,7 @@ class CudaGraphManager:
if is_global_first_rank():
sizes_to_capture = tqdm(sizes_to_capture, desc="Capturing CUDA graphs")
with freeze_gc(), graph_capture(device=self.device):
with graph_capture(device=self.device):
for batch_size in sizes_to_capture:
self.capture_graph(
batch_size,
@ -199,13 +200,3 @@ class CudaGraphManager:
self.graphs[batch_size].replay()
assert self.hidden_states is not None
return self.hidden_states[:batch_size]
@contextmanager
def freeze_gc():
gc.collect()
gc.freeze()
try:
yield
finally:
gc.unfreeze()

View File

@ -298,6 +298,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return 0
start_time = time.perf_counter()
torch.cuda.empty_cache()
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
with self.maybe_setup_dummy_loras(self.lora_config):