diff --git a/__init__.py b/__init__.py index be00731..f024196 100644 --- a/__init__.py +++ b/__init__.py @@ -162,6 +162,7 @@ NODE_CONFIG = { "TorchCompileControlNet": {"class": TorchCompileControlNet, "name": "TorchCompileControlNet"}, "PatchModelPatcherOrder": {"class": PatchModelPatcherOrder, "name": "Patch Model Patcher Order"}, "TorchCompileLTXModel": {"class": TorchCompileLTXModel, "name": "TorchCompileLTXModel"}, + "PathchSageAttentionKJ": {"class": PathchSageAttentionKJ, "name": "Pathch Sage Attention KJ"}, #instance diffusion "CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking}, diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index 33b5d52..8593bf6 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -114,6 +114,24 @@ class BaseLoaderKJ: disable_weight_init.Linear = BaseLoaderKJ.original_linear BaseLoaderKJ.cublas_patched = False +class PathchSageAttentionKJ(BaseLoaderKJ): + @classmethod + def INPUT_TYPES(s): + return {"required": { + "model": ("MODEL",), + "sage_attention": (["disabled", "auto", "sageattn_qk_int8_pv_fp16_cuda", "sageattn_qk_int8_pv_fp16_triton", "sageattn_qk_int8_pv_fp8_cuda"], {"default": False, "tooltip": "Patch comfy attention to use sageattn."}), + }} + + RETURN_TYPES = ("MODEL", ) + FUNCTION = "patch" + DESCRIPTION = "Experimental node for patching attention mode." + EXPERIMENTAL = True + CATEGORY = "KJNodes/experimental" + + def patch(self, model, sage_attention): + self._patch_modules(False, sage_attention) + return model, + class CheckpointLoaderKJ(BaseLoaderKJ): @classmethod def INPUT_TYPES(s):