diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index 69b83d9..a49a707 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -92,6 +92,7 @@ class BaseLoaderKJ: comfy.ldm.wan.model.optimized_attention = attention_sage else: + print("Restoring initial comfy attention") comfy_attention.optimized_attention = _original_functions.get("orig_attention") comfy.ldm.hunyuan_video.model.optimized_attention = _original_functions.get("orig_attention") comfy.ldm.flux.math.optimized_attention = _original_functions.get("orig_attention") @@ -145,9 +146,13 @@ class PathchSageAttentionKJ(BaseLoaderKJ): def patch(self, model, sage_attention): model_clone = model.clone() - def patch_attention(model): + def patch_attention_enable(model): self._patch_modules(False, sage_attention) - model_clone.add_callback(CallbacksMP.ON_PRE_RUN,patch_attention) + def patch_attention_disable(model): + self._patch_modules(False, "disabled") + + model_clone.add_callback(CallbacksMP.ON_PRE_RUN, patch_attention_enable) + model_clone.add_callback(CallbacksMP.ON_CLEANUP, patch_attention_disable) return model_clone,