Use pre_run callback for applying sageattention patch

It's still global patch, but now applied before using the model thus better allowing disabling it for unsupported models
This commit is contained in:
kijai 2025-04-19 16:26:59 +03:00
parent 8ecf5cd05e
commit 9026379046

View File

@ -8,9 +8,15 @@ import folder_paths
import comfy.model_management as mm
from comfy.cli_args import args
orig_attention = comfy_attention.optimized_attention
original_patch_model = comfy.model_patcher.ModelPatcher.patch_model
original_load_lora_for_models = comfy.sd.load_lora_for_models
_initialized = False
_original_functions = {}
if not _initialized:
_original_functions["orig_attention"] = comfy_attention.optimized_attention
_original_functions["original_patch_model"] = comfy.model_patcher.ModelPatcher.patch_model
_original_functions["original_load_lora_for_models"] = comfy.sd.load_lora_for_models
_initialized = True
class BaseLoaderKJ:
original_linear = None
@ -47,6 +53,7 @@ class BaseLoaderKJ:
@torch.compiler.disable()
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
print("SAGE")
if skip_reshape:
b, _, _, dim_head = q.shape
tensor_layout="HND"
@ -86,12 +93,12 @@ class BaseLoaderKJ:
comfy.ldm.wan.model.optimized_attention = attention_sage
else:
comfy_attention.optimized_attention = orig_attention
comfy.ldm.hunyuan_video.model.optimized_attention = orig_attention
comfy.ldm.flux.math.optimized_attention = orig_attention
comfy.ldm.genmo.joint_model.asymm_models_joint.optimized_attention = orig_attention
comfy.ldm.cosmos.blocks.optimized_attention = orig_attention
comfy.ldm.wan.model.optimized_attention = orig_attention
comfy_attention.optimized_attention = _original_functions.get("orig_attention")
comfy.ldm.hunyuan_video.model.optimized_attention = _original_functions.get("orig_attention")
comfy.ldm.flux.math.optimized_attention = _original_functions.get("orig_attention")
comfy.ldm.genmo.joint_model.asymm_models_joint.optimized_attention = _original_functions.get("orig_attention")
comfy.ldm.cosmos.blocks.optimized_attention = _original_functions.get("orig_attention")
comfy.ldm.wan.model.optimized_attention = _original_functions.get("orig_attention")
if patch_cublaslinear:
if not BaseLoaderKJ.cublas_patched:
@ -137,8 +144,13 @@ class PathchSageAttentionKJ(BaseLoaderKJ):
CATEGORY = "KJNodes/experimental"
def patch(self, model, sage_attention):
self._patch_modules(False, sage_attention)
return model,
from comfy.patcher_extension import CallbacksMP
model_clone = model.clone()
def patch_attention(model):
self._patch_modules(False, sage_attention)
model_clone.add_callback(CallbacksMP.ON_PRE_RUN,patch_attention)
return model_clone,
class CheckpointLoaderKJ(BaseLoaderKJ):
@classmethod
@ -318,8 +330,8 @@ class PatchModelPatcherOrder:
comfy.model_patcher.ModelPatcher.patch_model = patched_patch_model
comfy.sd.load_lora_for_models = patched_load_lora_for_models
else:
comfy.model_patcher.ModelPatcher.patch_model = original_patch_model
comfy.sd.load_lora_for_models = original_load_lora_for_models
comfy.model_patcher.ModelPatcher.patch_model = _original_functions.get("original_patch_model")
comfy.sd.load_lora_for_models = _original_functions.get("original_load_lora_for_models")
return model,