smarter sage patch

This commit is contained in:
kijai 2025-05-07 19:08:25 +03:00
parent ca07b9dadc
commit bec42252c6

View File

@ -92,6 +92,7 @@ class BaseLoaderKJ:
comfy.ldm.wan.model.optimized_attention = attention_sage
else:
print("Restoring initial comfy attention")
comfy_attention.optimized_attention = _original_functions.get("orig_attention")
comfy.ldm.hunyuan_video.model.optimized_attention = _original_functions.get("orig_attention")
comfy.ldm.flux.math.optimized_attention = _original_functions.get("orig_attention")
@ -145,9 +146,13 @@ class PathchSageAttentionKJ(BaseLoaderKJ):
def patch(self, model, sage_attention):
model_clone = model.clone()
def patch_attention(model):
def patch_attention_enable(model):
self._patch_modules(False, sage_attention)
model_clone.add_callback(CallbacksMP.ON_PRE_RUN,patch_attention)
def patch_attention_disable(model):
self._patch_modules(False, "disabled")
model_clone.add_callback(CallbacksMP.ON_PRE_RUN, patch_attention_enable)
model_clone.add_callback(CallbacksMP.ON_CLEANUP, patch_attention_disable)
return model_clone,