From 7ce88200c7eee0afc854eef3057a1709456d17e9 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Thu, 2 Oct 2025 01:09:59 +0300 Subject: [PATCH] sageattn3 --- nodes/model_optimization_nodes.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index 831a299..013bc4f 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -17,7 +17,7 @@ except ImportError: v3_available = False logging.warning("ComfyUI v3 node API not available, please update ComfyUI to access latest v3 nodes.") -sageattn_modes = ["disabled", "auto", "sageattn_qk_int8_pv_fp16_cuda", "sageattn_qk_int8_pv_fp16_triton", "sageattn_qk_int8_pv_fp8_cuda", "sageattn_qk_int8_pv_fp8_cuda++"] +sageattn_modes = ["disabled", "auto", "sageattn_qk_int8_pv_fp16_cuda", "sageattn_qk_int8_pv_fp16_triton", "sageattn_qk_int8_pv_fp8_cuda", "sageattn_qk_int8_pv_fp8_cuda++", "sageattn3", "sageattn3_per_block_mean"] _initialized = False _original_functions = {} @@ -118,6 +118,15 @@ class BaseLoaderKJ: def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"): return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp16", tensor_layout=tensor_layout) return func + elif "sageattn3" in sage_attention: + from sageattn3 import sageattn3_blackwell + if sage_attention == "sageattn3_per_block_mean": + def func(q, k, v, is_causal=False, attn_mask=None): + return sageattn3_blackwell(q, k, v, is_causal=is_causal, attn_mask=attn_mask, per_block_mean=True) + else: + def func(q, k, v, is_causal=False, attn_mask=None): + return sageattn3_blackwell(q, k, v, is_causal=is_causal, attn_mask=attn_mask, per_block_mean=False) + return func sage_func = set_sage_func(sage_attention)