diff --git a/tests/utils.py b/tests/utils.py index b061caf6a4489..38bdb4007d2ba 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -170,14 +170,18 @@ class RemoteOpenAIServer: except subprocess.TimeoutExpired: # force kill if needed self.proc.kill() - # GPU memory cleanup + self.__cleanup_gpu_memory() + + @staticmethod + def __cleanup_gpu_memory(): try: - if current_platform.is_cuda() or current_platform.is_rocm(): + if current_platform.is_cuda_alike(): num_devices = cuda_device_count_stateless() if num_devices > 0: wait_for_gpu_memory_to_clear(devices=list( range(num_devices)), - threshold_ratio=0.05) + threshold_ratio=0.05, + timeout_s=60) except Exception as e: print(f"GPU cleanup warning: {e}") @@ -276,16 +280,7 @@ class RemoteOpenAIServerCustom(RemoteOpenAIServer): # force kill if needed self.proc.kill() - # GPU memory cleaning - try: - if current_platform.is_cuda() or current_platform.is_rocm(): - num_devices = cuda_device_count_stateless() - if num_devices > 0: - wait_for_gpu_memory_to_clear(devices=list( - range(num_devices)), - threshold_ratio=0.05) - except Exception as e: - print(f"GPU cleanup warning: {e}") + self.__cleanup_gpu_memory() def _test_completion(