mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 17:56:08 +08:00
[Core] Freeze gc during cuda graph capture to speed up init (#21146)
Signed-off-by: Codex <codex@openai.com> Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
82ec66f514
commit
f3137cdd81
@ -140,6 +140,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None
|
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None
|
||||||
VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 120
|
VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 120
|
||||||
VLLM_USE_CUDNN_PREFILL: bool = False
|
VLLM_USE_CUDNN_PREFILL: bool = False
|
||||||
|
VLLM_ENABLE_CUDAGRAPH_GC: bool = False
|
||||||
VLLM_LOOPBACK_IP: str = ""
|
VLLM_LOOPBACK_IP: str = ""
|
||||||
|
|
||||||
|
|
||||||
@ -968,6 +969,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_USE_TRTLLM_DECODE_ATTENTION":
|
"VLLM_USE_TRTLLM_DECODE_ATTENTION":
|
||||||
lambda: os.getenv("VLLM_USE_TRTLLM_DECODE_ATTENTION", None),
|
lambda: os.getenv("VLLM_USE_TRTLLM_DECODE_ATTENTION", None),
|
||||||
|
|
||||||
|
# Controls garbage collection during CUDA graph capture.
|
||||||
|
# If set to 0 (default), enables GC freezing to speed up capture time.
|
||||||
|
# If set to 1, allows GC to run during capture.
|
||||||
|
"VLLM_ENABLE_CUDAGRAPH_GC":
|
||||||
|
lambda: bool(int(os.getenv("VLLM_ENABLE_CUDAGRAPH_GC", "0"))),
|
||||||
|
|
||||||
# Used to force set up loopback IP
|
# Used to force set up loopback IP
|
||||||
"VLLM_LOOPBACK_IP":
|
"VLLM_LOOPBACK_IP":
|
||||||
lambda: os.getenv("VLLM_LOOPBACK_IP", ""),
|
lambda: os.getenv("VLLM_LOOPBACK_IP", ""),
|
||||||
|
|||||||
@ -2439,10 +2439,25 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def freeze_gc():
|
||||||
|
# Optimize garbage collection during CUDA graph capture.
|
||||||
|
# Clean up, then freeze all remaining objects from being included
|
||||||
|
# in future collections.
|
||||||
|
gc.collect()
|
||||||
|
should_freeze = not envs.VLLM_ENABLE_CUDAGRAPH_GC
|
||||||
|
if should_freeze:
|
||||||
|
gc.freeze()
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
if should_freeze:
|
||||||
|
gc.unfreeze()
|
||||||
|
|
||||||
# Trigger CUDA graph capture for specific shapes.
|
# Trigger CUDA graph capture for specific shapes.
|
||||||
# Capture the large shapes first so that the smaller shapes
|
# Capture the large shapes first so that the smaller shapes
|
||||||
# can reuse the memory pool allocated for the large shapes.
|
# can reuse the memory pool allocated for the large shapes.
|
||||||
with graph_capture(device=self.device):
|
with freeze_gc(), graph_capture(device=self.device):
|
||||||
full_cg = self.full_cuda_graph
|
full_cg = self.full_cuda_graph
|
||||||
# Only rank 0 should print progress bar during capture
|
# Only rank 0 should print progress bar during capture
|
||||||
compilation_cases = reversed(self.cudagraph_batch_sizes)
|
compilation_cases = reversed(self.cudagraph_batch_sizes)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user