From ab89f1c6c93eeb09d3272bd847bf265bbdae33c7 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Thu, 2 Oct 2025 01:18:39 +0300 Subject: [PATCH] Update model_optimization_nodes.py --- nodes/model_optimization_nodes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nodes/model_optimization_nodes.py b/nodes/model_optimization_nodes.py index 013bc4f..f413c17 100644 --- a/nodes/model_optimization_nodes.py +++ b/nodes/model_optimization_nodes.py @@ -121,10 +121,10 @@ class BaseLoaderKJ: elif "sageattn3" in sage_attention: from sageattn3 import sageattn3_blackwell if sage_attention == "sageattn3_per_block_mean": - def func(q, k, v, is_causal=False, attn_mask=None): + def func(q, k, v, is_causal=False, attn_mask=None, **kwargs): return sageattn3_blackwell(q, k, v, is_causal=is_causal, attn_mask=attn_mask, per_block_mean=True) else: - def func(q, k, v, is_causal=False, attn_mask=None): + def func(q, k, v, is_causal=False, attn_mask=None, **kwargs): return sageattn3_blackwell(q, k, v, is_causal=is_causal, attn_mask=attn_mask, per_block_mean=False) return func