Add PathchSageAttentionKJ

This commit is contained in:
kijai 2024-12-22 22:31:20 +02:00
parent f5ce7d017b
commit cdbd38213f
2 changed files with 19 additions and 0 deletions

View File

@ -162,6 +162,7 @@ NODE_CONFIG = {
"TorchCompileControlNet": {"class": TorchCompileControlNet, "name": "TorchCompileControlNet"},
"PatchModelPatcherOrder": {"class": PatchModelPatcherOrder, "name": "Patch Model Patcher Order"},
"TorchCompileLTXModel": {"class": TorchCompileLTXModel, "name": "TorchCompileLTXModel"},
"PathchSageAttentionKJ": {"class": PathchSageAttentionKJ, "name": "Pathch Sage Attention KJ"},
#instance diffusion
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},

View File

@ -114,6 +114,24 @@ class BaseLoaderKJ:
disable_weight_init.Linear = BaseLoaderKJ.original_linear
BaseLoaderKJ.cublas_patched = False
class PathchSageAttentionKJ(BaseLoaderKJ):
@classmethod
def INPUT_TYPES(s):
return {"required": {
"model": ("MODEL",),
"sage_attention": (["disabled", "auto", "sageattn_qk_int8_pv_fp16_cuda", "sageattn_qk_int8_pv_fp16_triton", "sageattn_qk_int8_pv_fp8_cuda"], {"default": False, "tooltip": "Patch comfy attention to use sageattn."}),
}}
RETURN_TYPES = ("MODEL", )
FUNCTION = "patch"
DESCRIPTION = "Experimental node for patching attention mode."
EXPERIMENTAL = True
CATEGORY = "KJNodes/experimental"
def patch(self, model, sage_attention):
self._patch_modules(False, sage_attention)
return model,
class CheckpointLoaderKJ(BaseLoaderKJ):
@classmethod
def INPUT_TYPES(s):