Use pre_run callback for applying sageattention patch

It's still global patch, but now applied before using the model thus better allowing disabling it for unsupported models
This commit is contained in:
kijai 2025-04-19 16:26:59 +03:00
parent 8ecf5cd05e
commit 9026379046

View File

@ -8,9 +8,15 @@ import folder_paths
import comfy.model_management as mm import comfy.model_management as mm
from comfy.cli_args import args from comfy.cli_args import args
orig_attention = comfy_attention.optimized_attention
original_patch_model = comfy.model_patcher.ModelPatcher.patch_model _initialized = False
original_load_lora_for_models = comfy.sd.load_lora_for_models _original_functions = {}
if not _initialized:
_original_functions["orig_attention"] = comfy_attention.optimized_attention
_original_functions["original_patch_model"] = comfy.model_patcher.ModelPatcher.patch_model
_original_functions["original_load_lora_for_models"] = comfy.sd.load_lora_for_models
_initialized = True
class BaseLoaderKJ: class BaseLoaderKJ:
original_linear = None original_linear = None
@ -47,6 +53,7 @@ class BaseLoaderKJ:
@torch.compiler.disable() @torch.compiler.disable()
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
print("SAGE")
if skip_reshape: if skip_reshape:
b, _, _, dim_head = q.shape b, _, _, dim_head = q.shape
tensor_layout="HND" tensor_layout="HND"
@ -86,12 +93,12 @@ class BaseLoaderKJ:
comfy.ldm.wan.model.optimized_attention = attention_sage comfy.ldm.wan.model.optimized_attention = attention_sage
else: else:
comfy_attention.optimized_attention = orig_attention comfy_attention.optimized_attention = _original_functions.get("orig_attention")
comfy.ldm.hunyuan_video.model.optimized_attention = orig_attention comfy.ldm.hunyuan_video.model.optimized_attention = _original_functions.get("orig_attention")
comfy.ldm.flux.math.optimized_attention = orig_attention comfy.ldm.flux.math.optimized_attention = _original_functions.get("orig_attention")
comfy.ldm.genmo.joint_model.asymm_models_joint.optimized_attention = orig_attention comfy.ldm.genmo.joint_model.asymm_models_joint.optimized_attention = _original_functions.get("orig_attention")
comfy.ldm.cosmos.blocks.optimized_attention = orig_attention comfy.ldm.cosmos.blocks.optimized_attention = _original_functions.get("orig_attention")
comfy.ldm.wan.model.optimized_attention = orig_attention comfy.ldm.wan.model.optimized_attention = _original_functions.get("orig_attention")
if patch_cublaslinear: if patch_cublaslinear:
if not BaseLoaderKJ.cublas_patched: if not BaseLoaderKJ.cublas_patched:
@ -137,8 +144,13 @@ class PathchSageAttentionKJ(BaseLoaderKJ):
CATEGORY = "KJNodes/experimental" CATEGORY = "KJNodes/experimental"
def patch(self, model, sage_attention): def patch(self, model, sage_attention):
self._patch_modules(False, sage_attention) from comfy.patcher_extension import CallbacksMP
return model, model_clone = model.clone()
def patch_attention(model):
self._patch_modules(False, sage_attention)
model_clone.add_callback(CallbacksMP.ON_PRE_RUN,patch_attention)
return model_clone,
class CheckpointLoaderKJ(BaseLoaderKJ): class CheckpointLoaderKJ(BaseLoaderKJ):
@classmethod @classmethod
@ -318,8 +330,8 @@ class PatchModelPatcherOrder:
comfy.model_patcher.ModelPatcher.patch_model = patched_patch_model comfy.model_patcher.ModelPatcher.patch_model = patched_patch_model
comfy.sd.load_lora_for_models = patched_load_lora_for_models comfy.sd.load_lora_for_models = patched_load_lora_for_models
else: else:
comfy.model_patcher.ModelPatcher.patch_model = original_patch_model comfy.model_patcher.ModelPatcher.patch_model = _original_functions.get("original_patch_model")
comfy.sd.load_lora_for_models = original_load_lora_for_models comfy.sd.load_lora_for_models = _original_functions.get("original_load_lora_for_models")
return model, return model,