error earlier if sageattention fails to import

This commit is contained in:
Jukka Seppänen 2024-11-24 22:27:26 +02:00
parent 8d6e53b556
commit f1b3bc0abf

View File

@ -139,6 +139,12 @@ class DownloadAndLoadCogVideoModel:
enable_sequential_cpu_offload=False, block_edit=None, lora=None, compile_args=None,
attention_mode="sdpa", load_device="main_device"):
if "sage" in attention_mode:
try:
from sageattention import sageattn
except Exception as e:
raise ValueError(f"Can't import SageAttention: {str(e)}")
if precision == "fp16" and "1.5" in model:
raise ValueError("1.5 models do not currently work in fp16")
@ -446,6 +452,12 @@ class DownloadAndLoadCogVideoGGUFModel:
def loadmodel(self, model, vae_precision, fp8_fastmode, load_device, enable_sequential_cpu_offload,
block_edit=None, compile_args=None, attention_mode="sdpa"):
if "sage" in attention_mode:
try:
from sageattention import sageattn
except Exception as e:
raise ValueError(f"Can't import SageAttention: {str(e)}")
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
mm.soft_empty_cache()
@ -619,6 +631,12 @@ class CogVideoXModelLoader:
def loadmodel(self, model, base_precision, load_device, enable_sequential_cpu_offload,
block_edit=None, compile_args=None, lora=None, attention_mode="sdpa", quantization="disabled"):
if "sage" in attention_mode:
try:
from sageattention import sageattn
except Exception as e:
raise ValueError(f"Can't import SageAttention: {str(e)}")
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
manual_offloading = True