diff --git a/tests/utils.py b/tests/utils.py index b08009fda0ef8..b061caf6a4489 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -172,10 +172,11 @@ class RemoteOpenAIServer: self.proc.kill() # GPU memory cleanup try: - if torch.cuda.is_available(): - devices_to_clear = list(range(torch.cuda.device_count())) - if devices_to_clear: - wait_for_gpu_memory_to_clear(devices=devices_to_clear, + if current_platform.is_cuda() or current_platform.is_rocm(): + num_devices = cuda_device_count_stateless() + if num_devices > 0: + wait_for_gpu_memory_to_clear(devices=list( + range(num_devices)), threshold_ratio=0.05) except Exception as e: print(f"GPU cleanup warning: {e}") @@ -277,10 +278,11 @@ class RemoteOpenAIServerCustom(RemoteOpenAIServer): # GPU memory cleaning try: - if torch.cuda.is_available(): - devices_to_clear = list(range(torch.cuda.device_count())) - if devices_to_clear: - wait_for_gpu_memory_to_clear(devices=devices_to_clear, + if current_platform.is_cuda() or current_platform.is_rocm(): + num_devices = cuda_device_count_stateless() + if num_devices > 0: + wait_for_gpu_memory_to_clear(devices=list( + range(num_devices)), threshold_ratio=0.05) except Exception as e: print(f"GPU cleanup warning: {e}")