diff --git a/tests/utils.py b/tests/utils.py index 8eb12e9c4866e..b08009fda0ef8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,7 +4,6 @@ import asyncio import copy import functools -import gc import importlib import json import os @@ -174,11 +173,10 @@ class RemoteOpenAIServer: # GPU memory cleanup try: if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - torch.cuda.synchronize() - # Small delay to ensure cleanup completes - time.sleep(0.5) + devices_to_clear = list(range(torch.cuda.device_count())) + if devices_to_clear: + wait_for_gpu_memory_to_clear(devices=devices_to_clear, + threshold_ratio=0.05) except Exception as e: print(f"GPU cleanup warning: {e}") @@ -280,11 +278,10 @@ class RemoteOpenAIServerCustom(RemoteOpenAIServer): # GPU memory cleaning try: if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - torch.cuda.synchronize() - # Small delay to ensure cleanup completes - time.sleep(0.5) + devices_to_clear = list(range(torch.cuda.device_count())) + if devices_to_clear: + wait_for_gpu_memory_to_clear(devices=devices_to_clear, + threshold_ratio=0.05) except Exception as e: print(f"GPU cleanup warning: {e}")