diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index 831a299..013bc4f 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -17,7 +17,7 @@ except ImportError: v3_available = False logging.warning("ComfyUI v3 node API not available, please update ComfyUI to access latest v3 nodes.") -sageattn_modes = ["disabled", "auto", "sageattn_qk_int8_pv_fp16_cuda", "sageattn_qk_int8_pv_fp16_triton", "sageattn_qk_int8_pv_fp8_cuda", "sageattn_qk_int8_pv_fp8_cuda++"] +sageattn_modes = ["disabled", "auto", "sageattn_qk_int8_pv_fp16_cuda", "sageattn_qk_int8_pv_fp16_triton", "sageattn_qk_int8_pv_fp8_cuda", "sageattn_qk_int8_pv_fp8_cuda++", "sageattn3", "sageattn3_per_block_mean"] _initialized = False _original_functions = {} @@ -118,6 +118,15 @@ class BaseLoaderKJ: def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"): return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp16", tensor_layout=tensor_layout) return func + elif "sageattn3" in sage_attention: + from sageattn3 import sageattn3_blackwell + if sage_attention == "sageattn3_per_block_mean": + def func(q, k, v, is_causal=False, attn_mask=None): + return sageattn3_blackwell(q, k, v, is_causal=is_causal, attn_mask=attn_mask, per_block_mean=True) + else: + def func(q, k, v, is_causal=False, attn_mask=None): + return sageattn3_blackwell(q, k, v, is_causal=is_causal, attn_mask=attn_mask, per_block_mean=False) + return func sage_func = set_sage_func(sage_attention)