Update model_optimization_nodes.py

This commit is contained in:
kijai 2025-10-02 01:18:39 +03:00
parent 7ce88200c7
commit ab89f1c6c9

View File

@ -121,10 +121,10 @@ class BaseLoaderKJ:
elif "sageattn3" in sage_attention:
from sageattn3 import sageattn3_blackwell
if sage_attention == "sageattn3_per_block_mean":
def func(q, k, v, is_causal=False, attn_mask=None):
def func(q, k, v, is_causal=False, attn_mask=None, **kwargs):
return sageattn3_blackwell(q, k, v, is_causal=is_causal, attn_mask=attn_mask, per_block_mean=True)
else:
def func(q, k, v, is_causal=False, attn_mask=None):
def func(q, k, v, is_causal=False, attn_mask=None, **kwargs):
return sageattn3_blackwell(q, k, v, is_causal=is_causal, attn_mask=attn_mask, per_block_mean=False)
return func