From 902637904643401e0e6a87507a6febb2365edec0 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sat, 19 Apr 2025 16:26:59 +0300 Subject: [PATCH] Use pre_run callback for applying sageattention patch It's still global patch, but now applied before using the model thus better allowing disabling it for unsupported models --- nodes/model_optimization_nodes.py | 38 ++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index cafc8a0..52c7d77 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -8,9 +8,15 @@ import folder_paths import comfy.model_management as mm from comfy.cli_args import args -orig_attention = comfy_attention.optimized_attention -original_patch_model = comfy.model_patcher.ModelPatcher.patch_model -original_load_lora_for_models = comfy.sd.load_lora_for_models + +_initialized = False +_original_functions = {} + +if not _initialized: + _original_functions["orig_attention"] = comfy_attention.optimized_attention + _original_functions["original_patch_model"] = comfy.model_patcher.ModelPatcher.patch_model + _original_functions["original_load_lora_for_models"] = comfy.sd.load_lora_for_models + _initialized = True class BaseLoaderKJ: original_linear = None @@ -47,6 +53,7 @@ class BaseLoaderKJ: @torch.compiler.disable() def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): + print("SAGE") if skip_reshape: b, _, _, dim_head = q.shape tensor_layout="HND" @@ -86,12 +93,12 @@ class BaseLoaderKJ: comfy.ldm.wan.model.optimized_attention = attention_sage else: - comfy_attention.optimized_attention = orig_attention - comfy.ldm.hunyuan_video.model.optimized_attention = orig_attention - comfy.ldm.flux.math.optimized_attention = orig_attention - comfy.ldm.genmo.joint_model.asymm_models_joint.optimized_attention = orig_attention - comfy.ldm.cosmos.blocks.optimized_attention = orig_attention - comfy.ldm.wan.model.optimized_attention = orig_attention + comfy_attention.optimized_attention = _original_functions.get("orig_attention") + comfy.ldm.hunyuan_video.model.optimized_attention = _original_functions.get("orig_attention") + comfy.ldm.flux.math.optimized_attention = _original_functions.get("orig_attention") + comfy.ldm.genmo.joint_model.asymm_models_joint.optimized_attention = _original_functions.get("orig_attention") + comfy.ldm.cosmos.blocks.optimized_attention = _original_functions.get("orig_attention") + comfy.ldm.wan.model.optimized_attention = _original_functions.get("orig_attention") if patch_cublaslinear: if not BaseLoaderKJ.cublas_patched: @@ -137,8 +144,13 @@ class PathchSageAttentionKJ(BaseLoaderKJ): CATEGORY = "KJNodes/experimental" def patch(self, model, sage_attention): - self._patch_modules(False, sage_attention) - return model, + from comfy.patcher_extension import CallbacksMP + model_clone = model.clone() + def patch_attention(model): + self._patch_modules(False, sage_attention) + model_clone.add_callback(CallbacksMP.ON_PRE_RUN,patch_attention) + + return model_clone, class CheckpointLoaderKJ(BaseLoaderKJ): @classmethod @@ -318,8 +330,8 @@ class PatchModelPatcherOrder: comfy.model_patcher.ModelPatcher.patch_model = patched_patch_model comfy.sd.load_lora_for_models = patched_load_lora_for_models else: - comfy.model_patcher.ModelPatcher.patch_model = original_patch_model - comfy.sd.load_lora_for_models = original_load_lora_for_models + comfy.model_patcher.ModelPatcher.patch_model = _original_functions.get("original_patch_model") + comfy.sd.load_lora_for_models = _original_functions.get("original_load_lora_for_models") return model,