From 4f01f72bedb6e4c1956795690806b729d8b66955 Mon Sep 17 00:00:00 2001 From: patientx Date: Tue, 14 Jan 2025 20:03:55 +0300 Subject: [PATCH] Update zluda.py --- comfy/zluda.py | 47 +++++++++++++++-------------------------------- 1 file changed, 15 insertions(+), 32 deletions(-) diff --git a/comfy/zluda.py b/comfy/zluda.py index 4f52ccf39..4f02aeb2e 100644 --- a/comfy/zluda.py +++ b/comfy/zluda.py @@ -1,33 +1,16 @@ -try: - torch_device_name = get_torch_device_name(get_torch_device()) - - if "[ZLUDA]" in torch_device_name: - _torch_stft = torch.stft - _torch_istft = torch.istft +import torch - def z_stft(input: torch.Tensor, window: torch.Tensor, *args, **kwargs): - return _torch_stft(input=input.cpu(), window=window.cpu(), *args, **kwargs).to(input.device) - - def z_istft(input: torch.Tensor, window: torch.Tensor, *args, **kwargs): - return _torch_istft(input=input.cpu(), window=window.cpu(), *args, **kwargs).to(input.device) - - def z_jit(f, *_, **__): - f.graph = torch._C.Graph() - return f - - # hijacks - torch.stft = z_stft - torch.istft = z_istft - torch.jit.script = z_jit - print(" ") - print("***----------------------ZLUDA--------------------------***") - print(" :: ZLUDA detected, disabling non-supported functions.") - torch.backends.cudnn.enabled = False - print(" :: (cuDNN, flash_sdp, mem_efficient_sdp disabled) ") - torch.backends.cuda.enable_flash_sdp(False) - torch.backends.cuda.enable_math_sdp(True) - torch.backends.cuda.enable_mem_efficient_sdp(False) - - print("***-----------------------------------------------------***") - print(" :: Device:", torch_device_name) - print(" ") \ No newline at end of file +if torch.cuda.get_device_name().endswith("[ZLUDA]"): + print(" ") + print("***----------------------ZLUDA-----------------------------***") + print(" :: ZLUDA detected, disabling non-supported functions.") + torch.backends.cudnn.enabled = False + print(" :: CuDNN, flash_sdp, math_sdp, mem_efficient_sdp disabled) ") + torch.backends.cuda.enable_flash_sdp(False) + torch.backends.cuda.enable_math_sdp(True) + torch.backends.cuda.enable_mem_efficient_sdp(False) + print("***--------------------------------------------------------***") + print(" :: Device:", torch.cuda.get_device_name()) + print(" ") +else: + print(" :: ZLUDA isn't detected, please try patching it.")