diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py index e58cf5911282e..0e9b0fbe2c028 100644 --- a/vllm/benchmarks/serve.py +++ b/vllm/benchmarks/serve.py @@ -19,7 +19,6 @@ On the client side, run: import argparse import asyncio import contextlib -import gc import importlib.util import json import os @@ -49,6 +48,7 @@ from vllm.benchmarks.lib.endpoint_request_func import ( from vllm.benchmarks.lib.ready_checker import wait_for_endpoint from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.utils.gc_utils import freeze_gc_heap MILLISECONDS_TO_SECONDS_CONVERSION = 1000 @@ -1414,8 +1414,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: percentile_metrics: str = args.percentile_metrics or default_percentile_metrics # Avoid GC processing "static" data - reduce pause times. - gc.collect() - gc.freeze() + freeze_gc_heap() benchmark_result = await benchmark( task_type=task_type, diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index a9b01e82562b9..c78e6a32733c1 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1483,6 +1483,9 @@ def destroy_distributed_environment(): def cleanup_dist_env_and_memory(shutdown_ray: bool = False): + # Ensure all objects are not freezed before cleanup + gc.unfreeze() + destroy_model_parallel() destroy_distributed_environment() if shutdown_ray: diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index c8c8d5c034d55..51191879e4780 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - import asyncio -import gc import hashlib import importlib import inspect @@ -118,6 +116,7 @@ from vllm.reasoning import ReasoningParserManager from vllm.tasks import POOLING_TASKS from vllm.usage.usage_lib import UsageContext from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm.utils.gc_utils import freeze_gc_heap from vllm.utils.network_utils import is_valid_ipv6_address from vllm.utils.system_utils import decorate_logs, set_ulimit from vllm.v1.engine.exceptions import EngineDeadError @@ -153,8 +152,7 @@ async def lifespan(app: FastAPI): # Mark the startup heap as static so that it's ignored by GC. # Reduces pause times of oldest generation collections. - gc.collect() - gc.freeze() + freeze_gc_heap() try: yield finally: diff --git a/vllm/utils/gc_utils.py b/vllm/utils/gc_utils.py index 4dd85ef26f34a..160ac9ac263a9 100644 --- a/vllm/utils/gc_utils.py +++ b/vllm/utils/gc_utils.py @@ -89,6 +89,21 @@ class GCDebugger: ) +def freeze_gc_heap() -> None: + """ + Freeze all objects tracked by the garbage collector. It should be invoked + after server init / warmup, to reduce GC overhead from static objects + during serving time. + """ + # Ensure all static objects are pushed down to the oldest generation for + # freeze + gc.collect(0) + gc.collect(1) + gc.collect(2) + # Freeze all GC tracked objects + gc.freeze() + + def maybe_attach_gc_debug_callback() -> None: """ Attached a callback for GC debug when VLLM_GC_DEBUG is enabled. diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index c3efd52130cce..ffb5232e770d1 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import gc import os import queue import signal @@ -27,7 +26,10 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.cache import engine_receiver_cache_from_config from vllm.tasks import POOLING_TASKS, SupportedTask from vllm.transformers_utils.config import maybe_register_config_serialize_by_value -from vllm.utils.gc_utils import maybe_attach_gc_debug_callback +from vllm.utils.gc_utils import ( + freeze_gc_heap, + maybe_attach_gc_debug_callback, +) from vllm.utils.hashing import get_hash_fn_by_name from vllm.utils.network_utils import make_zmq_socket from vllm.utils.system_utils import decorate_logs, set_process_title @@ -197,6 +199,10 @@ class EngineCore: self.step if self.batch_queue is None else self.step_with_batch_queue ) + # Mark the startup heap as static so that it's ignored by GC. + # Reduces pause times of oldest generation collections. + freeze_gc_heap() + def _initialize_kv_caches( self, vllm_config: VllmConfig ) -> tuple[int, int, KVCacheConfig]: @@ -651,11 +657,6 @@ class EngineCoreProc(EngineCore): assert addresses.coordinator_input is not None logger.info("Waiting for READY message from DP Coordinator...") - # Mark the startup heap as static so that it's ignored by GC. - # Reduces pause times of oldest generation collections. - gc.collect() - gc.freeze() - # If enable, attach GC debugger after static variable freeze. maybe_attach_gc_debug_callback()