From ff49e1b01f10a14496b08e21bb89b64d2b15f333 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 1 Jul 2025 12:26:27 +0300 Subject: [PATCH] Add sageattn++ as selectable mode for easier testing --- nodes/model_optimization_nodes.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index 9d7c6cb..e0b55fe 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -8,6 +8,7 @@ import folder_paths import comfy.model_management as mm from comfy.cli_args import args +sageattn_modes = ["disabled", "auto", "sageattn_qk_int8_pv_fp16_cuda", "sageattn_qk_int8_pv_fp16_triton", "sageattn_qk_int8_pv_fp8_cuda", "sageattn_qk_int8_pv_fp8_cuda++"] _initialized = False _original_functions = {} @@ -49,6 +50,11 @@ class BaseLoaderKJ: def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"): return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp32", tensor_layout=tensor_layout) return func + elif sage_attention == "sageattn_qk_int8_pv_fp8_cuda++": + from sageattention import sageattn_qk_int8_pv_fp8_cuda + def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"): + return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp16", tensor_layout=tensor_layout) + return func sage_func = set_sage_func(sage_attention) @@ -136,7 +142,7 @@ class PathchSageAttentionKJ(BaseLoaderKJ): def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), - "sage_attention": (["disabled", "auto", "sageattn_qk_int8_pv_fp16_cuda", "sageattn_qk_int8_pv_fp16_triton", "sageattn_qk_int8_pv_fp8_cuda"], {"default": False, "tooltip": "Global patch comfy attention to use sageattn, once patched to revert back to normal you would need to run this node again with disabled option."}), + "sage_attention": (sageattn_modes, {"default": False, "tooltip": "Global patch comfy attention to use sageattn, once patched to revert back to normal you would need to run this node again with disabled option."}), }} RETURN_TYPES = ("MODEL", ) @@ -167,7 +173,7 @@ class CheckpointLoaderKJ(BaseLoaderKJ): "weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2", "fp16", "bf16", "fp32"],), "compute_dtype": (["default", "fp16", "bf16", "fp32"], {"default": "default", "tooltip": "The compute dtype to use for the model."}), "patch_cublaslinear": ("BOOLEAN", {"default": False, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}), - "sage_attention": (["disabled", "auto", "sageattn_qk_int8_pv_fp16_cuda", "sageattn_qk_int8_pv_fp16_triton", "sageattn_qk_int8_pv_fp8_cuda"], {"default": False, "tooltip": "Patch comfy attention to use sageattn."}), + "sage_attention": (sageattn_modes, {"default": False, "tooltip": "Patch comfy attention to use sageattn."}), "enable_fp16_accumulation": ("BOOLEAN", {"default": False, "tooltip": "Enable torch.backends.cuda.matmul.allow_fp16_accumulation, requires pytorch 2.7.0 nightly."}), }} @@ -309,7 +315,7 @@ class DiffusionModelLoaderKJ(BaseLoaderKJ): "weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2", "fp16", "bf16", "fp32"],), "compute_dtype": (["default", "fp16", "bf16", "fp32"], {"default": "default", "tooltip": "The compute dtype to use for the model."}), "patch_cublaslinear": ("BOOLEAN", {"default": False, "tooltip": "Enable or disable the patching, won't take effect on already loaded models!"}), - "sage_attention": (["disabled", "auto", "sageattn_qk_int8_pv_fp16_cuda", "sageattn_qk_int8_pv_fp16_triton", "sageattn_qk_int8_pv_fp8_cuda"], {"default": False, "tooltip": "Patch comfy attention to use sageattn."}), + "sage_attention": (sageattn_modes, {"default": False, "tooltip": "Patch comfy attention to use sageattn."}), "enable_fp16_accumulation": ("BOOLEAN", {"default": False, "tooltip": "Enable torch.backends.cuda.matmul.allow_fp16_accumulation, requires pytorch 2.7.0 nightly."}), }}