diff --git a/model_loading.py b/model_loading.py index 1c7aa42..ad8de59 100644 --- a/model_loading.py +++ b/model_loading.py @@ -139,6 +139,12 @@ class DownloadAndLoadCogVideoModel: enable_sequential_cpu_offload=False, block_edit=None, lora=None, compile_args=None, attention_mode="sdpa", load_device="main_device"): + if "sage" in attention_mode: + try: + from sageattention import sageattn + except Exception as e: + raise ValueError(f"Can't import SageAttention: {str(e)}") + if precision == "fp16" and "1.5" in model: raise ValueError("1.5 models do not currently work in fp16") @@ -445,6 +451,12 @@ class DownloadAndLoadCogVideoGGUFModel: def loadmodel(self, model, vae_precision, fp8_fastmode, load_device, enable_sequential_cpu_offload, block_edit=None, compile_args=None, attention_mode="sdpa"): + + if "sage" in attention_mode: + try: + from sageattention import sageattn + except Exception as e: + raise ValueError(f"Can't import SageAttention: {str(e)}") device = mm.get_torch_device() offload_device = mm.unet_offload_device() @@ -618,6 +630,12 @@ class CogVideoXModelLoader: def loadmodel(self, model, base_precision, load_device, enable_sequential_cpu_offload, block_edit=None, compile_args=None, lora=None, attention_mode="sdpa", quantization="disabled"): + + if "sage" in attention_mode: + try: + from sageattention import sageattn + except Exception as e: + raise ValueError(f"Can't import SageAttention: {str(e)}") device = mm.get_torch_device() offload_device = mm.unet_offload_device()