mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2026-01-28 03:57:15 +08:00
Use pre_run callback for applying sageattention patch
It's still global patch, but now applied before using the model thus better allowing disabling it for unsupported models
This commit is contained in:
parent
8ecf5cd05e
commit
9026379046
@ -8,9 +8,15 @@ import folder_paths
|
||||
import comfy.model_management as mm
|
||||
from comfy.cli_args import args
|
||||
|
||||
orig_attention = comfy_attention.optimized_attention
|
||||
original_patch_model = comfy.model_patcher.ModelPatcher.patch_model
|
||||
original_load_lora_for_models = comfy.sd.load_lora_for_models
|
||||
|
||||
_initialized = False
|
||||
_original_functions = {}
|
||||
|
||||
if not _initialized:
|
||||
_original_functions["orig_attention"] = comfy_attention.optimized_attention
|
||||
_original_functions["original_patch_model"] = comfy.model_patcher.ModelPatcher.patch_model
|
||||
_original_functions["original_load_lora_for_models"] = comfy.sd.load_lora_for_models
|
||||
_initialized = True
|
||||
|
||||
class BaseLoaderKJ:
|
||||
original_linear = None
|
||||
@ -47,6 +53,7 @@ class BaseLoaderKJ:
|
||||
|
||||
@torch.compiler.disable()
|
||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||
print("SAGE")
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
tensor_layout="HND"
|
||||
@ -86,12 +93,12 @@ class BaseLoaderKJ:
|
||||
comfy.ldm.wan.model.optimized_attention = attention_sage
|
||||
|
||||
else:
|
||||
comfy_attention.optimized_attention = orig_attention
|
||||
comfy.ldm.hunyuan_video.model.optimized_attention = orig_attention
|
||||
comfy.ldm.flux.math.optimized_attention = orig_attention
|
||||
comfy.ldm.genmo.joint_model.asymm_models_joint.optimized_attention = orig_attention
|
||||
comfy.ldm.cosmos.blocks.optimized_attention = orig_attention
|
||||
comfy.ldm.wan.model.optimized_attention = orig_attention
|
||||
comfy_attention.optimized_attention = _original_functions.get("orig_attention")
|
||||
comfy.ldm.hunyuan_video.model.optimized_attention = _original_functions.get("orig_attention")
|
||||
comfy.ldm.flux.math.optimized_attention = _original_functions.get("orig_attention")
|
||||
comfy.ldm.genmo.joint_model.asymm_models_joint.optimized_attention = _original_functions.get("orig_attention")
|
||||
comfy.ldm.cosmos.blocks.optimized_attention = _original_functions.get("orig_attention")
|
||||
comfy.ldm.wan.model.optimized_attention = _original_functions.get("orig_attention")
|
||||
|
||||
if patch_cublaslinear:
|
||||
if not BaseLoaderKJ.cublas_patched:
|
||||
@ -137,8 +144,13 @@ class PathchSageAttentionKJ(BaseLoaderKJ):
|
||||
CATEGORY = "KJNodes/experimental"
|
||||
|
||||
def patch(self, model, sage_attention):
|
||||
self._patch_modules(False, sage_attention)
|
||||
return model,
|
||||
from comfy.patcher_extension import CallbacksMP
|
||||
model_clone = model.clone()
|
||||
def patch_attention(model):
|
||||
self._patch_modules(False, sage_attention)
|
||||
model_clone.add_callback(CallbacksMP.ON_PRE_RUN,patch_attention)
|
||||
|
||||
return model_clone,
|
||||
|
||||
class CheckpointLoaderKJ(BaseLoaderKJ):
|
||||
@classmethod
|
||||
@ -318,8 +330,8 @@ class PatchModelPatcherOrder:
|
||||
comfy.model_patcher.ModelPatcher.patch_model = patched_patch_model
|
||||
comfy.sd.load_lora_for_models = patched_load_lora_for_models
|
||||
else:
|
||||
comfy.model_patcher.ModelPatcher.patch_model = original_patch_model
|
||||
comfy.sd.load_lora_for_models = original_load_lora_for_models
|
||||
comfy.model_patcher.ModelPatcher.patch_model = _original_functions.get("original_patch_model")
|
||||
comfy.sd.load_lora_for_models = _original_functions.get("original_load_lora_for_models")
|
||||
|
||||
return model,
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user