mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:34:27 +08:00
[Core] Freeze gc during cuda graph capture to speed up init (#21146)
Signed-off-by: Codex <codex@openai.com> Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
82ec66f514
commit
f3137cdd81
@ -140,6 +140,7 @@ if TYPE_CHECKING:
|
||||
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None
|
||||
VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 120
|
||||
VLLM_USE_CUDNN_PREFILL: bool = False
|
||||
VLLM_ENABLE_CUDAGRAPH_GC: bool = False
|
||||
VLLM_LOOPBACK_IP: str = ""
|
||||
|
||||
|
||||
@ -968,6 +969,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_USE_TRTLLM_DECODE_ATTENTION":
|
||||
lambda: os.getenv("VLLM_USE_TRTLLM_DECODE_ATTENTION", None),
|
||||
|
||||
# Controls garbage collection during CUDA graph capture.
|
||||
# If set to 0 (default), enables GC freezing to speed up capture time.
|
||||
# If set to 1, allows GC to run during capture.
|
||||
"VLLM_ENABLE_CUDAGRAPH_GC":
|
||||
lambda: bool(int(os.getenv("VLLM_ENABLE_CUDAGRAPH_GC", "0"))),
|
||||
|
||||
# Used to force set up loopback IP
|
||||
"VLLM_LOOPBACK_IP":
|
||||
lambda: os.getenv("VLLM_LOOPBACK_IP", ""),
|
||||
|
||||
@ -2439,10 +2439,25 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
start_time = time.perf_counter()
|
||||
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
|
||||
@contextmanager
|
||||
def freeze_gc():
|
||||
# Optimize garbage collection during CUDA graph capture.
|
||||
# Clean up, then freeze all remaining objects from being included
|
||||
# in future collections.
|
||||
gc.collect()
|
||||
should_freeze = not envs.VLLM_ENABLE_CUDAGRAPH_GC
|
||||
if should_freeze:
|
||||
gc.freeze()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if should_freeze:
|
||||
gc.unfreeze()
|
||||
|
||||
# Trigger CUDA graph capture for specific shapes.
|
||||
# Capture the large shapes first so that the smaller shapes
|
||||
# can reuse the memory pool allocated for the large shapes.
|
||||
with graph_capture(device=self.device):
|
||||
with freeze_gc(), graph_capture(device=self.device):
|
||||
full_cg = self.full_cuda_graph
|
||||
# Only rank 0 should print progress bar during capture
|
||||
compilation_cases = reversed(self.cudagraph_batch_sizes)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user